In [26]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torchtuples as tt

from load import read_csv

from pycox.models import LogisticHazard
from pycox.evaluation import EvalSurv

from dataset import Dataset
from fedcox import Federation
from net import MLP 
from discretiser import Discretiser
from interpolate import surv_const_pdf, surv_const_pdf_df

from sklearn_pandas import DataFrameMapper
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler

from tabulate import tabulate

from torch.utils.data import DataLoader



In [27]:
np.random.seed(1234)
_ = torch.manual_seed(123)

In [28]:
datapath = './Data/data.csv'
data = read_csv(datapath)
print(len(data))

40018


In [29]:
df_test = data.sample(frac=0.2)
df_train = data.drop(df_test.index)
df_val = df_train.sample(frac=0.1)
df_train = df_train.drop(df_val.index)

In [30]:
m_brain = df_test['SITE_C71'] == 1
m_other = (df_test['SITE_C70'] == 1) | (df_test['SITE_C72'] == 1)
benign = (df_test['SITE_D32'] == 1) | (df_test['SITE_D33'] == 1) | (df_test['SITE_D35'] == 1)

print('malignant brain: ',m_brain.sum())
print('malignant other: ',m_other.sum())
print('benign: ',benign.sum())

overlap = (m_brain & m_other & benign).sum()
print(overlap)


malignant brain:  3867
malignant other:  192
benign:  3951
0


In [31]:
# cols = list(df_train.columns.values)
# print(cols)
df_train = df_train[['GRADE', 'AGE', 'SEX', 'QUINTILE_2015', 'TUMOUR_COUNT', 'SACT', 'REGIMEN_COUNT', 'CLINICAL_TRIAL_INDICATOR', 'CHEMO_RADIATION_INDICATOR', 'NORMALISED_HEIGHT', 'NORMALISED_WEIGHT', 'BENIGN_BEHAVIOUR', 'SITE_C70', 'SITE_C71', 'SITE_C72', 'SITE_D32', 'SITE_D33', 'SITE_D35', 'CREG_L0201', 'CREG_L0301', 'CREG_L0401', 'CREG_L0801', 'CREG_L0901', 'CREG_L1001', 'CREG_L1201', 'CREG_L1701', 'LAT_9', 'LAT_B', 'LAT_L', 'LAT_M', 'LAT_R', 'ETH_A', 'ETH_B', 'ETH_C', 'ETH_M', 'ETH_O', 'ETH_U', 'ETH_W', 'DAYS_TO_FIRST_SURGERY','DAYS_SINCE_DIAGNOSIS','EVENT']]
df_train.head()

Unnamed: 0,GRADE,AGE,SEX,QUINTILE_2015,TUMOUR_COUNT,SACT,REGIMEN_COUNT,CLINICAL_TRIAL_INDICATOR,CHEMO_RADIATION_INDICATOR,NORMALISED_HEIGHT,...,ETH_A,ETH_B,ETH_C,ETH_M,ETH_O,ETH_U,ETH_W,DAYS_TO_FIRST_SURGERY,DAYS_SINCE_DIAGNOSIS,EVENT
0,4,69.0,1,4.0,2,0,0.0,0,0,0.0,...,0,0,0,0,0,0,1,0,751,0
1,4,88.0,1,1.0,2,1,5.0,1,1,1.755,...,0,0,0,0,0,0,1,0,17,1
2,4,79.0,1,2.0,2,0,0.0,0,0,0.0,...,0,0,0,0,0,0,1,0,252,1
5,-1,28.0,1,1.0,2,0,0.0,0,0,0.0,...,0,0,0,0,0,1,0,0,839,0
7,-1,65.0,1,2.0,2,0,0.0,0,0,0.0,...,0,0,0,0,0,0,1,12,770,0


