In [4]:
#%%capture
!pip install pyhealth torch polars numpy
import polars as pl
import pandas as pd
import os
import sys
import random
import pickle
import torch
import numpy as np

import torch.nn as nn
import torch.nn.functional as F
# set seed
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

Collecting pyhealth
  Using cached pyhealth-1.1.3-py2.py3-none-any.whl (113 kB)
Collecting torch
  Using cached torch-2.0.0-cp311-cp311-manylinux1_x86_64.whl (619.9 MB)
Collecting rdkit>=2022.03.4 (from pyhealth)
  Using cached rdkit-2022.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.4 MB)
Collecting scikit-learn>=0.24.2 (from pyhealth)
  Using cached scikit_learn-1.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.6 MB)
Collecting networkx>=2.6.3 (from pyhealth)
  Using cached networkx-3.1-py3-none-any.whl (2.1 MB)
Collecting tqdm (from pyhealth)
  Using cached tqdm-4.65.0-py3-none-any.whl (77 kB)
Collecting sympy (from torch)
  Using cached sympy-1.11.1-py3-none-any.whl (6.5 MB)
Collecting nvidia-cuda-nvrtc-cu11==11.7.99 (from torch)
  Using cached nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl (21.0 MB)
Collecting nvidia-cuda-runtime-cu11==11.7.99 (from torch)
  Using cached nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x8

In [82]:
import sys
obj = pickle.load(open(os.path.join('data/', 'allevents_by_episode'), 'rb'))
obj.estimated_size()/(1024**2)
patients = list(range(len(obj)))
seqs = obj.select(pl.col('uniq_itemid').cast(pl.List(pl.List(pl.Int32)))).to_series().to_list()
mortality = obj.select(pl.col('mortality_tf')).to_series().to_list()


In [27]:
itemset = list(set(list(set([each_j for i in seqs for each_i in i for each_j in each_i]))))
code2idx = {itemset[i]: i for i in range(len(itemset))}
idx2code = {i: itemset[i] for i in range(len(itemset))}

In [83]:
# Customer dataset
from torch.utils.data import Dataset, DataLoader, random_split
class CustomDataset(Dataset):
    def __init__(self, patients, seqs, mortality):
        self.patients = patients
        self.seqs = seqs
        self.labels = mortality
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, index):
        return self.patients[index], self.seqs[index], self.labels[index]
dataset = CustomDataset(patients, seqs, mortality)
assert len(dataset) == len(obj) #TODO write test separately 

In [117]:
# Collate function
def collate_fn(data):
    patients, seqs, labels = zip(*data)
    num_patients = len(patients)
    max_num_events = max([len(event) for event in seqs])
    max_num_items = max([len(itemid) for event in seqs for itemid in event])
    tensor_shape = (num_patients, max_num_events, max_num_items)
    x =        torch.zeros(tensor_shape, dtype=torch.long)
    rev_x =    torch.zeros(tensor_shape, dtype=torch.long)
    masks =    torch.zeros(tensor_shape, dtype=torch.bool)
    rev_masks = torch.zeros(tensor_shape, dtype=torch.bool) 
    y =        torch.tensor(labels, dtype=torch.long)
    

    for i_patient, events in enumerate(seqs):
        for i_event, item in enumerate(events):
            padded_item = torch.concat([torch.tensor(item),
                                        torch.zeros(max_num_items - len(item))]).long()
            x[i_patient, i_event, :] = padded_item
            masks[i_patient, i_event, :] = torch.where(padded_item!=0,1,0)  
    for i_patient, events in enumerate(seqs):
        idx_all_real_events = torch.sum(x[i_patient, :, :], dim=(1))!= 0
        idx_padded_events =torch.sum(x[i_patient, :, :], dim=(1))== 0
        fliped = torch.flip(x[i_patient, idx_all_real_events, :].unsqueeze(1), (0,)).squeeze(1)
        rev_x[i_patient, :, :] = torch.concat((fliped, x[i_patient, idx_padded_events, :] ))
        rev_masks[i_patient, :, :] = torch.where(rev_x[i_patient, :, :] != 0, True, False)
    return x, masks, rev_x, rev_masks, y

    

from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
loader_iter = iter(loader)
x, masks, rev_x, rev_masks, y = next(loader_iter)
assert x.shape == masks.shape == (10, 8, 313)
assert y.shape == (10,)

