Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom delft train args #469

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
47 changes: 33 additions & 14 deletions grobid-core/src/main/java/org/grobid/core/jni/DeLFTModel.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.grobid.core.jni;

import org.apache.commons.lang3.StringUtils;

import org.grobid.core.GrobidModel;
import org.grobid.core.engines.label.TaggingLabels;
import org.grobid.core.exceptions.GrobidException;
Expand Down Expand Up @@ -252,7 +254,35 @@ public void run() {
LOGGER.error("DeLFT model training via JEP failed", e);
}
}
}
}

protected static List<String> getTrainCommand(String modelName, File trainingData, String architecture) {
String trainModule = GrobidProperties.getDeLFTTrainModule();
lfoppiano marked this conversation as resolved.
Show resolved Hide resolved
if (StringUtils.isEmpty(trainModule)) {
trainModule = "grobidTagger.py";
}
List<String> command = new ArrayList<>(Arrays.asList(
"python3",
trainModule,
modelName,
"train",
"--input", trainingData.getAbsolutePath(),
"--output", GrobidProperties.getModelPath().getAbsolutePath()
));
if (architecture != null) {
command.add("--architecture");
command.add(architecture);
}
if (GrobidProperties.useELMo() && modelName.toLowerCase().indexOf("bert") == -1) {
command.add("--use-ELMo");
}
if (StringUtils.isNotEmpty(GrobidProperties.getDeLFTTrainArgs())) {
command.addAll(Arrays.asList(
GrobidProperties.getDeLFTTrainArgs().split(" ")
));
}
return command;
}

