Skip to content

Commit

Permalink
Merge pull request #658 from kermitt2/architecture-parameter
Browse files Browse the repository at this point in the history
Architecture parameter
  • Loading branch information
kermitt2 committed Oct 21, 2020
2 parents c3d3a73 + 9d257d5 commit a3ffce8
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 28 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ project("grobid-core") {

implementation "joda-time:joda-time:2.9.9"
implementation "org.apache.lucene:lucene-analyzers-common:4.5.1"
implementation 'black.ninia:jep:3.8.2'
implementation 'black.ninia:jep:3.9.1'
implementation 'org.apache.opennlp:opennlp-tools:1.9.1'
implementation group: 'org.jruby', name: 'jruby-complete', version: '9.2.13.0'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ protected AbstractParser(GrobidModel model, CntManager cntManager, GrobidCRFEngi
genericTagger = TaggerFactory.getTagger(model, engine);
}

protected AbstractParser(GrobidModel model, CntManager cntManager, GrobidCRFEngine engine, String architecture) {
this.cntManager = cntManager;
genericTagger = TaggerFactory.getTagger(model, engine, architecture);
}

@Override
public String label(Iterable<String> data) {
return genericTagger.label(data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ public class DeLFTTagger implements GenericTagger {
private final DeLFTModel delftModel;

public DeLFTTagger(GrobidModel model) {
delftModel = new DeLFTModel(model);
delftModel = new DeLFTModel(model, null);
}

public DeLFTTagger(GrobidModel model, String architecture) {
delftModel = new DeLFTModel(model, architecture);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ public class TaggerFactory {
private TaggerFactory() {}

public static synchronized GenericTagger getTagger(GrobidModel model) {
return getTagger(model, GrobidProperties.getGrobidCRFEngine(model));
return getTagger(model, GrobidProperties.getGrobidCRFEngine(model), GrobidProperties.getDelftArchitecture());
}

public static synchronized GenericTagger getTagger(GrobidModel model, GrobidCRFEngine engine) {
return getTagger(model, engine, GrobidProperties.getDelftArchitecture());
}

public static synchronized GenericTagger getTagger(GrobidModel model, GrobidCRFEngine engine, String architecture) {
GenericTagger t = cache.get(model);
if (t == null) {
if(model.equals(GrobidModels.DUMMY)) {
Expand All @@ -44,16 +48,7 @@ public static synchronized GenericTagger getTagger(GrobidModel model, GrobidCRFE
t = new WapitiTagger(model);
break;
case DELFT:
// be sure the native JEP lib can be loaded
// try {
// String libraryFolder = LibraryLoader.getLibraryFolder();
// System.out.println(libraryFolder);
// LibraryLoader.addLibraryPath(libraryFolder);
// } catch (Exception e) {
// LOGGER.info("Loading JEP native library for DeLFT failed", e);
// }

t = new DeLFTTagger(model);
t = new DeLFTTagger(model, architecture);
break;
default:
throw new IllegalStateException("Unsupported Grobid sequence labelling engine: " + GrobidProperties.getGrobidCRFEngine());
Expand Down
48 changes: 35 additions & 13 deletions grobid-core/src/main/java/org/grobid/core/jni/DeLFTModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ public class DeLFTModel {
// Exploit JNI CPython interpreter to execute load and execute a DeLFT deep learning model
private String modelName;

public DeLFTModel(GrobidModel model) {
public DeLFTModel(GrobidModel model, String architecture) {
this.modelName = model.getModelName().replace("-", "_");
try {
LOGGER.info("Loading DeLFT model for " + model.getModelName() + "...");
JEPThreadPool.getInstance().run(new InitModel(this.modelName, GrobidProperties.getInstance().getModelPath()));
LOGGER.info("Loading DeLFT model for " + model.getModelName() + " with architecture " + architecture + "...");
JEPThreadPool.getInstance().run(new InitModel(this.modelName, GrobidProperties.getInstance().getModelPath(), architecture));
} catch(InterruptedException e) {
LOGGER.error("DeLFT model " + this.modelName + " initialization failed", e);
}
Expand All @@ -42,17 +42,27 @@ public DeLFTModel(GrobidModel model) {
class InitModel implements Runnable {
private String modelName;
private File modelPath;
private String architecture;

public InitModel(String modelName, File modelPath) {
public InitModel(String modelName, File modelPath, String architecture) {
this.modelName = modelName;
this.modelPath = modelPath;
this.architecture = architecture;
}

@Override
public void run() {
Jep jep = JEPThreadPool.getInstance().getJEPInstance();
try {
jep.eval(this.modelName+" = Sequence('" + this.modelName.replace("_", "-") + "')");
String fullModelName = this.modelName.replace("_", "-");

if (architecture != null && !architecture.equals("BidLSTM_CRF"))
fullModelName += "-" + this.architecture;

if (GrobidProperties.getInstance().useELMo() && modelName.toLowerCase().indexOf("bert") == -1)
fullModelName += "-with_ELMo";

jep.eval(this.modelName+" = Sequence('" + fullModelName + "')");
jep.eval(this.modelName+".load(dir_path='"+modelPath.getAbsolutePath()+"')");
} catch(JepException e) {
throw new GrobidException("DeLFT model initialization failed. ", e);
Expand Down Expand Up @@ -171,10 +181,11 @@ public String label(String data) {
* usually hangs... Possibly issues with IO threads at the level of JEP (output not consumed because
* of \r and no end of line?).
*/
public static void trainJNI(String modelName, File trainingData, File outputModel) {
public static void trainJNI(String modelName, File trainingData, File outputModel, String architecture) {
try {
LOGGER.info("Train DeLFT model " + modelName + "...");
JEPThreadPool.getInstance().run(new TrainTask(modelName, trainingData, GrobidProperties.getInstance().getModelPath()));
JEPThreadPool.getInstance().run(
new TrainTask(modelName, trainingData, GrobidProperties.getInstance().getModelPath(), architecture));
} catch(InterruptedException e) {
LOGGER.error("Train DeLFT model " + modelName + " task failed", e);
}
Expand All @@ -184,12 +195,14 @@ private static class TrainTask implements Runnable {
private String modelName;
private File trainPath;
private File modelPath;
private String architecture;

public TrainTask(String modelName, File trainPath, File modelPath) {
public TrainTask(String modelName, File trainPath, File modelPath, String architecture) {
//System.out.println("train thread: " + Thread.currentThread().getId());
this.modelName = modelName;
this.trainPath = trainPath;
this.modelPath = modelPath;
this.architecture = architecture;
}

@Override
Expand All @@ -203,13 +216,18 @@ public void run() {
jep.eval("print(len(x_valid), 'validation sequences')");

String useELMo = "False";
if (GrobidProperties.getInstance().useELMo()) {
if (GrobidProperties.getInstance().useELMo() && modelName.toLowerCase().indexOf("bert") == -1) {
useELMo = "True";
}

// init model to be trained
jep.eval("model = Sequence('"+this.modelName+
"', max_epoch=100, recurrent_dropout=0.50, embeddings_name='glove-840B', use_ELMo="+useELMo+")");
if (architecture == null)
jep.eval("model = Sequence('"+this.modelName+
"', max_epoch=100, recurrent_dropout=0.50, embeddings_name='glove-840B', use_ELMo="+useELMo+")");
else
jep.eval("model = Sequence('"+this.modelName+
"', max_epoch=100, recurrent_dropout=0.50, embeddings_name='glove-840B', use_ELMo="+useELMo+
", model_type='"+architecture+"')");

// actual training
//start_time = time.time()
Expand Down Expand Up @@ -240,7 +258,7 @@ public void run() {
* Train with an external process rather than with JNI, this approach appears to be more stable for the
* training process (JNI approach hangs after a while) and does not raise any runtime/integration issues.
*/
public static void train(String modelName, File trainingData, File outputModel) {
public static void train(String modelName, File trainingData, File outputModel, String architecture) {
try {
LOGGER.info("Train DeLFT model " + modelName + "...");
List<String> command = Arrays.asList("python3",
Expand All @@ -249,7 +267,11 @@ public static void train(String modelName, File trainingData, File outputModel)
"train",
"--input", trainingData.getAbsolutePath(),
"--output", GrobidProperties.getInstance().getModelPath().getAbsolutePath());
if (GrobidProperties.getInstance().useELMo()) {
if (architecture != null) {
command.add("--architecture");
command.add(architecture);
}
if (GrobidProperties.getInstance().useELMo() && modelName.toLowerCase().indexOf("bert") == -1) {
command.add("--use-ELMo");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,14 @@ else if (rawValue.equals("false"))
return false;
}

public static String getDelftArchitecture() {
return getPropertyValue(GrobidPropertyKeys.PROP_DELFT_ARCHITECTURE);
}

public static void setDelftArchitecture(final String theArchitecture) {
setPropertyValue(GrobidPropertyKeys.PROP_DELFT_ARCHITECTURE, theArchitecture);
}

/**
* Returns the host for a proxy connection, given in the grobid-property
* file.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public interface GrobidPropertyKeys {
String PROP_GROBID_DELFT_PATH = "grobid.delft.install";
String PROP_GROBID_DELFT_REDIRECT_OUTPUT = "grobid.delft.redirect_output";
String PROP_GROBID_DELFT_ELMO = "grobid.delft.useELMo";
String PROP_DELFT_ARCHITECTURE = "grobid.delft.architecture";
String PROP_USE_LANG_ID = "grobid.use_language_id";
String PROP_LANG_DETECTOR_FACTORY = "grobid.language_detector_factory";
String PROP_SENTENCE_DETECTOR_FACTORY = "grobid.sentence_detector_factory";
Expand Down
5 changes: 4 additions & 1 deletion grobid-home/config/grobid.properties
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,16 @@ grobid.crf.engine.fulltext=wapiti
#grobid.crf.engine.figure=wapiti
#grobid.crf.engine.table=wapiti
#grobid.crf.engine.name_citation=wapiti
#grobid.crf.engine.affiliation_address=wapiti
#grobid.crf.engine.affiliation_address=delft
#grobid.crf.engine.citation=delft

grobid.delft.install=../delft
grobid.delft.useELMo=false
grobid.delft.python.virtualEnv=
grobid.delft.redirect.output=true
grobid.delft.architecture=BidLSTM_CRF
#grobid.delft.architecture=scibert

grobid.pdf.blocks.max=100000
grobid.pdf.tokens.max=1000000

Expand Down
Binary file modified grobid-home/lib/lin-64/libjep.so
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import org.grobid.core.jni.DeLFTModel;
import org.grobid.core.GrobidModels;
import org.grobid.trainer.SegmentationTrainer;
import org.grobid.core.utilities.GrobidProperties;
import java.math.BigDecimal;

import java.io.File;
Expand All @@ -17,7 +18,7 @@ public class DeLFTTrainer implements GenericTrainer {

@Override
public void train(File template, File trainingData, File outputModel, int numThreads, GrobidModel model) {
DeLFTModel.train(model.getModelName(), trainingData, outputModel);
DeLFTModel.train(model.getModelName(), trainingData, outputModel, GrobidProperties.getDelftArchitecture());
}

@Override
Expand Down

0 comments on commit a3ffce8

Please sign in to comment.