-
Notifications
You must be signed in to change notification settings - Fork 427
/
TableTrainer.java
227 lines (201 loc) · 9.11 KB
/
TableTrainer.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
package org.grobid.trainer;
import org.grobid.core.GrobidModels;
import org.grobid.core.exceptions.GrobidException;
import org.grobid.core.utilities.GrobidProperties;
import org.grobid.core.utilities.UnicodeUtil;
import org.grobid.trainer.sax.TEIFigureSaxParser;
import javax.xml.parsers.SAXParser;
import javax.xml.parsers.SAXParserFactory;
import java.io.*;
import java.util.List;
import java.util.StringTokenizer;
/**
* @author Patrice Lopez
*/
public class TableTrainer extends AbstractTrainer {
public TableTrainer() {
super(GrobidModels.TABLE);
// adjusting CRF training parameters for this model (only with Wapiti)
epsilon = 0.00001;
window = 20;
}
@Override
public int createCRFPPData(File corpusPath, File outputFile) {
return addFeaturesTable(corpusPath.getAbsolutePath() + "/tei",
corpusPath.getAbsolutePath() + "/raw",
outputFile, null, 1.0);
}
/**
* Add the selected features for the table model
*
* @param corpusDir path where corpus files are located
* @param trainingOutputPath path where to store the temporary training data
* @param evalOutputPath path where to store the temporary evaluation data
* @param splitRatio ratio to consider for separating training and evaluation data, e.g. 0.8 for 80%
* @return the total number of used corpus items
*/
@Override
public int createCRFPPData(final File corpusDir,
final File trainingOutputPath,
final File evalOutputPath,
double splitRatio) {
return addFeaturesTable(corpusDir.getAbsolutePath() + "/tei",
corpusDir.getAbsolutePath() + "/raw",
trainingOutputPath,
evalOutputPath,
splitRatio);
}
/**
* Add the selected features for the table model
*
* @param sourceTEIPathLabel path to corpus TEI files
* @param sourceRawPathLabel path to corpus raw files
* @param trainingOutputPath path where to store the temporary training data
* @param evalOutputPath path where to store the temporary evaluation data
* @param splitRatio ratio to consider for separating training and evaluation data, e.g. 0.8 for 80%
* @return number of examples
*/
public int addFeaturesTable(String sourceTEIPathLabel,
String sourceRawPathLabel,
final File trainingOutputPath,
final File evalOutputPath,
double splitRatio) {
int totalExamples = 0;
try {
System.out.println("sourceTEIPathLabel: " + sourceTEIPathLabel);
System.out.println("sourceRawPathLabel: " + sourceRawPathLabel);
System.out.println("trainingOutputPath: " + trainingOutputPath);
System.out.println("evalOutputPath: " + evalOutputPath);
// we need first to generate the labeled files from the TEI annotated files
File input = new File(sourceTEIPathLabel);
// we process all tei files in the output directory
File[] refFiles = input.listFiles(new FilenameFilter() {
public boolean accept(File dir, String name) {
return name.endsWith(".tei.xml") || name.endsWith(".tei");
}
});
if (refFiles == null) {
return 0;
}
System.out.println(refFiles.length + " tei files");
// the file for writing the training data
OutputStream os2 = null;
Writer writer2 = null;
if (trainingOutputPath != null) {
os2 = new FileOutputStream(trainingOutputPath);
writer2 = new OutputStreamWriter(os2, "UTF8");
}
// the file for writing the evaluation data
OutputStream os3 = null;
Writer writer3 = null;
if (evalOutputPath != null) {
os3 = new FileOutputStream(evalOutputPath);
writer3 = new OutputStreamWriter(os3, "UTF8");
}
// get a factory for SAX parser
SAXParserFactory spf = SAXParserFactory.newInstance();
for (File tf : refFiles) {
String name = tf.getName();
System.out.println(name);
// the full text SAX parser can be reused for the tables
TEIFigureSaxParser parser2 = new TEIFigureSaxParser();
//parser2.setMode(TEIFulltextSaxParser.TABLE);
//get a new instance of parser
SAXParser p = spf.newSAXParser();
p.parse(tf, parser2);
List<String> labeled = parser2.getLabeledResult();
//totalExamples += parser2.n;
// we can now add the features
// we open the featured file
File theRawFile = new File(sourceRawPathLabel + File.separator + name.replace(".tei.xml", ""));
if (!theRawFile.exists()) {
System.out.println("Raw file " + theRawFile +
" does not exist. Please have a look!");
continue;
}
int q = 0;
BufferedReader bis = new BufferedReader(
new InputStreamReader(new FileInputStream(
sourceRawPathLabel + File.separator + name.replace(".tei.xml", "")), "UTF8"));
StringBuilder table = new StringBuilder();
String line;
while ((line = bis.readLine()) != null) {
if (line.trim().length() < 2) {
table.append("\n");
}
int ii = line.indexOf('\t');
if (ii == -1) {
ii = line.indexOf(' ');
}
String token = null;
if (ii != -1) {
token = line.substring(0, ii);
// unicode normalisation of the token - it should not be necessary if the training data
// has been gnerated by a recent version of grobid
token = UnicodeUtil.normaliseTextAndRemoveSpaces(token);
}
// boolean found = false;
// we get the label in the labelled data file for the same token
for (int pp = q; pp < labeled.size(); pp++) {
String localLine = labeled.get(pp);
if (localLine.length() == 0) {
q = pp + 1;
continue;
}
StringTokenizer st = new StringTokenizer(localLine, " \t");
if (st.hasMoreTokens()) {
String localToken = st.nextToken();
// unicode normalisation of the token - it should not be necessary if the training data
// has been gnerated by a recent version of grobid
localToken = UnicodeUtil.normaliseTextAndRemoveSpaces(localToken);
if (localToken.equals(token)) {
String tag = st.nextToken();
line = line.replace("\t", " ").replace(" ", " ");
table.append(line).append(" ").append(tag);
q = pp + 1;
pp = q + 10;
}
}
if (pp - q > 5) {
break;
}
}
}
bis.close();
if ((writer2 == null) && (writer3 != null))
writer3.write(table.toString() + "\n");
if ((writer2 != null) && (writer3 == null))
writer2.write(table.toString() + "\n");
else {
if (Math.random() <= splitRatio)
writer2.write(table.toString() + "\n");
else
writer3.write(table.toString() + "\n");
}
}
if (writer2 != null) {
writer2.close();
os2.close();
}
if (writer3 != null) {
writer3.close();
os3.close();
}
} catch (Exception e) {
throw new GrobidException("An exception occured while running training for the table model.", e);
}
return totalExamples;
}
/**
* Command line execution.
*
* @param args Command line arguments.
* @throws Exception
*/
public static void main(String[] args) throws Exception {
GrobidProperties.getInstance();
System.out.println(AbstractTrainer.runNFoldEvaluation(new TableTrainer(), 2));
// System.out.println(AbstractTrainer.runEvaluation(new TableTrainer()));
System.exit(0);
}
}