# Reproducibility Project Notebook for Readmission Prediction via Deep Contextual Embedding of Clinical Concepts

- Data processing
- Content model implementation
- Training and evaluation

In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [2]:
import os
import pickle
import random
import numpy as np
import time
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from util import *
from common import full_eval

from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, precision_recall_curve, auc

## Process EHR data

Based on EHR data sorted by time, convert description test to numerical ids.


In [3]:
raw_file = './data/S1_File.txt' # this is the original synthetic data file, we use sorted version instead
input_file = './resource/s1_sorted.csv'
vocab_file = './resource/vocab.txt'
stop_file = './resource/stop.txt'
vocab_pkl = './resource/vocab.pkl'

### About Raw Data

Synthetic data based on real EHR data. 3000 patients in total. No demographic information is included.

- `PID` patient id
- `DAY_ID` numerical date identifier with time difference preserved
- `DX_GROUP_DESCRIPTION` diagnosis text descriptions
- `SERVICE_LOCATION`
- `OP_DATE` record posting date

In [4]:
df = pd.read_csv(raw_file, sep='\t', header=0)
print(df[0:3])

   PID  DAY_ID                               DX_GROUP_DESCRIPTION  \
0    1   73888                                    ANGINA PECTORIS   
1    1   73888  MONONEURITIS OF UPPER LIMB AND MONONEURITIS MU...   
2    1   73888  SYMPTOMS INVOLVING RESPIRATORY SYSTEM AND OTHE...   

  SERVICE_LOCATION  OP_DATE  
0   DOCTORS OFFICE    74084  
1   DOCTORS OFFICE    74084  
2   DOCTORS OFFICE    74084  


The data should be first sorted by time

In [5]:
def sort_data():
    df = pd.read_csv(raw_file, sep='\t', header=0)
    sorted_df = df.sort_values(by=['PID', 'DAY_ID'], ascending=True).reset_index().drop(columns=["index"])
    sorted_df.to_csv(input_file, sep='\t', index=False)

`dump_vocab` parse S1 data, collect DX_GROUP_DESCRIPTION vocabulary.

Filter words with low occurrence (rare word), store high occurrence word in stop.txt.

In [12]:
def dump_vocab():
    df = pd.read_csv(input_file, sep='\t', header=0)

    # .to_frame(): indexed by the groups, with a custom name
    # .reset_index(): set the groups to be columns again
    # after groupby, there are 1412 unique descriptions
    hist = df.groupby('DX_GROUP_DESCRIPTION').size().to_frame('SIZE').reset_index()
    print(hist[0:3])

    # show some stats
    hist_sort = hist.sort_values(by='SIZE', ascending=False)
    print(hist_sort[0:3])
    count = hist.groupby('SIZE').size().to_frame('COUNT').reset_index()
    print(count)

    # filter low occurrence descriptions, this leaves 490 unique descriptions with more than 100 occurrences
    hist = hist[hist['SIZE'] > rare_word]
    print(hist)

    # dump
    vocab = hist.sort_values(by='SIZE').reset_index()['DX_GROUP_DESCRIPTION']
    vocab.index += 2  # reserve 1 to unk
    vocab.to_csv(vocab_file, sep='\t', header=False, index=True)

    # there are 12 descriptions with more than 10000 occurrences.
    hist[hist['SIZE'] > stop_word].reset_index()['DX_GROUP_DESCRIPTION'] \
        .to_csv(stop_file, sep='\t', header=False, index=False)

In [13]:
def load_vocab():
    word_to_index = {}
    with open(vocab_file, mode='r') as f:
        line = f.readline()
        while line != '':
            tokens = line.strip().split('\t')
            word_to_index[tokens[1]] = int(tokens[0])
            line = f.readline()
    print('dict size: ' + str(len(word_to_index)))
    save_pkl(vocab_pkl, {v: k for k, v in word_to_index.items()})
    return word_to_index

Events with label 'INPATIENT HOSPITAL' signals hospital admission.
Group by patient and date then sorted by time for event sequence parsing.

