Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/bst-mug/n2c2
Browse files Browse the repository at this point in the history
  • Loading branch information
michelole committed Apr 3, 2018
2 parents 836701a + 121b18a commit ef882c7
Show file tree
Hide file tree
Showing 3 changed files with 432 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public class BILSTMClassifier implements Classifier {
public BILSTMClassifier(List<Patient> examples) {

this.patientExamples = examples;
this.wordVectors = WordVectorSerializer.loadStaticModel(new File(WORD_VECTORS_PATH));

initializeTokenizer();
initializeTruncateLength();
Expand All @@ -82,6 +83,9 @@ private void initializeTokenizer() {
tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
}

/**
* SOFTMAX activation and MCXENT loss function for binary classification.
*/
private void initializeNetwork() {

// initialize network
Expand All @@ -102,6 +106,30 @@ private void initializeNetwork() {
this.net.setListeners(new ScoreIterationListener(1));
}

/**
* SIGMOID activation and XENT loss function for binary multi-label
* classification.
*/
private void initializeNetworkBinaryMultiLabel() {

// initialize network
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.ADAM).adamMeanDecay(0.9)
.adamVarDecay(0.999).regularization(true).l2(1e-5).weightInit(WeightInit.XAVIER)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(1.0).learningRate(2e-2).list()
.layer(0,
new GravesBidirectionalLSTM.Builder().nIn(vectorSize).nOut(truncateLength)
.activation(Activation.TANH).build())
.layer(1,
new RnnOutputLayer.Builder().activation(Activation.SIGMOID)
.lossFunction(LossFunctions.LossFunction.XENT).nIn(truncateLength).nOut(13).build())
.pretrain(false).backprop(true).build();

this.net = new MultiLayerNetwork(conf);
this.net.init();
this.net.setListeners(new ScoreIterationListener(1));
}

/**
* Get longest token sequence of all patients with respect to existing word
* vector out of Google corpus.
Expand Down Expand Up @@ -130,7 +158,6 @@ public void train(List<Patient> examples) {

// start training
try {
WordVectors wordVectors = WordVectorSerializer.loadStaticModel(new File(WORD_VECTORS_PATH));
N2c2PatientIterator train = new N2c2PatientIterator(examples, wordVectors, miniBatchSize, truncateLength);

System.out.println("Starting training");
Expand Down
Loading

0 comments on commit ef882c7

Please sign in to comment.