In [116]:
train, val, test = int(len(dataset)*0.8), int(len(dataset)*0.1), len(dataset) - int(len(dataset)*0.8) -  int(len(dataset)*0.1)
lengths = [train, val, test]
train_dataset, val_dataset, test_dataset = random_split(dataset=dataset, lengths=lengths)

def load_data(train_dataset, val_dataset, test_dataset, collate_fn):
    batch_size = 32
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = load_data(train_dataset, val_dataset, test_dataset, collate_fn)

In [None]:
def sum_embeddings_with_masks(x, masks):
    """
    x    (batch_size, #item, embedding_dim)
    mask (batch_size, #item)
    return (batch_size, embedding_dim)
    """
    mask = mask.unsqueeze(-1).expand(x.shape[0], x.shape[1], x.shape[2])
    return torch.sum(mask * x, 1)
def get_last_item(hidden_states, masks):
    """
    hidden_states: (batch_size, #item, embedding_dim)
    masks:         (batch_size, #item, embedding_dim)
    return last_hidden_state: (batch_size, embedding_dim)
    """
    #print(torch.sum(masks, 1))
    idx_last_item = torch.argmin(torch.sum(masks, 2), 1)
    print(idx_last_item)
    return hidden_states[:, 
                         torch.where(idx_last_item - 1 < 0, max(idx_last_item), idx_last_item - 1),
                         :]


In [None]:
max_num_items = 5
batch_size = 3
embedding_dims = 4
torch.random.manual_seed(12345)
hidden_states = torch.randn((batch_size, max_num_items, embedding_dims))
#print(hidden_states)
masks = torch.ones_like(hidden_states)
masks[:,3:,:] = 0 
masks[0,2:,:] = 0 
masks[2,1:,:] = 0
masks[0,:,2:] = 0
masks[1,:,3:] = 0
masks = masks.bool()
#print(masks)
out = get_last_item(hidden_states, masks)
out

tensor([2, 3, 1])


tensor([[[-1.0148e+00, -5.4286e-01,  4.3074e-01, -1.9257e+00],
         [ 1.2756e+00, -1.1316e+00,  8.6800e-01,  7.0788e-01],
         [ 2.0585e-01, -9.3001e-01,  1.1425e-01, -4.4503e-01],
         [-8.5306e-01, -8.4074e-01, -3.9633e-01, -2.5913e-01],
         [-6.7731e-01,  7.0912e-02, -4.5838e-01,  1.6847e+00]],

        [[ 1.4235e-01,  6.4272e-01, -7.0122e-01,  1.0413e+00],
         [ 1.5445e+00,  1.1718e+00, -3.8031e-01,  1.7336e+00],
         [-4.5109e-01, -8.9362e-01, -2.7579e-01,  7.6457e-01],
         [-1.3222e+00, -2.5249e-01, -2.0878e+00, -5.7322e-01],
         [-8.1685e-01, -4.7587e-01,  8.2872e-01, -1.6278e-01]],

        [[-1.4798e+00,  4.8731e-01, -3.0128e+00,  4.4386e-01],
         [ 3.5976e-01, -1.2348e-02,  2.1852e-01, -1.2815e+00],
         [ 2.4112e+00,  1.9991e+00,  7.8479e-01, -1.0195e+00],
         [-2.1058e-01,  6.2684e-01,  9.3176e-01,  1.8675e-01],
         [ 9.5893e-01, -1.1371e+00,  1.3051e-03,  1.3174e+00]]])

tensor([[[False, False, False,  ...,  True, False, False],
         [False, False,  True,  ..., False, False, False],
         [ True, False,  True,  ..., False, False,  True],
         ...,
         [False,  True, False,  ...,  True,  True,  True],
         [False,  True, False,  ..., False, False, False],
         [False, False,  True,  ..., False,  True, False]],

        [[False,  True, False,  ..., False, False, False],
         [False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True, False, False],
         ...,
         [ True, False, False,  ..., False, False, False],
         [False, False, False,  ...,  True, False,  True],
         [ True,  True,  True,  ...,  True, False, False]],

        [[False, False,  True,  ..., False, False, False],
         [ True,  True,  True,  ...,  True, False, False],
         [ True,  True,  True,  ...,  True, False, False],
         ...,
         [False,  True, False,  ..., False,  True,  True],
         [