<img src=./doc/visit_sequence.png>

In [14]:
def extract_events():
    # extract event "INPATIENT HOSPITAL"
    target_event = 'INPATIENT HOSPITAL'

    df = pd.read_csv(input_file, sep='\t', header=0)
    events = df[df['SERVICE_LOCATION'] == target_event]

    # 30742 pid-day_id pairs with inpatient hospital event
    events = events.groupby(['PID', 'DAY_ID', 'SERVICE_LOCATION']).size().to_frame('COUNT').reset_index() \
        .sort_values(by=['PID', 'DAY_ID'], ascending=True) \
        .set_index('PID')

    return events

`convert_format` group records in to patient-date 2d array, while converting description text to their numerical ids.
Unknown description represented by 1.

Tag sequence with positive readmission label if there are inpatient events within 30 days of the current visit.

In [15]:
def tag(events, pid, day_id):
    return 1 if tag_logic(events, pid, day_id) else 0


def tag_logic(events, pid, day_id):
    try:
        patient = events.loc[int(pid)]

        # test whether have events within 30 days
        if isinstance(patient, pd.Series):
            return (int(day_id) <= patient.DAY_ID) & (patient.DAY_ID < int(day_id) + 30)

        return patient.loc[(int(day_id) <= patient.DAY_ID) & (patient.DAY_ID < int(day_id) + 30)].shape[0] > 0
    except KeyError:
        # the label is not in the [index]
        return False

    
def convert_format(word_to_index, events):
    # order by PID, DAY_ID
    with open(input_file, mode='r') as f:
        # header
        header = f.readline().strip().split('\t')
        print(header)
        pos = {}
        for key, value in enumerate(header):
            pos[value] = key
        print(pos)

        docs = []  #
        doc = []  # packs all events of the same patient
        sent = []  # pack events in the same day
        labels = []
        label = []

        # init
        line = f.readline()
        tokens = line.strip().split('\t')
        pid = tokens[pos['PID']]
        day_id = tokens[pos['DAY_ID']]
        label.append(tag(events, pid, day_id))

        while line != '':
            tokens = line.strip().split('\t')
            c_pid = tokens[pos['PID']]
            c_day_id = tokens[pos['DAY_ID']]

            # move to next patient
            if c_pid != pid:
                doc.append(sent)
                docs.append(doc)
                sent = []
                doc = []
                pid = c_pid
                day_id = c_day_id
                labels.append(label)
                label = [tag(events, pid, day_id)]
            else:
                if c_day_id != day_id:
                    doc.append(sent)
                    sent = []
                    day_id = c_day_id
                    label.append(tag(events, pid, day_id))

            word = tokens[pos['DX_GROUP_DESCRIPTION']]
            try:
                sent.append(word_to_index[word])
            except KeyError:
                sent.append(unknown)

            line = f.readline()

        # closure
        doc.append(sent)
        docs.append(doc)
        labels.append(label)

    return docs, labels

Then split sequence and labels into train, validation, test sets.

Each patient would only belong to one set.

In [16]:
def split_data(docs, labels):
    # train, validate, test
    # X, Y,
    # 3000 patients
    print(len(docs))
    print(len(labels))

    save_pkl('./resource/X_train.pkl', docs[:2000])
    save_pkl('./resource/Y_train.pkl', labels[:2000])
    save_pkl('./resource/X_valid.pkl', docs[2000:2350])
    save_pkl('./resource/Y_valid.pkl', labels[2000:2350])
    save_pkl('./resource/X_test.pkl', docs[2350:])
    save_pkl('./resource/Y_test.pkl', labels[2350:])
    save_pkl('./resource/X_complete.pkl', docs)
    save_pkl('./resource/Y_complete.pkl', labels)


Data preprocessing main logic commented out to directly use saved intermediate results in `resource/`

In [None]:
# sort_data()
# dump_vocab()
# word_to_index = load_vocab()
# events = extract_events()

# docs, labels = convert_format(word_to_index, events)
# split_data(docs, labels)

