# RNN for events

In [1]:
import os
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

events_items = pickle.load( open( "events_item.p", "rb" ) )
events_values = pickle.load(open("events_value.p", "rb") )
patients = pickle.load(open('patients.p', 'rb'))
max_code = pickle.load(open('events_maxcode.p', 'rb'))

assert len(events_items)==174272 and len(events_values)==174272 and len(patients)==9822, "Wrong dataframes?"
assert max_code==126, "MAX CODE changed?"

In [145]:
# set seed
seed = 230729
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

#### Dataset and Dataloader

In [30]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    
    def __init__(self, patients, events_items, events_values):

        self.patients = events_items['subject_id'].unique()
        self.y = patients
        self.items = events_items.groupby('subject_id').agg('codes').apply(list).values
        self.values = events_values.groupby('subject_id').agg('values').apply(list).values
        
    
    def __len__(self):
        
        """
        Return the number of patients.
        """
        
        return len(self.patients)
        
    
    def __getitem__(self, index):
        
        """
        Generates one sample of data.
        
        Outputs:
            - subject_id
            - tensor of visits, multi-hot items values
            - mortality flag
        
        """
        
        events = np.zeros([len(self.items[index]), max_code])

        for i, codes in enumerate(self.items[index]):
            for j, code in enumerate(codes):
                v = self.values[index][i][j]
                events[i, code] = v if not math.isnan(v) else 0.0
        
        subject_id = int(self.y[self.y['subject_id']==self.patients[index]]['subject_id'])
        mortality_flag = int(self.y[self.y['subject_id']==self.patients[index]]['mortality_flag'])
        
        return subject_id, events, mortality_flag 

In [31]:
dataset = CustomDataset(patients, events_items, events_values)

In [130]:
def collate_fn(data):
    subject_id, events, mortality_flag = zip(*data)
    
    maxvisits = max([len(p) for p in events])
    
    result = torch.tensor([np.concatenate((p, np.zeros([maxvisits - len(p), max_code]))) for p in events]).long()
    mask = torch.tensor([np.concatenate((np.ones(len(p)), np.zeros(maxvisits - len(p)))) for p in events]).int()
    
    return torch.tensor(subject_id).int(), result, mask, torch.tensor(mortality_flag).int()

In [142]:
loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
loader_iter = iter(loader)
subjects, events, masks, y = next(loader_iter)

assert subjects.shape==torch.Size([10]) and events.shape==torch.Size([10,34,126]) and masks.shape==torch.Size([10,34]) and y.shape==torch.Size([10]), "Wrong dimensions!"

In [143]:
from torch.utils.data.dataset import random_split

split = int(len(dataset)*0.8)

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 7844
Length of val dataset: 1962


In [144]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, collate_fn):
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)    
    
    return train_loader, val_loader


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

#### Naive RNN

In [None]:
class NaiveRNN(nn.Module):
    
    def __init__(self, num_codes=max_code, emb_size=128):
        super().__init__()
        
        self.embedding = nn.Embedding(num_codes, emb_size)
        self.rnn = nn.GRU(emb_size, hidden_size=emb_size, batch_first=True)
        self.fc1 = nn.Linear(emb_size, 1)
        self.sig = nn.Sigmoid()

    
    def forward(self, x, masks, rev_x, rev_masks):
        """
        TODO:
            1. Pass the sequence through the embedding layer;
            2. Sum the embeddings for each diagnosis code up for a visit of a patient.
               Use `sum_embeddings_with_mask()`;
            3. Pass the embegginds through the RNN layer;
            4. Obtain the hidden state at the last visit.
               Use `get_last_visit()`;
            5. Pass the hidden state through the linear and activation layers.
            
        Arguments:
            x: the diagnosis sequence of shape (batch_size, # visits, # diagnosis codes)
            masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

        Outputs:
            probs: probabilities of shape (batch_size)
            
        Note that rev_x, rev_masks are passed in as arguments so that we can use the same 
        training and validation function for both models. You can ignore the them here.
        """
        
        emb = self.embedding(x)
        #print ("emb=",emb.shape)
        
        visits_emb = sum_embeddings_with_mask(emb, masks)
        #print ("visits_emb=",visits_emb.shape)
        rnn_hidden_states = self.rnn(visits_emb)
        #print ("rnn_hidden_states= len:", len(rnn_hidden_states), " element0:", rnn_hidden_states[0].shape, " element1:", rnn_hidden_states[1].shape)

        last_visit_hidden_state = get_last_visit(rnn_hidden_states[0], masks)
        #print ("last_visit_hidden_state= ", last_visit_hidden_state.shape)
        fc1 = self.fc1(last_visit_hidden_state)
        #print ("fc1= ", fc1.shape)
        output = self.sig(fc1).flatten()
        #print ("output= ", output.shape)
        return output
    

# load the model here
naive_rnn = NaiveRNN(num_codes = len(types))
naive_rnn