diff --git a/.gitignore b/.gitignore index c12f5de69a..7e733b7e33 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.java-version target grobid-home/tmp .DS_Store diff --git a/.java-version b/.java-version new file mode 100644 index 0000000000..f2052c3e9a --- /dev/null +++ b/.java-version @@ -0,0 +1 @@ +openjdk64-11.0.2 diff --git a/grobid-core/src/main/java/org/grobid/core/GrobidModels.java b/grobid-core/src/main/java/org/grobid/core/GrobidModels.java index 01e859068f..fc27d5abe5 100755 --- a/grobid-core/src/main/java/org/grobid/core/GrobidModels.java +++ b/grobid-core/src/main/java/org/grobid/core/GrobidModels.java @@ -43,7 +43,11 @@ public enum GrobidModels implements GrobidModel { // ENTITIES_BIOTECH("entities/biotech"), ENTITIES_BIOTECH("bio"), ASTRO("astro"), - SOFTWARE("software"); + SOFTWARE("software"), + DUMMY("none"); + + //I cannot declare it before + public static final String DUMMY_FOLDER_LABEL = "none"; /** * Absolute path to the model. @@ -55,6 +59,12 @@ public enum GrobidModels implements GrobidModel { private static final ConcurrentMap models = new ConcurrentHashMap<>(); GrobidModels(String folderName) { + if(StringUtils.equals(DUMMY_FOLDER_LABEL, folderName)) { + modelPath = DUMMY_FOLDER_LABEL; + this.folderName = DUMMY_FOLDER_LABEL; + return; + } + this.folderName = folderName; File path = GrobidProperties.getModelPath(this); if (!path.exists()) { diff --git a/grobid-core/src/main/java/org/grobid/core/engines/tagging/GenericTaggerUtils.java b/grobid-core/src/main/java/org/grobid/core/engines/tagging/GenericTaggerUtils.java index 2653028e7c..feb5c25098 100644 --- a/grobid-core/src/main/java/org/grobid/core/engines/tagging/GenericTaggerUtils.java +++ b/grobid-core/src/main/java/org/grobid/core/engines/tagging/GenericTaggerUtils.java @@ -1,17 +1,15 @@ package org.grobid.core.engines.tagging; -import com.google.common.base.Function; import com.google.common.base.Joiner; import com.google.common.base.Splitter; import org.apache.commons.lang3.StringUtils; -//import org.grobid.core.utilities.Pair; +import org.apache.commons.lang3.tuple.Pair; import org.grobid.core.utilities.Triple; import org.wipo.analyzers.wipokr.utils.StringUtil; -import org.apache.commons.lang3.tuple.Pair; - import java.util.ArrayList; import java.util.List; +import java.util.function.Function; import java.util.regex.Pattern; /** @@ -22,6 +20,7 @@ public class GenericTaggerUtils { public static final String START_ENTITY_LABEL_PREFIX = "I-"; public static final String START_ENTITY_LABEL_PREFIX_ALTERNATIVE = "B-"; + public static final String START_ENTITY_LABEL_PREFIX_ALTERNATIVE_2 = "E-"; public static final Pattern SEPARATOR_PATTERN = Pattern.compile("[\t ]"); /** @@ -30,31 +29,23 @@ public class GenericTaggerUtils { * Note an empty line in the result will be transformed to a 'null' pointer of a pair */ public static List> getTokensAndLabels(String labeledResult) { - Function, Pair> fromSplits = new Function, Pair>() { - @Override public Pair apply(List splits) { - return Pair.of(splits.get(0), splits.get(splits.size() - 1)); - } - }; - - return processLabeledResult(labeledResult, fromSplits); + return processLabeledResult(labeledResult, splits -> Pair.of(splits.get(0), splits.get(splits.size() - 1))); } /** * @param labeledResult labeled result from a tagger - * @return a list of triples - first element in a pair is a token itself, the second is a label (e.g. or I-) + * @return a list of triples - first element in a pair is a token itself, the second is a label (e.g. or I-) * and the third element is a string with the features * Note an empty line in the result will be transformed to a 'null' pointer of a pair */ public static List> getTokensWithLabelsAndFeatures(String labeledResult, final boolean addFeatureString) { - Function, Triple> fromSplits = new Function, Triple>() { - @Override public Triple apply(List splits) { - String featureString = addFeatureString ? Joiner.on("\t").join(splits.subList(0, splits.size() - 1)) : null; - return new Triple<>( - splits.get(0), - splits.get(splits.size() - 1), - featureString); - } + Function, Triple> fromSplits = splits -> { + String featureString = addFeatureString ? Joiner.on("\t").join(splits.subList(0, splits.size() - 1)) : null; + return new Triple<>( + splits.get(0), + splits.get(splits.size() - 1), + featureString); }; return processLabeledResult(labeledResult, fromSplits); @@ -82,6 +73,8 @@ public static String getPlainLabel(String label) { } public static boolean isBeginningOfEntity(String label) { - return StringUtils.startsWith(label, START_ENTITY_LABEL_PREFIX) || StringUtil.startsWith(label, START_ENTITY_LABEL_PREFIX_ALTERNATIVE); + return StringUtils.startsWith(label, START_ENTITY_LABEL_PREFIX) + || StringUtil.startsWith(label, START_ENTITY_LABEL_PREFIX_ALTERNATIVE) + || StringUtil.startsWith(label, START_ENTITY_LABEL_PREFIX_ALTERNATIVE_2); } } \ No newline at end of file diff --git a/grobid-core/src/main/java/org/grobid/core/engines/tagging/GrobidCRFEngine.java b/grobid-core/src/main/java/org/grobid/core/engines/tagging/GrobidCRFEngine.java index b12cae7e92..cb4d93270a 100644 --- a/grobid-core/src/main/java/org/grobid/core/engines/tagging/GrobidCRFEngine.java +++ b/grobid-core/src/main/java/org/grobid/core/engines/tagging/GrobidCRFEngine.java @@ -9,7 +9,8 @@ public enum GrobidCRFEngine { WAPITI("wapiti"), CRFPP("crf"), - DELFT("delft"); + DELFT("delft"), + DUMMY("dummy"); private final String ext; diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index a9b69d9042..804f1284cb 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -1,18 +1,38 @@ package org.grobid.trainer; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.ImmutablePair; import org.grobid.core.GrobidModel; +import org.grobid.core.GrobidModels; import org.grobid.core.engines.tagging.GenericTagger; +import org.grobid.core.engines.tagging.GrobidCRFEngine; import org.grobid.core.engines.tagging.TaggerFactory; import org.grobid.core.exceptions.GrobidException; import org.grobid.core.factory.GrobidFactory; import org.grobid.core.utilities.GrobidProperties; -import org.grobid.core.engines.tagging.GrobidCRFEngine; +import org.grobid.core.utilities.TextUtilities; import org.grobid.trainer.evaluation.EvaluationUtilities; +import org.grobid.trainer.evaluation.LabelResult; +import org.grobid.trainer.evaluation.ModelStats; +import org.grobid.trainer.evaluation.Stats; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.awt.*; +import java.io.BufferedWriter; import java.io.File; import java.io.IOException; +import java.io.Writer; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.*; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; /** * @author Zholudev, Lopez @@ -35,6 +55,10 @@ public abstract class AbstractTrainer implements Trainer { public AbstractTrainer(final GrobidModel model) { GrobidFactory.getInstance().createEngine(); this.model = model; + if (model.equals(GrobidModels.DUMMY)) { + // In case of dummy model we do not initialise (and create) temporary files + return; + } this.trainDataPath = getTempTrainingDataPath(); this.evalDataPath = getTempEvaluationDataPath(); } @@ -45,13 +69,18 @@ public void setParams(double epsilon, int window, int nbMaxIterations) { this.nbMaxIterations = nbMaxIterations; } + @Override + public int createCRFPPData(final File corpusDir, final File trainingOutputPath) { + return createCRFPPData(corpusDir, trainingOutputPath, null, 1.0); + } + @Override public void train() { final File dataPath = trainDataPath; createCRFPPData(getCorpusPath(), dataPath); GenericTrainer trainer = TrainerFactory.getTrainer(); - if (epsilon != 0.0) + if (epsilon != 0.0) trainer.setEpsilon(epsilon); if (window != 0) trainer.setWindow(window); @@ -88,8 +117,19 @@ protected void renameModels(final File oldModelPath, final File tempModelPath) { @Override public String evaluate() { + return evaluate(false); + } + + @Override + public String evaluate(boolean includeRawResults) { createCRFPPData(getEvalCorpusPath(), evalDataPath); - return EvaluationUtilities.evaluateStandard(evalDataPath.getAbsolutePath(), getTagger()); + return EvaluationUtilities.evaluateStandard(evalDataPath.getAbsolutePath(), getTagger()).toString(includeRawResults); + } + + @Override + public String evaluate(GenericTagger tagger, boolean includeRawResults) { + createCRFPPData(getEvalCorpusPath(), evalDataPath); + return EvaluationUtilities.evaluateStandard(evalDataPath.getAbsolutePath(), tagger).toString(includeRawResults); } @Override @@ -98,7 +138,7 @@ public String splitTrainEvaluate(Double split) { createCRFPPData(getCorpusPath(), dataPath, evalDataPath, split); GenericTrainer trainer = TrainerFactory.getTrainer(); - if (epsilon != 0.0) + if (epsilon != 0.0) trainer.setEpsilon(epsilon); if (window != 0) trainer.setWindow(window); @@ -111,7 +151,7 @@ public String splitTrainEvaluate(Double split) { dirModelPath.mkdir(); //throw new GrobidException("Cannot find the destination directory " + dirModelPath.getAbsolutePath() + " for the model " + model.toString()); } - + final File tempModelPath = new File(GrobidProperties.getModelPath(model).getAbsolutePath() + NEW_MODEL_EXT); final File oldModelPath = GrobidProperties.getModelPath(model); @@ -120,7 +160,310 @@ public String splitTrainEvaluate(Double split) { // if we are here, that means that training succeeded renameModels(oldModelPath, tempModelPath); - return EvaluationUtilities.evaluateStandard(evalDataPath.getAbsolutePath(), getTagger()); + return EvaluationUtilities.evaluateStandard(evalDataPath.getAbsolutePath(), getTagger()).toString(); + } + + @Override + public String nFoldEvaluate(int numFolds) { + return nFoldEvaluate(numFolds, false); + } + + @Override + public String nFoldEvaluate(int numFolds, boolean includeRawResults) { + final File dataPath = trainDataPath; + createCRFPPData(getCorpusPath(), dataPath); + GenericTrainer trainer = TrainerFactory.getTrainer(); + + // Load in memory and Shuffle + Path dataPath2 = Paths.get(dataPath.getAbsolutePath()); + List trainingData = loadAndShuffle(dataPath2); + + // Split into folds + List> foldMap = splitNFold(trainingData, numFolds); + + // Train and evaluation + if (epsilon != 0.0) + trainer.setEpsilon(epsilon); + if (window != 0) + trainer.setWindow(window); + if (nbMaxIterations != 0) + trainer.setNbMaxIterations(nbMaxIterations); + + //We dump the model in the tmp directory + File tmpDirectory = new File(GrobidProperties.getTempPath().getAbsolutePath()); + if (!tmpDirectory.exists()) { + LOGGER.warn("Cannot find the destination directory " + tmpDirectory); + } + + // Output + StringBuilder sb = new StringBuilder(); + sb.append("Recap results for each fold:").append("\n\n"); + + AtomicInteger counter = new AtomicInteger(0); + List evaluationResults = foldMap.stream().map(fold -> { + sb.append("\n"); + sb.append("====================== Fold " + counter.get() + " ====================== ").append("\n"); + System.out.println("====================== Fold " + counter.get() + " ====================== "); + + final File tempModelPath = new File(tmpDirectory + File.separator + getModel().getModelName() + + "_nfold_" + counter.getAndIncrement() + ".wapiti"); + sb.append("Saving model in " + tempModelPath).append("\n"); + + sb.append("Training input data: " + fold.getLeft()).append("\n"); + trainer.train(getTemplatePath(), new File(fold.getLeft()), tempModelPath, GrobidProperties.getNBThreads(), model); + sb.append("Evaluation input data: " + fold.getRight()).append("\n"); + + //TODO: find a better solution!! + GrobidModel tmpModel = new GrobidModel() { + @Override + public String getFolderName() { + return tmpDirectory.getAbsolutePath(); + } + + @Override + public String getModelPath() { + return tempModelPath.getAbsolutePath(); + } + + @Override + public String getModelName() { + return model.getModelName(); + } + + @Override + public String getTemplateName() { + return model.getTemplateName(); + } + }; + + ModelStats modelStats = EvaluationUtilities.evaluateStandard(fold.getRight(), TaggerFactory.getTagger(tmpModel)); + + sb.append(modelStats.toString(includeRawResults)); + sb.append("\n"); + sb.append("\n"); + + return modelStats; + }).collect(Collectors.toList()); + + + sb.append("\n").append("Summary results: ").append("\n"); + + Comparator f1ScoreComparator = (o1, o2) -> { + Stats fieldStatsO1 = o1.getFieldStats(); + Stats fieldStatsO2 = o2.getFieldStats(); + + if (fieldStatsO1.getMacroAverageF1() > fieldStatsO2.getMacroAverageF1()) { + return 1; + } else if (fieldStatsO1.getMacroAverageF1() < fieldStatsO2.getMacroAverageF1()) { + return -1; + } else { + return 0; + } + }; + + Optional worstModel = evaluationResults.stream().min(f1ScoreComparator); + sb.append("Worst Model").append("\n"); + ModelStats worstModelStats = worstModel.orElseGet(() -> { + throw new GrobidException("Something wrong when computing evaluations " + + "- worst model metrics not found. "); + }); + sb.append(worstModelStats.toString()).append("\n"); + + sb.append("Best model:").append("\n"); + Optional bestModel = evaluationResults.stream().max(f1ScoreComparator); + ModelStats bestModelStats = bestModel.orElseGet(() -> { + throw new GrobidException("Something wrong when computing evaluations " + + "- best model metrics not found. "); + }); + sb.append(bestModelStats.toString()).append("\n").append("\n"); + + // Averages + sb.append("Average over " + numFolds + " folds: ").append("\n"); + + TreeMap averagesLabelStats = new TreeMap<>(); + int totalInstances = 0; + int correctInstances = 0; + for (ModelStats ms : evaluationResults) { + totalInstances += ms.getTotalInstances(); + correctInstances += ms.getCorrectInstance(); + for (Map.Entry entry : ms.getFieldStats().getLabelsResults().entrySet()) { + String key = entry.getKey(); + if (averagesLabelStats.containsKey(key)) { + averagesLabelStats.get(key).setAccuracy(averagesLabelStats.get(key).getAccuracy() + entry.getValue().getAccuracy()); + averagesLabelStats.get(key).setF1Score(averagesLabelStats.get(key).getF1Score() + entry.getValue().getF1Score()); + averagesLabelStats.get(key).setRecall(averagesLabelStats.get(key).getRecall() + entry.getValue().getRecall()); + averagesLabelStats.get(key).setPrecision(averagesLabelStats.get(key).getPrecision() + entry.getValue().getPrecision()); + averagesLabelStats.get(key).setSupport(averagesLabelStats.get(key).getSupport() + entry.getValue().getSupport()); + } else { + averagesLabelStats.put(key, new LabelResult(key)); + averagesLabelStats.get(key).setAccuracy(entry.getValue().getAccuracy()); + averagesLabelStats.get(key).setF1Score(entry.getValue().getF1Score()); + averagesLabelStats.get(key).setRecall(entry.getValue().getRecall()); + averagesLabelStats.get(key).setPrecision(entry.getValue().getPrecision()); + averagesLabelStats.get(key).setSupport(entry.getValue().getSupport()); + } + } + } + + sb.append(String.format("\n%-20s %-12s %-12s %-12s %-12s %-7s\n\n", + "label", + "accuracy", + "precision", + "recall", + "f1", + "support")); + + for (String label : averagesLabelStats.keySet()) { + LabelResult labelResult = averagesLabelStats.get(label); + + double avgAccuracy = labelResult.getAccuracy() / evaluationResults.size(); + averagesLabelStats.get(label).setAccuracy(avgAccuracy); + + double avgF1Score = labelResult.getF1Score() / evaluationResults.size(); + averagesLabelStats.get(label).setF1Score(avgF1Score); + + double avgPrecision = labelResult.getPrecision() / evaluationResults.size(); + averagesLabelStats.get(label).setPrecision(avgPrecision); + + double avgRecall = labelResult.getRecall() / evaluationResults.size(); + averagesLabelStats.get(label).setRecall(avgRecall); + + sb.append(labelResult.toString()); + } + + OptionalDouble averageF1 = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageF1()).average(); + OptionalDouble averagePrecision = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAveragePrecision()).average(); + OptionalDouble averageRecall = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageRecall()).average(); + OptionalDouble averageAccuracy = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageAccuracy()).average(); + + double avgAccuracy = averageAccuracy.orElseGet(() -> { + throw new GrobidException("Missing average accuracy. Something went wrong. Please check. "); + }); + + double avgF1 = averageF1.orElseGet(() -> { + throw new GrobidException("Missing average F1. Something went wrong. Please check. "); + }); + + double avgPrecision = averagePrecision.orElseGet(() -> { + throw new GrobidException("Missing average precision. Something went wrong. Please check. "); + }); + + double avgRecall = averageRecall.orElseGet(() -> { + throw new GrobidException("Missing average recall. Something went wrong. Please check. "); + }); + + sb.append("\n"); + + sb.append(String.format("%-20s %-12s %-12s %-12s %-7s\n", + "all (macro avg.)", + TextUtilities.formatTwoDecimals(avgAccuracy * 100), + TextUtilities.formatTwoDecimals(avgPrecision * 100), + TextUtilities.formatTwoDecimals(avgRecall * 100), + TextUtilities.formatTwoDecimals(avgF1 * 100)) +// String.valueOf(supportSum)) + ); + + sb.append("\n===== Instance-level results =====\n\n"); + sb.append(String.format("%-27s %d\n", "Total expected instances:", totalInstances)); + sb.append(String.format("%-27s %d\n", "Correct instances:", correctInstances)); + sb.append(String.format("%-27s %s\n", + "Instance-level recall:", + TextUtilities.formatTwoDecimals((double) correctInstances / totalInstances * 100))); + + + return sb.toString(); + } + + /** + * Partition the corpus in n folds, dump them in n files and return the pairs of (trainingPath, evaluationPath) + */ + protected List> splitNFold(List trainingData, int numberFolds) { + int trainingSize = CollectionUtils.size(trainingData); + int foldSize = Math.floorDiv(trainingSize, numberFolds); + if (foldSize == 0) { + throw new IllegalArgumentException("There aren't enough training data for n-fold evaluation with fold of size " + numberFolds); + } + + return IntStream.range(0, numberFolds).mapToObj(foldIndex -> { + int foldStart = foldSize * foldIndex; + int foldEnd = foldStart + foldSize; + + if (foldIndex == numberFolds - 1) { + foldEnd = trainingSize; + } + + List foldEvaluation = trainingData.subList(foldStart, foldEnd); + List foldTraining0 = trainingData.subList(0, foldStart); + List foldTraining1 = trainingData.subList(foldEnd, trainingSize); + List foldTraining = new ArrayList<>(); + foldTraining.addAll(foldTraining0); + foldTraining.addAll(foldTraining1); + + //Dump Evaluation + String tempEvaluationDataPath = getTempEvaluationDataPath().getAbsolutePath(); + try (BufferedWriter writer = Files.newBufferedWriter(Paths.get(tempEvaluationDataPath))) { + writer.write(String.join("\n\n", foldEvaluation)); + writer.write("\n"); + } catch (IOException e) { + throw new GrobidException("Error when dumping n-fold evaluation data into files. ", e); + } + + //Dump Training + String tempTrainingDataPath = getTempTrainingDataPath().getAbsolutePath(); + try (BufferedWriter writer = Files.newBufferedWriter(Paths.get(tempTrainingDataPath))) { + writer.write(String.join("\n\n", foldTraining)); + writer.write("\n"); + } catch (IOException e) { + throw new GrobidException("Error when dumping n-fold training data into files. ", e); + } + + return new ImmutablePair<>(tempTrainingDataPath, tempEvaluationDataPath); + }).collect(Collectors.toList()); + } + + /** + * Load the dataset in memory and shuffle it. + */ + protected List loadAndShuffle(Path dataPath) { + List trainingData = load(dataPath); + + Collections.shuffle(trainingData, new Random(839374947498L)); + + return trainingData; + } + + /** + * Read the Wapiti training files in list of String. + * Assuming that each empty line is a delimiter between instances. + * Each list element corresponds to one instance. + * Empty line are filtered out from the output. + */ + public List load(Path dataPath) { + List trainingData = new ArrayList<>(); + try (Stream stream = Files.lines(dataPath)) { + List instance = new ArrayList<>(); + ListIterator iterator = stream.collect(Collectors.toList()).listIterator(); + while (iterator.hasNext()) { + String current = iterator.next(); + + if (StringUtils.isBlank(current)) { + if (CollectionUtils.isNotEmpty(instance)) { + trainingData.add(String.join("\n", instance)); + } + instance = new ArrayList<>(); + } else { + instance.add(current); + } + } + if (CollectionUtils.isNotEmpty(instance)) { + trainingData.add(String.join("\n", instance)); + } + + } catch (IOException e) { + throw new GrobidException("Error in n-fold, when loading training data. Failing. ", e); + } + + return trainingData; } protected final File getTempTrainingDataPath() { @@ -149,7 +492,7 @@ protected GenericTagger getTagger() { protected static File getFilePath2Resources() { File theFile = new File(GrobidProperties.get_GROBID_HOME_PATH().getAbsoluteFile() + File.separator + ".." + File.separator - + "grobid-trainer" + File.separator + "resources"); + + "grobid-trainer" + File.separator + "resources"); if (!theFile.exists()) { theFile = new File("resources"); } @@ -174,7 +517,7 @@ protected File getEvalCorpusPath() { public static File getEvalCorpusBasePath() { final String path2Evelutation = getFilePath2Resources().getAbsolutePath() + File.separator + "dataset" + File.separator + "patent" - + File.separator + "evaluation"; + + File.separator + "evaluation"; return new File(path2Evelutation); } @@ -195,28 +538,92 @@ public File getEvalDataPath() { return evalDataPath; } - public static void runEvaluation(final Trainer trainer) { + public static String runEvaluation(final Trainer trainer, boolean includeRawResults) { long start = System.currentTimeMillis(); + String report = ""; try { - String report = trainer.evaluate(); - System.out.println(report); + report = trainer.evaluate(includeRawResults); } catch (Exception e) { throw new GrobidException("An exception occurred while evaluating Grobid.", e); } long end = System.currentTimeMillis(); - System.out.println("Evaluation for " + trainer.getModel() + " model is realized in " + (end - start) + " ms"); + report += "\n\nEvaluation for " + trainer.getModel() + " model is realized in " + (end - start) + " ms"; + + return report; + } + + public static String runEvaluation(final Trainer trainer) { + return trainer.evaluate(false); } - public static void runSplitTrainingEvaluation(final Trainer trainer, Double split) { + public static String runSplitTrainingEvaluation(final Trainer trainer, Double split) { long start = System.currentTimeMillis(); + String report = ""; try { - String report = trainer.splitTrainEvaluate(split); - System.out.println(report); + report = trainer.splitTrainEvaluate(split); + + } catch (Exception e) { + throw new GrobidException("An exception occurred while evaluating Grobid.", e); + } + long end = System.currentTimeMillis(); + report += "\n\nSplit, training and evaluation for " + trainer.getModel() + " model is realized in " + (end - start) + " ms"; + + return report; + } + + public static void runNFoldEvaluation(final Trainer trainer, int numFolds, Path outputFile) { + runNFoldEvaluation(trainer, numFolds, outputFile, false); + } + + public static void runNFoldEvaluation(final Trainer trainer, int numFolds, Path outputFile, boolean includeRawResults) { + + String report = runNFoldEvaluation(trainer, numFolds, includeRawResults); + + try (BufferedWriter writer = Files.newBufferedWriter(outputFile)) { + writer.write(report); + writer.write("\n"); + } catch (IOException e) { + throw new GrobidException("Error when dumping n-fold training data into files. ", e); + } + + } + + public static String runNFoldEvaluation(final Trainer trainer, int numFolds) { + return runNFoldEvaluation(trainer, numFolds, false); + } + + public static String runNFoldEvaluation(final Trainer trainer, int numFolds, boolean includeRawResults) { + long start = System.currentTimeMillis(); + String report = ""; + try { + report = trainer.nFoldEvaluate(numFolds, includeRawResults); + } catch (Exception e) { throw new GrobidException("An exception occurred while evaluating Grobid.", e); } long end = System.currentTimeMillis(); - System.out.println("Split, training and evaluation for " + trainer.getModel() + " model is realized in " + (end - start) + " ms"); + report += "\n\nN-Fold evaluation for " + trainer.getModel() + " model is realized in " + (end - start) + " ms"; + + return report; + } + + /** + * Dispatch the example to the training or test data, based on the split ration and the drawing of + * a random number + */ + public Writer dispatchExample(Writer writerTraining, Writer writerEvaluation, double splitRatio) { + Writer writer = null; + if ((writerTraining == null) && (writerEvaluation != null)) { + writer = writerEvaluation; + } else if ((writerTraining != null) && (writerEvaluation == null)) { + writer = writerTraining; + } else { + if (Math.random() <= splitRatio) + writer = writerTraining; + else + writer = writerEvaluation; + } + return writer; } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AffiliationAddressTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AffiliationAddressTrainer.java index da3816df49..458a68d666 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AffiliationAddressTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AffiliationAddressTrainer.java @@ -175,7 +175,7 @@ public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); Trainer trainer = new AffiliationAddressTrainer(); AbstractTrainer.runTraining(trainer); - AbstractTrainer.runEvaluation(trainer); + System.out.println(AbstractTrainer.runEvaluation(trainer)); System.exit(0); } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/CRFPPGenericTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/CRFPPGenericTrainer.java index b7105e3f0b..9a966ce236 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/CRFPPGenericTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/CRFPPGenericTrainer.java @@ -13,6 +13,8 @@ * * User: zholudev * Date: 3/20/14 + * + * @deprecated use WapitiTrainer or DelftTrainer (requires http://github.com/kermitt2/delft) */ @Deprecated public class CRFPPGenericTrainer implements GenericTrainer { diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/ChemicalEntityTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/ChemicalEntityTrainer.java index 05ae7333f3..c5b3d13d7d 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/ChemicalEntityTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/ChemicalEntityTrainer.java @@ -431,7 +431,7 @@ public void addFeatures(List texts, public static void main(String[] args) { Trainer trainer = new ChemicalEntityTrainer(); AbstractTrainer.runTraining(trainer); - AbstractTrainer.runEvaluation(trainer); + System.out.println(AbstractTrainer.runEvaluation(trainer)); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/CitationTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/CitationTrainer.java index d900aaac3f..c2cd493474 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/CitationTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/CitationTrainer.java @@ -195,7 +195,7 @@ public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); Trainer trainer = new CitationTrainer(); AbstractTrainer.runTraining(trainer); - AbstractTrainer.runEvaluation(trainer); + System.out.println(AbstractTrainer.runEvaluation(trainer)); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java index 1d3385b21f..6241238e0f 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java @@ -24,8 +24,8 @@ public DateTrainer() { } /** - * Add the selected features to a date example set - * + * Add the selected features to a date example set + * * @param corpusDir * a path where corpus files are located * @param trainingOutputPath @@ -38,8 +38,8 @@ public int createCRFPPData(final File corpusDir, final File trainingOutputPath) } /** - * Add the selected features to a date example set - * + * Add the selected features to a date example set + * * @param corpusDir * a path where corpus files are located * @param trainingOutputPath @@ -47,13 +47,13 @@ public int createCRFPPData(final File corpusDir, final File trainingOutputPath) * @param evalOutputPath * path where to store the temporary evaluation data * @param splitRatio - * ratio to consider for separating training and evaluation data, e.g. 0.8 for 80% - * @return the total number of used corpus items + * ratio to consider for separating training and evaluation data, e.g. 0.8 for 80% + * @return the total number of used corpus items */ @Override - public int createCRFPPData(final File corpusDir, - final File trainingOutputPath, - final File evalOutputPath, + public int createCRFPPData(final File corpusDir, + final File trainingOutputPath, + final File evalOutputPath, double splitRatio) { int totalExamples = 0; try { @@ -69,7 +69,7 @@ public int createCRFPPData(final File corpusDir, public boolean accept(File dir, String name) { return name.endsWith(".xml"); } - }); + }); if (refFiles == null) { throw new IllegalStateException("Folder " + corpusDir.getAbsolutePath() @@ -85,7 +85,7 @@ public boolean accept(File dir, String name) { os2 = new FileOutputStream(trainingOutputPath); writer2 = new OutputStreamWriter(os2, "UTF8"); } - + // the file for writing the evaluation data OutputStream os3 = null; Writer writer3 = null; @@ -93,7 +93,7 @@ public boolean accept(File dir, String name) { os3 = new FileOutputStream(evalOutputPath); writer3 = new OutputStreamWriter(os3, "UTF8"); } - + // get a factory for SAX parser SAXParserFactory spf = SAXParserFactory.newInstance(); @@ -114,22 +114,22 @@ public boolean accept(File dir, String name) { // we can now add the features String headerDates = FeaturesVectorDate.addFeaturesDate(labeled); - + // format with features for sequence tagging... // given the split ratio we write either in the training file or the evaluation file String[] chunks = headerDates.split("\n \n"); - + for(int i=0; i= observed) - // return 0.0; - //return (double) (observed - (falsePositive + falseNegative) ) / (observed); - return (double) observed / (falsePositive + observed); - } - - public double getRecall() { - if (expected == 0.0) - return 0.0; - // if ((falsePositive + falseNegative) >= observed) - // return 0.0; - //return (double) (observed - (falsePositive + falseNegative) ) / (expected); - return (double) observed / (expected); - } - - public double getF1Score() { - double precision = getPrecision(); - double recall = getRecall(); - - if ( (precision == 0) && (recall == 0.0) ) - return 0.0; - - return (2.0 * precision * recall) / (precision + recall); - } - - @Override - public String toString() { - StringBuilder builder = new StringBuilder(); - builder - .append("falsePositive: ").append(falsePositive) - .append("; falseNegative: ").append(falseNegative) - .append("; observed: ").append(observed) - .append("; expected: ").append(expected); - return builder.toString(); - } -} diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/MonographTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/MonographTrainer.java index 159aaa93bb..1f95974ea9 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/MonographTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/MonographTrainer.java @@ -162,7 +162,7 @@ public boolean accept(File dir, String name) { public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); AbstractTrainer.runTraining(new MonographTrainer()); - AbstractTrainer.runEvaluation(new MonographTrainer()); + System.out.println(AbstractTrainer.runEvaluation(new MonographTrainer())); System.exit(0); } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/NameCitationTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/NameCitationTrainer.java index 1720aea4a7..e127208ba8 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/NameCitationTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/NameCitationTrainer.java @@ -170,7 +170,7 @@ public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); Trainer trainer = new NameCitationTrainer(); AbstractTrainer.runTraining(trainer); - AbstractTrainer.runEvaluation(trainer); + System.out.println(AbstractTrainer.runEvaluation(trainer)); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/NameHeaderTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/NameHeaderTrainer.java index a5f2b5db11..7edcbdd720 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/NameHeaderTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/NameHeaderTrainer.java @@ -229,7 +229,7 @@ public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); Trainer trainer = new NameHeaderTrainer(); AbstractTrainer.runTraining(trainer); - AbstractTrainer.runEvaluation(trainer); + System.out.println(AbstractTrainer.runEvaluation(trainer)); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/ReferenceSegmenterTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/ReferenceSegmenterTrainer.java index e65fda701d..39aff5cfa4 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/ReferenceSegmenterTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/ReferenceSegmenterTrainer.java @@ -214,7 +214,7 @@ public boolean accept(File dir, String name) { public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); AbstractTrainer.runTraining(new ReferenceSegmenterTrainer()); - AbstractTrainer.runEvaluation(new ReferenceSegmenterTrainer()); + System.out.println(AbstractTrainer.runEvaluation(new ReferenceSegmenterTrainer())); System.exit(0); } } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/SegmentationTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/SegmentationTrainer.java index ba9b2a57b8..7fdd04e536 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/SegmentationTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/SegmentationTrainer.java @@ -348,7 +348,7 @@ public boolean accept(File dir, String name) { public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); AbstractTrainer.runTraining(new SegmentationTrainer()); - AbstractTrainer.runEvaluation(new SegmentationTrainer()); + System.out.println(AbstractTrainer.runEvaluation(new SegmentationTrainer())); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/ShorttextTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/ShorttextTrainer.java index 4403972be1..36d07604d1 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/ShorttextTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/ShorttextTrainer.java @@ -162,7 +162,7 @@ public boolean accept(File dir, String name) { public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); AbstractTrainer.runTraining(new ShorttextTrainer()); - AbstractTrainer.runEvaluation(new ShorttextTrainer()); + System.out.println(AbstractTrainer.runEvaluation(new ShorttextTrainer())); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/Stats.java b/grobid-trainer/src/main/java/org/grobid/trainer/Stats.java deleted file mode 100644 index e8b0dbee94..0000000000 --- a/grobid-trainer/src/main/java/org/grobid/trainer/Stats.java +++ /dev/null @@ -1,99 +0,0 @@ -package org.grobid.trainer; - -import java.util.Set; -import java.util.TreeMap; - -import org.grobid.core.exceptions.*; - -public final class Stats { - private final TreeMap labelStats; - - public Stats() { - this.labelStats = new TreeMap<>(); - } - - public Set getLabels() { - return this.labelStats.keySet(); - } - - public void incrementFalsePositive(String label) { - this.incrementFalsePositive(label, 1); - } - - public void incrementFalsePositive(String label, int count) { - LabelStat labelStat = this.getLabelStat(label); - if (labelStat == null) - throw new GrobidException("Unknown label: " + label); - labelStat.incrementFalsePositive(count); - } - - public void incrementFalseNegative(String label) { - this.incrementFalseNegative(label, 1); - } - - public void incrementFalseNegative(String label, int count) { - LabelStat labelStat = this.getLabelStat(label); - if (labelStat == null) - throw new GrobidException("Unknown label: " + label); - labelStat.incrementFalseNegative(count); - } - - public void incrementObserved(String label) { - this.incrementObserved(label, 1); - } - - public void incrementObserved(String label, int count) { - LabelStat labelStat = this.getLabelStat(label); - if (labelStat == null) - throw new GrobidException("Unknown label: " + label); - labelStat.incrementObserved(count); - } - - public void incrementExpected(String label) { - this.incrementExpected(label, 1); - } - - public void incrementExpected(String label, int count) { - LabelStat labelStat = this.getLabelStat(label); - if (labelStat == null) - throw new GrobidException("Unknown label: " + label); - labelStat.incrementExpected(count); - } - - public LabelStat getLabelStat(String label) { - if (this.labelStats.containsKey(label)) { - return this.labelStats.get(label); - } - - LabelStat labelStat = LabelStat.create(); - this.labelStats.put(label, labelStat); - - return labelStat; - } - - public int size() { - return this.labelStats.size(); - } - - public double getPrecision(String label) { - LabelStat labelStat = this.getLabelStat(label); - if (labelStat == null) - throw new GrobidException("Unknown label: " + label); - return labelStat.getPrecision(); - } - - public double getRecall(String label) { - LabelStat labelStat = this.getLabelStat(label); - if (labelStat == null) - throw new GrobidException("Unknown label: " + label); - return labelStat.getRecall(); - } - - public double getF1Score(String label) { - LabelStat labelStat = this.getLabelStat(label); - if (labelStat == null) - throw new GrobidException("Unknown label: " + label); - return labelStat.getF1Score(); - } -} - diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/TableTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/TableTrainer.java index 30e53fba28..9269906971 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/TableTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/TableTrainer.java @@ -221,7 +221,7 @@ public boolean accept(File dir, String name) { public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); AbstractTrainer.runTraining(new TableTrainer()); - AbstractTrainer.runEvaluation(new TableTrainer()); + System.out.println(AbstractTrainer.runEvaluation(new TableTrainer())); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/Trainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/Trainer.java index 2bff8058f2..c42535e030 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/Trainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/Trainer.java @@ -2,6 +2,7 @@ import org.grobid.core.GrobidModel; import org.grobid.core.GrobidModels; +import org.grobid.core.engines.tagging.GenericTagger; import java.io.File; @@ -16,13 +17,17 @@ public interface Trainer { void train(); - /** - * - * @return a report - */ String evaluate(); + String evaluate(boolean includeRawResults); + + String evaluate(GenericTagger tagger, boolean includeRawResults); + String splitTrainEvaluate(Double split); + String nFoldEvaluate(int folds); + + String nFoldEvaluate(int folds, boolean includeRawResults); + GrobidModel getModel(); } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/TrainerFactory.java b/grobid-trainer/src/main/java/org/grobid/trainer/TrainerFactory.java index b31ffb5eed..f9dc1b0dc3 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/TrainerFactory.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/TrainerFactory.java @@ -15,6 +15,8 @@ public static GenericTrainer getTrainer() { return new WapitiTrainer(); case DELFT: return new DeLFTTrainer(); + case DUMMY: + return new DummyTrainer(); default: throw new IllegalStateException("Unsupported Grobid sequence labelling engine: " + GrobidProperties.getGrobidCRFEngine()); } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java b/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java index 7a765bc162..67b5524973 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java @@ -1,139 +1,172 @@ package org.grobid.trainer; +import org.apache.commons.lang3.StringUtils; import org.grobid.core.utilities.GrobidProperties; import java.io.File; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.List; /** * Training application for training a target model. - * + * * @author Patrice Lopez */ public class TrainerRunner { - private enum RunType { - TRAIN, EVAL, SPLIT; - - public static RunType getRunType(int i) { - for (RunType t : values()) { - if (t.ordinal() == i) { - return t; - } - } - - throw new IllegalStateException("Unsupported RunType with ordinal " + i); - } - } - - /** - * Initialize the batch. - */ - protected static void initProcess(final String path2GbdHome, final String path2GbdProperties) { - GrobidProperties.getInstance(); - } - - /** - * Command line execution. - * - * @param args - * Command line arguments. - */ - public static void main(String[] args) { - if (args.length < 4) { - throw new IllegalStateException( - "Usage: {0 - train, 1 - evaluate, 2 - split, train and evaluate} {affiliation,chemical,date,citation,ebook,fulltext,header,name-citation,name-header,patent} -gH /path/to/Grobid/home -s { [0.0 - 1.0] - split ratio, optional}"); - } - - RunType mode = RunType.getRunType(Integer.parseInt(args[0])); - if ( (mode == RunType.SPLIT) && (args.length < 6) ) { - throw new IllegalStateException( - "Usage: {0 - train, 1 - evaluate, 2 - split, train and evaluate} {affiliation,chemical,date,citation,ebook,fulltext,header,name-citation,name-header,patent} -gH /path/to/Grobid/home -s { [0.0 - 1.0] - split ratio, optional}"); - } - - String path2GbdHome = null; - Double split = 0.0; - for (int i = 0; i < args.length; i++) { - if (args[i].equals("-gH")) { - if (i+1 == args.length) { - throw new IllegalStateException("Missing path to Grobid home. "); - } - path2GbdHome = args[i + 1]; - } - else if (args[i].equals("-s")) { - if (i+1 == args.length) { - throw new IllegalStateException("Missing split ratio value. "); - } - String splitRatio = args[i + 1]; - try { - split = Double.parseDouble(args[i + 1]); - } - catch(Exception e) { - throw new IllegalStateException("Invalid split value: " + args[i + 1]); - } - - } - } - - if (path2GbdHome == null) { - throw new IllegalStateException( - "Usage: {0 - train, 1 - evaluate, 2 - split, train and evaluate} {affiliation,chemical,date,citation,ebook,fulltext,header,name-citation,name-header,patent} -gH /path/to/Grobid/home -s { [0.0 - 1.0] - split ratio, optional}"); - } - - final String path2GbdProperties = path2GbdHome + File.separator + "config" + File.separator + "grobid.properties"; - - System.out.println("path2GbdHome=" + path2GbdHome + " path2GbdProperties=" + path2GbdProperties); - initProcess(path2GbdHome, path2GbdProperties); - - String model = args[1]; - - AbstractTrainer trainer; - - if (model.equals("affiliation") || model.equals("affiliation-address")) { - trainer = new AffiliationAddressTrainer(); - } else if (model.equals("chemical")) { - trainer = new ChemicalEntityTrainer(); - } else if (model.equals("date")) { - trainer = new DateTrainer(); - } else if (model.equals("citation")) { - trainer = new CitationTrainer(); - } else if (model.equals("monograph")) { - trainer = new MonographTrainer(); - } else if (model.equals("fulltext")) { - trainer = new FulltextTrainer(); - } else if (model.equals("header")) { - trainer = new HeaderTrainer(); - } else if (model.equals("name-citation")) { - trainer = new NameCitationTrainer(); - } else if (model.equals("name-header")) { - trainer = new NameHeaderTrainer(); - } else if (model.equals("patent")) { - trainer = new PatentParserTrainer(); - } else if (model.equals("segmentation")) { - trainer = new SegmentationTrainer(); - } else if (model.equals("reference-segmenter")) { + private static final List models = Arrays.asList("affiliation", "chemical", "date", "citation", "ebook", "fulltext", "header", "name-citation", "name-header", "patent"); + private static final List options = Arrays.asList("0 - train", "1 - evaluate", "2 - split, train and evaluate", "3 - n-fold evaluation"); + + private enum RunType { + TRAIN, EVAL, SPLIT, EVAL_N_FOLD; + + public static RunType getRunType(int i) { + for (RunType t : values()) { + if (t.ordinal() == i) { + return t; + } + } + + throw new IllegalStateException("Unsupported RunType with ordinal " + i); + } + } + + protected static void initProcess(final String path2GbdHome, final String path2GbdProperties) { + GrobidProperties.getInstance(); + } + + public static void main(String[] args) { + if (args.length < 4) { + throw new IllegalStateException( + "Usage: {" + String.join(", ", options) + "} {" + String.join(", ", models) + "} -gH /path/to/Grobid/home -s { [0.0 - 1.0] - split ratio, optional} -n {[int, num folds for n-fold evaluation, optional]}"); + } + + RunType mode = RunType.getRunType(Integer.parseInt(args[0])); + if ((mode == RunType.SPLIT || mode == RunType.EVAL_N_FOLD) && (args.length < 6)) { + throw new IllegalStateException( + "Usage: {" + String.join(", ", options) + "} {" + String.join(", ", models) + "} -gH /path/to/Grobid/home -s { [0.0 - 1.0] - split ratio, optional} -n {[int, num folds for n-fold evaluation, optional]}"); + } + + String path2GbdHome = null; + double split = 0.0; + int numFolds = 0; + String outputFilePath = null; + for (int i = 0; i < args.length; i++) { + if (args[i].equals("-gH")) { + if (i + 1 == args.length) { + throw new IllegalStateException("Missing path to Grobid home. "); + } + path2GbdHome = args[i + 1]; + } else if (args[i].equals("-s")) { + if (i + 1 == args.length) { + throw new IllegalStateException("Missing split ratio value. "); + } + try { + split = Double.parseDouble(args[i + 1]); + } catch (Exception e) { + throw new IllegalStateException("Invalid split value: " + args[i + 1]); + } + + } else if (args[i].equals("-n")) { + if (i + 1 == args.length) { + throw new IllegalStateException("Missing number of folds value. "); + } + try { + numFolds = Integer.parseInt(args[i + 1]); + } catch (Exception e) { + throw new IllegalStateException("Invalid number of folds value: " + args[i + 1]); + } + + } else if (args[i].equals("-o")) { + if (i + 1 == args.length) { + throw new IllegalStateException("Missing output file. "); + } + outputFilePath = args[i + 1]; + + } + } + + if (path2GbdHome == null) { + throw new IllegalStateException( + "Grobid-home path not found.\n Usage: {" + String.join(", ", options) + "} {" + String.join(", ", models) + "} -gH /path/to/Grobid/home -s { [0.0 - 1.0] - split ratio, optional} -n {[int, num folds for n-fold evaluation, optional]}"); + } + + final String path2GbdProperties = path2GbdHome + File.separator + "config" + File.separator + "grobid.properties"; + + System.out.println("path2GbdHome=" + path2GbdHome + " path2GbdProperties=" + path2GbdProperties); + initProcess(path2GbdHome, path2GbdProperties); + + String model = args[1]; + + AbstractTrainer trainer; + + if (model.equals("affiliation") || model.equals("affiliation-address")) { + trainer = new AffiliationAddressTrainer(); + } else if (model.equals("chemical")) { + trainer = new ChemicalEntityTrainer(); + } else if (model.equals("date")) { + trainer = new DateTrainer(); + } else if (model.equals("citation")) { + trainer = new CitationTrainer(); + } else if (model.equals("monograph")) { + trainer = new MonographTrainer(); + } else if (model.equals("fulltext")) { + trainer = new FulltextTrainer(); + } else if (model.equals("header")) { + trainer = new HeaderTrainer(); + } else if (model.equals("name-citation")) { + trainer = new NameCitationTrainer(); + } else if (model.equals("name-header")) { + trainer = new NameHeaderTrainer(); + } else if (model.equals("patent")) { + trainer = new PatentParserTrainer(); + } else if (model.equals("segmentation")) { + trainer = new SegmentationTrainer(); + } else if (model.equals("reference-segmenter")) { trainer = new ReferenceSegmenterTrainer(); } else if (model.equals("figure")) { - trainer = new FigureTrainer(); - } else if (model.equals("table")) { - trainer = new TableTrainer(); - } else { - throw new IllegalStateException("The model " + model + " is unknown."); - } - - switch (mode) { - case TRAIN: - AbstractTrainer.runTraining(trainer); - break; - case EVAL: - AbstractTrainer.runEvaluation(trainer); - break; - case SPLIT: - AbstractTrainer.runSplitTrainingEvaluation(trainer, split); - break; - default: - throw new IllegalStateException("Invalid RunType: " + mode.name()); - } - System.exit(0); - } + trainer = new FigureTrainer(); + } else if (model.equals("table")) { + trainer = new TableTrainer(); + } else { + throw new IllegalStateException("The model " + model + " is unknown."); + } + + switch (mode) { + case TRAIN: + AbstractTrainer.runTraining(trainer); + break; + case EVAL: + System.out.println(AbstractTrainer.runEvaluation(trainer)); + break; + case SPLIT: + System.out.println(AbstractTrainer.runSplitTrainingEvaluation(trainer, split)); + break; + case EVAL_N_FOLD: + if(numFolds == 1) { + throw new IllegalArgumentException("N should be > 1"); + } else { + if(numFolds == 0) { + numFolds = 10; + } + } + if (StringUtils.isNotEmpty(outputFilePath)) { + Path outputPath = Paths.get(outputFilePath); + if (Files.exists(outputPath)) { + System.err.println("Output file exists. "); + } + } else { + String results = AbstractTrainer.runNFoldEvaluation(trainer, numFolds); + System.out.println(results); + } + break; + default: + throw new IllegalStateException("Invalid RunType: " + mode.name()); + } + System.exit(0); + } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EndToEndEvaluation.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EndToEndEvaluation.java index c6ef5620b1..d4acb63d10 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EndToEndEvaluation.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EndToEndEvaluation.java @@ -1,17 +1,11 @@ package org.grobid.trainer.evaluation; import org.grobid.core.engines.config.GrobidAnalysisConfig; -import org.grobid.core.engines.tagging.GenericTagger; import org.grobid.core.exceptions.*; import org.grobid.core.engines.Engine; -import org.grobid.core.data.BiblioItem; -import org.grobid.core.data.BibDataSet; import org.grobid.core.factory.GrobidFactory; import org.grobid.core.utilities.GrobidProperties; import org.grobid.core.utilities.UnicodeUtil; -import org.grobid.trainer.Stats; -import org.grobid.trainer.sax.NLMHeaderSaxHandler; -import org.grobid.trainer.sax.FieldExtractSaxHandler; import org.grobid.core.utilities.TextUtilities; import org.grobid.trainer.evaluation.utilities.NamespaceContextMap; import org.grobid.trainer.evaluation.utilities.FieldSpecification; @@ -26,7 +20,7 @@ import javax.xml.xpath.XPathFactory; import javax.xml.parsers.*; import org.xml.sax.*; -import org.xml.sax.helpers.*; + import javax.xml.xpath.XPathConstants; import com.rockymadden.stringmetric.similarity.RatcliffObershelpMetric; diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationDOIMatching.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationDOIMatching.java index 7c19528f62..1fc72b3e90 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationDOIMatching.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationDOIMatching.java @@ -1,48 +1,38 @@ package org.grobid.trainer.evaluation; -import org.grobid.core.engines.config.GrobidAnalysisConfig; -import org.grobid.core.engines.tagging.GenericTagger; -import org.grobid.core.exceptions.*; -import org.grobid.core.engines.Engine; -import org.grobid.core.data.BiblioItem; +import com.fasterxml.jackson.core.io.JsonStringEncoder; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.commons.io.FileUtils; import org.grobid.core.data.BibDataSet; +import org.grobid.core.data.BiblioItem; +import org.grobid.core.engines.Engine; +import org.grobid.core.exceptions.GrobidResourceException; import org.grobid.core.factory.GrobidFactory; +import org.grobid.core.utilities.Consolidation.GrobidConsolidationService; import org.grobid.core.utilities.GrobidProperties; -import org.grobid.core.utilities.UnicodeUtil; -import org.grobid.trainer.Stats; -import org.grobid.trainer.sax.NLMHeaderSaxHandler; -import org.grobid.trainer.sax.FieldExtractSaxHandler; import org.grobid.core.utilities.TextUtilities; import org.grobid.trainer.evaluation.utilities.NamespaceContextMap; -import org.grobid.trainer.evaluation.utilities.FieldSpecification; -import org.grobid.core.utilities.Consolidation.GrobidConsolidationService; - -import java.io.*; -import java.util.*; -import java.text.Normalizer; -import java.util.regex.Pattern; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; - -import org.apache.commons.io.FileUtils; - -import org.w3c.dom.*; +import org.w3c.dom.Document; +import org.w3c.dom.Node; +import org.w3c.dom.NodeList; +import org.xml.sax.EntityResolver; +import org.xml.sax.InputSource; + +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.parsers.SAXParserFactory; import javax.xml.xpath.XPath; -import javax.xml.xpath.XPathFactory; -import javax.xml.parsers.*; -import org.xml.sax.*; -import org.xml.sax.helpers.*; import javax.xml.xpath.XPathConstants; - -import com.rockymadden.stringmetric.similarity.RatcliffObershelpMetric; -import scala.Option; - -import com.fasterxml.jackson.core.io.*; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.core.*; -import com.fasterxml.jackson.databind.*; -import com.fasterxml.jackson.databind.node.*; +import javax.xml.xpath.XPathFactory; +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.FilenameFilter; +import java.text.Normalizer; +import java.util.*; +import java.util.regex.Pattern; /** * Evaluation of the DOI matching for the extracted bibliographical references, diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java index 725246e4a3..fdb91895d6 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java @@ -1,373 +1,343 @@ package org.grobid.trainer.evaluation; -import com.google.common.base.Function; import org.chasen.crfpp.Tagger; import org.grobid.core.engines.tagging.GenericTagger; import org.grobid.core.exceptions.GrobidException; -import org.grobid.core.utilities.TextUtilities; -import org.grobid.core.utilities.GrobidProperties; import org.grobid.core.utilities.OffsetPosition; import org.grobid.core.utilities.Pair; -import org.grobid.core.engines.tagging.GrobidCRFEngine; import java.io.BufferedReader; import java.io.FileInputStream; -import java.io.File; import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import java.util.StringTokenizer; - -import org.grobid.trainer.LabelStat; -import org.grobid.trainer.Stats; - -import org.apache.commons.io.FileUtils; +import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static org.grobid.core.engines.tagging.GenericTaggerUtils.getPlainLabel; + /** - * Generic evaluation of a single-CRF model processing given an expected result. - * + * Generic evaluation of a single-CRF model processing given an expected result. + * * @author Patrice Lopez */ public class EvaluationUtilities { - protected static final Logger logger = LoggerFactory.getLogger(EvaluationUtilities.class); - - /** - * Method for running a CRF tagger for evaluation purpose (i.e. with - * expected and actual labels). - * - * @param ress - * list - * @param tagger - * a tagger - * @return a report - */ - public static String taggerRun(List ress, Tagger tagger) { - // clear internal context - tagger.clear(); - StringBuilder res = new StringBuilder(); - - // we have to re-inject the pre-tags because they are removed by the JNI - // parse method - ArrayList pretags = new ArrayList(); - // add context - for (String piece : ress) { - if (piece.trim().length() == 0) { - // parse and change internal stated as 'parsed' - if (!tagger.parse()) { - // throw an exception - throw new RuntimeException("CRF++ parsing failed."); - } - - for (int i = 0; i < tagger.size(); i++) { - for (int j = 0; j < tagger.xsize(); j++) { - res.append(tagger.x(i, j)).append("\t"); - } - res.append(pretags.get(i)).append("\t"); - res.append(tagger.y2(i)); - res.append("\n"); - } - res.append(" \n"); - // clear internal context - tagger.clear(); - pretags = new ArrayList(); - } else { - tagger.add(piece); - tagger.add("\n"); - // get last tag - StringTokenizer tokenizer = new StringTokenizer(piece, " \t"); - while (tokenizer.hasMoreTokens()) { - String toke = tokenizer.nextToken(); - if (!tokenizer.hasMoreTokens()) { - pretags.add(toke); - } - } - } - } - - // parse and change internal stated as 'parsed' - if (!tagger.parse()) { - // throw an exception - throw new RuntimeException("CRF++ parsing failed."); - } - - for (int i = 0; i < tagger.size(); i++) { - for (int j = 0; j < tagger.xsize(); j++) { - res.append(tagger.x(i, j)).append("\t"); - } - res.append(pretags.get(i)).append("\t"); - res.append(tagger.y2(i)); - res.append(System.lineSeparator()); - } - res.append(System.lineSeparator()); - - return res.toString(); - } - - public static String evaluateStandard(String path, final GenericTagger tagger) { - return evaluateStandard(path, new Function, String>() { - @Override - public String apply(List strings) { - return tagger.label(strings); + protected static final Logger logger = LoggerFactory.getLogger(EvaluationUtilities.class); + + /** + * Method for running a CRF tagger for evaluation purpose (i.e. with + * expected and actual labels). + * + * @param ress list + * @param tagger a tagger + * @return a report + */ + public static String taggerRun(List ress, Tagger tagger) { + // clear internal context + tagger.clear(); + StringBuilder res = new StringBuilder(); + + // we have to re-inject the pre-tags because they are removed by the JNI + // parse method + ArrayList pretags = new ArrayList<>(); + // add context + for (String piece : ress) { + if (piece.trim().length() == 0) { + // parse and change internal stated as 'parsed' + if (!tagger.parse()) { + // throw an exception + throw new RuntimeException("CRF++ parsing failed."); + } + + for (int i = 0; i < tagger.size(); i++) { + for (int j = 0; j < tagger.xsize(); j++) { + res.append(tagger.x(i, j)).append("\t"); + } + res.append(pretags.get(i)).append("\t"); + res.append(tagger.y2(i)); + res.append("\n"); + } + res.append(" \n"); + // clear internal context + tagger.clear(); + pretags = new ArrayList<>(); + } else { + tagger.add(piece); + tagger.add("\n"); + // get last tag + StringTokenizer tokenizer = new StringTokenizer(piece, " \t"); + while (tokenizer.hasMoreTokens()) { + String toke = tokenizer.nextToken(); + if (!tokenizer.hasMoreTokens()) { + pretags.add(toke); + } + } + } + } + + // parse and change internal stated as 'parsed' + if (!tagger.parse()) { + // throw an exception + throw new RuntimeException("CRF++ parsing failed."); + } + + for (int i = 0; i < tagger.size(); i++) { + for (int j = 0; j < tagger.xsize(); j++) { + res.append(tagger.x(i, j)).append("\t"); } - }); + res.append(pretags.get(i)).append("\t"); + res.append(tagger.y2(i)); + res.append(System.lineSeparator()); + } + res.append(System.lineSeparator()); + + return res.toString(); } - public static String evaluateStandard(String path, Function, String> taggerFunction) { - String theResult = null; + public static ModelStats evaluateStandard(String path, final GenericTagger tagger) { + return evaluateStandard(path, tagger::label); + } - try { - final BufferedReader bufReader = new BufferedReader(new InputStreamReader(new FileInputStream(path), "UTF-8")); + public static ModelStats evaluateStandard(String path, Function, String> taggerFunction) { + String theResult = null; - String line = null; - List citationBlocks = new ArrayList(); - while ((line = bufReader.readLine()) != null) { - citationBlocks.add(line); - } - long time = System.currentTimeMillis(); - theResult = taggerFunction.apply(citationBlocks); - bufReader.close(); + try { + final BufferedReader bufReader = new BufferedReader(new InputStreamReader(new FileInputStream(path), StandardCharsets.UTF_8)); + + String line = null; + List instance = new ArrayList<>(); + while ((line = bufReader.readLine()) != null) { + instance.add(line); + } + long time = System.currentTimeMillis(); + theResult = taggerFunction.apply(instance); + bufReader.close(); System.out.println("Labeling took: " + (System.currentTimeMillis() - time) + " ms"); } catch (Exception e) { - throw new GrobidException("An exception occurred while evaluating Grobid.", e); - } - - return reportMetrics(theResult); - } - - public static String reportMetrics(String theResult) { - StringBuilder report = new StringBuilder(); - - Stats wordStats = tokenLevelStats(theResult); - - // report token-level results - report.append("\n===== Token-level results =====\n\n"); - report.append(computeMetrics(wordStats)); - - Stats fieldStats = fieldLevelStats(theResult); - - report.append("\n===== Field-level results =====\n"); - report.append(computeMetrics(fieldStats)); - - report.append("\n===== Instance-level results =====\n\n"); - // instance-level: instances are separated by a new line in the result file - // third pass - theResult = theResult.replace("\n\n", "\n \n"); - StringTokenizer stt = new StringTokenizer(theResult, "\n"); - boolean allGood = true; - int correctInstance = 0; - int totalInstance = 0; - String line = null; - while (stt.hasMoreTokens()) { - line = stt.nextToken(); - if ((line.trim().length() == 0) || (!stt.hasMoreTokens())) { - // instance done - totalInstance++; - if (allGood) { - correctInstance++; - } - // we reinit for a new instance - allGood = true; - } else { - StringTokenizer st = new StringTokenizer(line, "\t "); - String obtainedLabel = null; - String expectedLabel = null; - while (st.hasMoreTokens()) { - obtainedLabel = getPlainLabel(st.nextToken()); - if (st.hasMoreTokens()) { - expectedLabel = obtainedLabel; - } - } - - if (!obtainedLabel.equals(expectedLabel)) { - // one error is enough to have the whole instance false, damn! - allGood = false; - } - } - } - - report.append(String.format("%-27s %d\n", "Total expected instances:", totalInstance)); - report.append(String.format("%-27s %d\n", "Correct instances:", correctInstance)); - double accuracy = (double) correctInstance / (totalInstance); - report.append(String.format("%-27s %s\n", - "Instance-level recall:", - TextUtilities.formatTwoDecimals(accuracy * 100))); - - return report.toString(); - } - - public static Stats tokenLevelStats(String theResult) { - Stats wordStats = new Stats(); - String line = null; - StringTokenizer stt = new StringTokenizer(theResult, System.lineSeparator()); - while (stt.hasMoreTokens()) { - line = stt.nextToken(); - - if (line.trim().length() == 0) { - continue; - } - // the two last tokens, separated by a tabulation, gives the - // expected label and, last, the resulting label -> for Wapiti - StringTokenizer st = new StringTokenizer(line, "\t "); - String obtainedLabel = null; - String expectedLabel = null; - - while (st.hasMoreTokens()) { - obtainedLabel = getPlainLabel(st.nextToken()); - if (st.hasMoreTokens()) { - expectedLabel = obtainedLabel; - } - } - - if ((expectedLabel == null) || (obtainedLabel == null)) { - continue; - } - - processCounters(wordStats, obtainedLabel, expectedLabel); + throw new GrobidException("An exception occurred while evaluating Grobid.", e); + } + + return computeStats(theResult); + } + + public static ModelStats computeStats(String theResult) { + ModelStats modelStats = new ModelStats(); + modelStats.setRawResults(theResult); + // report token-level results +// Stats wordStats = tokenLevelStats(theResult); +// modelStats.setTokenStats(wordStats); + + // report field-level results + Stats fieldStats = fieldLevelStats(theResult); + modelStats.setFieldStats(fieldStats); + + // instance-level: instances are separated by a new line in the result file + // third pass + theResult = theResult.replace("\n\n", "\n \n"); + StringTokenizer stt = new StringTokenizer(theResult, "\n"); + boolean allGood = true; + int correctInstance = 0; + int totalInstance = 0; + String line = null; + while (stt.hasMoreTokens()) { + line = stt.nextToken(); + if ((line.trim().length() == 0) || (!stt.hasMoreTokens())) { + // instance done + totalInstance++; + if (allGood) { + correctInstance++; + } + // we reinit for a new instance + allGood = true; + } else { + StringTokenizer st = new StringTokenizer(line, "\t "); + String obtainedLabel = null; + String expectedLabel = null; + while (st.hasMoreTokens()) { + obtainedLabel = getPlainLabel(st.nextToken()); + if (st.hasMoreTokens()) { + expectedLabel = obtainedLabel; + } + } + + if (!obtainedLabel.equals(expectedLabel)) { + // one error is enough to have the whole instance false, damn! + allGood = false; + } + } + } + + modelStats.setTotalInstances(totalInstance); + modelStats.setCorrectInstance(correctInstance); + + return modelStats; + } + + public static Stats tokenLevelStats(String theResult) { + Stats wordStats = new Stats(); + String line = null; + StringTokenizer stt = new StringTokenizer(theResult, System.lineSeparator()); + while (stt.hasMoreTokens()) { + line = stt.nextToken(); + + if (line.trim().length() == 0) { + continue; + } + // the two last tokens, separated by a tabulation, gives the + // expected label and, last, the resulting label -> for Wapiti + StringTokenizer st = new StringTokenizer(line, "\t "); + String obtainedLabel = null; + String expectedLabel = null; + + while (st.hasMoreTokens()) { + obtainedLabel = getPlainLabel(st.nextToken()); + if (st.hasMoreTokens()) { + expectedLabel = obtainedLabel; + } + } + + if ((expectedLabel == null) || (obtainedLabel == null)) { + continue; + } + + processCounters(wordStats, obtainedLabel, expectedLabel); /*if (!obtainedLabel.equals(expectedLabel)) { - logger.warn("Disagreement / expected: " + expectedLabel + " / obtained: " + obtainedLabel); + logger.warn("Disagreement / expected: " + expectedLabel + " / obtained: " + obtainedLabel); }*/ - } - return wordStats; - } - - public static Stats fieldLevelStats(String theResult) { - Stats fieldStats = new Stats(); - - // field: a field is simply a sequence of token with the same label - - // we build first the list of fields in expected and obtained result - // with offset positions - List> expectedFields = new ArrayList>(); - List> obtainedFields = new ArrayList>(); - StringTokenizer stt = new StringTokenizer(theResult, System.lineSeparator()); - String line = null; - String previousExpectedLabel = null; - String previousObtainedLabel = null; - int pos = 0; // current token index - OffsetPosition currentObtainedPosition = new OffsetPosition(); - currentObtainedPosition.start = 0; - OffsetPosition currentExpectedPosition = new OffsetPosition(); - currentExpectedPosition.start = 0; - String obtainedLabel = null; - String expectedLabel = null; - while (stt.hasMoreTokens()) { - line = stt.nextToken(); - obtainedLabel = null; - expectedLabel = null; - StringTokenizer st = new StringTokenizer(line, "\t "); - while (st.hasMoreTokens()) { - obtainedLabel = st.nextToken(); - if (st.hasMoreTokens()) { - expectedLabel = obtainedLabel; - } - } - - if ( (obtainedLabel == null) || (expectedLabel == null) ) - continue; - - if ((previousObtainedLabel != null) && - (!obtainedLabel.equals(getPlainLabel(previousObtainedLabel)))) { - // new obtained field - currentObtainedPosition.end = pos - 1; - Pair theField = new Pair(getPlainLabel(previousObtainedLabel), - currentObtainedPosition); - currentObtainedPosition = new OffsetPosition(); - currentObtainedPosition.start = pos; - obtainedFields.add(theField); - } - - if ((previousExpectedLabel != null) && - (!expectedLabel.equals(getPlainLabel(previousExpectedLabel)))) { - // new expected field - currentExpectedPosition.end = pos - 1; - Pair theField = new Pair(getPlainLabel(previousExpectedLabel), - currentExpectedPosition); - currentExpectedPosition = new OffsetPosition(); - currentExpectedPosition.start = pos; - expectedFields.add(theField); - } - - previousExpectedLabel = expectedLabel; - previousObtainedLabel = obtainedLabel; - pos++; - } - // last fields of the sequence - if ((previousObtainedLabel != null)) { - currentObtainedPosition.end = pos - 1; - Pair theField = new Pair(getPlainLabel(previousObtainedLabel), - currentObtainedPosition); - obtainedFields.add(theField); - } - - if ((previousExpectedLabel != null)) { - currentExpectedPosition.end = pos - 1; - Pair theField = new Pair(getPlainLabel(previousExpectedLabel), - currentExpectedPosition); - expectedFields.add(theField); - } - - // we then simply compared the positions and labels of the two fields and update - // statistics - int obtainedFieldIndex = 0; - List> matchedObtainedFields = new ArrayList>(); - for(Pair expectedField : expectedFields) { - expectedLabel = expectedField.getA(); - int expectedStart = expectedField.getB().start; - int expectedEnd = expectedField.getB().end; - - LabelStat labelStat = fieldStats.getLabelStat(getPlainLabel(expectedLabel)); - labelStat.incrementExpected(); - - // try to find a match in the obtained fields - boolean found = false; - for(int i=obtainedFieldIndex; i obtainedField :obtainedFields) { - if (!matchedObtainedFields.contains(obtainedField)) { - obtainedLabel = obtainedField.getA(); - LabelStat labelStat = fieldStats.getLabelStat(getPlainLabel(obtainedLabel)); - labelStat.incrementFalsePositive(); - } - } - - return fieldStats; - } - - - private static String getPlainLabel(String label) { - if (label == null) - return null; - if (label.startsWith("I-") || label.startsWith("E-") || label.startsWith("B-")) { - return label.substring(2, label.length()); - } else - return label; + } + return wordStats; + } + + public static Stats fieldLevelStats(String theResult) { + Stats fieldStats = new Stats(); + + // field: a field is simply a sequence of token with the same label + + // we build first the list of fields in expected and obtained result + // with offset positions + List> expectedFields = new ArrayList<>(); + List> obtainedFields = new ArrayList<>(); + StringTokenizer stt = new StringTokenizer(theResult, System.lineSeparator()); + String line = null; + String previousExpectedLabel = null; + String previousObtainedLabel = null; + int pos = 0; // current token index + OffsetPosition currentObtainedPosition = new OffsetPosition(); + currentObtainedPosition.start = 0; + OffsetPosition currentExpectedPosition = new OffsetPosition(); + currentExpectedPosition.start = 0; + String obtainedLabel = null; + String expectedLabel = null; + while (stt.hasMoreTokens()) { + line = stt.nextToken(); + obtainedLabel = null; + expectedLabel = null; + StringTokenizer st = new StringTokenizer(line, "\t "); + while (st.hasMoreTokens()) { + obtainedLabel = st.nextToken(); + if (st.hasMoreTokens()) { + expectedLabel = obtainedLabel; + } + } + + if ((obtainedLabel == null) || (expectedLabel == null)) + continue; + + if ((previousObtainedLabel != null) && + (!obtainedLabel.equals(getPlainLabel(previousObtainedLabel)))) { + // new obtained field + currentObtainedPosition.end = pos - 1; + Pair theField = new Pair<>(getPlainLabel(previousObtainedLabel), + currentObtainedPosition); + currentObtainedPosition = new OffsetPosition(); + currentObtainedPosition.start = pos; + obtainedFields.add(theField); + } + + if ((previousExpectedLabel != null) && + (!expectedLabel.equals(getPlainLabel(previousExpectedLabel)))) { + // new expected field + currentExpectedPosition.end = pos - 1; + Pair theField = new Pair<>(getPlainLabel(previousExpectedLabel), + currentExpectedPosition); + currentExpectedPosition = new OffsetPosition(); + currentExpectedPosition.start = pos; + expectedFields.add(theField); + } + + previousExpectedLabel = expectedLabel; + previousObtainedLabel = obtainedLabel; + pos++; + } + // last fields of the sequence + if ((previousObtainedLabel != null)) { + currentObtainedPosition.end = pos - 1; + Pair theField = new Pair<>(getPlainLabel(previousObtainedLabel), + currentObtainedPosition); + obtainedFields.add(theField); + } + + if ((previousExpectedLabel != null)) { + currentExpectedPosition.end = pos - 1; + Pair theField = new Pair<>(getPlainLabel(previousExpectedLabel), + currentExpectedPosition); + expectedFields.add(theField); + } + + // we then simply compared the positions and labels of the two fields and update + // statistics + int obtainedFieldIndex = 0; + List> matchedObtainedFields = new ArrayList>(); + for (Pair expectedField : expectedFields) { + expectedLabel = expectedField.getA(); + int expectedStart = expectedField.getB().start; + int expectedEnd = expectedField.getB().end; + + LabelStat labelStat = fieldStats.getLabelStat(getPlainLabel(expectedLabel)); + labelStat.incrementExpected(); + + // try to find a match in the obtained fields + boolean found = false; + for (int i = obtainedFieldIndex; i < obtainedFields.size(); i++) { + obtainedLabel = obtainedFields.get(i).getA(); + if (!expectedLabel.equals(obtainedLabel)) + continue; + if ((expectedStart == obtainedFields.get(i).getB().start) && + (expectedEnd == obtainedFields.get(i).getB().end)) { + // we have a match + labelStat.incrementObserved(); // TP + found = true; + obtainedFieldIndex = i; + matchedObtainedFields.add(obtainedFields.get(i)); + break; + } + // if we went too far, we can stop the pain + if (expectedEnd < obtainedFields.get(i).getB().start) { + break; + } + } + if (!found) { + labelStat.incrementFalseNegative(); + } + } + + // all the obtained fields without match in the expected fields are false positive + for (Pair obtainedField : obtainedFields) { + if (!matchedObtainedFields.contains(obtainedField)) { + obtainedLabel = obtainedField.getA(); + LabelStat labelStat = fieldStats.getLabelStat(getPlainLabel(obtainedLabel)); + labelStat.incrementFalsePositive(); + } + } + + return fieldStats; } + private static void processCounters(Stats stats, String obtained, String expected) { LabelStat expectedStat = stats.getLabelStat(expected); LabelStat obtainedStat = stats.getLabelStat(obtained); @@ -382,153 +352,8 @@ private static void processCounters(Stats stats, String obtained, String expecte } } - public static String computeMetrics(Stats stats) { - StringBuilder report = new StringBuilder(); - report.append(String.format("\n%-20s %-12s %-12s %-12s %-7s\n\n", - "label", - "accuracy", - "precision", - "recall", - "f1")); - - int cumulated_tp = 0; - int cumulated_fp = 0; - int cumulated_tn = 0; - int cumulated_fn = 0; - double cumulated_f0 = 0.0; - double cumulated_accuracy = 0.0; - double cumulated_precision = 0.0; - double cumulated_recall = 0.0; - int cumulated_all = 0; - int totalValidFields = 0; - - int totalFields = 0; - for (String label : stats.getLabels()) { - LabelStat labelStat = stats.getLabelStat(label); - totalFields += labelStat.getObserved(); - totalFields += labelStat.getFalseNegative(); - totalFields += labelStat.getFalsePositive(); - } - - for (String label : stats.getLabels()) { - if (label.equals("") || label.equals("base") || label.equals("O")) { - continue; - } - - LabelStat labelStat = stats.getLabelStat(label); - int tp = labelStat.getObserved(); // true positives - int fp = labelStat.getFalsePositive(); // false positives - int fn = labelStat.getFalseNegative(); // false negative - int tn = totalFields - tp - (fp + fn); // true negatives - int all = labelStat.getExpected(); // all expected - - if (all != 0) { - totalValidFields++; - } - - double accuracy = (double) (tp + tn) / (tp + fp + tn + fn); - if (accuracy < 0.0) - accuracy = 0.0; - - double precision; - if ((tp + fp) == 0) { - precision = 0.0; - } else { - precision = (double) (tp) / (tp + fp); - } - - double recall; - if ((tp == 0) || (all == 0)) { - recall = 0.0; - } else { - recall = (double) (tp) / all; - } - - double f0; - if (precision + recall == 0) { - f0 = 0.0; - } else { - f0 = (2 * precision * recall) / (precision + recall); - } - - report.append(String.format("%-20s %-12s %-12s %-12s %-7s\n", - label, - TextUtilities.formatTwoDecimals(accuracy * 100), - TextUtilities.formatTwoDecimals(precision * 100), - TextUtilities.formatTwoDecimals(recall * 100), - TextUtilities.formatTwoDecimals(f0 * 100))); - - cumulated_tp += tp; - cumulated_fp += fp; - cumulated_tn += tn; - cumulated_fn += fn; - if (all != 0) { - cumulated_all += all; - cumulated_f0 += f0; - cumulated_accuracy += accuracy; - cumulated_precision += precision; - cumulated_recall += recall; - } - } - - report.append("\n"); - - // micro average over measures - double accuracy = 0.0; - if (cumulated_tp + cumulated_fp + cumulated_tn + cumulated_fn != 0.0) - accuracy = (double) (cumulated_tp + cumulated_tn) / (cumulated_tp + cumulated_fp + cumulated_tn + cumulated_fn); - accuracy = Math.min(1.0, accuracy); - - double precision = 0.0; - if (cumulated_tp + cumulated_fp != 0) - precision = (double) cumulated_tp / (cumulated_tp + cumulated_fp); - precision = Math.min(1.0, precision); - - //recall = ((double) cumulated_tp) / (cumulated_tp + cumulated_fn); - double recall = 0.0; - if (cumulated_all != 0.0) - recall = ((double) cumulated_tp) / (cumulated_all); - recall = Math.min(1.0, recall); - - double f0 = 0.0; - if (precision + recall != 0.0) - f0 = (2 * precision * recall) / (precision + recall); - - report.append(String.format("%-20s %-12s %-12s %-12s %-7s (micro average)\n", - "all fields", - TextUtilities.formatTwoDecimals(accuracy * 100), - TextUtilities.formatTwoDecimals(precision * 100), - TextUtilities.formatTwoDecimals(recall * 100), - TextUtilities.formatTwoDecimals(f0 * 100))); - - // macro average over measures - if (totalValidFields == 0) - accuracy = 0.0; - else - accuracy = Math.min(1.0, cumulated_accuracy / (totalValidFields)); - - if (totalValidFields == 0) - precision = 0.0; - else - precision = Math.min(1.0, cumulated_precision / totalValidFields); - - if (totalValidFields == 0) - recall = 0.0; - else - recall = Math.min(1.0, cumulated_recall / totalValidFields); - - if (totalValidFields == 0) - f0 = 0.0; - else - f0 = Math.min(1.0, cumulated_f0 / totalValidFields); - - report.append(String.format("%-20s %-12s %-12s %-12s %-7s (macro average)\n", - "", - TextUtilities.formatTwoDecimals(accuracy * 100), - TextUtilities.formatTwoDecimals(precision * 100), - TextUtilities.formatTwoDecimals(recall * 100), - TextUtilities.formatTwoDecimals(f0 * 100))); - - return report.toString(); - } + public static String computeMetrics(Stats stats) { + return stats.getOldReport(); + } + } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelResult.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelResult.java new file mode 100644 index 0000000000..7a8ce608a8 --- /dev/null +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelResult.java @@ -0,0 +1,73 @@ +package org.grobid.trainer.evaluation; + +import org.grobid.core.utilities.TextUtilities; + +public class LabelResult { + + private final String label; + private double accuracy; + private double precision; + private double recall; + private double f1Score; + private long support; + + public LabelResult(String label) { + this.label = label; + } + + public void setAccuracy(double accuracy) { + this.accuracy = accuracy; + } + + public double getAccuracy() { + return accuracy; + } + + public String getLabel() { + return label; + } + + public void setPrecision(double precision) { + this.precision = precision; + } + + public double getPrecision() { + return precision; + } + + public void setRecall(double recall) { + this.recall = recall; + } + + public double getRecall() { + return recall; + } + + public void setF1Score(double f1Score) { + this.f1Score = f1Score; + } + + public double getF1Score() { + return f1Score; + } + + public void setSupport(long support) { + this.support = support; + + } + + public String toString() { + return String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", + label, + TextUtilities.formatTwoDecimals(getAccuracy() * 100), + TextUtilities.formatTwoDecimals(getPrecision() * 100), + TextUtilities.formatTwoDecimals(getRecall() * 100), + TextUtilities.formatTwoDecimals(getF1Score() * 100), + String.valueOf(getSupport()) + ); + } + + public long getSupport() { + return support; + } +} diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelStat.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelStat.java new file mode 100644 index 0000000000..4a47e47ec1 --- /dev/null +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelStat.java @@ -0,0 +1,153 @@ +package org.grobid.trainer.evaluation; + +/** Model the results for each label **/ +public final class LabelStat { + private int falsePositive = 0; + private int falseNegative = 0; + private int observed = 0; // this is true positives + private int expected = 0; // total expected number of items with this label + + private double accuracy = 0.0; + private int trueNegative; + private boolean hasChanged = false; + + public void incrementFalseNegative() { + this.incrementFalseNegative(1); + hasChanged = true; + } + + public void incrementFalsePositive() { + this.incrementFalsePositive(1); + hasChanged = true; + } + + public void incrementObserved() { + this.incrementObserved(1); + hasChanged = true; + } + + public void incrementExpected() { + this.incrementExpected(1); + hasChanged = true; + } + + public void incrementFalseNegative(int count) { + this.falseNegative += count; + hasChanged = true; + } + + public void incrementFalsePositive(int count) { + this.falsePositive += count; + hasChanged = true; + } + + public void incrementObserved(int count) { + this.observed += count; + hasChanged = true; + } + + public void incrementExpected(int count) { + this.expected += count; + hasChanged = true; + } + + public int getExpected() { + return this.expected; + } + + public int getFalseNegative() { + return this.falseNegative; + } + + public int getFalsePositive() { + return this.falsePositive; + } + + public int getObserved() { + return this.observed; + } + + public int getAll() { + return observed + falseNegative + falsePositive; + } + + public void setFalsePositive(int falsePositive) { + this.falsePositive = falsePositive; + hasChanged = true; + } + + public void setFalseNegative(int falseNegative) { + this.falseNegative = falseNegative; + hasChanged = true; + } + + public void setObserved(int observed) { + this.observed = observed; + hasChanged = true; + } + + public void setExpected(int expected) { + this.expected = expected; + hasChanged = true; + } + + public static LabelStat create() { + return new LabelStat(); + } + + public double getAccuracy() { + double accuracy = (double) (observed + trueNegative) / (observed + falsePositive + trueNegative + falseNegative); + if (accuracy < 0.0) + accuracy = 0.0; + + return accuracy; + } + + public long getSupport() { + return expected; + } + + public double getPrecision() { + if (observed == 0.0) { + return 0.0; + } + return ((double) observed) / (falsePositive + observed); + } + + public double getRecall() { + if (expected == 0.0) + return 0.0; + return ((double) observed) / (expected); + } + + public double getF1Score() { + double precision = getPrecision(); + double recall = getRecall(); + + if ((precision == 0.0) && (recall == 0.0)) + return 0.0; + + return (2.0 * precision * recall) / (precision + recall); + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + builder + .append("falsePositive: ").append(falsePositive) + .append("; falseNegative: ").append(falseNegative) + .append("; observed: ").append(observed) + .append("; expected: ").append(expected); + return builder.toString(); + } + + public void setTrueNegative(int trueNegative) { + this.trueNegative = trueNegative; + } + + public boolean hasChanged() { + Boolean oldValue = new Boolean(hasChanged); + hasChanged = false; + return oldValue; + } +} diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java new file mode 100644 index 0000000000..86847b62f6 --- /dev/null +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java @@ -0,0 +1,131 @@ +package org.grobid.trainer.evaluation; + +import org.grobid.core.utilities.TextUtilities; + +import java.util.Map; + +/** + * Represent all different evaluation given a specific model + */ +public class ModelStats { + private int totalInstances; + private int correctInstance; + // private Stats tokenStats; + private Stats fieldStats; + private String rawResults; + + + public void setTotalInstances(int totalInstances) { + this.totalInstances = totalInstances; + } + + public int getTotalInstances() { + return totalInstances; + } + + public void setCorrectInstance(int correctInstance) { + this.correctInstance = correctInstance; + } + + public int getCorrectInstance() { + return correctInstance; + } + +// public void setTokenStats(Stats tokenStats) { +// this.tokenStats = tokenStats; +// } + +// public Stats getTokenStats() { +// return tokenStats; +// } + + public void setFieldStats(Stats fieldStats) { + this.fieldStats = fieldStats; + } + + public Stats getFieldStats() { + return fieldStats; + } + + public double getInstanceRecall() { + if (getTotalInstances() <= 0) { + return 0.0d; + } + return (double) getCorrectInstance() / (getTotalInstances()); + } + + public String toString() { + return toString(false); + } + + public String toString(boolean includeRawResults) { + StringBuilder report = new StringBuilder(); + + if(includeRawResults) { + report.append("=== START RAW RESULTS ===").append("\n"); + report.append(getRawResults()).append("\n"); + report.append("=== END RAw RESULTS ===").append("\n").append("\n"); + } + + + Stats fieldStats = getFieldStats(); + report.append("\n===== Field-level results =====\n"); + report.append(String.format("\n%-20s %-12s %-12s %-12s %-12s %-7s\n\n", + "label", + "accuracy", + "precision", + "recall", + "f1", + "support")); + + + for (Map.Entry labelResult : fieldStats.getLabelsResults().entrySet()) { + report.append(labelResult.getValue()); + } + + report.append("\n"); + + report.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", + "all (micro avg.)", + TextUtilities.formatTwoDecimals(fieldStats.getMicroAverageAccuracy() * 100), + TextUtilities.formatTwoDecimals(fieldStats.getMicroAveragePrecision() * 100), + TextUtilities.formatTwoDecimals(fieldStats.getMicroAverageRecall() * 100), + TextUtilities.formatTwoDecimals(fieldStats.getMicroAverageF1() * 100), + String.valueOf(getSupportSum()))); + + report.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", + "all (macro avg.)", + TextUtilities.formatTwoDecimals(fieldStats.getMacroAverageAccuracy() * 100), + TextUtilities.formatTwoDecimals(fieldStats.getMacroAveragePrecision() * 100), + TextUtilities.formatTwoDecimals(fieldStats.getMacroAverageRecall() * 100), + TextUtilities.formatTwoDecimals(fieldStats.getMacroAverageF1() * 100), + String.valueOf(getSupportSum()))); + + + // instance-level: instances are separated by a new line in the result file + report.append("\n===== Instance-level results =====\n\n"); + report.append(String.format("%-27s %d\n", "Total expected instances:", getTotalInstances())); + report.append(String.format("%-27s %d\n", "Correct instances:", getCorrectInstance())); + report.append(String.format("%-27s %s\n", + "Instance-level recall:", + TextUtilities.formatTwoDecimals(getInstanceRecall() * 100))); + + return report.toString(); + } + + public long getSupportSum() { + long supportSum = 0; + for (LabelResult labelResult : fieldStats.getLabelsResults().values()) { + supportSum += labelResult.getSupport(); + } + return supportSum; + } + + public String getRawResults() { + return rawResults; + } + + public void setRawResults(String rawResults) { + this.rawResults = rawResults; + } +} diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java new file mode 100644 index 0000000000..468a96d4da --- /dev/null +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java @@ -0,0 +1,345 @@ +package org.grobid.trainer.evaluation; + +import java.util.Set; +import java.util.TreeMap; + +import org.grobid.core.exceptions.*; +import org.grobid.core.utilities.TextUtilities; + +/** + * Contains the single statistic computation for evaluation + */ +public final class Stats { + private final TreeMap labelStats; + + // State variable to know whether is required to recompute the statistics + private boolean requiredToRecomputeMetrics = true; + + private double cumulated_tp = 0; + private double cumulated_fp = 0; + private double cumulated_tn = 0; + private double cumulated_fn = 0; + private double cumulated_f1 = 0.0; + private double cumulated_accuracy = 0.0; + private double cumulated_precision = 0.0; + private double cumulated_recall = 0.0; + private double cumulated_expected = 0; + private int totalValidFields = 0; + + public Stats() { + this.labelStats = new TreeMap<>(); + } + + public Set getLabels() { + return this.labelStats.keySet(); + } + + public void incrementFalsePositive(String label) { + this.incrementFalsePositive(label, 1); + } + + public void incrementFalsePositive(String label, int count) { + LabelStat labelStat = this.getLabelStat(label); + if (labelStat == null) + throw new GrobidException("Unknown label: " + label); + labelStat.incrementFalsePositive(count); + requiredToRecomputeMetrics = true; + } + + public void incrementFalseNegative(String label) { + this.incrementFalseNegative(label, 1); + } + + public void incrementFalseNegative(String label, int count) { + LabelStat labelStat = this.getLabelStat(label); + if (labelStat == null) + throw new GrobidException("Unknown label: " + label); + labelStat.incrementFalseNegative(count); + requiredToRecomputeMetrics = true; + } + + public void incrementObserved(String label) { + this.incrementObserved(label, 1); + } + + public void incrementObserved(String label, int count) { + LabelStat labelStat = this.getLabelStat(label); + if (labelStat == null) + throw new GrobidException("Unknown label: " + label); + labelStat.incrementObserved(count); + requiredToRecomputeMetrics = true; + } + + public void incrementExpected(String label) { + this.incrementExpected(label, 1); + } + + public void incrementExpected(String label, int count) { + LabelStat labelStat = this.getLabelStat(label); + if (labelStat == null) + throw new GrobidException("Unknown label: " + label); + labelStat.incrementExpected(count); + requiredToRecomputeMetrics = true; + } + + public LabelStat getLabelStat(String label) { + if (this.labelStats.containsKey(label)) { + return this.labelStats.get(label); + } + + LabelStat labelStat = LabelStat.create(); + this.labelStats.put(label, labelStat); + requiredToRecomputeMetrics = true; + + return labelStat; + } + + public int size() { + return this.labelStats.size(); + } + + public double getPrecision(String label) { + LabelStat labelStat = this.getLabelStat(label); + if (labelStat == null) + throw new GrobidException("Unknown label: " + label); + return labelStat.getPrecision(); + } + + public double getRecall(String label) { + LabelStat labelStat = this.getLabelStat(label); + if (labelStat == null) + throw new GrobidException("Unknown label: " + label); + return labelStat.getRecall(); + } + + public double getF1Score(String label) { + LabelStat labelStat = this.getLabelStat(label); + if (labelStat == null) + throw new GrobidException("Unknown label: " + label); + return labelStat.getF1Score(); + } + + /** + * In order to compute metrics in an efficient way, they are computed all at the same time. + * Since the state of the object is important in this case, it's required to have a flag that + * allow the recompute of the metrics when one is required. + */ + public void computeMetrics() { + for (String label : getLabels()) { + if (getLabelStat(label).hasChanged()) { + requiredToRecomputeMetrics = true; + break; + } + } + + if (!requiredToRecomputeMetrics) + return; + + int totalFields = 0; + for (String label : getLabels()) { + LabelStat labelStat = getLabelStat(label); + totalFields += labelStat.getObserved(); + totalFields += labelStat.getFalseNegative(); + totalFields += labelStat.getFalsePositive(); + } + + for (String label : getLabels()) { + if (label.equals("") || label.equals("base") || label.equals("O")) { + continue; + } + + LabelStat labelStat = getLabelStat(label); + int tp = labelStat.getObserved(); // true positives + int fp = labelStat.getFalsePositive(); // false positives + int fn = labelStat.getFalseNegative(); // false negative + int tn = totalFields - tp - (fp + fn); // true negatives + labelStat.setTrueNegative(tn); + int expected = labelStat.getExpected(); // all expected + + if (expected != 0) { + totalValidFields++; + } + + if (expected != 0) { + cumulated_tp += tp; + cumulated_fp += fp; + cumulated_tn += tn; + cumulated_fn += fn; + + cumulated_expected += expected; + cumulated_f1 += labelStat.getF1Score(); + cumulated_accuracy += labelStat.getAccuracy(); + cumulated_precision += labelStat.getPrecision(); + cumulated_recall += labelStat.getRecall(); + } + } + + requiredToRecomputeMetrics = false; + } + + public TreeMap getLabelsResults() { + computeMetrics(); + + TreeMap result = new TreeMap<>(); + + for (String label : getLabels()) { + if (label.equals("") || label.equals("base") || label.equals("O")) { + continue; + } + + LabelStat labelStat = getLabelStat(label); + LabelResult labelResult = new LabelResult(label); + labelResult.setAccuracy(labelStat.getAccuracy()); + labelResult.setPrecision(labelStat.getPrecision()); + labelResult.setRecall(labelStat.getRecall()); + labelResult.setF1Score(labelStat.getF1Score()); + labelResult.setSupport(labelStat.getSupport()); + + result.put(label, labelResult); + } + + return result; + } + + + public double getMicroAverageAccuracy() { + computeMetrics(); + + // macro average over measures + if (totalValidFields == 0) + return 0.0; + else + return Math.min(1.0, cumulated_accuracy / totalValidFields); + } + + public double getMacroAverageAccuracy() { + computeMetrics(); + + double accuracy = 0.0; + if (cumulated_tp + cumulated_fp + cumulated_tn + cumulated_fn != 0.0) + accuracy = ((double) cumulated_tp + cumulated_tn) / (cumulated_tp + cumulated_fp + cumulated_tn + cumulated_fn); + + return Math.min(1.0, accuracy); + } + + + public double getMicroAveragePrecision() { + computeMetrics(); + + double precision = 0.0; + if (cumulated_tp + cumulated_fp != 0) + precision = cumulated_tp / (cumulated_tp + cumulated_fp); + + return Math.min(1.0, precision); + } + + public double getMacroAveragePrecision() { + computeMetrics(); + + if (totalValidFields == 0) + return 0.0; + + return Math.min(1.0, cumulated_precision / totalValidFields); + } + + public double getMicroAverageRecall() { + computeMetrics(); + + double recall = 0.0; + if (cumulated_expected != 0.0) + recall = cumulated_tp / cumulated_expected; + + return Math.min(1.0, recall); + } + + public double getMacroAverageRecall() { + computeMetrics(); + + if (totalValidFields == 0) + return 0.0; + + return Math.min(1.0, cumulated_recall / totalValidFields); + } + + public int getTotalValidFields() { + computeMetrics(); + return totalValidFields; + } + + public double getMicroAverageF1() { + double precision = getMicroAveragePrecision(); + double recall = getMicroAverageRecall(); + + double f1 = 0.0; + if (precision + recall != 0.0) + f1 = (2 * precision * recall) / (precision + recall); + + return f1; + } + + public double getMacroAverageF1() { + computeMetrics(); + + if (totalValidFields == 0) + return 0.0; + + return Math.min(1.0, cumulated_f1 / totalValidFields); + } + + public String getOldReport() { + computeMetrics(); + + StringBuilder report = new StringBuilder(); + report.append(String.format("\n%-20s %-12s %-12s %-12s %-12s %-7s\n\n", + "label", + "accuracy", + "precision", + "recall", + "f1", + "support")); + + long supportSum = 0; + + for (String label : getLabels()) { + if (label.equals("") || label.equals("base") || label.equals("O")) { + continue; + } + + LabelStat labelStat = getLabelStat(label); + + long support = labelStat.getSupport(); + report.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", + label, + TextUtilities.formatTwoDecimals(labelStat.getAccuracy() * 100), + TextUtilities.formatTwoDecimals(labelStat.getPrecision() * 100), + TextUtilities.formatTwoDecimals(labelStat.getRecall() * 100), + TextUtilities.formatTwoDecimals(labelStat.getF1Score() * 100), + String.valueOf(support)) + ); + + supportSum += support; + } + + report.append("\n"); + + report.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", + "all (micro avg.)", + TextUtilities.formatTwoDecimals(getMicroAverageAccuracy() * 100), + TextUtilities.formatTwoDecimals(getMicroAveragePrecision() * 100), + TextUtilities.formatTwoDecimals(getMicroAverageRecall() * 100), + TextUtilities.formatTwoDecimals(getMicroAverageF1() * 100), + String.valueOf(supportSum))); + + report.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", + "all (macro avg.)", + TextUtilities.formatTwoDecimals(getMacroAverageAccuracy() * 100), + TextUtilities.formatTwoDecimals(getMacroAveragePrecision() * 100), + TextUtilities.formatTwoDecimals(getMacroAverageRecall() * 100), + TextUtilities.formatTwoDecimals(getMacroAverageF1() * 100), + String.valueOf(supportSum))); + + return report.toString(); + + } +} + diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/AbstractTrainerIntegrationTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/AbstractTrainerIntegrationTest.java new file mode 100644 index 0000000000..c2f92bb2a2 --- /dev/null +++ b/grobid-trainer/src/test/java/org/grobid/trainer/AbstractTrainerIntegrationTest.java @@ -0,0 +1,212 @@ +package org.grobid.trainer; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.grobid.core.GrobidModels; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.CoreMatchers.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.collection.IsCollectionWithSize.hasSize; + + +public class AbstractTrainerIntegrationTest { + + private AbstractTrainer target; + + @BeforeClass + public static void beforeClass() throws Exception { +// LibraryLoader.load(); + } + + @Before + public void setUp() throws Exception { + target = new AbstractTrainer(GrobidModels.DUMMY) { + + @Override + public int createCRFPPData(File corpusPath, File outputTrainingFile, File outputEvalFile, double splitRatio) { + // the file for writing the training data +// if (outputTrainingFile != null) { +// try (OutputStream os = new FileOutputStream(outputTrainingFile)) { +// try (Writer writer = new OutputStreamWriter(os, StandardCharsets.UTF_8)) { +// +// for (int i = 0; i < 100; i++) { +// double random = Math.random(); +// writer.write("blablabla" + random); +// writer.write("\n"); +// if (i % 10 == 0) { +// writer.write("\n"); +// } +// } +// } +// } catch (IOException e) { +// e.printStackTrace(); +// } +// } + + return 100; + } + }; + } + + @After + public void tearDown() throws Exception { + + } + + @Test + public void testLoad_shouldWork() throws Exception { + Path path = Paths.get("src/test/resources/sample.wapiti.output.date.txt"); + List expected = Arrays.asList( + "Available available A Av Ava Avai e le ble able LINESTART INITCAP NODIGIT 0 0 0 NOPUNCT I-\n" + + "online online o on onl onli e ne ine line LINEIN NOCAPS NODIGIT 0 0 0 NOPUNCT \n" + + "18 18 1 18 18 18 8 18 18 18 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I-\n" + + "January january J Ja Jan Janu y ry ary uary LINEIN INITCAP NODIGIT 0 0 1 NOPUNCT I-\n" + + "2010 2010 2 20 201 2010 0 10 010 2010 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I-", + "June june J Ju Jun June e ne une June LINESTART INITCAP NODIGIT 0 0 1 NOPUNCT I-\n" + + "16 16 1 16 16 16 6 16 16 16 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I-\n" + + ", , , , , , , , , , LINEIN ALLCAP NODIGIT 1 0 0 COMMA I-\n" + + "2008 2008 2 20 200 2008 8 08 008 2008 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I-", + "November november N No Nov Nove r er ber mber LINESTART INITCAP NODIGIT 0 0 1 NOPUNCT I-\n" + + "4 4 4 4 4 4 4 4 4 4 LINEIN NOCAPS ALLDIGIT 1 0 0 NOPUNCT I-\n" + + ", , , , , , , , , , LINEIN ALLCAP NODIGIT 1 0 0 COMMA I-\n" + + "2009 2009 2 20 200 2009 9 09 009 2009 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I-", + "Published published P Pu Pub Publ d ed hed shed LINESTART INITCAP NODIGIT 0 0 0 NOPUNCT I-\n" + + "18 18 1 18 18 18 8 18 18 18 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I-\n" + + "May may M Ma May May y ay May May LINEIN INITCAP NODIGIT 0 0 1 NOPUNCT I-\n" + + "2011 2011 2 20 201 2011 1 11 011 2011 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I-"); + + List loadedTrainingData = target.load(path); + + assertThat(loadedTrainingData, is(expected)); + } + + + @Test + public void testSplitNFold_n3_shouldWork() throws Exception { + List dummyTrainingData = new ArrayList<>(); + dummyTrainingData.add(dummyExampleGeneration("1", 3)); + dummyTrainingData.add(dummyExampleGeneration("2", 4)); + dummyTrainingData.add(dummyExampleGeneration("3", 2)); + dummyTrainingData.add(dummyExampleGeneration("4", 6)); + dummyTrainingData.add(dummyExampleGeneration("5", 6)); + dummyTrainingData.add(dummyExampleGeneration("6", 2)); + dummyTrainingData.add(dummyExampleGeneration("7", 2)); + dummyTrainingData.add(dummyExampleGeneration("8", 3)); + dummyTrainingData.add(dummyExampleGeneration("9", 3)); + dummyTrainingData.add(dummyExampleGeneration("10", 3)); + + List> splitMapping = target.splitNFold(dummyTrainingData, 3); + assertThat(splitMapping, hasSize(3)); + + assertThat(splitMapping.get(0).getLeft(), endsWith("train")); + assertThat(splitMapping.get(0).getRight(), endsWith("test")); + + //Fold 1 + List fold1Training = target.load(Paths.get(splitMapping.get(0).getLeft())); + List fold1Evaluation = target.load(Paths.get(splitMapping.get(0).getRight())); + + System.out.println(Arrays.toString(fold1Training.toArray())); + System.out.println(Arrays.toString(fold1Evaluation.toArray())); + + assertThat(fold1Training, hasSize(7)); + assertThat(fold1Evaluation, hasSize(3)); + + //Fold 2 + List fold2Training = target.load(Paths.get(splitMapping.get(1).getLeft())); + List fold2Evaluation = target.load(Paths.get(splitMapping.get(1).getRight())); + + System.out.println(Arrays.toString(fold2Training.toArray())); + System.out.println(Arrays.toString(fold2Evaluation.toArray())); + + assertThat(fold2Training, hasSize(7)); + assertThat(fold2Evaluation, hasSize(3)); + + //Fold 3 + List fold3Training = target.load(Paths.get(splitMapping.get(2).getLeft())); + List fold3Evaluation = target.load(Paths.get(splitMapping.get(2).getRight())); + + System.out.println(Arrays.toString(fold3Training.toArray())); + System.out.println(Arrays.toString(fold3Evaluation.toArray())); + + assertThat(fold3Training, hasSize(6)); + assertThat(fold3Evaluation, hasSize(4)); + + // Cleanup + splitMapping.stream().forEach(f -> { + try { + Files.delete(Paths.get(f.getRight())); + } catch (IOException e) { + e.printStackTrace(); + } + }); + splitMapping.stream().forEach(f -> { + try { + Files.delete(Paths.get(f.getLeft())); + } catch (IOException e) { + e.printStackTrace(); + } + }); + } + + @Test(expected = IllegalArgumentException.class) + public void testSplitNFold_n10_shouldThrowException() throws Exception { + List dummyTrainingData = new ArrayList<>(); + dummyTrainingData.add(dummyExampleGeneration("1", 3)); + dummyTrainingData.add(dummyExampleGeneration("2", 4)); + dummyTrainingData.add(dummyExampleGeneration("3", 2)); + dummyTrainingData.add(dummyExampleGeneration("4", 6)); + + List> splitMapping = target.splitNFold(dummyTrainingData, 10); + + } + + private String dummyExampleGeneration(String exampleId, int total) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < total; i++) { + sb.append("line " + i + " example " + exampleId).append("\n"); + } + sb.append("\n"); + return sb.toString(); + } + + @Test + public void testLoadAndShuffle_shouldWork() throws Exception { + Path path = Paths.get("src/test/resources/sample.wapiti.output.date.txt"); + List orderedTrainingData = Arrays.asList( + "Available available A Av Ava Avai e le ble able LINESTART INITCAP NODIGIT 0 0 0 NOPUNCT I-\n" + + "online online o on onl onli e ne ine line LINEIN NOCAPS NODIGIT 0 0 0 NOPUNCT \n" + + "18 18 1 18 18 18 8 18 18 18 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I-\n" + + "January january J Ja Jan Janu y ry ary uary LINEIN INITCAP NODIGIT 0 0 1 NOPUNCT I-\n" + + "2010 2010 2 20 201 2010 0 10 010 2010 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I-", + "June june J Ju Jun June e ne une June LINESTART INITCAP NODIGIT 0 0 1 NOPUNCT I-\n" + + "16 16 1 16 16 16 6 16 16 16 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I-\n" + + ", , , , , , , , , , LINEIN ALLCAP NODIGIT 1 0 0 COMMA I-\n" + + "2008 2008 2 20 200 2008 8 08 008 2008 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I-", + "November november N No Nov Nove r er ber mber LINESTART INITCAP NODIGIT 0 0 1 NOPUNCT I-\n" + + "4 4 4 4 4 4 4 4 4 4 LINEIN NOCAPS ALLDIGIT 1 0 0 NOPUNCT I-\n" + + ", , , , , , , , , , LINEIN ALLCAP NODIGIT 1 0 0 COMMA I-\n" + + "2009 2009 2 20 200 2009 9 09 009 2009 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I-", + "Published published P Pu Pub Publ d ed hed shed LINESTART INITCAP NODIGIT 0 0 0 NOPUNCT I-\n" + + "18 18 1 18 18 18 8 18 18 18 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I-\n" + + "May may M Ma May May y ay May May LINEIN INITCAP NODIGIT 0 0 1 NOPUNCT I-\n" + + "2011 2011 2 20 201 2011 1 11 011 2011 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I-"); + + List shuffledTrainingData = target.loadAndShuffle(path); + + assertThat(shuffledTrainingData, hasSize(orderedTrainingData.size())); + assertThat(shuffledTrainingData, is(not(orderedTrainingData))); + } + +} \ No newline at end of file diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java new file mode 100644 index 0000000000..69fac5d2e1 --- /dev/null +++ b/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java @@ -0,0 +1,395 @@ +package org.grobid.trainer; + +import org.grobid.trainer.evaluation.LabelStat; +import org.grobid.trainer.evaluation.Stats; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; + +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; + +public class StatsTest { + Stats target; + + @Before + public void setUp() throws Exception { + target = new Stats(); + } + + @Test + public void testPrecision_fullMatch() throws Exception { + target.getLabelStat("BAO").setExpected(4); + target.getLabelStat("BAO").setObserved(4); + assertThat(target.getLabelStat("BAO").getPrecision(), is(1.0)); + } + + @Test + public void testPrecision_noMatch() throws Exception { + target.getLabelStat("MIAO").setExpected(4); + target.getLabelStat("MIAO").setObserved(0); + target.getLabelStat("MIAO").setFalsePositive(3); + target.getLabelStat("MIAO").setFalseNegative(1); + assertThat(target.getLabelStat("MIAO").getPrecision(), is(0.0)); + assertThat(target.getLabelStat("MIAO").getSupport(), is(4L)); + } + + @Test + public void testPrecision_missingMatch() throws Exception { + // The precision stays at 1.0 because none of the observed + // is wrong (no false positives) + target.getLabelStat("CIAO").setExpected(4); + target.getLabelStat("CIAO").setObserved(1); + target.getLabelStat("CIAO").setFalseNegative(3); + assertThat(target.getLabelStat("CIAO").getPrecision(), is(1.0)); + } + + @Test + public void testPrecision_2wronglyRecognised() throws Exception { + target.getLabelStat("ZIAO").setExpected(4); + target.getLabelStat("ZIAO").setObserved(1); + target.getLabelStat("ZIAO").setFalsePositive(2); + + assertThat(target.getLabelStat("ZIAO").getPrecision(), is(0.3333333333333333)); + } + + @Test + public void testRecall_fullMatch() throws Exception { + target.getLabelStat("BAO").setExpected(4); + target.getLabelStat("BAO").setObserved(4); + assertThat(target.getLabelStat("BAO").getRecall(), is(1.0)); + } + + @Test + public void testRecall_noMatch() throws Exception { + target.getLabelStat("MIAO").setExpected(4); + target.getLabelStat("MIAO").setObserved(0); + target.getLabelStat("MIAO").setFalsePositive(3); + target.getLabelStat("MIAO").setFalseNegative(1); + assertThat(target.getLabelStat("MIAO").getRecall(), is(0.0)); + } + + @Test + public void testRecall_oneOverFour() throws Exception { + target.getLabelStat("CIAO").setExpected(4); + target.getLabelStat("CIAO").setObserved(1); + target.getLabelStat("CIAO").setFalseNegative(3); + assertThat(target.getLabelStat("CIAO").getRecall(), is(0.25)); + + } + + @Test + public void testRecall_partialMatch() throws Exception { + target.getLabelStat("ZIAO").setExpected(4); + target.getLabelStat("ZIAO").setObserved(3); + target.getLabelStat("ZIAO").setFalsePositive(1); + + assertThat(target.getLabelStat("ZIAO").getPrecision(), is(0.75)); + } + + + // Average measures + + @Test + public void testMicroAvgPrecision_shouldWork() throws Exception { + target.getLabelStat("BAO").setExpected(4); + target.getLabelStat("BAO").setObserved(4); + + target.getLabelStat("MIAO").setExpected(4); + target.getLabelStat("MIAO").setObserved(0); + target.getLabelStat("MIAO").setFalsePositive(3); + target.getLabelStat("MIAO").setFalseNegative(1); + + target.getLabelStat("CIAO").setExpected(4); + target.getLabelStat("CIAO").setObserved(1); + target.getLabelStat("CIAO").setFalseNegative(3); + + target.getLabelStat("ZIAO").setExpected(4); + target.getLabelStat("ZIAO").setObserved(0); + target.getLabelStat("ZIAO").setFalsePositive(2); + + assertThat(target.getMicroAveragePrecision(), is(((double) 4 + 0 + 1 + 0) / (4 + 0 + 3 + 1 + 0 + 2))); + } + + + @Test + public void testMacroAvgPrecision_shouldWork() throws Exception { + + target.getLabelStat("BAO").setExpected(4); + target.getLabelStat("BAO").setObserved(4); + + target.getLabelStat("MIAO").setExpected(4); + target.getLabelStat("MIAO").setObserved(0); + target.getLabelStat("MIAO").setFalsePositive(3); + target.getLabelStat("MIAO").setFalseNegative(1); + + target.getLabelStat("CIAO").setExpected(4); + target.getLabelStat("CIAO").setObserved(1); + target.getLabelStat("CIAO").setFalseNegative(3); + + target.getLabelStat("ZIAO").setExpected(4); + target.getLabelStat("ZIAO").setObserved(0); + target.getLabelStat("ZIAO").setFalsePositive(2); + + final double precisionBao = target.getLabelStat("BAO").getPrecision(); + final double precisionMiao = target.getLabelStat("MIAO").getPrecision(); + final double precisionCiao = target.getLabelStat("CIAO").getPrecision(); + final double precisionZiao = target.getLabelStat("ZIAO").getPrecision(); + assertThat(target.getMacroAveragePrecision(), + is((precisionBao + precisionMiao + precisionCiao + precisionZiao) / (4))); + } + + + @Test + public void testMicroAvgRecall_shouldWork() throws Exception { + + target.getLabelStat("BAO").setExpected(4); + target.getLabelStat("BAO").setObserved(4); //TP + + target.getLabelStat("MIAO").setExpected(4); + target.getLabelStat("MIAO").setObserved(0); + target.getLabelStat("MIAO").setFalsePositive(3); + target.getLabelStat("MIAO").setFalseNegative(1); + + target.getLabelStat("CIAO").setExpected(4); + target.getLabelStat("CIAO").setObserved(1); + target.getLabelStat("CIAO").setFalseNegative(3); + + target.getLabelStat("ZIAO").setExpected(4); + target.getLabelStat("ZIAO").setObserved(0); + target.getLabelStat("ZIAO").setFalsePositive(2); + + assertThat(target.getMicroAverageRecall(), is(((double) 4 + 0 + 1 + 0) / (4 + 4 + 4 + 4))); + } + + @Test + public void testMacroAvgRecall_shouldWork() throws Exception { + + target.getLabelStat("BAO").setExpected(4); + target.getLabelStat("BAO").setObserved(4); + + target.getLabelStat("MIAO").setExpected(4); + target.getLabelStat("MIAO").setObserved(0); + target.getLabelStat("MIAO").setFalsePositive(3); + target.getLabelStat("MIAO").setFalseNegative(1); + + target.getLabelStat("CIAO").setExpected(4); + target.getLabelStat("CIAO").setObserved(1); + target.getLabelStat("CIAO").setFalseNegative(3); + + target.getLabelStat("ZIAO").setExpected(4); + target.getLabelStat("ZIAO").setObserved(0); + target.getLabelStat("ZIAO").setFalsePositive(2); + + final double recallBao = target.getLabelStat("BAO").getRecall(); + final double recallMiao = target.getLabelStat("MIAO").getRecall(); + final double recallCiao = target.getLabelStat("CIAO").getRecall(); + final double recallZiao = target.getLabelStat("ZIAO").getRecall(); + assertThat(target.getMacroAverageRecall(), + is((recallBao + recallMiao + recallCiao + recallZiao) / (4))); + } + + @Test + public void testMicroAvgF0_shouldWork() throws Exception { + + target.getLabelStat("BAO").setExpected(4); + target.getLabelStat("BAO").setObserved(4); //TP + + target.getLabelStat("MIAO").setExpected(4); + target.getLabelStat("MIAO").setObserved(0); + target.getLabelStat("MIAO").setFalsePositive(3); + target.getLabelStat("MIAO").setFalseNegative(1); + + target.getLabelStat("CIAO").setExpected(4); + target.getLabelStat("CIAO").setObserved(1); + target.getLabelStat("CIAO").setFalseNegative(3); + + target.getLabelStat("ZIAO").setExpected(4); + target.getLabelStat("ZIAO").setObserved(0); + target.getLabelStat("ZIAO").setFalsePositive(2); + + assertThat(target.getMicroAverageF1(), + is(((double) 2 * target.getMicroAveragePrecision() + * target.getMicroAverageRecall()) + / (target.getMicroAveragePrecision() + target.getMicroAverageRecall()))); + } + + @Test + public void testMacroAvgF0_shouldWork() throws Exception { + + target.getLabelStat("BAO").setExpected(4); + target.getLabelStat("BAO").setObserved(4); + + target.getLabelStat("MIAO").setExpected(4); + target.getLabelStat("MIAO").setObserved(0); + target.getLabelStat("MIAO").setFalsePositive(3); + target.getLabelStat("MIAO").setFalseNegative(1); + + target.getLabelStat("CIAO").setExpected(4); + target.getLabelStat("CIAO").setObserved(1); + target.getLabelStat("CIAO").setFalseNegative(3); + + target.getLabelStat("ZIAO").setExpected(4); + target.getLabelStat("ZIAO").setObserved(0); + target.getLabelStat("ZIAO").setFalsePositive(2); + + final double f1Bao = target.getLabelStat("BAO").getF1Score(); + final double f1Miao = target.getLabelStat("MIAO").getF1Score(); + final double f1Ciao = target.getLabelStat("CIAO").getF1Score(); + final double f1Ziao = target.getLabelStat("ZIAO").getRecall(); + + assertThat(target.getMacroAverageF1(), + is((f1Bao + f1Miao + f1Ciao + f1Ziao) / (4))); + } + + + @Ignore("Not really useful") + @Test + public void testMicroMacroAveragePrecision() throws Exception { + final LabelStat conceptual = target.getLabelStat("CONCEPTUAL"); + conceptual.setFalsePositive(1); + conceptual.setFalseNegative(1); + conceptual.setObserved(2); + conceptual.setExpected(3); + + final LabelStat location = target.getLabelStat("LOCATION"); + location.setFalsePositive(0); + location.setFalseNegative(0); + location.setObserved(2); + location.setExpected(2); + + final LabelStat media = target.getLabelStat("MEDIA"); + media.setFalsePositive(0); + media.setFalseNegative(0); + media.setObserved(7); + media.setExpected(7); + + final LabelStat national = target.getLabelStat("NATIONAL"); + national.setFalsePositive(1); + national.setFalseNegative(0); + national.setObserved(0); + national.setExpected(0); + + final LabelStat other = target.getLabelStat("O"); + other.setFalsePositive(0); + other.setFalseNegative(1); + other.setObserved(33); + other.setExpected(34); + + final LabelStat organisation = target.getLabelStat("ORGANISATION"); + organisation.setFalsePositive(0); + organisation.setFalseNegative(0); + organisation.setObserved(2); + organisation.setExpected(2); + + final LabelStat period = target.getLabelStat("PERIOD"); + period.setFalsePositive(0); + period.setFalseNegative(0); + period.setObserved(8); + period.setExpected(8); + + final LabelStat person = target.getLabelStat("PERSON"); + person.setFalsePositive(1); + person.setFalseNegative(0); + person.setObserved(0); + person.setExpected(0); + + final LabelStat personType = target.getLabelStat("PERSON_TYPE"); + personType.setFalsePositive(0); + personType.setFalseNegative(1); + personType.setObserved(0); + personType.setExpected(1); + + for (String label : target.getLabels()) { + System.out.println(label + " precision --> " + target.getLabelStat(label).getPrecision()); + System.out.println(label + " recall --> " + target.getLabelStat(label).getRecall()); + } + + System.out.println(target.getMacroAveragePrecision()); + System.out.println(target.getMicroAveragePrecision()); + } + + + @Test + public void testMicroMacroAverageMeasures_realTest() throws Exception { + LabelStat otherLabelStats = target.getLabelStat("O"); + otherLabelStats.setExpected(33); + otherLabelStats.setObserved(33); + otherLabelStats.setFalseNegative(1); + + assertThat(otherLabelStats.getPrecision(), is(1.0)); + assertThat(otherLabelStats.getRecall(), is(1.0)); + + + LabelStat conceptualLabelStats = target.getLabelStat("CONCEPTUAL"); + conceptualLabelStats.setObserved(2); + conceptualLabelStats.setExpected(3); + conceptualLabelStats.setFalseNegative(1); + conceptualLabelStats.setFalsePositive(1); + + assertThat(conceptualLabelStats.getPrecision(), is(0.6666666666666666)); + assertThat(conceptualLabelStats.getRecall(), is(0.6666666666666666)); + + + LabelStat periodLabelStats = target.getLabelStat("PERIOD"); + periodLabelStats.setObserved(8); + periodLabelStats.setExpected(8); + + assertThat(periodLabelStats.getPrecision(), is(1.0)); + assertThat(periodLabelStats.getRecall(), is(1.0)); + + + LabelStat mediaLabelStats = target.getLabelStat("MEDIA"); + mediaLabelStats.setObserved(7); + mediaLabelStats.setExpected(7); + + assertThat(mediaLabelStats.getPrecision(), is(1.0)); + assertThat(mediaLabelStats.getRecall(), is(1.0)); + + + LabelStat personTypeLabelStats = target.getLabelStat("PERSON_TYPE"); + personTypeLabelStats.setObserved(0); + personTypeLabelStats.setExpected(1); + personTypeLabelStats.setFalseNegative(1); + + assertThat(personTypeLabelStats.getPrecision(), is(0.0)); + assertThat(personTypeLabelStats.getRecall(), is(0.0)); + + + + LabelStat locationTypeLabelStats = target.getLabelStat("LOCATION"); + locationTypeLabelStats.setObserved(2); + locationTypeLabelStats.setExpected(2); + + assertThat(locationTypeLabelStats.getPrecision(), is(1.0)); + assertThat(locationTypeLabelStats.getRecall(), is(1.0)); + + + LabelStat organisationTypeLabelStats = target.getLabelStat("ORGANISATION"); + organisationTypeLabelStats.setObserved(2); + organisationTypeLabelStats.setExpected(2); + + assertThat(locationTypeLabelStats.getPrecision(), is(1.0)); + assertThat(locationTypeLabelStats.getRecall(), is(1.0)); + + + LabelStat personLabelStats = target.getLabelStat("PERSON"); + personLabelStats.setFalsePositive(1); + + assertThat(personLabelStats.getPrecision(), is(0.0)); + assertThat(personLabelStats.getRecall(), is(0.0)); + + // 2+8+2+2+7 / (2+8+2+2+7+1) + assertThat(target.getMicroAverageRecall(), is(0.9130434782608695)); //91.3 + + // 2+8+2+2+7 / (3+8+7+2+2+1) + assertThat(target.getMicroAveragePrecision(), is(0.9545454545454546)); //95.45 + + // 0.66 + 1.0 + 1.0 + 0.0 + 1.0 + 1.0 / 6 + assertThat(target.getMacroAverageRecall(), is(0.7777777777777777)); //77.78 + + // same as above + assertThat(target.getMacroAveragePrecision(), is(0.7777777777777777)); //77.78 + } + +} \ No newline at end of file diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java index b753e61fb9..6ce08b58fb 100644 --- a/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java +++ b/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java @@ -1,11 +1,10 @@ package org.grobid.trainer.evaluation; import org.apache.commons.io.IOUtils; -import org.grobid.trainer.LabelStat; -import org.grobid.trainer.Stats; import org.junit.Test; import java.nio.charset.StandardCharsets; +import java.util.TreeMap; import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertThat; @@ -13,130 +12,166 @@ public class EvaluationUtilitiesTest { @Test - public void testMetricsAllGood() throws Exception { + public void testTokenLevelStats_allGood() throws Exception { String result = "a I-<1> I-<1>\nb <1> <1>\nc I-<2> I-<2>\nd I-<1> I-<1>\ne <1> <1>\n"; - //System.out.println(result); Stats wordStats = EvaluationUtilities.tokenLevelStats(result); - Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); - LabelStat labelstat1 = wordStats.getLabelStat("<1>"); - LabelStat labelstat2 = wordStats.getLabelStat("<2>"); + LabelStat labelstat1 = wordStats.getLabelStat("<1>"); + LabelStat labelstat2 = wordStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(4)); assertThat(labelstat2.getObserved(), is(1)); assertThat(labelstat1.getExpected(), is(4)); assertThat(labelstat2.getExpected(), is(1)); - labelstat1 = fieldStats.getLabelStat("<1>"); - labelstat2 = fieldStats.getLabelStat("<2>"); + assertThat(labelstat1.getSupport(), is(4L)); + assertThat(labelstat2.getSupport(), is(1L)); + } + + @Test + public void testFieldLevelStats_allGood() throws Exception { + String result = "a I-<1> I-<1>\nb <1> <1>\nc I-<2> I-<2>\nd I-<1> I-<1>\ne <1> <1>\n"; + + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); + + LabelStat labelstat1 = fieldStats.getLabelStat("<1>"); + LabelStat labelstat2 = fieldStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(2)); assertThat(labelstat2.getObserved(), is(1)); assertThat(labelstat1.getExpected(), is(2)); assertThat(labelstat2.getExpected(), is(1)); + + assertThat(labelstat1.getSupport(), is(2L)); + assertThat(labelstat2.getSupport(), is(1L)); } @Test - public void testMetricsAllFalse() throws Exception { + public void testTokenLevelStats_noMatch() throws Exception { String result = "a I-<1> I-<2>\nb <1> <2>\nc <1> I-<2>\nd <1> <2>\ne <1> <2>\n"; - //System.out.println(result); + Stats wordStats = EvaluationUtilities.tokenLevelStats(result); - Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); LabelStat labelstat1 = wordStats.getLabelStat("<1>"); - LabelStat labelstat2 = wordStats.getLabelStat("<2>"); + LabelStat labelstat2 = wordStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(0)); - assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat1.getFalseNegative(), is(5)); + assertThat(labelstat1.getSupport(), is(5L)); + + assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat2.getFalsePositive(), is(5)); + assertThat(labelstat2.getSupport(), is(0L)); + } + + @Test + public void testFieldLevelStats_noMatch() throws Exception { + String result = "a I-<1> I-<2>\nb <1> <2>\nc <1> I-<2>\nd <1> <2>\ne <1> <2>\n"; + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); - labelstat1 = fieldStats.getLabelStat("<1>"); - labelstat2 = fieldStats.getLabelStat("<2>"); + LabelStat labelstat1 = fieldStats.getLabelStat("<1>"); + LabelStat labelstat2 = fieldStats.getLabelStat("<2>"); - assertThat(labelstat1.getObserved(), is(0)); - assertThat(labelstat2.getObserved(), is(0)); + assertThat(labelstat1.getObserved(), is(0)); assertThat(labelstat1.getExpected(), is(1)); + assertThat(labelstat1.getSupport(), is(1L)); + + assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat2.getExpected(), is(0)); + assertThat(labelstat2.getSupport(), is(0L)); } - @Test - public void testMetricsMixed1() throws Exception { - // label of c is false - // token 80 precision for label <1>, 0 for label <2> - // field: precision and recall are 0, because the whole - // sequence abcde with label <1> does not make sub-field - // ab and de correctly label with respect to positions + @Test + public void testTokenLevelStats_mixed() throws Exception { + // label of c is false + // token 80 precision for label <1>, 0 for label <2> String result = "a I-<1> I-<1>\nb <1> <1>\nc I-<2> <1>\nd I-<1> <1>\ne <1> <1>\n"; //System.out.println(result); - Stats wordStats = EvaluationUtilities.tokenLevelStats(result); - Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); + Stats wordStats = EvaluationUtilities.tokenLevelStats(result); LabelStat labelstat1 = wordStats.getLabelStat("<1>"); - LabelStat labelstat2 = wordStats.getLabelStat("<2>"); + LabelStat labelstat2 = wordStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(4)); - assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat1.getExpected(), is(4)); - assertThat(labelstat2.getExpected(), is(1)); assertThat(labelstat1.getFalseNegative(), is(0)); + assertThat(labelstat1.getFalsePositive(), is(1)); + assertThat(labelstat1.getSupport(), is(4L)); + + assertThat(labelstat2.getObserved(), is(0)); + assertThat(labelstat2.getExpected(), is(1)); assertThat(labelstat2.getFalseNegative(), is(1)); - assertThat(labelstat1.getFalsePositive(), is(1)); assertThat(labelstat2.getFalsePositive(), is(0)); + assertThat(labelstat2.getSupport(), is(1L)); + } - labelstat1 = fieldStats.getLabelStat("<1>"); - labelstat2 = fieldStats.getLabelStat("<2>"); + @Test + public void testFieldLevelStats_mixed() throws Exception { + // field: precision and recall are 0, because the whole + // sequence abcde with label <1> does not make sub-field + // ab and de correctly label with respect to positions + String result = "a I-<1> I-<1>\nb <1> <1>\nc I-<2> <1>\nd I-<1> <1>\ne <1> <1>\n"; + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); + + LabelStat labelstat1 = fieldStats.getLabelStat("<1>"); + LabelStat labelstat2 = fieldStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(0)); - assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat1.getExpected(), is(2)); + assertThat(labelstat1.getSupport(), is(2L)); + + assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat2.getExpected(), is(1)); + assertThat(labelstat2.getSupport(), is(1L)); } @Test - public void testMetricsMixed1Bis() throws Exception { - // variant of testMetricsMixed1 where the I- prefix impact the field-level results - // with field ab correctly found + public void testTokenLevelStats2_mixed() throws Exception { String result = "a I-<1> I-<1>\nb <1> <1>\nc I-<2> I-<1>\nd I-<1> <1>\ne <1> <1>\n"; - //System.out.println(result); - Stats wordStats = EvaluationUtilities.tokenLevelStats(result); - Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); + Stats wordStats = EvaluationUtilities.tokenLevelStats(result); LabelStat labelstat1 = wordStats.getLabelStat("<1>"); - LabelStat labelstat2 = wordStats.getLabelStat("<2>"); + LabelStat labelstat2 = wordStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(4)); - assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat1.getExpected(), is(4)); - assertThat(labelstat2.getExpected(), is(1)); assertThat(labelstat1.getFalseNegative(), is(0)); + assertThat(labelstat1.getFalsePositive(), is(1)); + assertThat(labelstat1.getSupport(), is(4L)); + + assertThat(labelstat2.getObserved(), is(0)); + assertThat(labelstat2.getExpected(), is(1)); assertThat(labelstat2.getFalseNegative(), is(1)); - assertThat(labelstat1.getFalsePositive(), is(1)); assertThat(labelstat2.getFalsePositive(), is(0)); + assertThat(labelstat2.getSupport(), is(1L)); + } + + @Test + public void testFieldLevelStats2_mixed() throws Exception { + // variant of testMetricsMixed1 where the I- prefix impact the field-level results + // with field ab correctly found - labelstat1 = fieldStats.getLabelStat("<1>"); - labelstat2 = fieldStats.getLabelStat("<2>"); + String result = "a I-<1> I-<1>\nb <1> <1>\nc I-<2> I-<1>\nd I-<1> <1>\ne <1> <1>\n"; + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); - /*System.out.println(labelstat1.toString()); - System.out.println(labelstat2.toString()); - String report = EvaluationUtilities.reportMetrics(result); - System.out.println(report);*/ + LabelStat labelstat1 = fieldStats.getLabelStat("<1>"); + LabelStat labelstat2 = fieldStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(1)); - assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat1.getExpected(), is(2)); + + assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat2.getExpected(), is(1)); } @Test - public void testMetricsMixed2() throws Exception { + public void testTokenLevelStats3_mixed() throws Exception { String result = "a I-<1> I-<1>\nb <1> <1>\nc <1> I-<2>\nd <1> I-<1>\ne <1> <1>\n"; //System.out.println(result); Stats wordStats = EvaluationUtilities.tokenLevelStats(result); - Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); LabelStat labelstat1 = wordStats.getLabelStat("<1>"); - LabelStat labelstat2 = wordStats.getLabelStat("<2>"); + LabelStat labelstat2 = wordStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(4)); assertThat(labelstat2.getObserved(), is(0)); @@ -144,11 +179,18 @@ public void testMetricsMixed2() throws Exception { assertThat(labelstat2.getExpected(), is(0)); assertThat(labelstat1.getFalseNegative(), is(1)); assertThat(labelstat2.getFalseNegative(), is(0)); - assertThat(labelstat1.getFalsePositive(), is(0)); + assertThat(labelstat1.getFalsePositive(), is(0)); assertThat(labelstat2.getFalsePositive(), is(1)); - labelstat1 = fieldStats.getLabelStat("<1>"); - labelstat2 = fieldStats.getLabelStat("<2>"); + } + + @Test + public void testFieldLevelStats3_mixed() throws Exception { + String result = "a I-<1> I-<1>\nb <1> <1>\nc <1> I-<2>\nd <1> I-<1>\ne <1> <1>\n"; + + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); + LabelStat labelstat1 = fieldStats.getLabelStat("<1>"); + LabelStat labelstat2 = fieldStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(0)); assertThat(labelstat2.getObserved(), is(0)); @@ -156,36 +198,37 @@ public void testMetricsMixed2() throws Exception { assertThat(labelstat2.getExpected(), is(0)); assertThat(labelstat1.getFalseNegative(), is(1)); assertThat(labelstat2.getFalseNegative(), is(0)); - assertThat(labelstat1.getFalsePositive(), is(2)); + assertThat(labelstat1.getFalsePositive(), is(2)); assertThat(labelstat2.getFalsePositive(), is(1)); } @Test - public void testMetricsMixed3() throws Exception { + public void testTokenLevelStats4_mixed() throws Exception { String result = "a I-<1> I-<1>\nb I-<2> <1>\nc <2> I-<2>\nd <2> <2>\ne I-<1> I-<1>\nf <1> <1>\ng I-<2> I-<2>\n"; - //System.out.println(result); + Stats wordStats = EvaluationUtilities.tokenLevelStats(result); - Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); LabelStat labelstat1 = wordStats.getLabelStat("<1>"); - LabelStat labelstat2 = wordStats.getLabelStat("<2>"); + LabelStat labelstat2 = wordStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(3)); - assertThat(labelstat2.getObserved(), is(3)); assertThat(labelstat1.getExpected(), is(3)); - assertThat(labelstat2.getExpected(), is(4)); assertThat(labelstat1.getFalseNegative(), is(0)); + assertThat(labelstat1.getFalsePositive(), is(1)); + + assertThat(labelstat2.getObserved(), is(3)); + assertThat(labelstat2.getExpected(), is(4)); assertThat(labelstat2.getFalseNegative(), is(1)); - assertThat(labelstat1.getFalsePositive(), is(1)); assertThat(labelstat2.getFalsePositive(), is(0)); + } - labelstat1 = fieldStats.getLabelStat("<1>"); - labelstat2 = fieldStats.getLabelStat("<2>"); + @Test + public void testFieldLevelStats4_mixed() throws Exception { + String result = "a I-<1> I-<1>\nb I-<2> <1>\nc <2> I-<2>\nd <2> <2>\ne I-<1> I-<1>\nf <1> <1>\ng I-<2> I-<2>\n"; - /*System.out.println(labelstat1.toString()); - System.out.println(labelstat2.toString()); - String report = EvaluationUtilities.reportMetrics(result); - System.out.println(report);*/ + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); + LabelStat labelstat1 = fieldStats.getLabelStat("<1>"); + LabelStat labelstat2 = fieldStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(1)); assertThat(labelstat2.getObserved(), is(1)); @@ -193,20 +236,19 @@ public void testMetricsMixed3() throws Exception { assertThat(labelstat2.getExpected(), is(2)); assertThat(labelstat1.getFalseNegative(), is(1)); assertThat(labelstat2.getFalseNegative(), is(1)); - assertThat(labelstat1.getFalsePositive(), is(1)); + assertThat(labelstat1.getFalsePositive(), is(1)); assertThat(labelstat2.getFalsePositive(), is(1)); } @Test - public void testMetricsReal() throws Exception { - String result = IOUtils.toString(this.getClass().getResourceAsStream("/ex.txt.txt"), StandardCharsets.UTF_8); + public void testTokenLevelStats_realCase() throws Exception { + String result = IOUtils.toString(this.getClass().getResourceAsStream("/sample.wapiti.output.1.txt"), StandardCharsets.UTF_8); result = result.replace(System.lineSeparator(), "\n"); Stats wordStats = EvaluationUtilities.tokenLevelStats(result); - Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); LabelStat labelstat1 = wordStats.getLabelStat(""); - LabelStat labelstat2 = wordStats.getLabelStat(""); + LabelStat labelstat2 = wordStats.getLabelStat(""); assertThat(labelstat1.getObserved(), is(378)); assertThat(labelstat2.getObserved(), is(6)); @@ -214,25 +256,95 @@ public void testMetricsReal() throws Exception { assertThat(labelstat2.getExpected(), is(9)); assertThat(labelstat1.getFalseNegative(), is(0)); assertThat(labelstat2.getFalseNegative(), is(3)); - assertThat(labelstat1.getFalsePositive(), is(3)); + assertThat(labelstat1.getFalsePositive(), is(3)); assertThat(labelstat2.getFalsePositive(), is(0)); - labelstat1 = fieldStats.getLabelStat(""); - labelstat2 = fieldStats.getLabelStat(""); + } + + @Test + public void testFieldLevelStats_realCase() throws Exception { + String result = IOUtils.toString(this.getClass().getResourceAsStream("/sample.wapiti.output.1.txt"), StandardCharsets.UTF_8); - /*System.out.println(labelstat1.toString()); - System.out.println(labelstat2.toString()); - String report = EvaluationUtilities.reportMetrics(result); - System.out.println(report);*/ + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); + LabelStat labelstat1 = fieldStats.getLabelStat(""); + LabelStat labelstat2 = fieldStats.getLabelStat(""); - assertThat(labelstat1.getObserved(), is(1)); + assertThat(labelstat1.getObserved(), is(1)); assertThat(labelstat2.getObserved(), is(2)); assertThat(labelstat1.getExpected(), is(3)); assertThat(labelstat2.getExpected(), is(3)); assertThat(labelstat1.getFalseNegative(), is(2)); assertThat(labelstat2.getFalseNegative(), is(1)); - assertThat(labelstat1.getFalsePositive(), is(1)); + assertThat(labelstat1.getFalsePositive(), is(1)); assertThat(labelstat2.getFalsePositive(), is(0)); } + @Test + public void testTokenLevelStats2_realCase() throws Exception { + String result = IOUtils.toString(this.getClass().getResourceAsStream("/sample.wapiti.output.2.txt"), StandardCharsets.UTF_8); + + Stats stats = EvaluationUtilities.tokenLevelStats(result); + + LabelStat conceptualLabelStats = stats.getLabelStat("CONCEPTUAL"); + assertThat(conceptualLabelStats.getObserved(), is(2)); + assertThat(conceptualLabelStats.getExpected(), is(3)); + assertThat(conceptualLabelStats.getFalseNegative(), is(1)); + assertThat(conceptualLabelStats.getFalsePositive(), is(1)); + + LabelStat periodLabelStats = stats.getLabelStat("PERIOD"); + assertThat(periodLabelStats.getObserved(), is(8)); + assertThat(periodLabelStats.getExpected(), is(8)); + assertThat(periodLabelStats.getFalseNegative(), is(0)); + assertThat(periodLabelStats.getFalsePositive(), is(0)); + + LabelStat mediaLabelStats = stats.getLabelStat("MEDIA"); + assertThat(mediaLabelStats.getObserved(), is(7)); + assertThat(mediaLabelStats.getExpected(), is(7)); + assertThat(mediaLabelStats.getFalseNegative(), is(0)); + assertThat(mediaLabelStats.getFalsePositive(), is(0)); + + LabelStat personTypeLabelStats = stats.getLabelStat("PERSON_TYPE"); + assertThat(personTypeLabelStats.getObserved(), is(0)); + assertThat(personTypeLabelStats.getExpected(), is(1)); + assertThat(personTypeLabelStats.getFalseNegative(), is(1)); + assertThat(personTypeLabelStats.getFalsePositive(), is(0)); + + LabelStat locationTypeLabelStats = stats.getLabelStat("LOCATION"); + assertThat(locationTypeLabelStats.getObserved(), is(2)); + assertThat(locationTypeLabelStats.getExpected(), is(2)); + assertThat(locationTypeLabelStats.getFalseNegative(), is(0)); + assertThat(locationTypeLabelStats.getFalsePositive(), is(0)); + + LabelStat organisationTypeLabelStats = stats.getLabelStat("ORGANISATION"); + assertThat(organisationTypeLabelStats.getObserved(), is(2)); + assertThat(organisationTypeLabelStats.getExpected(), is(2)); + assertThat(organisationTypeLabelStats.getFalseNegative(), is(0)); + assertThat(organisationTypeLabelStats.getFalsePositive(), is(0)); + + LabelStat otherLabelStats = stats.getLabelStat("O"); + assertThat(otherLabelStats.getObserved(), is(33)); + assertThat(otherLabelStats.getExpected(), is(34)); + assertThat(otherLabelStats.getFalseNegative(), is(1)); + assertThat(otherLabelStats.getFalsePositive(), is(0)); + + LabelStat personLabelStats = stats.getLabelStat("PERSON"); + assertThat(personLabelStats.getObserved(), is(0)); + assertThat(personLabelStats.getExpected(), is(0)); + assertThat(personLabelStats.getFalseNegative(), is(0)); + assertThat(personLabelStats.getFalsePositive(), is(1)); + } + + @Test + public void testTokenLevelStats3_realCase() throws Exception { + String result = IOUtils.toString(this.getClass().getResourceAsStream("/sample.wapiti.output.3.txt"), StandardCharsets.UTF_8); + + + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); + + TreeMap labelsResults = fieldStats.getLabelsResults(); + + assertThat(labelsResults.get("").getSupport(), is(4L)); + assertThat(labelsResults.get("").getSupport(), is(2L)); + + } } \ No newline at end of file diff --git a/grobid-trainer/src/test/resources/ex.txt.txt b/grobid-trainer/src/test/resources/sample.wapiti.output.1.txt similarity index 100% rename from grobid-trainer/src/test/resources/ex.txt.txt rename to grobid-trainer/src/test/resources/sample.wapiti.output.1.txt diff --git a/grobid-trainer/src/test/resources/sample.wapiti.output.2.txt b/grobid-trainer/src/test/resources/sample.wapiti.output.2.txt new file mode 100644 index 0000000000..f0e9f4bd5b --- /dev/null +++ b/grobid-trainer/src/test/resources/sample.wapiti.output.2.txt @@ -0,0 +1,57 @@ +Articles articles A Ar Art Arti Artic s es les cles icles INITCAP NODIGIT 0 1 0 0 0 0 0 0 0 0 0 Xxxx Xx O O +from from f fr fro from from m om rom from from NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxxx x O O +European european E Eu Eur Euro Europ n an ean pean opean INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 0 0 Xxxx Xx B-CONCEPTUAL B-NATIONAL +Zionist zionist Z Zi Zio Zion Zioni t st ist nist onist INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 0 0 Xxxx Xx CONCEPTUAL B-CONCEPTUAL +newspapers newspapers n ne new news newsp s rs ers pers apers NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxxx x O O +, , , , , , , , , , , , ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 0 0 , , O O +1901 1901 1 19 190 1901 1901 1 01 901 1901 1901 NOCAPS ALLDIGIT 0 0 0 0 0 1 0 0 0 0 0 dddd d B-PERIOD B-PERIOD +- - - - - - - - - - - - ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 0 0 - - PERIOD PERIOD +1931 1931 1 19 193 1931 1931 1 31 931 1931 1931 NOCAPS ALLDIGIT 0 0 0 0 0 1 0 0 0 0 0 dddd d PERIOD PERIOD +Section section S Se Sec Sect Secti n on ion tion ction INITCAP NODIGIT 0 1 0 0 0 0 0 1 0 0 0 Xxxx Xx O O +of of o of of of of f of of of of NOCAPS NODIGIT 0 1 0 0 0 0 0 1 0 0 0 xx x O O +the the t th the the the e he the the the NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxx x O O +L l L L L L L L L L L L ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 1 0 X X B-MEDIA B-MEDIA +' ' ' ' ' ' ' ' ' ' ' ' ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 1 0 ' ' MEDIA MEDIA +echo echo e ec ech echo echo o ho cho echo echo NOCAPS NODIGIT 0 1 1 0 0 0 0 1 0 1 0 xxxx x MEDIA MEDIA +Sioniste sioniste S Si Sio Sion Sioni e te ste iste niste INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 0 0 Xxxx Xx MEDIA MEDIA +newspaper newspaper n ne new news newsp r er per aper paper NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 1 0 xxxx x O O +with with w wi wit with with h th ith with with NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxxx x O O +reports reports r re rep repo repor s ts rts orts ports NOCAPS NODIGIT 1 1 0 0 0 0 0 0 0 0 0 xxxx x O O +on on o on on on on n on on on on NOCAPS NODIGIT 0 1 0 0 0 0 0 1 0 0 0 xx x O O +Jewish jewish J Je Jew Jewi Jewis h sh ish wish ewish INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 0 0 Xxxx Xx B-PERSON_TYPE B-CONCEPTUAL +problems problems p pr pro prob probl s ms ems lems blems NOCAPS NODIGIT 1 1 0 0 0 0 0 0 0 0 0 xxxx x O O +in in i in in in in n in in in in NOCAPS NODIGIT 0 1 1 0 0 0 0 1 0 0 0 xx x O O +Bulgaria bulgaria B Bu Bul Bulg Bulga a ia ria aria garia INITCAP NODIGIT 0 0 0 0 1 0 0 1 0 0 0 Xxxx Xx B-LOCATION B-LOCATION +, , , , , , , , , , , , ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 0 0 , , O O +May may M Ma May May May y ay May May May INITCAP NODIGIT 0 1 1 0 0 0 1 1 0 0 0 Xxx Xx B-PERIOD B-PERIOD +1901 1901 1 19 190 1901 1901 1 01 901 1901 1901 NOCAPS ALLDIGIT 0 0 0 0 0 1 0 0 0 0 0 dddd d PERIOD PERIOD +; ; ; ; ; ; ; ; ; ; ; ; ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 0 0 ; ; O O +Page page P Pa Pag Page Page e ge age Page Page INITCAP NODIGIT 1 1 1 0 0 0 0 1 0 0 0 Xxxx Xx O B-PERSON +from from f fr fro from from m om rom from from NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxxx x O O +the the t th the the the e he the the the NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxx x O O +Die die D Di Die Die Die e ie Die Die Die INITCAP NODIGIT 0 1 0 0 0 0 0 1 0 0 0 Xxx Xx B-MEDIA B-MEDIA +Arbeit arbeit A Ar Arb Arbe Arbei t it eit beit rbeit INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 0 0 Xxxx Xx MEDIA MEDIA +newspaper newspaper n ne new news newsp r er per aper paper NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 1 0 xxxx x O O +of of o of of of of f of of of of NOCAPS NODIGIT 0 1 0 0 0 0 0 1 0 0 0 xx x O O +Hapoel hapoel H Ha Hap Hapo Hapoe l el oel poel apoel INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 1 0 Xxxx Xx B-ORGANISATION B-ORGANISATION +Hatzair hatzair H Ha Hat Hatz Hatza r ir air zair tzair INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 0 0 Xxxx Xx ORGANISATION ORGANISATION +, , , , , , , , , , , , ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 0 0 , , O O +15 15 1 15 15 15 15 5 15 15 15 15 NOCAPS ALLDIGIT 0 0 0 0 0 0 0 0 0 0 0 dd d B-PERIOD B-PERIOD +February february F Fe Feb Febr Febru y ry ary uary ruary INITCAP NODIGIT 0 0 0 0 0 0 1 0 0 0 0 Xxxx Xx PERIOD PERIOD +1920 1920 1 19 192 1920 1920 0 20 920 1920 1920 NOCAPS ALLDIGIT 0 0 0 0 0 1 0 0 0 0 0 dddd d PERIOD PERIOD +; ; ; ; ; ; ; ; ; ; ; ; ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 0 0 ; ; O O +Issue issue I Is Iss Issu Issue e ue sue ssue Issue INITCAP NODIGIT 0 1 0 0 0 0 0 0 0 0 0 Xxxx Xx O O +of of o of of of of f of of of of NOCAPS NODIGIT 0 1 0 0 0 0 0 1 0 0 0 xx x O O +Haolam haolam H Ha Hao Haol Haola m am lam olam aolam INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 0 0 Xxxx Xx B-MEDIA B-MEDIA +newspaper newspaper n ne new news newsp r er per aper paper NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 1 0 xxxx x O O +with with w wi wit with with h th ith with with NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxxx x O O +excerpts excerpts e ex exc exce excer s ts pts rpts erpts NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxxx x O O +on on o on on on on n on on on on NOCAPS NODIGIT 0 1 0 0 0 0 0 1 0 0 0 xx x O O +Zionist zionist Z Zi Zio Zion Zioni t st ist nist onist INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 0 0 Xxxx Xx B-CONCEPTUAL B-CONCEPTUAL +events events e ev eve even event s ts nts ents vents NOCAPS NODIGIT 1 1 0 0 0 0 0 0 0 0 0 xxxx x O O +in in i in in in in n in in in in NOCAPS NODIGIT 0 1 1 0 0 0 0 1 0 0 0 xx x O O +Europe europe E Eu Eur Euro Europ e pe ope rope urope INITCAP NODIGIT 0 0 0 0 0 0 0 1 0 0 0 Xxxx Xx B-LOCATION B-LOCATION +at at a at at at at t at at at at NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xx x O O +the the t th the the the e he the the the NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxx x O O +time time t ti tim time time e me ime time time NOCAPS NODIGIT 0 1 0 0 0 0 0 1 0 0 0 xxxx x O O +. . . . . . . . . . . . ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 0 0 . . O O \ No newline at end of file diff --git a/grobid-trainer/src/test/resources/sample.wapiti.output.3.txt b/grobid-trainer/src/test/resources/sample.wapiti.output.3.txt new file mode 100644 index 0000000000..0e58507e12 --- /dev/null +++ b/grobid-trainer/src/test/resources/sample.wapiti.output.3.txt @@ -0,0 +1,13 @@ +y 0 0 1 1 NOPUNCT 0 I- I- +r 0 0 0 0 NOPUNCT 0 +s 0 0 1 0 NOPUNCT 0 + +s 0 0 1 0 NOPUNCT 0 I- I- + +G 1 0 1 1 NOPUNCT 0 I- I- +P 1 0 1 1 NOPUNCT 0 I- I- +a 0 0 1 1 NOPUNCT 0 + +G 1 0 1 1 NOPUNCT 0 I- I- +P 1 0 1 1 NOPUNCT 0 I- I- +a 0 0 1 1 NOPUNCT 0 \ No newline at end of file diff --git a/grobid-trainer/src/test/resources/sample.wapiti.output.date.txt b/grobid-trainer/src/test/resources/sample.wapiti.output.date.txt new file mode 100644 index 0000000000..b1005bfb59 --- /dev/null +++ b/grobid-trainer/src/test/resources/sample.wapiti.output.date.txt @@ -0,0 +1,23 @@ +Available available A Av Ava Avai e le ble able LINESTART INITCAP NODIGIT 0 0 0 NOPUNCT I- +online online o on onl onli e ne ine line LINEIN NOCAPS NODIGIT 0 0 0 NOPUNCT +18 18 1 18 18 18 8 18 18 18 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I- +January january J Ja Jan Janu y ry ary uary LINEIN INITCAP NODIGIT 0 0 1 NOPUNCT I- +2010 2010 2 20 201 2010 0 10 010 2010 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I- + + +June june J Ju Jun June e ne une June LINESTART INITCAP NODIGIT 0 0 1 NOPUNCT I- +16 16 1 16 16 16 6 16 16 16 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I- +, , , , , , , , , , LINEIN ALLCAP NODIGIT 1 0 0 COMMA I- +2008 2008 2 20 200 2008 8 08 008 2008 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I- + + +November november N No Nov Nove r er ber mber LINESTART INITCAP NODIGIT 0 0 1 NOPUNCT I- +4 4 4 4 4 4 4 4 4 4 LINEIN NOCAPS ALLDIGIT 1 0 0 NOPUNCT I- +, , , , , , , , , , LINEIN ALLCAP NODIGIT 1 0 0 COMMA I- +2009 2009 2 20 200 2009 9 09 009 2009 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I- + + +Published published P Pu Pub Publ d ed hed shed LINESTART INITCAP NODIGIT 0 0 0 NOPUNCT I- +18 18 1 18 18 18 8 18 18 18 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I- +May may M Ma May May y ay May May LINEIN INITCAP NODIGIT 0 0 1 NOPUNCT I- +2011 2011 2 20 201 2011 1 11 011 2011 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I- \ No newline at end of file