From 9f1687e5f8f0492974484cdf7d0f83ed3e28a967 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Thu, 16 Nov 2017 10:14:06 +0100 Subject: [PATCH 01/37] moving logic for the calculation of precision/recall/f1 together with averages in the Stats class in order to be able to test it directly. --- .../java/org/grobid/trainer/LabelStat.java | 61 +- .../main/java/org/grobid/trainer/Stats.java | 143 ++- .../evaluation/EvaluationUtilities.java | 873 ++++++++---------- .../java/org/grobid/trainer/StatsTest.java | 307 ++++++ .../evaluation/EvaluationUtilitiesTest.java | 185 ++-- 5 files changed, 988 insertions(+), 581 deletions(-) create mode 100644 grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/LabelStat.java b/grobid-trainer/src/main/java/org/grobid/trainer/LabelStat.java index de7b7acbee..71fdb5bc1b 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/LabelStat.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/LabelStat.java @@ -6,51 +6,58 @@ public final class LabelStat { private int observed = 0; // this is true positives private int expected = 0; // total expected number of items with this label - public void incrementFalseNegative() { - this.incrementFalseNegative(1); + public void incrementFalseNegative() { + this.incrementFalseNegative(1); } - - public void incrementFalsePositive() { - this.incrementFalsePositive(1); + + public void incrementFalsePositive() { + this.incrementFalsePositive(1); } - - public void incrementObserved() { + + public void incrementObserved() { this.incrementObserved(1); } - - public void incrementExpected() { - this.incrementExpected(1); + + public void incrementExpected() { + this.incrementExpected(1); } - public void incrementFalseNegative(int count) { - this.falseNegative += count; + public void incrementFalseNegative(int count) { + this.falseNegative += count; } - - public void incrementFalsePositive(int count) { - this.falsePositive += count; + + public void incrementFalsePositive(int count) { + this.falsePositive += count; } - - public void incrementObserved(int count) { - this.observed += count; + + public void incrementObserved(int count) { + this.observed += count; } - public void incrementExpected(int count) { - this.expected += count; + public void incrementExpected(int count) { + this.expected += count; } public int getExpected() { return this.expected; } + public int getFalseNegative() { return this.falseNegative; } + public int getFalsePositive() { return this.falsePositive; } + public int getObserved() { return this.observed; } + public int getAll() { + return observed + falseNegative + falsePositive; + } + public void setFalsePositive(int falsePositive) { this.falsePositive = falsePositive; } @@ -75,32 +82,26 @@ public double getPrecision() { if (observed == 0.0) { return 0.0; } - // if ((falsePositive + falseNegative) >= observed) - // return 0.0; - //return (double) (observed - (falsePositive + falseNegative) ) / (observed); - return (double) observed / (falsePositive + observed); + return ((double) observed) / (falsePositive + observed); } public double getRecall() { if (expected == 0.0) return 0.0; - // if ((falsePositive + falseNegative) >= observed) - // return 0.0; - //return (double) (observed - (falsePositive + falseNegative) ) / (expected); - return (double) observed / (expected); + return ((double) observed) / (expected); } public double getF1Score() { double precision = getPrecision(); double recall = getRecall(); - if ( (precision == 0) && (recall == 0.0) ) + if ((precision == 0.0) && (recall == 0.0)) return 0.0; return (2.0 * precision * recall) / (precision + recall); } - @Override + @Override public String toString() { StringBuilder builder = new StringBuilder(); builder diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/Stats.java b/grobid-trainer/src/main/java/org/grobid/trainer/Stats.java index e8b0dbee94..fb2ccc09a4 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/Stats.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/Stats.java @@ -71,8 +71,8 @@ public LabelStat getLabelStat(String label) { return labelStat; } - public int size() { - return this.labelStats.size(); + public int size() { + return this.labelStats.size(); } public double getPrecision(String label) { @@ -95,5 +95,144 @@ public double getF1Score(String label) { throw new GrobidException("Unknown label: " + label); return labelStat.getF1Score(); } + + /** + * Return the micro average precision, which is the precision calculated + * on the cumulation of the TP and FP over the whole data + */ + public double getMicroAveragePrecision() { + double cumulatedTruePositive = 0.0; + double cumulatedFalsePositive = 0.0; + + for (String label : getLabels()) { + if (label.equals("") || label.equals("base") || label.equals("O")) { + continue; + } + final LabelStat labelStat = getLabelStat(label); + if (labelStat.getExpected() != 0) { + cumulatedTruePositive += labelStat.getObserved(); + cumulatedFalsePositive += labelStat.getFalsePositive(); + } + } + + double precision = 0.0; + if (cumulatedTruePositive + cumulatedFalsePositive != 0) + precision = cumulatedTruePositive / (cumulatedTruePositive + cumulatedFalsePositive); + + return Math.min(1.0, precision); + } + + /** + * Calculate the macro average precision, which is the + * mean average among all the precision for each label + */ + public double getMacroAveragePrecision() { + int totalValidFields = 0; + double cumulated_precision = 0.0; + + for (String label : getLabels()) { + if (label.equals("") || label.equals("base") || label.equals("O")) { + continue; + } + final LabelStat labelStat = getLabelStat(label); + if (labelStat.getExpected() != 0) { + totalValidFields++; + double labelPrecision = labelStat.getPrecision(); + cumulated_precision += labelPrecision; + } + } + + if (totalValidFields == 0) + return 0.0; + + return Math.min(1.0, cumulated_precision / totalValidFields); + } + + public double getMicroAverageRecall() { + double cumulatedTruePositive = 0.0; + double cumulatedExpected = 0.0; + + for (String label : getLabels()) { + if (label.equals("") || label.equals("base") || label.equals("O")) { + continue; + } + final LabelStat labelStat = getLabelStat(label); + if (labelStat.getExpected() != 0) { + cumulatedTruePositive += labelStat.getObserved(); + cumulatedExpected += labelStat.getExpected(); + } + } + + double recall = 0.0; + if (cumulatedExpected != 0.0) + recall = cumulatedTruePositive / cumulatedExpected; + + return Math.min(1.0, recall); + } + + public double getMacroAverageRecall() { + int totalValidFields = 0; + double cumulatedRecall = 0.0; + + for (String label : getLabels()) { + if (label.equals("") || label.equals("base") || label.equals("O")) { + continue; + } + + final LabelStat labelStat = getLabelStat(label); + if (labelStat.getExpected() != 0) { + totalValidFields++; + cumulatedRecall += labelStat.getRecall(); + } + } + + if (totalValidFields == 0) + return 0.0; + + return Math.min(1.0, cumulatedRecall / totalValidFields); + } + + public int getTotalFields() { + int totalFields = 0; + for (String label : getLabels()) { + totalFields += getLabelStat(label).getAll(); + } + + return totalFields; + } + + public double getMicroAverageF1() { + double precision = getMicroAveragePrecision(); + double recall = getMicroAverageRecall(); + + double f1 = 0.0; + if (precision + recall != 0.0) + f1 = (2 * precision * recall) / (precision + recall); + + return f1; + } + + public double getMacroAverageF1() { + double cumulatedF1 = 0.0; + int totalValidFields = 0; + + for (String label : getLabels()) { + if (label.equals("") || label.equals("base") || label.equals("O")) { + continue; + } + + final LabelStat labelStat = getLabelStat(label); + if (labelStat.getExpected() != 0) { + totalValidFields++; + double labelF1 = labelStat.getF1Score(); + cumulatedF1 += labelF1; + } + } + + if (totalValidFields == 0) + return 0.0; + + return Math.min(1.0, cumulatedF1 / totalValidFields); + } } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java index 4621a7d9bd..08373a5ca7 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java @@ -27,84 +27,82 @@ import org.slf4j.LoggerFactory; /** - * Generic evaluation of a single-CRF model processing given an expected result. - * + * Generic evaluation of a single-CRF model processing given an expected result. + * * @author Patrice Lopez */ public class EvaluationUtilities { - protected static final Logger logger = LoggerFactory.getLogger(EvaluationUtilities.class); - - /** - * Method for running a CRF tagger for evaluation purpose (i.e. with - * expected and actual labels). - * - * @param ress - * list - * @param tagger - * a tagger - * @return a report - */ - public static String taggerRun(List ress, Tagger tagger) { - // clear internal context - tagger.clear(); - StringBuilder res = new StringBuilder(); - - // we have to re-inject the pre-tags because they are removed by the JNI - // parse method - ArrayList pretags = new ArrayList(); - // add context - for (String piece : ress) { - if (piece.trim().length() == 0) { - // parse and change internal stated as 'parsed' - if (!tagger.parse()) { - // throw an exception - throw new RuntimeException("CRF++ parsing failed."); - } - - for (int i = 0; i < tagger.size(); i++) { - for (int j = 0; j < tagger.xsize(); j++) { - res.append(tagger.x(i, j)).append("\t"); - } - res.append(pretags.get(i)).append("\t"); - res.append(tagger.y2(i)); - res.append("\n"); - } - res.append(" \n"); - // clear internal context - tagger.clear(); - pretags = new ArrayList(); - } else { - tagger.add(piece); - tagger.add("\n"); - // get last tag - StringTokenizer tokenizer = new StringTokenizer(piece, " \t"); - while (tokenizer.hasMoreTokens()) { - String toke = tokenizer.nextToken(); - if (!tokenizer.hasMoreTokens()) { - pretags.add(toke); - } - } - } - } - - // parse and change internal stated as 'parsed' - if (!tagger.parse()) { - // throw an exception - throw new RuntimeException("CRF++ parsing failed."); - } - - for (int i = 0; i < tagger.size(); i++) { - for (int j = 0; j < tagger.xsize(); j++) { - res.append(tagger.x(i, j)).append("\t"); - } - res.append(pretags.get(i)).append("\t"); - res.append(tagger.y2(i)); - res.append(System.lineSeparator()); - } - res.append(System.lineSeparator()); - - return res.toString(); - } + protected static final Logger logger = LoggerFactory.getLogger(EvaluationUtilities.class); + + /** + * Method for running a CRF tagger for evaluation purpose (i.e. with + * expected and actual labels). + * + * @param ress list + * @param tagger a tagger + * @return a report + */ + public static String taggerRun(List ress, Tagger tagger) { + // clear internal context + tagger.clear(); + StringBuilder res = new StringBuilder(); + + // we have to re-inject the pre-tags because they are removed by the JNI + // parse method + ArrayList pretags = new ArrayList(); + // add context + for (String piece : ress) { + if (piece.trim().length() == 0) { + // parse and change internal stated as 'parsed' + if (!tagger.parse()) { + // throw an exception + throw new RuntimeException("CRF++ parsing failed."); + } + + for (int i = 0; i < tagger.size(); i++) { + for (int j = 0; j < tagger.xsize(); j++) { + res.append(tagger.x(i, j)).append("\t"); + } + res.append(pretags.get(i)).append("\t"); + res.append(tagger.y2(i)); + res.append("\n"); + } + res.append(" \n"); + // clear internal context + tagger.clear(); + pretags = new ArrayList(); + } else { + tagger.add(piece); + tagger.add("\n"); + // get last tag + StringTokenizer tokenizer = new StringTokenizer(piece, " \t"); + while (tokenizer.hasMoreTokens()) { + String toke = tokenizer.nextToken(); + if (!tokenizer.hasMoreTokens()) { + pretags.add(toke); + } + } + } + } + + // parse and change internal stated as 'parsed' + if (!tagger.parse()) { + // throw an exception + throw new RuntimeException("CRF++ parsing failed."); + } + + for (int i = 0; i < tagger.size(); i++) { + for (int j = 0; j < tagger.xsize(); j++) { + res.append(tagger.x(i, j)).append("\t"); + } + res.append(pretags.get(i)).append("\t"); + res.append(tagger.y2(i)); + res.append(System.lineSeparator()); + } + res.append(System.lineSeparator()); + + return res.toString(); + } public static String evaluateStandard(String path, final GenericTagger tagger) { return evaluateStandard(path, new Function, String>() { @@ -115,258 +113,257 @@ public String apply(List strings) { }); } - public static String evaluateStandard(String path, Function, String> taggerFunction) { - String theResult = null; + public static String evaluateStandard(String path, Function, String> taggerFunction) { + String theResult = null; - try { - final BufferedReader bufReader = new BufferedReader(new InputStreamReader(new FileInputStream(path), "UTF-8")); + try { + final BufferedReader bufReader = new BufferedReader(new InputStreamReader(new FileInputStream(path), "UTF-8")); - String line = null; - List citationBlocks = new ArrayList(); - while ((line = bufReader.readLine()) != null) { - citationBlocks.add(line); - } - long time = System.currentTimeMillis(); - theResult = taggerFunction.apply(citationBlocks); - bufReader.close(); + String line = null; + List citationBlocks = new ArrayList(); + while ((line = bufReader.readLine()) != null) { + citationBlocks.add(line); + } + long time = System.currentTimeMillis(); + theResult = taggerFunction.apply(citationBlocks); + bufReader.close(); System.out.println("Labeling took: " + (System.currentTimeMillis() - time) + " ms"); } catch (Exception e) { - throw new GrobidException("An exception occurred while evaluating Grobid.", e); - } + throw new GrobidException("An exception occurred while evaluating Grobid.", e); + } - return reportMetrics(theResult); - } + return reportMetrics(theResult); + } - public static String reportMetrics(String theResult) { + public static String reportMetrics(String theResult) { StringBuilder report = new StringBuilder(); - - Stats wordStats = tokenLevelStats(theResult); - - // report token-level results - report.append("\n===== Token-level results =====\n\n"); - report.append(computeMetrics(wordStats)); - - Stats fieldStats = fieldLevelStats(theResult); - - report.append("\n===== Field-level results =====\n"); - report.append(computeMetrics(fieldStats)); - - report.append("\n===== Instance-level results =====\n\n"); - // instance-level: instances are separated by a new line in the result file - // third pass - theResult = theResult.replace("\n\n", "\n \n"); - StringTokenizer stt = new StringTokenizer(theResult, "\n"); - boolean allGood = true; - int correctInstance = 0; - int totalInstance = 0; - String line = null; - while (stt.hasMoreTokens()) { - line = stt.nextToken(); - if ((line.trim().length() == 0) || (!stt.hasMoreTokens())) { - // instance done - totalInstance++; - if (allGood) { - correctInstance++; - } - // we reinit for a new instance - allGood = true; - } else { - StringTokenizer st = new StringTokenizer(line, "\t "); - String obtainedLabel = null; - String expectedLabel = null; - while (st.hasMoreTokens()) { - obtainedLabel = getPlainLabel(st.nextToken()); - if (st.hasMoreTokens()) { - expectedLabel = obtainedLabel; - } - } - - if (!obtainedLabel.equals(expectedLabel)) { - // one error is enough to have the whole instance false, damn! - allGood = false; - } - } - } - - report.append(String.format("%-27s %d\n", "Total expected instances:", totalInstance)); - report.append(String.format("%-27s %d\n", "Correct instances:", correctInstance)); - double accuracy = (double) correctInstance / (totalInstance); - report.append(String.format("%-27s %s\n", - "Instance-level recall:", - TextUtilities.formatTwoDecimals(accuracy * 100))); - - return report.toString(); - } - - public static Stats tokenLevelStats(String theResult) { - Stats wordStats = new Stats(); - String line = null; - StringTokenizer stt = new StringTokenizer(theResult, System.lineSeparator()); - while (stt.hasMoreTokens()) { - line = stt.nextToken(); - - if (line.trim().length() == 0) { - continue; - } - // the two last tokens, separated by a tabulation, gives the - // expected label and, last, the resulting label -> for Wapiti - StringTokenizer st = new StringTokenizer(line, "\t "); - String obtainedLabel = null; - String expectedLabel = null; - - while (st.hasMoreTokens()) { - obtainedLabel = getPlainLabel(st.nextToken()); - if (st.hasMoreTokens()) { - expectedLabel = obtainedLabel; - } - } - - if ((expectedLabel == null) || (obtainedLabel == null)) { - continue; - } - - processCounters(wordStats, obtainedLabel, expectedLabel); - if (!obtainedLabel.equals(expectedLabel)) { - logger.warn("Disagreement / expected: " + expectedLabel + " / obtained: " + obtainedLabel); - } - } - return wordStats; - } - - public static Stats fieldLevelStats(String theResult) { - Stats fieldStats = new Stats(); - - // field: a field is simply a sequence of token with the same label - - // we build first the list of fields in expected and obtained result - // with offset positions - List> expectedFields = new ArrayList>(); - List> obtainedFields = new ArrayList>(); - StringTokenizer stt = new StringTokenizer(theResult, System.lineSeparator()); - String line = null; - String previousExpectedLabel = null; - String previousObtainedLabel = null; - int pos = 0; // current token index - OffsetPosition currentObtainedPosition = new OffsetPosition(); - currentObtainedPosition.start = 0; - OffsetPosition currentExpectedPosition = new OffsetPosition(); - currentExpectedPosition.start = 0; - String obtainedLabel = null; - String expectedLabel = null; - while (stt.hasMoreTokens()) { - line = stt.nextToken(); - obtainedLabel = null; - expectedLabel = null; - StringTokenizer st = new StringTokenizer(line, "\t "); - while (st.hasMoreTokens()) { - obtainedLabel = st.nextToken(); - if (st.hasMoreTokens()) { - expectedLabel = obtainedLabel; - } - } - - if ( (obtainedLabel == null) || (expectedLabel == null) ) - continue; - - if ((previousObtainedLabel != null) && - (!obtainedLabel.equals(getPlainLabel(previousObtainedLabel)))) { - // new obtained field - currentObtainedPosition.end = pos - 1; - Pair theField = new Pair(getPlainLabel(previousObtainedLabel), - currentObtainedPosition); - currentObtainedPosition = new OffsetPosition(); - currentObtainedPosition.start = pos; - obtainedFields.add(theField); - } - - if ((previousExpectedLabel != null) && - (!expectedLabel.equals(getPlainLabel(previousExpectedLabel)))) { - // new expected field - currentExpectedPosition.end = pos - 1; - Pair theField = new Pair(getPlainLabel(previousExpectedLabel), - currentExpectedPosition); - currentExpectedPosition = new OffsetPosition(); - currentExpectedPosition.start = pos; - expectedFields.add(theField); - } - - previousExpectedLabel = expectedLabel; - previousObtainedLabel = obtainedLabel; - pos++; - } - // last fields of the sequence - if ((previousObtainedLabel != null)) { - currentObtainedPosition.end = pos - 1; - Pair theField = new Pair(getPlainLabel(previousObtainedLabel), - currentObtainedPosition); - obtainedFields.add(theField); - } - - if ((previousExpectedLabel != null)) { - currentExpectedPosition.end = pos - 1; - Pair theField = new Pair(getPlainLabel(previousExpectedLabel), - currentExpectedPosition); - expectedFields.add(theField); - } - - // we then simply compared the positions and labels of the two fields and update - // statistics - int obtainedFieldIndex = 0; - List> matchedObtainedFields = new ArrayList>(); - for(Pair expectedField : expectedFields) { - expectedLabel = expectedField.getA(); - int expectedStart = expectedField.getB().start; - int expectedEnd = expectedField.getB().end; - - LabelStat labelStat = fieldStats.getLabelStat(getPlainLabel(expectedLabel)); - labelStat.incrementExpected(); - - // try to find a match in the obtained fields - boolean found = false; - for(int i=obtainedFieldIndex; i obtainedField :obtainedFields) { - if (!matchedObtainedFields.contains(obtainedField)) { - obtainedLabel = obtainedField.getA(); - LabelStat labelStat = fieldStats.getLabelStat(getPlainLabel(obtainedLabel)); - labelStat.incrementFalsePositive(); - } - } - - return fieldStats; - } - - - private static String getPlainLabel(String label) { - if (label == null) - return null; - if (label.startsWith("I-") || label.startsWith("E-") || label.startsWith("B-")) { - return label.substring(2, label.length()); - } else - return label; + + // report token-level results + Stats wordStats = tokenLevelStats(theResult); + report.append("\n===== Token-level results =====\n\n"); + report.append(computeMetrics(wordStats)); + + // report field-level results + Stats fieldStats = fieldLevelStats(theResult); + report.append("\n===== Field-level results =====\n"); + report.append(computeMetrics(fieldStats)); + + // instance-level: instances are separated by a new line in the result file + // third pass + report.append("\n===== Instance-level results =====\n\n"); + theResult = theResult.replace("\n\n", "\n \n"); + StringTokenizer stt = new StringTokenizer(theResult, "\n"); + boolean allGood = true; + int correctInstance = 0; + int totalInstance = 0; + String line = null; + while (stt.hasMoreTokens()) { + line = stt.nextToken(); + if ((line.trim().length() == 0) || (!stt.hasMoreTokens())) { + // instance done + totalInstance++; + if (allGood) { + correctInstance++; + } + // we reinit for a new instance + allGood = true; + } else { + StringTokenizer st = new StringTokenizer(line, "\t "); + String obtainedLabel = null; + String expectedLabel = null; + while (st.hasMoreTokens()) { + obtainedLabel = getPlainLabel(st.nextToken()); + if (st.hasMoreTokens()) { + expectedLabel = obtainedLabel; + } + } + + if (!obtainedLabel.equals(expectedLabel)) { + // one error is enough to have the whole instance false, damn! + allGood = false; + } + } + } + + report.append(String.format("%-27s %d\n", "Total expected instances:", totalInstance)); + report.append(String.format("%-27s %d\n", "Correct instances:", correctInstance)); + double accuracy = (double) correctInstance / (totalInstance); + report.append(String.format("%-27s %s\n", + "Instance-level recall:", + TextUtilities.formatTwoDecimals(accuracy * 100))); + + return report.toString(); + } + + public static Stats tokenLevelStats(String theResult) { + Stats wordStats = new Stats(); + String line = null; + StringTokenizer stt = new StringTokenizer(theResult, System.lineSeparator()); + while (stt.hasMoreTokens()) { + line = stt.nextToken(); + + if (line.trim().length() == 0) { + continue; + } + // the two last tokens, separated by a tabulation, gives the + // expected label and, last, the resulting label -> for Wapiti + StringTokenizer st = new StringTokenizer(line, "\t "); + String obtainedLabel = null; + String expectedLabel = null; + + while (st.hasMoreTokens()) { + obtainedLabel = getPlainLabel(st.nextToken()); + if (st.hasMoreTokens()) { + expectedLabel = obtainedLabel; + } + } + + if ((expectedLabel == null) || (obtainedLabel == null)) { + continue; + } + + processCounters(wordStats, obtainedLabel, expectedLabel); + if (!obtainedLabel.equals(expectedLabel)) { + logger.warn("Disagreement / expected: " + expectedLabel + " / obtained: " + obtainedLabel); + } + } + return wordStats; + } + + public static Stats fieldLevelStats(String theResult) { + Stats fieldStats = new Stats(); + + // field: a field is simply a sequence of token with the same label + + // we build first the list of fields in expected and obtained result + // with offset positions + List> expectedFields = new ArrayList<>(); + List> obtainedFields = new ArrayList<>(); + StringTokenizer stt = new StringTokenizer(theResult, System.lineSeparator()); + String line = null; + String previousExpectedLabel = null; + String previousObtainedLabel = null; + int pos = 0; // current token index + OffsetPosition currentObtainedPosition = new OffsetPosition(); + currentObtainedPosition.start = 0; + OffsetPosition currentExpectedPosition = new OffsetPosition(); + currentExpectedPosition.start = 0; + String obtainedLabel = null; + String expectedLabel = null; + while (stt.hasMoreTokens()) { + line = stt.nextToken(); + obtainedLabel = null; + expectedLabel = null; + StringTokenizer st = new StringTokenizer(line, "\t "); + while (st.hasMoreTokens()) { + obtainedLabel = st.nextToken(); + if (st.hasMoreTokens()) { + expectedLabel = obtainedLabel; + } + } + + if ((obtainedLabel == null) || (expectedLabel == null)) + continue; + + if ((previousObtainedLabel != null) && + (!obtainedLabel.equals(getPlainLabel(previousObtainedLabel)))) { + // new obtained field + currentObtainedPosition.end = pos - 1; + Pair theField = new Pair<>(getPlainLabel(previousObtainedLabel), + currentObtainedPosition); + currentObtainedPosition = new OffsetPosition(); + currentObtainedPosition.start = pos; + obtainedFields.add(theField); + } + + if ((previousExpectedLabel != null) && + (!expectedLabel.equals(getPlainLabel(previousExpectedLabel)))) { + // new expected field + currentExpectedPosition.end = pos - 1; + Pair theField = new Pair<>(getPlainLabel(previousExpectedLabel), + currentExpectedPosition); + currentExpectedPosition = new OffsetPosition(); + currentExpectedPosition.start = pos; + expectedFields.add(theField); + } + + previousExpectedLabel = expectedLabel; + previousObtainedLabel = obtainedLabel; + pos++; + } + // last fields of the sequence + if ((previousObtainedLabel != null)) { + currentObtainedPosition.end = pos - 1; + Pair theField = new Pair(getPlainLabel(previousObtainedLabel), + currentObtainedPosition); + obtainedFields.add(theField); + } + + if ((previousExpectedLabel != null)) { + currentExpectedPosition.end = pos - 1; + Pair theField = new Pair(getPlainLabel(previousExpectedLabel), + currentExpectedPosition); + expectedFields.add(theField); + } + + // we then simply compared the positions and labels of the two fields and update + // statistics + int obtainedFieldIndex = 0; + List> matchedObtainedFields = new ArrayList>(); + for (Pair expectedField : expectedFields) { + expectedLabel = expectedField.getA(); + int expectedStart = expectedField.getB().start; + int expectedEnd = expectedField.getB().end; + + LabelStat labelStat = fieldStats.getLabelStat(getPlainLabel(expectedLabel)); + labelStat.incrementExpected(); + + // try to find a match in the obtained fields + boolean found = false; + for (int i = obtainedFieldIndex; i < obtainedFields.size(); i++) { + obtainedLabel = obtainedFields.get(i).getA(); + if (!expectedLabel.equals(obtainedLabel)) + continue; + if ((expectedStart == obtainedFields.get(i).getB().start) && + (expectedEnd == obtainedFields.get(i).getB().end)) { + // we have a match + labelStat.incrementObserved(); // TP + found = true; + obtainedFieldIndex = i; + matchedObtainedFields.add(obtainedFields.get(i)); + break; + } + // if we went too far, we can stop the pain + if (expectedEnd < obtainedFields.get(i).getB().start) { + break; + } + } + if (!found) { + labelStat.incrementFalseNegative(); + } + } + + // all the obtained fields without match in the expected fields are false positive + for (Pair obtainedField : obtainedFields) { + if (!matchedObtainedFields.contains(obtainedField)) { + obtainedLabel = obtainedField.getA(); + LabelStat labelStat = fieldStats.getLabelStat(getPlainLabel(obtainedLabel)); + labelStat.incrementFalsePositive(); + } + } + + return fieldStats; + } + + + private static String getPlainLabel(String label) { + if (label == null) + return null; + if (label.startsWith("I-") || label.startsWith("E-") || label.startsWith("B-")) { + return label.substring(2, label.length()); + } else + return label; } private static void processCounters(Stats stats, String obtained, String expected) { @@ -383,153 +380,93 @@ private static void processCounters(Stats stats, String obtained, String expecte } } - public static String computeMetrics(Stats stats) { - StringBuilder report = new StringBuilder(); - report.append(String.format("\n%-20s %-12s %-12s %-12s %-7s\n\n", - "label", - "accuracy", - "precision", - "recall", - "f1")); - - int cumulated_tp = 0; - int cumulated_fp = 0; - int cumulated_tn = 0; - int cumulated_fn = 0; - double cumulated_f0 = 0.0; - double cumulated_accuracy = 0.0; - double cumulated_precision = 0.0; - double cumulated_recall = 0.0; - int cumulated_all = 0; - int totalValidFields = 0; - - int totalFields = 0; - for (String label : stats.getLabels()) { - LabelStat labelStat = stats.getLabelStat(label); - totalFields += labelStat.getObserved(); - totalFields += labelStat.getFalseNegative(); - totalFields += labelStat.getFalsePositive(); - } - - for (String label : stats.getLabels()) { - if (label.equals("") || label.equals("base") || label.equals("O")) { - continue; - } - - LabelStat labelStat = stats.getLabelStat(label); - int tp = labelStat.getObserved(); // true positives - int fp = labelStat.getFalsePositive(); // false positives - int fn = labelStat.getFalseNegative(); // false negative - int tn = totalFields - tp - (fp + fn); // true negatives - int all = labelStat.getExpected(); // all expected - - if (all != 0) { - totalValidFields++; - } - - double accuracy = (double) (tp + tn) / (tp + fp + tn + fn); - if (accuracy < 0.0) - accuracy = 0.0; - - double precision; - if ((tp + fp) == 0) { - precision = 0.0; - } else { - precision = (double) (tp) / (tp + fp); - } - - double recall; - if ((tp == 0) || (all == 0)) { - recall = 0.0; - } else { - recall = (double) (tp) / all; - } - - double f0; - if (precision + recall == 0) { - f0 = 0.0; - } else { - f0 = (2 * precision * recall) / (precision + recall); - } - - report.append(String.format("%-20s %-12s %-12s %-12s %-7s\n", - label, - TextUtilities.formatTwoDecimals(accuracy * 100), - TextUtilities.formatTwoDecimals(precision * 100), - TextUtilities.formatTwoDecimals(recall * 100), - TextUtilities.formatTwoDecimals(f0 * 100))); - - cumulated_tp += tp; - cumulated_fp += fp; - cumulated_tn += tn; - cumulated_fn += fn; - if (all != 0) { - cumulated_all += all; - cumulated_f0 += f0; - cumulated_accuracy += accuracy; - cumulated_precision += precision; - cumulated_recall += recall; - } - } - - report.append("\n"); - - // micro average over measures - double accuracy = 0.0; - if (cumulated_tp + cumulated_fp + cumulated_tn + cumulated_fn != 0.0) - accuracy = (double) (cumulated_tp + cumulated_tn) / (cumulated_tp + cumulated_fp + cumulated_tn + cumulated_fn); - accuracy = Math.min(1.0, accuracy); - - double precision = 0.0; - if (cumulated_tp + cumulated_fp != 0) - precision = (double) cumulated_tp / (cumulated_tp + cumulated_fp); - precision = Math.min(1.0, precision); - - //recall = ((double) cumulated_tp) / (cumulated_tp + cumulated_fn); - double recall = 0.0; - if (cumulated_all != 0.0) - recall = ((double) cumulated_tp) / (cumulated_all); - recall = Math.min(1.0, recall); - - double f0 = 0.0; - if (precision + recall != 0.0) - f0 = (2 * precision * recall) / (precision + recall); - - report.append(String.format("%-20s %-12s %-12s %-12s %-7s (micro average)\n", - "all fields", - TextUtilities.formatTwoDecimals(accuracy * 100), - TextUtilities.formatTwoDecimals(precision * 100), - TextUtilities.formatTwoDecimals(recall * 100), - TextUtilities.formatTwoDecimals(f0 * 100))); - - // macro average over measures - if (totalValidFields == 0) - accuracy = 0.0; - else - accuracy = Math.min(1.0, cumulated_accuracy / (totalValidFields)); - - if (totalValidFields == 0) - precision = 0.0; - else - precision = Math.min(1.0, cumulated_precision / totalValidFields); - - if (totalValidFields == 0) - recall = 0.0; - else - recall = Math.min(1.0, cumulated_recall / totalValidFields); - - if (totalValidFields == 0) - f0 = 0.0; - else - f0 = Math.min(1.0, cumulated_f0 / totalValidFields); - - report.append(String.format("%-20s %-12s %-12s %-12s %-7s (macro average)\n", - "", - TextUtilities.formatTwoDecimals(accuracy * 100), - TextUtilities.formatTwoDecimals(precision * 100), - TextUtilities.formatTwoDecimals(recall * 100), - TextUtilities.formatTwoDecimals(f0 * 100))); - - return report.toString(); - } + public static String computeMetrics(Stats stats) { + StringBuilder report = new StringBuilder(); + report.append(String.format("\n%-20s %-12s %-12s %-12s %-7s\n\n", + "label", + "accuracy", + "precision", + "recall", + "f1")); + + int cumulated_tp = 0; + int cumulated_fp = 0; + int cumulated_tn = 0; + int cumulated_fn = 0; + double cumulated_accuracy = 0.0; + + int totalValidFields = 0; + + int totalFields = stats.getTotalFields(); + + for (String label : stats.getLabels()) { + if (label.equals("") || label.equals("base") || label.equals("O")) { + continue; + } + + LabelStat labelStat = stats.getLabelStat(label); + int tp = labelStat.getObserved(); // true positives + int fp = labelStat.getFalsePositive(); // false positives + int fn = labelStat.getFalseNegative(); // false negative + int tn = totalFields - tp - (fp + fn); // true negatives + int expected = labelStat.getExpected(); // all expected + + double accuracy = (double) (tp + tn) / (tp + fp + tn + fn); + if (accuracy < 0.0) + accuracy = 0.0; + + double precision = labelStat.getPrecision(); + double recall = labelStat.getRecall(); + double f1Score = labelStat.getF1Score(); + + report.append(String.format("%-20s %-12s %-12s %-12s %-7s\n", + label, + TextUtilities.formatTwoDecimals(accuracy * 100), + TextUtilities.formatTwoDecimals(precision * 100), + TextUtilities.formatTwoDecimals(recall * 100), + TextUtilities.formatTwoDecimals(f1Score * 100))); + + if (expected != 0) { + totalValidFields++; + + cumulated_tp += tp; + cumulated_fp += fp; + cumulated_tn += tn; + cumulated_fn += fn; + + cumulated_accuracy += accuracy; + } + } + + report.append("\n"); + + // micro average over measures + double accuracy = 0.0; + if (cumulated_tp + cumulated_fp + cumulated_tn + cumulated_fn != 0.0) + accuracy = ((double) cumulated_tp + cumulated_tn) / (cumulated_tp + cumulated_fp + cumulated_tn + cumulated_fn); + accuracy = Math.min(1.0, accuracy); + + report.append(String.format("%-20s %-12s %-12s %-12s %-7s (micro average)\n", + "all fields", + TextUtilities.formatTwoDecimals(accuracy * 100), + TextUtilities.formatTwoDecimals(stats.getMicroAveragePrecision() * 100), + TextUtilities.formatTwoDecimals(stats.getMicroAverageRecall() * 100), + TextUtilities.formatTwoDecimals(stats.getMicroAverageF1() * 100))); + + // macro average over measures + if (totalValidFields == 0) + accuracy = 0.0; + else + accuracy = Math.min(1.0, cumulated_accuracy / totalValidFields); + + + report.append(String.format("%-20s %-12s %-12s %-12s %-7s (macro average)\n", + "", + TextUtilities.formatTwoDecimals(accuracy * 100), + TextUtilities.formatTwoDecimals(stats.getMacroAveragePrecision() * 100), + TextUtilities.formatTwoDecimals(stats.getMacroAverageRecall() * 100), + TextUtilities.formatTwoDecimals(stats.getMacroAverageF1() * 100))); + + return report.toString(); + } } diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java new file mode 100644 index 0000000000..636522c90d --- /dev/null +++ b/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java @@ -0,0 +1,307 @@ +package org.grobid.trainer; + +import org.junit.Before; +import org.junit.Test; + +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.*; + +public class StatsTest { + Stats target; + + @Before + public void setUp() throws Exception { + target = new Stats(); + } + + @Test + public void testPrecision_fullMatch() throws Exception { + target.getLabelStat("BAO").setExpected(4); + target.getLabelStat("BAO").setObserved(4); + assertThat(target.getLabelStat("BAO").getPrecision(), is(1.0)); + } + + @Test + public void testPrecision_noMatch() throws Exception { + target.getLabelStat("MIAO").setExpected(4); + target.getLabelStat("MIAO").setObserved(0); + target.getLabelStat("MIAO").setFalsePositive(3); + target.getLabelStat("MIAO").setFalseNegative(1); + assertThat(target.getLabelStat("MIAO").getPrecision(), is(0.0)); + } + + @Test + public void testPrecision_missingMatch() throws Exception { + // The precision stays at 1.0 because none of the observed + // is wrong (no false positives) + target.getLabelStat("CIAO").setExpected(4); + target.getLabelStat("CIAO").setObserved(1); + target.getLabelStat("CIAO").setFalseNegative(3); + assertThat(target.getLabelStat("CIAO").getPrecision(), is(1.0)); + } + + @Test + public void testPrecision_2wronglyRecognised() throws Exception { + target.getLabelStat("ZIAO").setExpected(4); + target.getLabelStat("ZIAO").setObserved(1); + target.getLabelStat("ZIAO").setFalsePositive(2); + + assertThat(target.getLabelStat("ZIAO").getPrecision(), is(0.3333333333333333)); + } + + @Test + public void testRecall_fullMatch() throws Exception { + target.getLabelStat("BAO").setExpected(4); + target.getLabelStat("BAO").setObserved(4); + assertThat(target.getLabelStat("BAO").getRecall(), is(1.0)); + } + + @Test + public void testRecall_noMatch() throws Exception { + target.getLabelStat("MIAO").setExpected(4); + target.getLabelStat("MIAO").setObserved(0); + target.getLabelStat("MIAO").setFalsePositive(3); + target.getLabelStat("MIAO").setFalseNegative(1); + assertThat(target.getLabelStat("MIAO").getRecall(), is(0.0)); + } + + @Test + public void testRecall_oneOverFour() throws Exception { + target.getLabelStat("CIAO").setExpected(4); + target.getLabelStat("CIAO").setObserved(1); + target.getLabelStat("CIAO").setFalseNegative(3); + assertThat(target.getLabelStat("CIAO").getRecall(), is(0.25)); + + } + + @Test + public void testRecall_partialMatch() throws Exception { + target.getLabelStat("ZIAO").setExpected(4); + target.getLabelStat("ZIAO").setObserved(3); + target.getLabelStat("ZIAO").setFalsePositive(1); + + assertThat(target.getLabelStat("ZIAO").getPrecision(), is(0.75)); + } + + + // Average measures + + @Test + public void testMicroAvgPrecision_shouldWork() throws Exception { + target.getLabelStat("BAO").setExpected(4); + target.getLabelStat("BAO").setObserved(4); + + target.getLabelStat("MIAO").setExpected(4); + target.getLabelStat("MIAO").setObserved(0); + target.getLabelStat("MIAO").setFalsePositive(3); + target.getLabelStat("MIAO").setFalseNegative(1); + + target.getLabelStat("CIAO").setExpected(4); + target.getLabelStat("CIAO").setObserved(1); + target.getLabelStat("CIAO").setFalseNegative(3); + + target.getLabelStat("ZIAO").setExpected(4); + target.getLabelStat("ZIAO").setObserved(0); + target.getLabelStat("ZIAO").setFalsePositive(2); + + assertThat(target.getMicroAveragePrecision(), is(((double) 4 + 0 + 1 + 0) / (4 + 0 + 3 + 1 + 0 + 2))); + } + + + @Test + public void testMacroAvgPrecision_shouldWork() throws Exception { + + target.getLabelStat("BAO").setExpected(4); + target.getLabelStat("BAO").setObserved(4); + + target.getLabelStat("MIAO").setExpected(4); + target.getLabelStat("MIAO").setObserved(0); + target.getLabelStat("MIAO").setFalsePositive(3); + target.getLabelStat("MIAO").setFalseNegative(1); + + target.getLabelStat("CIAO").setExpected(4); + target.getLabelStat("CIAO").setObserved(1); + target.getLabelStat("CIAO").setFalseNegative(3); + + target.getLabelStat("ZIAO").setExpected(4); + target.getLabelStat("ZIAO").setObserved(0); + target.getLabelStat("ZIAO").setFalsePositive(2); + + final double precisionBao = target.getLabelStat("BAO").getPrecision(); + final double precisionMiao = target.getLabelStat("MIAO").getPrecision(); + final double precisionCiao = target.getLabelStat("CIAO").getPrecision(); + final double precisionZiao = target.getLabelStat("ZIAO").getPrecision(); + assertThat(target.getMacroAveragePrecision(), + is((precisionBao + precisionMiao + precisionCiao + precisionZiao) / (4))); + } + + + @Test + public void testMicroAvgRecall_shouldWork() throws Exception { + + target.getLabelStat("BAO").setExpected(4); + target.getLabelStat("BAO").setObserved(4); //TP + + target.getLabelStat("MIAO").setExpected(4); + target.getLabelStat("MIAO").setObserved(0); + target.getLabelStat("MIAO").setFalsePositive(3); + target.getLabelStat("MIAO").setFalseNegative(1); + + target.getLabelStat("CIAO").setExpected(4); + target.getLabelStat("CIAO").setObserved(1); + target.getLabelStat("CIAO").setFalseNegative(3); + + target.getLabelStat("ZIAO").setExpected(4); + target.getLabelStat("ZIAO").setObserved(0); + target.getLabelStat("ZIAO").setFalsePositive(2); + + assertThat(target.getMicroAverageRecall(), is(((double) 4 + 0 + 1 + 0) / (4 + 4 + 4 + 4))); + } + + @Test + public void testMacroAvgRecall_shouldWork() throws Exception { + + target.getLabelStat("BAO").setExpected(4); + target.getLabelStat("BAO").setObserved(4); + + target.getLabelStat("MIAO").setExpected(4); + target.getLabelStat("MIAO").setObserved(0); + target.getLabelStat("MIAO").setFalsePositive(3); + target.getLabelStat("MIAO").setFalseNegative(1); + + target.getLabelStat("CIAO").setExpected(4); + target.getLabelStat("CIAO").setObserved(1); + target.getLabelStat("CIAO").setFalseNegative(3); + + target.getLabelStat("ZIAO").setExpected(4); + target.getLabelStat("ZIAO").setObserved(0); + target.getLabelStat("ZIAO").setFalsePositive(2); + + final double recallBao = target.getLabelStat("BAO").getRecall(); + final double recallMiao = target.getLabelStat("MIAO").getRecall(); + final double recallCiao = target.getLabelStat("CIAO").getRecall(); + final double recallZiao = target.getLabelStat("ZIAO").getRecall(); + assertThat(target.getMacroAverageRecall(), + is((recallBao + recallMiao + recallCiao + recallZiao) / (4))); + } + + @Test + public void testMicroAvgF0_shouldWork() throws Exception { + + target.getLabelStat("BAO").setExpected(4); + target.getLabelStat("BAO").setObserved(4); //TP + + target.getLabelStat("MIAO").setExpected(4); + target.getLabelStat("MIAO").setObserved(0); + target.getLabelStat("MIAO").setFalsePositive(3); + target.getLabelStat("MIAO").setFalseNegative(1); + + target.getLabelStat("CIAO").setExpected(4); + target.getLabelStat("CIAO").setObserved(1); + target.getLabelStat("CIAO").setFalseNegative(3); + + target.getLabelStat("ZIAO").setExpected(4); + target.getLabelStat("ZIAO").setObserved(0); + target.getLabelStat("ZIAO").setFalsePositive(2); + + assertThat(target.getMicroAverageF1(), + is(((double) 2 * target.getMicroAveragePrecision() + * target.getMicroAverageRecall()) + / (target.getMicroAveragePrecision() + target.getMicroAverageRecall()))); + } + + @Test + public void testMacroAvgF0_shouldWork() throws Exception { + + target.getLabelStat("BAO").setExpected(4); + target.getLabelStat("BAO").setObserved(4); + + target.getLabelStat("MIAO").setExpected(4); + target.getLabelStat("MIAO").setObserved(0); + target.getLabelStat("MIAO").setFalsePositive(3); + target.getLabelStat("MIAO").setFalseNegative(1); + + target.getLabelStat("CIAO").setExpected(4); + target.getLabelStat("CIAO").setObserved(1); + target.getLabelStat("CIAO").setFalseNegative(3); + + target.getLabelStat("ZIAO").setExpected(4); + target.getLabelStat("ZIAO").setObserved(0); + target.getLabelStat("ZIAO").setFalsePositive(2); + + final double f1Bao = target.getLabelStat("BAO").getRecall(); + final double f1Miao = target.getLabelStat("MIAO").getRecall(); + final double f1Ciao = target.getLabelStat("CIAO").getRecall(); + final double f1Ziao = target.getLabelStat("ZIAO").getRecall(); + + assertThat(target.getMacroAverageF1(), + is((f1Bao + f1Miao + f1Ciao + f1Ziao) / (4))); + } + + + @Test + public void testMicroMacroAveragePrecision() throws Exception { + final LabelStat conceptual = target.getLabelStat("CONCEPTUAL"); + conceptual.setFalsePositive(1); + conceptual.setFalseNegative(1); + conceptual.setObserved(2); + conceptual.setExpected(3); + + final LabelStat location = target.getLabelStat("LOCATION"); + location.setFalsePositive(0); + location.setFalseNegative(0); + location.setObserved(2); + location.setExpected(2); + + final LabelStat media = target.getLabelStat("MEDIA"); + media.setFalsePositive(0); + media.setFalseNegative(0); + media.setObserved(7); + media.setExpected(7); + + final LabelStat national = target.getLabelStat("NATIONAL"); + national.setFalsePositive(1); + national.setFalseNegative(0); + national.setObserved(0); + national.setExpected(0); + + final LabelStat other = target.getLabelStat("O"); + other.setFalsePositive(0); + other.setFalseNegative(1); + other.setObserved(33); + other.setExpected(34); + + final LabelStat organisation = target.getLabelStat("ORGANISATION"); + organisation.setFalsePositive(0); + organisation.setFalseNegative(0); + organisation.setObserved(2); + organisation.setExpected(2); + + final LabelStat period = target.getLabelStat("PERIOD"); + period.setFalsePositive(0); + period.setFalseNegative(0); + period.setObserved(8); + period.setExpected(8); + + final LabelStat person = target.getLabelStat("PERSON"); + person.setFalsePositive(1); + person.setFalseNegative(0); + person.setObserved(0); + person.setExpected(0); + + final LabelStat personType = target.getLabelStat("PERSON_TYPE"); + personType.setFalsePositive(0); + personType.setFalseNegative(1); + personType.setObserved(0); + personType.setExpected(1); + + for (String label : target.getLabels()) { + System.out.println(label + " precision --> " + target.getLabelStat(label).getPrecision()); + System.out.println(label + " recall --> " + target.getLabelStat(label).getRecall()); + } + + System.out.println(target.getMacroAveragePrecision()); + System.out.println(target.getMicroAveragePrecision()); + } + +} \ No newline at end of file diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java index b753e61fb9..9ec6ed8b1d 100644 --- a/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java +++ b/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java @@ -13,22 +13,27 @@ public class EvaluationUtilitiesTest { @Test - public void testMetricsAllGood() throws Exception { + public void testTokenLevelStats_allGood() throws Exception { String result = "a I-<1> I-<1>\nb <1> <1>\nc I-<2> I-<2>\nd I-<1> I-<1>\ne <1> <1>\n"; - //System.out.println(result); Stats wordStats = EvaluationUtilities.tokenLevelStats(result); - Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); - LabelStat labelstat1 = wordStats.getLabelStat("<1>"); - LabelStat labelstat2 = wordStats.getLabelStat("<2>"); + LabelStat labelstat1 = wordStats.getLabelStat("<1>"); + LabelStat labelstat2 = wordStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(4)); assertThat(labelstat2.getObserved(), is(1)); assertThat(labelstat1.getExpected(), is(4)); assertThat(labelstat2.getExpected(), is(1)); + } + + @Test + public void testFieldLevelStats_allGood() throws Exception { + String result = "a I-<1> I-<1>\nb <1> <1>\nc I-<2> I-<2>\nd I-<1> I-<1>\ne <1> <1>\n"; + + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); - labelstat1 = fieldStats.getLabelStat("<1>"); - labelstat2 = fieldStats.getLabelStat("<2>"); + LabelStat labelstat1 = fieldStats.getLabelStat("<1>"); + LabelStat labelstat2 = fieldStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(2)); assertThat(labelstat2.getObserved(), is(1)); @@ -37,55 +42,66 @@ public void testMetricsAllGood() throws Exception { } @Test - public void testMetricsAllFalse() throws Exception { + public void testTokenLevelStats_noMatch() throws Exception { String result = "a I-<1> I-<2>\nb <1> <2>\nc <1> I-<2>\nd <1> <2>\ne <1> <2>\n"; - //System.out.println(result); + Stats wordStats = EvaluationUtilities.tokenLevelStats(result); - Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); LabelStat labelstat1 = wordStats.getLabelStat("<1>"); - LabelStat labelstat2 = wordStats.getLabelStat("<2>"); + LabelStat labelstat2 = wordStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(0)); assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat1.getFalseNegative(), is(5)); assertThat(labelstat2.getFalsePositive(), is(5)); + } - labelstat1 = fieldStats.getLabelStat("<1>"); - labelstat2 = fieldStats.getLabelStat("<2>"); + @Test + public void testFieldLevelStats_noMatch() throws Exception { + String result = "a I-<1> I-<2>\nb <1> <2>\nc <1> I-<2>\nd <1> <2>\ne <1> <2>\n"; + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); - assertThat(labelstat1.getObserved(), is(0)); + LabelStat labelstat1 = fieldStats.getLabelStat("<1>"); + LabelStat labelstat2 = fieldStats.getLabelStat("<2>"); + + assertThat(labelstat1.getObserved(), is(0)); assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat1.getExpected(), is(1)); assertThat(labelstat2.getExpected(), is(0)); } - @Test - public void testMetricsMixed1() throws Exception { - // label of c is false - // token 80 precision for label <1>, 0 for label <2> - // field: precision and recall are 0, because the whole - // sequence abcde with label <1> does not make sub-field - // ab and de correctly label with respect to positions + @Test + public void testTokenLevelStats_mixed() throws Exception { + // label of c is false + // token 80 precision for label <1>, 0 for label <2> String result = "a I-<1> I-<1>\nb <1> <1>\nc I-<2> <1>\nd I-<1> <1>\ne <1> <1>\n"; //System.out.println(result); - Stats wordStats = EvaluationUtilities.tokenLevelStats(result); - Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); + Stats wordStats = EvaluationUtilities.tokenLevelStats(result); LabelStat labelstat1 = wordStats.getLabelStat("<1>"); - LabelStat labelstat2 = wordStats.getLabelStat("<2>"); + LabelStat labelstat2 = wordStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(4)); - assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat1.getExpected(), is(4)); - assertThat(labelstat2.getExpected(), is(1)); assertThat(labelstat1.getFalseNegative(), is(0)); + assertThat(labelstat1.getFalsePositive(), is(1)); + + assertThat(labelstat2.getObserved(), is(0)); + assertThat(labelstat2.getExpected(), is(1)); assertThat(labelstat2.getFalseNegative(), is(1)); - assertThat(labelstat1.getFalsePositive(), is(1)); assertThat(labelstat2.getFalsePositive(), is(0)); + } - labelstat1 = fieldStats.getLabelStat("<1>"); - labelstat2 = fieldStats.getLabelStat("<2>"); + @Test + public void testFieldLevelStats_mixed() throws Exception { + // field: precision and recall are 0, because the whole + // sequence abcde with label <1> does not make sub-field + // ab and de correctly label with respect to positions + String result = "a I-<1> I-<1>\nb <1> <1>\nc I-<2> <1>\nd I-<1> <1>\ne <1> <1>\n"; + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); + + LabelStat labelstat1 = fieldStats.getLabelStat("<1>"); + LabelStat labelstat2 = fieldStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(0)); assertThat(labelstat2.getObserved(), is(0)); @@ -94,16 +110,12 @@ public void testMetricsMixed1() throws Exception { } @Test - public void testMetricsMixed1Bis() throws Exception { - // variant of testMetricsMixed1 where the I- prefix impact the field-level results - // with field ab correctly found + public void testTokenLevelStats2_mixed() throws Exception { String result = "a I-<1> I-<1>\nb <1> <1>\nc I-<2> I-<1>\nd I-<1> <1>\ne <1> <1>\n"; - //System.out.println(result); - Stats wordStats = EvaluationUtilities.tokenLevelStats(result); - Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); + Stats wordStats = EvaluationUtilities.tokenLevelStats(result); LabelStat labelstat1 = wordStats.getLabelStat("<1>"); - LabelStat labelstat2 = wordStats.getLabelStat("<2>"); + LabelStat labelstat2 = wordStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(4)); assertThat(labelstat2.getObserved(), is(0)); @@ -111,32 +123,36 @@ public void testMetricsMixed1Bis() throws Exception { assertThat(labelstat2.getExpected(), is(1)); assertThat(labelstat1.getFalseNegative(), is(0)); assertThat(labelstat2.getFalseNegative(), is(1)); - assertThat(labelstat1.getFalsePositive(), is(1)); + assertThat(labelstat1.getFalsePositive(), is(1)); assertThat(labelstat2.getFalsePositive(), is(0)); + } - labelstat1 = fieldStats.getLabelStat("<1>"); - labelstat2 = fieldStats.getLabelStat("<2>"); + @Test + public void testFieldLevelStats2_mixed() throws Exception { + // variant of testMetricsMixed1 where the I- prefix impact the field-level results + // with field ab correctly found + + String result = "a I-<1> I-<1>\nb <1> <1>\nc I-<2> I-<1>\nd I-<1> <1>\ne <1> <1>\n"; + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); - /*System.out.println(labelstat1.toString()); - System.out.println(labelstat2.toString()); - String report = EvaluationUtilities.reportMetrics(result); - System.out.println(report);*/ + LabelStat labelstat1 = fieldStats.getLabelStat("<1>"); + LabelStat labelstat2 = fieldStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(1)); - assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat1.getExpected(), is(2)); + + assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat2.getExpected(), is(1)); } @Test - public void testMetricsMixed2() throws Exception { + public void testTokenLevelStats3_mixed() throws Exception { String result = "a I-<1> I-<1>\nb <1> <1>\nc <1> I-<2>\nd <1> I-<1>\ne <1> <1>\n"; //System.out.println(result); Stats wordStats = EvaluationUtilities.tokenLevelStats(result); - Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); LabelStat labelstat1 = wordStats.getLabelStat("<1>"); - LabelStat labelstat2 = wordStats.getLabelStat("<2>"); + LabelStat labelstat2 = wordStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(4)); assertThat(labelstat2.getObserved(), is(0)); @@ -144,11 +160,18 @@ public void testMetricsMixed2() throws Exception { assertThat(labelstat2.getExpected(), is(0)); assertThat(labelstat1.getFalseNegative(), is(1)); assertThat(labelstat2.getFalseNegative(), is(0)); - assertThat(labelstat1.getFalsePositive(), is(0)); + assertThat(labelstat1.getFalsePositive(), is(0)); assertThat(labelstat2.getFalsePositive(), is(1)); - labelstat1 = fieldStats.getLabelStat("<1>"); - labelstat2 = fieldStats.getLabelStat("<2>"); + } + + @Test + public void testFieldLevelStats3_mixed() throws Exception { + String result = "a I-<1> I-<1>\nb <1> <1>\nc <1> I-<2>\nd <1> I-<1>\ne <1> <1>\n"; + + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); + LabelStat labelstat1 = fieldStats.getLabelStat("<1>"); + LabelStat labelstat2 = fieldStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(0)); assertThat(labelstat2.getObserved(), is(0)); @@ -156,36 +179,36 @@ public void testMetricsMixed2() throws Exception { assertThat(labelstat2.getExpected(), is(0)); assertThat(labelstat1.getFalseNegative(), is(1)); assertThat(labelstat2.getFalseNegative(), is(0)); - assertThat(labelstat1.getFalsePositive(), is(2)); + assertThat(labelstat1.getFalsePositive(), is(2)); assertThat(labelstat2.getFalsePositive(), is(1)); } @Test - public void testMetricsMixed3() throws Exception { + public void testTokenLevelStats4_mixed() throws Exception { String result = "a I-<1> I-<1>\nb I-<2> <1>\nc <2> I-<2>\nd <2> <2>\ne I-<1> I-<1>\nf <1> <1>\ng I-<2> I-<2>\n"; - //System.out.println(result); + Stats wordStats = EvaluationUtilities.tokenLevelStats(result); - Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); LabelStat labelstat1 = wordStats.getLabelStat("<1>"); - LabelStat labelstat2 = wordStats.getLabelStat("<2>"); + LabelStat labelstat2 = wordStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(3)); - assertThat(labelstat2.getObserved(), is(3)); assertThat(labelstat1.getExpected(), is(3)); - assertThat(labelstat2.getExpected(), is(4)); assertThat(labelstat1.getFalseNegative(), is(0)); + assertThat(labelstat1.getFalsePositive(), is(1)); + + assertThat(labelstat2.getObserved(), is(3)); + assertThat(labelstat2.getExpected(), is(4)); assertThat(labelstat2.getFalseNegative(), is(1)); - assertThat(labelstat1.getFalsePositive(), is(1)); assertThat(labelstat2.getFalsePositive(), is(0)); - - labelstat1 = fieldStats.getLabelStat("<1>"); - labelstat2 = fieldStats.getLabelStat("<2>"); - - /*System.out.println(labelstat1.toString()); - System.out.println(labelstat2.toString()); - String report = EvaluationUtilities.reportMetrics(result); - System.out.println(report);*/ + } + @Test + public void testFieldLevelStats4_mixed() throws Exception { + String result = "a I-<1> I-<1>\nb I-<2> <1>\nc <2> I-<2>\nd <2> <2>\ne I-<1> I-<1>\nf <1> <1>\ng I-<2> I-<2>\n"; + + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); + LabelStat labelstat1 = fieldStats.getLabelStat("<1>"); + LabelStat labelstat2 = fieldStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(1)); assertThat(labelstat2.getObserved(), is(1)); @@ -193,20 +216,19 @@ public void testMetricsMixed3() throws Exception { assertThat(labelstat2.getExpected(), is(2)); assertThat(labelstat1.getFalseNegative(), is(1)); assertThat(labelstat2.getFalseNegative(), is(1)); - assertThat(labelstat1.getFalsePositive(), is(1)); + assertThat(labelstat1.getFalsePositive(), is(1)); assertThat(labelstat2.getFalsePositive(), is(1)); } @Test - public void testMetricsReal() throws Exception { - String result = IOUtils.toString(this.getClass().getResourceAsStream("/ex.txt.txt"), StandardCharsets.UTF_8); + public void testTokenLevelStats_realCase() throws Exception { + String result = IOUtils.toString(this.getClass().getResourceAsStream("/ex.txt.txt"), StandardCharsets.UTF_8); result = result.replace(System.lineSeparator(), "\n"); Stats wordStats = EvaluationUtilities.tokenLevelStats(result); - Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); LabelStat labelstat1 = wordStats.getLabelStat(""); - LabelStat labelstat2 = wordStats.getLabelStat(""); + LabelStat labelstat2 = wordStats.getLabelStat(""); assertThat(labelstat1.getObserved(), is(378)); assertThat(labelstat2.getObserved(), is(6)); @@ -214,25 +236,26 @@ public void testMetricsReal() throws Exception { assertThat(labelstat2.getExpected(), is(9)); assertThat(labelstat1.getFalseNegative(), is(0)); assertThat(labelstat2.getFalseNegative(), is(3)); - assertThat(labelstat1.getFalsePositive(), is(3)); + assertThat(labelstat1.getFalsePositive(), is(3)); assertThat(labelstat2.getFalsePositive(), is(0)); - labelstat1 = fieldStats.getLabelStat(""); - labelstat2 = fieldStats.getLabelStat(""); + } + + @Test + public void testFieldLevelStats_realCase() throws Exception { + String result = IOUtils.toString(this.getClass().getResourceAsStream("/ex.txt.txt"), StandardCharsets.UTF_8); - /*System.out.println(labelstat1.toString()); - System.out.println(labelstat2.toString()); - String report = EvaluationUtilities.reportMetrics(result); - System.out.println(report);*/ + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); + LabelStat labelstat1 = fieldStats.getLabelStat(""); + LabelStat labelstat2 = fieldStats.getLabelStat(""); - assertThat(labelstat1.getObserved(), is(1)); + assertThat(labelstat1.getObserved(), is(1)); assertThat(labelstat2.getObserved(), is(2)); assertThat(labelstat1.getExpected(), is(3)); assertThat(labelstat2.getExpected(), is(3)); assertThat(labelstat1.getFalseNegative(), is(2)); assertThat(labelstat2.getFalseNegative(), is(1)); - assertThat(labelstat1.getFalsePositive(), is(1)); + assertThat(labelstat1.getFalsePositive(), is(1)); assertThat(labelstat2.getFalsePositive(), is(0)); } - } \ No newline at end of file From 6368a6a0484b4a9e31f851d4d48c9d79d6e4a653 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Thu, 16 Nov 2017 11:11:52 +0100 Subject: [PATCH 02/37] adding more real tests and verify manual calculations --- .../java/org/grobid/trainer/StatsTest.java | 84 +++++++++++++++++++ .../evaluation/EvaluationUtilitiesTest.java | 68 +++++++++++++-- ...{ex.txt.txt => sample.wapiti.output.1.txt} | 0 3 files changed, 147 insertions(+), 5 deletions(-) rename grobid-trainer/src/test/resources/{ex.txt.txt => sample.wapiti.output.1.txt} (100%) diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java index 636522c90d..2c00f68353 100644 --- a/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java +++ b/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java @@ -1,5 +1,6 @@ package org.grobid.trainer; +import org.hamcrest.CoreMatchers; import org.junit.Before; import org.junit.Test; @@ -304,4 +305,87 @@ public void testMicroMacroAveragePrecision() throws Exception { System.out.println(target.getMicroAveragePrecision()); } + + @Test + public void testMicroMacroAverageMeasures_realTest() throws Exception { + LabelStat otherLabelStats = target.getLabelStat("O"); + otherLabelStats.setExpected(33); + otherLabelStats.setObserved(33); + otherLabelStats.setFalseNegative(1); + + assertThat(otherLabelStats.getPrecision(), is(1.0)); + assertThat(otherLabelStats.getRecall(), is(1.0)); + + + LabelStat conceptualLabelStats = target.getLabelStat("CONCEPTUAL"); + conceptualLabelStats.setObserved(2); + conceptualLabelStats.setExpected(3); + conceptualLabelStats.setFalseNegative(1); + conceptualLabelStats.setFalsePositive(1); + + assertThat(conceptualLabelStats.getPrecision(), is(0.6666666666666666)); + assertThat(conceptualLabelStats.getRecall(), is(0.6666666666666666)); + + + LabelStat periodLabelStats = target.getLabelStat("PERIOD"); + periodLabelStats.setObserved(8); + periodLabelStats.setExpected(8); + + assertThat(periodLabelStats.getPrecision(), is(1.0)); + assertThat(periodLabelStats.getRecall(), is(1.0)); + + + LabelStat mediaLabelStats = target.getLabelStat("MEDIA"); + mediaLabelStats.setObserved(7); + mediaLabelStats.setExpected(7); + + assertThat(mediaLabelStats.getPrecision(), is(1.0)); + assertThat(mediaLabelStats.getRecall(), is(1.0)); + + + LabelStat personTypeLabelStats = target.getLabelStat("PERSON_TYPE"); + personTypeLabelStats.setObserved(0); + personTypeLabelStats.setExpected(1); + personTypeLabelStats.setFalseNegative(1); + + assertThat(personTypeLabelStats.getPrecision(), is(0.0)); + assertThat(personTypeLabelStats.getRecall(), is(0.0)); + + + + LabelStat locationTypeLabelStats = target.getLabelStat("LOCATION"); + locationTypeLabelStats.setObserved(2); + locationTypeLabelStats.setExpected(2); + + assertThat(locationTypeLabelStats.getPrecision(), is(1.0)); + assertThat(locationTypeLabelStats.getRecall(), is(1.0)); + + + LabelStat organisationTypeLabelStats = target.getLabelStat("ORGANISATION"); + organisationTypeLabelStats.setObserved(2); + organisationTypeLabelStats.setExpected(2); + + assertThat(locationTypeLabelStats.getPrecision(), is(1.0)); + assertThat(locationTypeLabelStats.getRecall(), is(1.0)); + + + LabelStat personLabelStats = target.getLabelStat("PERSON"); + personLabelStats.setFalsePositive(1); + + assertThat(personLabelStats.getPrecision(), is(0.0)); + assertThat(personLabelStats.getRecall(), is(0.0)); + + // 2+8+2+2+7 / (2+8+2+2+7+1) + assertThat(target.getMicroAverageRecall(), is(0.9130434782608695)); //91.3 + + // 2+8+2+2+7 / (3+8+7+2+2+1) + assertThat(target.getMicroAveragePrecision(), is(0.9545454545454546)); //95.45 + + // 0.66 + 1.0 + 1.0 + 0.0 + 1.0 + 1.0 / 6 + assertThat(target.getMacroAverageRecall(), is(0.7777777777777777)); //77.78 + + // same as above + assertThat(target.getMacroAveragePrecision(), is(0.7777777777777777)); //77.78 + } + } \ No newline at end of file diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java index 9ec6ed8b1d..40130ce6c7 100644 --- a/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java +++ b/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java @@ -140,7 +140,7 @@ public void testFieldLevelStats2_mixed() throws Exception { assertThat(labelstat1.getObserved(), is(1)); assertThat(labelstat1.getExpected(), is(2)); - + assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat2.getExpected(), is(1)); } @@ -196,16 +196,17 @@ public void testTokenLevelStats4_mixed() throws Exception { assertThat(labelstat1.getExpected(), is(3)); assertThat(labelstat1.getFalseNegative(), is(0)); assertThat(labelstat1.getFalsePositive(), is(1)); - + assertThat(labelstat2.getObserved(), is(3)); assertThat(labelstat2.getExpected(), is(4)); assertThat(labelstat2.getFalseNegative(), is(1)); assertThat(labelstat2.getFalsePositive(), is(0)); } + @Test public void testFieldLevelStats4_mixed() throws Exception { String result = "a I-<1> I-<1>\nb I-<2> <1>\nc <2> I-<2>\nd <2> <2>\ne I-<1> I-<1>\nf <1> <1>\ng I-<2> I-<2>\n"; - + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); LabelStat labelstat1 = fieldStats.getLabelStat("<1>"); LabelStat labelstat2 = fieldStats.getLabelStat("<2>"); @@ -222,7 +223,7 @@ public void testFieldLevelStats4_mixed() throws Exception { @Test public void testTokenLevelStats_realCase() throws Exception { - String result = IOUtils.toString(this.getClass().getResourceAsStream("/ex.txt.txt"), StandardCharsets.UTF_8); + String result = IOUtils.toString(this.getClass().getResourceAsStream("/sample.wapiti.output.1.txt"), StandardCharsets.UTF_8); result = result.replace(System.lineSeparator(), "\n"); Stats wordStats = EvaluationUtilities.tokenLevelStats(result); @@ -243,7 +244,7 @@ public void testTokenLevelStats_realCase() throws Exception { @Test public void testFieldLevelStats_realCase() throws Exception { - String result = IOUtils.toString(this.getClass().getResourceAsStream("/ex.txt.txt"), StandardCharsets.UTF_8); + String result = IOUtils.toString(this.getClass().getResourceAsStream("/sample.wapiti.output.1.txt"), StandardCharsets.UTF_8); Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); LabelStat labelstat1 = fieldStats.getLabelStat(""); @@ -258,4 +259,61 @@ public void testFieldLevelStats_realCase() throws Exception { assertThat(labelstat1.getFalsePositive(), is(1)); assertThat(labelstat2.getFalsePositive(), is(0)); } + + @Test + public void testTokenLevelStats2_realCase() throws Exception { + String result = IOUtils.toString(this.getClass().getResourceAsStream("/sample.wapiti.output.2.txt"), StandardCharsets.UTF_8); + + Stats stats = EvaluationUtilities.tokenLevelStats(result); + + LabelStat conceptualLabelStats = stats.getLabelStat("CONCEPTUAL"); + assertThat(conceptualLabelStats.getObserved(), is(2)); + assertThat(conceptualLabelStats.getExpected(), is(3)); + assertThat(conceptualLabelStats.getFalseNegative(), is(1)); + assertThat(conceptualLabelStats.getFalsePositive(), is(1)); + + LabelStat periodLabelStats = stats.getLabelStat("PERIOD"); + assertThat(periodLabelStats.getObserved(), is(8)); + assertThat(periodLabelStats.getExpected(), is(8)); + assertThat(periodLabelStats.getFalseNegative(), is(0)); + assertThat(periodLabelStats.getFalsePositive(), is(0)); + + LabelStat mediaLabelStats = stats.getLabelStat("MEDIA"); + assertThat(mediaLabelStats.getObserved(), is(7)); + assertThat(mediaLabelStats.getExpected(), is(7)); + assertThat(mediaLabelStats.getFalseNegative(), is(0)); + assertThat(mediaLabelStats.getFalsePositive(), is(0)); + + LabelStat personTypeLabelStats = stats.getLabelStat("PERSON_TYPE"); + assertThat(personTypeLabelStats.getObserved(), is(0)); + assertThat(personTypeLabelStats.getExpected(), is(1)); + assertThat(personTypeLabelStats.getFalseNegative(), is(1)); + assertThat(personTypeLabelStats.getFalsePositive(), is(0)); + + LabelStat locationTypeLabelStats = stats.getLabelStat("LOCATION"); + assertThat(locationTypeLabelStats.getObserved(), is(2)); + assertThat(locationTypeLabelStats.getExpected(), is(2)); + assertThat(locationTypeLabelStats.getFalseNegative(), is(0)); + assertThat(locationTypeLabelStats.getFalsePositive(), is(0)); + + LabelStat organisationTypeLabelStats = stats.getLabelStat("ORGANISATION"); + assertThat(organisationTypeLabelStats.getObserved(), is(2)); + assertThat(organisationTypeLabelStats.getExpected(), is(2)); + assertThat(organisationTypeLabelStats.getFalseNegative(), is(0)); + assertThat(organisationTypeLabelStats.getFalsePositive(), is(0)); + + LabelStat otherLabelStats = stats.getLabelStat("O"); + assertThat(otherLabelStats.getObserved(), is(33)); + assertThat(otherLabelStats.getExpected(), is(34)); + assertThat(otherLabelStats.getFalseNegative(), is(1)); + assertThat(otherLabelStats.getFalsePositive(), is(0)); + + LabelStat personLabelStats = stats.getLabelStat("PERSON"); + assertThat(personLabelStats.getObserved(), is(0)); + assertThat(personLabelStats.getExpected(), is(0)); + assertThat(personLabelStats.getFalseNegative(), is(0)); + assertThat(personLabelStats.getFalsePositive(), is(1)); + + + } } \ No newline at end of file diff --git a/grobid-trainer/src/test/resources/ex.txt.txt b/grobid-trainer/src/test/resources/sample.wapiti.output.1.txt similarity index 100% rename from grobid-trainer/src/test/resources/ex.txt.txt rename to grobid-trainer/src/test/resources/sample.wapiti.output.1.txt From 34134de6f6494003a5ac4caa608474039834290f Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Thu, 16 Nov 2017 14:24:18 +0100 Subject: [PATCH 03/37] as usual, the forgotten file --- .../test/resources/sample.wapiti.output.2.txt | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 grobid-trainer/src/test/resources/sample.wapiti.output.2.txt diff --git a/grobid-trainer/src/test/resources/sample.wapiti.output.2.txt b/grobid-trainer/src/test/resources/sample.wapiti.output.2.txt new file mode 100644 index 0000000000..f0e9f4bd5b --- /dev/null +++ b/grobid-trainer/src/test/resources/sample.wapiti.output.2.txt @@ -0,0 +1,57 @@ +Articles articles A Ar Art Arti Artic s es les cles icles INITCAP NODIGIT 0 1 0 0 0 0 0 0 0 0 0 Xxxx Xx O O +from from f fr fro from from m om rom from from NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxxx x O O +European european E Eu Eur Euro Europ n an ean pean opean INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 0 0 Xxxx Xx B-CONCEPTUAL B-NATIONAL +Zionist zionist Z Zi Zio Zion Zioni t st ist nist onist INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 0 0 Xxxx Xx CONCEPTUAL B-CONCEPTUAL +newspapers newspapers n ne new news newsp s rs ers pers apers NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxxx x O O +, , , , , , , , , , , , ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 0 0 , , O O +1901 1901 1 19 190 1901 1901 1 01 901 1901 1901 NOCAPS ALLDIGIT 0 0 0 0 0 1 0 0 0 0 0 dddd d B-PERIOD B-PERIOD +- - - - - - - - - - - - ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 0 0 - - PERIOD PERIOD +1931 1931 1 19 193 1931 1931 1 31 931 1931 1931 NOCAPS ALLDIGIT 0 0 0 0 0 1 0 0 0 0 0 dddd d PERIOD PERIOD +Section section S Se Sec Sect Secti n on ion tion ction INITCAP NODIGIT 0 1 0 0 0 0 0 1 0 0 0 Xxxx Xx O O +of of o of of of of f of of of of NOCAPS NODIGIT 0 1 0 0 0 0 0 1 0 0 0 xx x O O +the the t th the the the e he the the the NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxx x O O +L l L L L L L L L L L L ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 1 0 X X B-MEDIA B-MEDIA +' ' ' ' ' ' ' ' ' ' ' ' ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 1 0 ' ' MEDIA MEDIA +echo echo e ec ech echo echo o ho cho echo echo NOCAPS NODIGIT 0 1 1 0 0 0 0 1 0 1 0 xxxx x MEDIA MEDIA +Sioniste sioniste S Si Sio Sion Sioni e te ste iste niste INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 0 0 Xxxx Xx MEDIA MEDIA +newspaper newspaper n ne new news newsp r er per aper paper NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 1 0 xxxx x O O +with with w wi wit with with h th ith with with NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxxx x O O +reports reports r re rep repo repor s ts rts orts ports NOCAPS NODIGIT 1 1 0 0 0 0 0 0 0 0 0 xxxx x O O +on on o on on on on n on on on on NOCAPS NODIGIT 0 1 0 0 0 0 0 1 0 0 0 xx x O O +Jewish jewish J Je Jew Jewi Jewis h sh ish wish ewish INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 0 0 Xxxx Xx B-PERSON_TYPE B-CONCEPTUAL +problems problems p pr pro prob probl s ms ems lems blems NOCAPS NODIGIT 1 1 0 0 0 0 0 0 0 0 0 xxxx x O O +in in i in in in in n in in in in NOCAPS NODIGIT 0 1 1 0 0 0 0 1 0 0 0 xx x O O +Bulgaria bulgaria B Bu Bul Bulg Bulga a ia ria aria garia INITCAP NODIGIT 0 0 0 0 1 0 0 1 0 0 0 Xxxx Xx B-LOCATION B-LOCATION +, , , , , , , , , , , , ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 0 0 , , O O +May may M Ma May May May y ay May May May INITCAP NODIGIT 0 1 1 0 0 0 1 1 0 0 0 Xxx Xx B-PERIOD B-PERIOD +1901 1901 1 19 190 1901 1901 1 01 901 1901 1901 NOCAPS ALLDIGIT 0 0 0 0 0 1 0 0 0 0 0 dddd d PERIOD PERIOD +; ; ; ; ; ; ; ; ; ; ; ; ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 0 0 ; ; O O +Page page P Pa Pag Page Page e ge age Page Page INITCAP NODIGIT 1 1 1 0 0 0 0 1 0 0 0 Xxxx Xx O B-PERSON +from from f fr fro from from m om rom from from NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxxx x O O +the the t th the the the e he the the the NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxx x O O +Die die D Di Die Die Die e ie Die Die Die INITCAP NODIGIT 0 1 0 0 0 0 0 1 0 0 0 Xxx Xx B-MEDIA B-MEDIA +Arbeit arbeit A Ar Arb Arbe Arbei t it eit beit rbeit INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 0 0 Xxxx Xx MEDIA MEDIA +newspaper newspaper n ne new news newsp r er per aper paper NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 1 0 xxxx x O O +of of o of of of of f of of of of NOCAPS NODIGIT 0 1 0 0 0 0 0 1 0 0 0 xx x O O +Hapoel hapoel H Ha Hap Hapo Hapoe l el oel poel apoel INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 1 0 Xxxx Xx B-ORGANISATION B-ORGANISATION +Hatzair hatzair H Ha Hat Hatz Hatza r ir air zair tzair INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 0 0 Xxxx Xx ORGANISATION ORGANISATION +, , , , , , , , , , , , ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 0 0 , , O O +15 15 1 15 15 15 15 5 15 15 15 15 NOCAPS ALLDIGIT 0 0 0 0 0 0 0 0 0 0 0 dd d B-PERIOD B-PERIOD +February february F Fe Feb Febr Febru y ry ary uary ruary INITCAP NODIGIT 0 0 0 0 0 0 1 0 0 0 0 Xxxx Xx PERIOD PERIOD +1920 1920 1 19 192 1920 1920 0 20 920 1920 1920 NOCAPS ALLDIGIT 0 0 0 0 0 1 0 0 0 0 0 dddd d PERIOD PERIOD +; ; ; ; ; ; ; ; ; ; ; ; ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 0 0 ; ; O O +Issue issue I Is Iss Issu Issue e ue sue ssue Issue INITCAP NODIGIT 0 1 0 0 0 0 0 0 0 0 0 Xxxx Xx O O +of of o of of of of f of of of of NOCAPS NODIGIT 0 1 0 0 0 0 0 1 0 0 0 xx x O O +Haolam haolam H Ha Hao Haol Haola m am lam olam aolam INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 0 0 Xxxx Xx B-MEDIA B-MEDIA +newspaper newspaper n ne new news newsp r er per aper paper NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 1 0 xxxx x O O +with with w wi wit with with h th ith with with NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxxx x O O +excerpts excerpts e ex exc exce excer s ts pts rpts erpts NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxxx x O O +on on o on on on on n on on on on NOCAPS NODIGIT 0 1 0 0 0 0 0 1 0 0 0 xx x O O +Zionist zionist Z Zi Zio Zion Zioni t st ist nist onist INITCAP NODIGIT 0 0 0 0 0 0 0 0 0 0 0 Xxxx Xx B-CONCEPTUAL B-CONCEPTUAL +events events e ev eve even event s ts nts ents vents NOCAPS NODIGIT 1 1 0 0 0 0 0 0 0 0 0 xxxx x O O +in in i in in in in n in in in in NOCAPS NODIGIT 0 1 1 0 0 0 0 1 0 0 0 xx x O O +Europe europe E Eu Eur Euro Europ e pe ope rope urope INITCAP NODIGIT 0 0 0 0 0 0 0 1 0 0 0 Xxxx Xx B-LOCATION B-LOCATION +at at a at at at at t at at at at NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xx x O O +the the t th the the the e he the the the NOCAPS NODIGIT 0 1 0 0 0 0 0 0 0 0 0 xxx x O O +time time t ti tim time time e me ime time time NOCAPS NODIGIT 0 1 0 0 0 0 0 1 0 0 0 xxxx x O O +. . . . . . . . . . . . ALLCAPS NODIGIT 0 0 0 0 0 0 0 0 0 0 0 . . O O \ No newline at end of file From 8074cf8be7db5156308d487382e6e75f86797af4 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Thu, 16 Nov 2017 17:18:36 +0100 Subject: [PATCH 04/37] disable test not testing anything special --- grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java index 2c00f68353..6fbc593c0b 100644 --- a/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java +++ b/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java @@ -2,6 +2,7 @@ import org.hamcrest.CoreMatchers; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import static org.hamcrest.Matchers.is; @@ -240,6 +241,7 @@ public void testMacroAvgF0_shouldWork() throws Exception { } + @Ignore("Not really useful") @Test public void testMicroMacroAveragePrecision() throws Exception { final LabelStat conceptual = target.getLabelStat("CONCEPTUAL"); From 2107af1311fd521c11424e5de19d3753e822bca8 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Thu, 16 Nov 2017 17:29:51 +0100 Subject: [PATCH 05/37] fix broken test --- .../src/test/java/org/grobid/trainer/StatsTest.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java index 6fbc593c0b..586ec01246 100644 --- a/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java +++ b/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java @@ -231,10 +231,10 @@ public void testMacroAvgF0_shouldWork() throws Exception { target.getLabelStat("ZIAO").setObserved(0); target.getLabelStat("ZIAO").setFalsePositive(2); - final double f1Bao = target.getLabelStat("BAO").getRecall(); - final double f1Miao = target.getLabelStat("MIAO").getRecall(); - final double f1Ciao = target.getLabelStat("CIAO").getRecall(); - final double f1Ziao = target.getLabelStat("ZIAO").getRecall(); + final double f1Bao = target.getLabelStat("BAO").getF1Score(); + final double f1Miao = target.getLabelStat("MIAO").getF1Score(); + final double f1Ciao = target.getLabelStat("CIAO").getF1Score(); + final double f1Ziao = target.getLabelStat("ZIAO").getF1Score(); assertThat(target.getMacroAverageF1(), is((f1Bao + f1Miao + f1Ciao + f1Ziao) / (4))); From 99d3417956a222c07bcdd2ab8da7a1d0d49b8297 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Fri, 17 Nov 2017 11:52:30 +0100 Subject: [PATCH 06/37] refactoring to improve the efficiency when metrics are computed --- .../java/org/grobid/trainer/LabelStat.java | 34 ++ .../main/java/org/grobid/trainer/Stats.java | 238 -------------- .../evaluation/EndToEndEvaluation.java | 8 +- .../evaluation/EvaluationUtilities.java | 99 +----- .../org/grobid/trainer/evaluation/Stats.java | 308 ++++++++++++++++++ .../java/org/grobid/trainer/StatsTest.java | 6 +- .../evaluation/EvaluationUtilitiesTest.java | 1 - 7 files changed, 350 insertions(+), 344 deletions(-) delete mode 100644 grobid-trainer/src/main/java/org/grobid/trainer/Stats.java create mode 100644 grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/LabelStat.java b/grobid-trainer/src/main/java/org/grobid/trainer/LabelStat.java index 71fdb5bc1b..4f17e32877 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/LabelStat.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/LabelStat.java @@ -6,36 +6,48 @@ public final class LabelStat { private int observed = 0; // this is true positives private int expected = 0; // total expected number of items with this label + private double accuracy = 0.0; + private int trueNegative; + private boolean hasChanged = false; + public void incrementFalseNegative() { this.incrementFalseNegative(1); + hasChanged = true; } public void incrementFalsePositive() { this.incrementFalsePositive(1); + hasChanged = true; } public void incrementObserved() { this.incrementObserved(1); + hasChanged = true; } public void incrementExpected() { this.incrementExpected(1); + hasChanged = true; } public void incrementFalseNegative(int count) { this.falseNegative += count; + hasChanged = true; } public void incrementFalsePositive(int count) { this.falsePositive += count; + hasChanged = true; } public void incrementObserved(int count) { this.observed += count; + hasChanged = true; } public void incrementExpected(int count) { this.expected += count; + hasChanged = true; } public int getExpected() { @@ -60,24 +72,36 @@ public int getAll() { public void setFalsePositive(int falsePositive) { this.falsePositive = falsePositive; + hasChanged = true; } public void setFalseNegative(int falseNegative) { this.falseNegative = falseNegative; + hasChanged = true; } public void setObserved(int observed) { this.observed = observed; + hasChanged = true; } public void setExpected(int expected) { this.expected = expected; + hasChanged = true; } public static LabelStat create() { return new LabelStat(); } + public double getAccuracy() { + double accuracy = (double) (observed + trueNegative) / (observed + falsePositive + trueNegative + falseNegative); + if (accuracy < 0.0) + accuracy = 0.0; + + return accuracy; + } + public double getPrecision() { if (observed == 0.0) { return 0.0; @@ -111,4 +135,14 @@ public String toString() { .append("; expected: ").append(expected); return builder.toString(); } + + public void setTrueNegative(int trueNegative) { + this.trueNegative = trueNegative; + } + + public boolean hasChanged() { + Boolean oldValue = new Boolean(hasChanged); + hasChanged = false; + return oldValue; + } } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/Stats.java b/grobid-trainer/src/main/java/org/grobid/trainer/Stats.java deleted file mode 100644 index fb2ccc09a4..0000000000 --- a/grobid-trainer/src/main/java/org/grobid/trainer/Stats.java +++ /dev/null @@ -1,238 +0,0 @@ -package org.grobid.trainer; - -import java.util.Set; -import java.util.TreeMap; - -import org.grobid.core.exceptions.*; - -public final class Stats { - private final TreeMap labelStats; - - public Stats() { - this.labelStats = new TreeMap<>(); - } - - public Set getLabels() { - return this.labelStats.keySet(); - } - - public void incrementFalsePositive(String label) { - this.incrementFalsePositive(label, 1); - } - - public void incrementFalsePositive(String label, int count) { - LabelStat labelStat = this.getLabelStat(label); - if (labelStat == null) - throw new GrobidException("Unknown label: " + label); - labelStat.incrementFalsePositive(count); - } - - public void incrementFalseNegative(String label) { - this.incrementFalseNegative(label, 1); - } - - public void incrementFalseNegative(String label, int count) { - LabelStat labelStat = this.getLabelStat(label); - if (labelStat == null) - throw new GrobidException("Unknown label: " + label); - labelStat.incrementFalseNegative(count); - } - - public void incrementObserved(String label) { - this.incrementObserved(label, 1); - } - - public void incrementObserved(String label, int count) { - LabelStat labelStat = this.getLabelStat(label); - if (labelStat == null) - throw new GrobidException("Unknown label: " + label); - labelStat.incrementObserved(count); - } - - public void incrementExpected(String label) { - this.incrementExpected(label, 1); - } - - public void incrementExpected(String label, int count) { - LabelStat labelStat = this.getLabelStat(label); - if (labelStat == null) - throw new GrobidException("Unknown label: " + label); - labelStat.incrementExpected(count); - } - - public LabelStat getLabelStat(String label) { - if (this.labelStats.containsKey(label)) { - return this.labelStats.get(label); - } - - LabelStat labelStat = LabelStat.create(); - this.labelStats.put(label, labelStat); - - return labelStat; - } - - public int size() { - return this.labelStats.size(); - } - - public double getPrecision(String label) { - LabelStat labelStat = this.getLabelStat(label); - if (labelStat == null) - throw new GrobidException("Unknown label: " + label); - return labelStat.getPrecision(); - } - - public double getRecall(String label) { - LabelStat labelStat = this.getLabelStat(label); - if (labelStat == null) - throw new GrobidException("Unknown label: " + label); - return labelStat.getRecall(); - } - - public double getF1Score(String label) { - LabelStat labelStat = this.getLabelStat(label); - if (labelStat == null) - throw new GrobidException("Unknown label: " + label); - return labelStat.getF1Score(); - } - - /** - * Return the micro average precision, which is the precision calculated - * on the cumulation of the TP and FP over the whole data - */ - public double getMicroAveragePrecision() { - double cumulatedTruePositive = 0.0; - double cumulatedFalsePositive = 0.0; - - for (String label : getLabels()) { - if (label.equals("") || label.equals("base") || label.equals("O")) { - continue; - } - final LabelStat labelStat = getLabelStat(label); - if (labelStat.getExpected() != 0) { - cumulatedTruePositive += labelStat.getObserved(); - cumulatedFalsePositive += labelStat.getFalsePositive(); - } - } - - double precision = 0.0; - if (cumulatedTruePositive + cumulatedFalsePositive != 0) - precision = cumulatedTruePositive / (cumulatedTruePositive + cumulatedFalsePositive); - - return Math.min(1.0, precision); - } - - /** - * Calculate the macro average precision, which is the - * mean average among all the precision for each label - */ - public double getMacroAveragePrecision() { - int totalValidFields = 0; - double cumulated_precision = 0.0; - - for (String label : getLabels()) { - if (label.equals("") || label.equals("base") || label.equals("O")) { - continue; - } - final LabelStat labelStat = getLabelStat(label); - if (labelStat.getExpected() != 0) { - totalValidFields++; - double labelPrecision = labelStat.getPrecision(); - cumulated_precision += labelPrecision; - } - } - - if (totalValidFields == 0) - return 0.0; - - return Math.min(1.0, cumulated_precision / totalValidFields); - } - - public double getMicroAverageRecall() { - double cumulatedTruePositive = 0.0; - double cumulatedExpected = 0.0; - - for (String label : getLabels()) { - if (label.equals("") || label.equals("base") || label.equals("O")) { - continue; - } - final LabelStat labelStat = getLabelStat(label); - if (labelStat.getExpected() != 0) { - cumulatedTruePositive += labelStat.getObserved(); - cumulatedExpected += labelStat.getExpected(); - } - } - - double recall = 0.0; - if (cumulatedExpected != 0.0) - recall = cumulatedTruePositive / cumulatedExpected; - - return Math.min(1.0, recall); - } - - public double getMacroAverageRecall() { - int totalValidFields = 0; - double cumulatedRecall = 0.0; - - for (String label : getLabels()) { - if (label.equals("") || label.equals("base") || label.equals("O")) { - continue; - } - - final LabelStat labelStat = getLabelStat(label); - if (labelStat.getExpected() != 0) { - totalValidFields++; - cumulatedRecall += labelStat.getRecall(); - } - } - - if (totalValidFields == 0) - return 0.0; - - return Math.min(1.0, cumulatedRecall / totalValidFields); - } - - public int getTotalFields() { - int totalFields = 0; - for (String label : getLabels()) { - totalFields += getLabelStat(label).getAll(); - } - - return totalFields; - } - - public double getMicroAverageF1() { - double precision = getMicroAveragePrecision(); - double recall = getMicroAverageRecall(); - - double f1 = 0.0; - if (precision + recall != 0.0) - f1 = (2 * precision * recall) / (precision + recall); - - return f1; - } - - public double getMacroAverageF1() { - double cumulatedF1 = 0.0; - int totalValidFields = 0; - - for (String label : getLabels()) { - if (label.equals("") || label.equals("base") || label.equals("O")) { - continue; - } - - final LabelStat labelStat = getLabelStat(label); - if (labelStat.getExpected() != 0) { - totalValidFields++; - double labelF1 = labelStat.getF1Score(); - cumulatedF1 += labelF1; - } - } - - if (totalValidFields == 0) - return 0.0; - - return Math.min(1.0, cumulatedF1 / totalValidFields); - } -} - diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EndToEndEvaluation.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EndToEndEvaluation.java index 71b406d6c5..d132e0af8c 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EndToEndEvaluation.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EndToEndEvaluation.java @@ -1,17 +1,11 @@ package org.grobid.trainer.evaluation; import org.grobid.core.engines.config.GrobidAnalysisConfig; -import org.grobid.core.engines.tagging.GenericTagger; import org.grobid.core.exceptions.*; import org.grobid.core.engines.Engine; -import org.grobid.core.data.BiblioItem; -import org.grobid.core.data.BibDataSet; import org.grobid.core.factory.GrobidFactory; import org.grobid.core.utilities.GrobidProperties; import org.grobid.core.utilities.UnicodeUtil; -import org.grobid.trainer.Stats; -import org.grobid.trainer.sax.NLMHeaderSaxHandler; -import org.grobid.trainer.sax.FieldExtractSaxHandler; import org.grobid.core.utilities.TextUtilities; import org.grobid.trainer.evaluation.utilities.NamespaceContextMap; import org.grobid.trainer.evaluation.utilities.FieldSpecification; @@ -26,7 +20,7 @@ import javax.xml.xpath.XPathFactory; import javax.xml.parsers.*; import org.xml.sax.*; -import org.xml.sax.helpers.*; + import javax.xml.xpath.XPathConstants; import com.rockymadden.stringmetric.similarity.RatcliffObershelpMetric; diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java index 08373a5ca7..501e361648 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java @@ -5,23 +5,17 @@ import org.grobid.core.engines.tagging.GenericTagger; import org.grobid.core.exceptions.GrobidException; import org.grobid.core.utilities.TextUtilities; -import org.grobid.core.utilities.GrobidProperties; import org.grobid.core.utilities.OffsetPosition; import org.grobid.core.utilities.Pair; -import org.grobid.core.engines.tagging.GrobidCRFEngine; import java.io.BufferedReader; import java.io.FileInputStream; -import java.io.File; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.List; import java.util.StringTokenizer; import org.grobid.trainer.LabelStat; -import org.grobid.trainer.Stats; - -import org.apache.commons.io.FileUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -142,12 +136,12 @@ public static String reportMetrics(String theResult) { // report token-level results Stats wordStats = tokenLevelStats(theResult); report.append("\n===== Token-level results =====\n\n"); - report.append(computeMetrics(wordStats)); + report.append(wordStats.getReport()); // report field-level results Stats fieldStats = fieldLevelStats(theResult); report.append("\n===== Field-level results =====\n"); - report.append(computeMetrics(fieldStats)); + report.append(fieldStats.getReport()); // instance-level: instances are separated by a new line in the result file // third pass @@ -381,92 +375,7 @@ private static void processCounters(Stats stats, String obtained, String expecte } public static String computeMetrics(Stats stats) { - StringBuilder report = new StringBuilder(); - report.append(String.format("\n%-20s %-12s %-12s %-12s %-7s\n\n", - "label", - "accuracy", - "precision", - "recall", - "f1")); - - int cumulated_tp = 0; - int cumulated_fp = 0; - int cumulated_tn = 0; - int cumulated_fn = 0; - double cumulated_accuracy = 0.0; - - int totalValidFields = 0; - - int totalFields = stats.getTotalFields(); - - for (String label : stats.getLabels()) { - if (label.equals("") || label.equals("base") || label.equals("O")) { - continue; - } - - LabelStat labelStat = stats.getLabelStat(label); - int tp = labelStat.getObserved(); // true positives - int fp = labelStat.getFalsePositive(); // false positives - int fn = labelStat.getFalseNegative(); // false negative - int tn = totalFields - tp - (fp + fn); // true negatives - int expected = labelStat.getExpected(); // all expected - - double accuracy = (double) (tp + tn) / (tp + fp + tn + fn); - if (accuracy < 0.0) - accuracy = 0.0; - - double precision = labelStat.getPrecision(); - double recall = labelStat.getRecall(); - double f1Score = labelStat.getF1Score(); - - report.append(String.format("%-20s %-12s %-12s %-12s %-7s\n", - label, - TextUtilities.formatTwoDecimals(accuracy * 100), - TextUtilities.formatTwoDecimals(precision * 100), - TextUtilities.formatTwoDecimals(recall * 100), - TextUtilities.formatTwoDecimals(f1Score * 100))); - - if (expected != 0) { - totalValidFields++; - - cumulated_tp += tp; - cumulated_fp += fp; - cumulated_tn += tn; - cumulated_fn += fn; - - cumulated_accuracy += accuracy; - } - } - - report.append("\n"); - - // micro average over measures - double accuracy = 0.0; - if (cumulated_tp + cumulated_fp + cumulated_tn + cumulated_fn != 0.0) - accuracy = ((double) cumulated_tp + cumulated_tn) / (cumulated_tp + cumulated_fp + cumulated_tn + cumulated_fn); - accuracy = Math.min(1.0, accuracy); - - report.append(String.format("%-20s %-12s %-12s %-12s %-7s (micro average)\n", - "all fields", - TextUtilities.formatTwoDecimals(accuracy * 100), - TextUtilities.formatTwoDecimals(stats.getMicroAveragePrecision() * 100), - TextUtilities.formatTwoDecimals(stats.getMicroAverageRecall() * 100), - TextUtilities.formatTwoDecimals(stats.getMicroAverageF1() * 100))); - - // macro average over measures - if (totalValidFields == 0) - accuracy = 0.0; - else - accuracy = Math.min(1.0, cumulated_accuracy / totalValidFields); - - - report.append(String.format("%-20s %-12s %-12s %-12s %-7s (macro average)\n", - "", - TextUtilities.formatTwoDecimals(accuracy * 100), - TextUtilities.formatTwoDecimals(stats.getMacroAveragePrecision() * 100), - TextUtilities.formatTwoDecimals(stats.getMacroAverageRecall() * 100), - TextUtilities.formatTwoDecimals(stats.getMacroAverageF1() * 100))); - - return report.toString(); + return stats.getReport(); } + } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java new file mode 100644 index 0000000000..dafb7ef3f1 --- /dev/null +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java @@ -0,0 +1,308 @@ +package org.grobid.trainer.evaluation; + +import java.util.Set; +import java.util.TreeMap; + +import org.grobid.core.exceptions.*; +import org.grobid.core.utilities.TextUtilities; +import org.grobid.trainer.LabelStat; + +public final class Stats { + private final TreeMap labelStats; + + // State variable to know whether is required to recompute the statistics + private boolean requiredToRecomputeMetrics = true; + + private double cumulated_tp = 0; + private double cumulated_fp = 0; + private double cumulated_tn = 0; + private double cumulated_fn = 0; + private double cumulated_f1 = 0.0; + private double cumulated_accuracy = 0.0; + private double cumulated_precision = 0.0; + private double cumulated_recall = 0.0; + private double cumulated_expected = 0; + private int totalValidFields = 0; + + public Stats() { + this.labelStats = new TreeMap<>(); + } + + public Set getLabels() { + return this.labelStats.keySet(); + } + + public void incrementFalsePositive(String label) { + this.incrementFalsePositive(label, 1); + } + + public void incrementFalsePositive(String label, int count) { + LabelStat labelStat = this.getLabelStat(label); + if (labelStat == null) + throw new GrobidException("Unknown label: " + label); + labelStat.incrementFalsePositive(count); + requiredToRecomputeMetrics = true; + } + + public void incrementFalseNegative(String label) { + this.incrementFalseNegative(label, 1); + } + + public void incrementFalseNegative(String label, int count) { + LabelStat labelStat = this.getLabelStat(label); + if (labelStat == null) + throw new GrobidException("Unknown label: " + label); + labelStat.incrementFalseNegative(count); + requiredToRecomputeMetrics = true; + } + + public void incrementObserved(String label) { + this.incrementObserved(label, 1); + } + + public void incrementObserved(String label, int count) { + LabelStat labelStat = this.getLabelStat(label); + if (labelStat == null) + throw new GrobidException("Unknown label: " + label); + labelStat.incrementObserved(count); + requiredToRecomputeMetrics = true; + } + + public void incrementExpected(String label) { + this.incrementExpected(label, 1); + } + + public void incrementExpected(String label, int count) { + LabelStat labelStat = this.getLabelStat(label); + if (labelStat == null) + throw new GrobidException("Unknown label: " + label); + labelStat.incrementExpected(count); + requiredToRecomputeMetrics = true; + } + + public LabelStat getLabelStat(String label) { + if (this.labelStats.containsKey(label)) { + return this.labelStats.get(label); + } + + LabelStat labelStat = LabelStat.create(); + this.labelStats.put(label, labelStat); + requiredToRecomputeMetrics = true; + + return labelStat; + } + + public int size() { + return this.labelStats.size(); + } + + public double getPrecision(String label) { + LabelStat labelStat = this.getLabelStat(label); + if (labelStat == null) + throw new GrobidException("Unknown label: " + label); + return labelStat.getPrecision(); + } + + public double getRecall(String label) { + LabelStat labelStat = this.getLabelStat(label); + if (labelStat == null) + throw new GrobidException("Unknown label: " + label); + return labelStat.getRecall(); + } + + public double getF1Score(String label) { + LabelStat labelStat = this.getLabelStat(label); + if (labelStat == null) + throw new GrobidException("Unknown label: " + label); + return labelStat.getF1Score(); + } + + /** + * In order to compute metrics in an efficient way, they are computed all at the same time. + * Since the state of the object is important in this case, it's required to have a flag that + * allow the recompute of the metrics when one is required. + */ + public void computeMetrics() { + for (String label : getLabels()) { + if (getLabelStat(label).hasChanged()) { + requiredToRecomputeMetrics = true; + break; + } + } + + if (!requiredToRecomputeMetrics) + return; + + int totalFields = 0; + for (String label : getLabels()) { + LabelStat labelStat = getLabelStat(label); + totalFields += labelStat.getObserved(); + totalFields += labelStat.getFalseNegative(); + totalFields += labelStat.getFalsePositive(); + } + + for (String label : getLabels()) { + if (label.equals("") || label.equals("base") || label.equals("O")) { + continue; + } + + LabelStat labelStat = getLabelStat(label); + int tp = labelStat.getObserved(); // true positives + int fp = labelStat.getFalsePositive(); // false positives + int fn = labelStat.getFalseNegative(); // false negative + int tn = totalFields - tp - (fp + fn); // true negatives + labelStat.setTrueNegative(tn); + int expected = labelStat.getExpected(); // all expected + + if (expected != 0) { + totalValidFields++; + } + + if (expected != 0) { + cumulated_tp += tp; + cumulated_fp += fp; + cumulated_tn += tn; + cumulated_fn += fn; + + cumulated_expected += expected; + cumulated_f1 += labelStat.getF1Score(); + cumulated_accuracy += labelStat.getAccuracy(); + cumulated_precision += labelStat.getPrecision(); + cumulated_recall += labelStat.getRecall(); + } + } + + requiredToRecomputeMetrics = false; + } + + public String getReport() { + computeMetrics(); + + StringBuilder report = new StringBuilder(); + report.append(String.format("\n%-20s %-12s %-12s %-12s %-7s\n\n", + "label", + "accuracy", + "precision", + "recall", + "f1")); + + for (String label : getLabels()) { + if (label.equals("") || label.equals("base") || label.equals("O")) { + continue; + } + + LabelStat labelStat = getLabelStat(label); + + report.append(String.format("%-20s %-12s %-12s %-12s %-7s\n", + label, + TextUtilities.formatTwoDecimals(labelStat.getAccuracy() * 100), + TextUtilities.formatTwoDecimals(labelStat.getPrecision() * 100), + TextUtilities.formatTwoDecimals(labelStat.getRecall() * 100), + TextUtilities.formatTwoDecimals(labelStat.getF1Score() * 100))); + } + + report.append("\n"); + + report.append(String.format("%-20s %-12s %-12s %-12s %-7s (micro average)\n", + "all fields", + TextUtilities.formatTwoDecimals(getMicroAverageAccuracy() * 100), + TextUtilities.formatTwoDecimals(getMicroAveragePrecision() * 100), + TextUtilities.formatTwoDecimals(getMicroAverageRecall() * 100), + TextUtilities.formatTwoDecimals(getMicroAverageF1() * 100))); + + report.append(String.format("%-20s %-12s %-12s %-12s %-7s (macro average)\n", + "", + TextUtilities.formatTwoDecimals(getMacroAverageAccuracy() * 100), + TextUtilities.formatTwoDecimals(getMacroAveragePrecision() * 100), + TextUtilities.formatTwoDecimals(getMacroAverageRecall() * 100), + TextUtilities.formatTwoDecimals(getMacroAverageF1() * 100))); + + return report.toString(); + } + + + public double getMicroAverageAccuracy() { + computeMetrics(); + + // macro average over measures + if (totalValidFields == 0) + return 0.0; + else + return Math.min(1.0, cumulated_accuracy / totalValidFields); + } + + public double getMacroAverageAccuracy() { + computeMetrics(); + + double accuracy = 0.0; + if (cumulated_tp + cumulated_fp + cumulated_tn + cumulated_fn != 0.0) + accuracy = ((double) cumulated_tp + cumulated_tn) / (cumulated_tp + cumulated_fp + cumulated_tn + cumulated_fn); + + return Math.min(1.0, accuracy); + } + + + public double getMicroAveragePrecision() { + computeMetrics(); + + double precision = 0.0; + if (cumulated_tp + cumulated_fp != 0) + precision = cumulated_tp / (cumulated_tp + cumulated_fp); + + return Math.min(1.0, precision); + } + + public double getMacroAveragePrecision() { + computeMetrics(); + + if (totalValidFields == 0) + return 0.0; + + return Math.min(1.0, cumulated_precision / totalValidFields); + } + + public double getMicroAverageRecall() { + computeMetrics(); + + double recall = 0.0; + if (cumulated_expected != 0.0) + recall = cumulated_tp / cumulated_expected; + + return Math.min(1.0, recall); + } + + public double getMacroAverageRecall() { + computeMetrics(); + + if (totalValidFields == 0) + return 0.0; + + return Math.min(1.0, cumulated_recall / totalValidFields); + } + + public int getTotalValidFields() { + computeMetrics(); + return totalValidFields; + } + + public double getMicroAverageF1() { + double precision = getMicroAveragePrecision(); + double recall = getMicroAverageRecall(); + + double f1 = 0.0; + if (precision + recall != 0.0) + f1 = (2 * precision * recall) / (precision + recall); + + return f1; + } + + public double getMacroAverageF1() { + computeMetrics(); + + if (totalValidFields == 0) + return 0.0; + + return Math.min(1.0, cumulated_f1 / totalValidFields); + } +} + diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java index 586ec01246..fd0e403fae 100644 --- a/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java +++ b/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java @@ -1,12 +1,12 @@ package org.grobid.trainer; -import org.hamcrest.CoreMatchers; +import org.grobid.trainer.evaluation.Stats; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import static org.hamcrest.Matchers.is; -import static org.junit.Assert.*; +import static org.junit.Assert.assertThat; public class StatsTest { Stats target; @@ -234,7 +234,7 @@ public void testMacroAvgF0_shouldWork() throws Exception { final double f1Bao = target.getLabelStat("BAO").getF1Score(); final double f1Miao = target.getLabelStat("MIAO").getF1Score(); final double f1Ciao = target.getLabelStat("CIAO").getF1Score(); - final double f1Ziao = target.getLabelStat("ZIAO").getF1Score(); + final double f1Ziao = target.getLabelStat("ZIAO").getRecall(); assertThat(target.getMacroAverageF1(), is((f1Bao + f1Miao + f1Ciao + f1Ziao) / (4))); diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java index 40130ce6c7..3d82877b9d 100644 --- a/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java +++ b/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java @@ -2,7 +2,6 @@ import org.apache.commons.io.IOUtils; import org.grobid.trainer.LabelStat; -import org.grobid.trainer.Stats; import org.junit.Test; import java.nio.charset.StandardCharsets; From 457f8601692dc6c0466c04d504b8ea88c5ac6649 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Thu, 4 Jul 2019 09:47:37 +0900 Subject: [PATCH 07/37] forgotten import --- .../evaluation/EvaluationDOIMatching.java | 58 ++++++++----------- 1 file changed, 24 insertions(+), 34 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationDOIMatching.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationDOIMatching.java index 7c19528f62..1fc72b3e90 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationDOIMatching.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationDOIMatching.java @@ -1,48 +1,38 @@ package org.grobid.trainer.evaluation; -import org.grobid.core.engines.config.GrobidAnalysisConfig; -import org.grobid.core.engines.tagging.GenericTagger; -import org.grobid.core.exceptions.*; -import org.grobid.core.engines.Engine; -import org.grobid.core.data.BiblioItem; +import com.fasterxml.jackson.core.io.JsonStringEncoder; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.commons.io.FileUtils; import org.grobid.core.data.BibDataSet; +import org.grobid.core.data.BiblioItem; +import org.grobid.core.engines.Engine; +import org.grobid.core.exceptions.GrobidResourceException; import org.grobid.core.factory.GrobidFactory; +import org.grobid.core.utilities.Consolidation.GrobidConsolidationService; import org.grobid.core.utilities.GrobidProperties; -import org.grobid.core.utilities.UnicodeUtil; -import org.grobid.trainer.Stats; -import org.grobid.trainer.sax.NLMHeaderSaxHandler; -import org.grobid.trainer.sax.FieldExtractSaxHandler; import org.grobid.core.utilities.TextUtilities; import org.grobid.trainer.evaluation.utilities.NamespaceContextMap; -import org.grobid.trainer.evaluation.utilities.FieldSpecification; -import org.grobid.core.utilities.Consolidation.GrobidConsolidationService; - -import java.io.*; -import java.util.*; -import java.text.Normalizer; -import java.util.regex.Pattern; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; - -import org.apache.commons.io.FileUtils; - -import org.w3c.dom.*; +import org.w3c.dom.Document; +import org.w3c.dom.Node; +import org.w3c.dom.NodeList; +import org.xml.sax.EntityResolver; +import org.xml.sax.InputSource; + +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.parsers.SAXParserFactory; import javax.xml.xpath.XPath; -import javax.xml.xpath.XPathFactory; -import javax.xml.parsers.*; -import org.xml.sax.*; -import org.xml.sax.helpers.*; import javax.xml.xpath.XPathConstants; - -import com.rockymadden.stringmetric.similarity.RatcliffObershelpMetric; -import scala.Option; - -import com.fasterxml.jackson.core.io.*; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.core.*; -import com.fasterxml.jackson.databind.*; -import com.fasterxml.jackson.databind.node.*; +import javax.xml.xpath.XPathFactory; +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.FilenameFilter; +import java.text.Normalizer; +import java.util.*; +import java.util.regex.Pattern; /** * Evaluation of the DOI matching for the extracted bibliographical references, From 080668c9028e5c290670c934d6f7a6c5d1ab33c4 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Thu, 11 Jul 2019 10:50:59 +0900 Subject: [PATCH 08/37] first implementation of n-fold evaluation #453 --- .../org/grobid/trainer/AbstractTrainer.java | 181 ++++++++++++- .../grobid/trainer/CRFPPGenericTrainer.java | 2 + .../main/java/org/grobid/trainer/Trainer.java | 2 + .../org/grobid/trainer/TrainerRunner.java | 253 +++++++++--------- .../evaluation/EvaluationUtilities.java | 65 +++-- .../grobid/trainer/evaluation/ModelStats.java | 52 ++++ .../org/grobid/trainer/evaluation/Stats.java | 3 + 7 files changed, 402 insertions(+), 156 deletions(-) create mode 100644 grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index a9b69d9042..25bc21d0b1 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -1,5 +1,8 @@ package org.grobid.trainer; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.ImmutablePair; import org.grobid.core.GrobidModel; import org.grobid.core.engines.tagging.GenericTagger; import org.grobid.core.engines.tagging.TaggerFactory; @@ -8,11 +11,17 @@ import org.grobid.core.utilities.GrobidProperties; import org.grobid.core.engines.tagging.GrobidCRFEngine; import org.grobid.trainer.evaluation.EvaluationUtilities; +import org.grobid.trainer.evaluation.ModelStats; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.File; -import java.io.IOException; +import java.io.*; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; /** * @author Zholudev, Lopez @@ -45,13 +54,18 @@ public void setParams(double epsilon, int window, int nbMaxIterations) { this.nbMaxIterations = nbMaxIterations; } + @Override + public int createCRFPPData(final File corpusDir, final File trainingOutputPath) { + return createCRFPPData(corpusDir, trainingOutputPath, null, 1.0); + } + @Override public void train() { final File dataPath = trainDataPath; createCRFPPData(getCorpusPath(), dataPath); GenericTrainer trainer = TrainerFactory.getTrainer(); - if (epsilon != 0.0) + if (epsilon != 0.0) trainer.setEpsilon(epsilon); if (window != 0) trainer.setWindow(window); @@ -89,7 +103,7 @@ protected void renameModels(final File oldModelPath, final File tempModelPath) { @Override public String evaluate() { createCRFPPData(getEvalCorpusPath(), evalDataPath); - return EvaluationUtilities.evaluateStandard(evalDataPath.getAbsolutePath(), getTagger()); + return EvaluationUtilities.reportMetrics(EvaluationUtilities.evaluateStandard(evalDataPath.getAbsolutePath(), getTagger())); } @Override @@ -98,7 +112,7 @@ public String splitTrainEvaluate(Double split) { createCRFPPData(getCorpusPath(), dataPath, evalDataPath, split); GenericTrainer trainer = TrainerFactory.getTrainer(); - if (epsilon != 0.0) + if (epsilon != 0.0) trainer.setEpsilon(epsilon); if (window != 0) trainer.setWindow(window); @@ -111,7 +125,7 @@ public String splitTrainEvaluate(Double split) { dirModelPath.mkdir(); //throw new GrobidException("Cannot find the destination directory " + dirModelPath.getAbsolutePath() + " for the model " + model.toString()); } - + final File tempModelPath = new File(GrobidProperties.getModelPath(model).getAbsolutePath() + NEW_MODEL_EXT); final File oldModelPath = GrobidProperties.getModelPath(model); @@ -120,7 +134,124 @@ public String splitTrainEvaluate(Double split) { // if we are here, that means that training succeeded renameModels(oldModelPath, tempModelPath); - return EvaluationUtilities.evaluateStandard(evalDataPath.getAbsolutePath(), getTagger()); + return EvaluationUtilities.reportMetrics(EvaluationUtilities.evaluateStandard(evalDataPath.getAbsolutePath(), getTagger())); + } + + @Override + public String nFoldEvaluate(int folds) { + final File dataPath = trainDataPath; + createCRFPPData(getCorpusPath(), dataPath); + GenericTrainer trainer = TrainerFactory.getTrainer(); + + // Load in memory and Shuffle + List trainingData = new ArrayList<>(); + try (Stream stream = Files.lines(Paths.get(dataPath.getAbsolutePath()))) { + List instance = new ArrayList<>(); + ListIterator iterator = stream.collect(Collectors.toList()).listIterator(); + while (iterator.hasNext()) { + String current = iterator.next(); + + if (StringUtils.isBlank(current)) { + if (CollectionUtils.isNotEmpty(instance)) { + trainingData.add(String.join("\n", instance)); + } + instance = new ArrayList<>(); + } else { + instance.add(current); + } + } + + } catch (IOException e) { + e.printStackTrace(); + } + +// trainingData.forEach(s -> { +// System.out.println(s); +// System.out.println("\n\n"); +// }); + Collections.shuffle(trainingData); +// Collections.shuffle(trainingData); + +// System.out.println("\n\n"); +// trainingData.forEach(s -> { +// System.out.println(s); +// System.out.println("\n\n"); +// }); + + // Split into folds + + int trainingSize = CollectionUtils.size(trainingData); + int foldSize = Math.floorDiv(trainingSize, folds); + + List> foldMap = IntStream.range(0, folds).mapToObj(foldIndex -> { + int foldStart = foldSize * foldIndex; + int foldEnd = foldStart + foldSize; + + if (foldIndex == folds - 1) { + foldEnd = trainingSize; + } + + List foldEvaluation = trainingData.subList(foldStart, foldEnd); + List foldTraining0 = trainingData.subList(0, foldStart); + List foldTraining1 = trainingData.subList(foldEnd, trainingSize); + List foldTraining = new ArrayList<>(); + foldTraining.addAll(foldTraining0); + foldTraining.addAll(foldTraining1); + + //Dump Evaluation + String tempEvaluationDataPath = getTempEvaluationDataPath().getAbsolutePath(); + System.out.println(tempEvaluationDataPath); + try (BufferedWriter writer = Files.newBufferedWriter(Paths.get(tempEvaluationDataPath))) { +// System.out.println(String.join("\n\n\n", foldEvaluation)); + writer.write(String.join("\n\n", foldEvaluation)); + } catch (IOException e) { + e.printStackTrace(); + } + + //Dump Training + String tempTrainingDataPath = getTempTrainingDataPath().getAbsolutePath(); + System.out.println(tempTrainingDataPath); + try (BufferedWriter writer = Files.newBufferedWriter(Paths.get(tempTrainingDataPath))) { +// System.out.println(String.join("\n\n\n", foldTraining)); + writer.write(String.join("\n\n", foldTraining)); + } catch (IOException e) { + e.printStackTrace(); + } + + return new ImmutablePair<>(tempTrainingDataPath, tempEvaluationDataPath); + }).collect(Collectors.toList()); + + + // Train and evaluastion + + if (epsilon != 0.0) + trainer.setEpsilon(epsilon); + if (window != 0) + trainer.setWindow(window); + if (nbMaxIterations != 0) + trainer.setNbMaxIterations(nbMaxIterations); + + //We dump the model in the tmp directory + File tmpDirectory = new File(GrobidProperties.getTempPath().getAbsolutePath()); + if (!tmpDirectory.exists()) { + LOGGER.warn("Cannot find the destination directory " + tmpDirectory); + } + + final File tempModelPath = new File(tmpDirectory + File.separator + "nfold_dummy_model"); + System.out.println("Saving model in " + tempModelPath); + + List evaluationResults = foldMap.stream().map(fold -> { + trainer.train(getTemplatePath(), new File(fold.getLeft()), tempModelPath, GrobidProperties.getNBThreads(), model); + return EvaluationUtilities.evaluateStandard(fold.getRight(), getTagger()); + }).collect(Collectors.toList()); + + + // Averages + OptionalDouble averageF1 = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageF1()).average(); + OptionalDouble averagePrecision = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAveragePrecision()).average(); + OptionalDouble averageRecall = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageRecall()).average(); + + return "Average precision: " + averagePrecision.getAsDouble() + "\nAverage recall: " + averageRecall.getAsDouble() + "\nAverage F1: " + averageF1.getAsDouble() + "\n "; } protected final File getTempTrainingDataPath() { @@ -149,7 +280,7 @@ protected GenericTagger getTagger() { protected static File getFilePath2Resources() { File theFile = new File(GrobidProperties.get_GROBID_HOME_PATH().getAbsoluteFile() + File.separator + ".." + File.separator - + "grobid-trainer" + File.separator + "resources"); + + "grobid-trainer" + File.separator + "resources"); if (!theFile.exists()) { theFile = new File("resources"); } @@ -174,7 +305,7 @@ protected File getEvalCorpusPath() { public static File getEvalCorpusBasePath() { final String path2Evelutation = getFilePath2Resources().getAbsolutePath() + File.separator + "dataset" + File.separator + "patent" - + File.separator + "evaluation"; + + File.separator + "evaluation"; return new File(path2Evelutation); } @@ -220,4 +351,36 @@ public static void runSplitTrainingEvaluation(final Trainer trainer, Double spli } + public static void runNFoldEvaluation(final Trainer trainer, int numFolds) { + long start = System.currentTimeMillis(); + try { + String report = trainer.nFoldEvaluate(numFolds); + System.out.println(report); + } catch (Exception e) { + throw new GrobidException("An exception occurred while evaluating Grobid.", e); + } + long end = System.currentTimeMillis(); + System.out.println("Split, training and evaluation for " + trainer.getModel() + " model is realized in " + (end - start) + " ms"); + } + + /** + * Dispatch the example to the training or test data, based on the split ration and the drawing of + * a random number + */ + public Writer dispatchExample(Writer writerTraining, Writer writerEvaluation, double splitRatio) { + Writer writer = null; + if ((writerTraining == null) && (writerEvaluation != null)) { + writer = writerEvaluation; + } else if ((writerTraining != null) && (writerEvaluation == null)) { + writer = writerTraining; + } else { + if (Math.random() <= splitRatio) + writer = writerTraining; + else + writer = writerEvaluation; + } + return writer; + } + + } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/CRFPPGenericTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/CRFPPGenericTrainer.java index b7105e3f0b..9a966ce236 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/CRFPPGenericTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/CRFPPGenericTrainer.java @@ -13,6 +13,8 @@ * * User: zholudev * Date: 3/20/14 + * + * @deprecated use WapitiTrainer or DelftTrainer (requires http://github.com/kermitt2/delft) */ @Deprecated public class CRFPPGenericTrainer implements GenericTrainer { diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/Trainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/Trainer.java index 2bff8058f2..4685f770a3 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/Trainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/Trainer.java @@ -24,5 +24,7 @@ public interface Trainer { String splitTrainEvaluate(Double split); + String nFoldEvaluate(int folds); + GrobidModel getModel(); } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java b/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java index 7a765bc162..5e83246083 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java @@ -3,137 +3,144 @@ import org.grobid.core.utilities.GrobidProperties; import java.io.File; +import java.util.Arrays; +import java.util.List; /** * Training application for training a target model. - * + * * @author Patrice Lopez */ public class TrainerRunner { - private enum RunType { - TRAIN, EVAL, SPLIT; - - public static RunType getRunType(int i) { - for (RunType t : values()) { - if (t.ordinal() == i) { - return t; - } - } - - throw new IllegalStateException("Unsupported RunType with ordinal " + i); - } - } - - /** - * Initialize the batch. - */ - protected static void initProcess(final String path2GbdHome, final String path2GbdProperties) { - GrobidProperties.getInstance(); - } - - /** - * Command line execution. - * - * @param args - * Command line arguments. - */ - public static void main(String[] args) { - if (args.length < 4) { - throw new IllegalStateException( - "Usage: {0 - train, 1 - evaluate, 2 - split, train and evaluate} {affiliation,chemical,date,citation,ebook,fulltext,header,name-citation,name-header,patent} -gH /path/to/Grobid/home -s { [0.0 - 1.0] - split ratio, optional}"); - } - - RunType mode = RunType.getRunType(Integer.parseInt(args[0])); - if ( (mode == RunType.SPLIT) && (args.length < 6) ) { - throw new IllegalStateException( - "Usage: {0 - train, 1 - evaluate, 2 - split, train and evaluate} {affiliation,chemical,date,citation,ebook,fulltext,header,name-citation,name-header,patent} -gH /path/to/Grobid/home -s { [0.0 - 1.0] - split ratio, optional}"); - } - - String path2GbdHome = null; - Double split = 0.0; - for (int i = 0; i < args.length; i++) { - if (args[i].equals("-gH")) { - if (i+1 == args.length) { - throw new IllegalStateException("Missing path to Grobid home. "); - } - path2GbdHome = args[i + 1]; - } - else if (args[i].equals("-s")) { - if (i+1 == args.length) { - throw new IllegalStateException("Missing split ratio value. "); - } - String splitRatio = args[i + 1]; - try { - split = Double.parseDouble(args[i + 1]); - } - catch(Exception e) { - throw new IllegalStateException("Invalid split value: " + args[i + 1]); - } - - } - } - - if (path2GbdHome == null) { - throw new IllegalStateException( - "Usage: {0 - train, 1 - evaluate, 2 - split, train and evaluate} {affiliation,chemical,date,citation,ebook,fulltext,header,name-citation,name-header,patent} -gH /path/to/Grobid/home -s { [0.0 - 1.0] - split ratio, optional}"); - } - - final String path2GbdProperties = path2GbdHome + File.separator + "config" + File.separator + "grobid.properties"; - - System.out.println("path2GbdHome=" + path2GbdHome + " path2GbdProperties=" + path2GbdProperties); - initProcess(path2GbdHome, path2GbdProperties); - - String model = args[1]; - - AbstractTrainer trainer; - - if (model.equals("affiliation") || model.equals("affiliation-address")) { - trainer = new AffiliationAddressTrainer(); - } else if (model.equals("chemical")) { - trainer = new ChemicalEntityTrainer(); - } else if (model.equals("date")) { - trainer = new DateTrainer(); - } else if (model.equals("citation")) { - trainer = new CitationTrainer(); - } else if (model.equals("monograph")) { - trainer = new MonographTrainer(); - } else if (model.equals("fulltext")) { - trainer = new FulltextTrainer(); - } else if (model.equals("header")) { - trainer = new HeaderTrainer(); - } else if (model.equals("name-citation")) { - trainer = new NameCitationTrainer(); - } else if (model.equals("name-header")) { - trainer = new NameHeaderTrainer(); - } else if (model.equals("patent")) { - trainer = new PatentParserTrainer(); - } else if (model.equals("segmentation")) { - trainer = new SegmentationTrainer(); - } else if (model.equals("reference-segmenter")) { + private static final List models = Arrays.asList("affiliation", "chemical", "date", "citation", "ebook", "fulltext", "header", "name-citation", "name-header", "patent"); + private static final List options = Arrays.asList("0 - train", "1 - evaluate", "2 - split, train and evaluate", "3 - n-fold evaluation"); + + private enum RunType { + TRAIN, EVAL, SPLIT, EVAL_N_FOLD; + + public static RunType getRunType(int i) { + for (RunType t : values()) { + if (t.ordinal() == i) { + return t; + } + } + + throw new IllegalStateException("Unsupported RunType with ordinal " + i); + } + } + + protected static void initProcess(final String path2GbdHome, final String path2GbdProperties) { + GrobidProperties.getInstance(); + } + + public static void main(String[] args) { + if (args.length < 4) { + throw new IllegalStateException( + "Usage: {" + String.join(", ", options) + "} {" + String.join(", ", models) + "} -gH /path/to/Grobid/home -s { [0.0 - 1.0] - split ratio, optional} -n {[int, num folds for n-fold evaluation, optional]}"); + } + + RunType mode = RunType.getRunType(Integer.parseInt(args[0])); + if ((mode == RunType.SPLIT || mode == RunType.EVAL_N_FOLD) && (args.length < 6)) { + throw new IllegalStateException( + "Usage: {" + String.join(", ", options) + "} {" + String.join(", ", models) + "} -gH /path/to/Grobid/home -s { [0.0 - 1.0] - split ratio, optional} -n {[int, num folds for n-fold evaluation, optional]}"); + } + + String path2GbdHome = null; + Double split = 0.0; + Integer numFolds = 0; + for (int i = 0; i < args.length; i++) { + if (args[i].equals("-gH")) { + if (i + 1 == args.length) { + throw new IllegalStateException("Missing path to Grobid home. "); + } + path2GbdHome = args[i + 1]; + } else if (args[i].equals("-s")) { + if (i + 1 == args.length) { + throw new IllegalStateException("Missing split ratio value. "); + } + try { + split = Double.parseDouble(args[i + 1]); + } catch (Exception e) { + throw new IllegalStateException("Invalid split value: " + args[i + 1]); + } + + } else if (args[i].equals("-n")) { + if (i + 1 == args.length) { + throw new IllegalStateException("Missing number of folds value. "); + } + try { + numFolds = Integer.parseInt(args[i + 1]); + } catch (Exception e) { + throw new IllegalStateException("Invalid split value: " + args[i + 1]); + } + + } + } + + if (path2GbdHome == null) { + throw new IllegalStateException( + "Usage: {" + String.join(", ", options) + "} {" + String.join(", ", models) + "} -gH /path/to/Grobid/home -s { [0.0 - 1.0] - split ratio, optional} -n {[int, num folds for n-fold evaluation, optional]}"); + } + + final String path2GbdProperties = path2GbdHome + File.separator + "config" + File.separator + "grobid.properties"; + + System.out.println("path2GbdHome=" + path2GbdHome + " path2GbdProperties=" + path2GbdProperties); + initProcess(path2GbdHome, path2GbdProperties); + + String model = args[1]; + + AbstractTrainer trainer; + + if (model.equals("affiliation") || model.equals("affiliation-address")) { + trainer = new AffiliationAddressTrainer(); + } else if (model.equals("chemical")) { + trainer = new ChemicalEntityTrainer(); + } else if (model.equals("date")) { + trainer = new DateTrainer(); + } else if (model.equals("citation")) { + trainer = new CitationTrainer(); + } else if (model.equals("monograph")) { + trainer = new MonographTrainer(); + } else if (model.equals("fulltext")) { + trainer = new FulltextTrainer(); + } else if (model.equals("header")) { + trainer = new HeaderTrainer(); + } else if (model.equals("name-citation")) { + trainer = new NameCitationTrainer(); + } else if (model.equals("name-header")) { + trainer = new NameHeaderTrainer(); + } else if (model.equals("patent")) { + trainer = new PatentParserTrainer(); + } else if (model.equals("segmentation")) { + trainer = new SegmentationTrainer(); + } else if (model.equals("reference-segmenter")) { trainer = new ReferenceSegmenterTrainer(); } else if (model.equals("figure")) { - trainer = new FigureTrainer(); - } else if (model.equals("table")) { - trainer = new TableTrainer(); - } else { - throw new IllegalStateException("The model " + model + " is unknown."); - } - - switch (mode) { - case TRAIN: - AbstractTrainer.runTraining(trainer); - break; - case EVAL: - AbstractTrainer.runEvaluation(trainer); - break; - case SPLIT: - AbstractTrainer.runSplitTrainingEvaluation(trainer, split); - break; - default: - throw new IllegalStateException("Invalid RunType: " + mode.name()); - } - System.exit(0); - } + trainer = new FigureTrainer(); + } else if (model.equals("table")) { + trainer = new TableTrainer(); + } else { + throw new IllegalStateException("The model " + model + " is unknown."); + } + + switch (mode) { + case TRAIN: + AbstractTrainer.runTraining(trainer); + break; + case EVAL: + AbstractTrainer.runEvaluation(trainer); + break; + case SPLIT: + AbstractTrainer.runSplitTrainingEvaluation(trainer, split); + break; + case EVAL_N_FOLD: + AbstractTrainer.runNFoldEvaluation(trainer, numFolds); + break; + default: + throw new IllegalStateException("Invalid RunType: " + mode.name()); + } + System.exit(0); + } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java index 9a8921e148..a28ecbd7f8 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java @@ -11,6 +11,7 @@ import java.io.BufferedReader; import java.io.FileInputStream; import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import java.util.StringTokenizer; @@ -98,53 +99,44 @@ public static String taggerRun(List ress, Tagger tagger) { return res.toString(); } - public static String evaluateStandard(String path, final GenericTagger tagger) { - return evaluateStandard(path, new Function, String>() { - @Override - public String apply(List strings) { - return tagger.label(strings); - } - }); + public static ModelStats evaluateStandard(String path, final GenericTagger tagger) { + return evaluateStandard(path, tagger::label); } - public static String evaluateStandard(String path, Function, String> taggerFunction) { + public static ModelStats evaluateStandard(String path, Function, String> taggerFunction) { String theResult = null; try { - final BufferedReader bufReader = new BufferedReader(new InputStreamReader(new FileInputStream(path), "UTF-8")); + final BufferedReader bufReader = new BufferedReader(new InputStreamReader(new FileInputStream(path), StandardCharsets.UTF_8)); String line = null; - List citationBlocks = new ArrayList(); + List instance = new ArrayList(); while ((line = bufReader.readLine()) != null) { - citationBlocks.add(line); + instance.add(line); } long time = System.currentTimeMillis(); - theResult = taggerFunction.apply(citationBlocks); + theResult = taggerFunction.apply(instance); bufReader.close(); System.out.println("Labeling took: " + (System.currentTimeMillis() - time) + " ms"); } catch (Exception e) { throw new GrobidException("An exception occurred while evaluating Grobid.", e); } - return reportMetrics(theResult); + return computeStats(theResult); } - public static String reportMetrics(String theResult) { - StringBuilder report = new StringBuilder(); - + public static ModelStats computeStats(String theResult) { + ModelStats accumulator = new ModelStats(); // report token-level results Stats wordStats = tokenLevelStats(theResult); - report.append("\n===== Token-level results =====\n\n"); - report.append(wordStats.getReport()); + accumulator.setTokenStats(wordStats); // report field-level results Stats fieldStats = fieldLevelStats(theResult); - report.append("\n===== Field-level results =====\n"); - report.append(fieldStats.getReport()); + accumulator.setFieldStats(fieldStats); // instance-level: instances are separated by a new line in the result file // third pass - report.append("\n===== Instance-level results =====\n\n"); theResult = theResult.replace("\n\n", "\n \n"); StringTokenizer stt = new StringTokenizer(theResult, "\n"); boolean allGood = true; @@ -179,9 +171,34 @@ public static String reportMetrics(String theResult) { } } - report.append(String.format("%-27s %d\n", "Total expected instances:", totalInstance)); - report.append(String.format("%-27s %d\n", "Correct instances:", correctInstance)); - double accuracy = (double) correctInstance / (totalInstance); + accumulator.setTotalInstances(totalInstance); + accumulator.setCorrectInstance(correctInstance); + accumulator.setInstanceAccuracy(correctInstance); + double accuracy = (double) (correctInstance / totalInstance); + + + return accumulator; + } + + public static String reportMetrics(ModelStats accumulated) { + StringBuilder report = new StringBuilder(); + + // report token-level results + Stats wordStats = accumulated.getTokenStats(); + report.append("\n===== Token-level results =====\n\n"); + report.append(wordStats.getReport()); + + // report field-level results + Stats fieldStats = accumulated.getFieldStats(); + report.append("\n===== Field-level results =====\n"); + report.append(fieldStats.getReport()); + + // instance-level: instances are separated by a new line in the result file + // third pass + report.append("\n===== Instance-level results =====\n\n"); + report.append(String.format("%-27s %d\n", "Total expected instances:", accumulated.getTotalInstances())); + report.append(String.format("%-27s %d\n", "Correct instances:", accumulated.getCorrectInstance())); + double accuracy = (double) accumulated.getCorrectInstance() / (accumulated.getTotalInstances()); report.append(String.format("%-27s %s\n", "Instance-level recall:", TextUtilities.formatTwoDecimals(accuracy * 100))); diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java new file mode 100644 index 0000000000..7078e191f1 --- /dev/null +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java @@ -0,0 +1,52 @@ +package org.grobid.trainer.evaluation; + +/** + * Represent all different evaluation given a specific model + */ +public class ModelStats { + private int totalInstances; + private int correctInstance; + private int instanceAccuracy; + private Stats tokenStats; + private Stats fieldStats; + + public void setTotalInstances(int totalInstances) { + this.totalInstances = totalInstances; + } + + public int getTotalInstances() { + return totalInstances; + } + + public void setCorrectInstance(int correctInstance) { + this.correctInstance = correctInstance; + } + + public int getCorrectInstance() { + return correctInstance; + } + + public void setInstanceAccuracy(int instanceAccuracy) { + this.instanceAccuracy = instanceAccuracy; + } + + public int getInstanceAccuracy() { + return instanceAccuracy; + } + + public void setTokenStats(Stats tokenStats) { + this.tokenStats = tokenStats; + } + + public Stats getTokenStats() { + return tokenStats; + } + + public void setFieldStats(Stats fieldStats) { + this.fieldStats = fieldStats; + } + + public Stats getFieldStats() { + return fieldStats; + } +} diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java index dafb7ef3f1..49d4c80ea6 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java @@ -7,6 +7,9 @@ import org.grobid.core.utilities.TextUtilities; import org.grobid.trainer.LabelStat; +/** + * Contains the single statistic computation for evaluation + */ public final class Stats { private final TreeMap labelStats; From 08f2e3ac84a8681f959968bd3e837332b8380a88 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Thu, 11 Jul 2019 13:36:28 +0900 Subject: [PATCH 09/37] Adding more tests and more output #453 --- .../org/grobid/trainer/AbstractTrainer.java | 174 ++++++++++++------ .../AbstractTrainerIntegrationTest.java | 166 +++++++++++++++++ .../resources/sample.wapiti.output.date.txt | 23 +++ 3 files changed, 302 insertions(+), 61 deletions(-) create mode 100644 grobid-trainer/src/test/java/org/grobid/trainer/AbstractTrainerIntegrationTest.java create mode 100644 grobid-trainer/src/test/resources/sample.wapiti.output.date.txt diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index 25bc21d0b1..7f32f018a6 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -5,20 +5,26 @@ import org.apache.commons.lang3.tuple.ImmutablePair; import org.grobid.core.GrobidModel; import org.grobid.core.engines.tagging.GenericTagger; +import org.grobid.core.engines.tagging.GrobidCRFEngine; import org.grobid.core.engines.tagging.TaggerFactory; import org.grobid.core.exceptions.GrobidException; import org.grobid.core.factory.GrobidFactory; import org.grobid.core.utilities.GrobidProperties; -import org.grobid.core.engines.tagging.GrobidCRFEngine; +import org.grobid.core.utilities.TextUtilities; import org.grobid.trainer.evaluation.EvaluationUtilities; import org.grobid.trainer.evaluation.ModelStats; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.*; +import java.io.BufferedWriter; +import java.io.File; +import java.io.IOException; +import java.io.Writer; import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; import java.util.*; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -144,46 +150,94 @@ public String nFoldEvaluate(int folds) { GenericTrainer trainer = TrainerFactory.getTrainer(); // Load in memory and Shuffle - List trainingData = new ArrayList<>(); - try (Stream stream = Files.lines(Paths.get(dataPath.getAbsolutePath()))) { - List instance = new ArrayList<>(); - ListIterator iterator = stream.collect(Collectors.toList()).listIterator(); - while (iterator.hasNext()) { - String current = iterator.next(); + Path dataPath2 = Paths.get(dataPath.getAbsolutePath()); + List trainingData = loadAndShuffle(dataPath2); - if (StringUtils.isBlank(current)) { - if (CollectionUtils.isNotEmpty(instance)) { - trainingData.add(String.join("\n", instance)); - } - instance = new ArrayList<>(); - } else { - instance.add(current); - } - } + // Split into folds + List> foldMap = splitNFold(trainingData, folds); - } catch (IOException e) { - e.printStackTrace(); + // Train and evaluation + if (epsilon != 0.0) + trainer.setEpsilon(epsilon); + if (window != 0) + trainer.setWindow(window); + if (nbMaxIterations != 0) + trainer.setNbMaxIterations(nbMaxIterations); + + //We dump the model in the tmp directory + File tmpDirectory = new File(GrobidProperties.getTempPath().getAbsolutePath()); + if (!tmpDirectory.exists()) { + LOGGER.warn("Cannot find the destination directory " + tmpDirectory); } -// trainingData.forEach(s -> { -// System.out.println(s); -// System.out.println("\n\n"); -// }); - Collections.shuffle(trainingData); -// Collections.shuffle(trainingData); + final File tempModelPath = new File(tmpDirectory + File.separator + "nfold_dummy_model"); + System.out.println("Saving model in " + tempModelPath); + + List evaluationResults = foldMap.stream().map(fold -> { + trainer.train(getTemplatePath(), new File(fold.getLeft()), tempModelPath, GrobidProperties.getNBThreads(), model); + return EvaluationUtilities.evaluateStandard(fold.getRight(), getTagger()); + }).collect(Collectors.toList()); -// System.out.println("\n\n"); -// trainingData.forEach(s -> { -// System.out.println(s); -// System.out.println("\n\n"); -// }); + System.out.println("Results: "); - // Split into folds + Comparator f1ScoreComparator = (o1, o2) -> { + if (o1.getFieldStats().getMacroAverageF1() > o1.getFieldStats().getMacroAverageF1()) { + return 1; + } else if (o1.getFieldStats().getMacroAverageF1() < o1.getFieldStats().getMacroAverageF1()) { + return -1; + } else { + return 0; + } + }; + // Output + StringBuilder sb = new StringBuilder(); + Optional worstModel = evaluationResults.stream().min(f1ScoreComparator); + sb.append("Worst Model").append("\n"); + ModelStats worstModelStats = worstModel.orElseGet(() -> { + throw new GrobidException("Something wrong when computing evaluations " + + "- worst model metrics not found. "); + }); + sb.append(EvaluationUtilities.reportMetrics(worstModelStats)); + + Optional bestModel = evaluationResults.stream().max(f1ScoreComparator); + ModelStats bestModelStats = bestModel.orElseGet(() -> { + throw new GrobidException("Something wrong when computing evaluations " + + "- best model metrics not found. "); + }); + sb.append(EvaluationUtilities.reportMetrics(bestModelStats)); + + // Averages + OptionalDouble averageF1 = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageF1()).average(); + OptionalDouble averagePrecision = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAveragePrecision()).average(); + OptionalDouble averageRecall = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageRecall()).average(); + + double avgPrecision = averagePrecision.orElseGet(() -> { + throw new GrobidException("Missing average precision. Something went wrong. Please check. "); + }); + sb.append("Average precision: " + TextUtilities.formatTwoDecimals(avgPrecision * 100)).append("\n"); + + double avgRecall = averageRecall.orElseGet(() -> { + throw new GrobidException("Missing average recall. Something went wrong. Please check. "); + }); + sb.append("Average recall: " + TextUtilities.formatTwoDecimals(avgRecall * 100)).append("\n"); + + double avgF1 = averageF1.orElseGet(() -> { + throw new GrobidException("Missing average F1. Something went wrong. Please check. "); + }); + sb.append("Average F1: " + TextUtilities.formatTwoDecimals(avgF1 * 100)).append("\n"); + + return sb.toString(); + } + + /** + * Partition the corpus in n folds, dump them in n files and return the pairs of (trainingPath, evaluationPath) + */ + protected List> splitNFold(List trainingData, int folds) { int trainingSize = CollectionUtils.size(trainingData); int foldSize = Math.floorDiv(trainingSize, folds); - List> foldMap = IntStream.range(0, folds).mapToObj(foldIndex -> { + return IntStream.range(0, folds).mapToObj(foldIndex -> { int foldStart = foldSize * foldIndex; int foldEnd = foldStart + foldSize; @@ -200,7 +254,7 @@ public String nFoldEvaluate(int folds) { //Dump Evaluation String tempEvaluationDataPath = getTempEvaluationDataPath().getAbsolutePath(); - System.out.println(tempEvaluationDataPath); +// System.out.println(tempEvaluationDataPath); try (BufferedWriter writer = Files.newBufferedWriter(Paths.get(tempEvaluationDataPath))) { // System.out.println(String.join("\n\n\n", foldEvaluation)); writer.write(String.join("\n\n", foldEvaluation)); @@ -210,7 +264,7 @@ public String nFoldEvaluate(int folds) { //Dump Training String tempTrainingDataPath = getTempTrainingDataPath().getAbsolutePath(); - System.out.println(tempTrainingDataPath); +// System.out.println(tempTrainingDataPath); try (BufferedWriter writer = Files.newBufferedWriter(Paths.get(tempTrainingDataPath))) { // System.out.println(String.join("\n\n\n", foldTraining)); writer.write(String.join("\n\n", foldTraining)); @@ -220,38 +274,36 @@ public String nFoldEvaluate(int folds) { return new ImmutablePair<>(tempTrainingDataPath, tempEvaluationDataPath); }).collect(Collectors.toList()); + } + /** + * Load the dataset in memory and shuffle it. Assuming that each empty line is a delimiter between instances. + * Empty line are filtered out from the output + */ + protected List loadAndShuffle(Path dataPath) { + List trainingData = new ArrayList<>(); + try (Stream stream = Files.lines(dataPath)) { + List instance = new ArrayList<>(); + ListIterator iterator = stream.collect(Collectors.toList()).listIterator(); + while (iterator.hasNext()) { + String current = iterator.next(); - // Train and evaluastion - - if (epsilon != 0.0) - trainer.setEpsilon(epsilon); - if (window != 0) - trainer.setWindow(window); - if (nbMaxIterations != 0) - trainer.setNbMaxIterations(nbMaxIterations); + if (StringUtils.isBlank(current)) { + if (CollectionUtils.isNotEmpty(instance)) { + trainingData.add(String.join("\n", instance)); + } + instance = new ArrayList<>(); + } else { + instance.add(current); + } + } - //We dump the model in the tmp directory - File tmpDirectory = new File(GrobidProperties.getTempPath().getAbsolutePath()); - if (!tmpDirectory.exists()) { - LOGGER.warn("Cannot find the destination directory " + tmpDirectory); + } catch (IOException e) { + e.printStackTrace(); } - final File tempModelPath = new File(tmpDirectory + File.separator + "nfold_dummy_model"); - System.out.println("Saving model in " + tempModelPath); - - List evaluationResults = foldMap.stream().map(fold -> { - trainer.train(getTemplatePath(), new File(fold.getLeft()), tempModelPath, GrobidProperties.getNBThreads(), model); - return EvaluationUtilities.evaluateStandard(fold.getRight(), getTagger()); - }).collect(Collectors.toList()); - - - // Averages - OptionalDouble averageF1 = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageF1()).average(); - OptionalDouble averagePrecision = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAveragePrecision()).average(); - OptionalDouble averageRecall = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageRecall()).average(); - - return "Average precision: " + averagePrecision.getAsDouble() + "\nAverage recall: " + averageRecall.getAsDouble() + "\nAverage F1: " + averageF1.getAsDouble() + "\n "; + Collections.shuffle(trainingData); + return trainingData; } protected final File getTempTrainingDataPath() { diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/AbstractTrainerIntegrationTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/AbstractTrainerIntegrationTest.java new file mode 100644 index 0000000000..4fc7dbd3ec --- /dev/null +++ b/grobid-trainer/src/test/java/org/grobid/trainer/AbstractTrainerIntegrationTest.java @@ -0,0 +1,166 @@ +package org.grobid.trainer; + +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.grobid.core.GrobidModels; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.*; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.ListIterator; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.CoreMatchers.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.collection.IsCollectionWithSize.hasSize; + + +/** This test is creating temp files - some cannot be removed **/ +public class AbstractTrainerIntegrationTest { + + private AbstractTrainer target; + + @BeforeClass + public static void beforeClass() throws Exception { +// LibraryLoader.load(); + } + + @Before + public void setUp() throws Exception { + target = new AbstractTrainer(GrobidModels.DUMMY) { + + @Override + public int createCRFPPData(File corpusPath, File outputTrainingFile, File outputEvalFile, double splitRatio) { + // the file for writing the training data +// if (outputTrainingFile != null) { +// try (OutputStream os = new FileOutputStream(outputTrainingFile)) { +// try (Writer writer = new OutputStreamWriter(os, StandardCharsets.UTF_8)) { +// +// for (int i = 0; i < 100; i++) { +// double random = Math.random(); +// writer.write("blablabla" + random); +// writer.write("\n"); +// if (i % 10 == 0) { +// writer.write("\n"); +// } +// } +// } +// } catch (IOException e) { +// e.printStackTrace(); +// } +// } + + return 100; + } + }; + } + + @After + public void tearDown() throws Exception { + + } + + + @Test + public void testSplitNFold_n2_shouldWork() throws Exception { + List dummyTrainingData = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + double random = Math.random(); + dummyTrainingData.add("blablabla" + random); + } + + List> splitMapping = target.splitNFold(dummyTrainingData, 2); + assertThat(splitMapping, hasSize(2)); + + assertThat(splitMapping.get(0).getLeft(), endsWith("train")); + assertThat(splitMapping.get(0).getRight(), endsWith("test")); + + // Cleanup + splitMapping.stream().forEach(f -> { + try { + Files.delete(Paths.get(f.getRight())); + } catch (IOException e) { + e.printStackTrace(); + } + }); + splitMapping.stream().forEach(f -> { + try { + Files.delete(Paths.get(f.getLeft())); + } catch (IOException e) { + e.printStackTrace(); + } + }); + } + + @Test + public void testSplitNFold_n10_shouldWork() throws Exception { + List dummyTrainingData = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + double random = Math.random(); + dummyTrainingData.add("blablabla" + random); + } + + List> splitMapping = target.splitNFold(dummyTrainingData, 10); + assertThat(splitMapping, hasSize(10)); + + assertThat(splitMapping.get(0).getLeft(), endsWith("train")); + assertThat(splitMapping.get(0).getRight(), endsWith("test")); + + // Cleanup + splitMapping.stream().forEach(f -> { + try { + Files.delete(Paths.get(f.getRight())); + } catch (IOException e) { + e.printStackTrace(); + } + }); + splitMapping.stream().forEach(f -> { + try { + Files.delete(Paths.get(f.getLeft())); + } catch (IOException e) { + e.printStackTrace(); + } + }); + } + + + @Test + public void testLoadAndShuffle_shouldWork() throws Exception { + Path path = Paths.get("src/test/resources/sample.wapiti.output.date.txt"); + List orderedTrainingData = new ArrayList<>(); + try (Stream stream = Files.lines(path)) { + + ListIterator iterator = stream.collect(Collectors.toList()).listIterator(); + List instance = new ArrayList<>(); + while (iterator.hasNext()) { + String current = iterator.next(); + if (StringUtils.isBlank(current)) { + if (CollectionUtils.isNotEmpty(instance)) { + orderedTrainingData.add(String.join("\n", instance)); + } + instance = new ArrayList<>(); + } else { + instance.add(current); + } + } + } + List shuffledTrainingData = target.loadAndShuffle(path); + + assertThat(shuffledTrainingData, hasSize(orderedTrainingData.size())); + + assertThat(shuffledTrainingData.get(0), is(not(orderedTrainingData.get(0)))); + assertThat(shuffledTrainingData.get(0).split("\n").length, is(4)); + assertThat(shuffledTrainingData.get(1), is(not(orderedTrainingData.get(1)))); + } + +} \ No newline at end of file diff --git a/grobid-trainer/src/test/resources/sample.wapiti.output.date.txt b/grobid-trainer/src/test/resources/sample.wapiti.output.date.txt new file mode 100644 index 0000000000..b1005bfb59 --- /dev/null +++ b/grobid-trainer/src/test/resources/sample.wapiti.output.date.txt @@ -0,0 +1,23 @@ +Available available A Av Ava Avai e le ble able LINESTART INITCAP NODIGIT 0 0 0 NOPUNCT I- +online online o on onl onli e ne ine line LINEIN NOCAPS NODIGIT 0 0 0 NOPUNCT +18 18 1 18 18 18 8 18 18 18 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I- +January january J Ja Jan Janu y ry ary uary LINEIN INITCAP NODIGIT 0 0 1 NOPUNCT I- +2010 2010 2 20 201 2010 0 10 010 2010 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I- + + +June june J Ju Jun June e ne une June LINESTART INITCAP NODIGIT 0 0 1 NOPUNCT I- +16 16 1 16 16 16 6 16 16 16 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I- +, , , , , , , , , , LINEIN ALLCAP NODIGIT 1 0 0 COMMA I- +2008 2008 2 20 200 2008 8 08 008 2008 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I- + + +November november N No Nov Nove r er ber mber LINESTART INITCAP NODIGIT 0 0 1 NOPUNCT I- +4 4 4 4 4 4 4 4 4 4 LINEIN NOCAPS ALLDIGIT 1 0 0 NOPUNCT I- +, , , , , , , , , , LINEIN ALLCAP NODIGIT 1 0 0 COMMA I- +2009 2009 2 20 200 2009 9 09 009 2009 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I- + + +Published published P Pu Pub Publ d ed hed shed LINESTART INITCAP NODIGIT 0 0 0 NOPUNCT I- +18 18 1 18 18 18 8 18 18 18 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I- +May may M Ma May May y ay May May LINEIN INITCAP NODIGIT 0 0 1 NOPUNCT I- +2011 2011 2 20 201 2011 1 11 011 2011 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I- \ No newline at end of file From e1153965c5c395eeca687042cf20d092fdda0837 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Thu, 11 Jul 2019 14:01:42 +0900 Subject: [PATCH 10/37] Adding dummy model for testing #410 #453 --- .../java/org/grobid/core/GrobidModels.java | 12 ++++- .../core/engines/tagging/GrobidCRFEngine.java | 3 +- .../org/grobid/trainer/AbstractTrainer.java | 1 - .../java/org/grobid/trainer/DummyTrainer.java | 50 +++++++++++++++++++ .../org/grobid/trainer/TrainerFactory.java | 2 + 5 files changed, 65 insertions(+), 3 deletions(-) create mode 100644 grobid-trainer/src/main/java/org/grobid/trainer/DummyTrainer.java diff --git a/grobid-core/src/main/java/org/grobid/core/GrobidModels.java b/grobid-core/src/main/java/org/grobid/core/GrobidModels.java index 01e859068f..fc27d5abe5 100755 --- a/grobid-core/src/main/java/org/grobid/core/GrobidModels.java +++ b/grobid-core/src/main/java/org/grobid/core/GrobidModels.java @@ -43,7 +43,11 @@ public enum GrobidModels implements GrobidModel { // ENTITIES_BIOTECH("entities/biotech"), ENTITIES_BIOTECH("bio"), ASTRO("astro"), - SOFTWARE("software"); + SOFTWARE("software"), + DUMMY("none"); + + //I cannot declare it before + public static final String DUMMY_FOLDER_LABEL = "none"; /** * Absolute path to the model. @@ -55,6 +59,12 @@ public enum GrobidModels implements GrobidModel { private static final ConcurrentMap models = new ConcurrentHashMap<>(); GrobidModels(String folderName) { + if(StringUtils.equals(DUMMY_FOLDER_LABEL, folderName)) { + modelPath = DUMMY_FOLDER_LABEL; + this.folderName = DUMMY_FOLDER_LABEL; + return; + } + this.folderName = folderName; File path = GrobidProperties.getModelPath(this); if (!path.exists()) { diff --git a/grobid-core/src/main/java/org/grobid/core/engines/tagging/GrobidCRFEngine.java b/grobid-core/src/main/java/org/grobid/core/engines/tagging/GrobidCRFEngine.java index b12cae7e92..cb4d93270a 100644 --- a/grobid-core/src/main/java/org/grobid/core/engines/tagging/GrobidCRFEngine.java +++ b/grobid-core/src/main/java/org/grobid/core/engines/tagging/GrobidCRFEngine.java @@ -9,7 +9,8 @@ public enum GrobidCRFEngine { WAPITI("wapiti"), CRFPP("crf"), - DELFT("delft"); + DELFT("delft"), + DUMMY("dummy"); private final String ext; diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index 7f32f018a6..5d30b0f4dd 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -24,7 +24,6 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.*; -import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/DummyTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/DummyTrainer.java new file mode 100644 index 0000000000..4e90920609 --- /dev/null +++ b/grobid-trainer/src/main/java/org/grobid/trainer/DummyTrainer.java @@ -0,0 +1,50 @@ +package org.grobid.trainer; + +import org.grobid.core.GrobidModel; + +import java.io.File; + +/** + * Dummy trainer which won't do anything. + */ +public class DummyTrainer implements GenericTrainer { + @Override + public void train(File template, File trainingData, File outputModel, int numThreads, GrobidModel model) { + + } + + @Override + public String getName() { + return null; + } + + @Override + public void setEpsilon(double epsilon) { + + } + + @Override + public void setWindow(int window) { + + } + + @Override + public double getEpsilon() { + return 0; + } + + @Override + public int getWindow() { + return 0; + } + + @Override + public int getNbMaxIterations() { + return 0; + } + + @Override + public void setNbMaxIterations(int iterations) { + + } +} diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/TrainerFactory.java b/grobid-trainer/src/main/java/org/grobid/trainer/TrainerFactory.java index b31ffb5eed..f9dc1b0dc3 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/TrainerFactory.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/TrainerFactory.java @@ -15,6 +15,8 @@ public static GenericTrainer getTrainer() { return new WapitiTrainer(); case DELFT: return new DeLFTTrainer(); + case DUMMY: + return new DummyTrainer(); default: throw new IllegalStateException("Unsupported Grobid sequence labelling engine: " + GrobidProperties.getGrobidCRFEngine()); } From b2737283095dc6509811d2d25fb8ba928eb4b2de Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Thu, 11 Jul 2019 15:42:58 +0900 Subject: [PATCH 11/37] More unit tests and fixes #453 --- .../org/grobid/trainer/AbstractTrainer.java | 53 ++++-- .../AbstractTrainerIntegrationTest.java | 170 +++++++++++------- 2 files changed, 145 insertions(+), 78 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index 5d30b0f4dd..c13a75ce74 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -4,6 +4,7 @@ import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.ImmutablePair; import org.grobid.core.GrobidModel; +import org.grobid.core.GrobidModels; import org.grobid.core.engines.tagging.GenericTagger; import org.grobid.core.engines.tagging.GrobidCRFEngine; import org.grobid.core.engines.tagging.TaggerFactory; @@ -49,6 +50,10 @@ public abstract class AbstractTrainer implements Trainer { public AbstractTrainer(final GrobidModel model) { GrobidFactory.getInstance().createEngine(); this.model = model; + if (model.equals(GrobidModels.DUMMY)) { + // In case of dummy model we do not initialise (and create) temporary files + return; + } this.trainDataPath = getTempTrainingDataPath(); this.evalDataPath = getTempEvaluationDataPath(); } @@ -181,9 +186,9 @@ public String nFoldEvaluate(int folds) { Comparator f1ScoreComparator = (o1, o2) -> { - if (o1.getFieldStats().getMacroAverageF1() > o1.getFieldStats().getMacroAverageF1()) { + if (o1.getFieldStats().getMacroAverageF1() > o2.getFieldStats().getMacroAverageF1()) { return 1; - } else if (o1.getFieldStats().getMacroAverageF1() < o1.getFieldStats().getMacroAverageF1()) { + } else if (o1.getFieldStats().getMacroAverageF1() < o2.getFieldStats().getMacroAverageF1()) { return -1; } else { return 0; @@ -232,15 +237,18 @@ public String nFoldEvaluate(int folds) { /** * Partition the corpus in n folds, dump them in n files and return the pairs of (trainingPath, evaluationPath) */ - protected List> splitNFold(List trainingData, int folds) { + protected List> splitNFold(List trainingData, int numberFolds) { int trainingSize = CollectionUtils.size(trainingData); - int foldSize = Math.floorDiv(trainingSize, folds); + int foldSize = Math.floorDiv(trainingSize, numberFolds); + if (foldSize == 0) { + throw new IllegalArgumentException("There aren't enough training data for n-fold evaluation with fold of size " + numberFolds); + } - return IntStream.range(0, folds).mapToObj(foldIndex -> { + return IntStream.range(0, numberFolds).mapToObj(foldIndex -> { int foldStart = foldSize * foldIndex; int foldEnd = foldStart + foldSize; - if (foldIndex == folds - 1) { + if (foldIndex == numberFolds - 1) { foldEnd = trainingSize; } @@ -253,22 +261,20 @@ protected List> splitNFold(List trainingDa //Dump Evaluation String tempEvaluationDataPath = getTempEvaluationDataPath().getAbsolutePath(); -// System.out.println(tempEvaluationDataPath); try (BufferedWriter writer = Files.newBufferedWriter(Paths.get(tempEvaluationDataPath))) { -// System.out.println(String.join("\n\n\n", foldEvaluation)); writer.write(String.join("\n\n", foldEvaluation)); + writer.write("\n"); } catch (IOException e) { - e.printStackTrace(); + throw new GrobidException("Error when dumping n-fold evaluation data into files. ", e); } //Dump Training String tempTrainingDataPath = getTempTrainingDataPath().getAbsolutePath(); -// System.out.println(tempTrainingDataPath); try (BufferedWriter writer = Files.newBufferedWriter(Paths.get(tempTrainingDataPath))) { -// System.out.println(String.join("\n\n\n", foldTraining)); writer.write(String.join("\n\n", foldTraining)); + writer.write("\n"); } catch (IOException e) { - e.printStackTrace(); + throw new GrobidException("Error when dumping n-fold training data into files. ", e); } return new ImmutablePair<>(tempTrainingDataPath, tempEvaluationDataPath); @@ -276,10 +282,23 @@ protected List> splitNFold(List trainingDa } /** - * Load the dataset in memory and shuffle it. Assuming that each empty line is a delimiter between instances. - * Empty line are filtered out from the output + * Load the dataset in memory and shuffle it. */ protected List loadAndShuffle(Path dataPath) { + List trainingData = load(dataPath); + + Collections.shuffle(trainingData, new Random(839374947498L)); + + return trainingData; + } + + /** + * Read the Wapiti training files in list of String. + * Assuming that each empty line is a delimiter between instances. + * Each list element corresponds to one instance. + * Empty line are filtered out from the output. + */ + public List load(Path dataPath) { List trainingData = new ArrayList<>(); try (Stream stream = Files.lines(dataPath)) { List instance = new ArrayList<>(); @@ -296,12 +315,14 @@ protected List loadAndShuffle(Path dataPath) { instance.add(current); } } + if (CollectionUtils.isNotEmpty(instance)) { + trainingData.add(String.join("\n", instance)); + } } catch (IOException e) { - e.printStackTrace(); + throw new GrobidException("Error in n-fold, when loading training data. Failing. ", e); } - Collections.shuffle(trainingData); return trainingData; } diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/AbstractTrainerIntegrationTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/AbstractTrainerIntegrationTest.java index 4fc7dbd3ec..c2f92bb2a2 100644 --- a/grobid-trainer/src/test/java/org/grobid/trainer/AbstractTrainerIntegrationTest.java +++ b/grobid-trainer/src/test/java/org/grobid/trainer/AbstractTrainerIntegrationTest.java @@ -1,7 +1,5 @@ package org.grobid.trainer; -import org.apache.commons.collections4.CollectionUtils; -import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.ImmutablePair; import org.grobid.core.GrobidModels; import org.junit.After; @@ -9,23 +7,20 @@ import org.junit.BeforeClass; import org.junit.Test; -import java.io.*; -import java.nio.charset.StandardCharsets; +import java.io.File; +import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; -import java.util.ListIterator; -import java.util.stream.Collectors; -import java.util.stream.Stream; import static org.hamcrest.CoreMatchers.*; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.collection.IsCollectionWithSize.hasSize; -/** This test is creating temp files - some cannot be removed **/ public class AbstractTrainerIntegrationTest { private AbstractTrainer target; @@ -70,21 +65,84 @@ public void tearDown() throws Exception { } + @Test + public void testLoad_shouldWork() throws Exception { + Path path = Paths.get("src/test/resources/sample.wapiti.output.date.txt"); + List expected = Arrays.asList( + "Available available A Av Ava Avai e le ble able LINESTART INITCAP NODIGIT 0 0 0 NOPUNCT I-\n" + + "online online o on onl onli e ne ine line LINEIN NOCAPS NODIGIT 0 0 0 NOPUNCT \n" + + "18 18 1 18 18 18 8 18 18 18 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I-\n" + + "January january J Ja Jan Janu y ry ary uary LINEIN INITCAP NODIGIT 0 0 1 NOPUNCT I-\n" + + "2010 2010 2 20 201 2010 0 10 010 2010 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I-", + "June june J Ju Jun June e ne une June LINESTART INITCAP NODIGIT 0 0 1 NOPUNCT I-\n" + + "16 16 1 16 16 16 6 16 16 16 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I-\n" + + ", , , , , , , , , , LINEIN ALLCAP NODIGIT 1 0 0 COMMA I-\n" + + "2008 2008 2 20 200 2008 8 08 008 2008 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I-", + "November november N No Nov Nove r er ber mber LINESTART INITCAP NODIGIT 0 0 1 NOPUNCT I-\n" + + "4 4 4 4 4 4 4 4 4 4 LINEIN NOCAPS ALLDIGIT 1 0 0 NOPUNCT I-\n" + + ", , , , , , , , , , LINEIN ALLCAP NODIGIT 1 0 0 COMMA I-\n" + + "2009 2009 2 20 200 2009 9 09 009 2009 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I-", + "Published published P Pu Pub Publ d ed hed shed LINESTART INITCAP NODIGIT 0 0 0 NOPUNCT I-\n" + + "18 18 1 18 18 18 8 18 18 18 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I-\n" + + "May may M Ma May May y ay May May LINEIN INITCAP NODIGIT 0 0 1 NOPUNCT I-\n" + + "2011 2011 2 20 201 2011 1 11 011 2011 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I-"); + + List loadedTrainingData = target.load(path); + + assertThat(loadedTrainingData, is(expected)); + } + @Test - public void testSplitNFold_n2_shouldWork() throws Exception { + public void testSplitNFold_n3_shouldWork() throws Exception { List dummyTrainingData = new ArrayList<>(); - for (int i = 0; i < 100; i++) { - double random = Math.random(); - dummyTrainingData.add("blablabla" + random); - } - - List> splitMapping = target.splitNFold(dummyTrainingData, 2); - assertThat(splitMapping, hasSize(2)); + dummyTrainingData.add(dummyExampleGeneration("1", 3)); + dummyTrainingData.add(dummyExampleGeneration("2", 4)); + dummyTrainingData.add(dummyExampleGeneration("3", 2)); + dummyTrainingData.add(dummyExampleGeneration("4", 6)); + dummyTrainingData.add(dummyExampleGeneration("5", 6)); + dummyTrainingData.add(dummyExampleGeneration("6", 2)); + dummyTrainingData.add(dummyExampleGeneration("7", 2)); + dummyTrainingData.add(dummyExampleGeneration("8", 3)); + dummyTrainingData.add(dummyExampleGeneration("9", 3)); + dummyTrainingData.add(dummyExampleGeneration("10", 3)); + + List> splitMapping = target.splitNFold(dummyTrainingData, 3); + assertThat(splitMapping, hasSize(3)); assertThat(splitMapping.get(0).getLeft(), endsWith("train")); assertThat(splitMapping.get(0).getRight(), endsWith("test")); + //Fold 1 + List fold1Training = target.load(Paths.get(splitMapping.get(0).getLeft())); + List fold1Evaluation = target.load(Paths.get(splitMapping.get(0).getRight())); + + System.out.println(Arrays.toString(fold1Training.toArray())); + System.out.println(Arrays.toString(fold1Evaluation.toArray())); + + assertThat(fold1Training, hasSize(7)); + assertThat(fold1Evaluation, hasSize(3)); + + //Fold 2 + List fold2Training = target.load(Paths.get(splitMapping.get(1).getLeft())); + List fold2Evaluation = target.load(Paths.get(splitMapping.get(1).getRight())); + + System.out.println(Arrays.toString(fold2Training.toArray())); + System.out.println(Arrays.toString(fold2Evaluation.toArray())); + + assertThat(fold2Training, hasSize(7)); + assertThat(fold2Evaluation, hasSize(3)); + + //Fold 3 + List fold3Training = target.load(Paths.get(splitMapping.get(2).getLeft())); + List fold3Evaluation = target.load(Paths.get(splitMapping.get(2).getRight())); + + System.out.println(Arrays.toString(fold3Training.toArray())); + System.out.println(Arrays.toString(fold3Evaluation.toArray())); + + assertThat(fold3Training, hasSize(6)); + assertThat(fold3Evaluation, hasSize(4)); + // Cleanup splitMapping.stream().forEach(f -> { try { @@ -102,65 +160,53 @@ public void testSplitNFold_n2_shouldWork() throws Exception { }); } - @Test - public void testSplitNFold_n10_shouldWork() throws Exception { + @Test(expected = IllegalArgumentException.class) + public void testSplitNFold_n10_shouldThrowException() throws Exception { List dummyTrainingData = new ArrayList<>(); - for (int i = 0; i < 100; i++) { - double random = Math.random(); - dummyTrainingData.add("blablabla" + random); - } + dummyTrainingData.add(dummyExampleGeneration("1", 3)); + dummyTrainingData.add(dummyExampleGeneration("2", 4)); + dummyTrainingData.add(dummyExampleGeneration("3", 2)); + dummyTrainingData.add(dummyExampleGeneration("4", 6)); List> splitMapping = target.splitNFold(dummyTrainingData, 10); - assertThat(splitMapping, hasSize(10)); - assertThat(splitMapping.get(0).getLeft(), endsWith("train")); - assertThat(splitMapping.get(0).getRight(), endsWith("test")); - - // Cleanup - splitMapping.stream().forEach(f -> { - try { - Files.delete(Paths.get(f.getRight())); - } catch (IOException e) { - e.printStackTrace(); - } - }); - splitMapping.stream().forEach(f -> { - try { - Files.delete(Paths.get(f.getLeft())); - } catch (IOException e) { - e.printStackTrace(); - } - }); } + private String dummyExampleGeneration(String exampleId, int total) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < total; i++) { + sb.append("line " + i + " example " + exampleId).append("\n"); + } + sb.append("\n"); + return sb.toString(); + } @Test public void testLoadAndShuffle_shouldWork() throws Exception { Path path = Paths.get("src/test/resources/sample.wapiti.output.date.txt"); - List orderedTrainingData = new ArrayList<>(); - try (Stream stream = Files.lines(path)) { - - ListIterator iterator = stream.collect(Collectors.toList()).listIterator(); - List instance = new ArrayList<>(); - while (iterator.hasNext()) { - String current = iterator.next(); - if (StringUtils.isBlank(current)) { - if (CollectionUtils.isNotEmpty(instance)) { - orderedTrainingData.add(String.join("\n", instance)); - } - instance = new ArrayList<>(); - } else { - instance.add(current); - } - } - } + List orderedTrainingData = Arrays.asList( + "Available available A Av Ava Avai e le ble able LINESTART INITCAP NODIGIT 0 0 0 NOPUNCT I-\n" + + "online online o on onl onli e ne ine line LINEIN NOCAPS NODIGIT 0 0 0 NOPUNCT \n" + + "18 18 1 18 18 18 8 18 18 18 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I-\n" + + "January january J Ja Jan Janu y ry ary uary LINEIN INITCAP NODIGIT 0 0 1 NOPUNCT I-\n" + + "2010 2010 2 20 201 2010 0 10 010 2010 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I-", + "June june J Ju Jun June e ne une June LINESTART INITCAP NODIGIT 0 0 1 NOPUNCT I-\n" + + "16 16 1 16 16 16 6 16 16 16 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I-\n" + + ", , , , , , , , , , LINEIN ALLCAP NODIGIT 1 0 0 COMMA I-\n" + + "2008 2008 2 20 200 2008 8 08 008 2008 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I-", + "November november N No Nov Nove r er ber mber LINESTART INITCAP NODIGIT 0 0 1 NOPUNCT I-\n" + + "4 4 4 4 4 4 4 4 4 4 LINEIN NOCAPS ALLDIGIT 1 0 0 NOPUNCT I-\n" + + ", , , , , , , , , , LINEIN ALLCAP NODIGIT 1 0 0 COMMA I-\n" + + "2009 2009 2 20 200 2009 9 09 009 2009 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I-", + "Published published P Pu Pub Publ d ed hed shed LINESTART INITCAP NODIGIT 0 0 0 NOPUNCT I-\n" + + "18 18 1 18 18 18 8 18 18 18 LINEIN NOCAPS ALLDIGIT 0 0 0 NOPUNCT I-\n" + + "May may M Ma May May y ay May May LINEIN INITCAP NODIGIT 0 0 1 NOPUNCT I-\n" + + "2011 2011 2 20 201 2011 1 11 011 2011 LINEEND NOCAPS ALLDIGIT 0 1 0 NOPUNCT I-"); + List shuffledTrainingData = target.loadAndShuffle(path); assertThat(shuffledTrainingData, hasSize(orderedTrainingData.size())); - - assertThat(shuffledTrainingData.get(0), is(not(orderedTrainingData.get(0)))); - assertThat(shuffledTrainingData.get(0).split("\n").length, is(4)); - assertThat(shuffledTrainingData.get(1), is(not(orderedTrainingData.get(1)))); + assertThat(shuffledTrainingData, is(not(orderedTrainingData))); } } \ No newline at end of file From fbd798a2b77df668e978bbe5ed3c67ca38d385b7 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Thu, 11 Jul 2019 16:03:42 +0900 Subject: [PATCH 12/37] Printing out more information #453 --- .../main/java/org/grobid/trainer/AbstractTrainer.java | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index c13a75ce74..ee7af17d83 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -25,6 +25,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.*; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -174,11 +175,15 @@ public String nFoldEvaluate(int folds) { LOGGER.warn("Cannot find the destination directory " + tmpDirectory); } - final File tempModelPath = new File(tmpDirectory + File.separator + "nfold_dummy_model"); - System.out.println("Saving model in " + tempModelPath); - + AtomicInteger counter = new AtomicInteger(0); List evaluationResults = foldMap.stream().map(fold -> { + final File tempModelPath = new File(tmpDirectory + File.separator + getModel().getModelName() + + "_nfold_" + counter.getAndIncrement() + ".wapiti"); + System.out.println("Saving model in " + tempModelPath); + + System.out.println("Training input data: " + fold.getLeft()); trainer.train(getTemplatePath(), new File(fold.getLeft()), tempModelPath, GrobidProperties.getNBThreads(), model); + System.out.println("Evaluation input data: " + fold.getRight()); return EvaluationUtilities.evaluateStandard(fold.getRight(), getTagger()); }).collect(Collectors.toList()); From 59122b42e83378d6da8448bc156cdc75638f1743 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Thu, 11 Jul 2019 16:12:06 +0900 Subject: [PATCH 13/37] Adding more information in output #453 --- .../main/java/org/grobid/trainer/AbstractTrainer.java | 9 ++++++--- .../grobid/trainer/evaluation/EvaluationUtilities.java | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index ee7af17d83..d2ed43a0aa 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -184,7 +184,9 @@ public String nFoldEvaluate(int folds) { System.out.println("Training input data: " + fold.getLeft()); trainer.train(getTemplatePath(), new File(fold.getLeft()), tempModelPath, GrobidProperties.getNBThreads(), model); System.out.println("Evaluation input data: " + fold.getRight()); - return EvaluationUtilities.evaluateStandard(fold.getRight(), getTagger()); + ModelStats modelStats = EvaluationUtilities.evaluateStandard(fold.getRight(), getTagger()); + System.out.println(EvaluationUtilities.reportMetrics(modelStats)); + return modelStats; }).collect(Collectors.toList()); System.out.println("Results: "); @@ -207,14 +209,15 @@ public String nFoldEvaluate(int folds) { throw new GrobidException("Something wrong when computing evaluations " + "- worst model metrics not found. "); }); - sb.append(EvaluationUtilities.reportMetrics(worstModelStats)); + sb.append(EvaluationUtilities.reportMetrics(worstModelStats)).append("\n"); + sb.append("Best model:").append("\n"); Optional bestModel = evaluationResults.stream().max(f1ScoreComparator); ModelStats bestModelStats = bestModel.orElseGet(() -> { throw new GrobidException("Something wrong when computing evaluations " + "- best model metrics not found. "); }); - sb.append(EvaluationUtilities.reportMetrics(bestModelStats)); + sb.append(EvaluationUtilities.reportMetrics(bestModelStats)).append("\n"); // Averages OptionalDouble averageF1 = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageF1()).average(); diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java index a28ecbd7f8..ecca84da5a 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java @@ -110,7 +110,7 @@ public static ModelStats evaluateStandard(String path, Function, St final BufferedReader bufReader = new BufferedReader(new InputStreamReader(new FileInputStream(path), StandardCharsets.UTF_8)); String line = null; - List instance = new ArrayList(); + List instance = new ArrayList<>(); while ((line = bufReader.readLine()) != null) { instance.add(line); } From 55f060fe12d8d0d128ccb4a747286d6b60a7ca03 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Thu, 11 Jul 2019 16:41:39 +0900 Subject: [PATCH 14/37] cleanup #453 --- .../evaluation/EvaluationUtilities.java | 18 +++++++----------- .../grobid/trainer/evaluation/ModelStats.java | 16 +++++++--------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java index ecca84da5a..855ce2a366 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java @@ -126,14 +126,14 @@ public static ModelStats evaluateStandard(String path, Function, St } public static ModelStats computeStats(String theResult) { - ModelStats accumulator = new ModelStats(); + ModelStats modelStats = new ModelStats(); // report token-level results Stats wordStats = tokenLevelStats(theResult); - accumulator.setTokenStats(wordStats); + modelStats.setTokenStats(wordStats); // report field-level results Stats fieldStats = fieldLevelStats(theResult); - accumulator.setFieldStats(fieldStats); + modelStats.setFieldStats(fieldStats); // instance-level: instances are separated by a new line in the result file // third pass @@ -171,13 +171,10 @@ public static ModelStats computeStats(String theResult) { } } - accumulator.setTotalInstances(totalInstance); - accumulator.setCorrectInstance(correctInstance); - accumulator.setInstanceAccuracy(correctInstance); - double accuracy = (double) (correctInstance / totalInstance); + modelStats.setTotalInstances(totalInstance); + modelStats.setCorrectInstance(correctInstance); - - return accumulator; + return modelStats; } public static String reportMetrics(ModelStats accumulated) { @@ -198,10 +195,9 @@ public static String reportMetrics(ModelStats accumulated) { report.append("\n===== Instance-level results =====\n\n"); report.append(String.format("%-27s %d\n", "Total expected instances:", accumulated.getTotalInstances())); report.append(String.format("%-27s %d\n", "Correct instances:", accumulated.getCorrectInstance())); - double accuracy = (double) accumulated.getCorrectInstance() / (accumulated.getTotalInstances()); report.append(String.format("%-27s %s\n", "Instance-level recall:", - TextUtilities.formatTwoDecimals(accuracy * 100))); + TextUtilities.formatTwoDecimals(accumulated.getInstanceRecall() * 100))); return report.toString(); } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java index 7078e191f1..32c55bc5e2 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java @@ -6,7 +6,6 @@ public class ModelStats { private int totalInstances; private int correctInstance; - private int instanceAccuracy; private Stats tokenStats; private Stats fieldStats; @@ -26,14 +25,6 @@ public int getCorrectInstance() { return correctInstance; } - public void setInstanceAccuracy(int instanceAccuracy) { - this.instanceAccuracy = instanceAccuracy; - } - - public int getInstanceAccuracy() { - return instanceAccuracy; - } - public void setTokenStats(Stats tokenStats) { this.tokenStats = tokenStats; } @@ -49,4 +40,11 @@ public void setFieldStats(Stats fieldStats) { public Stats getFieldStats() { return fieldStats; } + + public double getInstanceRecall() { + if (getTotalInstances() <= 0) { + return 0.0d; + } + return (double) getCorrectInstance() / (getTotalInstances()); + } } From 8cb585184017eff70ed8d6daa0a7e0bb937bcba7 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Thu, 11 Jul 2019 17:22:10 +0900 Subject: [PATCH 15/37] implementing output report on file - moved printing code within the model class #453 --- .gitignore | 1 + .java-version | 1 + .../org/grobid/trainer/AbstractTrainer.java | 34 +++++++++++++------ .../org/grobid/trainer/TrainerRunner.java | 29 +++++++++++++--- .../evaluation/EvaluationUtilities.java | 25 -------------- .../grobid/trainer/evaluation/ModelStats.java | 28 +++++++++++++++ 6 files changed, 79 insertions(+), 39 deletions(-) create mode 100644 .java-version diff --git a/.gitignore b/.gitignore index 0dea71b51d..1c603462c0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.java-version target grobid-home/tmp .DS_Store diff --git a/.java-version b/.java-version new file mode 100644 index 0000000000..f2052c3e9a --- /dev/null +++ b/.java-version @@ -0,0 +1 @@ +openjdk64-11.0.2 diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index d2ed43a0aa..f7c8383092 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -114,7 +114,7 @@ protected void renameModels(final File oldModelPath, final File tempModelPath) { @Override public String evaluate() { createCRFPPData(getEvalCorpusPath(), evalDataPath); - return EvaluationUtilities.reportMetrics(EvaluationUtilities.evaluateStandard(evalDataPath.getAbsolutePath(), getTagger())); + return EvaluationUtilities.evaluateStandard(evalDataPath.getAbsolutePath(), getTagger()).toString(); } @Override @@ -145,7 +145,7 @@ public String splitTrainEvaluate(Double split) { // if we are here, that means that training succeeded renameModels(oldModelPath, tempModelPath); - return EvaluationUtilities.reportMetrics(EvaluationUtilities.evaluateStandard(evalDataPath.getAbsolutePath(), getTagger())); + return EvaluationUtilities.evaluateStandard(evalDataPath.getAbsolutePath(), getTagger()).toString(); } @Override @@ -185,13 +185,12 @@ public String nFoldEvaluate(int folds) { trainer.train(getTemplatePath(), new File(fold.getLeft()), tempModelPath, GrobidProperties.getNBThreads(), model); System.out.println("Evaluation input data: " + fold.getRight()); ModelStats modelStats = EvaluationUtilities.evaluateStandard(fold.getRight(), getTagger()); - System.out.println(EvaluationUtilities.reportMetrics(modelStats)); + System.out.println(modelStats.toString()); return modelStats; }).collect(Collectors.toList()); System.out.println("Results: "); - Comparator f1ScoreComparator = (o1, o2) -> { if (o1.getFieldStats().getMacroAverageF1() > o2.getFieldStats().getMacroAverageF1()) { return 1; @@ -209,7 +208,7 @@ public String nFoldEvaluate(int folds) { throw new GrobidException("Something wrong when computing evaluations " + "- worst model metrics not found. "); }); - sb.append(EvaluationUtilities.reportMetrics(worstModelStats)).append("\n"); + sb.append(worstModelStats.toString()).append("\n"); sb.append("Best model:").append("\n"); Optional bestModel = evaluationResults.stream().max(f1ScoreComparator); @@ -217,7 +216,7 @@ public String nFoldEvaluate(int folds) { throw new GrobidException("Something wrong when computing evaluations " + "- best model metrics not found. "); }); - sb.append(EvaluationUtilities.reportMetrics(bestModelStats)).append("\n"); + sb.append(bestModelStats.toString()).append("\n"); // Averages OptionalDouble averageF1 = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageF1()).average(); @@ -430,17 +429,32 @@ public static void runSplitTrainingEvaluation(final Trainer trainer, Double spli System.out.println("Split, training and evaluation for " + trainer.getModel() + " model is realized in " + (end - start) + " ms"); } + public static void runNFoldEvaluation(final Trainer trainer, int numFolds, Path outputFile) { + + String report = runNFoldEvaluation(trainer, numFolds); + + try (BufferedWriter writer = Files.newBufferedWriter(outputFile)) { + writer.write(report); + writer.write("\n"); + } catch (IOException e) { + throw new GrobidException("Error when dumping n-fold training data into files. ", e); + } - public static void runNFoldEvaluation(final Trainer trainer, int numFolds) { + } + + public static String runNFoldEvaluation(final Trainer trainer, int numFolds) { long start = System.currentTimeMillis(); + String report = ""; try { - String report = trainer.nFoldEvaluate(numFolds); - System.out.println(report); + report = trainer.nFoldEvaluate(numFolds); + } catch (Exception e) { throw new GrobidException("An exception occurred while evaluating Grobid.", e); } long end = System.currentTimeMillis(); - System.out.println("Split, training and evaluation for " + trainer.getModel() + " model is realized in " + (end - start) + " ms"); + System.out.println("N-Fold evaluation for " + trainer.getModel() + " model is realized in " + (end - start) + " ms"); + + return report; } /** diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java b/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java index 5e83246083..5d9473021d 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java @@ -1,8 +1,12 @@ package org.grobid.trainer; +import org.apache.commons.lang3.StringUtils; import org.grobid.core.utilities.GrobidProperties; import java.io.File; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.Arrays; import java.util.List; @@ -47,8 +51,9 @@ public static void main(String[] args) { } String path2GbdHome = null; - Double split = 0.0; - Integer numFolds = 0; + double split = 0.0; + int numFolds = 0; + String outputFilePath = null; for (int i = 0; i < args.length; i++) { if (args[i].equals("-gH")) { if (i + 1 == args.length) { @@ -72,9 +77,15 @@ public static void main(String[] args) { try { numFolds = Integer.parseInt(args[i + 1]); } catch (Exception e) { - throw new IllegalStateException("Invalid split value: " + args[i + 1]); + throw new IllegalStateException("Invalid number of folds value: " + args[i + 1]); } + } else if (args[i].equals("-o")) { + if (i + 1 == args.length) { + throw new IllegalStateException("Missing output file. "); + } + outputFilePath = args[i + 1]; + } } @@ -135,7 +146,17 @@ public static void main(String[] args) { AbstractTrainer.runSplitTrainingEvaluation(trainer, split); break; case EVAL_N_FOLD: - AbstractTrainer.runNFoldEvaluation(trainer, numFolds); + if (StringUtils.isNotEmpty(outputFilePath)) { + Path outputPath = Paths.get(outputFilePath); + if (Files.exists(outputPath)) { + System.err.println("Output file exists. "); + } + AbstractTrainer.runNFoldEvaluation(trainer, numFolds, outputPath); + } else { + + String results = AbstractTrainer.runNFoldEvaluation(trainer, numFolds); + System.out.println(results); + } break; default: throw new IllegalStateException("Invalid RunType: " + mode.name()); diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java index 855ce2a366..71bc226664 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java @@ -177,31 +177,6 @@ public static ModelStats computeStats(String theResult) { return modelStats; } - public static String reportMetrics(ModelStats accumulated) { - StringBuilder report = new StringBuilder(); - - // report token-level results - Stats wordStats = accumulated.getTokenStats(); - report.append("\n===== Token-level results =====\n\n"); - report.append(wordStats.getReport()); - - // report field-level results - Stats fieldStats = accumulated.getFieldStats(); - report.append("\n===== Field-level results =====\n"); - report.append(fieldStats.getReport()); - - // instance-level: instances are separated by a new line in the result file - // third pass - report.append("\n===== Instance-level results =====\n\n"); - report.append(String.format("%-27s %d\n", "Total expected instances:", accumulated.getTotalInstances())); - report.append(String.format("%-27s %d\n", "Correct instances:", accumulated.getCorrectInstance())); - report.append(String.format("%-27s %s\n", - "Instance-level recall:", - TextUtilities.formatTwoDecimals(accumulated.getInstanceRecall() * 100))); - - return report.toString(); - } - public static Stats tokenLevelStats(String theResult) { Stats wordStats = new Stats(); String line = null; diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java index 32c55bc5e2..37fac902a5 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java @@ -1,5 +1,9 @@ package org.grobid.trainer.evaluation; +import org.grobid.core.utilities.TextUtilities; + +import java.io.PrintStream; + /** * Represent all different evaluation given a specific model */ @@ -47,4 +51,28 @@ public double getInstanceRecall() { } return (double) getCorrectInstance() / (getTotalInstances()); } + + public String toString() { + StringBuilder report = new StringBuilder(); + + // report token-level results + Stats wordStats = getTokenStats(); + report.append("\n===== Token-level results =====\n\n"); + report.append(wordStats.getReport()); + + // report field-level results + Stats fieldStats = getFieldStats(); + report.append("\n===== Field-level results =====\n"); + report.append(fieldStats.getReport()); + + // instance-level: instances are separated by a new line in the result file + report.append("\n===== Instance-level results =====\n\n"); + report.append(String.format("%-27s %d\n", "Total expected instances:", getTotalInstances())); + report.append(String.format("%-27s %d\n", "Correct instances:", getCorrectInstance())); + report.append(String.format("%-27s %s\n", + "Instance-level recall:", + TextUtilities.formatTwoDecimals(getInstanceRecall() * 100))); + + return report.toString(); + } } From ec1b17820cbea9c202b090bbb8c59c04c8562ebc Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Fri, 12 Jul 2019 10:52:28 +0900 Subject: [PATCH 16/37] Fixing results output #453 #58 --- .../org/grobid/trainer/AbstractTrainer.java | 42 +++++++++++-------- .../evaluation/EvaluationUtilities.java | 4 +- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index f7c8383092..ee0ad385b1 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -149,7 +149,7 @@ public String splitTrainEvaluate(Double split) { } @Override - public String nFoldEvaluate(int folds) { + public String nFoldEvaluate(int numFolds) { final File dataPath = trainDataPath; createCRFPPData(getCorpusPath(), dataPath); GenericTrainer trainer = TrainerFactory.getTrainer(); @@ -159,7 +159,7 @@ public String nFoldEvaluate(int folds) { List trainingData = loadAndShuffle(dataPath2); // Split into folds - List> foldMap = splitNFold(trainingData, folds); + List> foldMap = splitNFold(trainingData, numFolds); // Train and evaluation if (epsilon != 0.0) @@ -223,20 +223,23 @@ public String nFoldEvaluate(int folds) { OptionalDouble averagePrecision = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAveragePrecision()).average(); OptionalDouble averageRecall = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageRecall()).average(); + sb.append("average over " + numFolds + " folds: ").append("\n"); + + double avgF1 = averageF1.orElseGet(() -> { + throw new GrobidException("Missing average F1. Something went wrong. Please check. "); + }); + sb.append("\tmacro f1 = " + TextUtilities.formatTwoDecimals(avgF1 * 100)).append("\n"); + double avgPrecision = averagePrecision.orElseGet(() -> { throw new GrobidException("Missing average precision. Something went wrong. Please check. "); }); - sb.append("Average precision: " + TextUtilities.formatTwoDecimals(avgPrecision * 100)).append("\n"); + sb.append("\tmacro precision = " + TextUtilities.formatTwoDecimals(avgPrecision * 100)).append("\n"); double avgRecall = averageRecall.orElseGet(() -> { throw new GrobidException("Missing average recall. Something went wrong. Please check. "); }); - sb.append("Average recall: " + TextUtilities.formatTwoDecimals(avgRecall * 100)).append("\n"); + sb.append("\tmacro recall = " + TextUtilities.formatTwoDecimals(avgRecall * 100)).append("\n"); - double avgF1 = averageF1.orElseGet(() -> { - throw new GrobidException("Missing average F1. Something went wrong. Please check. "); - }); - sb.append("Average F1: " + TextUtilities.formatTwoDecimals(avgF1 * 100)).append("\n"); return sb.toString(); } @@ -405,28 +408,33 @@ public File getEvalDataPath() { return evalDataPath; } - public static void runEvaluation(final Trainer trainer) { + public static String runEvaluation(final Trainer trainer) { long start = System.currentTimeMillis(); + String report = ""; try { - String report = trainer.evaluate(); - System.out.println(report); + report = trainer.evaluate(); } catch (Exception e) { throw new GrobidException("An exception occurred while evaluating Grobid.", e); } long end = System.currentTimeMillis(); - System.out.println("Evaluation for " + trainer.getModel() + " model is realized in " + (end - start) + " ms"); + report += "\n\nEvaluation for " + trainer.getModel() + " model is realized in " + (end - start) + " ms"; + + return report; } - public static void runSplitTrainingEvaluation(final Trainer trainer, Double split) { + public static String runSplitTrainingEvaluation(final Trainer trainer, Double split) { long start = System.currentTimeMillis(); + String report = ""; try { - String report = trainer.splitTrainEvaluate(split); - System.out.println(report); + report = trainer.splitTrainEvaluate(split); + } catch (Exception e) { throw new GrobidException("An exception occurred while evaluating Grobid.", e); } long end = System.currentTimeMillis(); - System.out.println("Split, training and evaluation for " + trainer.getModel() + " model is realized in " + (end - start) + " ms"); + report += "\n\nSplit, training and evaluation for " + trainer.getModel() + " model is realized in " + (end - start) + " ms"; + + return report; } public static void runNFoldEvaluation(final Trainer trainer, int numFolds, Path outputFile) { @@ -452,7 +460,7 @@ public static String runNFoldEvaluation(final Trainer trainer, int numFolds) { throw new GrobidException("An exception occurred while evaluating Grobid.", e); } long end = System.currentTimeMillis(); - System.out.println("N-Fold evaluation for " + trainer.getModel() + " model is realized in " + (end - start) + " ms"); + report += "\n\nN-Fold evaluation for " + trainer.getModel() + " model is realized in " + (end - start) + " ms"; return report; } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java index 71bc226664..2db815fbe6 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java @@ -276,14 +276,14 @@ public static Stats fieldLevelStats(String theResult) { // last fields of the sequence if ((previousObtainedLabel != null)) { currentObtainedPosition.end = pos - 1; - Pair theField = new Pair(getPlainLabel(previousObtainedLabel), + Pair theField = new Pair<>(getPlainLabel(previousObtainedLabel), currentObtainedPosition); obtainedFields.add(theField); } if ((previousExpectedLabel != null)) { currentExpectedPosition.end = pos - 1; - Pair theField = new Pair(getPlainLabel(previousExpectedLabel), + Pair theField = new Pair<>(getPlainLabel(previousExpectedLabel), currentExpectedPosition); expectedFields.add(theField); } From c72ff35c5512899824c601a113ab6594498bbcf5 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Fri, 12 Jul 2019 11:07:35 +0900 Subject: [PATCH 17/37] removing token-level results #453 --- .../main/java/org/grobid/trainer/evaluation/ModelStats.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java index 37fac902a5..7d3a422715 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java @@ -56,9 +56,9 @@ public String toString() { StringBuilder report = new StringBuilder(); // report token-level results - Stats wordStats = getTokenStats(); - report.append("\n===== Token-level results =====\n\n"); - report.append(wordStats.getReport()); +// Stats wordStats = getTokenStats(); +// report.append("\n===== Token-level results =====\n\n"); +// report.append(wordStats.getReport()); // report field-level results Stats fieldStats = getFieldStats(); From 2ba3bfd5c0f114cf10f6b04d3166fbca6fe08caf Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Tue, 16 Jul 2019 13:23:53 +0900 Subject: [PATCH 18/37] Improving output format #453 --- .../java/org/grobid/trainer/AbstractTrainer.java | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index ee0ad385b1..c934a4df88 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -175,6 +175,11 @@ public String nFoldEvaluate(int numFolds) { LOGGER.warn("Cannot find the destination directory " + tmpDirectory); } + // Output + StringBuilder sb = new StringBuilder(); + sb.append("Recap results for each fold:").append("\n\n"); + + AtomicInteger counter = new AtomicInteger(0); List evaluationResults = foldMap.stream().map(fold -> { final File tempModelPath = new File(tmpDirectory + File.separator + getModel().getModelName() @@ -186,10 +191,14 @@ public String nFoldEvaluate(int numFolds) { System.out.println("Evaluation input data: " + fold.getRight()); ModelStats modelStats = EvaluationUtilities.evaluateStandard(fold.getRight(), getTagger()); System.out.println(modelStats.toString()); + + sb.append(" ====================== Fold " + counter.get() + " ====================== ").append("\n"); + sb.append(modelStats.toString()).append("\n"); + return modelStats; }).collect(Collectors.toList()); - System.out.println("Results: "); + sb.append("\n").append("Summary results: ").append("\n"); Comparator f1ScoreComparator = (o1, o2) -> { if (o1.getFieldStats().getMacroAverageF1() > o2.getFieldStats().getMacroAverageF1()) { @@ -200,8 +209,7 @@ public String nFoldEvaluate(int numFolds) { return 0; } }; - // Output - StringBuilder sb = new StringBuilder(); + Optional worstModel = evaluationResults.stream().min(f1ScoreComparator); sb.append("Worst Model").append("\n"); ModelStats worstModelStats = worstModel.orElseGet(() -> { From 7af4150db562b3317709e17c4933e764fec000ab Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Tue, 16 Jul 2019 13:29:48 +0900 Subject: [PATCH 19/37] Using java base librarie instead of guava #453 --- .../org/grobid/trainer/evaluation/EvaluationUtilities.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java index 2db815fbe6..0e5e3b87f9 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java @@ -1,10 +1,8 @@ package org.grobid.trainer.evaluation; -import com.google.common.base.Function; import org.chasen.crfpp.Tagger; import org.grobid.core.engines.tagging.GenericTagger; import org.grobid.core.exceptions.GrobidException; -import org.grobid.core.utilities.TextUtilities; import org.grobid.core.utilities.OffsetPosition; import org.grobid.core.utilities.Pair; @@ -15,6 +13,7 @@ import java.util.ArrayList; import java.util.List; import java.util.StringTokenizer; +import java.util.function.Function; import org.grobid.trainer.LabelStat; From 244a09a3d5d51287d425f9d9f24af4fe567e9e08 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Tue, 16 Jul 2019 13:53:35 +0900 Subject: [PATCH 20/37] Print output #453 --- .../src/main/java/org/grobid/trainer/DateTrainer.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java index 1d3385b21f..de95b2060c 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java @@ -162,7 +162,7 @@ public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); DateTrainer trainer = new DateTrainer(); AbstractTrainer.runTraining(trainer); - AbstractTrainer.runEvaluation(trainer); + System.out.println(AbstractTrainer.runEvaluation(trainer)); System.exit(0); } From 156190c01cb1e8337bdc07edc324b4ab345a4f41 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Tue, 16 Jul 2019 13:57:25 +0900 Subject: [PATCH 21/37] Print output #453 --- .../java/org/grobid/trainer/AffiliationAddressTrainer.java | 2 +- .../main/java/org/grobid/trainer/ChemicalEntityTrainer.java | 2 +- .../src/main/java/org/grobid/trainer/CitationTrainer.java | 2 +- .../src/main/java/org/grobid/trainer/FigureTrainer.java | 2 +- .../src/main/java/org/grobid/trainer/FulltextTrainer.java | 2 +- .../src/main/java/org/grobid/trainer/HeaderTrainer.java | 2 +- .../src/main/java/org/grobid/trainer/MonographTrainer.java | 2 +- .../src/main/java/org/grobid/trainer/NameCitationTrainer.java | 2 +- .../src/main/java/org/grobid/trainer/NameHeaderTrainer.java | 2 +- .../java/org/grobid/trainer/ReferenceSegmenterTrainer.java | 2 +- .../src/main/java/org/grobid/trainer/SegmentationTrainer.java | 2 +- .../src/main/java/org/grobid/trainer/ShorttextTrainer.java | 2 +- .../src/main/java/org/grobid/trainer/TableTrainer.java | 2 +- .../src/main/java/org/grobid/trainer/TrainerRunner.java | 4 ++-- 14 files changed, 15 insertions(+), 15 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AffiliationAddressTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AffiliationAddressTrainer.java index da3816df49..458a68d666 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AffiliationAddressTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AffiliationAddressTrainer.java @@ -175,7 +175,7 @@ public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); Trainer trainer = new AffiliationAddressTrainer(); AbstractTrainer.runTraining(trainer); - AbstractTrainer.runEvaluation(trainer); + System.out.println(AbstractTrainer.runEvaluation(trainer)); System.exit(0); } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/ChemicalEntityTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/ChemicalEntityTrainer.java index 05ae7333f3..c5b3d13d7d 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/ChemicalEntityTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/ChemicalEntityTrainer.java @@ -431,7 +431,7 @@ public void addFeatures(List texts, public static void main(String[] args) { Trainer trainer = new ChemicalEntityTrainer(); AbstractTrainer.runTraining(trainer); - AbstractTrainer.runEvaluation(trainer); + System.out.println(AbstractTrainer.runEvaluation(trainer)); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/CitationTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/CitationTrainer.java index d900aaac3f..c2cd493474 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/CitationTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/CitationTrainer.java @@ -195,7 +195,7 @@ public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); Trainer trainer = new CitationTrainer(); AbstractTrainer.runTraining(trainer); - AbstractTrainer.runEvaluation(trainer); + System.out.println(AbstractTrainer.runEvaluation(trainer)); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/FigureTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/FigureTrainer.java index d2a14295d5..a6dcf0603b 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/FigureTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/FigureTrainer.java @@ -215,7 +215,7 @@ public boolean accept(File dir, String name) { public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); AbstractTrainer.runTraining(new FigureTrainer()); - AbstractTrainer.runEvaluation(new FigureTrainer()); + System.out.println(AbstractTrainer.runEvaluation(new FigureTrainer())); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/FulltextTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/FulltextTrainer.java index 0336ba31e4..b24298cd8c 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/FulltextTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/FulltextTrainer.java @@ -252,7 +252,7 @@ public boolean accept(File dir, String name) { public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); AbstractTrainer.runTraining(new FulltextTrainer()); - AbstractTrainer.runEvaluation(new FulltextTrainer()); + System.out.println(AbstractTrainer.runEvaluation(new FulltextTrainer())); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/HeaderTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/HeaderTrainer.java index 1e52177653..debb95b20a 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/HeaderTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/HeaderTrainer.java @@ -300,7 +300,7 @@ public boolean accept(File dir, String name) { public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); AbstractTrainer.runTraining(new HeaderTrainer()); - AbstractTrainer.runEvaluation(new HeaderTrainer()); + System.out.println(AbstractTrainer.runEvaluation(new HeaderTrainer())); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/MonographTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/MonographTrainer.java index 159aaa93bb..1f95974ea9 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/MonographTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/MonographTrainer.java @@ -162,7 +162,7 @@ public boolean accept(File dir, String name) { public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); AbstractTrainer.runTraining(new MonographTrainer()); - AbstractTrainer.runEvaluation(new MonographTrainer()); + System.out.println(AbstractTrainer.runEvaluation(new MonographTrainer())); System.exit(0); } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/NameCitationTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/NameCitationTrainer.java index 1720aea4a7..e127208ba8 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/NameCitationTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/NameCitationTrainer.java @@ -170,7 +170,7 @@ public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); Trainer trainer = new NameCitationTrainer(); AbstractTrainer.runTraining(trainer); - AbstractTrainer.runEvaluation(trainer); + System.out.println(AbstractTrainer.runEvaluation(trainer)); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/NameHeaderTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/NameHeaderTrainer.java index a5f2b5db11..7edcbdd720 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/NameHeaderTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/NameHeaderTrainer.java @@ -229,7 +229,7 @@ public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); Trainer trainer = new NameHeaderTrainer(); AbstractTrainer.runTraining(trainer); - AbstractTrainer.runEvaluation(trainer); + System.out.println(AbstractTrainer.runEvaluation(trainer)); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/ReferenceSegmenterTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/ReferenceSegmenterTrainer.java index e65fda701d..39aff5cfa4 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/ReferenceSegmenterTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/ReferenceSegmenterTrainer.java @@ -214,7 +214,7 @@ public boolean accept(File dir, String name) { public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); AbstractTrainer.runTraining(new ReferenceSegmenterTrainer()); - AbstractTrainer.runEvaluation(new ReferenceSegmenterTrainer()); + System.out.println(AbstractTrainer.runEvaluation(new ReferenceSegmenterTrainer())); System.exit(0); } } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/SegmentationTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/SegmentationTrainer.java index ba9b2a57b8..7fdd04e536 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/SegmentationTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/SegmentationTrainer.java @@ -348,7 +348,7 @@ public boolean accept(File dir, String name) { public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); AbstractTrainer.runTraining(new SegmentationTrainer()); - AbstractTrainer.runEvaluation(new SegmentationTrainer()); + System.out.println(AbstractTrainer.runEvaluation(new SegmentationTrainer())); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/ShorttextTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/ShorttextTrainer.java index 4403972be1..36d07604d1 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/ShorttextTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/ShorttextTrainer.java @@ -162,7 +162,7 @@ public boolean accept(File dir, String name) { public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); AbstractTrainer.runTraining(new ShorttextTrainer()); - AbstractTrainer.runEvaluation(new ShorttextTrainer()); + System.out.println(AbstractTrainer.runEvaluation(new ShorttextTrainer())); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/TableTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/TableTrainer.java index 30e53fba28..9269906971 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/TableTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/TableTrainer.java @@ -221,7 +221,7 @@ public boolean accept(File dir, String name) { public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); AbstractTrainer.runTraining(new TableTrainer()); - AbstractTrainer.runEvaluation(new TableTrainer()); + System.out.println(AbstractTrainer.runEvaluation(new TableTrainer())); System.exit(0); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java b/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java index 5d9473021d..2564cda11f 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java @@ -140,10 +140,10 @@ public static void main(String[] args) { AbstractTrainer.runTraining(trainer); break; case EVAL: - AbstractTrainer.runEvaluation(trainer); + System.out.println(AbstractTrainer.runEvaluation(trainer)); break; case SPLIT: - AbstractTrainer.runSplitTrainingEvaluation(trainer, split); + System.out.println(AbstractTrainer.runSplitTrainingEvaluation(trainer, split)); break; case EVAL_N_FOLD: if (StringUtils.isNotEmpty(outputFilePath)) { From 8b83ffe5049f855d09aa169bb5696828cab5415a Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Tue, 16 Jul 2019 16:03:05 +0900 Subject: [PATCH 22/37] Adding label support in result evaluation + cosmetics #453 --- .../src/main/java/org/grobid/trainer/LabelStat.java | 4 ++++ .../main/java/org/grobid/trainer/TrainerRunner.java | 2 +- .../org/grobid/trainer/evaluation/ModelStats.java | 2 -- .../java/org/grobid/trainer/evaluation/Stats.java | 11 +++++++---- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/LabelStat.java b/grobid-trainer/src/main/java/org/grobid/trainer/LabelStat.java index 4f17e32877..670fdcdc40 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/LabelStat.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/LabelStat.java @@ -102,6 +102,10 @@ public double getAccuracy() { return accuracy; } + public long getSupport() { + return observed + falsePositive + falseNegative; + } + public double getPrecision() { if (observed == 0.0) { return 0.0; diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java b/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java index 2564cda11f..cb50d22bee 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java @@ -91,7 +91,7 @@ public static void main(String[] args) { if (path2GbdHome == null) { throw new IllegalStateException( - "Usage: {" + String.join(", ", options) + "} {" + String.join(", ", models) + "} -gH /path/to/Grobid/home -s { [0.0 - 1.0] - split ratio, optional} -n {[int, num folds for n-fold evaluation, optional]}"); + "Grobid-home path not found.\n Usage: {" + String.join(", ", options) + "} {" + String.join(", ", models) + "} -gH /path/to/Grobid/home -s { [0.0 - 1.0] - split ratio, optional} -n {[int, num folds for n-fold evaluation, optional]}"); } final String path2GbdProperties = path2GbdHome + File.separator + "config" + File.separator + "grobid.properties"; diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java index 7d3a422715..da1ff89d8d 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java @@ -2,8 +2,6 @@ import org.grobid.core.utilities.TextUtilities; -import java.io.PrintStream; - /** * Represent all different evaluation given a specific model */ diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java index 49d4c80ea6..7044b7e706 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java @@ -182,12 +182,13 @@ public String getReport() { computeMetrics(); StringBuilder report = new StringBuilder(); - report.append(String.format("\n%-20s %-12s %-12s %-12s %-7s\n\n", + report.append(String.format("\n%-20s %-12s %-12s %-12s %-7s %-7s\n\n", "label", "accuracy", "precision", "recall", - "f1")); + "f1", + "support")); for (String label : getLabels()) { if (label.equals("") || label.equals("base") || label.equals("O")) { @@ -196,12 +197,14 @@ public String getReport() { LabelStat labelStat = getLabelStat(label); - report.append(String.format("%-20s %-12s %-12s %-12s %-7s\n", + report.append(String.format("%-20s %-12s %-12s %-12s %-7s %-7s\n", label, TextUtilities.formatTwoDecimals(labelStat.getAccuracy() * 100), TextUtilities.formatTwoDecimals(labelStat.getPrecision() * 100), TextUtilities.formatTwoDecimals(labelStat.getRecall() * 100), - TextUtilities.formatTwoDecimals(labelStat.getF1Score() * 100))); + TextUtilities.formatTwoDecimals(labelStat.getF1Score() * 100), + String.valueOf(labelStat.getSupport())) + ); } report.append("\n"); From 5c6f0cd940e4e571ddf27c90207a4f3cd4a7231e Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Tue, 16 Jul 2019 16:32:40 +0900 Subject: [PATCH 23/37] output label support #453 --- .../org/grobid/trainer/evaluation/Stats.java | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java index 7044b7e706..6a70b48e3b 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java @@ -182,7 +182,7 @@ public String getReport() { computeMetrics(); StringBuilder report = new StringBuilder(); - report.append(String.format("\n%-20s %-12s %-12s %-12s %-7s %-7s\n\n", + report.append(String.format("\n%-20s %-12s %-12s %-12s %-12s %-7s\n\n", "label", "accuracy", "precision", @@ -190,6 +190,8 @@ public String getReport() { "f1", "support")); + long supportSum = 0; + for (String label : getLabels()) { if (label.equals("") || label.equals("base") || label.equals("O")) { continue; @@ -197,31 +199,36 @@ public String getReport() { LabelStat labelStat = getLabelStat(label); - report.append(String.format("%-20s %-12s %-12s %-12s %-7s %-7s\n", + long support = labelStat.getSupport(); + report.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", label, TextUtilities.formatTwoDecimals(labelStat.getAccuracy() * 100), TextUtilities.formatTwoDecimals(labelStat.getPrecision() * 100), TextUtilities.formatTwoDecimals(labelStat.getRecall() * 100), TextUtilities.formatTwoDecimals(labelStat.getF1Score() * 100), - String.valueOf(labelStat.getSupport())) + String.valueOf(support)) ); + + supportSum += support; } report.append("\n"); - report.append(String.format("%-20s %-12s %-12s %-12s %-7s (micro average)\n", - "all fields", + report.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", + "all (micro avg.)", TextUtilities.formatTwoDecimals(getMicroAverageAccuracy() * 100), TextUtilities.formatTwoDecimals(getMicroAveragePrecision() * 100), TextUtilities.formatTwoDecimals(getMicroAverageRecall() * 100), - TextUtilities.formatTwoDecimals(getMicroAverageF1() * 100))); + TextUtilities.formatTwoDecimals(getMicroAverageF1() * 100), + String.valueOf(supportSum))); - report.append(String.format("%-20s %-12s %-12s %-12s %-7s (macro average)\n", - "", + report.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", + "all (macro avg.)", TextUtilities.formatTwoDecimals(getMacroAverageAccuracy() * 100), TextUtilities.formatTwoDecimals(getMacroAveragePrecision() * 100), TextUtilities.formatTwoDecimals(getMacroAverageRecall() * 100), - TextUtilities.formatTwoDecimals(getMacroAverageF1() * 100))); + TextUtilities.formatTwoDecimals(getMacroAverageF1() * 100), + String.valueOf(supportSum))); return report.toString(); } From 25af2b79da2f6af49e490039145109aa0bedfd14 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Tue, 16 Jul 2019 18:19:21 +0900 Subject: [PATCH 24/37] Set n = 10 by default and throw exception for n = 1#453 --- .../src/main/java/org/grobid/trainer/TrainerRunner.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java b/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java index cb50d22bee..67b5524973 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/TrainerRunner.java @@ -146,14 +146,19 @@ public static void main(String[] args) { System.out.println(AbstractTrainer.runSplitTrainingEvaluation(trainer, split)); break; case EVAL_N_FOLD: + if(numFolds == 1) { + throw new IllegalArgumentException("N should be > 1"); + } else { + if(numFolds == 0) { + numFolds = 10; + } + } if (StringUtils.isNotEmpty(outputFilePath)) { Path outputPath = Paths.get(outputFilePath); if (Files.exists(outputPath)) { System.err.println("Output file exists. "); } - AbstractTrainer.runNFoldEvaluation(trainer, numFolds, outputPath); } else { - String results = AbstractTrainer.runNFoldEvaluation(trainer, numFolds); System.out.println(results); } From 0049eccdfa0ab75f85c23f5a0350b12e2ed0228e Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Tue, 16 Jul 2019 23:58:07 +0900 Subject: [PATCH 25/37] Implementing averages on labels for 10-fold #453 --- .../org/grobid/trainer/AbstractTrainer.java | 97 +++++++++++++-- .../java/org/grobid/trainer/DateTrainer.java | 6 +- .../evaluation/EvaluationUtilities.java | 12 +- .../trainer/evaluation/LabelResult.java | 73 ++++++++++++ .../trainer/{ => evaluation}/LabelStat.java | 3 +- .../grobid/trainer/evaluation/ModelStats.java | 65 ++++++++--- .../org/grobid/trainer/evaluation/Stats.java | 110 +++++++++++------- .../java/org/grobid/trainer/StatsTest.java | 1 + .../evaluation/EvaluationUtilitiesTest.java | 1 - 9 files changed, 290 insertions(+), 78 deletions(-) create mode 100644 grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelResult.java rename grobid-trainer/src/main/java/org/grobid/trainer/{ => evaluation}/LabelStat.java (97%) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index c934a4df88..8ac6425a0b 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -13,10 +13,13 @@ import org.grobid.core.utilities.GrobidProperties; import org.grobid.core.utilities.TextUtilities; import org.grobid.trainer.evaluation.EvaluationUtilities; +import org.grobid.trainer.evaluation.LabelResult; import org.grobid.trainer.evaluation.ModelStats; +import org.grobid.trainer.evaluation.Stats; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.awt.*; import java.io.BufferedWriter; import java.io.File; import java.io.IOException; @@ -25,6 +28,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.*; +import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -179,7 +183,6 @@ public String nFoldEvaluate(int numFolds) { StringBuilder sb = new StringBuilder(); sb.append("Recap results for each fold:").append("\n\n"); - AtomicInteger counter = new AtomicInteger(0); List evaluationResults = foldMap.stream().map(fold -> { final File tempModelPath = new File(tmpDirectory + File.separator + getModel().getModelName() @@ -198,12 +201,16 @@ public String nFoldEvaluate(int numFolds) { return modelStats; }).collect(Collectors.toList()); + sb.append("\n").append("Summary results: ").append("\n"); Comparator f1ScoreComparator = (o1, o2) -> { - if (o1.getFieldStats().getMacroAverageF1() > o2.getFieldStats().getMacroAverageF1()) { + Stats fieldStatsO1 = o1.getFieldStats(); + Stats fieldStatsO2 = o2.getFieldStats(); + + if (fieldStatsO1.getMacroAverageF1() > fieldStatsO2.getMacroAverageF1()) { return 1; - } else if (o1.getFieldStats().getMacroAverageF1() < o2.getFieldStats().getMacroAverageF1()) { + } else if (fieldStatsO1.getMacroAverageF1() < fieldStatsO2.getMacroAverageF1()) { return -1; } else { return 0; @@ -224,29 +231,99 @@ public String nFoldEvaluate(int numFolds) { throw new GrobidException("Something wrong when computing evaluations " + "- best model metrics not found. "); }); - sb.append(bestModelStats.toString()).append("\n"); + sb.append(bestModelStats.toString()).append("\n").append("\n"); // Averages + sb.append("Average over " + numFolds + " folds: ").append("\n"); + + TreeMap averagesLabelStats = new TreeMap<>(); + int totalInstances = 0; + int correctInstances = 0; + for(ModelStats ms : evaluationResults) { + totalInstances += ms.getTotalInstances(); + correctInstances += ms.getCorrectInstance(); + for(Map.Entry entry : ms.getFieldStats().getLabelsResults().entrySet()) { + String key = entry.getKey(); + if(averagesLabelStats.containsKey(key)) { + averagesLabelStats.get(key).setAccuracy(averagesLabelStats.get(key).getAccuracy() + entry.getValue().getAccuracy()); + averagesLabelStats.get(key).setF1Score(averagesLabelStats.get(key).getF1Score() + entry.getValue().getF1Score()); + averagesLabelStats.get(key).setRecall(averagesLabelStats.get(key).getRecall() + entry.getValue().getF1Score()); + averagesLabelStats.get(key).setPrecision(averagesLabelStats.get(key).getPrecision() + entry.getValue().getPrecision()); + averagesLabelStats.get(key).setSupport(averagesLabelStats.get(key).getSupport() + entry.getValue().getSupport()); + } else { + averagesLabelStats.put(key, new LabelResult(key)); + averagesLabelStats.get(key).setAccuracy(entry.getValue().getAccuracy()); + averagesLabelStats.get(key).setF1Score(entry.getValue().getF1Score()); + averagesLabelStats.get(key).setRecall(entry.getValue().getF1Score()); + averagesLabelStats.get(key).setPrecision(entry.getValue().getPrecision()); + averagesLabelStats.get(key).setSupport(entry.getValue().getSupport()); + } + } + } + + sb.append(String.format("\n%-20s %-12s %-12s %-12s %-12s %-7s\n\n", + "label", + "accuracy", + "precision", + "recall", + "f1", + "support")); + + for(String label : averagesLabelStats.keySet()) { + LabelResult labelResult = averagesLabelStats.get(label); + double avgAccuracy = labelResult.getAccuracy() / evaluationResults.size(); + averagesLabelStats.get(label).setAccuracy(avgAccuracy); + + double avgF1Score = labelResult.getF1Score() / evaluationResults.size(); + averagesLabelStats.get(label).setF1Score(avgF1Score); + + double avgPrecision = labelResult.getPrecision() / evaluationResults.size(); + averagesLabelStats.get(label).setPrecision(avgPrecision); + + double avgRecall = labelResult.getRecall() / evaluationResults.size(); + averagesLabelStats.get(label).setRecall(avgRecall); + + sb.append(labelResult.toString()); + } + OptionalDouble averageF1 = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageF1()).average(); - OptionalDouble averagePrecision = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAveragePrecision()).average(); - OptionalDouble averageRecall = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageRecall()).average(); + OptionalDouble averagePrecision = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageF1()).average(); + OptionalDouble averageRecall = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageF1()).average(); + OptionalDouble averageAccuracy = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageAccuracy()).average(); - sb.append("average over " + numFolds + " folds: ").append("\n"); + double avgAccuracy = averageAccuracy.orElseGet(() -> { + throw new GrobidException("Missing average accuracy. Something went wrong. Please check. "); + }); double avgF1 = averageF1.orElseGet(() -> { throw new GrobidException("Missing average F1. Something went wrong. Please check. "); }); - sb.append("\tmacro f1 = " + TextUtilities.formatTwoDecimals(avgF1 * 100)).append("\n"); double avgPrecision = averagePrecision.orElseGet(() -> { throw new GrobidException("Missing average precision. Something went wrong. Please check. "); }); - sb.append("\tmacro precision = " + TextUtilities.formatTwoDecimals(avgPrecision * 100)).append("\n"); double avgRecall = averageRecall.orElseGet(() -> { throw new GrobidException("Missing average recall. Something went wrong. Please check. "); }); - sb.append("\tmacro recall = " + TextUtilities.formatTwoDecimals(avgRecall * 100)).append("\n"); + + sb.append("\n"); + + sb.append(String.format("%-20s %-12s %-12s %-12s %-7s\n", + "all (macro avg.)", + TextUtilities.formatTwoDecimals(avgAccuracy * 100), + TextUtilities.formatTwoDecimals(avgPrecision * 100), + TextUtilities.formatTwoDecimals( avgRecall * 100), + TextUtilities.formatTwoDecimals(avgF1 * 100)) +// String.valueOf(supportSum)) + ); + + sb.append("\n===== Instance-level results =====\n\n"); + sb.append(String.format("%-27s %d\n", "Total expected instances:", totalInstances)); + sb.append(String.format("%-27s %d\n", "Correct instances:", correctInstances)); + sb.append(String.format("%-27s %s\n", + "Instance-level recall:", + TextUtilities.formatTwoDecimals((double) correctInstances / totalInstances * 100))); return sb.toString(); diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java index de95b2060c..f0567c314d 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java @@ -161,8 +161,10 @@ else if ( (writer2 != null) && (writer3 == null) ) public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); DateTrainer trainer = new DateTrainer(); - AbstractTrainer.runTraining(trainer); - System.out.println(AbstractTrainer.runEvaluation(trainer)); +// AbstractTrainer.runTraining(trainer); +// System.out.println(AbstractTrainer.runEvaluation(trainer)); + + System.out.println(AbstractTrainer.runNFoldEvaluation(trainer, 10)); System.exit(0); } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java index 0e5e3b87f9..454bdcdcf8 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java @@ -15,8 +15,6 @@ import java.util.StringTokenizer; import java.util.function.Function; -import org.grobid.trainer.LabelStat; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,7 +41,7 @@ public static String taggerRun(List ress, Tagger tagger) { // we have to re-inject the pre-tags because they are removed by the JNI // parse method - ArrayList pretags = new ArrayList(); + ArrayList pretags = new ArrayList<>(); // add context for (String piece : ress) { if (piece.trim().length() == 0) { @@ -64,7 +62,7 @@ public static String taggerRun(List ress, Tagger tagger) { res.append(" \n"); // clear internal context tagger.clear(); - pretags = new ArrayList(); + pretags = new ArrayList<>(); } else { tagger.add(piece); tagger.add("\n"); @@ -127,8 +125,8 @@ public static ModelStats evaluateStandard(String path, Function, St public static ModelStats computeStats(String theResult) { ModelStats modelStats = new ModelStats(); // report token-level results - Stats wordStats = tokenLevelStats(theResult); - modelStats.setTokenStats(wordStats); +// Stats wordStats = tokenLevelStats(theResult); +// modelStats.setTokenStats(wordStats); // report field-level results Stats fieldStats = fieldLevelStats(theResult); @@ -361,7 +359,7 @@ private static void processCounters(Stats stats, String obtained, String expecte } public static String computeMetrics(Stats stats) { - return stats.getReport(); + return stats.getOldReport(); } } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelResult.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelResult.java new file mode 100644 index 0000000000..7a8ce608a8 --- /dev/null +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelResult.java @@ -0,0 +1,73 @@ +package org.grobid.trainer.evaluation; + +import org.grobid.core.utilities.TextUtilities; + +public class LabelResult { + + private final String label; + private double accuracy; + private double precision; + private double recall; + private double f1Score; + private long support; + + public LabelResult(String label) { + this.label = label; + } + + public void setAccuracy(double accuracy) { + this.accuracy = accuracy; + } + + public double getAccuracy() { + return accuracy; + } + + public String getLabel() { + return label; + } + + public void setPrecision(double precision) { + this.precision = precision; + } + + public double getPrecision() { + return precision; + } + + public void setRecall(double recall) { + this.recall = recall; + } + + public double getRecall() { + return recall; + } + + public void setF1Score(double f1Score) { + this.f1Score = f1Score; + } + + public double getF1Score() { + return f1Score; + } + + public void setSupport(long support) { + this.support = support; + + } + + public String toString() { + return String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", + label, + TextUtilities.formatTwoDecimals(getAccuracy() * 100), + TextUtilities.formatTwoDecimals(getPrecision() * 100), + TextUtilities.formatTwoDecimals(getRecall() * 100), + TextUtilities.formatTwoDecimals(getF1Score() * 100), + String.valueOf(getSupport()) + ); + } + + public long getSupport() { + return support; + } +} diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/LabelStat.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelStat.java similarity index 97% rename from grobid-trainer/src/main/java/org/grobid/trainer/LabelStat.java rename to grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelStat.java index 670fdcdc40..c17d9b4dce 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/LabelStat.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelStat.java @@ -1,5 +1,6 @@ -package org.grobid.trainer; +package org.grobid.trainer.evaluation; +/** Model the results for each label **/ public final class LabelStat { private int falsePositive = 0; private int falseNegative = 0; diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java index da1ff89d8d..9392fc903f 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java @@ -2,15 +2,19 @@ import org.grobid.core.utilities.TextUtilities; +import java.util.Map; +import java.util.TreeMap; + /** * Represent all different evaluation given a specific model */ public class ModelStats { private int totalInstances; private int correctInstance; - private Stats tokenStats; + // private Stats tokenStats; private Stats fieldStats; + public void setTotalInstances(int totalInstances) { this.totalInstances = totalInstances; } @@ -27,13 +31,13 @@ public int getCorrectInstance() { return correctInstance; } - public void setTokenStats(Stats tokenStats) { - this.tokenStats = tokenStats; - } +// public void setTokenStats(Stats tokenStats) { +// this.tokenStats = tokenStats; +// } - public Stats getTokenStats() { - return tokenStats; - } +// public Stats getTokenStats() { +// return tokenStats; +// } public void setFieldStats(Stats fieldStats) { this.fieldStats = fieldStats; @@ -53,15 +57,39 @@ public double getInstanceRecall() { public String toString() { StringBuilder report = new StringBuilder(); - // report token-level results -// Stats wordStats = getTokenStats(); -// report.append("\n===== Token-level results =====\n\n"); -// report.append(wordStats.getReport()); - - // report field-level results Stats fieldStats = getFieldStats(); report.append("\n===== Field-level results =====\n"); - report.append(fieldStats.getReport()); + report.append(String.format("\n%-20s %-12s %-12s %-12s %-12s %-7s\n\n", + "label", + "accuracy", + "precision", + "recall", + "f1", + "support")); + + + for (Map.Entry labelResult : fieldStats.getLabelsResults().entrySet()) { + report.append(labelResult.getValue()); + } + + report.append("\n"); + + report.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", + "all (micro avg.)", + TextUtilities.formatTwoDecimals(fieldStats.getMicroAverageAccuracy() * 100), + TextUtilities.formatTwoDecimals(fieldStats.getMicroAveragePrecision() * 100), + TextUtilities.formatTwoDecimals(fieldStats.getMicroAverageRecall() * 100), + TextUtilities.formatTwoDecimals(fieldStats.getMicroAverageF1() * 100), + String.valueOf(getSupportSum()))); + + report.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", + "all (macro avg.)", + TextUtilities.formatTwoDecimals(fieldStats.getMacroAverageAccuracy() * 100), + TextUtilities.formatTwoDecimals(fieldStats.getMacroAveragePrecision() * 100), + TextUtilities.formatTwoDecimals(fieldStats.getMacroAverageRecall() * 100), + TextUtilities.formatTwoDecimals(fieldStats.getMacroAverageF1() * 100), + String.valueOf(getSupportSum()))); + // instance-level: instances are separated by a new line in the result file report.append("\n===== Instance-level results =====\n\n"); @@ -73,4 +101,13 @@ public String toString() { return report.toString(); } + + public long getSupportSum() { + long supportSum = 0; + for (LabelResult labelResult : fieldStats.getLabelsResults().values()) { + supportSum += labelResult.getSupport(); + } + return supportSum; + } + } diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java index 6a70b48e3b..468a96d4da 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/Stats.java @@ -5,7 +5,6 @@ import org.grobid.core.exceptions.*; import org.grobid.core.utilities.TextUtilities; -import org.grobid.trainer.LabelStat; /** * Contains the single statistic computation for evaluation @@ -178,19 +177,10 @@ public void computeMetrics() { requiredToRecomputeMetrics = false; } - public String getReport() { + public TreeMap getLabelsResults() { computeMetrics(); - StringBuilder report = new StringBuilder(); - report.append(String.format("\n%-20s %-12s %-12s %-12s %-12s %-7s\n\n", - "label", - "accuracy", - "precision", - "recall", - "f1", - "support")); - - long supportSum = 0; + TreeMap result = new TreeMap<>(); for (String label : getLabels()) { if (label.equals("") || label.equals("base") || label.equals("O")) { @@ -198,39 +188,17 @@ public String getReport() { } LabelStat labelStat = getLabelStat(label); - - long support = labelStat.getSupport(); - report.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", - label, - TextUtilities.formatTwoDecimals(labelStat.getAccuracy() * 100), - TextUtilities.formatTwoDecimals(labelStat.getPrecision() * 100), - TextUtilities.formatTwoDecimals(labelStat.getRecall() * 100), - TextUtilities.formatTwoDecimals(labelStat.getF1Score() * 100), - String.valueOf(support)) - ); - - supportSum += support; + LabelResult labelResult = new LabelResult(label); + labelResult.setAccuracy(labelStat.getAccuracy()); + labelResult.setPrecision(labelStat.getPrecision()); + labelResult.setRecall(labelStat.getRecall()); + labelResult.setF1Score(labelStat.getF1Score()); + labelResult.setSupport(labelStat.getSupport()); + + result.put(label, labelResult); } - report.append("\n"); - - report.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", - "all (micro avg.)", - TextUtilities.formatTwoDecimals(getMicroAverageAccuracy() * 100), - TextUtilities.formatTwoDecimals(getMicroAveragePrecision() * 100), - TextUtilities.formatTwoDecimals(getMicroAverageRecall() * 100), - TextUtilities.formatTwoDecimals(getMicroAverageF1() * 100), - String.valueOf(supportSum))); - - report.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", - "all (macro avg.)", - TextUtilities.formatTwoDecimals(getMacroAverageAccuracy() * 100), - TextUtilities.formatTwoDecimals(getMacroAveragePrecision() * 100), - TextUtilities.formatTwoDecimals(getMacroAverageRecall() * 100), - TextUtilities.formatTwoDecimals(getMacroAverageF1() * 100), - String.valueOf(supportSum))); - - return report.toString(); + return result; } @@ -317,5 +285,61 @@ public double getMacroAverageF1() { return Math.min(1.0, cumulated_f1 / totalValidFields); } + + public String getOldReport() { + computeMetrics(); + + StringBuilder report = new StringBuilder(); + report.append(String.format("\n%-20s %-12s %-12s %-12s %-12s %-7s\n\n", + "label", + "accuracy", + "precision", + "recall", + "f1", + "support")); + + long supportSum = 0; + + for (String label : getLabels()) { + if (label.equals("") || label.equals("base") || label.equals("O")) { + continue; + } + + LabelStat labelStat = getLabelStat(label); + + long support = labelStat.getSupport(); + report.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", + label, + TextUtilities.formatTwoDecimals(labelStat.getAccuracy() * 100), + TextUtilities.formatTwoDecimals(labelStat.getPrecision() * 100), + TextUtilities.formatTwoDecimals(labelStat.getRecall() * 100), + TextUtilities.formatTwoDecimals(labelStat.getF1Score() * 100), + String.valueOf(support)) + ); + + supportSum += support; + } + + report.append("\n"); + + report.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", + "all (micro avg.)", + TextUtilities.formatTwoDecimals(getMicroAverageAccuracy() * 100), + TextUtilities.formatTwoDecimals(getMicroAveragePrecision() * 100), + TextUtilities.formatTwoDecimals(getMicroAverageRecall() * 100), + TextUtilities.formatTwoDecimals(getMicroAverageF1() * 100), + String.valueOf(supportSum))); + + report.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", + "all (macro avg.)", + TextUtilities.formatTwoDecimals(getMacroAverageAccuracy() * 100), + TextUtilities.formatTwoDecimals(getMacroAveragePrecision() * 100), + TextUtilities.formatTwoDecimals(getMacroAverageRecall() * 100), + TextUtilities.formatTwoDecimals(getMacroAverageF1() * 100), + String.valueOf(supportSum))); + + return report.toString(); + + } } diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java index fd0e403fae..77c5612233 100644 --- a/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java +++ b/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java @@ -1,5 +1,6 @@ package org.grobid.trainer; +import org.grobid.trainer.evaluation.LabelStat; import org.grobid.trainer.evaluation.Stats; import org.junit.Before; import org.junit.Ignore; diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java index 3d82877b9d..5d25a245ca 100644 --- a/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java +++ b/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java @@ -1,7 +1,6 @@ package org.grobid.trainer.evaluation; import org.apache.commons.io.IOUtils; -import org.grobid.trainer.LabelStat; import org.junit.Test; import java.nio.charset.StandardCharsets; From 31da48b855d369f8c9f312af56ed2d8eaeda097e Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Wed, 17 Jul 2019 00:14:39 +0900 Subject: [PATCH 26/37] test jdk 11 --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 6ad8e70cd4..2a3a948449 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,7 +2,7 @@ language: java sudo: true jdk: - - oraclejdk8 + - oraclejdk11 env: global: From ff9e6b293a78f55567bdd08951589bf714068ca7 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Wed, 17 Jul 2019 00:21:50 +0900 Subject: [PATCH 27/37] calm down and go to sleep #453 --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 2a3a948449..0cb6292ab9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,7 +2,7 @@ language: java sudo: true jdk: - - oraclejdk11 + - openjdk8 env: global: From 15c9a879ac7feb8cdcb3e3662b5120c9de525ea2 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Wed, 17 Jul 2019 11:09:57 +0900 Subject: [PATCH 28/37] Adding raw result output #453 --- .../org/grobid/trainer/AbstractTrainer.java | 61 ++++++++++++++++--- .../main/java/org/grobid/trainer/Trainer.java | 11 ++-- .../evaluation/EvaluationUtilities.java | 1 + .../grobid/trainer/evaluation/ModelStats.java | 20 +++++- 4 files changed, 79 insertions(+), 14 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index 8ac6425a0b..562e56f259 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -117,8 +117,19 @@ protected void renameModels(final File oldModelPath, final File tempModelPath) { @Override public String evaluate() { + return evaluate(false); + } + + @Override + public String evaluate(boolean includeRawResults) { createCRFPPData(getEvalCorpusPath(), evalDataPath); - return EvaluationUtilities.evaluateStandard(evalDataPath.getAbsolutePath(), getTagger()).toString(); + return EvaluationUtilities.evaluateStandard(evalDataPath.getAbsolutePath(), getTagger()).toString(includeRawResults); + } + + @Override + public String evaluate(GenericTagger tagger, boolean includeRawResults) { + createCRFPPData(getEvalCorpusPath(), evalDataPath); + return EvaluationUtilities.evaluateStandard(evalDataPath.getAbsolutePath(), tagger).toString(includeRawResults); } @Override @@ -154,6 +165,11 @@ public String splitTrainEvaluate(Double split) { @Override public String nFoldEvaluate(int numFolds) { + return nFoldEvaluate(numFolds, false); + } + + @Override + public String nFoldEvaluate(int numFolds, boolean includeRawResults) { final File dataPath = trainDataPath; createCRFPPData(getCorpusPath(), dataPath); GenericTrainer trainer = TrainerFactory.getTrainer(); @@ -187,16 +203,39 @@ public String nFoldEvaluate(int numFolds) { List evaluationResults = foldMap.stream().map(fold -> { final File tempModelPath = new File(tmpDirectory + File.separator + getModel().getModelName() + "_nfold_" + counter.getAndIncrement() + ".wapiti"); - System.out.println("Saving model in " + tempModelPath); + sb.append("Saving model in " + tempModelPath).append("\n"); - System.out.println("Training input data: " + fold.getLeft()); + sb.append("Training input data: " + fold.getLeft()).append("\n"); trainer.train(getTemplatePath(), new File(fold.getLeft()), tempModelPath, GrobidProperties.getNBThreads(), model); - System.out.println("Evaluation input data: " + fold.getRight()); - ModelStats modelStats = EvaluationUtilities.evaluateStandard(fold.getRight(), getTagger()); - System.out.println(modelStats.toString()); + sb.append("Evaluation input data: " + fold.getRight()).append("\n"); + + //TODO: find a better solution!! + GrobidModel tmpModel = new GrobidModel() { + @Override + public String getFolderName() { + return tmpDirectory.getAbsolutePath(); + } + + @Override + public String getModelPath() { + return tempModelPath.getAbsolutePath(); + } + + @Override + public String getModelName() { + return model.getModelName(); + } + + @Override + public String getTemplateName() { + return model.getTemplateName(); + } + }; + + ModelStats modelStats = EvaluationUtilities.evaluateStandard(fold.getRight(), TaggerFactory.getTagger(tmpModel)); sb.append(" ====================== Fold " + counter.get() + " ====================== ").append("\n"); - sb.append(modelStats.toString()).append("\n"); + sb.append(modelStats.toString(includeRawResults)).append("\n"); return modelStats; }).collect(Collectors.toList()); @@ -493,11 +532,11 @@ public File getEvalDataPath() { return evalDataPath; } - public static String runEvaluation(final Trainer trainer) { + public static String runEvaluation(final Trainer trainer, boolean includeRawResults) { long start = System.currentTimeMillis(); String report = ""; try { - report = trainer.evaluate(); + report = trainer.evaluate(includeRawResults); } catch (Exception e) { throw new GrobidException("An exception occurred while evaluating Grobid.", e); } @@ -507,6 +546,10 @@ public static String runEvaluation(final Trainer trainer) { return report; } + public static String runEvaluation(final Trainer trainer) { + return trainer.evaluate(false); + } + public static String runSplitTrainingEvaluation(final Trainer trainer, Double split) { long start = System.currentTimeMillis(); String report = ""; diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/Trainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/Trainer.java index 4685f770a3..c42535e030 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/Trainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/Trainer.java @@ -2,6 +2,7 @@ import org.grobid.core.GrobidModel; import org.grobid.core.GrobidModels; +import org.grobid.core.engines.tagging.GenericTagger; import java.io.File; @@ -16,15 +17,17 @@ public interface Trainer { void train(); - /** - * - * @return a report - */ String evaluate(); + String evaluate(boolean includeRawResults); + + String evaluate(GenericTagger tagger, boolean includeRawResults); + String splitTrainEvaluate(Double split); String nFoldEvaluate(int folds); + String nFoldEvaluate(int folds, boolean includeRawResults); + GrobidModel getModel(); } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java index 454bdcdcf8..1fdd2bde8e 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java @@ -124,6 +124,7 @@ public static ModelStats evaluateStandard(String path, Function, St public static ModelStats computeStats(String theResult) { ModelStats modelStats = new ModelStats(); + modelStats.setRawResults(theResult); // report token-level results // Stats wordStats = tokenLevelStats(theResult); // modelStats.setTokenStats(wordStats); diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java index 9392fc903f..926ed6faf0 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java @@ -3,7 +3,6 @@ import org.grobid.core.utilities.TextUtilities; import java.util.Map; -import java.util.TreeMap; /** * Represent all different evaluation given a specific model @@ -13,6 +12,7 @@ public class ModelStats { private int correctInstance; // private Stats tokenStats; private Stats fieldStats; + private String rawResults; public void setTotalInstances(int totalInstances) { @@ -55,8 +55,19 @@ public double getInstanceRecall() { } public String toString() { + return toString(false); + } + + public String toString(boolean includeRawResults) { StringBuilder report = new StringBuilder(); + if(includeRawResults) { + report.append("=== RAw RESULTS ===").append("\n"); + report.append(getRawResults()).append("\n"); + report.append("=== END RAw RESULTS ===").append("\n"); + } + + Stats fieldStats = getFieldStats(); report.append("\n===== Field-level results =====\n"); report.append(String.format("\n%-20s %-12s %-12s %-12s %-12s %-7s\n\n", @@ -110,4 +121,11 @@ public long getSupportSum() { return supportSum; } + public String getRawResults() { + return rawResults; + } + + public void setRawResults(String rawResults) { + this.rawResults = rawResults; + } } From 3271c486e46d6a2e927fc5f15e2f7179530ba331 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Wed, 17 Jul 2019 11:15:38 +0900 Subject: [PATCH 29/37] cosmetics #453 --- .../java/org/grobid/trainer/DateTrainer.java | 47 ++++++++++--------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java index f0567c314d..80133c8435 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java @@ -24,8 +24,8 @@ public DateTrainer() { } /** - * Add the selected features to a date example set - * + * Add the selected features to a date example set + * * @param corpusDir * a path where corpus files are located * @param trainingOutputPath @@ -38,8 +38,8 @@ public int createCRFPPData(final File corpusDir, final File trainingOutputPath) } /** - * Add the selected features to a date example set - * + * Add the selected features to a date example set + * * @param corpusDir * a path where corpus files are located * @param trainingOutputPath @@ -47,13 +47,13 @@ public int createCRFPPData(final File corpusDir, final File trainingOutputPath) * @param evalOutputPath * path where to store the temporary evaluation data * @param splitRatio - * ratio to consider for separating training and evaluation data, e.g. 0.8 for 80% - * @return the total number of used corpus items + * ratio to consider for separating training and evaluation data, e.g. 0.8 for 80% + * @return the total number of used corpus items */ @Override - public int createCRFPPData(final File corpusDir, - final File trainingOutputPath, - final File evalOutputPath, + public int createCRFPPData(final File corpusDir, + final File trainingOutputPath, + final File evalOutputPath, double splitRatio) { int totalExamples = 0; try { @@ -69,7 +69,7 @@ public int createCRFPPData(final File corpusDir, public boolean accept(File dir, String name) { return name.endsWith(".xml"); } - }); + }); if (refFiles == null) { throw new IllegalStateException("Folder " + corpusDir.getAbsolutePath() @@ -85,7 +85,7 @@ public boolean accept(File dir, String name) { os2 = new FileOutputStream(trainingOutputPath); writer2 = new OutputStreamWriter(os2, "UTF8"); } - + // the file for writing the evaluation data OutputStream os3 = null; Writer writer3 = null; @@ -93,7 +93,7 @@ public boolean accept(File dir, String name) { os3 = new FileOutputStream(evalOutputPath); writer3 = new OutputStreamWriter(os3, "UTF8"); } - + // get a factory for SAX parser SAXParserFactory spf = SAXParserFactory.newInstance(); @@ -114,22 +114,22 @@ public boolean accept(File dir, String name) { // we can now add the features String headerDates = FeaturesVectorDate.addFeaturesDate(labeled); - + // format with features for sequence tagging... // given the split ratio we write either in the training file or the evaluation file String[] chunks = headerDates.split("\n \n"); - + for(int i=0; i Date: Wed, 17 Jul 2019 11:26:37 +0900 Subject: [PATCH 30/37] Improving visualisation - more cosmetics #453 --- .../src/main/java/org/grobid/trainer/AbstractTrainer.java | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index 562e56f259..b33229c6ba 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -201,6 +201,9 @@ public String nFoldEvaluate(int numFolds, boolean includeRawResults) { AtomicInteger counter = new AtomicInteger(0); List evaluationResults = foldMap.stream().map(fold -> { + sb.append("\n"); + sb.append("====================== Fold " + counter.get() + " ====================== ").append("\n"); + final File tempModelPath = new File(tmpDirectory + File.separator + getModel().getModelName() + "_nfold_" + counter.getAndIncrement() + ".wapiti"); sb.append("Saving model in " + tempModelPath).append("\n"); @@ -234,8 +237,9 @@ public String getTemplateName() { ModelStats modelStats = EvaluationUtilities.evaluateStandard(fold.getRight(), TaggerFactory.getTagger(tmpModel)); - sb.append(" ====================== Fold " + counter.get() + " ====================== ").append("\n"); - sb.append(modelStats.toString(includeRawResults)).append("\n"); + sb.append(modelStats.toString(includeRawResults)); + sb.append("\n"); + sb.append("\n"); return modelStats; }).collect(Collectors.toList()); From 62a897c97baa59c617b4591632a20a594b79671f Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Wed, 17 Jul 2019 11:38:06 +0900 Subject: [PATCH 31/37] Adding output of raw results for n-fold evaluation #453 --- .../org/grobid/trainer/AbstractTrainer.java | 23 +++++++++++++------ .../java/org/grobid/trainer/DateTrainer.java | 2 +- .../grobid/trainer/evaluation/ModelStats.java | 4 ++-- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index b33229c6ba..dfe1d2695a 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -203,6 +203,7 @@ public String nFoldEvaluate(int numFolds, boolean includeRawResults) { List evaluationResults = foldMap.stream().map(fold -> { sb.append("\n"); sb.append("====================== Fold " + counter.get() + " ====================== ").append("\n"); + System.out.println("====================== Fold " + counter.get() + " ====================== "); final File tempModelPath = new File(tmpDirectory + File.separator + getModel().getModelName() + "_nfold_" + counter.getAndIncrement() + ".wapiti"); @@ -282,12 +283,12 @@ public String getTemplateName() { TreeMap averagesLabelStats = new TreeMap<>(); int totalInstances = 0; int correctInstances = 0; - for(ModelStats ms : evaluationResults) { + for (ModelStats ms : evaluationResults) { totalInstances += ms.getTotalInstances(); correctInstances += ms.getCorrectInstance(); - for(Map.Entry entry : ms.getFieldStats().getLabelsResults().entrySet()) { + for (Map.Entry entry : ms.getFieldStats().getLabelsResults().entrySet()) { String key = entry.getKey(); - if(averagesLabelStats.containsKey(key)) { + if (averagesLabelStats.containsKey(key)) { averagesLabelStats.get(key).setAccuracy(averagesLabelStats.get(key).getAccuracy() + entry.getValue().getAccuracy()); averagesLabelStats.get(key).setF1Score(averagesLabelStats.get(key).getF1Score() + entry.getValue().getF1Score()); averagesLabelStats.get(key).setRecall(averagesLabelStats.get(key).getRecall() + entry.getValue().getF1Score()); @@ -312,7 +313,7 @@ public String getTemplateName() { "f1", "support")); - for(String label : averagesLabelStats.keySet()) { + for (String label : averagesLabelStats.keySet()) { LabelResult labelResult = averagesLabelStats.get(label); double avgAccuracy = labelResult.getAccuracy() / evaluationResults.size(); averagesLabelStats.get(label).setAccuracy(avgAccuracy); @@ -356,7 +357,7 @@ public String getTemplateName() { "all (macro avg.)", TextUtilities.formatTwoDecimals(avgAccuracy * 100), TextUtilities.formatTwoDecimals(avgPrecision * 100), - TextUtilities.formatTwoDecimals( avgRecall * 100), + TextUtilities.formatTwoDecimals(avgRecall * 100), TextUtilities.formatTwoDecimals(avgF1 * 100)) // String.valueOf(supportSum)) ); @@ -570,8 +571,12 @@ public static String runSplitTrainingEvaluation(final Trainer trainer, Double sp } public static void runNFoldEvaluation(final Trainer trainer, int numFolds, Path outputFile) { + runNFoldEvaluation(trainer, numFolds, outputFile, false); + } + + public static void runNFoldEvaluation(final Trainer trainer, int numFolds, Path outputFile, boolean includeRawResults) { - String report = runNFoldEvaluation(trainer, numFolds); + String report = runNFoldEvaluation(trainer, numFolds, includeRawResults); try (BufferedWriter writer = Files.newBufferedWriter(outputFile)) { writer.write(report); @@ -583,10 +588,14 @@ public static void runNFoldEvaluation(final Trainer trainer, int numFolds, Path } public static String runNFoldEvaluation(final Trainer trainer, int numFolds) { + return runNFoldEvaluation(trainer, numFolds, false); + } + + public static String runNFoldEvaluation(final Trainer trainer, int numFolds, boolean includeRawResults) { long start = System.currentTimeMillis(); String report = ""; try { - report = trainer.nFoldEvaluate(numFolds); + report = trainer.nFoldEvaluate(numFolds, includeRawResults); } catch (Exception e) { throw new GrobidException("An exception occurred while evaluating Grobid.", e); diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java index 80133c8435..0d168dd149 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java @@ -162,7 +162,7 @@ public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); DateTrainer trainer = new DateTrainer(); - System.out.println(AbstractTrainer.runNFoldEvaluation(trainer, 10)); + System.out.println(AbstractTrainer.runNFoldEvaluation(trainer, 10, true)); AbstractTrainer.runTraining(trainer); System.out.println(AbstractTrainer.runEvaluation(trainer)); diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java index 926ed6faf0..86847b62f6 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/ModelStats.java @@ -62,9 +62,9 @@ public String toString(boolean includeRawResults) { StringBuilder report = new StringBuilder(); if(includeRawResults) { - report.append("=== RAw RESULTS ===").append("\n"); + report.append("=== START RAW RESULTS ===").append("\n"); report.append(getRawResults()).append("\n"); - report.append("=== END RAw RESULTS ===").append("\n"); + report.append("=== END RAw RESULTS ===").append("\n").append("\n"); } From 4f308a562d68528cb2501fbb8d403d6641b77f2c Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Wed, 17 Jul 2019 12:22:58 +0900 Subject: [PATCH 32/37] fixing copy-pasta distraction problem #453 --- .../src/main/java/org/grobid/trainer/AbstractTrainer.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index dfe1d2695a..8806b71afc 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -331,8 +331,8 @@ public String getTemplateName() { } OptionalDouble averageF1 = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageF1()).average(); - OptionalDouble averagePrecision = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageF1()).average(); - OptionalDouble averageRecall = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageF1()).average(); + OptionalDouble averagePrecision = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAveragePrecision()).average(); + OptionalDouble averageRecall = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageRecall()).average(); OptionalDouble averageAccuracy = evaluationResults.stream().mapToDouble(e -> e.getFieldStats().getMacroAverageAccuracy()).average(); double avgAccuracy = averageAccuracy.orElseGet(() -> { From e675b114b8424ece76cc89e42af3992830b97195 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Wed, 17 Jul 2019 12:55:44 +0900 Subject: [PATCH 33/37] FIxing other minor and nasty annoying errors #453 --- .../src/main/java/org/grobid/trainer/AbstractTrainer.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java index 8806b71afc..804f1284cb 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/AbstractTrainer.java @@ -291,14 +291,14 @@ public String getTemplateName() { if (averagesLabelStats.containsKey(key)) { averagesLabelStats.get(key).setAccuracy(averagesLabelStats.get(key).getAccuracy() + entry.getValue().getAccuracy()); averagesLabelStats.get(key).setF1Score(averagesLabelStats.get(key).getF1Score() + entry.getValue().getF1Score()); - averagesLabelStats.get(key).setRecall(averagesLabelStats.get(key).getRecall() + entry.getValue().getF1Score()); + averagesLabelStats.get(key).setRecall(averagesLabelStats.get(key).getRecall() + entry.getValue().getRecall()); averagesLabelStats.get(key).setPrecision(averagesLabelStats.get(key).getPrecision() + entry.getValue().getPrecision()); averagesLabelStats.get(key).setSupport(averagesLabelStats.get(key).getSupport() + entry.getValue().getSupport()); } else { averagesLabelStats.put(key, new LabelResult(key)); averagesLabelStats.get(key).setAccuracy(entry.getValue().getAccuracy()); averagesLabelStats.get(key).setF1Score(entry.getValue().getF1Score()); - averagesLabelStats.get(key).setRecall(entry.getValue().getF1Score()); + averagesLabelStats.get(key).setRecall(entry.getValue().getRecall()); averagesLabelStats.get(key).setPrecision(entry.getValue().getPrecision()); averagesLabelStats.get(key).setSupport(entry.getValue().getSupport()); } @@ -315,6 +315,7 @@ public String getTemplateName() { for (String label : averagesLabelStats.keySet()) { LabelResult labelResult = averagesLabelStats.get(label); + double avgAccuracy = labelResult.getAccuracy() / evaluationResults.size(); averagesLabelStats.get(label).setAccuracy(avgAccuracy); From 149d3b71ceca44e1a4137cbb7f7e51d4be41fc30 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Wed, 17 Jul 2019 14:00:46 +0900 Subject: [PATCH 34/37] do not include the raw results in the output #453 --- .../src/main/java/org/grobid/trainer/DateTrainer.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java index 0d168dd149..80133c8435 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java @@ -162,7 +162,7 @@ public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); DateTrainer trainer = new DateTrainer(); - System.out.println(AbstractTrainer.runNFoldEvaluation(trainer, 10, true)); + System.out.println(AbstractTrainer.runNFoldEvaluation(trainer, 10)); AbstractTrainer.runTraining(trainer); System.out.println(AbstractTrainer.runEvaluation(trainer)); From 270c83ed5f7a5a14e98865e284c01dac0fca7e1a Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Thu, 8 Aug 2019 11:59:52 +0900 Subject: [PATCH 35/37] fixing test (cherry picked from commit 67243f42c09aac9f16ef5f7d1b23472058a99f10) --- .../test/java/org/grobid/core/sax/PDFALTOSaxHandlerTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grobid-core/src/test/java/org/grobid/core/sax/PDFALTOSaxHandlerTest.java b/grobid-core/src/test/java/org/grobid/core/sax/PDFALTOSaxHandlerTest.java index bac8f69818..41c78ee04e 100644 --- a/grobid-core/src/test/java/org/grobid/core/sax/PDFALTOSaxHandlerTest.java +++ b/grobid-core/src/test/java/org/grobid/core/sax/PDFALTOSaxHandlerTest.java @@ -81,7 +81,7 @@ public void testParsing_shouldWork() throws Exception { List tokenList = target.getTokenization(); - assertThat(tokenList.stream().filter(t -> t.getText().equals("newly")).count(), is(1)); + assertThat(tokenList.stream().filter(t -> t.getText().equals("newly")).count(), is(1L)); } } \ No newline at end of file From 7fc02a888eda01a7771df33bffee2db9f1268d35 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Fri, 9 Aug 2019 18:00:42 +0900 Subject: [PATCH 36/37] Remove 10-fold from date trainer - forgot there from testing --- grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java | 1 - 1 file changed, 1 deletion(-) diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java b/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java index 80133c8435..6241238e0f 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/DateTrainer.java @@ -162,7 +162,6 @@ public static void main(String[] args) throws Exception { GrobidProperties.getInstance(); DateTrainer trainer = new DateTrainer(); - System.out.println(AbstractTrainer.runNFoldEvaluation(trainer, 10)); AbstractTrainer.runTraining(trainer); System.out.println(AbstractTrainer.runEvaluation(trainer)); From abc94909f21d6a09fa3023b5e8ddd9e2398c3b45 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Tue, 20 Aug 2019 17:27:04 +0900 Subject: [PATCH 37/37] adding more tests for evaluation and fixing small bug on support metrics --- .../engines/tagging/GenericTaggerUtils.java | 35 ++++++--------- .../evaluation/EvaluationUtilities.java | 19 +++----- .../grobid/trainer/evaluation/LabelStat.java | 2 +- .../java/org/grobid/trainer/StatsTest.java | 1 + .../evaluation/EvaluationUtilitiesTest.java | 45 ++++++++++++++++--- .../test/resources/sample.wapiti.output.3.txt | 13 ++++++ 6 files changed, 74 insertions(+), 41 deletions(-) create mode 100644 grobid-trainer/src/test/resources/sample.wapiti.output.3.txt diff --git a/grobid-core/src/main/java/org/grobid/core/engines/tagging/GenericTaggerUtils.java b/grobid-core/src/main/java/org/grobid/core/engines/tagging/GenericTaggerUtils.java index 2653028e7c..feb5c25098 100644 --- a/grobid-core/src/main/java/org/grobid/core/engines/tagging/GenericTaggerUtils.java +++ b/grobid-core/src/main/java/org/grobid/core/engines/tagging/GenericTaggerUtils.java @@ -1,17 +1,15 @@ package org.grobid.core.engines.tagging; -import com.google.common.base.Function; import com.google.common.base.Joiner; import com.google.common.base.Splitter; import org.apache.commons.lang3.StringUtils; -//import org.grobid.core.utilities.Pair; +import org.apache.commons.lang3.tuple.Pair; import org.grobid.core.utilities.Triple; import org.wipo.analyzers.wipokr.utils.StringUtil; -import org.apache.commons.lang3.tuple.Pair; - import java.util.ArrayList; import java.util.List; +import java.util.function.Function; import java.util.regex.Pattern; /** @@ -22,6 +20,7 @@ public class GenericTaggerUtils { public static final String START_ENTITY_LABEL_PREFIX = "I-"; public static final String START_ENTITY_LABEL_PREFIX_ALTERNATIVE = "B-"; + public static final String START_ENTITY_LABEL_PREFIX_ALTERNATIVE_2 = "E-"; public static final Pattern SEPARATOR_PATTERN = Pattern.compile("[\t ]"); /** @@ -30,31 +29,23 @@ public class GenericTaggerUtils { * Note an empty line in the result will be transformed to a 'null' pointer of a pair */ public static List> getTokensAndLabels(String labeledResult) { - Function, Pair> fromSplits = new Function, Pair>() { - @Override public Pair apply(List splits) { - return Pair.of(splits.get(0), splits.get(splits.size() - 1)); - } - }; - - return processLabeledResult(labeledResult, fromSplits); + return processLabeledResult(labeledResult, splits -> Pair.of(splits.get(0), splits.get(splits.size() - 1))); } /** * @param labeledResult labeled result from a tagger - * @return a list of triples - first element in a pair is a token itself, the second is a label (e.g. or I-) + * @return a list of triples - first element in a pair is a token itself, the second is a label (e.g. or I-) * and the third element is a string with the features * Note an empty line in the result will be transformed to a 'null' pointer of a pair */ public static List> getTokensWithLabelsAndFeatures(String labeledResult, final boolean addFeatureString) { - Function, Triple> fromSplits = new Function, Triple>() { - @Override public Triple apply(List splits) { - String featureString = addFeatureString ? Joiner.on("\t").join(splits.subList(0, splits.size() - 1)) : null; - return new Triple<>( - splits.get(0), - splits.get(splits.size() - 1), - featureString); - } + Function, Triple> fromSplits = splits -> { + String featureString = addFeatureString ? Joiner.on("\t").join(splits.subList(0, splits.size() - 1)) : null; + return new Triple<>( + splits.get(0), + splits.get(splits.size() - 1), + featureString); }; return processLabeledResult(labeledResult, fromSplits); @@ -82,6 +73,8 @@ public static String getPlainLabel(String label) { } public static boolean isBeginningOfEntity(String label) { - return StringUtils.startsWith(label, START_ENTITY_LABEL_PREFIX) || StringUtil.startsWith(label, START_ENTITY_LABEL_PREFIX_ALTERNATIVE); + return StringUtils.startsWith(label, START_ENTITY_LABEL_PREFIX) + || StringUtil.startsWith(label, START_ENTITY_LABEL_PREFIX_ALTERNATIVE) + || StringUtil.startsWith(label, START_ENTITY_LABEL_PREFIX_ALTERNATIVE_2); } } \ No newline at end of file diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java index 1fdd2bde8e..fdb91895d6 100755 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/EvaluationUtilities.java @@ -18,6 +18,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static org.grobid.core.engines.tagging.GenericTaggerUtils.getPlainLabel; + /** * Generic evaluation of a single-CRF model processing given an expected result. * @@ -249,7 +251,7 @@ public static Stats fieldLevelStats(String theResult) { (!obtainedLabel.equals(getPlainLabel(previousObtainedLabel)))) { // new obtained field currentObtainedPosition.end = pos - 1; - Pair theField = new Pair<>(getPlainLabel(previousObtainedLabel), + Pair theField = new Pair<>(getPlainLabel(previousObtainedLabel), currentObtainedPosition); currentObtainedPosition = new OffsetPosition(); currentObtainedPosition.start = pos; @@ -260,7 +262,7 @@ public static Stats fieldLevelStats(String theResult) { (!expectedLabel.equals(getPlainLabel(previousExpectedLabel)))) { // new expected field currentExpectedPosition.end = pos - 1; - Pair theField = new Pair<>(getPlainLabel(previousExpectedLabel), + Pair theField = new Pair<>(getPlainLabel(previousExpectedLabel), currentExpectedPosition); currentExpectedPosition = new OffsetPosition(); currentExpectedPosition.start = pos; @@ -274,14 +276,14 @@ public static Stats fieldLevelStats(String theResult) { // last fields of the sequence if ((previousObtainedLabel != null)) { currentObtainedPosition.end = pos - 1; - Pair theField = new Pair<>(getPlainLabel(previousObtainedLabel), + Pair theField = new Pair<>(getPlainLabel(previousObtainedLabel), currentObtainedPosition); obtainedFields.add(theField); } if ((previousExpectedLabel != null)) { currentExpectedPosition.end = pos - 1; - Pair theField = new Pair<>(getPlainLabel(previousExpectedLabel), + Pair theField = new Pair<>(getPlainLabel(previousExpectedLabel), currentExpectedPosition); expectedFields.add(theField); } @@ -336,15 +338,6 @@ public static Stats fieldLevelStats(String theResult) { } - private static String getPlainLabel(String label) { - if (label == null) - return null; - if (label.startsWith("I-") || label.startsWith("E-") || label.startsWith("B-")) { - return label.substring(2, label.length()); - } else - return label; - } - private static void processCounters(Stats stats, String obtained, String expected) { LabelStat expectedStat = stats.getLabelStat(expected); LabelStat obtainedStat = stats.getLabelStat(obtained); diff --git a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelStat.java b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelStat.java index c17d9b4dce..4a47e47ec1 100644 --- a/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelStat.java +++ b/grobid-trainer/src/main/java/org/grobid/trainer/evaluation/LabelStat.java @@ -104,7 +104,7 @@ public double getAccuracy() { } public long getSupport() { - return observed + falsePositive + falseNegative; + return expected; } public double getPrecision() { diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java index 77c5612233..69fac5d2e1 100644 --- a/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java +++ b/grobid-trainer/src/test/java/org/grobid/trainer/StatsTest.java @@ -31,6 +31,7 @@ public void testPrecision_noMatch() throws Exception { target.getLabelStat("MIAO").setFalsePositive(3); target.getLabelStat("MIAO").setFalseNegative(1); assertThat(target.getLabelStat("MIAO").getPrecision(), is(0.0)); + assertThat(target.getLabelStat("MIAO").getSupport(), is(4L)); } @Test diff --git a/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java b/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java index 5d25a245ca..6ce08b58fb 100644 --- a/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java +++ b/grobid-trainer/src/test/java/org/grobid/trainer/evaluation/EvaluationUtilitiesTest.java @@ -4,6 +4,7 @@ import org.junit.Test; import java.nio.charset.StandardCharsets; +import java.util.TreeMap; import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertThat; @@ -22,6 +23,9 @@ public void testTokenLevelStats_allGood() throws Exception { assertThat(labelstat2.getObserved(), is(1)); assertThat(labelstat1.getExpected(), is(4)); assertThat(labelstat2.getExpected(), is(1)); + + assertThat(labelstat1.getSupport(), is(4L)); + assertThat(labelstat2.getSupport(), is(1L)); } @Test @@ -37,6 +41,9 @@ public void testFieldLevelStats_allGood() throws Exception { assertThat(labelstat2.getObserved(), is(1)); assertThat(labelstat1.getExpected(), is(2)); assertThat(labelstat2.getExpected(), is(1)); + + assertThat(labelstat1.getSupport(), is(2L)); + assertThat(labelstat2.getSupport(), is(1L)); } @Test @@ -49,9 +56,12 @@ public void testTokenLevelStats_noMatch() throws Exception { LabelStat labelstat2 = wordStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(0)); - assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat1.getFalseNegative(), is(5)); + assertThat(labelstat1.getSupport(), is(5L)); + + assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat2.getFalsePositive(), is(5)); + assertThat(labelstat2.getSupport(), is(0L)); } @Test @@ -63,9 +73,12 @@ public void testFieldLevelStats_noMatch() throws Exception { LabelStat labelstat2 = fieldStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(0)); - assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat1.getExpected(), is(1)); + assertThat(labelstat1.getSupport(), is(1L)); + + assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat2.getExpected(), is(0)); + assertThat(labelstat2.getSupport(), is(0L)); } @Test @@ -83,11 +96,13 @@ public void testTokenLevelStats_mixed() throws Exception { assertThat(labelstat1.getExpected(), is(4)); assertThat(labelstat1.getFalseNegative(), is(0)); assertThat(labelstat1.getFalsePositive(), is(1)); + assertThat(labelstat1.getSupport(), is(4L)); assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat2.getExpected(), is(1)); assertThat(labelstat2.getFalseNegative(), is(1)); assertThat(labelstat2.getFalsePositive(), is(0)); + assertThat(labelstat2.getSupport(), is(1L)); } @Test @@ -102,9 +117,12 @@ public void testFieldLevelStats_mixed() throws Exception { LabelStat labelstat2 = fieldStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(0)); - assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat1.getExpected(), is(2)); + assertThat(labelstat1.getSupport(), is(2L)); + + assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat2.getExpected(), is(1)); + assertThat(labelstat2.getSupport(), is(1L)); } @Test @@ -116,13 +134,16 @@ public void testTokenLevelStats2_mixed() throws Exception { LabelStat labelstat2 = wordStats.getLabelStat("<2>"); assertThat(labelstat1.getObserved(), is(4)); - assertThat(labelstat2.getObserved(), is(0)); assertThat(labelstat1.getExpected(), is(4)); - assertThat(labelstat2.getExpected(), is(1)); assertThat(labelstat1.getFalseNegative(), is(0)); - assertThat(labelstat2.getFalseNegative(), is(1)); assertThat(labelstat1.getFalsePositive(), is(1)); + assertThat(labelstat1.getSupport(), is(4L)); + + assertThat(labelstat2.getObserved(), is(0)); + assertThat(labelstat2.getExpected(), is(1)); + assertThat(labelstat2.getFalseNegative(), is(1)); assertThat(labelstat2.getFalsePositive(), is(0)); + assertThat(labelstat2.getSupport(), is(1L)); } @Test @@ -311,7 +332,19 @@ public void testTokenLevelStats2_realCase() throws Exception { assertThat(personLabelStats.getExpected(), is(0)); assertThat(personLabelStats.getFalseNegative(), is(0)); assertThat(personLabelStats.getFalsePositive(), is(1)); + } + + @Test + public void testTokenLevelStats3_realCase() throws Exception { + String result = IOUtils.toString(this.getClass().getResourceAsStream("/sample.wapiti.output.3.txt"), StandardCharsets.UTF_8); + + + Stats fieldStats = EvaluationUtilities.fieldLevelStats(result); + + TreeMap labelsResults = fieldStats.getLabelsResults(); + assertThat(labelsResults.get("").getSupport(), is(4L)); + assertThat(labelsResults.get("").getSupport(), is(2L)); } } \ No newline at end of file diff --git a/grobid-trainer/src/test/resources/sample.wapiti.output.3.txt b/grobid-trainer/src/test/resources/sample.wapiti.output.3.txt new file mode 100644 index 0000000000..0e58507e12 --- /dev/null +++ b/grobid-trainer/src/test/resources/sample.wapiti.output.3.txt @@ -0,0 +1,13 @@ +y 0 0 1 1 NOPUNCT 0 I- I- +r 0 0 0 0 NOPUNCT 0 +s 0 0 1 0 NOPUNCT 0 + +s 0 0 1 0 NOPUNCT 0 I- I- + +G 1 0 1 1 NOPUNCT 0 I- I- +P 1 0 1 1 NOPUNCT 0 I- I- +a 0 0 1 1 NOPUNCT 0 + +G 1 0 1 1 NOPUNCT 0 I- I- +P 1 0 1 1 NOPUNCT 0 I- I- +a 0 0 1 1 NOPUNCT 0 \ No newline at end of file