In [15]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torchtuples as tt
import warnings

from load import read_csv

from pycox.models import LogisticHazard
from pycox.evaluation import EvalSurv

from dataset import Dataset
from fedcox import Federation
from net import MLP 
from discretiser import Discretiser
from interpolate import surv_const_pdf, surv_const_pdf_df

from sklearn.model_selection import KFold
from sklearn_pandas import DataFrameMapper
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler

from tabulate import tabulate

from torch.utils.data import DataLoader



In [16]:
rng = np.random.default_rng(123)
_ = torch.manual_seed(123)


In [17]:
def count_benign_malignant(df):
    m_brain = df['SITE_C71'] == 1
    m_other = (df['SITE_C70'] == 1) | (df['SITE_C72'] == 1)
    benign = (df['SITE_D32'] == 1) | (df['SITE_D33'] == 1) | (df['SITE_D35'] == 1)

    print('malignant brain: ',m_brain.sum())
    print('malignant other: ',m_other.sum())
    print('benign: ',benign.sum())

    overlap = (m_brain & m_other & benign).sum()
    print(overlap)


In [18]:
datapath = './Data/data.csv'
data = read_csv(datapath)
print(len(data))
data = data.drop(columns='PATIENTID')
print(data.columns)

40018
Index(['GRADE', 'AGE', 'SEX', 'QUINTILE_2015', 'TUMOUR_COUNT', 'SACT',
       'REGIMEN_COUNT', 'CLINICAL_TRIAL_INDICATOR',
       'CHEMO_RADIATION_INDICATOR', 'NORMALISED_HEIGHT', 'NORMALISED_WEIGHT',
       'DAYS_TO_FIRST_SURGERY', 'DAYS_SINCE_DIAGNOSIS', 'SITE_C70', 'SITE_C71',
       'SITE_C72', 'SITE_D32', 'SITE_D33', 'SITE_D35', 'BENIGN_BEHAVIOUR',
       'CREG_L0201', 'CREG_L0301', 'CREG_L0401', 'CREG_L0801', 'CREG_L0901',
       'CREG_L1001', 'CREG_L1201', 'CREG_L1701', 'LAT_9', 'LAT_B', 'LAT_L',
       'LAT_M', 'LAT_R', 'ETH_A', 'ETH_B', 'ETH_C', 'ETH_M', 'ETH_O', 'ETH_U',
       'ETH_W', 'EVENT'],
      dtype='object')


In [19]:
# standardisation of features
cols_standardise = ['GRADE', 'AGE', 'QUINTILE_2015', 'NORMALISED_HEIGHT', 'NORMALISED_WEIGHT']
cols_minmax = ['SEX', 'TUMOUR_COUNT', 'REGIMEN_COUNT']
cols_leave = ['SACT', 'CLINICAL_TRIAL_INDICATOR', 'CHEMO_RADIATION_INDICATOR','BENIGN_BEHAVIOUR','SITE_C70', 'SITE_C71', 'SITE_C72', 'SITE_D32','SITE_D33','SITE_D35','CREG_L0201','CREG_L0301','CREG_L0401','CREG_L0801','CREG_L0901','CREG_L1001','CREG_L1201','CREG_L1701','LAT_9','LAT_B','LAT_L','LAT_M','LAT_R','ETH_A','ETH_B','ETH_C','ETH_M','ETH_O','ETH_U','ETH_W','DAYS_TO_FIRST_SURGERY']

print(len(data.columns) == len(cols_standardise + cols_minmax + cols_leave) + 2)

standardise = [([col], StandardScaler()) for col in cols_standardise]
minmax = [([col], MinMaxScaler()) for col in cols_minmax]
leave = [(col, None) for col in cols_leave]

x_mapper = DataFrameMapper(standardise + minmax + leave)

# discretisation
num_durations = 50
discretiser = Discretiser(num_durations, scheme='km')


True


In [20]:
def train_val_split(df, t_index, v_index, x_mapper, fit_transform=True):
    df_t = df.loc[t_index]
    df_v = df.loc[v_index]

    if fit_transform:
        x_t = x_mapper.fit_transform(df_t).astype('float32')
    else:
        x_t = x_mapper.transform(df_t).astype('float32')
    x_v = x_mapper.transform(df_v).astype('float32')

    y_t = (df_t.DAYS_SINCE_DIAGNOSIS.values, df_t.EVENT.values)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        if fit_transform:
            y_t = discretiser.fit_transform(*y_t)
        else:
            y_t = discretiser.transform(*y_t)

    y_v = (df_v.DAYS_SINCE_DIAGNOSIS.values, df_v.EVENT.values)

    return x_t, y_t, x_v, y_v

In [21]:
# federation parameters - excl lr
num_centers = 4
optimizer = 'adam'
batch_size = 256
local_epochs = 1
epochs = 100
print_every = 100

log = f'./training_log_C{num_centers}L{local_epochs}.txt'
with open(log, 'a') as f:
    print(f'-- Centers: {num_centers}, Local rounds: {local_epochs} --', file=f)