In [4]:
doc_train = load_pkl("resource/X_train.pkl")
lb_train = load_pkl("resource/Y_train.pkl")
doc_val = load_pkl("resource/X_valid.pkl")
lb_val = load_pkl("resource/Y_valid.pkl")
doc_test = load_pkl("resource/X_test.pkl")
lb_test = load_pkl("resource/Y_test.pkl")

 [*] load resource/X_train.pkl
 [*] load resource/Y_train.pkl
 [*] load resource/X_valid.pkl
 [*] load resource/Y_valid.pkl
 [*] load resource/X_test.pkl
 [*] load resource/Y_test.pkl


In [6]:
num_codes = 492

In [18]:
def split_sequence_hot_code(docs, labels):
    split_sequences = []
    split_labels = []
    idx_to_patient = []
    for i in range(len(docs)):
        patient_seq = docs[i]
        patient_labels = labels[i]
        for j in range(len(patient_seq)):
            sub_seq = patient_seq[0:j+1]
            seq_hc = []
            for visit in sub_seq:
                visit_hc = [0] * (num_codes-1)
                for mcode in visit:
                    visit_hc[mcode-1] = 1
                seq_hc.append(visit_hc)
            split_sequences.append(seq_hc)
            split_labels.append(patient_labels[j])
            idx_to_patient.append(i)
    return split_sequences, split_labels

Optional: persist split result to `./resource`. Take up less than 2GiB of disk space.

In [20]:
seq_train, labels_train = split_sequence_hot_code(doc_train, lb_train)
seq_val, labels_val = split_sequence_hot_code(doc_val, lb_val)
seq_test, labels_test = split_sequence_hot_code(doc_test, lb_test)

# store multi hot encoded data
save_pkl('./resource/X_train_mhc.pkl', seq_train)
save_pkl('./resource/Y_train_mhc.pkl', labels_train)
save_pkl('./resource/X_valid_mhc.pkl', seq_val)
save_pkl('./resource/Y_valid_mhc.pkl', labels_val)
save_pkl('./resource/X_test_mhc.pkl', seq_test)
save_pkl('./resource/Y_test_mhc.pkl', labels_test)

## Prepare data loaders

In [7]:
seq_train = load_pkl('./resource/X_train_mhc.pkl')
labels_train = load_pkl('./resource/Y_train_mhc.pkl')
seq_val = load_pkl('./resource/X_valid_mhc.pkl')
labels_val = load_pkl('./resource/Y_valid_mhc.pkl')
seq_test = load_pkl('./resource/X_test_mhc.pkl')
labels_test = load_pkl('./resource/Y_test_mhc.pkl')


class CustomDataset(Dataset):

    def __init__(self, docs, labels):
        self.x = docs
        self.y = labels

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        return self.x[index], self.y[index]


def collate_fn(data):
    """

    Arguments:
        data: a list of samples fetched from `CustomDataset`

    Outputs:
        x: a tensor of shape (# patiens, max # visits, largest diagnosis code) of type torch.float, multi-host encoding of diagnosis code within each visit
        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time.
        rev_masks: same as mask but in reversed time.
        y: a tensor of shape (# patiens) of type torch.float
    """

    sequences, labels = zip(*data)

    y = torch.tensor(labels, dtype=torch.float)

    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)

    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.float)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.float)
    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
            masks[i_patient][j_visit][:len(visit)] = True
            x[i_patient][j_visit][:len(visit)] = torch.tensor(visit).type(torch.float)
            rev_masks[i_patient][len(patient) - 1 - j_visit][:len(visit)] = True
            rev_x[i_patient][len(patient) - 1 - j_visit][:len(visit)] = torch.tensor(visit).type(torch.float)

    return x, masks, rev_x, rev_masks, y



dataset_train = CustomDataset(seq_train, labels_train)
train_loader = DataLoader(dataset_train, batch_size=16, collate_fn=collate_fn, shuffle=True)

dataset_val = CustomDataset(seq_val, labels_val)
val_loader = DataLoader(dataset_val, batch_size=16, collate_fn=collate_fn, shuffle=False)

