In [1]:
import math
import os
import random
import itertools

import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.preprocessing.sequence import pad_sequences
from IPython.display import clear_output
from auton_survival import datasets, preprocessing, models
from sklearn.model_selection import train_test_split

import torch
from torch import nn, matmul
from torch.nn.functional import softmax

clear_output()

In [2]:
# define the setup parameters
import matplotlib.pyplot as plt

from SurvTRACE.survtrace.dataset import load_data
from SurvTRACE.survtrace.evaluate_utils import Evaluator
from SurvTRACE.survtrace.utils import set_random_seed
from SurvTRACE.survtrace.model import SurvTraceSingle
from SurvTRACE.survtrace.train_utils import Trainer
from SurvTRACE.survtrace.config import STConfig

STConfig['data'] = 'metabric'
STConfig['duration_index'] = [0.11225496, 2.06987187, 3.72357902, 6.68738364, 14.30566203]


STConfig['num_numerical_feature'] = 21
STConfig['num_categorical_feature'] = 0
STConfig['num_feature'] = 21
STConfig['vocab_size'] = 0
STConfig['out_feature'] = 4

set_random_seed(STConfig['seed'])

hparams = {
    'batch_size': 2,
    'weight_decay': 1e-4,
    'learning_rate': 1e-3,
    'epochs': 1,
}

In [3]:
df = pd.read_csv('/Users/davidlee/Documents/GitHub/Surtimesurvival/Data Project/Pycox Lib/PBC2 Convariate Data/pbc2_data_proccessed_auton_covariate.csv')

In [4]:
df.head()

Unnamed: 0,event,time,seq_id,seq_time_id,seq_temporal_SGOT,seq_temporal_age,seq_temporal_albumin,seq_temporal_alkaline,seq_temporal_platelets,seq_temporal_prothrombin,...,seq_temporal_drug_1.0,seq_temporal_edema_1.0,seq_temporal_edema_2.0,seq_temporal_hepatomegaly_1.0,seq_temporal_hepatomegaly_2.0,seq_temporal_histologic_1.0,seq_temporal_histologic_2.0,seq_temporal_histologic_3.0,seq_temporal_spiders_1.0,seq_temporal_spiders_2.0
0,1.0,0.569489,0,0.569489,-1.485263,0.248058,-0.894575,0.195532,-0.529101,0.136768,...,0,1,0,1,0,0,0,1,1,0
1,1.0,0.569489,0,1.09517,0.195488,0.248058,-1.570646,0.285613,-0.456022,0.813132,...,0,1,0,1,0,0,0,1,1,0
2,0.0,14.152338,1,5.31979,-0.442126,1.292856,-1.431455,-0.605844,-1.395605,0.339677,...,0,1,0,1,0,0,1,0,1,0
3,0.0,14.152338,1,6.261636,-0.046806,1.292856,-1.172958,-0.512364,-1.259888,0.339677,...,0,1,0,1,0,0,1,0,1,0
4,0.0,14.152338,1,7.266455,0.29368,1.292856,-1.312149,-0.443529,-1.364286,0.339677,...,0,1,0,1,0,0,1,0,1,0


In [5]:
df_temp = df.loc[0:, ['seq_id', 'time']]
df_event_time_temp = df.loc[0:, ['seq_id' ,'event', 'time']]
df = df.drop(columns=['seq_time_id', 'event', 'time'])

In [6]:
# df.rename(columns = {'seq_out_time_to_event':'duration', 'seq_out_event':'event'}, inplace = True)
df.rename(columns = {'time':'duration'}, inplace = True)

In [7]:
# Check sequence length
def padded_mask_processing(df_train):
  max_seq_length = df_train.groupby("seq_id").size().max()
  num_patients = len(df_train["seq_id"].unique())
  print(max_seq_length, max_seq_length)
  padded_patients = []
  masks = []
  for patient_id, patient_data in df_train.groupby("seq_id"):
      padding_rows = max_seq_length - len(patient_data)

      current_patients = torch.zeros(max_seq_length, df_train.shape[1])
      curent_masks = torch.zeros(max_seq_length)
      current_patients[:len(patient_data)] = torch.tensor(patient_data.to_numpy())
      curent_masks[:len(patient_data)] = 1
      masks.append(curent_masks)
      padded_patients.append(current_patients)
  padded_patients = torch.stack(padded_patients)
  masks = torch.stack(masks)
  padded_patients = padded_patients[:,:,1:]
  return masks, padded_patients


In [8]:
df.head()

Unnamed: 0,seq_id,seq_temporal_SGOT,seq_temporal_age,seq_temporal_albumin,seq_temporal_alkaline,seq_temporal_platelets,seq_temporal_prothrombin,seq_temporal_serBilir,seq_temporal_serChol,seq_static_sex_1.0,...,seq_temporal_drug_1.0,seq_temporal_edema_1.0,seq_temporal_edema_2.0,seq_temporal_hepatomegaly_1.0,seq_temporal_hepatomegaly_2.0,seq_temporal_histologic_1.0,seq_temporal_histologic_2.0,seq_temporal_histologic_3.0,seq_temporal_spiders_1.0,seq_temporal_spiders_2.0
0,0,-1.485263,0.248058,-0.894575,0.195532,-0.529101,0.136768,3.28189,1.169016e-16,0,...,0,1,0,1,0,0,0,1,1,0
1,0,0.195488,0.248058,-1.570646,0.285613,-0.456022,0.813132,2.015877,-0.4694608,0,...,0,1,0,1,0,0,0,1,1,0
2,1,-0.442126,1.292856,-1.431455,-0.605844,-1.395605,0.339677,0.17271,-0.6589138,0,...,0,1,0,1,0,0,1,0,1,0
3,1,-0.046806,1.292856,-1.172958,-0.512364,-1.259888,0.339677,-0.013468,-0.6036567,0,...,0,1,0,1,0,0,1,0,1,0
4,1,0.29368,1.292856,-1.312149,-0.443529,-1.364286,0.339677,0.098239,1.169016e-16,0,...,0,1,0,1,0,0,1,0,1,0


