Skip to content

Commit

Permalink
output the model name for debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Oct 12, 2020
1 parent b0560b5 commit 50b784b
Showing 1 changed file with 54 additions and 54 deletions.
108 changes: 54 additions & 54 deletions grobid-core/src/main/java/org/grobid/core/jni/DeLFTModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.*;
import java.util.concurrent.*;
import java.io.*;
import java.lang.StringBuilder;
import java.util.*;
Expand All @@ -20,7 +20,7 @@
import java.util.function.Consumer;

/**
*
*
* @author: Patrice
*/
public class DeLFTModel {
Expand All @@ -39,32 +39,32 @@ public DeLFTModel(GrobidModel model) {
}
}

class InitModel implements Runnable {
class InitModel implements Runnable {
private String modelName;
private File modelPath;
public InitModel(String modelName, File modelPath) {

public InitModel(String modelName, File modelPath) {
this.modelName = modelName;
this.modelPath = modelPath;
}
}

@Override
public void run() {
Jep jep = JEPThreadPool.getInstance().getJEPInstance();
try {
public void run() {
Jep jep = JEPThreadPool.getInstance().getJEPInstance();
try {
jep.eval(this.modelName+" = Sequence('" + this.modelName.replace("_", "-") + "')");
jep.eval(this.modelName+".load(dir_path='"+modelPath.getAbsolutePath()+"')");
} catch(JepException e) {
throw new GrobidException("DeLFT model initialization failed. ", e);
}
}
}
}
}

private class LabelTask implements Callable<String> {
private class LabelTask implements Callable<String> {
private String data;
private String modelName;

public LabelTask(String modelName, String data) {
public LabelTask(String modelName, String data) {
//System.out.println("label thread: " + Thread.currentThread().getId());
this.modelName = modelName;
this.data = data;
Expand Down Expand Up @@ -92,8 +92,8 @@ private void setJepStringValueWithFileFallback(
}

@Override
public String call() {
Jep jep = JEPThreadPool.getInstance().getJEPInstance();
public String call() {
Jep jep = JEPThreadPool.getInstance().getJEPInstance();
StringBuilder labelledData = new StringBuilder();
try {
//System.out.println(this.data);
Expand All @@ -102,7 +102,7 @@ public String call() {
this.setJepStringValueWithFileFallback(jep, "input", this.data);
jep.eval("x_all, f_all = load_data_crf_string(input)");
Object objectResults = jep.getValue(this.modelName+".tag(x_all, None)");

// inject back the labels
List<List<List<String>>> results = (List<List<List<String>>>) objectResults;
BufferedReader bufReader = new BufferedReader(new StringReader(data));
Expand Down Expand Up @@ -138,21 +138,21 @@ public String call() {
labelledData.append("\n");
j++;
}

// cleaning
jep.eval("del input");
jep.eval("del x_all");
jep.eval("del f_all");
//jep.eval("K.clear_session()");
} catch(JepException e) {
LOGGER.error("DeLFT model labelling via JEP failed", e);
LOGGER.error("DeLFT model " + this.modelName + " labelling via JEP failed", e);
} catch(IOException e) {
LOGGER.error("DeLFT model labelling failed", e);
LOGGER.error("DeLFT model " + this.modelName + " labelling failed", e);
}
//System.out.println(labelledData.toString());
return labelledData.toString();
}
}
}
}

public String label(String data) {
String result = null;
Expand All @@ -169,7 +169,7 @@ public String label(String data) {
/**
* Training via JNI CPython interpreter (JEP). It appears that after some epochs, the JEP thread
* usually hangs... Possibly issues with IO threads at the level of JEP (output not consumed because
* of \r and no end of line?).
* of \r and no end of line?).
*/
public static void trainJNI(String modelName, File trainingData, File outputModel) {
try {
Expand All @@ -180,21 +180,21 @@ public static void trainJNI(String modelName, File trainingData, File outputMode
}
}

private static class TrainTask implements Runnable {
private static class TrainTask implements Runnable {
private String modelName;
private File trainPath;
private File modelPath;

public TrainTask(String modelName, File trainPath, File modelPath) {
public TrainTask(String modelName, File trainPath, File modelPath) {
//System.out.println("train thread: " + Thread.currentThread().getId());
this.modelName = modelName;
this.trainPath = trainPath;
this.modelPath = modelPath;
}
}

@Override
public void run() {
Jep jep = JEPThreadPool.getInstance().getJEPInstance();
public void run() {
Jep jep = JEPThreadPool.getInstance().getJEPInstance();
try {
// load data
jep.eval("x_all, y_all, f_all = load_data_and_labels_crf_file('" + this.trainPath.getAbsolutePath() + "')");
Expand All @@ -220,7 +220,7 @@ public void run() {
// saving the model
System.out.println(this.modelPath.getAbsolutePath());
jep.eval("model.save('"+this.modelPath.getAbsolutePath()+"')");

// cleaning
jep.eval("del x_all");
jep.eval("del y_all");
Expand All @@ -232,19 +232,19 @@ public void run() {
jep.eval("del model");
} catch(JepException e) {
LOGGER.error("DeLFT model training via JEP failed", e);
}
}
}
}
}
}

/**
* Train with an external process rather than with JNI, this approach appears to be more stable for the
* training process (JNI approach hangs after a while) and does not raise any runtime/integration issues.
* training process (JNI approach hangs after a while) and does not raise any runtime/integration issues.
*/
public static void train(String modelName, File trainingData, File outputModel) {
try {
LOGGER.info("Train DeLFT model " + modelName + "...");
List<String> command = Arrays.asList("python3",
"grobidTagger.py",
List<String> command = Arrays.asList("python3",
"grobidTagger.py",
modelName,
"train",
"--input", trainingData.getAbsolutePath(),
Expand All @@ -256,9 +256,9 @@ public static void train(String modelName, File trainingData, File outputModel)
ProcessBuilder pb = new ProcessBuilder(command);
File delftPath = new File(GrobidProperties.getInstance().getDeLFTFilePath());
pb.directory(delftPath);
Process process = pb.start();
Process process = pb.start();
//pb.inheritIO();
CustomStreamGobbler customStreamGobbler =
CustomStreamGobbler customStreamGobbler =
new CustomStreamGobbler(process.getInputStream(), System.out);
Executors.newSingleThreadExecutor().submit(customStreamGobbler);
SimpleStreamGobbler streamGobbler = new SimpleStreamGobbler(process.getErrorStream(), System.err::println);
Expand All @@ -281,22 +281,22 @@ public synchronized void close() {
}
}

private class CloseModel implements Runnable {
private class CloseModel implements Runnable {
private String modelName;
public CloseModel(String modelName) {

public CloseModel(String modelName) {
this.modelName = modelName;
}
}

@Override
public void run() {
Jep jep = JEPThreadPool.getInstance().getJEPInstance();
try {
public void run() {
Jep jep = JEPThreadPool.getInstance().getJEPInstance();
try {
jep.eval("del "+this.modelName);
} catch(JepException e) {
LOGGER.error("Closing DeLFT model failed", e);
}
}
}
}
}

private static String delft2grobidLabel(String label) {
Expand All @@ -306,19 +306,19 @@ private static String delft2grobidLabel(String label) {
label = label.replace(TaggingLabels.IOB_START_ENTITY_LABEL_PREFIX, TaggingLabels.GROBID_START_ENTITY_LABEL_PREFIX);
} else if (label.startsWith(TaggingLabels.IOB_INSIDE_LABEL_PREFIX)) {
label = label.replace(TaggingLabels.IOB_INSIDE_LABEL_PREFIX, TaggingLabels.GROBID_INSIDE_ENTITY_LABEL_PREFIX);
}
}
return label;
}

private static class SimpleStreamGobbler implements Runnable {
private InputStream inputStream;
private Consumer<String> consumer;

public SimpleStreamGobbler(InputStream inputStream, Consumer<String> consumer) {
this.inputStream = inputStream;
this.consumer = consumer;
}

@Override
public void run() {
new BufferedReader(new InputStreamReader(inputStream)).lines()
Expand All @@ -328,8 +328,8 @@ public void run() {

/**
* This is a custom gobbler that reproduces correctly the Keras training progress bar
* by injecting a \r for progress line updates.
*/
* by injecting a \r for progress line updates.
*/
private static class CustomStreamGobbler implements Runnable {
public static final Logger LOGGER = LoggerFactory.getLogger(CustomStreamGobbler.class);

Expand All @@ -341,7 +341,7 @@ public CustomStreamGobbler(InputStream is, PrintStream os) {
this.is = is;
this.os = os;
}

@Override
public void run() {
try {
Expand Down

0 comments on commit 50b784b

Please sign in to comment.