From ac4dc2cf61d5a0491663d25b2f177981506972f6 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Wed, 3 Feb 2021 17:27:33 +0900 Subject: [PATCH] add algorithm option for wapiti --- .../java/org/grobid/trainer/AbstractTrainer.java | 8 ++++++++ .../org/grobid/trainer/CRFPPGenericTrainer.java | 8 +++++++- .../main/java/org/grobid/trainer/DeLFTTrainer.java | 6 +++++- .../main/java/org/grobid/trainer/DummyTrainer.java | 3 +++ .../java/org/grobid/trainer/GenericTrainer.java | 13 +++++++------ .../main/java/org/grobid/trainer/WapitiTrainer.java | 13 +++++++++++-- 6 files changed, 41 insertions(+), 10 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index 81e6f8e66a..86e474c030 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -49,6 +49,7 @@ public abstract class AbstractTrainer implements Trainer { protected double epsilon = 0.0; // size of the interval for stopping criterion protected int window = 0; // similar to CRF++ protected int nbMaxIterations = 0; // maximum number of iterations in training + protected String algorithm = ""; // algorithm protected GrobidModel model; private File trainDataPath; @@ -94,6 +95,9 @@ public void train() { if (nbMaxIterations != 0) trainer.setNbMaxIterations(nbMaxIterations); + if (StringUtils.isNotBlank(algorithm)) + trainer.setAlgorithm(algorithm); + File dirModelPath = new File(GrobidProperties.getModelPath(model).getAbsolutePath()).getParentFile(); if (!dirModelPath.exists()) { LOGGER.warn("Cannot find the destination directory " + dirModelPath.getAbsolutePath() + " for the model " + model.getModelName() + ". Creating it."); @@ -151,6 +155,8 @@ public String splitTrainEvaluate(Double split) { trainer.setWindow(window); if (nbMaxIterations != 0) trainer.setNbMaxIterations(nbMaxIterations); + if (StringUtils.isNotBlank(algorithm)) + trainer.setAlgorithm(algorithm); File dirModelPath = new File(GrobidProperties.getModelPath(model).getAbsolutePath()).getParentFile(); if (!dirModelPath.exists()) { @@ -197,6 +203,8 @@ public String nFoldEvaluate(int numFolds, boolean includeRawResults) { trainer.setWindow(window); if (nbMaxIterations != 0) trainer.setNbMaxIterations(nbMaxIterations); + if (StringUtils.isNotBlank(algorithm)) + trainer.setAlgorithm(algorithm); //We dump the model in the tmp directory File tmpDirectory = new File(GrobidProperties.getTempPath().getAbsolutePath()); diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/CRFPPGenericTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/CRFPPGenericTrainer.java index 9a966ce236..307e11683e 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/CRFPPGenericTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/CRFPPGenericTrainer.java @@ -26,6 +26,7 @@ public class CRFPPGenericTrainer implements GenericTrainer { protected double epsilon = 0.00001; // default size of the interval for stopping criterion protected int window = 20; // default similar to CRF++ protected int nbMaxIterations = 6000; + protected String algorithm = "crf-l2"; public CRFPPGenericTrainer() { crfppTrainer = new CRFPPTrainer(); @@ -70,7 +71,12 @@ public int getWindow() { public void setNbMaxIterations(int interations) { this.nbMaxIterations = interations; } - + + @Override + public void setAlgorithm(String algorithm) { + this.algorithm = algorithm; + } + @Override public int getNbMaxIterations() { return nbMaxIterations; diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/DeLFTTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/DeLFTTrainer.java index 4add8aeee1..6473954ba0 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/DeLFTTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/DeLFTTrainer.java @@ -50,7 +50,11 @@ public int getWindow() { @Override public void setNbMaxIterations(int interations) { } - + + @Override + public void setAlgorithm(String algorithm) { + } + @Override public int getNbMaxIterations() { return 0; diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/DummyTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/DummyTrainer.java index 4e90920609..6d254ee555 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/DummyTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/DummyTrainer.java @@ -45,6 +45,9 @@ public int getNbMaxIterations() { @Override public void setNbMaxIterations(int iterations) { + } + @Override + public void setAlgorithm(String algorithm) { } } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/GenericTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/GenericTrainer.java index d083f8e482..aa03bba2ef 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/GenericTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/GenericTrainer.java @@ -11,10 +11,11 @@ public interface GenericTrainer { void train(File template, File trainingData, File outputModel, int numThreads, GrobidModel model); String getName(); - public void setEpsilon(double epsilon); - public void setWindow(int window); - public double getEpsilon(); - public int getWindow(); - public int getNbMaxIterations(); - public void setNbMaxIterations(int iterations); + void setEpsilon(double epsilon); + void setWindow(int window); + double getEpsilon(); + int getWindow(); + int getNbMaxIterations(); + void setNbMaxIterations(int iterations); + void setAlgorithm(String algorithm); } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/WapitiTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/WapitiTrainer.java index f4c5193747..9ebe338079 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/WapitiTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/WapitiTrainer.java @@ -20,6 +20,8 @@ public class WapitiTrainer implements GenericTrainer { protected double epsilon = 0.00001; // default size of the interval for stopping criterion protected int window = 20; // default similar to CRF++ protected int nbMaxIterations = 2000; // by default maximum of training iterations + protected String algorithm = "l-bfgs"; // algorithm to be used, values: l-bfgs (default), sgd-l1, bcd, rprop, rprop+, rprop- + @Override public void train(File template, File trainingData, File outputModel, int numThreads, GrobidModel model) { @@ -27,8 +29,10 @@ public void train(File template, File trainingData, File outputModel, int numThr System.out.println("\twindow: " + window); System.out.println("\tnb max iterations: " + nbMaxIterations); System.out.println("\tnb threads: " + numThreads); + System.out.println("\talgorithm: " + algorithm); + WapitiModel.train(template, trainingData, outputModel, "--nthread " + numThreads + -// " --algo sgd-l1" + + " --algo " + algorithm + " -e " + BigDecimal.valueOf(epsilon).toPlainString() + " -w " + window + " -i " + nbMaxIterations @@ -64,7 +68,12 @@ public int getWindow() { public void setNbMaxIterations(int interations) { this.nbMaxIterations = interations; } - + + @Override + public void setAlgorithm(String algorithm) { + this.algorithm = algorithm; + } + @Override public int getNbMaxIterations() { return nbMaxIterations;