# CV setup
n_splits = 5
random_state = rng.integers(0,1000)
scores = []
parameters = []

kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
cv_round = 0
for train_index, test_index in kf.split(data):
    with open(log, 'a') as f:
        print(f'-- Eval CV fold: {cv_round} --', file=f)
    cv_round += 1
    x_train, y_train, x_test, y_test = train_val_split(data, train_index, test_index, x_mapper, fit_transform=True)

    test_loader = DataLoader(Dataset(x_test, y_test), batch_size=256, shuffle=False)

    # MLP parameters - excl dropout
    dim_in = x_train.shape[1]
    num_nodes = [32, 32]
    dim_out = len(discretiser.cuts)
    batch_norm = True
    
    best_lr = 0.001
    best_dropout = 0.1
    tuning = True
    if tuning:
        # parameter tuning
        learning_rates = [0.1, 0.01, 0.001]
        dropouts = [0.1, 0.5] 
        best_score = 0
        best_lr = None
        best_dropout = None
        para_splits = 5
        para_kf = KFold(n_splits=para_splits)
        para_round = 0
        for t_index, v_index in kf.split(x_train):
            x_t, y_t, x_v, y_v = train_val_split(data.loc[train_index].reset_index(), t_index, v_index, x_mapper, fit_transform=False)

            val_loader = DataLoader(Dataset(x_v, y_v), batch_size=256, shuffle=False)
            
            for lr in learning_rates:
                for dropout in dropouts:
                    net = MLP(dim_in=dim_in, num_nodes=num_nodes, dim_out=dim_out, batch_norm=batch_norm, dropout=dropout)
                    fed = Federation(features=x_t, labels=y_t, net=net, num_centers=num_centers, optimizer=optimizer, lr=lr, batch_size=batch_size, local_epochs=local_epochs)
                    fed.fit(epochs=epochs, print_every=print_every, verbose=False)    

                    surv = fed.predict_surv(val_loader)[0]
                    surv = surv_const_pdf_df(surv, discretiser.cuts) # interpolation
                    
                    ev = EvalSurv(surv, *y_v, censor_surv='km')
                    score = ev.concordance_td('antolini')
                    if score > best_score:
                        best_lr, best_dropout = lr, dropout
                        best_score = score
                    with open(log, 'a') as f:
                        print(f'Tuning CV fold {para_round}: conc = {score}, lr = {lr}, dropout = {dropout}', file=f)
            para_round += 1

    net = MLP(dim_in=dim_in, num_nodes=num_nodes, dim_out=dim_out, batch_norm=batch_norm, dropout=best_dropout)
    fed = Federation(features=x_train, labels=y_train, net=net, num_centers=num_centers, optimizer=optimizer, lr=best_lr, batch_size=batch_size, local_epochs=local_epochs)
    fed.fit(epochs=epochs, print_every=print_every)    

    surv = fed.predict_surv(test_loader)[0]
    surv = surv_const_pdf_df(surv, discretiser.cuts) # interpolation
    
    ev = EvalSurv(surv, *y_test, censor_surv='km')
    score = ev.concordance_td('antolini')
    scores.append(score)
    parameters.append({'lr' : best_lr, 'dropout' : best_dropout})

    with open(log, 'a') as f:
        print(f'>> Best parameters: conc = {score}, LR = {best_lr}, dropout ={best_dropout}', file=f)

with open(log, 'a') as f:
    print('Avg concordance: ', sum(scores) / len(scores), file=f)


Early stop at epoch 16
Tuning CV fold 0: conc = 0.7499163148583105, lr = 0.1, dropout = 0.1
Early stop at epoch 7
Tuning CV fold 0: conc = 0.7391691891076443, lr = 0.1, dropout = 0.5
Early stop at epoch 17
Tuning CV fold 0: conc = 0.754545781547903, lr = 0.01, dropout = 0.1
Early stop at epoch 14
Tuning CV fold 0: conc = 0.7518303363018625, lr = 0.01, dropout = 0.5


KeyboardInterrupt: 

In [None]:
if False:
    labtrans = LogisticHazard.label_transform(num_durations, scheme='quantiles')
    get_target = lambda df: (df['DAYS_SINCE_DIAGNOSIS'].values, df['EVENT'].values)
    y_train = labtrans.fit_transform(*get_target(df_train))
    # y_val = labtrans.transform(*get_target(df_val))

    train = (x_train, y_train)
    # val = (x_val, y_val)

    in_features = x_train.shape[1]
    num_nodes = [32, 32]
    out_features = labtrans.out_features
    batch_norm = True
    dropout = 0.1

    net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout)
    model = LogisticHazard(net, tt.optim.Adam(0.001), duration_index=labtrans.cuts)

    batch_size = 256
    epochs = 3
    callbacks = [tt.cb.EarlyStopping()]

    # log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val)
    log = model.fit(x_train, y_train, batch_size, epochs, callbacks)

    surv = model.interpolate(10).predict_surv_df(x_test)

    ev = EvalSurv(surv, *y_test, censor_surv='km')
    ev.concordance_td('antolini')