Skip to content

Commit

Permalink
support incremental training for DL models
Browse files Browse the repository at this point in the history
  • Loading branch information
kermitt2 committed Nov 24, 2022
1 parent b9d92c6 commit 2dce9b9
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 30 deletions.
46 changes: 38 additions & 8 deletions grobid-core/src/main/java/org/grobid/core/jni/DeLFTModel.java
Expand Up @@ -209,11 +209,11 @@ public String label(String data) {
* usually hangs... Possibly issues with IO threads at the level of JEP (output not consumed because
* of \r and no end of line?).
*/
public static void trainJNI(String modelName, File trainingData, File outputModel, String architecture) {
public static void trainJNI(String modelName, File trainingData, File outputModel, String architecture, boolean incremental) {
try {
LOGGER.info("Train DeLFT model " + modelName + "...");
JEPThreadPool.getInstance().run(
new TrainTask(modelName, trainingData, GrobidProperties.getInstance().getModelPath(), architecture));
new TrainTask(modelName, trainingData, GrobidProperties.getInstance().getModelPath(), architecture, incremental));
} catch(InterruptedException e) {
LOGGER.error("Train DeLFT model " + modelName + " task failed", e);
}
Expand All @@ -224,13 +224,15 @@ private static class TrainTask implements Runnable {
private File trainPath;
private File modelPath;
private String architecture;
private boolean incremental;

public TrainTask(String modelName, File trainPath, File modelPath, String architecture) {
public TrainTask(String modelName, File trainPath, File modelPath, String architecture, boolean incremental) {
//System.out.println("train thread: " + Thread.currentThread().getId());
this.modelName = modelName;
this.trainPath = trainPath;
this.modelPath = modelPath;
this.architecture = architecture;
this.incremental = incremental;
}

@Override
Expand Down Expand Up @@ -273,7 +275,18 @@ public void run() {

// actual training
//start_time = time.time()
jep.eval("model.train(x_train, y_train, x_valid, y_valid)");
if (incremental) {
// if incremental training, we need to load the existing model
if (this.modelPath != null &&
this.modelPath.exists() &&
!this.modelPath.isDirectory()) {
jep.eval("model.train(x_train, y_train, x_valid, y_valid, incremental=True)");
} else {
throw new GrobidException("the path to the model to be used for starting incremental training is invalid: " +
this.modelPath.getAbsolutePath());
}
} else
jep.eval("model.train(x_train, y_train, x_valid, y_valid)");
//runtime = round(time.time() - start_time, 3)
//print("training runtime: %s seconds " % (runtime))

Expand All @@ -292,6 +305,8 @@ public void run() {
jep.eval("del model");
} catch(JepException e) {
LOGGER.error("DeLFT model training via JEP failed", e);
} catch(GrobidException e) {
LOGGER.error("GROBID call to DeLFT training via JEP failed", e);
}
}
}
Expand All @@ -300,7 +315,7 @@ public void run() {
* Train with an external process rather than with JNI, this approach appears to be more stable for the
* training process (JNI approach hangs after a while) and does not raise any runtime/integration issues.
*/
public static void train(String modelName, File trainingData, File outputModel, String architecture) {
public static void train(String modelName, File trainingData, File outputModel, String architecture, boolean incremental) {
try {
LOGGER.info("Train DeLFT model " + modelName + "...");
List<String> command = new ArrayList<>();
Expand All @@ -322,16 +337,29 @@ public static void train(String modelName, File trainingData, File outputModel,
if (GrobidProperties.getInstance().useELMo(modelName) && modelName.toLowerCase().indexOf("bert") == -1) {
command.add("--use-ELMo");
}

if (GrobidProperties.getInstance().getDelftTrainingMaxSequenceLength(modelName) != -1) {
command.add("--max-sequence-length");
command.add(String.valueOf(GrobidProperties.getInstance().getDelftTrainingMaxSequenceLength(modelName)));
}

if (GrobidProperties.getInstance().getDelftTrainingBatchSize(modelName) != -1) {
command.add("--batch-size");
command.add(String.valueOf(GrobidProperties.getInstance().getDelftTrainingBatchSize(modelName)));
}
if (incremental) {
command.add("--incremental");

// if incremental training, we need to load the existing model
File modelPath = GrobidProperties.getInstance().getModelPath();
if (modelPath != null &&
modelPath.exists() &&
!modelPath.isDirectory()) {
command.add("--input-model");
command.add(GrobidProperties.getInstance().getModelPath().getAbsolutePath());
} else {
throw new GrobidException("the path to the model to be used for starting incremental training is invalid: " +
GrobidProperties.getInstance().getModelPath().getAbsolutePath());
}
}

ProcessBuilder pb = new ProcessBuilder(command);
File delftPath = new File(GrobidProperties.getInstance().getDeLFTFilePath());
Expand All @@ -349,7 +377,9 @@ public static void train(String modelName, File trainingData, File outputModel,
LOGGER.error("IO error when training DeLFT model " + modelName, e);
} catch(InterruptedException e) {
LOGGER.error("Train DeLFT model " + modelName + " task failed", e);
}
} catch(GrobidException e) {
LOGGER.error("GROBID call to DeLFT training via JEP failed", e);
}
}

public synchronized void close() {
Expand Down
Expand Up @@ -128,7 +128,8 @@ public Response getModel(String model, String architecture) {
for (final File currFile : files) {
if (currFile.getName().toLowerCase().endsWith(".hdf5")
|| currFile.getName().toLowerCase().endsWith(".json")
|| currFile.getName().toLowerCase().endsWith(".pkl")) {
|| currFile.getName().toLowerCase().endsWith(".pkl")
|| currFile.getName().toLowerCase().endsWith(".txt")) {
try {
ZipEntry ze = new ZipEntry(currFile.getName());
out.putNextEntry(ze);
Expand Down Expand Up @@ -309,14 +310,14 @@ public void run() {
switch (this.type.toLowerCase()) {
// possible values are `full`, `holdout`, `split`, `nfold`
case "full":
AbstractTrainer.runTraining(this.trainer);
AbstractTrainer.runTraining(this.trainer, false);
break;
case "holdout":
AbstractTrainer.runTraining(this.trainer);
results = AbstractTrainer.runEvaluation(this.trainer);
break;
case "split":
results = AbstractTrainer.runSplitTrainingEvaluation(this.trainer, this.ratio);
results = AbstractTrainer.runSplitTrainingEvaluation(this.trainer, this.ratio, false);
break;
case "nfold":
if (n == 0) {
Expand Down
Expand Up @@ -79,7 +79,7 @@ public int createCRFPPData(final File corpusDir, final File trainingOutputPath)
}

@Override
public void train() {
public void train(boolean incremental) {
final File dataPath = trainDataPath;
createCRFPPData(getCorpusPath(), dataPath);
GenericTrainer trainer = TrainerFactory.getTrainer(model);
Expand All @@ -96,7 +96,7 @@ public void train() {
}
final File tempModelPath = new File(GrobidProperties.getModelPath(model).getAbsolutePath() + NEW_MODEL_EXT);
final File oldModelPath = GrobidProperties.getModelPath(model);
trainer.train(getTemplatePath(), dataPath, tempModelPath, GrobidProperties.getWapitiNbThreads(), model);
trainer.train(getTemplatePath(), dataPath, tempModelPath, GrobidProperties.getWapitiNbThreads(), model, incremental);
// if we are here, that means that training succeeded
// rename model for CRF sequence labellers (not with DeLFT deep learning models)
if (GrobidProperties.getGrobidCRFEngine(this.model) != GrobidCRFEngine.DELFT)
Expand Down Expand Up @@ -134,7 +134,7 @@ public String evaluate(GenericTagger tagger, boolean includeRawResults) {
}

@Override
public String splitTrainEvaluate(Double split) {
public String splitTrainEvaluate(Double split, boolean incremental) {
final File dataPath = trainDataPath;
createCRFPPData(getCorpusPath(), dataPath, evalDataPath, split);
GenericTrainer trainer = TrainerFactory.getTrainer(model);
Expand All @@ -156,7 +156,7 @@ public String splitTrainEvaluate(Double split) {
final File tempModelPath = new File(GrobidProperties.getModelPath(model).getAbsolutePath() + NEW_MODEL_EXT);
final File oldModelPath = GrobidProperties.getModelPath(model);

trainer.train(getTemplatePath(), dataPath, tempModelPath, GrobidProperties.getWapitiNbThreads(), model);
trainer.train(getTemplatePath(), dataPath, tempModelPath, GrobidProperties.getWapitiNbThreads(), model, incremental);

// if we are here, that means that training succeeded
renameModels(oldModelPath, tempModelPath);
Expand Down Expand Up @@ -220,7 +220,7 @@ public String nFoldEvaluate(int numFolds, boolean includeRawResults) {
tempFilePaths.add(fold.getRight());

sb.append("Training input data: " + fold.getLeft()).append("\n");
trainer.train(getTemplatePath(), new File(fold.getLeft()), tempModelPath, GrobidProperties.getWapitiNbThreads(), model);
trainer.train(getTemplatePath(), new File(fold.getLeft()), tempModelPath, GrobidProperties.getWapitiNbThreads(), model, false);
sb.append("Evaluation input data: " + fold.getRight()).append("\n");

//TODO: find a better solution!!
Expand Down Expand Up @@ -548,8 +548,12 @@ public GrobidModel getModel() {
}

public static void runTraining(final Trainer trainer) {
runTraining(trainer, false);
}

public static void runTraining(final Trainer trainer, boolean incremental) {
long start = System.currentTimeMillis();
trainer.train();
trainer.train(incremental);
long end = System.currentTimeMillis();

System.out.println("Model for " + trainer.getModel() + " created in " + (end - start) + " ms");
Expand Down Expand Up @@ -577,11 +581,11 @@ public static String runEvaluation(final Trainer trainer) {
return trainer.evaluate(false);
}

public static String runSplitTrainingEvaluation(final Trainer trainer, Double split) {
public static String runSplitTrainingEvaluation(final Trainer trainer, Double split, boolean incremental) {
long start = System.currentTimeMillis();
String report = "";
try {
report = trainer.splitTrainEvaluate(split);
report = trainer.splitTrainEvaluate(split, incremental);

} catch (Exception e) {
throw new GrobidException("An exception occurred while evaluating Grobid.", e);
Expand Down
Expand Up @@ -29,7 +29,7 @@ public CRFPPGenericTrainer() {
}

@Override
public void train(File template, File trainingData, File outputModel, int numThreads, GrobidModel model) {
public void train(File template, File trainingData, File outputModel, int numThreads, GrobidModel model, boolean incremental) {
crfppTrainer.train(template.getAbsolutePath(), trainingData.getAbsolutePath(), outputModel.getAbsolutePath(), numThreads);
if (!crfppTrainer.what().isEmpty()) {
LOGGER.warn("CRF++ Trainer warnings:\n" + crfppTrainer.what());
Expand Down
Expand Up @@ -14,8 +14,8 @@ public class DeLFTTrainer implements GenericTrainer {
public static final String DELFT = "delft";

@Override
public void train(File template, File trainingData, File outputModel, int numThreads, GrobidModel model) {
DeLFTModel.train(model.getModelName(), trainingData, outputModel, GrobidProperties.getDelftArchitecture(model));
public void train(File template, File trainingData, File outputModel, int numThreads, GrobidModel model, boolean incremental) {
DeLFTModel.train(model.getModelName(), trainingData, outputModel, GrobidProperties.getDelftArchitecture(model), incremental);
}

@Override
Expand Down
Expand Up @@ -9,7 +9,7 @@
*/
public class DummyTrainer implements GenericTrainer {
@Override
public void train(File template, File trainingData, File outputModel, int numThreads, GrobidModel model) {
public void train(File template, File trainingData, File outputModel, int numThreads, GrobidModel model, boolean incremental) {

}

Expand Down
Expand Up @@ -5,7 +5,7 @@
import java.io.File;

public interface GenericTrainer {
void train(File template, File trainingData, File outputModel, int numThreads, GrobidModel model);
void train(File template, File trainingData, File outputModel, int numThreads, GrobidModel model, boolean incremental);
String getName();
public void setEpsilon(double epsilon);
public void setWindow(int window);
Expand Down
4 changes: 2 additions & 2 deletions grobid-trainer/src/main/java/org/grobid/trainer/Trainer.java
Expand Up @@ -12,15 +12,15 @@ public interface Trainer {

int createCRFPPData(File corpusPath, File outputTrainingFile, File outputEvalFile, double splitRatio);

void train();
void train(boolean incremental);

String evaluate();

String evaluate(boolean includeRawResults);

String evaluate(GenericTagger tagger, boolean includeRawResults);

String splitTrainEvaluate(Double split);
String splitTrainEvaluate(Double split, boolean incremental);

String nFoldEvaluate(int folds);

Expand Down
Expand Up @@ -53,6 +53,7 @@ public static void main(String[] args) {
double split = 0.0;
int numFolds = 0;
String outputFilePath = null;
boolean incremental = false;
for (int i = 0; i < args.length; i++) {
if (args[i].equals("-gH")) {
if (i + 1 == args.length) {
Expand Down Expand Up @@ -85,6 +86,9 @@ public static void main(String[] args) {
}
outputFilePath = args[i + 1];

} else if (args[i].equals("-i")) {
incremental = true;

}
}

Expand Down Expand Up @@ -136,13 +140,13 @@ public static void main(String[] args) {

switch (mode) {
case TRAIN:
AbstractTrainer.runTraining(trainer);
AbstractTrainer.runTraining(trainer, incremental);
break;
case EVAL:
System.out.println(AbstractTrainer.runEvaluation(trainer));
break;
case SPLIT:
System.out.println(AbstractTrainer.runSplitTrainingEvaluation(trainer, split));
System.out.println(AbstractTrainer.runSplitTrainingEvaluation(trainer, split, incremental));
break;
case EVAL_N_FOLD:
if(numFolds == 0) {
Expand All @@ -154,7 +158,7 @@ public static void main(String[] args) {
System.err.println("Output file exists. ");
}
} else {
String results = AbstractTrainer.runNFoldEvaluation(trainer, numFolds);
String results = AbstractTrainer.runNFoldEvaluation(trainer, numFolds, incremental);
System.out.println(results);
}
break;
Expand Down
Expand Up @@ -18,7 +18,7 @@ public class WapitiTrainer implements GenericTrainer {
protected int nbMaxIterations = 2000; // by default maximum of training iterations

@Override
public void train(File template, File trainingData, File outputModel, int numThreads, GrobidModel model) {
public void train(File template, File trainingData, File outputModel, int numThreads, GrobidModel model, boolean incremental) {
System.out.println("\tepsilon: " + epsilon);
System.out.println("\twindow: " + window);
System.out.println("\tnb max iterations: " + nbMaxIterations);
Expand Down

0 comments on commit 2dce9b9

Please sign in to comment.