# RNN for events

In [178]:
import os
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

events_items = pickle.load( open( "events_item.p", "rb" ) )
events_values = pickle.load(open("events_value.p", "rb") )
patients = pickle.load(open('patients.p', 'rb'))
max_code = pickle.load(open('events_maxcode.p', 'rb'))

assert len(events_items)==174272 and len(events_values)==174272 and len(patients)==9822, "Wrong dataframes?"
assert max_code==126, "MAX CODE changed?"

In [179]:
# set seed
seed = 230729
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
MAX_CODE = 126

In [180]:
#events_items['daystodischarge'].max()

In [208]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    
    def __init__(self, patients, events_items, events_values):

        self.patients = events_items['subject_id'].unique()
        self.y = patients
        self.items = events_items.groupby('subject_id').agg('codes').apply(list).values
        self.values = events_values.groupby('subject_id').agg('values').apply(list).values
        
    
    def __len__(self):
        
        """
        Return the number of patients.
        """
        
        return len(self.patients)
        
    
    def __getitem__(self, index):
        
        """
        Generates one sample of data.
        
        Outputs:
            - subject_id
            - tensor of visits, multi-hot items values
            - mortality flag
        
        """
        
        events = torch.zeros(len(self.items[index]), max_code)

        for i, codes in enumerate(self.items[index]):
            for j, code in enumerate(codes):
                v = self.values[index][i][j]
                events[i, code] = v if not math.isnan(v) else 0.0
        
        subject_id = int(self.y[self.y['subject_id']==self.patients[index]]['subject_id'])
        mortality_flag = int(self.y[self.y['subject_id']==self.patients[index]]['mortality_flag'])
        
        return subject_id, events, mortality_flag 

In [209]:
dataset = CustomDataset(patients, events_items, events_values)

In [211]:
dataset[0]

(13,
 tensor([[  0.0000,   0.0000,   0.0000,   0.0000,   0.0000, 155.0000,   0.0000,
            0.5000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
            0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
            0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
            0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
            0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
            0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
            0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
            0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
            0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
            0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
            0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
            0.0000,   0.0000,   0.0