In [32]:
cols_standardise = ['GRADE', 'AGE', 'QUINTILE_2015', 'NORMALISED_HEIGHT', 'NORMALISED_WEIGHT']
cols_minmax = ['SEX', 'TUMOUR_COUNT', 'REGIMEN_COUNT']
cols_leave = ['SACT', 'CLINICAL_TRIAL_INDICATOR', 'CHEMO_RADIATION_INDICATOR','BENIGN_BEHAVIOUR','SITE_C70', 'SITE_C71', 'SITE_C72', 'SITE_D32','SITE_D33','SITE_D35','CREG_L0201','CREG_L0301','CREG_L0401','CREG_L0801','CREG_L0901','CREG_L1001','CREG_L1201','CREG_L1701','LAT_9','LAT_B','LAT_L','LAT_M','LAT_R','ETH_A','ETH_B','ETH_C','ETH_M','ETH_O','ETH_U','ETH_W','DAYS_TO_FIRST_SURGERY']

print(len(df_train.columns) == len(cols_standardise + cols_minmax + cols_leave))

standardise = [([col], StandardScaler()) for col in cols_standardise]
minmax = [([col], MinMaxScaler()) for col in cols_minmax]
leave = [(col, None) for col in cols_leave]

x_mapper = DataFrameMapper(standardise + minmax + leave)

False


In [33]:
x_train = x_mapper.fit_transform(df_train).astype('float32')
x_val = x_mapper.transform(df_val).astype('float32')
x_test = x_mapper.transform(df_test).astype('float32')

In [34]:
num_durations = 50

discretiser = Discretiser(num_durations, scheme='km')

y_train = (df_train.DAYS_SINCE_DIAGNOSIS.values, df_train.EVENT.values)
y_train = discretiser.fit_transform(*y_train)

y_val = (df_val.DAYS_SINCE_DIAGNOSIS.values, df_val.EVENT.values)
y_val = discretiser.transform(*y_val)
val_laoder = DataLoader(Dataset(x_val, y_val), batch_size=256, shuffle=False)

y_test = (df_test.DAYS_SINCE_DIAGNOSIS.values, df_test.EVENT.values)
test_loader = DataLoader(Dataset(x_test, y_test), batch_size=256, shuffle=False)





















































In [35]:
dim_in = x_train.shape[1]
num_nodes = [32, 32]
dim_out = len(discretiser.cuts)
batch_norm = True
dropout = 0.1

net = MLP(dim_in=dim_in, num_nodes=num_nodes, dim_out=dim_out, batch_norm=batch_norm, dropout=dropout)

num_centers = 1
optimizer = 'adam'
lr = 0.001


32 50 True
<class 'int'> <class 'int'>


In [36]:
fed = Federation(features=x_train, labels=y_train, net=net, num_centers=num_centers, optimizer=optimizer, lr=lr, batch_size=256, local_epochs=5)
fed.fit(epochs=2)


 | Global Training Round : 1 |


 | Global Training Round : 2 |

 \Latest training stats after 2 global rounds:
Training loss : 2.504595217075977
Validation loss : 0.9047363184649369
Epochs exhausted


In [37]:
surv = fed.predict_surv(test_loader)[0]
surv = surv_const_pdf_df(surv, discretiser.cuts)

In [38]:
ev = EvalSurv(surv, *y_test, censor_surv='km')
ev.concordance_td('antolini')

0.7313799902611011

In [42]:
labtrans = LogisticHazard.label_transform(num_durations, scheme='quantiles')
get_target = lambda df: (df['DAYS_SINCE_DIAGNOSIS'].values, df['EVENT'].values)
y_train = labtrans.fit_transform(*get_target(df_train))
y_val = labtrans.transform(*get_target(df_val))

train = (x_train, y_train)
val = (x_val, y_val)

In [43]:
in_features = x_train.shape[1]
num_nodes = [32, 32]
out_features = labtrans.out_features
batch_norm = True
dropout = 0.1

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout)


In [49]:
model = LogisticHazard(net, tt.optim.Adam(0.001), duration_index=labtrans.cuts)

In [50]:
batch_size = 256
epochs = 3
callbacks = [tt.cb.EarlyStopping()]

In [51]:
log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val)

0:	[0s / 0s],		train_loss: 3.5182,	val_loss: 2.4398
1:	[0s / 0s],		train_loss: 2.3311,	val_loss: 2.1408
2:	[0s / 0s],		train_loss: 2.1850,	val_loss: 2.0938


In [52]:
surv = model.interpolate(10).predict_surv_df(x_test)

In [53]:
ev = EvalSurv(surv, *y_test, censor_surv='km')
ev.concordance_td('antolini')

0.7350419314255027