diff --git a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMClassifier.java b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMClassifier.java index 10ff343..2953387 100644 --- a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMClassifier.java +++ b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMClassifier.java @@ -71,6 +71,7 @@ public class BILSTMClassifier implements Classifier { public BILSTMClassifier(List examples) { this.patientExamples = examples; + this.wordVectors = WordVectorSerializer.loadStaticModel(new File(WORD_VECTORS_PATH)); initializeTokenizer(); initializeTruncateLength(); @@ -82,6 +83,9 @@ private void initializeTokenizer() { tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); } + /** + * SOFTMAX activation and MCXENT loss function for binary classification. + */ private void initializeNetwork() { // initialize network @@ -102,6 +106,30 @@ private void initializeNetwork() { this.net.setListeners(new ScoreIterationListener(1)); } + /** + * SIGMOID activation and XENT loss function for binary multi-label + * classification. + */ + private void initializeNetworkBinaryMultiLabel() { + + // initialize network + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.ADAM).adamMeanDecay(0.9) + .adamVarDecay(0.999).regularization(true).l2(1e-5).weightInit(WeightInit.XAVIER) + .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) + .gradientNormalizationThreshold(1.0).learningRate(2e-2).list() + .layer(0, + new GravesBidirectionalLSTM.Builder().nIn(vectorSize).nOut(truncateLength) + .activation(Activation.TANH).build()) + .layer(1, + new RnnOutputLayer.Builder().activation(Activation.SIGMOID) + .lossFunction(LossFunctions.LossFunction.XENT).nIn(truncateLength).nOut(13).build()) + .pretrain(false).backprop(true).build(); + + this.net = new MultiLayerNetwork(conf); + this.net.init(); + this.net.setListeners(new ScoreIterationListener(1)); + } + /** * Get longest token sequence of all patients with respect to existing word * vector out of Google corpus. @@ -130,7 +158,6 @@ public void train(List examples) { // start training try { - WordVectors wordVectors = WordVectorSerializer.loadStaticModel(new File(WORD_VECTORS_PATH)); N2c2PatientIterator train = new N2c2PatientIterator(examples, wordVectors, miniBatchSize, truncateLength); System.out.println("Starting training"); diff --git a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/N2c2PatientIteratorBML.java b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/N2c2PatientIteratorBML.java new file mode 100644 index 0000000..42faf73 --- /dev/null +++ b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/N2c2PatientIteratorBML.java @@ -0,0 +1,366 @@ +package at.medunigraz.imi.bst.n2c2.nn; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.NoSuchElementException; + +import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; +import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; +import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; +import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.api.DataSetPreProcessor; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.INDArrayIndex; +import org.nd4j.linalg.indexing.NDArrayIndex; + +import at.medunigraz.imi.bst.n2c2.model.Criterion; +import at.medunigraz.imi.bst.n2c2.model.Eligibility; +import at.medunigraz.imi.bst.n2c2.model.Patient; + +/** + * Date iterator refactored from dl4j examples. + * + * @author Markus + * + */ +public class N2c2PatientIteratorBML implements DataSetIterator { + + private static final long serialVersionUID = 1L; + + private final WordVectors wordVectors; + + private final int batchSize; + private final int vectorSize; + private final int truncateLength; + + private int cursor = 0; + private final TokenizerFactory tokenizerFactory; + + private List patients; + + /** + * Patient data iterator for the n2c2 task. + * + * @param patients + * Patient data. + * @param wordVectors + * Word vectors object. + * @param batchSize + * Mini batch size use for processing. + * @param truncateLength + * Maximum length of token sequence. + * @throws IOException + */ + public N2c2PatientIteratorBML(List patients, WordVectors wordVectors, int batchSize, int truncateLength) + throws IOException { + + this.patients = patients; + this.batchSize = batchSize; + this.vectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length; + + this.wordVectors = wordVectors; + this.truncateLength = truncateLength; + + tokenizerFactory = new DefaultTokenizerFactory(); + tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); + } + + /* + * (non-Javadoc) + * + * @see org.nd4j.linalg.dataset.api.iterator.DataSetIterator#next(int) + */ + @Override + public DataSet next(int num) { + if (cursor >= patients.size()) + throw new NoSuchElementException(); + try { + return nextPatientsDataSet(num); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** + * Next data set implementation. + * + * @param num + * Mini batch size. + * @return DataSet Patients data set. + * @throws IOException + */ + private DataSet nextPatientsDataSet(int num) throws IOException { + + HashMap> binaryMultiHotVectorMap = new HashMap>(); + + // load narrative from patient + List narratives = new ArrayList<>(num); + for (int i = 0; i < num && cursor < totalExamples(); i++) { + String narrative = patients.get(cursor).getText(); + narratives.add(narrative); + + ArrayList binaryMultiHotVector = new ArrayList(); + + // 0 + binaryMultiHotVector.add(patients.get(cursor).getEligibility(Criterion.ABDOMINAL).equals(Eligibility.MET)); + + // + binaryMultiHotVector + .add(patients.get(cursor).getEligibility(Criterion.ADVANCED_CAD).equals(Eligibility.MET)); + + // + binaryMultiHotVector + .add(patients.get(cursor).getEligibility(Criterion.ALCOHOL_ABUSE).equals(Eligibility.MET)); + + // + binaryMultiHotVector.add(patients.get(cursor).getEligibility(Criterion.ASP_FOR_MI).equals(Eligibility.MET)); + + // + binaryMultiHotVector.add(patients.get(cursor).getEligibility(Criterion.CREATININE).equals(Eligibility.MET)); + + // + binaryMultiHotVector + .add(patients.get(cursor).getEligibility(Criterion.DIETSUPP_2MOS).equals(Eligibility.MET)); + + // + binaryMultiHotVector.add(patients.get(cursor).getEligibility(Criterion.DRUG_ABUSE).equals(Eligibility.MET)); + + // + binaryMultiHotVector.add(patients.get(cursor).getEligibility(Criterion.ENGLISH).equals(Eligibility.MET)); + + // + binaryMultiHotVector.add(patients.get(cursor).getEligibility(Criterion.HBA1C).equals(Eligibility.MET)); + + // + binaryMultiHotVector.add(patients.get(cursor).getEligibility(Criterion.KETO_1YR).equals(Eligibility.MET)); + + // + binaryMultiHotVector + .add(patients.get(cursor).getEligibility(Criterion.MAJOR_DIABETES).equals(Eligibility.MET)); + + // + binaryMultiHotVector + .add(patients.get(cursor).getEligibility(Criterion.MAKES_DECISIONS).equals(Eligibility.MET)); + + // + binaryMultiHotVector.add(patients.get(cursor).getEligibility(Criterion.MI_6MOS).equals(Eligibility.MET)); + + binaryMultiHotVectorMap.put(i, binaryMultiHotVector); + cursor++; + } + + // filter unknown words and tokenize + List> allTokens = new ArrayList<>(narratives.size()); + int maxLength = 0; + for (String narrative : narratives) { + List tokens = tokenizerFactory.create(narrative).getTokens(); + List tokensFiltered = new ArrayList<>(); + for (String token : tokens) { + if (wordVectors.hasWord(token)) + tokensFiltered.add(token); + } + allTokens.add(tokensFiltered); + maxLength = Math.max(maxLength, tokensFiltered.size()); + } + + // truncate if sequence is longer than truncateLength + if (maxLength > truncateLength) + maxLength = truncateLength; + + INDArray features = Nd4j.create(narratives.size(), vectorSize, maxLength); + INDArray labels = Nd4j.create(narratives.size(), 13, maxLength); + + INDArray featuresMask = Nd4j.zeros(narratives.size(), maxLength); + INDArray labelsMask = Nd4j.zeros(narratives.size(), maxLength); + + int[] temp = new int[2]; + for (int i = 0; i < narratives.size(); i++) { + List tokens = allTokens.get(i); + temp[0] = i; + + // get word vectors for each token in narrative + for (int j = 0; j < tokens.size() && j < maxLength; j++) { + String token = tokens.get(j); + INDArray vector = wordVectors.getWordVectorMatrix(token); + features.put(new INDArrayIndex[] { NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j) }, + vector); + + temp[1] = j; + featuresMask.putScalar(temp, 1.0); + } + + int lastIdx = Math.min(tokens.size(), maxLength); + + // set binary multi-labels + ArrayList binaryMultiHotVector = binaryMultiHotVectorMap.get(i); + int labelIndex = 0; + for (Boolean label : binaryMultiHotVector) { + labels.putScalar(new int[] { i, labelIndex, lastIdx - 1 }, label == true ? 1.0 : 0.0); + labelIndex++; + } + // out exists at the final step of the sequence + labelsMask.putScalar(new int[] { i, lastIdx - 1 }, 1.0); + } + return new DataSet(features, labels, featuresMask, labelsMask); + } + + /* + * (non-Javadoc) + * + * @see org.nd4j.linalg.dataset.api.iterator.DataSetIterator#totalExamples() + */ + @Override + public int totalExamples() { + return this.patients.size(); + } + + /* + * (non-Javadoc) + * + * @see org.nd4j.linalg.dataset.api.iterator.DataSetIterator#inputColumns() + */ + @Override + public int inputColumns() { + return vectorSize; + } + + /* + * (non-Javadoc) + * + * @see org.nd4j.linalg.dataset.api.iterator.DataSetIterator#totalOutcomes() + */ + @Override + public int totalOutcomes() { + return 2; + } + + /* + * (non-Javadoc) + * + * @see org.nd4j.linalg.dataset.api.iterator.DataSetIterator#reset() + */ + @Override + public void reset() { + cursor = 0; + } + + /* + * (non-Javadoc) + * + * @see + * org.nd4j.linalg.dataset.api.iterator.DataSetIterator#resetSupported() + */ + public boolean resetSupported() { + return true; + } + + /* + * (non-Javadoc) + * + * @see + * org.nd4j.linalg.dataset.api.iterator.DataSetIterator#asyncSupported() + */ + @Override + public boolean asyncSupported() { + return true; + } + + /* + * (non-Javadoc) + * + * @see org.nd4j.linalg.dataset.api.iterator.DataSetIterator#batch() + */ + @Override + public int batch() { + return batchSize; + } + + /* + * (non-Javadoc) + * + * @see org.nd4j.linalg.dataset.api.iterator.DataSetIterator#cursor() + */ + @Override + public int cursor() { + return cursor; + } + + /* + * (non-Javadoc) + * + * @see org.nd4j.linalg.dataset.api.iterator.DataSetIterator#numExamples() + */ + @Override + public int numExamples() { + return totalExamples(); + } + + /* + * (non-Javadoc) + * + * @see + * org.nd4j.linalg.dataset.api.iterator.DataSetIterator#setPreProcessor(org. + * nd4j.linalg.dataset.api.DataSetPreProcessor) + */ + @Override + public void setPreProcessor(DataSetPreProcessor preProcessor) { + throw new UnsupportedOperationException(); + } + + /* + * (non-Javadoc) + * + * @see org.nd4j.linalg.dataset.api.iterator.DataSetIterator#getLabels() + */ + @Override + public List getLabels() { + return Arrays.asList("positive", "negative"); + } + + /* + * (non-Javadoc) + * + * @see java.util.Iterator#hasNext() + */ + @Override + public boolean hasNext() { + return cursor < numExamples(); + } + + /* + * (non-Javadoc) + * + * @see java.util.Iterator#next() + */ + @Override + public DataSet next() { + return next(batchSize); + } + + /* + * (non-Javadoc) + * + * @see java.util.Iterator#remove() + */ + @Override + public void remove() { + + } + + /* + * (non-Javadoc) + * + * @see + * org.nd4j.linalg.dataset.api.iterator.DataSetIterator#getPreProcessor() + */ + @Override + public DataSetPreProcessor getPreProcessor() { + throw new UnsupportedOperationException("Not implemented"); + } +} diff --git a/src/test/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMClassifierTest.java b/src/test/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMClassifierTest.java new file mode 100644 index 0000000..fc1a422 --- /dev/null +++ b/src/test/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMClassifierTest.java @@ -0,0 +1,38 @@ +package at.medunigraz.imi.bst.n2c2.nn; + +import static org.junit.Assert.assertEquals; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.io.FileUtils; +import org.apache.commons.io.filefilter.TrueFileFilter; +import org.junit.Ignore; +import org.xml.sax.SAXException; + +import at.medunigraz.imi.bst.n2c2.dao.PatientDAO; +import at.medunigraz.imi.bst.n2c2.model.Patient; + +public class BILSTMClassifierTest { + + @Ignore + public void train() throws IOException, SAXException { + + // read in patients + File sampleDirectory = new File("Z:/n2c2/data/samples/"); + List sampleFiles = (List) FileUtils.listFiles(sampleDirectory, TrueFileFilter.INSTANCE, + TrueFileFilter.INSTANCE); + + List patients = new ArrayList(); + for (File patientSample : sampleFiles) { + patients.add(new PatientDAO().fromXML(patientSample)); + } + + BILSTMClassifier classifier = new BILSTMClassifier(patients); + classifier.train(patients); + + assertEquals(true, true); + } +}