/**
* Train with an external process rather than with JNI, this approach appears to be more stable for the
Expand All @@ -261,19 +291,8 @@ public void run() {
public static void train(String modelName, File trainingData, File outputModel, String architecture) {
try {
LOGGER.info("Train DeLFT model " + modelName + "...");
List<String> command = Arrays.asList("python3",
"grobidTagger.py",
modelName,
"train",
"--input", trainingData.getAbsolutePath(),
"--output", GrobidProperties.getInstance().getModelPath().getAbsolutePath());
if (architecture != null) {
command.add("--architecture");
command.add(architecture);
}
if (GrobidProperties.getInstance().useELMo() && modelName.toLowerCase().indexOf("bert") == -1) {
command.add("--use-ELMo");
}
List<String> command = getTrainCommand(modelName, trainingData, architecture);
LOGGER.info("Running: {}", command);

ProcessBuilder pb = new ProcessBuilder(command);
File delftPath = new File(GrobidProperties.getInstance().getDeLFTFilePath());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,16 @@ public static boolean isDeLFTRedirectOutput() {
);
}

public static String getDeLFTTrainModule() {
return getPropertyValue(
GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_MODULE, ""
);
}

public static String getDeLFTTrainArgs() {
return getPropertyValue(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS, "");
}

public static String getGluttonHost() {
return getPropertyValue(GrobidPropertyKeys.PROP_GLUTTON_HOST);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public interface GrobidPropertyKeys {
String PROP_GROBID_DELFT_REDIRECT_OUTPUT = "grobid.delft.redirect_output";
String PROP_GROBID_DELFT_ELMO = "grobid.delft.useELMo";
String PROP_DELFT_ARCHITECTURE = "grobid.delft.architecture";
String PROP_GROBID_DELFT_TRAIN_MODULE = "grobid.delft.train.module";
String PROP_GROBID_DELFT_TRAIN_ARGS = "grobid.delft.train.args";
String PROP_USE_LANG_ID = "grobid.use_language_id";
String PROP_LANG_DETECTOR_FACTORY = "grobid.language_detector_factory";
String PROP_SENTENCE_DETECTOR_FACTORY = "grobid.sentence_detector_factory";
Expand Down
97 changes: 97 additions & 0 deletions grobid-core/src/test/java/org/grobid/core/jni/DeLFTModelTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package org.grobid.core.jni;

import java.io.File;

import org.junit.Before;
import org.junit.Test;

import static org.hamcrest.Matchers.contains;
import static org.junit.Assert.assertThat;

import org.grobid.core.utilities.GrobidProperties;
import org.grobid.core.utilities.GrobidPropertyKeys;


public class DeLFTModelTest {
private File trainingData = new File("test/train.data");

@Before
public void setUp() {
GrobidProperties.getInstance();
GrobidProperties.getProps().put(GrobidPropertyKeys.PROP_GROBID_DELFT_ELMO, "false");
GrobidProperties.getProps().remove(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_MODULE);
GrobidProperties.getProps().remove(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS);
}

@Test
public void testShouldBuildTrainCommand() {
assertThat(
DeLFTModel.getTrainCommand("model1", trainingData, null),
contains(
"python3", "grobidTagger.py", "model1", "train",
"--input", this.trainingData.getAbsolutePath(),
"--output", GrobidProperties.getModelPath().getAbsolutePath()
)
);
}

@Test
public void testShouldAddUseELMO() {
GrobidProperties.getProps().put(GrobidPropertyKeys.PROP_GROBID_DELFT_ELMO, "true");
assertThat(
DeLFTModel.getTrainCommand("model1", trainingData, null),
contains(
"python3", "grobidTagger.py", "model1", "train",
"--input", this.trainingData.getAbsolutePath(),
"--output", GrobidProperties.getModelPath().getAbsolutePath(),
"--use-ELMo"
)
);
}

@Test
public void testShouldUseCustomTrainModule() {
GrobidProperties.getProps().put(
GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_MODULE, "module1.py"
);
assertThat(
DeLFTModel.getTrainCommand("model1", trainingData, null),
contains(
"python3", "module1.py", "model1", "train",
"--input", this.trainingData.getAbsolutePath(),
"--output", GrobidProperties.getModelPath().getAbsolutePath()
)
);
}

@Test
public void testShouldAddSingleCustomTrainArg() {
GrobidProperties.getProps().put(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS, "arg1");
assertThat(
DeLFTModel.getTrainCommand("model1", trainingData, null),
contains(
"python3", "grobidTagger.py", "model1", "train",
"--input", this.trainingData.getAbsolutePath(),
"--output", GrobidProperties.getModelPath().getAbsolutePath(),
"arg1"
)
);
}

@Test
public void testShouldAddMultipleCustomTrainArg() {
GrobidProperties.getProps().put(
GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS, "arg1 arg2"
);
assertThat(
DeLFTModel.getTrainCommand("model1", trainingData, null),
contains(
"python3", "grobidTagger.py", "model1", "train",
"--input", this.trainingData.getAbsolutePath(),
"--output", GrobidProperties.getModelPath().getAbsolutePath(),
"arg1",
"arg2"
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,32 @@ public void testIsDeLFTRedirectOutputTrueIfSet() throws IOException {
assertTrue(GrobidProperties.isDeLFTRedirectOutput());
}

@Test
public void testShouldReturnEmptyTrainModuleByDefault() {
GrobidProperties.getProps().remove(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_MODULE);
assertEquals(GrobidProperties.getDeLFTTrainModule(), "");
}

@Test
public void testShouldReturnConfiguredModule() {
GrobidProperties.getProps().put(
GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_MODULE, "module1"
);
assertEquals(GrobidProperties.getDeLFTTrainModule(), "module1");
}

@Test
public void testShouldReturnEmptyTrainArgsByDefault() {
GrobidProperties.getProps().remove(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS);
assertEquals(GrobidProperties.getDeLFTTrainArgs(), "");
}

@Test
public void testShouldReturnConfiguredTrainArgs() {
GrobidProperties.getProps().put(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS, "args");
assertEquals(GrobidProperties.getDeLFTTrainArgs(), "args");
}

/*@Test(expected = GrobidPropertyException.class)
public void testCheckPropertiesException_shouldThrowException() {
GrobidProperties.getProps().put(
Expand Down