## Load package

In [None]:
import os
import pickle
import pandas as pd
import numpy as np

from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import SimpleImputer, IterativeImputer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from sksurv.metrics import concordance_index_censored

## Load data

In [None]:
datapath = "data"
df       = pd.read_csv(f"{datapath}/MCI_survival.csv")
df_rad   = pd.read_csv(f"{datapath}/MCI_radiomics.csv")

## Split data

In [None]:
with open(f"{datapath}/train.txt", "rb") as fp:
    train_ids = pickle.load(fp)
with open(f"{datapath}/test.txt", "rb") as fp:
    test_ids = pickle.load(fp)

In [None]:
train_ids, valid_ids = train_test_split(train_ids, test_size=0.2, random_state=333)

In [None]:
df.set_index(['RID'], inplace=True)
df_train = df.query('RID in @train_ids')
df_valid = df.query('RID in @valid_ids')
df_test  = df.query('RID in @test_ids')

In [None]:
df_rad.set_index(['RID'], inplace=True)

#### Select radiomic features

In [None]:
radiomic_tissue = ["GM"] #GM, WM, CSF
radiomic_type   = ["shape"] # shape, texture

In [None]:
rad_filtered = []
for rt in radiomic_tissue:
    my_df = df_rad.filter(regex=f'{rt}$',axis=1)
    for t in radiomic_type:
        if t=="shape":
            idx = list(range(18,32))
            my_df_t = my_df.iloc[:,idx]
        elif t=="texture":
            idx = list(range(0,18))+list(range(32,101))
            my_df_t = my_df.iloc[:,idx]
        rad_filtered.append(my_df_t)

In [None]:
df_rad_filter = pd.concat(rad_filtered, axis=1)

In [None]:
df_rad_train = df_rad_filter.query('RID in @train_ids')
df_rad_valid = df_rad_filter.query('RID in @valid_ids')
df_rad_test  = df_rad_filter.query('RID in @test_ids')

## Preprocessing

In [None]:
_ID_COLS    = ['RID']
_INFO_COLS  = ['M']
_DEMOG_COLS = ['gender','PTEDUCAT_norm','currage',\
               'PTRACCAT_Asian','PTRACCAT_Black','PTRACCAT_Hawaii','PTRACCAT_Indian','PTRACCAT_White',\
               'PTMARRY_Divorced','PTMARRY_Married','PTMARRY_Never married','PTMARRY_Widowed',\
               'PTETHCAT_Hisp','PTETHCAT_NoHisp']
_COG_COLS   = ['CDRSB', 'ADAS13', 'MMSE', 'RAVLT_learning', 'FAQ']
_BIO_COLS   = ['Ventricles_norm', 'Hippocampus_norm', 'WholeBrain_norm', 'Entorhinal_norm', 'Fusiform_norm', 'MidTemp_norm',\
               'FDG', 'AV45']
_LABEL_COLS = ['onset_interval', 'offset_interval', 'currdx_bl', 'currdx']

In [None]:
df[_DEMOG_COLS+_COG_COLS+_BIO_COLS].isnull().sum() * 100 / len(df)

#### Imputation

In [None]:
def normalize(df, desc=None):
    return ((df-desc['min'])/(desc['max']-desc['min']+1e-9)).values