In [9]:
masks, padded_patients = padded_mask_processing(df)

16 16


In [10]:
padded_patients.shape

torch.Size([312, 16, 21])

In [11]:
masks.shape

torch.Size([312, 16])

In [12]:
padded_patients[0]

tensor([[-1.4853e+00,  2.4806e-01, -8.9458e-01,  1.9553e-01, -5.2910e-01,
          1.3677e-01,  3.2819e+00,  1.1690e-16,  0.0000e+00,  1.0000e+00,
          0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00,  1.0000e+00,
          0.0000e+00],
        [ 1.9549e-01,  2.4806e-01, -1.5706e+00,  2.8561e-01, -4.5602e-01,
          8.1313e-01,  2.0159e+00, -4.6946e-01,  0.0000e+00,  1.0000e+00,
          0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00,  1.0000e+00,
          0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e

In [13]:
df_y = df_event_time_temp.drop_duplicates(subset='seq_id', keep='last')

In [14]:
df_y = df_y.reset_index(drop=True)

In [15]:
df_y = df_y.drop(columns=['seq_id'])

In [16]:
df_y.head()

Unnamed: 0,event,time
0,1.0,0.569489
1,0.0,14.152338
2,1.0,0.736502
3,1.0,0.276531
4,0.0,4.120578


In [17]:
X_features_data_tensor = padded_patients
Y_labels_data_tensor = torch.Tensor(df_y.values)

In [18]:
import torch
from torch.utils.data import TensorDataset

X_train, X_val, y_train, y_val, masks_train, masks_val = train_test_split(X_features_data_tensor, Y_labels_data_tensor,
                                                                          masks, test_size=0.3)
train_data = TensorDataset(X_train, y_train, masks_train)
val_data = TensorDataset(X_val, y_val, masks_val)

In [19]:
X_train.shape

torch.Size([218, 16, 21])

In [20]:
y_train.shape

torch.Size([218, 2])

In [21]:
from model.custom_model_survtrace import Custom_SurvTrace

model = Custom_SurvTrace(STConfig)
a = masks_train[0:4]
# print(X_train[0].shape)
# print(X_train[0].unsqueeze(0).shape)
output = model(input_nums = X_train[0:4], our_mask = a)

torch.Size([4, 21])


In [22]:
output

((tensor([[[-1.7078e-04,  1.3087e-05, -7.9956e-05,  ..., -1.1270e-04,
             7.1550e-05, -3.5457e-05],
           [-1.6163e-04,  1.9880e-05,  1.9303e-04,  ...,  5.0802e-05,
            -6.5775e-05,  3.3731e-04],
           [ 1.0822e-04,  3.5568e-05,  9.1149e-05,  ...,  3.6070e-05,
             5.0647e-05, -2.2794e-05],
           ...,
           [ 2.0328e-03,  1.9553e-03, -1.4191e-03,  ...,  3.7438e-04,
            -1.0329e-04, -2.4697e-03],
           [ 8.4235e-05,  2.5445e-04,  5.9989e-04,  ...,  4.2486e-04,
            -3.4781e-04, -7.8210e-04],
           [-2.3195e-04, -5.2743e-05,  1.2989e-03,  ...,  6.6997e-05,
            -1.7274e-04,  4.5446e-04]],
  
          [[ 5.3063e-04, -4.0662e-05,  2.4843e-04,  ...,  3.5016e-04,
            -2.2231e-04,  1.1017e-04],
           [ 3.6755e-04, -4.5205e-05, -4.3895e-04,  ..., -1.1552e-04,
             1.4957e-04, -7.6702e-04],
           [-1.3747e-04, -4.5180e-05, -1.1578e-04,  ..., -4.5818e-05,
            -6.4335e-05,  2.8954e-05],

In [71]:
STConfig.keys()

dict_keys(['data', 'num_durations', 'horizons', 'seed', 'checkpoint', 'vocab_size', 'hidden_size', 'intermediate_size', 'num_hidden_layers', 'num_attention_heads', 'hidden_dropout_prob', 'num_feature', 'num_numerical_feature', 'num_categorical_feature', 'out_feature', 'num_event', 'hidden_act', 'attention_probs_dropout_prob', 'early_stop_patience', 'initializer_range', 'layer_norm_eps', 'max_position_embeddings', 'chunk_size_feed_forward', 'output_attentions', 'output_hidden_states', 'tie_word_embeddings', 'pruned_heads'])

In [22]:
df_train=X_train
df_y_train=y_train
df_val=X_val
df_y_val=y_val

In [23]:
# get model
model = SurvTraceSingle(STConfig)

# initialize a trainer
trainer = Trainer(model)
train_loss, val_loss = trainer.fit((df_train, df_y_train), (df_val, df_y_val),
        batch_size=16,
        epochs=hparams['epochs'],
        learning_rate=hparams['learning_rate'],
        weight_decay=hparams['weight_decay'],)

GPU not found! will use cpu for training!
train with single event


IndexError: too many indices for tensor of dimension 2