Skip to content

Commit

Permalink
Merge 67bf75a into 47c8978
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Apr 9, 2020
2 parents 47c8978 + 67bf75a commit 25733c9
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
2 changes: 2 additions & 0 deletions build.gradle
Expand Up @@ -85,6 +85,8 @@ subprojects {
testCompile "org.powermock:powermock-module-junit4:2.0.0-beta.5"
testCompile "xmlunit:xmlunit:1.6"
testCompile "org.hamcrest:hamcrest-all:1.3"

compile 'org.apache.commons:commons-text:1.8'
}

task sourceJar(type: Jar) {
Expand Down
@@ -1,8 +1,10 @@
package org.grobid.trainer;

import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.text.RandomStringGenerator;
import org.grobid.core.GrobidModel;
import org.grobid.core.GrobidModels;
import org.grobid.core.engines.tagging.GenericTagger;
Expand All @@ -12,6 +14,7 @@
import org.grobid.core.factory.GrobidFactory;
import org.grobid.core.utilities.GrobidProperties;
import org.grobid.core.utilities.TextUtilities;
import org.grobid.core.utilities.Utilities;
import org.grobid.trainer.evaluation.EvaluationUtilities;
import org.grobid.trainer.evaluation.LabelResult;
import org.grobid.trainer.evaluation.ModelStats;
Expand Down Expand Up @@ -51,6 +54,7 @@ public abstract class AbstractTrainer implements Trainer {
private File trainDataPath;
private File evalDataPath;
private GenericTagger tagger;
private RandomStringGenerator randomStringGenerator;

public AbstractTrainer(final GrobidModel model) {
GrobidFactory.getInstance().createEngine();
Expand All @@ -61,6 +65,9 @@ public AbstractTrainer(final GrobidModel model) {
}
this.trainDataPath = getTempTrainingDataPath();
this.evalDataPath = getTempEvaluationDataPath();
this.randomStringGenerator = new RandomStringGenerator.Builder()
.withinRange('a', 'z')
.build();
}

public void setParams(double epsilon, int window, int nbMaxIterations) {
Expand Down Expand Up @@ -174,6 +181,8 @@ public String nFoldEvaluate(int numFolds, boolean includeRawResults) {
createCRFPPData(getCorpusPath(), dataPath);
GenericTrainer trainer = TrainerFactory.getTrainer();

String randomString = randomStringGenerator.generate(10);

// Load in memory and Shuffle
Path dataPath2 = Paths.get(dataPath.getAbsolutePath());
List<String> trainingData = loadAndShuffle(dataPath2);
Expand All @@ -195,6 +204,8 @@ public String nFoldEvaluate(int numFolds, boolean includeRawResults) {
LOGGER.warn("Cannot find the destination directory " + tmpDirectory);
}

List<String> tempFilePaths = new ArrayList<>();

// Output
StringBuilder sb = new StringBuilder();
sb.append("Recap results for each fold:").append("\n\n");
Expand All @@ -206,9 +217,14 @@ public String nFoldEvaluate(int numFolds, boolean includeRawResults) {
System.out.println("====================== Fold " + counter.get() + " ====================== ");

final File tempModelPath = new File(tmpDirectory + File.separator + getModel().getModelName()
+ "_nfold_" + counter.getAndIncrement() + ".wapiti");
+ "_nfold_" + counter.getAndIncrement() + "_" + randomString + ".wapiti");
sb.append("Saving model in " + tempModelPath).append("\n");

// Collecting generated paths to be deleted at the end of the process
tempFilePaths.add(tempModelPath.getAbsolutePath());
tempFilePaths.add(fold.getLeft());
tempFilePaths.add(fold.getRight());

sb.append("Training input data: " + fold.getLeft()).append("\n");
trainer.train(getTemplatePath(), new File(fold.getLeft()), tempModelPath, GrobidProperties.getNBThreads(), model);
sb.append("Evaluation input data: " + fold.getRight()).append("\n");
Expand Down Expand Up @@ -373,6 +389,14 @@ public String getTemplateName() {
"Instance-level recall:",
TextUtilities.formatTwoDecimals(averageCorrectInstances / averageTotalInstances * 100)));

// Cleanup
tempFilePaths.stream().forEach(f -> {
try {
Files.delete(Paths.get(f));
} catch (IOException e) {
LOGGER.warn("Error while performing the cleanup after n-fold cross-validation. Cannot delete the file: " + f, e);
}
});

return sb.toString();
}
Expand Down Expand Up @@ -411,7 +435,7 @@ protected List<ImmutablePair<String, String>> splitNFold(List<String> trainingDa
throw new GrobidException("Error when dumping n-fold evaluation data into files. ", e);
}

//Dump Training
//Remove temporary Training and models files (note: the original data files (.train and .eval) are not deleted)
String tempTrainingDataPath = getTempTrainingDataPath().getAbsolutePath();
try (BufferedWriter writer = Files.newBufferedWriter(Paths.get(tempTrainingDataPath))) {
writer.write(String.join("\n\n", foldTraining));
Expand Down
Expand Up @@ -220,8 +220,8 @@ public boolean accept(File dir, String name) {
*/
public static void main(String[] args) throws Exception {
GrobidProperties.getInstance();
AbstractTrainer.runTraining(new TableTrainer());
System.out.println(AbstractTrainer.runEvaluation(new TableTrainer()));
System.out.println(AbstractTrainer.runNFoldEvaluation(new TableTrainer(), 2));
// System.out.println(AbstractTrainer.runEvaluation(new TableTrainer()));
System.exit(0);
}
}

0 comments on commit 25733c9

Please sign in to comment.