In [None]:
def preprocessing(df, data_type=None, impute_mode='constant', impute_value=0):
    df = df.copy()
    demog_features = df[_DEMOG_COLS].copy()
    cog_features   = df[_COG_COLS].copy()
    bio_features   = df[_BIO_COLS].copy()
    labels         = df[_LABEL_COLS[0:3]].copy()
    # Save train features
    if data_type=='train':
        cog_features.to_csv(os.path.join(datapath,'MCI_cognitive_train.csv'), header=True, index=True, index_label=['RID'])
        bio_features.to_csv(os.path.join(datapath,'MCI_biomarker_train.csv'), header=True, index=True, index_label=['RID'])
        train_cog_features = cog_features.copy()
        train_bio_features = bio_features.copy()
    else:
        train_cog_features = pd.read_csv(os.path.join(datapath,'MCI_cognitive_train.csv'), index_col=['RID'])
        train_bio_features = pd.read_csv(os.path.join(datapath,'MCI_biomarker_train.csv'), index_col=['RID'])
    # Imputation
    if impute_mode=='iterative':
        cog_imputer = IterativeImputer()
        cog_imputer.fit(train_cog_features)
        imp_cog_features = cog_imputer.transform(cog_features)
        imp_cog_features = pd.DataFrame(imp_cog_features,
                                        columns=cog_features.columns,
                                        index=cog_features.index)
        bio_imputer = IterativeImputer()
        bio_imputer.fit(train_bio_features)
        imp_bio_features = bio_imputer.transform(bio_features)
        imp_bio_features = pd.DataFrame(imp_bio_features,
                                        columns=bio_features.columns,
                                        index=bio_features.index)
    elif impute_mode=='constant':
        cog_imputer = SimpleImputer(strategy=impute_mode,
                                           fill_value=impute_value)
        cog_imputer.fit(train_cog_features)
        imp_cog_features = cog_imputer.transform(cog_features)
        imp_cog_features = pd.DataFrame(imp_cog_features,
                                        columns=cog_features.columns,
                                        index=cog_features.index)
        bio_imputer = SimpleImputer(strategy=impute_mode,
                                           fill_value=impute_value)
        bio_imputer.fit(train_bio_features)
        imp_bio_features = bio_imputer.transform(bio_features)
        imp_bio_features = pd.DataFrame(imp_bio_features,
                                        columns=bio_features.columns,
                                        index=bio_features.index)
    
    # One-hot label
    labels['target'] = (labels[_LABEL_COLS[2]]==3).values.astype(np.float32)
    # Merge all features
    all_features = demog_features.copy()
    all_features = all_features.join(imp_cog_features)
    all_features = all_features.join(imp_bio_features)
    all_features = all_features.join(labels[['target']])
    
    # Compute survival time:
    labels['Event'] = np.array(~np.isinf(labels['offset_interval']), dtype=np.bool_)
    ## Add 2-year (730 days) as delaying time to the most recent visit of censor patients
    labels.loc[labels[_LABEL_COLS[1]].index[np.isinf(labels[_LABEL_COLS[1]])], _LABEL_COLS[1]] = 730 + labels.loc[labels[_LABEL_COLS[1]].index[np.isinf(labels[_LABEL_COLS[1]])], _LABEL_COLS[0]]
    times = labels.pop(_LABEL_COLS[1])
    events = labels.pop('Event')
    
    return all_features, times, events

In [None]:
tr_features, tr_times, tr_events = preprocessing(df_train, data_type='train', impute_mode='iterative', impute_value=0)
vl_features, vl_times, vl_events = preprocessing(df_valid, data_type='valid', impute_mode='iterative', impute_value=0)
ts_features, ts_times, ts_events = preprocessing(df_test, data_type='test', impute_mode='iterative', impute_value=0)

#### Normalize clinical data

In [None]:
tr_desc = tr_features.describe()
tr_features.loc[:,_BIO_COLS] = normalize(tr_features[_BIO_COLS], tr_desc[_BIO_COLS].transpose())
vl_features.loc[:,_BIO_COLS] = normalize(vl_features[_BIO_COLS], tr_desc[_BIO_COLS].transpose())
ts_features.loc[:,_BIO_COLS] = normalize(ts_features[_BIO_COLS], tr_desc[_BIO_COLS].transpose())

In [None]:
tr_rad_desc = df_rad_train.describe()
tr_rad_features = normalize(df_rad_train, tr_rad_desc.transpose())
vl_rad_features = normalize(df_rad_valid, tr_rad_desc.transpose())
ts_rad_features = normalize(df_rad_test, tr_rad_desc.transpose())

In [None]:
_COG_MAX = [18, 85, 30, 14, 30]
_COG_MIN = [0, 0, 0, -5, 0]
for i, cog in enumerate(_COG_COLS):
    tr_features[cog] = ((tr_features[cog] - _COG_MIN[i]) / (_COG_MAX[i] - _COG_MIN[i] + 1e-9)).values
    vl_features[cog] = ((vl_features[cog] - _COG_MIN[i]) / (_COG_MAX[i] - _COG_MIN[i] + 1e-9)).values
    ts_features[cog] = ((ts_features[cog] - _COG_MIN[i]) / (_COG_MAX[i] - _COG_MIN[i] + 1e-9)).values

In [None]:
min_age =  tr_features['currage'].min()
max_age =  tr_features['currage'].max()
tr_features['currage'] = ((tr_features['currage'] - min_age) / (max_age - min_age + 1e-9)).values
vl_features['currage'] = ((vl_features['currage'] - min_age) / (max_age - min_age + 1e-9)).values
ts_features['currage'] = ((ts_features['currage'] - min_age) / (max_age - min_age + 1e-9)).values

#### Organize data

