Skip to content

Commit

Permalink
Merge ab1441d into 7e48155
Browse files Browse the repository at this point in the history
  • Loading branch information
michelole committed Jun 14, 2019
2 parents 7e48155 + ab1441d commit db98b69
Show file tree
Hide file tree
Showing 23 changed files with 184 additions and 186 deletions.
13 changes: 9 additions & 4 deletions src/main/java/at/medunigraz/imi/bst/n2c2/ClassifierRunner.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package at.medunigraz.imi.bst.n2c2;

import at.medunigraz.imi.bst.n2c2.classifier.factory.*;
import at.medunigraz.imi.bst.n2c2.classifier.factory.ClassifierFactory;
import at.medunigraz.imi.bst.n2c2.classifier.factory.FactoryProvider;
import at.medunigraz.imi.bst.n2c2.evaluator.BasicEvaluator;
import at.medunigraz.imi.bst.n2c2.evaluator.Evaluator;
import at.medunigraz.imi.bst.n2c2.evaluator.OfficialEvaluator;
Expand All @@ -23,9 +24,13 @@ public class ClassifierRunner {
private static final Logger LOG = LogManager.getLogger();

private static final ClassifierFactory[] CLASSIFIERS = new ClassifierFactory[] {
new MajorityClassifierFactory(), new RuleBasedClassifierFactory(), // lower and upper bound
new SVMClassifierFactory(), new PerceptronClassifierFactory(), // linear methods
new NNClassifierFactory() // non-linear methods
FactoryProvider.getMajorityFactory(),
FactoryProvider.getRBCFactory(),
FactoryProvider.getSVMFactory(),
FactoryProvider.getSelfTrainedPerceptronFactory(),
FactoryProvider.getPreTrainedPerceptronFactory(),
FactoryProvider.getLSTMSelfTrainedFactory(),
FactoryProvider.getLSTMPreTrainedFactory()
};

public static void main(String[] args) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ public class PerceptronClassifier extends CriterionBasedClassifier {
*/
private static final File PRETRAINED_VECTORS = new File(PerceptronClassifier.class.getClassLoader().getResource("BioWordVec-vectors.vec").getFile());

public PerceptronClassifier(Criterion c) {
private final boolean preTrained;

public PerceptronClassifier(Criterion c, boolean preTrained) {
super(c);
this.preTrained = preTrained;
}

private String preprocess(String text) {
Expand Down Expand Up @@ -51,7 +54,10 @@ public void train(List<Patient> examples) {
trainData.put(preprocess(p.getText()), p.getEligibility(criterion).name());
}

// FastTextFacade.train(trainData);
FastTextFacade.train(trainData, PRETRAINED_VECTORS);
if (preTrained) {
FastTextFacade.train(trainData, PRETRAINED_VECTORS);
} else {
FastTextFacade.train(trainData);
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@

import java.util.List;

public interface ClassifierFactory {
public interface ClassifierFactory<C extends Classifier> {

Classifier getClassifier(Criterion criterion);
C getClassifier(Criterion criterion);

default List<Patient> trainAndPredict(List<Patient> train, List<Patient> toPredict) {
List<Patient> prediction = DatasetUtil.stripTags(toPredict);

for (Criterion criterion : Criterion.classifiableValues()) {
Classifier classifier = this.getClassifier(criterion);
C classifier = this.getClassifier(criterion);
classifier.train(train);
prediction = classifier.predict(prediction);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package at.medunigraz.imi.bst.n2c2.classifier.factory;

import at.medunigraz.imi.bst.n2c2.classifier.CriterionBasedClassifier;
import at.medunigraz.imi.bst.n2c2.model.Criterion;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class CriterionBasedClassifierFactory implements ClassifierFactory<CriterionBasedClassifier> {

protected final Map<Criterion, CriterionBasedClassifier> classifierByCriterion = new HashMap<>();

protected CriterionBasedClassifierFactory() {
// Only used by subclasses
}

public CriterionBasedClassifierFactory(Class cls) {
Arrays.stream(Criterion.classifiableValues()).forEach(c -> {
try {
classifierByCriterion.put(c, (CriterionBasedClassifier) cls.getDeclaredConstructor(Criterion.class).newInstance(c));
} catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
});
}

@Override
public CriterionBasedClassifier getClassifier(Criterion criterion) {
return classifierByCriterion.get(criterion);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package at.medunigraz.imi.bst.n2c2.classifier.factory;

import at.medunigraz.imi.bst.n2c2.classifier.FakeClassifier;
import at.medunigraz.imi.bst.n2c2.classifier.MajorityClassifier;
import at.medunigraz.imi.bst.n2c2.classifier.PatientBasedClassifier;
import at.medunigraz.imi.bst.n2c2.nn.BiLSTMCharacterTrigramClassifier;
import at.medunigraz.imi.bst.n2c2.nn.LSTMPreTrainedEmbeddingsClassifier;
import at.medunigraz.imi.bst.n2c2.nn.LSTMSelfTrainedEmbeddingsClassifier;
import at.medunigraz.imi.bst.n2c2.rules.RuleBasedClassifier;

public abstract class FactoryProvider {

public static CriterionBasedClassifierFactory getMajorityFactory() {
return new CriterionBasedClassifierFactory(MajorityClassifier.class);
}

public static CriterionBasedClassifierFactory getRBCFactory() {
return new CriterionBasedClassifierFactory(RuleBasedClassifier.class);
}

public static CriterionBasedClassifierFactory getSVMFactory() {
return new SVMClassifierFactory();
}

public static CriterionBasedClassifierFactory getSelfTrainedPerceptronFactory() {
return new PerceptronClassifierFactory();
}

public static CriterionBasedClassifierFactory getPreTrainedPerceptronFactory() {
return new PerceptronClassifierFactory(true);
}

public static PatientBasedClassifierFactory getBiLSTMCharacterTrigramFactory() {
return new PatientBasedClassifierFactory(BiLSTMCharacterTrigramClassifier.class);
}

public static ClassifierFactory<PatientBasedClassifier> getLSTMPreTrainedFactory() {
return new PatientBasedClassifierFactory(LSTMPreTrainedEmbeddingsClassifier.class);
}

public static PatientBasedClassifierFactory getLSTMSelfTrainedFactory() {
return new PatientBasedClassifierFactory(LSTMSelfTrainedEmbeddingsClassifier.class);
}

public static CriterionBasedClassifierFactory getFakeClassifierFactory() {
return new CriterionBasedClassifierFactory(FakeClassifier.class);
}
}

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package at.medunigraz.imi.bst.n2c2.classifier.factory;

import at.medunigraz.imi.bst.n2c2.classifier.PatientBasedClassifier;
import at.medunigraz.imi.bst.n2c2.model.Criterion;


public final class PatientBasedClassifierFactory implements ClassifierFactory<PatientBasedClassifier> {

private final PatientBasedClassifier classifier;

public PatientBasedClassifierFactory(Class cls) {
try {
classifier = (PatientBasedClassifier) cls.newInstance();
} catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
}

@Override
public PatientBasedClassifier getClassifier(Criterion criterion) {
return classifier;
}

}
Original file line number Diff line number Diff line change
@@ -1,34 +1,20 @@
package at.medunigraz.imi.bst.n2c2.classifier.factory;

import at.medunigraz.imi.bst.n2c2.classifier.Classifier;
import at.medunigraz.imi.bst.n2c2.classifier.PerceptronClassifier;
import at.medunigraz.imi.bst.n2c2.model.Criterion;

import java.util.HashMap;
import java.util.Map;
import java.util.Arrays;

public class PerceptronClassifierFactory implements ClassifierFactory {
public class PerceptronClassifierFactory extends CriterionBasedClassifierFactory {

private static final Map<Criterion, Classifier> classifierByCriterion = new HashMap<>();
static {
classifierByCriterion.put(Criterion.MAKES_DECISIONS, new PerceptronClassifier(Criterion.MAKES_DECISIONS));
classifierByCriterion.put(Criterion.HBA1C, new PerceptronClassifier(Criterion.HBA1C));
classifierByCriterion.put(Criterion.ASP_FOR_MI, new PerceptronClassifier(Criterion.ASP_FOR_MI));
classifierByCriterion.put(Criterion.ALCOHOL_ABUSE, new PerceptronClassifier(Criterion.ALCOHOL_ABUSE));
classifierByCriterion.put(Criterion.ADVANCED_CAD, new PerceptronClassifier(Criterion.ADVANCED_CAD));
classifierByCriterion.put(Criterion.CREATININE, new PerceptronClassifier(Criterion.CREATININE));
classifierByCriterion.put(Criterion.ENGLISH, new PerceptronClassifier(Criterion.ENGLISH));
classifierByCriterion.put(Criterion.MI_6MOS, new PerceptronClassifier(Criterion.MI_6MOS));
classifierByCriterion.put(Criterion.DRUG_ABUSE, new PerceptronClassifier(Criterion.DRUG_ABUSE));
classifierByCriterion.put(Criterion.MAJOR_DIABETES, new PerceptronClassifier(Criterion.MAJOR_DIABETES));
classifierByCriterion.put(Criterion.KETO_1YR, new PerceptronClassifier(Criterion.KETO_1YR));
classifierByCriterion.put(Criterion.ABDOMINAL, new PerceptronClassifier(Criterion.ABDOMINAL));
classifierByCriterion.put(Criterion.DIETSUPP_2MOS, new PerceptronClassifier(Criterion.DIETSUPP_2MOS));
public PerceptronClassifierFactory() {
Arrays.stream(Criterion.classifiableValues()).forEach(c ->
classifierByCriterion.put(c, new PerceptronClassifier(c, false)));
}

@Override
public Classifier getClassifier(Criterion criterion) {
return classifierByCriterion.get(criterion);
public PerceptronClassifierFactory(boolean preTrained) {
Arrays.stream(Criterion.classifiableValues()).forEach(c ->
classifierByCriterion.put(c, new PerceptronClassifier(c, preTrained)));
}

}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
package at.medunigraz.imi.bst.n2c2.classifier.factory;

import at.medunigraz.imi.bst.n2c2.classifier.Classifier;
import at.medunigraz.imi.bst.n2c2.classifier.svm.SVMClassifier;
import at.medunigraz.imi.bst.n2c2.config.Config;
import at.medunigraz.imi.bst.n2c2.model.Criterion;

import java.util.HashMap;
import java.util.Map;

public class SVMClassifierFactory implements ClassifierFactory {

private static final Map<Criterion, Classifier> classifierByCriterion = new HashMap<>();
public class SVMClassifierFactory extends CriterionBasedClassifierFactory {

public SVMClassifierFactory() {
classifierByCriterion.put(Criterion.MAKES_DECISIONS, new SVMClassifier(Criterion.MAKES_DECISIONS, Config.SVM_COST_MAKES_DECISIONS));
Expand Down Expand Up @@ -48,8 +42,4 @@ public SVMClassifierFactory(double cost) {
classifierByCriterion.put(Criterion.DIETSUPP_2MOS, new SVMClassifier(Criterion.DIETSUPP_2MOS, cost).withMonths(6));
}

@Override
public Classifier getClassifier(Criterion criterion) {
return classifierByCriterion.get(criterion);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ public static void main(String[] args) throws IOException {
List<Patient> testPatients = DatasetUtil.loadFromFolder(new File("data/test"));

Map<String, List<Patient>> data = new LinkedHashMap<>();
data.put("MAJ", new MajorityClassifierFactory().trainAndPredict(trainPatients, testPatients));
data.put("NN", new NNClassifierFactory().trainAndPredict(trainPatients, testPatients));
data.put("SVM", new SVMClassifierFactory().trainAndPredict(trainPatients, testPatients));
data.put("RBC", new RuleBasedClassifierFactory().trainAndPredict(trainPatients, testPatients));
data.put("GT", new FakeClassifierFactory().trainAndPredict(trainPatients, testPatients));
// TODO expand with Perceptron and self-trained LSTM
data.put("MAJ", FactoryProvider.getMajorityFactory().trainAndPredict(trainPatients, testPatients));
data.put("NN", FactoryProvider.getLSTMPreTrainedFactory().trainAndPredict(trainPatients, testPatients));
data.put("SVM", FactoryProvider.getSVMFactory().trainAndPredict(trainPatients, testPatients));
data.put("RBC", FactoryProvider.getRBCFactory().trainAndPredict(trainPatients, testPatients));
data.put("GT", FactoryProvider.getFakeClassifierFactory().trainAndPredict(trainPatients, testPatients));

dataToCsv(data, new File("false-analysis.csv"));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package at.medunigraz.imi.bst.n2c2.runner;

import at.medunigraz.imi.bst.n2c2.classifier.factory.ClassifierFactory;
import at.medunigraz.imi.bst.n2c2.classifier.factory.NNClassifierFactory;
import at.medunigraz.imi.bst.n2c2.classifier.factory.FactoryProvider;
import at.medunigraz.imi.bst.n2c2.model.Criterion;
import at.medunigraz.imi.bst.n2c2.model.Eligibility;
import at.medunigraz.imi.bst.n2c2.model.Patient;
Expand Down Expand Up @@ -52,7 +52,7 @@ public static void main(String[] args) {

List<Patient> trainPatients = DatasetUtil.loadFromFolder(trainFolder);

ClassifierFactory factory = new NNClassifierFactory();
ClassifierFactory factory = FactoryProvider.getLSTMPreTrainedFactory();

// BaseNNClassifier extends PatientBasedClassifier, so it's a single classifier for any criterion.
BaseNNClassifier nnClassifier = (BaseNNClassifier) factory.getClassifier(Criterion.ABDOMINAL);
Expand Down
Loading

0 comments on commit db98b69

Please sign in to comment.