In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data.dataset import random_split
import pickle

### Load preprocessed data to save time

In [2]:
with open("los.pkl","rb") as f:
    los_7 = pickle.load(f)
    
with open("mort.pkl","rb") as f:
    mort = pickle.load(f)
    
with open("demograph_interventions_SDPRL.pkl","rb") as f:
    demo_iseqs = pickle.load(f)

with open("demograph_vitals_zero_SDPRL.pkl","rb") as f:
    demo_vseqs = pickle.load(f)

In [3]:
class CustomMMDataset(Dataset):
    
    def __init__(self, seqs1, seqs2, labels):
        self.x1 = seqs1
        self.x2 = seqs2
        self.y = labels
    
    def __len__(self):
        
        """
        TODO: Return the number of samples (i.e. patients).
        """
        
        # your code here
        return len(self.y)
    
    def __getitem__(self, index):
        
        """
        TODO: Generates one sample of data.
        
        Note that you DO NOT need to covert them to tensor as we will do this later.
        """
        
        # SDPRL: m1_pos,m2_pos, m1_neg,m2_neg, y
        return self.x1[index][0], self.x2[index][0], self.x1[index][1], self.x2[index][1], self.y[index]

In [4]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, collate_fn):
    
    '''
    TODO: Implement this function to return the data loader for  train and validation dataset. 
    Set batchsize to 32. Set `shuffle=True` only for train dataloader.
    
    Arguments:
        train dataset: train dataset of type `CustomDataset`
        val dataset: validation dataset of type `CustomDataset`
        collate_fn: collate function
        
    Outputs:
        train_loader, val_loader: train and validation dataloaders
    
    Note that you need to pass the collate function to the data loader `collate_fn()`.
    '''
    
    # your code here
    train_loader = DataLoader(train_dataset, batch_size=16, collate_fn=collate_fn, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16, collate_fn=collate_fn)
    
    return train_loader, val_loader

In [5]:
def collate_fn(data):
    sequences1pos, sequences2pos,sequences1neg, sequences2neg, labels = zip(*data)

    y = torch.tensor(labels, dtype=torch.float)
    ########################################################
    #padding and masking for modal 1 & 2 for positive sample
    ########################################################
    num_patients = len(sequences1pos)
    num_interventions = [len(patient) for patient in sequences1pos]
    num_codes = [len(intervention) for patient in sequences1pos for intervention in patient]

    max_num_interventions = max(num_interventions)
    max_num_codes = max(num_codes)
    
    x1_pos = torch.zeros((num_patients, max_num_interventions, max_num_codes), dtype=torch.long)
    masks1_pos = torch.zeros((num_patients, max_num_interventions, max_num_codes), dtype=torch.bool)

    for i_patient, patient in enumerate(sequences1pos):
        for j_intervention, intervention in enumerate(patient):
            # your code here
            for idx, code in enumerate(intervention):
                x1_pos[i_patient,j_intervention,idx] = code
                masks1_pos[i_patient,j_intervention,idx] = 1
    

    num_vitals = [len(patient) for patient in sequences2pos]
    max_num_vitals = max(num_interventions)
    vitals_dim = 107
    x2_pos = torch.zeros((num_patients, max_num_vitals, vitals_dim), dtype=torch.long)
    masks2_pos = torch.zeros((num_patients, max_num_vitals, vitals_dim), dtype=torch.bool)
    
    for i_patient, patient in enumerate(sequences2pos):
        for j_vitals, vitals in enumerate(patient):
            # your code here
            x2_pos[i_patient, j_vitals] = torch.tensor(vitals, dtype=torch.long)
            masks2_pos[i_patient, j_vitals] = 1
            
    ########################################################
    #padding and masking for modal 1 & 2 for NEGATIVE sample
    ########################################################
    num_patients = len(sequences1neg)
    num_interventions = [len(patient) for patient in sequences1neg]
    num_codes = [len(intervention) for patient in sequences1neg for intervention in patient]

    max_num_interventions = max(num_interventions)
    max_num_codes = max(num_codes)
    
    x1_neg = torch.zeros((num_patients, max_num_interventions, max_num_codes), dtype=torch.long)
    masks1_neg = torch.zeros((num_patients, max_num_interventions, max_num_codes), dtype=torch.bool)

    for i_patient, patient in enumerate(sequences1neg):
        for j_intervention, intervention in enumerate(patient):
            # your code here
            for idx, code in enumerate(intervention):
                x1_neg[i_patient,j_intervention,idx] = code
                masks1_neg[i_patient,j_intervention,idx] = 1
    

    num_vitals = [len(patient) for patient in sequences2neg]
    max_num_vitals = max(num_interventions)
    vitals_dim = 107
    x2_neg = torch.zeros((num_patients, max_num_vitals, vitals_dim), dtype=torch.long)
    masks2_neg = torch.zeros((num_patients, max_num_vitals, vitals_dim), dtype=torch.bool)
    
    for i_patient, patient in enumerate(sequences2neg):
        for j_vitals, vitals in enumerate(patient):
            # your code here
            x2_neg[i_patient, j_vitals] = torch.tensor(vitals, dtype=torch.long)
            masks2_neg[i_patient, j_vitals] = 1
    
    return x1_pos, masks1_pos, x2_pos, masks2_pos,\
            x1_neg, masks1_neg, x2_neg, masks2_neg, y

