# RNN for events

In [171]:
import os
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

events_items = pickle.load( open( "events_item.p", "rb" ) )
events_values = pickle.load(open("events_value.p", "rb") )
patients = pickle.load(open('patients.p', 'rb'))
max_code = pickle.load(open('events_maxcode.p', 'rb')) + 1

assert len(events_items)==174272 and len(events_values)==174272 and len(patients)==9822, "Wrong dataframes?"
assert max_code==127, "MAX CODE changed?"

In [172]:
# set seed
seed = 230729
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

PATIENTS = 10000  # 0 - ALL

#### Dataset and Dataloader

In [173]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    
    def __init__(self, patients, events_items, events_values):

        self.patients = events_items['subject_id'].unique()
        self.y = patients
        self.items = events_items.groupby('subject_id').agg('codes').apply(list).values
        self.values = events_values.groupby('subject_id').agg('values').apply(list).values
        
    
    def __len__(self):
        
        """
        Return the number of patients.
        """
        
        return len(self.patients)
        
    
    def __getitem__(self, index):
        
        """
        Generates one sample of data.
        
        Outputs:
            - subject_id
            - tensor of visits, multi-hot items values
            - mortality flag
        
        """
        
        events = np.zeros([len(self.items[index]), max_code])

        for i, codes in enumerate(self.items[index]):
            for j, code in enumerate(codes):
                v = self.values[index][i][j]
                events[i, code] = v if not math.isnan(v) else 0.0
        
        subject_id = int(self.y[self.y['subject_id']==self.patients[index]]['subject_id'])
        mortality_flag = int(self.y[self.y['subject_id']==self.patients[index]]['mortality_flag'])
        
        return subject_id, events, mortality_flag 

In [174]:
if PATIENTS > 0 :
    patients = patients[:PATIENTS]

events_items = events_items[events_items['subject_id'].isin(patients['subject_id'])]
events_values = events_values[events_values['subject_id'].isin(patients['subject_id'])]
dataset = CustomDataset(patients, events_items, events_values)

print ("Patients:", len(patients))
print ("Len of dataset:", len(dataset))

Patients: 9822
Len of dataset: 9806


In [175]:
def collate_fn(data):
    subject_id, events, mortality_flag = zip(*data)
    
    maxvisits = max([len(p) for p in events])
    
    result = torch.tensor([np.concatenate((p, np.zeros([maxvisits - len(p), max_code]))) for p in events]).float()
    mask = torch.tensor([np.concatenate((np.ones(len(p)), np.zeros(maxvisits - len(p)))) for p in events]).int()
    
    return torch.tensor(subject_id).int(), result, mask, torch.tensor(mortality_flag).float()

In [177]:
from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
loader_iter = iter(loader)
subjects, events, masks, y = next(loader_iter)

#assert subjects.shape==torch.Size([10]) and events.shape==torch.Size([10,34,126]) and masks.shape==torch.Size([10,34]) and y.shape==torch.Size([10]), "Wrong dimensions!"

In [178]:
from torch.utils.data.dataset import random_split

split = int(len(dataset)*0.8)

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 7844
Length of val dataset: 1962


In [179]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, collate_fn):
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)    
    
    return train_loader, val_loader


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

#### Naive RNN

In [180]:
class NaiveRNN(nn.Module):
    
    def __init__(self, num_codes=max_code, emb_size=128):
        super().__init__()
        
       # self.embedding = nn.Embedding(num_codes, emb_size)
        self.rnn = nn.GRU(num_codes, hidden_size=emb_size, batch_first=True)
        self.fc1 = nn.Linear(emb_size, 1)
        self.sig = nn.Sigmoid()

    
    def forward(self, events, masks):
        
        rnn_hidden_states, _ = self.rnn(events)
        
        real_hidden_states = rnn_hidden_states * masks.unsqueeze(-1).expand(rnn_hidden_states.shape)
        
        sum_hidden_states = real_hidden_states.sum(dim=1)

        fc1 = self.fc1(sum_hidden_states)
        output = self.sig(fc1).flatten()
        return output
    

# load the model here
naive_rnn = NaiveRNN()
naive_rnn

NaiveRNN(
  (rnn): GRU(127, 128, batch_first=True)
  (fc1): Linear(in_features=128, out_features=1, bias=True)
  (sig): Sigmoid()
)

### Training and evaluation

In [181]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(naive_rnn.parameters(), lr=0.001)


In [182]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score


def eval_model(model, val_loader):    
    model.eval()
    Y_pred = []
    Y_test = []
    
    for subjects, x, masks, y in val_loader:
        outputs = model(x, masks)
        Y_pred.append(outputs.detach().numpy())
        Y_test.append(y.numpy())
    
    Y_pred = np.concatenate(Y_pred, axis=0)
    Y_test = np.concatenate(Y_test, axis=0)
    roc_auc = roc_auc_score(Y_test, Y_pred)
    Y_pred[ Y_pred < 0.5 ] = 0
    Y_pred[ Y_pred >= 0.5 ] = 1
    precision, recall, f1, support = precision_recall_fscore_support(Y_test, Y_pred, average='binary')

    return precision, recall, f1, roc_auc

In [183]:
def train(model, train_loader, val_loader, n_epochs):
    
    model.train()
    for epoch in range(n_epochs):
        for subjects, events, masks, target in train_loader:
            # your code here
            optimizer.zero_grad()
            output = model(events, masks)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
        precision, recall, f1, roc_auc = eval_model(model, val_loader)
        print(f"Epoch {epoch}: precision={precision} recall={recall}, f1={f1}, roc_auc={roc_auc}")
    

    
# number of epochs to train the model
n_epochs = 5
train(naive_rnn, train_loader, val_loader, n_epochs)

Epoch 0: precision=0.5360824742268041 recall=0.22510822510822512, f1=0.3170731707317073, roc_auc=0.6239393189133223
Epoch 1: precision=0.5289256198347108 recall=0.27705627705627706, f1=0.36363636363636365, roc_auc=0.7423854789539365
Epoch 2: precision=0.5294117647058824 recall=0.38961038961038963, f1=0.4488778054862843, roc_auc=0.7797672191086402
Epoch 3: precision=0.46308724832214765 recall=0.5974025974025974, f1=0.5217391304347827, roc_auc=0.8328219056122004
Epoch 4: precision=0.624 recall=0.33766233766233766, f1=0.4382022471910112, roc_auc=0.7695024020847244
