Skip to content

Commit

Permalink
add algorithm option for wapiti
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Feb 3, 2021
1 parent bfc10f7 commit ac4dc2c
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 10 deletions.
Expand Up @@ -49,6 +49,7 @@ public abstract class AbstractTrainer implements Trainer {
protected double epsilon = 0.0; // size of the interval for stopping criterion
protected int window = 0; // similar to CRF++
protected int nbMaxIterations = 0; // maximum number of iterations in training
protected String algorithm = ""; // algorithm

protected GrobidModel model;
private File trainDataPath;
Expand Down Expand Up @@ -94,6 +95,9 @@ public void train() {
if (nbMaxIterations != 0)
trainer.setNbMaxIterations(nbMaxIterations);

if (StringUtils.isNotBlank(algorithm))
trainer.setAlgorithm(algorithm);

File dirModelPath = new File(GrobidProperties.getModelPath(model).getAbsolutePath()).getParentFile();
if (!dirModelPath.exists()) {
LOGGER.warn("Cannot find the destination directory " + dirModelPath.getAbsolutePath() + " for the model " + model.getModelName() + ". Creating it.");
Expand Down Expand Up @@ -151,6 +155,8 @@ public String splitTrainEvaluate(Double split) {
trainer.setWindow(window);
if (nbMaxIterations != 0)
trainer.setNbMaxIterations(nbMaxIterations);
if (StringUtils.isNotBlank(algorithm))
trainer.setAlgorithm(algorithm);

File dirModelPath = new File(GrobidProperties.getModelPath(model).getAbsolutePath()).getParentFile();
if (!dirModelPath.exists()) {
Expand Down Expand Up @@ -197,6 +203,8 @@ public String nFoldEvaluate(int numFolds, boolean includeRawResults) {
trainer.setWindow(window);
if (nbMaxIterations != 0)
trainer.setNbMaxIterations(nbMaxIterations);
if (StringUtils.isNotBlank(algorithm))
trainer.setAlgorithm(algorithm);

//We dump the model in the tmp directory
File tmpDirectory = new File(GrobidProperties.getTempPath().getAbsolutePath());
Expand Down
Expand Up @@ -26,6 +26,7 @@ public class CRFPPGenericTrainer implements GenericTrainer {
protected double epsilon = 0.00001; // default size of the interval for stopping criterion
protected int window = 20; // default similar to CRF++
protected int nbMaxIterations = 6000;
protected String algorithm = "crf-l2";

public CRFPPGenericTrainer() {
crfppTrainer = new CRFPPTrainer();
Expand Down Expand Up @@ -70,7 +71,12 @@ public int getWindow() {
public void setNbMaxIterations(int interations) {
this.nbMaxIterations = interations;
}


@Override
public void setAlgorithm(String algorithm) {
this.algorithm = algorithm;
}

@Override
public int getNbMaxIterations() {
return nbMaxIterations;
Expand Down
Expand Up @@ -50,7 +50,11 @@ public int getWindow() {
@Override
public void setNbMaxIterations(int interations) {
}


@Override
public void setAlgorithm(String algorithm) {
}

@Override
public int getNbMaxIterations() {
return 0;
Expand Down
Expand Up @@ -45,6 +45,9 @@ public int getNbMaxIterations() {

@Override
public void setNbMaxIterations(int iterations) {
}

@Override
public void setAlgorithm(String algorithm) {
}
}
Expand Up @@ -11,10 +11,11 @@
public interface GenericTrainer {
void train(File template, File trainingData, File outputModel, int numThreads, GrobidModel model);
String getName();
public void setEpsilon(double epsilon);
public void setWindow(int window);
public double getEpsilon();
public int getWindow();
public int getNbMaxIterations();
public void setNbMaxIterations(int iterations);
void setEpsilon(double epsilon);
void setWindow(int window);
double getEpsilon();
int getWindow();
int getNbMaxIterations();
void setNbMaxIterations(int iterations);
void setAlgorithm(String algorithm);
}
13 changes: 11 additions & 2 deletions grobid-trainer/src/main/java/org/grobid/trainer/WapitiTrainer.java
Expand Up @@ -20,15 +20,19 @@ public class WapitiTrainer implements GenericTrainer {
protected double epsilon = 0.00001; // default size of the interval for stopping criterion
protected int window = 20; // default similar to CRF++
protected int nbMaxIterations = 2000; // by default maximum of training iterations
protected String algorithm = "l-bfgs"; // algorithm to be used, values: l-bfgs (default), sgd-l1, bcd, rprop, rprop+, rprop-


@Override
public void train(File template, File trainingData, File outputModel, int numThreads, GrobidModel model) {
System.out.println("\tepsilon: " + epsilon);
System.out.println("\twindow: " + window);
System.out.println("\tnb max iterations: " + nbMaxIterations);
System.out.println("\tnb threads: " + numThreads);
System.out.println("\talgorithm: " + algorithm);

WapitiModel.train(template, trainingData, outputModel, "--nthread " + numThreads +
// " --algo sgd-l1" +
" --algo " + algorithm +
" -e " + BigDecimal.valueOf(epsilon).toPlainString() +
" -w " + window +
" -i " + nbMaxIterations
Expand Down Expand Up @@ -64,7 +68,12 @@ public int getWindow() {
public void setNbMaxIterations(int interations) {
this.nbMaxIterations = interations;
}


@Override
public void setAlgorithm(String algorithm) {
this.algorithm = algorithm;
}

@Override
public int getNbMaxIterations() {
return nbMaxIterations;
Expand Down

0 comments on commit ac4dc2c

Please sign in to comment.