dataset_test = CustomDataset(seq_test, labels_test)
test_loader = DataLoader(dataset_test, batch_size=16, collate_fn=collate_fn, shuffle=False)

 [*] load ./resource/X_train_mhc.pkl
 [*] load ./resource/Y_train_mhc.pkl
 [*] load ./resource/X_valid_mhc.pkl
 [*] load ./resource/Y_valid_mhc.pkl
 [*] load ./resource/X_test_mhc.pkl
 [*] load ./resource/Y_test_mhc.pkl


# CONTENT model implementation



<img src=./doc/content_model_illustration.png>

First implement the recognition network (lower portion of the illustration above), which use MLP to produce distribution parameters for patient context generation.

- The MLP produce $log(\sigma)$ vector and $\mu$ vector
- Then the topic vector would be generated from distribution $N(\mu, \sigma^2)$

In [8]:
class Recognition(torch.nn.Module):

    def __init__(self, input_dim=num_codes-1, hidden_dim=200, topic_dim=50):
        super().__init__()
        """
        Define the recognition MLP that generates topic vector theta;

        Arguments:
            input_dim: generator does not take embeddings, directly put input dimension here
        """

        self.a_att = nn.Linear(input_dim, hidden_dim)
        self.b_att = nn.Linear(hidden_dim, hidden_dim)
        self.u_ln = nn.Linear(hidden_dim, topic_dim)
        self.sigma_ln = nn.Linear(hidden_dim, topic_dim)

        self.hidden = hidden_dim

    def forward(self, x, masks):
        """

        Arguments:
            x: the multi hot encoded visits (batch_size, # visits, # total diagnosis codes)
            masks: the padding masks of shape (batch_size, # visits, # total diagnosis codes)

        Outputs:
            gen: generated value from learned distribution
        """
        # MLP to obtain mean and log_sigma values
        x = torch.relu(self.a_att(x))  # (batch, visit, input) -> (batch, visit, hidden)
        x = torch.relu(self.b_att(x))
        lu = self.u_ln(x)  # -> (batch, visit, n_topic)
        ls = self.sigma_ln(x)  # -> (batch, visit, n_topic)
        visit_masks = torch.sum(masks, dim=-1).type(torch.bool)  # (batch, visit)
        # calculate mean with mask
        # (batch, n_topic) / (batch, 1)
        mean_u = torch.sum(lu * visit_masks.unsqueeze(-1), dim=1) / torch.sum(visit_masks, dim=-1).unsqueeze(-1)
        mean_log_sigma = torch.sum(ls * visit_masks.unsqueeze(-1), dim=1) / torch.sum(visit_masks, dim=-1).unsqueeze(-1)
        # generate from learned distribution
        gen = torch.randn(mean_u.shape) * torch.exp(mean_log_sigma) + mean_u  # (batch, n_topic)
        return gen

Utility function to obtain GRU output at last valid visit

In [9]:
def get_last_visit(hidden_states, masks):
    """
    Arguments:
        hidden_states: the hidden states of each visit of shape (batch_size, # visits, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        last_hidden_state: the hidden state for the last true visit of shape (batch_size, embedding_dim)

    First convert the mask to a vector of shape (batch_size,) containing the true visit length;
          and then use this length vector as index to select the last visit.
    """

    idx = torch.sum(torch.sum(masks, -1) > 0, -1)
    # pass two list in index [], so that each row would select different index according to idx.
    return hidden_states[range(hidden_states.shape[0]), idx - 1, :]

Then the complete content structure, include the upper GRU network.

The output from Recognition and GRU network are stiched together by matrix Q and B

$$ Q^T * H + B^T * \theta = Logits $$

In [10]:

class Content(torch.nn.Module):
    """
    Define the CONTENT network that contains recognition and GRU modules;
    """
    def __init__(self, input_dim=num_codes-1, embedding_dim=100, hidden_dim=200, topic_dim=50):
        """
        Arguments:
            input_dim: generator does not take embeddings, directly put input dimension here
        """
        super().__init__()

        self.fc_embedding = nn.Linear(in_features=input_dim, out_features=embedding_dim)
        self.rnn = nn.GRU(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True)
        self.recognition = Recognition(input_dim=input_dim, hidden_dim=hidden_dim, topic_dim=topic_dim)
        self.fc_q = nn.Linear(hidden_dim, 1)
        self.fc_b = nn.Linear(topic_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, masks):
        """
        Arguments:
            x: the multi hot encoded visits (batch_size, # visits, # total diagnosis codes)
            masks: the padding masks of shape (batch_size, # visits, # total diagnosis codes)
        """
        # x = x.type(dtype=torch.float)
        x_embed = self.fc_embedding(x)
        output, _ = self.rnn(x_embed)
        final_visit_h = get_last_visit(output, masks)  # (batch_size, hidden_dim)
        topics = self.recognition(x, masks)  # (batch_size, n_topic)
        score = self.fc_q(final_visit_h) + self.fc_b(topics)  # (batch_size, 1)
        return self.sigmoid(score).squeeze(dim=-1)

# Training and Evaluation

In [11]:
ctn = Content(input_dim=num_codes-1)  # total vocab 491

# load the loss function
criterion = nn.BCELoss()
# load the optimizer
optimizer = torch.optim.Adam(ctn.parameters(), lr=0.00002)

In [12]:
def train(model, train_loader, val_loader, n_epochs):
    """
    Train the model.

    Arguments:
        model: the RNN model
        train_loader: training dataloder
        val_loader: validation dataloader
        n_epochs: total number of epochs
    """

    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for x, masks, rev_x, rev_masks, y in train_loader:
            optimizer.zero_grad()
            y_hat = model(x, masks)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch + 1, train_loss))
        p, r, f, roc_auc = eval(model, val_loader)
        print('Epoch: {} \t Validation p: {:.4f}, r:{:.4f}, f: {:.4f}, roc_auc: {:.4f}'.format(epoch + 1, p, r, f,
                                                                                               roc_auc))
    return round(roc_auc, 2)

In [13]:
def eval(model, val_loader):
    """
    Evaluate the model.

    Arguments:
        model: the model
        val_loader: validation dataloader

    Outputs:
        precision: overall precision score
        recall: overall recall score
        f1: overall f1 score
        roc_auc: overall roc_auc score
    """

    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    model.eval()
    for x, masks, rev_x, rev_masks, y in val_loader:
        y_logit = model(x, masks)
        y_hat = torch.where(y_logit > 0.5, 1, 0)
        y_score = torch.cat((y_score, y_logit.detach().to('cpu')), dim=0)
        y_pred = torch.cat((y_pred, y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)

    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
    roc_auc = roc_auc_score(y_true, y_score)
    return p, r, f, roc_auc

## Training and evaluation with test data

Each epoch would take around 40 minutes with CPU training.

Best result is obtained after 4 epochs.

PR-AUC and ROC-AUC are the most significant metric. The printed F score as a reference, is based on threshold value 0.5.

In [14]:
n_epochs = 4
print(time.strftime("%H:%M:%S", time.localtime()))
train(ctn, train_loader, val_loader, n_epochs)
print(time.strftime("%H:%M:%S", time.localtime()))

# store model state dict
torch.save(ctn.state_dict(), "models/content_notebook.pth")


19:10:14


KeyboardInterrupt: 

Test should report a ROC-AUC over 0.8 and PR-AUC over 0.67, which outperforms the retain model.

In [15]:
# reload and evaluation

#ctn = Content(input_dim=num_codes-1)
#ctn.load_state_dict(torch.load("models/content.pth"))

p, r, f, roc_auc, pr_auc = full_eval(ctn, test_loader)
print('Test p: {:.4f}, r:{:.4f}, f: {:.4f}, roc_auc: {:.4f}, pr_auc: {:.4f}'.format(p, r, f, roc_auc, pr_auc))

Test p: 0.7243, r:0.3398, f: 0.4626, roc_auc: 0.7659, pr_auc: 0.5826
