<a href="https://colab.research.google.com/github/balaksuiuc/CS598IQVIAClaims/blob/main/src/cost_prediction_RNN_embedding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The following code uses RNN+embedding to predict patient treatment costs.

In [10]:
import pandas, numpy
import urllib.request
import os, datetime
from sklearn.ensemble import RandomForestRegressor
from sklearn.datasets import make_regression
import sklearn

import os
import pickle
import random
import numpy  as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import sys
print(sys.version_info)

Mounted at /content/drive
sys.version_info(major=3, minor=7, micro=10, releaselevel='final', serial=0)


1.1 Load preprocessed data

In [11]:
# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"]=str(seed)

DATA_DIR = "/content/drive/MyDrive/iqvia_data/bala_seq_data/"
DATA_PATH = DATA_DIR

#pids  = pandas.read_pickle(os.path.join(DATA_PATH,'pids.pkl'))

pids  = pickle.load(open(os.path.join(DATA_PATH,'pids.pkl'),'rb'))
morts = pickle.load(open(os.path.join(DATA_PATH,'morts.pkl'),'rb')) # this is a list of floats
seqs  = pickle.load(open(os.path.join(DATA_PATH,'seqs.pkl'),'rb'))
num_types = len(numpy.unique(list(itertools.chain(*list(itertools.chain(*seqs))))))

1.2 Define custom dataset and collate_fn

In [15]:
from torch.utils.data import Dataset
class CustomDataset(Dataset):
    def __init__(self, seqs, morts):
        self.x = seqs
        self.y = morts
    def __len__(self):
        return(len(self.x))
    def __getitem__(self, index):
        return (self.x[index], self.y[index])
    
## collate_fn
def collate_fn(data):
    sequences, labels = zip(*data)
    y = torch.tensor(labels, dtype=torch.float)
    
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    max_num_visits = max(num_visits)
    #num_diagcodes = [max([len(np.unique(s)) for s in ss]) for ss in sequences]
    num_diagcodes = [max([len(s) for s in ss]) for ss in sequences]
    max_num_diagcodes = max(num_diagcodes)
    #print('num_patients', num_patients, 'num_visits', num_visits, \
    #      'max_num_visits', max_num_visits, 'num_diagcodes', num_diagcodes, \
    #          'max_num_diagcodes', max_num_diagcodes)
    
    x = torch.zeros((num_patients, max_num_visits, max_num_diagcodes), dtype=torch.long)
    masks = torch.zeros((num_patients, max_num_visits, max_num_diagcodes), dtype=torch.bool)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_diagcodes), dtype=torch.long)
    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_diagcodes), dtype=torch.bool)
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
            for k_code, code in enumerate(visit):
                 x[i_patient, j_visit, k_code] = code
            masks[i_patient][j_visit][:len(visit)] = torch.Tensor([1]*len(visit))
        rev_x[i_patient][:len(patient)] = torch.flip(x[i_patient][:len(patient)],[0])    
        rev_masks[i_patient][:len(patient)] = torch.flip(masks[i_patient][:len(patient)],[0])    
        
    return x, masks, rev_x, rev_masks, y

<__main__.CustomDataset object at 0x7f87e01fc250>


1.3 DataLoader

In [16]:
## DataLoader
from torch.utils.data import DataLoader
dataset = CustomDataset(seqs, morts)
print(dataset)

loader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn)    
#dataset_loader = iter(loader)
#for i in range(0,10):
#    x,masks,rev_x,rev_masks,y=next(dataset_loader)

## load data
from torch.utils.data.dataset import random_split
split = int(len(dataset)*0.8)
lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)


<__main__.CustomDataset object at 0x7f87e0187c90>


1.4 Split into train and val datasets

In [17]:
## split into train and val datasets
from torch.utils.data import DataLoader
def load_data(train_dataset, val_dataset, collate_fn):
    train_loader = DataLoader(train_dataset, batch_size=32, collate_fn=collate_fn, shuffle=True)
    #train_loader = iter(train_loader)
    val_loader  = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn)
    #val_loader= iter(val_loader)
    return(train_loader, val_loader)

train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

'''
# testing to see how iterator works
i = 0
for step, batch,a,b,c in loader:
    print('i=',i)
    i = i + 1
''' 

"\n# testing to see how iterator works\ni = 0\nfor step, batch,a,b,c in loader:\n    print('i=',i)\n    i = i + 1\n"

2.1 RNN model

In [18]:
def sum_embeddings_with_mask(x, masks):
    return torch.sum(x * masks.unsqueeze_(-1).expand(x.shape), dim=-2)

def get_last_visit(hidden_states, masks):
    sum_masks = masks.sum(axis=2)
    last_true_visits = ((sum_masks > 0).sum(axis = 1) - 1)
    last_true_visits = last_true_visits.view(-1, 1, 1).expand(hidden_states.shape)
    out = torch.gather(hidden_states, dim=1, index=last_true_visits)
    last_hidden_state = out[:, 0, :].squeeze()
    return last_hidden_state