In [None]:
tr_target = tr_features["target"]
vl_target = vl_features["target"]
ts_target = ts_features["target"]

tr_features.drop(["target"], axis=1, inplace=True)
vl_features.drop(["target"], axis=1, inplace=True)
ts_features.drop(["target"], axis=1, inplace=True)

In [None]:
tr_times  = tr_times.values[:, np.newaxis]
tr_events = tr_events.values
vl_times  = vl_times.values[:, np.newaxis]
vl_events = vl_events.values
ts_times  = ts_times.values[:, np.newaxis]
ts_events = ts_events.values

In [None]:
X_scaler    = StandardScaler().fit(tr_features)
tr_features = X_scaler.transform(tr_features)
vl_features = X_scaler.transform(vl_features)
ts_features = X_scaler.transform(ts_features)

X_rad_scaler    = StandardScaler().fit(tr_rad_features)
tr_rad_features = X_rad_scaler.transform(tr_rad_features)
vl_rad_features = X_rad_scaler.transform(vl_rad_features)
ts_rad_features = X_rad_scaler.transform(ts_rad_features)

Y_scaler = StandardScaler().fit(tr_times.reshape(-1, 1))
tr_times = Y_scaler.transform(tr_times)
vl_times = Y_scaler.transform(vl_times)
ts_times = Y_scaler.transform(ts_times)

tr_times = tr_times.flatten()
vl_times = vl_times.flatten()
ts_times = ts_times.flatten()

In [None]:
sort_tr_idx = np.argsort(tr_times)[::-1]
train_ids   = train_ids[sort_tr_idx]
tr_features = tr_features[sort_tr_idx]
tr_target   = tr_target.values[sort_tr_idx]
tr_times    = tr_times[sort_tr_idx]
tr_events   = tr_events[sort_tr_idx]
tr_rad_features = tr_rad_features[sort_tr_idx]

sort_vl_idx = np.argsort(vl_times)[::-1]
valid_ids   = valid_ids[sort_vl_idx]
vl_features = vl_features[sort_vl_idx]
vl_target   = vl_target.values[sort_vl_idx]
vl_times    = vl_times[sort_vl_idx]
vl_events   = vl_events[sort_vl_idx]
vl_rad_features = vl_rad_features[sort_vl_idx]

sort_ts_idx = np.argsort(ts_times)[::-1]
test_ids    = test_ids[sort_ts_idx]
ts_features = ts_features[sort_ts_idx]
ts_target   = ts_target.values[sort_ts_idx]
ts_times    = ts_times[sort_ts_idx]
ts_events   = ts_events[sort_ts_idx]
ts_rad_features = ts_rad_features[sort_ts_idx]

In [None]:
print("%.2f%% samples are right censored in training data." % (np.sum(~tr_events) * 100. / len(tr_events)))
print("%.2f%% samples are right censored in validation data." % (np.sum(~vl_events) * 100. / len(vl_events)))
print("%.2f%% samples are right censored in test data." % (np.sum(~ts_events) * 100. / len(ts_events)))

In [None]:
cindex = concordance_index_censored(ts_events, ts_times, np.exp(-ts_times))
print(f"Concordance index on test data with actual risk scores: {cindex[0]:.3f}")

## Create model

#### Unimodel

In [None]:
import models as M

In [None]:
model_name  = 'SPAN'
input_name  = 'clinical' # 'clinical' or 'radiomic'
uni_paras = dict(
    num_hidden=tr_features.shape[-1]*4.7619,
    num_outputs=1,
    num_layers=1,
    dropout=0.1416,
    kernel_initializer="glorot_uniform",
    l2_regularizer = 3.1416*1e-2,
    encode_model=model_name,  
    hidden_activation='selu',
    att_activation='softmax',
    output_activation='linear',
    input_name=f'{input_name}_features',
    name=model_name,
)

uni_model = M.ADSurv(**uni_paras)

#### Multimodal

In [None]:
multi_name = 'SPAN'
input_names  = ['clinical','radiomic']
multi_paras = dict(
    num_hidden=(tr_features.shape[-1]+tr_rad_features.shape[-1])*14.2857/2,
    num_outputs=1,
    dropout=0.21416,
    kernel_initializer="glorot_uniform",
    l2_regularizer = 3.1416*1e-2,
    encode_model=multi_name,  
    hidden_activation='selu',
    att_activation='softmax',
    output_activation='linear',
    input_names=[f'{n}_features' for n in input_names],
    fusion_type='weight',
    name=multi_name,
)

multi_model = M.MultiClinRad(**multi_paras)