In [6]:
mm_dataset = CustomMMDataset(demo_iseqs, demo_vseqs, los_7)

In [7]:
split = int(len(mm_dataset)*0.8)

lengths = [split, len(mm_dataset) - split]
train_dataset, val_dataset = random_split(mm_dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 27577
Length of val dataset: 6895


In [8]:
train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

In [9]:
loader_iter = iter(train_loader)
x1_pos, masks1_pos, x2_pos, masks2_pos,\
            x1_neg, masks1_neg, x2_neg, masks2_neg, y = next(loader_iter)

### SDPRL

In [11]:
def sum_embeddings_with_mask(x, masks):
    """
    TODO: mask select the embeddings for true visits (not padding visits) and then
        sum the embeddings for each visit up.

    Arguments:
        x: the embeddings of diagnosis sequence of shape (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        sum_embeddings: the sum of embeddings of shape (batch_size, # visits, embedding_dim)
        
    NOTE: Do NOT use for loop.

    """
    x_masked = x * masks.unsqueeze(dim=-1)
    # your code here
    return x_masked.sum(dim=2)

In [12]:
def get_last_visit(hidden_states, masks):
    """
    TODO: obtain the hidden state for the last true visit (not padding visits)

    Arguments:
        hidden_states: the hidden states of each visit of shape (batch_size, # visits, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        last_hidden_state: the hidden state for the last true visit of shape (batch_size, embedding_dim)
        
    NOTE: DO NOT use for loop.
    
    HINT: Consider using `torch.gather()`.
    """
    
    # your code here
    first_zero_idx = torch.argmin(masks.sum(dim=2),axis=1)
    last_nonzero_idx = first_zero_idx - 1
    batch_size = hidden_states.shape[0]
    return hidden_states[list(range(batch_size)), last_nonzero_idx]

In [13]:
class SDPRL_RNN(nn.Module):
    
    def __init__(self, num_codes_m1,num_codes_m2):
        super().__init__()

        # your code here
        self.m1_embedding = nn.Embedding(num_codes_m1+1,embedding_dim=128)
        self.m1_rnn = nn.LSTM(128, 128, 1, batch_first=True, bidirectional=True)
        
        self.m2_embedding_num = nn.Linear(in_features=104, out_features=96)
        self.m2_embedding_cat = nn.Embedding(num_codes_m2+1,embedding_dim=32)
        self.m2_rnn = nn.LSTM(128, 128, 1, batch_first=True, bidirectional=True)
        
        self.fc1 = nn.Linear(in_features=256, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=1)
        
        self.relu = nn.ReLU()

    
    def forward(self, x1_pos, masks1_pos, x2_pos, masks2_pos, x1_neg=None, masks1_neg=None, x2_neg=None, masks2_neg=None):
        batch_size = x1_pos.shape[0]
        
        # FIRST MODAL pos
        # 1. Pass the sequence through the embedding layer;
        x1_pos = self.m1_embedding(x1_pos)
        # 2. Sum the embeddings for each diagnosis code up for a visit of a patient.
        x1_pos = sum_embeddings_with_mask(x1_pos, masks1_pos)
        # 3. Pass the embegginds through the RNN layer;
        output1_pos, (hn1, cn1) = self.m1_rnn(x1_pos)
        # 4. Obtain the hidden state at the last visit.
        true_h_n1_pos = get_last_visit(output1_pos, masks1_pos)
        
        # SECOND MODAL pos
        num_x2_pos = x2_pos[:,:,:-3]
        cat_x2_pos = x2_pos[:,:,-3:]
        num_x2_pos = self.m2_embedding_num(num_x2_pos.float())
        cat_x2_pos = self.m2_embedding_cat(cat_x2_pos)
        cat_x2_pos = cat_x2_pos.sum(dim=2)
        x2_pos = torch.concat([num_x2_pos, cat_x2_pos], dim=-1)
        output2_pos, (hn2, cn2) = self.m2_rnn(x2_pos)
        true_h_n2_pos = get_last_visit(output2_pos, masks2_pos)
        
        # FOR PREDICTION
        # concat hidden stats and pass through 
        true_h_n_pos = torch.maximum(true_h_n1_pos, true_h_n2_pos)
        x = self.relu(self.fc1(true_h_n_pos))
        out_pos = self.fc2(x).view(batch_size)
        
        if x1_neg != None:
            # FIRST MODAL neg
            # 1. Pass the sequence through the embedding layer;
            x1_neg = self.m1_embedding(x1_neg)
            # 2. Sum the embeddings for each diagnosis code up for a visit of a patient.
            x1_neg = sum_embeddings_with_mask(x1_neg, masks1_neg)
            # 3. Pass the embegginds through the RNN layer;
            output1_neg, (hn1, cn1) = self.m1_rnn(x1_neg)
            # 4. Obtain the hidden state at the last visit.
            true_h_n1_neg = get_last_visit(output1_neg, masks1_neg)

            # SECOND MODAL pos
            num_x2_neg = x2_neg[:,:,:-3]
            cat_x2_neg = x2_neg[:,:,-3:]
            num_x2_neg = self.m2_embedding_num(num_x2_neg.float())
            cat_x2_neg = self.m2_embedding_cat(cat_x2_neg)
            cat_x2_neg = cat_x2_neg.sum(dim=2)
            x2_neg = torch.concat([num_x2_neg, cat_x2_neg], dim=-1)
            output2_neg, (hn2, cn2) = self.m2_rnn(x2_neg)
            true_h_n2_neg = get_last_visit(output2_neg, masks2_neg)

            # FOR TRAINING/Calculating loss
            # lm1m2
            exp_pos = torch.exp(torch.nn.functional.cosine_similarity(true_h_n1_pos, true_h_n2_pos) / TAU)
            lm1m2_denom = 32 * torch.exp(torch.nn.functional.cosine_similarity(true_h_n1_pos, true_h_n2_neg) / TAU)
            lm1m2 = -torch.log(exp_pos/lm1m2_denom)
            # lm2m1
            lm2m1_denom = 32 * torch.exp(torch.nn.functional.cosine_similarity(true_h_n1_neg, true_h_n2_pos) / TAU)
            lm2m1 = -torch.log(exp_pos/lm2m1_denom)    
            return out_pos, lm1m2, lm2m1
            
            
        
        # prediction mode
        return out_pos

In [14]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, average_precision_score

def eval_model(model, val_loader):
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    model.eval()
    for x1_pos, masks1_pos, x2_pos, masks2_pos,\
        x1_neg, masks1_neg, x2_neg, masks2_neg, y in val_loader:
        
        y_hat = model(x1_pos, masks1_pos, x2_pos, masks2_pos)
        y_score = torch.cat((y_score,  y_hat.detach().to('cpu')), dim=0)
        y_hat = (y_hat > 0.5).int()
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
    """
    TODO:
        Calculate precision, recall, f1, and roc auc scores.
        Use `average='binary'` for calculating precision, recall, and fscore.
    """
    # p, r, f,_ = precision_recall_fscore_support(y_true, y_pred, average='binary')
    p, r, f = None, None, None
    roc_auc = roc_auc_score(y_true, y_score)
    auprc = average_precision_score(y_true, y_score, average='macro')

    return p, r, f, roc_auc, auprc

In [15]:
def train(model, train_loader, val_loader, n_epochs):
    
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for x1_pos, masks1_pos, x2_pos, masks2_pos,\
            x1_neg, masks1_neg, x2_neg, masks2_neg, y in train_loader:

            optimizer.zero_grad()
            y_pred, lm1m2, lm2m1 = model(x1_pos, masks1_pos, x2_pos, masks2_pos,\
                                        x1_neg, masks1_neg, x2_neg, masks2_neg)
            task_loss = task_criterion(y_pred, y)
            loss = lambda1 * task_loss + lambda2 * torch.mean(lm1m2) + lambda3 * torch.mean(lm2m1)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        p, r, f, roc_auc, auprc = eval_model(model, val_loader)
        print('Epoch: {} \t Validation roc_auc: {:.4f}, auprc: {:.4f}'
              .format(epoch+1,roc_auc,auprc))

LOS 7 TASK

In [16]:
TAU = 0.1
lambda1 = 0.05
lambda2 = 0.05
lambda3 = 1

In [17]:
sdprlrnn = SDPRL_RNN(14+49, 49)

In [18]:
task_criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(sdprlrnn.parameters(), lr=0.0001, weight_decay=0.001)

In [None]:
# number of epochs to train the model
n_epochs = 10
train(sdprlrnn, train_loader, val_loader, n_epochs)


In [20]:
p, r, f, roc_auc,auprc = eval_model(sdprlrnn, val_loader)
print(roc_auc)
print(auprc)

0.7897234784855925
0.1920386786378866


ICU MORT TASK

In [30]:
mm_dataset = CustomMMDataset(demo_iseqs, demo_vseqs, mort)

In [31]:
split = int(len(mm_dataset)*0.8)

lengths = [split, len(mm_dataset) - split]
train_dataset, val_dataset = random_split(mm_dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 27577
Length of val dataset: 6895


In [32]:
train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

In [33]:
sdprlrnn = SDPRL_RNN(14+49, 49)

In [34]:
task_criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(sdprlrnn.parameters(), lr=0.0001, weight_decay=0.001)

In [None]:
# number of epochs to train the model
n_epochs = 10
train(sdprlrnn, train_loader, val_loader, n_epochs)


In [36]:
p, r, f, roc_auc,auprc = eval_model(sdprlrnn, val_loader)
print(roc_auc)
print(auprc)

0.8013975051378864
0.2270470319798003


### Transformer

In [69]:
class SDPRL_TF(nn.Module):
    
    def __init__(self, num_codes_m1,num_codes_m2):
        super().__init__()

        # your code here
        self.m1_embedding = nn.Embedding(num_codes_m1+1,embedding_dim=128)
        encoder_layer1 = nn.TransformerEncoderLayer(d_model=128, nhead=2)
        self.transformer_encoder1 = nn.TransformerEncoder(encoder_layer1, num_layers=2)
        
        self.m2_embedding_num = nn.Linear(in_features=104, out_features=96)
        self.m2_embedding_cat = nn.Embedding(num_codes_m2+1,embedding_dim=32)
        encoder_layer2 = nn.TransformerEncoderLayer(d_model=128, nhead=2)
        self.transformer_encoder2 = nn.TransformerEncoder(encoder_layer2, num_layers=2)
        
        self.fc1 = nn.Linear(in_features=128, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=1)
        
        self.relu = nn.ReLU()

    
    def forward(self, x1_pos, masks1_pos, x2_pos, masks2_pos, x1_neg=None, masks1_neg=None, x2_neg=None, masks2_neg=None):
        batch_size = x1_pos.shape[0]
        
        # FIRST MODAL pos
        # 1. Pass the sequence through the embedding layer;
        x1_pos = self.m1_embedding(x1_pos)
        # 2. Sum the embeddings for each diagnosis code up for a visit of a patient.
        x1_pos = sum_embeddings_with_mask(x1_pos, masks1_pos)
        # 3. Pass the embegginds through the RNN layer;
        output1_pos = self.transformer_encoder1(x1_pos)
        # 4. Obtain the hidden state at the last visit.
        true_h_n1_pos = get_last_visit(output1_pos, masks1_pos)
        
        # SECOND MODAL pos
        num_x2_pos = x2_pos[:,:,:-3]
        cat_x2_pos = x2_pos[:,:,-3:]
        num_x2_pos = self.m2_embedding_num(num_x2_pos.float())
        cat_x2_pos = self.m2_embedding_cat(cat_x2_pos)
        cat_x2_pos = cat_x2_pos.sum(dim=2)
        x2_pos = torch.concat([num_x2_pos, cat_x2_pos], dim=-1)
        output2_pos = self.transformer_encoder2(x2_pos)
        true_h_n2_pos = get_last_visit(output2_pos, masks2_pos)
        
        # FOR PREDICTION
        # concat hidden stats and pass through 
        true_h_n_pos = torch.maximum(true_h_n1_pos, true_h_n2_pos)
        x = self.relu(self.fc1(true_h_n_pos))
        out_pos = self.fc2(x).view(batch_size)
        
        if x1_neg != None:
            # FIRST MODAL neg
            # 1. Pass the sequence through the embedding layer;
            x1_neg = self.m1_embedding(x1_neg)
            # 2. Sum the embeddings for each diagnosis code up for a visit of a patient.
            x1_neg = sum_embeddings_with_mask(x1_neg, masks1_neg)
            # 3. Pass the embegginds through the RNN layer;
            output1_neg = self.transformer_encoder1(x1_neg)
            # 4. Obtain the hidden state at the last visit.
            true_h_n1_neg = get_last_visit(output1_neg, masks1_neg)

            # SECOND MODAL pos
            num_x2_neg = x2_neg[:,:,:-3]
            cat_x2_neg = x2_neg[:,:,-3:]
            num_x2_neg = self.m2_embedding_num(num_x2_neg.float())
            cat_x2_neg = self.m2_embedding_cat(cat_x2_neg)
            cat_x2_neg = cat_x2_neg.sum(dim=2)
            x2_neg = torch.concat([num_x2_neg, cat_x2_neg], dim=-1)
            output2_neg = self.transformer_encoder2(x2_neg)
            true_h_n2_neg = get_last_visit(output2_neg, masks2_neg)

            # FOR TRAINING/Calculating loss
            # lm1m2
            exp_pos = torch.exp(torch.nn.functional.cosine_similarity(true_h_n1_pos, true_h_n2_pos) / TAU)
            lm1m2_denom = 32 * torch.exp(torch.nn.functional.cosine_similarity(true_h_n1_pos, true_h_n2_neg) / TAU)
            lm1m2 = -torch.log(exp_pos/lm1m2_denom)
            # lm2m1
            lm2m1_denom = 32 * torch.exp(torch.nn.functional.cosine_similarity(true_h_n1_neg, true_h_n2_pos) / TAU)
            lm2m1 = -torch.log(exp_pos/lm2m1_denom)    
            return out_pos, lm1m2, lm2m1
            
            
        
        # prediction mode
        return out_pos

In [70]:
mm_dataset = CustomMMDataset(demo_iseqs, demo_vseqs, los_7)

In [71]:
split = int(len(mm_dataset)*0.8)

lengths = [split, len(mm_dataset) - split]
train_dataset, val_dataset = random_split(mm_dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 27577
Length of val dataset: 6895


In [72]:
train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

In [73]:
loader_iter = iter(train_loader)
x1_pos, masks1_pos, x2_pos, masks2_pos,\
            x1_neg, masks1_neg, x2_neg, masks2_neg, y = next(loader_iter)

In [74]:
sdprltf = SDPRL_TF(14+49, 49)

In [75]:
task_criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(sdprltf.parameters(), lr=0.0001, weight_decay=0.001)

In [None]:
# number of epochs to train the model
n_epochs = 10
train(sdprltf, train_loader, val_loader, n_epochs)


In [None]:
p, r, f, roc_auc,auprc = eval_model(sdprlrnn, val_loader)
print(roc_auc)
print(auprc)

MORT ICU

In [77]:
mm_dataset = CustomMMDataset(demo_iseqs, demo_vseqs, mort)

In [78]:
split = int(len(mm_dataset)*0.8)

lengths = [split, len(mm_dataset) - split]
train_dataset, val_dataset = random_split(mm_dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 27577
Length of val dataset: 6895


In [79]:
train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

In [80]:
loader_iter = iter(train_loader)
x1_pos, masks1_pos, x2_pos, masks2_pos,\
            x1_neg, masks1_neg, x2_neg, masks2_neg, y = next(loader_iter)

In [81]:
sdprltf = SDPRL_TF(14+49, 49)

In [82]:
task_criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(sdprltf.parameters(), lr=0.0001, weight_decay=0.001)

In [None]:
# number of epochs to train the model
n_epochs = 10
train(sdprltf, train_loader, val_loader, n_epochs)


In [84]:
p, r, f, roc_auc,auprc = eval_model(sdprlrnn, val_loader)
print(roc_auc)
print(auprc)

0.7944453064391002
0.21529256966372448
