In [26]:
import torch
import pandas as pd

from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

from SurvTRACE.survtrace.utils import set_random_seed
from SurvTRACE.survtrace.config import STConfig


from utils.covariate_data_processing import pbc2_proccess_covariate, padded_mask_processing

In [7]:
# define the setup parameters - we want to extend the metabric config to save the time, all different in PBC2 will be changed by our functions
STConfig['data'] = 'metabric'

set_random_seed(STConfig['seed'])

hparams = {
    'batch_size': 64,
    'weight_decay': 1e-4,
    'learning_rate': 1e-3,
    'epochs': 20,
}

In [8]:
df = pd.read_csv("/Users/davidlee/Documents/GitHub/Surtimesurvival/Data Project/Pycox Lib/PBC2 Convariate Data/pbc2_data_proccessed_auton_covariate.csv")

In [9]:
df_temp = df.loc[0:, ['seq_id', 'seq_time_id']]
df_event_time_temp = df.loc[0:, ['event', 'time']]
df = df.drop(columns=['seq_id', 'seq_time_id'])

In [10]:
df_temp.head()

Unnamed: 0,seq_id,seq_time_id
0,0,0.569489
1,0,1.09517
2,1,5.31979
3,1,6.261636
4,1,7.266455


In [11]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1945 entries, 0 to 1944
Data columns (total 23 columns):
 #   Column                         Non-Null Count  Dtype  
---  ------                         --------------  -----  
 0   event                          1945 non-null   float64
 1   time                           1945 non-null   float64
 2   seq_temporal_SGOT              1945 non-null   float64
 3   seq_temporal_age               1945 non-null   float64
 4   seq_temporal_albumin           1945 non-null   float64
 5   seq_temporal_alkaline          1945 non-null   float64
 6   seq_temporal_platelets         1945 non-null   float64
 7   seq_temporal_prothrombin       1945 non-null   float64
 8   seq_temporal_serBilir          1945 non-null   float64
 9   seq_temporal_serChol           1945 non-null   float64
 10  seq_static_sex_1.0             1945 non-null   int64  
 11  seq_temporal_ascites_1.0       1945 non-null   int64  
 12  seq_temporal_ascites_2.0       1945 non-null   i

In [12]:
# for noauton processing
# df.rename(columns = {'seq_out_time_to_event':'duration', 'seq_out_event':'event'}, inplace = True)

#for auton processing
df.rename(columns = {'time':'duration'}, inplace = True)

In [13]:
y, df, df_train, df_y_train = pbc2_proccess_covariate(df)

26.0
______
8
______
13
______
21
______
26
______
[2.73792575e-03 3.83309605e-02 1.86178951e-01 4.73661154e-01
 1.43056620e+01]
______
4




In [14]:
df_train.head()

Unnamed: 0,seq_static_sex_1.0,seq_temporal_ascites_1.0,seq_temporal_ascites_2.0,seq_temporal_drug_1.0,seq_temporal_edema_1.0,seq_temporal_edema_2.0,seq_temporal_hepatomegaly_1.0,seq_temporal_hepatomegaly_2.0,seq_temporal_histologic_1.0,seq_temporal_histologic_2.0,...,seq_temporal_spiders_1.0,seq_temporal_spiders_2.0,seq_temporal_SGOT,seq_temporal_age,seq_temporal_albumin,seq_temporal_alkaline,seq_temporal_platelets,seq_temporal_prothrombin,seq_temporal_serBilir,seq_temporal_serChol
0,0.0,3.0,4.0,6.0,9.0,10.0,13.0,14.0,16.0,18.0,...,23.0,24.0,-1.485263,0.248058,-0.894575,0.195532,-0.529101,0.136768,3.28189,1.24208e-16
1,0.0,3.0,4.0,6.0,9.0,10.0,13.0,14.0,16.0,18.0,...,23.0,24.0,0.195488,0.248058,-1.570646,0.285613,-0.456022,0.813132,2.015877,-0.4694608
2,0.0,3.0,4.0,6.0,9.0,10.0,13.0,14.0,16.0,19.0,...,23.0,24.0,-0.442126,1.292856,-1.431455,-0.605844,-1.395605,0.339677,0.17271,-0.6589138
3,0.0,3.0,4.0,6.0,9.0,10.0,13.0,14.0,16.0,19.0,...,23.0,24.0,-0.046806,1.292856,-1.172958,-0.512364,-1.259888,0.339677,-0.013468,-0.6036567
4,0.0,3.0,4.0,6.0,9.0,10.0,13.0,14.0,16.0,19.0,...,23.0,24.0,0.29368,1.292856,-1.312149,-0.443529,-1.364286,0.339677,0.098239,1.24208e-16


In [15]:
df_y_train.head()

Unnamed: 0,duration,event,proportion
0,3,1.0,0.006928
1,3,1.0,0.006928
2,3,0.0,0.988915
3,3,0.0,0.988915
4,3,0.0,0.988915


In [16]:
df_train = pd.concat([df_train, df_temp['seq_id']], axis=1, join='inner')
df_y_train = pd.concat([df_y_train, df_temp['seq_id']], axis=1, join='inner')

In [18]:
masks, padded_patients = padded_mask_processing(df_train)

16 16


In [30]:
print(padded_patients.shape)
print(masks.shape)

torch.Size([312, 16, 21])
torch.Size([312, 16])


In [21]:
df_y_train = df_y_train.drop_duplicates(subset='seq_id', keep='last')
df_y_train = df_y_train.reset_index(drop=True)
df_y_train = df_y_train.drop(columns=['seq_id'])
df_y_train

Unnamed: 0,duration,event,proportion
0,3,1.0,0.006928
1,3,0.0,0.988915
2,3,1.0,0.019002
3,2,1.0,0.314286
4,3,0.0,0.263658
...,...,...,...
307,3,0.0,0.326405
308,3,0.0,0.294933
309,3,0.0,0.284046
310,3,0.0,0.264252


In [22]:
X_features_data_tensor = padded_patients
Y_labels_data_tensor = torch.tensor(df_y_train.values)

In [24]:
X_train, X_val, y_train, y_val, masks_train, masks_val = train_test_split(X_features_data_tensor, Y_labels_data_tensor, masks, test_size=0.1)
train_data = TensorDataset(X_train, y_train, masks_train)
val_data = TensorDataset(X_val, y_val, masks_val)

In [27]:
from model.survtimesurvival_model import TransformerClassifier
# Hyperparameters
embed_dim = 32
num_heads = 2
ffn_hidden_dim = 64
num_layers = 2


batch_size = 1
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)

# Set up training configurations
input_dim = X_features_data_tensor.size(2)
seq_length = X_features_data_tensor.size(1)
model = TransformerClassifier(input_dim, seq_length, embed_dim, num_heads, ffn_hidden_dim, num_layers)

In [28]:
a = masks_train[0].unsqueeze(0)
output = model(X_train[0], a)

In [29]:
output.shape

torch.Size([1, 21])