In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torchtuples as tt
import warnings

from load import read_csv

from pycox.models import LogisticHazard
from pycox.evaluation import EvalSurv

from sklearn.model_selection import KFold
from sklearn_pandas import DataFrameMapper
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler

from torch.utils.data import DataLoader



In [2]:
datapath = './Data/data.csv'
data = read_csv(datapath)
print(data.shape)
data = data.drop(columns='PATIENTID')
print(data.columns)

(40018, 42)
Index(['GRADE', 'AGE', 'SEX', 'QUINTILE_2015', 'TUMOUR_COUNT', 'SACT',
       'REGIMEN_COUNT', 'CLINICAL_TRIAL_INDICATOR',
       'CHEMO_RADIATION_INDICATOR', 'NORMALISED_HEIGHT', 'NORMALISED_WEIGHT',
       'DAYS_TO_FIRST_SURGERY', 'DAYS_SINCE_DIAGNOSIS', 'SITE_C70', 'SITE_C71',
       'SITE_C72', 'SITE_D32', 'SITE_D33', 'SITE_D35', 'BENIGN_BEHAVIOUR',
       'CREG_L0201', 'CREG_L0301', 'CREG_L0401', 'CREG_L0801', 'CREG_L0901',
       'CREG_L1001', 'CREG_L1201', 'CREG_L1701', 'LAT_9', 'LAT_B', 'LAT_L',
       'LAT_M', 'LAT_R', 'ETH_A', 'ETH_B', 'ETH_C', 'ETH_M', 'ETH_O', 'ETH_U',
       'ETH_W', 'EVENT'],
      dtype='object')


In [3]:
# standardisation of features
cols_standardise = ['GRADE', 'AGE', 'QUINTILE_2015', 'NORMALISED_HEIGHT', 'NORMALISED_WEIGHT']
cols_minmax = ['SEX', 'TUMOUR_COUNT', 'REGIMEN_COUNT']
cols_leave = ['SACT', 'CLINICAL_TRIAL_INDICATOR', 'CHEMO_RADIATION_INDICATOR','BENIGN_BEHAVIOUR','SITE_C70', 'SITE_C71', 'SITE_C72', 'SITE_D32','SITE_D33','SITE_D35','CREG_L0201','CREG_L0301','CREG_L0401','CREG_L0801','CREG_L0901','CREG_L1001','CREG_L1201','CREG_L1701','LAT_9','LAT_B','LAT_L','LAT_M','LAT_R','ETH_A','ETH_B','ETH_C','ETH_M','ETH_O','ETH_U','ETH_W','DAYS_TO_FIRST_SURGERY']

all_cols = cols_standardise + cols_minmax + cols_leave

print(len(data.columns) == len(cols_standardise + cols_minmax + cols_leave) + 2)

standardise = [([col], StandardScaler()) for col in cols_standardise]
minmax = [([col], MinMaxScaler()) for col in cols_minmax]
leave = [(col, None) for col in cols_leave]

x_mapper = DataFrameMapper(standardise + minmax + leave)

# discretisation
num_durations = 50
labtrans = LogisticHazard.label_transform(num_durations, scheme='quantiles')


True


In [4]:
df_test = data.sample(frac=0.1)
df_train = data.drop(df_test.index)
df_val = df_train.sample(frac=0.1)
df_train = df_train.drop(df_val.index)

get_target = lambda df: (df['DAYS_SINCE_DIAGNOSIS'].values, df['EVENT'].values)
y_train = labtrans.fit_transform(*get_target(df_train))
y_val = labtrans.transform(*get_target(df_val))
y_test = get_target(df_test)

x_train = x_mapper.fit_transform(df_train).astype('float32')
x_val = x_mapper.transform(df_val).astype('float32')
x_test = x_mapper.transform(df_test).astype('float32')


train = (x_train, y_train)
val = (x_val, y_val)

in_features = x_train.shape[1]
num_nodes = [168, 168, 168, 168]
out_features = labtrans.out_features
batch_norm = True
dropout = 0.1

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout)
model = LogisticHazard(net, tt.optim.Adam(0.0005), duration_index=labtrans.cuts)

batch_size = 256
epochs = 100
callbacks = [tt.cb.EarlyStopping()]

log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val)
# log = model.fit(x_train, y_train, batch_size, epochs, callbacks)

surv = model.interpolate(10).predict_surv_df(x_test)

ev = EvalSurv(surv, *y_test, censor_surv='km')
print(ev.concordance_td('antolini'))

0:	[1s / 1s],		train_loss: 21.7236,	val_loss: 15.9305
1:	[1s / 2s],		train_loss: 9.1735,	val_loss: 4.2051
2:	[1s / 3s],		train_loss: 3.2375,	val_loss: 2.4347
3:	[1s / 4s],		train_loss: 2.3843,	val_loss: 2.1559
4:	[1s / 6s],		train_loss: 2.2136,	val_loss: 2.0959
5:	[1s / 7s],		train_loss: 2.1575,	val_loss: 2.0621
6:	[1s / 8s],		train_loss: 2.1434,	val_loss: 2.0540
7:	[1s / 9s],		train_loss: 2.1294,	val_loss: 2.0480
8:	[1s / 10s],		train_loss: 2.1160,	val_loss: 2.0493
9:	[1s / 12s],		train_loss: 2.1115,	val_loss: 2.0488
10:	[1s / 13s],		train_loss: 2.1109,	val_loss: 2.0440
11:	[1s / 14s],		train_loss: 2.1037,	val_loss: 2.0468
12:	[1s / 15s],		train_loss: 2.1055,	val_loss: 2.0442
13:	[1s / 16s],		train_loss: 2.0971,	val_loss: 2.0396
14:	[1s / 18s],		train_loss: 2.0908,	val_loss: 2.0378
15:	[1s / 19s],		train_loss: 2.0909,	val_loss: 2.0410
16:	[1s / 20s],		train_loss: 2.0872,	val_loss: 2.0405
17:	[1s / 21s],		train_loss: 2.0877,	val_loss: 2.0421
18:	[1s / 22s],		train_loss: 2.0798,	val_los