Skip to content

Commit

Permalink
Submit process_mimic.py, which will process MIMIC-III dataset to buil…
Browse files Browse the repository at this point in the history
…d the training set for RETAIN.
  • Loading branch information
Edward Yoonjae Choi committed Feb 26, 2017
1 parent dd1f113 commit f99aa16
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 12 deletions.
40 changes: 34 additions & 6 deletions README.md
Expand Up @@ -18,7 +18,7 @@ RETAIN implements an algorithm introduced in the following [paper](http://papers

The RETAIN paper formulates the model as being able to make prediction at each timestep (e.g. try to predict what diagnoses the patient will receive at each visit), and treats sequence classification (e.g. Given a patient record, will he be diagnosed with heart failure in the future?) as a special case, since sequence classification makes the prediction at the last timestep only.

This code, however, is implemented to perform the sequence classification task. For example, you can use this code to predict whether the given patient is a heart failure patient or not. Or you can predict whether this patient will be readmitted in the future. The more general version of RETAIN will be released shortly in the future.
This code, however, is implemented to perform the sequence classification task. For example, you can use this code to predict whether the given patient is a heart failure patient or not. Or you can predict whether this patient will be readmitted in the future. The more general version of RETAIN will be released in the future.

####Running RETAIN

Expand All @@ -30,7 +30,31 @@ This code, however, is implemented to perform the sequence classification task.

3. Download/clone the RETAIN code

**STEP 2: Preparing training data**
**STEP 2: Fast way to test GRAM with MIMIC-III**
This step describes how to train RETAIN, with minimum number of steps using MIMIC-III, to predict patients' mortality using their visit records.

0. You will first need to request access for [MIMIC-III](https://mimic.physionet.org/gettingstarted/access/), a publicly avaiable electronic health records collected from ICU patients over 11 years.

1. You can use "process_mimic.py" to process MIMIC-III dataset and generate a suitable training dataset for RETAIN.
Place the script to the same location where the MIMIC-III CSV files are located, and run the script.
The execution command is `python process_mimic.py ADMISSIONS.csv DIAGNOSES_ICD.csv PATIENTS.csv <output file>`.
Instructions are described inside the script.

2. Run RETAIN using the ".seqs" and ".morts" file generated by process_mimic.py.
The ".seqs" file contains the sequence of visits for each patient. Each visit consists of multiple diagnosis codes.
However we recommend using ".3digitICD9.seqs" file instead, as the results will be much more interpretable.
(Or you could use [Single-level Clical Classification Software for ICD9](https://www.hcup-us.ahrq.gov/toolssoftware/ccs/ccs.jsp#examples) to decrease the number of codes to a couple of hundreds, which will even more improve the performance)
The ".morts" file contains the sequence of mortality labels for each patient.
The command is `python retain.py <3digitICD9.seqs file> 942 <morts file> <output path> --simple_load --n_epochs 100 --dropout_context 0.8 --dropout_emb 0.0`.
`942` is the number of the entire 3-digit ICD9 codes used in the dataset.

3. To test the model for interpretation, please refer to Step 6. I personally found that _perinatal jaundice (ICD9 774)_ has high correlation with mortality.

4. The model reaches AUC above 0.8 with the above command, but the interpretations are not super clear.
You could tune the hyper-parameters, but I doubt things will dramatically imporve.
After all, MIMIC-III has only 7500 patients who made more than two hospital visits, and the most visit sequences are very short (average 2 visits per patient).

**STEP 3: How to prepare your own dataset**

1. RETAIN's training dataset needs to be a Python cPickled list of list of list. The outermost list corresponds to patients, the intermediate to the visit sequence each patient made, and the innermost to the medical codes (e.g. diagnosis codes, medication codes, procedure codes, etc.) that occurred within each visit.
First, medical codes need to be converted to an integer. Then a single visit can be seen as a list of integers. Then a patient can be seen as a list of visits.
Expand Down Expand Up @@ -64,7 +88,7 @@ Specify the path to your code representation file using `--embed_file <path to e
Additionally, even if you use your own medical code representations, you can re-train (a.k.a fine-tune) them as you train RETAIN.
Use `--embed_finetune` option to do this. If you are not providing your own medical code representations, RETAIN will use randomly initialized one, which obviously requires this fine-tuning process. Since the default is to use the fine-tuning, you do not need to worry about this.

**STEP 3: Running RETAIN**
**STEP 4: Running RETAIN**

1. The minimum input you need to run RETAIN is the "visit file", the number of unique medical codes in the "visit file",
the "label file", and the output path. The output path is where the learned weights and the log will be saved.
Expand All @@ -84,17 +108,21 @@ For example `--alpha_hidden_dim_size 128` will tell RETAIN to use a GRU with 128

7. My personal recommendation: use mild regularization (0.0001 ~ 0.001) on all four weights, and use moderate dropout on the context vector only. But this entirely depends on your data, so you should always tune the hyperparameters for yourself.

**STEP 4: Getting your results**
**STEP 5: Getting your results**

RETAIN checks the AUC of the validation set after each epoch, and if it is higher than all previous values, it will save the current model. The model file is generated by [numpy.savez_compressed](http://docs.scipy.org/doc/numpy-1.10.1/reference/generated/numpy.savez_compressed.html).

**Step 5: Testing your model**
**Step 6: Testing your model**

1. Using the file "test_retain.py", you can calculate the contributions of each medical code at each visit. First you need to have a trained model that was saved by numpy.savez_compressed. Note that you need to know the configuration with which you trained RETAIN (e.g. use of `--time_file`, use of `--use_log_time`.)

2. Again, you need the "visit file" and "label file" prepared in the same way. This time, however, you do not need to follow the ".train", ".valid", ".test" rule. The testing script will try to load the file name as given.

3. You also need the mapping information between the actual string medical codes and their integer codes. (e.g. "Hypertension" is mapped to 24) This file (let's call this "mapping file") need to be a Python cPickled dictionary where the keys are the string medical codes and the values are the corresponding intergers. This file is required to print the contributions of each medical code in a user-friendly format.
3. You also need the mapping information between the actual string medical codes and their integer codes.
(e.g. "Hypertension" is mapped to 24)
This file (let's call this "mapping file") need to be a Python cPickled dictionary where the keys are the string medical codes and the values are the corresponding intergers.
(e.g. The mapping file generated by process_mimic.py is the ".types" file)
This file is required to print the contributions of each medical code in a user-friendly format.

4. For the additional options such as `--time_file` or `--use_log_time`, you should use exactly the same configuration with which you trained the model. For more detailed information, use "--help" option.

Expand Down
Binary file removed figs/retain.pdf
Binary file not shown.
165 changes: 165 additions & 0 deletions process_mimic.py
@@ -0,0 +1,165 @@
# This script processes MIMIC-III dataset and builds longitudinal diagnosis records for patients with at least two visits.
# The output data are cPickled, and suitable for training Doctor AI or RETAIN
# Written by Edward Choi (mp2893@gatech.edu)
# Usage: Put this script to the foler where MIMIC-III CSV files are located. Then execute the below command.
# python process_mimic.py ADMISSIONS.csv DIAGNOSES_ICD.csv PATIENTS.csv <output file>

# Output files
# <output file>.pids: List of unique Patient IDs. Used for intermediate processing
# <output file>.morts: List of binary values indicating the mortality of each patient
# <output file>.dates: List of List of Python datetime objects. The outer List is for each patient. The inner List is for each visit made by each patient
# <output file>.seqs: List of List of List of integer diagnosis codes. The outer List is for each patient. The middle List contains visits made by each patient. The inner List contains the integer diagnosis codes that occurred in each visit
# <output file>.types: Python dictionary that maps string diagnosis codes to integer diagnosis codes.

import sys
import cPickle as pickle
from datetime import datetime

def convert_to_icd9(dxStr):
if dxStr.startswith('E'):
if len(dxStr) > 4: return dxStr[:4] + '.' + dxStr[4:]
else: return dxStr
else:
if len(dxStr) > 3: return dxStr[:3] + '.' + dxStr[3:]
else: return dxStr

def convert_to_3digit_icd9(dxStr):
if dxStr.startswith('E'):
if len(dxStr) > 4: return dxStr[:4]
else: return dxStr
else:
if len(dxStr) > 3: return dxStr[:3]
else: return dxStr

if __name__ == '__main__':
admissionFile = sys.argv[1]
diagnosisFile = sys.argv[2]
patientsFile = sys.argv[3]
outFile = sys.argv[4]

print 'Collecting mortality information'
pidDodMap = {}
infd = open(patientsFile, 'r')
infd.readline()
for line in infd:
tokens = line.strip().split(',')
pid = int(tokens[1])
dod_hosp = tokens[5]
if len(dod_hosp) > 0:
pidDodMap[pid] = 1
else:
pidDodMap[pid] = 0
infd.close()

print 'Building pid-admission mapping, admission-date mapping'
pidAdmMap = {}
admDateMap = {}
infd = open(admissionFile, 'r')
infd.readline()
for line in infd:
tokens = line.strip().split(',')
pid = int(tokens[1])
admId = int(tokens[2])
admTime = datetime.strptime(tokens[3], '%Y-%m-%d %H:%M:%S')
admDateMap[admId] = admTime
if pid in pidAdmMap: pidAdmMap[pid].append(admId)
else: pidAdmMap[pid] = [admId]
infd.close()

print 'Building admission-dxList mapping'
admDxMap = {}
admDxMap_3digit = {}
infd = open(diagnosisFile, 'r')
infd.readline()
for line in infd:
tokens = line.strip().split(',')
admId = int(tokens[2])
dxStr = 'D_' + convert_to_icd9(tokens[4][1:-1]) ############## Uncomment this line and comment the line below, if you want to use the entire ICD9 digits.
dxStr_3digit = 'D_' + convert_to_3digit_icd9(tokens[4][1:-1])

if admId in admDxMap:
admDxMap[admId].append(dxStr)
else:
admDxMap[admId] = [dxStr]

if admId in admDxMap_3digit:
admDxMap_3digit[admId].append(dxStr_3digit)
else:
admDxMap_3digit[admId] = [dxStr_3digit]
infd.close()

print 'Building pid-sortedVisits mapping'
pidSeqMap = {}
pidSeqMap_3digit = {}
for pid, admIdList in pidAdmMap.iteritems():
if len(admIdList) < 2: continue

sortedList = sorted([(admDateMap[admId], admDxMap[admId]) for admId in admIdList])
pidSeqMap[pid] = sortedList

sortedList_3digit = sorted([(admDateMap[admId], admDxMap_3digit[admId]) for admId in admIdList])
pidSeqMap_3digit[pid] = sortedList_3digit

print 'Building pids, dates, mortality_labels, strSeqs'
pids = []
dates = []
seqs = []
morts = []
for pid, visits in pidSeqMap.iteritems():
pids.append(pid)
morts.append(pidDodMap[pid])
seq = []
date = []
for visit in visits:
date.append(visit[0])
seq.append(visit[1])
dates.append(date)
seqs.append(seq)

print 'Building pids, dates, strSeqs for 3digit ICD9 code'
seqs_3digit = []
for pid, visits in pidSeqMap_3digit.iteritems():
seq = []
for visit in visits:
seq.append(visit[1])
seqs_3digit.append(seq)

print 'Converting strSeqs to intSeqs, and making types'
types = {}
newSeqs = []
for patient in seqs:
newPatient = []
for visit in patient:
newVisit = []
for code in visit:
if code in types:
newVisit.append(types[code])
else:
types[code] = len(types)
newVisit.append(types[code])
newPatient.append(newVisit)
newSeqs.append(newPatient)

print 'Converting strSeqs to intSeqs, and making types for 3digit ICD9 code'
types_3digit = {}
newSeqs_3digit = []
for patient in seqs_3digit:
newPatient = []
for visit in patient:
newVisit = []
for code in set(visit):
if code in types_3digit:
newVisit.append(types_3digit[code])
else:
types_3digit[code] = len(types_3digit)
newVisit.append(types_3digit[code])
newPatient.append(newVisit)
newSeqs_3digit.append(newPatient)

pickle.dump(pids, open(outFile+'.pids', 'wb'), -1)
pickle.dump(dates, open(outFile+'.dates', 'wb'), -1)
pickle.dump(morts, open(outFile+'.morts', 'wb'), -1)
pickle.dump(newSeqs, open(outFile+'.seqs', 'wb'), -1)
pickle.dump(types, open(outFile+'.types', 'wb'), -1)
pickle.dump(newSeqs_3digit, open(outFile+'.3digitICD9.seqs', 'wb'), -1)
pickle.dump(types_3digit, open(outFile+'.3digitICD9.types', 'wb'), -1)
21 changes: 15 additions & 6 deletions retain.py
Expand Up @@ -16,6 +16,9 @@

from sklearn.metrics import roc_auc_score

_TEST_RATIO = 0.2
_VALIDATION_RATIO = 0.1

def unzip(zipped):
new_params = OrderedDict()
for key, value in zipped.iteritems():
Expand Down Expand Up @@ -236,7 +239,7 @@ def padMatrixWithoutTime(seqs, options):

return x, lengths

def load_data_debug(seqFile, labelFile, timeFile=''):
def load_data_simple(seqFile, labelFile, timeFile=''):
sequences = np.array(pickle.load(open(seqFile, 'rb')))
labels = np.array(pickle.load(open(labelFile, 'rb')))
if len(timeFile) > 0:
Expand All @@ -245,8 +248,8 @@ def load_data_debug(seqFile, labelFile, timeFile=''):
dataSize = len(labels)
np.random.seed(0)
ind = np.random.permutation(dataSize)
nTest = int(0.15 * dataSize)
nValid = int(0.10 * dataSize)
nTest = int(_TEST_RATIO * dataSize)
nValid = int(_VALIDATION_RATIO * dataSize)

test_indices = ind[:nTest]
valid_indices = ind[nTest:nTest+nValid]
Expand Down Expand Up @@ -390,7 +393,7 @@ def train_RETAIN(
labelFile='labelFile.txt',
numClass=1,
outFile='outFile.txt',
timeFile='timeFile.txt',
timeFile='',
modelFile='model.npz',
useLogTime=True,
embFile='embFile.txt',
Expand All @@ -408,6 +411,7 @@ def train_RETAIN(
dropoutRateContext=0.5,
logEps=1e-8,
solver='adadelta',
simpleLoad=False,
verbose=False
):
options = locals().copy()
Expand Down Expand Up @@ -470,7 +474,10 @@ def train_RETAIN(
get_cost = theano.function(inputs=[x, y, lengths], outputs=cost_noreg, name='get_cost')

print 'Loading data ... ',
trainSet, validSet, testSet = load_data(seqFile, labelFile, timeFile)
if simpleLoad:
trainSet, validSet, testSet = load_data_simple(seqFile, labelFile, timeFile)
else:
trainSet, validSet, testSet = load_data(seqFile, labelFile, timeFile)
n_batches = int(np.ceil(float(len(trainSet[0])) / float(batchSize)))
print 'done'

Expand Down Expand Up @@ -531,7 +538,7 @@ def parse_arguments(parser):
parser.add_argument('n_input_codes', type=int, metavar='<n_input_codes>', help='The number of unique input medical codes')
parser.add_argument('label_file', type=str, metavar='<label_file>', help='The path to the Pickled file containing label information of patients')
#parser.add_argument('n_output_codes', type=int, metavar='<n_output_codes>', help='The number of unique label medical codes')
parser.add_argument('out_file', metavar='out_file', help='The path to the output models. The models will be saved after every epoch')
parser.add_argument('out_file', metavar='<out_file>', help='The path to the output models. The models will be saved after every epoch')
parser.add_argument('--time_file', type=str, default='', help='The path to the Pickled file containing durations between visits of patients. If you are not using duration information, do not use this option')
parser.add_argument('--model_file', type=str, default='', help='The path to the Numpy-compressed file containing the model parameters. Use this option if you want to re-train an existing model')
parser.add_argument('--use_log_time', type=int, default=1, choices=[0,1], help='Use logarithm of time duration to dampen the impact of the outliers (0 for false, 1 for true) (default value: 1)')
Expand All @@ -550,6 +557,7 @@ def parse_arguments(parser):
parser.add_argument('--dropout_context', type=float, default=0.5, help='Dropout rate between the context vector c_i and the final classifier (default value: 0.5)')
parser.add_argument('--log_eps', type=float, default=1e-8, help='A small value to prevent log(0) (default value: 1e-8)')
parser.add_argument('--solver', type=str, default='adadelta', choices=['adadelta','adam'], help='Select which solver to train RETAIN: adadelta, or adam. (default: adadelta)')
parser.add_argument('--simple_load', action='store_true', help='Use an alternative way to load the dataset. Instead of you having to provide a trainign set, validation set, test set, this will automatically divide the dataset. (default false)')
parser.add_argument('--verbose', action='store_true', help='Print output after every 100 mini-batches (default false)')
args = parser.parse_args()
return args
Expand Down Expand Up @@ -582,5 +590,6 @@ def parse_arguments(parser):
dropoutRateContext=args.dropout_context,
logEps=args.log_eps,
solver=args.solver,
simpleLoad=args.simple_load,
verbose=args.verbose
)

0 comments on commit f99aa16

Please sign in to comment.