Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove side-effects when running several proceses of n-fold cross-validation of the same model while sharing the same grobid-home directory #565

Merged
merged 2 commits into from
Apr 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ subprojects {
testCompile "org.powermock:powermock-module-junit4:2.0.0-beta.5"
testCompile "xmlunit:xmlunit:1.6"
testCompile "org.hamcrest:hamcrest-all:1.3"

compile 'org.apache.commons:commons-text:1.8'
}

task sourceJar(type: Jar) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package org.grobid.trainer;

import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.text.RandomStringGenerator;
import org.grobid.core.GrobidModel;
import org.grobid.core.GrobidModels;
import org.grobid.core.engines.tagging.GenericTagger;
Expand All @@ -12,6 +14,7 @@
import org.grobid.core.factory.GrobidFactory;
import org.grobid.core.utilities.GrobidProperties;
import org.grobid.core.utilities.TextUtilities;
import org.grobid.core.utilities.Utilities;
import org.grobid.trainer.evaluation.EvaluationUtilities;
import org.grobid.trainer.evaluation.LabelResult;
import org.grobid.trainer.evaluation.ModelStats;
Expand Down Expand Up @@ -51,6 +54,7 @@ public abstract class AbstractTrainer implements Trainer {
private File trainDataPath;
private File evalDataPath;
private GenericTagger tagger;
private RandomStringGenerator randomStringGenerator;

public AbstractTrainer(final GrobidModel model) {
GrobidFactory.getInstance().createEngine();
Expand All @@ -61,6 +65,9 @@ public AbstractTrainer(final GrobidModel model) {
}
this.trainDataPath = getTempTrainingDataPath();
this.evalDataPath = getTempEvaluationDataPath();
this.randomStringGenerator = new RandomStringGenerator.Builder()
.withinRange('a', 'z')
.build();
}

public void setParams(double epsilon, int window, int nbMaxIterations) {
Expand Down Expand Up @@ -174,6 +181,8 @@ public String nFoldEvaluate(int numFolds, boolean includeRawResults) {
createCRFPPData(getCorpusPath(), dataPath);
GenericTrainer trainer = TrainerFactory.getTrainer();

String randomString = randomStringGenerator.generate(10);

// Load in memory and Shuffle
Path dataPath2 = Paths.get(dataPath.getAbsolutePath());
List<String> trainingData = loadAndShuffle(dataPath2);
Expand All @@ -195,6 +204,8 @@ public String nFoldEvaluate(int numFolds, boolean includeRawResults) {
LOGGER.warn("Cannot find the destination directory " + tmpDirectory);
}

List<String> tempFilePaths = new ArrayList<>();

// Output
StringBuilder sb = new StringBuilder();
sb.append("Recap results for each fold:").append("\n\n");
Expand All @@ -206,9 +217,14 @@ public String nFoldEvaluate(int numFolds, boolean includeRawResults) {
System.out.println("====================== Fold " + counter.get() + " ====================== ");

final File tempModelPath = new File(tmpDirectory + File.separator + getModel().getModelName()
+ "_nfold_" + counter.getAndIncrement() + ".wapiti");
+ "_nfold_" + counter.getAndIncrement() + "_" + randomString + ".wapiti");
sb.append("Saving model in " + tempModelPath).append("\n");

// Collecting generated paths to be deleted at the end of the process
tempFilePaths.add(tempModelPath.getAbsolutePath());
tempFilePaths.add(fold.getLeft());
tempFilePaths.add(fold.getRight());

sb.append("Training input data: " + fold.getLeft()).append("\n");
trainer.train(getTemplatePath(), new File(fold.getLeft()), tempModelPath, GrobidProperties.getNBThreads(), model);
sb.append("Evaluation input data: " + fold.getRight()).append("\n");
Expand Down Expand Up @@ -373,6 +389,14 @@ public String getTemplateName() {
"Instance-level recall:",
TextUtilities.formatTwoDecimals(averageCorrectInstances / averageTotalInstances * 100)));

// Cleanup
tempFilePaths.stream().forEach(f -> {
try {
Files.delete(Paths.get(f));
} catch (IOException e) {
LOGGER.warn("Error while performing the cleanup after n-fold cross-validation. Cannot delete the file: " + f, e);
}
});

return sb.toString();
}
Expand Down Expand Up @@ -411,7 +435,7 @@ protected List<ImmutablePair<String, String>> splitNFold(List<String> trainingDa
throw new GrobidException("Error when dumping n-fold evaluation data into files. ", e);
}

//Dump Training
//Remove temporary Training and models files (note: the original data files (.train and .eval) are not deleted)
String tempTrainingDataPath = getTempTrainingDataPath().getAbsolutePath();
try (BufferedWriter writer = Files.newBufferedWriter(Paths.get(tempTrainingDataPath))) {
writer.write(String.join("\n\n", foldTraining));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ public boolean accept(File dir, String name) {
*/
public static void main(String[] args) throws Exception {
GrobidProperties.getInstance();
AbstractTrainer.runTraining(new TableTrainer());
System.out.println(AbstractTrainer.runEvaluation(new TableTrainer()));
System.out.println(AbstractTrainer.runNFoldEvaluation(new TableTrainer(), 2));
// System.out.println(AbstractTrainer.runEvaluation(new TableTrainer()));
System.exit(0);
}
}