class BaseRNN(nn.Module):
    def __init__(self, num_codes):
        super().__init__()
        self.embDimSize = 32
        self.embedding = nn.Embedding(num_embeddings = num_codes+1, embedding_dim = self.embDimSize)
        self.rnn = nn.GRU(input_size = 32, hidden_size=32, batch_first=True)
        self.fc = nn.Linear(in_features=32, out_features=1)
        #self.sigmoid = nn.Sigmoid()
        
    def forward(self, x, masks, rev_x, rev_masks):
        embedding = self.embedding(x)
        sum_embedding = sum_embeddings_with_mask(embedding, masks)
        output, hidden = self.rnn(sum_embedding)
        last_hidden = get_last_visit(output, masks)
        fc = self.fc(last_hidden)
        #return(self.sigmoid(fc).view(-1))
        return(fc.view(-1))



2.2 Loss, optimizer and model evaluation

In [19]:
## load the model here
model = BaseRNN(num_codes = num_types)
print(model)

## loss and optimizer
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.005)

## eval_model
from sklearn.metrics import *
def regression_metrics(Y_pred, Y_True):
    # Evaluation of methods: 
    # 1. Pearson's correlation (r), 
    # 2. Spearman's correlation (),
    # 3. Mean absolute prediction error (MAPE),
    # 4. R squared (r2),
    # 5. Cumming's Prediction Measure (CPM)
    mae, r2 = mean_absolute_error(Y_True, Y_pred), \
                r2_score(Y_True, Y_pred)
    return mae, r2

def eval_model(model, val_loader):
    model.eval()
    y_true = list()
    y_pred = list()
    #for x, y in val_loader:
    for batch in val_loader:
        xSUB, masksSUB, rev_xSUB, rev_masksSUB, labelsSUB = batch
        with torch.no_grad():
            pred = model(xSUB, masksSUB, rev_xSUB, rev_masksSUB)
            y_true.extend(labelsSUB.detach().numpy().tolist())
            y_pred.extend(pred.detach().numpy().reshape(-1).tolist())

    mae, r2 = regression_metrics(y_pred, y_true)
    return(mae, r2)
print(eval_model(model, train_loader))

BaseRNN(
  (embedding): Embedding(13367, 32)
  (rnn): GRU(32, 32, batch_first=True)
  (fc): Linear(in_features=32, out_features=1, bias=True)
)
(1877.5939231460702, -0.013321903072579744)


2.3 Model training, and test sample validation

In [20]:
def train(model, train_loader, val_loader, n_epochs):
    model.train()
    for epoch in range(n_epochs):
        train_loss= 0
        all_y_true = torch.LongTensor()
        all_y_pred = torch.LongTensor()

        for batch in train_loader:
            xSUB, masksSUB, rev_xSUB, rev_masksSUB, y = batch
            optimizer.zero_grad()
            y_pred = model(xSUB, masksSUB, rev_xSUB, rev_masksSUB)

            y = y.view(y.shape[0])
            #print(y_pred.shape)
            #print(y.shape)
            all_y_true = torch.cat((all_y_true, y.to('cpu').long()), dim=0)
            all_y_pred = torch.cat((all_y_pred, y_pred.to('cpu').long()), dim=0)

            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss = train_loss / len(train_loader)
        train_MAE, r2 = eval_model(model, train_loader)
        val_MAE, r2 = eval_model(model, val_loader)
        print(f'Epoch: {epoch+1} \t Training Loss: {train_loss} \t Training MAE: {train_MAE} \t Validation MAE: {val_MAE}')

n_epochs = 25 
train(model, train_loader, val_loader, n_epochs)

test_mae, _ = eval_model(model, val_loader)
print('Test MAE: %.2f'%(test_mae))

Epoch: 1 	 Training Loss: 1847.8011940331146 	 Training MAE: 1824.6236310742977 	 Validation MAE: 1752.9916416150552
Epoch: 2 	 Training Loss: 1808.7991976921792 	 Training MAE: 1796.4273801128707 	 Validation MAE: 1724.423391162861
Epoch: 3 	 Training Loss: 1788.5290145355018 	 Training MAE: 1779.4945938665269 	 Validation MAE: 1707.756992827019
Epoch: 4 	 Training Loss: 1773.2840381025458 	 Training MAE: 1768.7862839207542 	 Validation MAE: 1698.5549432268674
Epoch: 5 	 Training Loss: 1767.5847320902644 	 Training MAE: 1759.842841738822 	 Validation MAE: 1693.6497900031795
Epoch: 6 	 Training Loss: 1756.102658987586 	 Training MAE: 1748.7064996376912 	 Validation MAE: 1689.0684605956742
Epoch: 7 	 Training Loss: 1745.4524324032184 	 Training MAE: 1737.059284052921 	 Validation MAE: 1685.298776993227
Epoch: 8 	 Training Loss: 1735.6582915636957 	 Training MAE: 1725.5150400325977 	 Validation MAE: 1682.1508143704302
Epoch: 9 	 Training Loss: 1724.5841141958085 	 Training MAE: 1713.8124

KeyboardInterrupt: ignored