In [None]:
# Set backend for Keras to be Theano
# You can make multiple Keras models this way w/o slowing down, unlike Tensorflow

from keras import backend as K
import os
from importlib import reload

def set_keras_backend(backend):

    if K.backend() != backend:
        os.environ['KERAS_BACKEND'] = backend
        reload(K)
        assert K.backend() == backend

set_keras_backend("theano")

In [None]:
# Edit this to change your directory to TransferSRNN directory

%cd /home/nolelin/TransferSRNN

In [None]:
import numpy as np
from keras.models import Sequential, Model
from keras import optimizers, layers, regularizers
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.layers import Input, Dense, Activation
from keras.models import load_model
from lifelines import KaplanMeierFitter
from lifelines import CoxPHFitter
from lifelines.utils import concordance_index
from nnet_survival import *

In [None]:
import numpy as np
import pandas as pd

from keras.models import Sequential
from keras.layers.core import Dense, Activation, Dropout
from keras.optimizers import SGD, RMSprop
import theano.tensor as T

from lifelines.utils import concordance_index
from lifelines import CoxPHFitter

from lifelines.datasets import load_rossi
from sklearn import preprocessing

In [None]:
# Get disease codes for each disease type that we want to make comparisons between baseline/transfer

df = pd.read_csv('binding/all_features/all_features.csv')
disease_names = list(df.disease_type.unique())
disease_names

In [None]:
df

In [None]:
# Get all case_ids of patients for each disease type

disease_caseids = {}
df2 = pd.read_csv('binding/all_features/all_features.csv')
for disease in disease_names:
    diseasedf = df2.loc[df2['disease_type'] == disease]
    disease_caseids[disease] = list(diseasedf['case_id'])

In [None]:
# Only train/validate on 80% of all patients
# Do 10 runs

from random import shuffle

runs_constrained_caseids = {}
for run in range(10):
    constrained_caseids = []
    for disease in disease_caseids.keys():
        num_patients = len(disease_caseids[disease])
        shuffle(disease_caseids[disease])
        constrained_caseids.extend(disease_caseids[disease][0:int(0.8 * num_patients)])
    runs_constrained_caseids[run] = constrained_caseids

In [None]:
# Clean and preprocess data for global model

import numpy as np
import sys
import numpy.ma as ma
from sklearn import preprocessing
from sklearn.impute import SimpleImputer

runs_x_train = {}
runs_x_test = {}
runs_x_val = {}
runs_y_train = {}
runs_y_test = {}
runs_y_val = {}
runs_e_train = {}
runs_e_test = {}
runs_e_val = {}
runs_n_intervals = {}
runs_breaks = {}
    
for run in range(10):
    print(run)
    disease_dfs = []
    for disease in disease_names:
        df = pd.read_csv('binding/all_features/all_features.csv')
        df = df.dropna(subset=['days_to_death', 'days_to_followup'], how='all')
        df = df.loc[df['case_id'].isin(disease_caseids[disease])]
        df['disease_code'] = disease
        disease_dfs.append(df)
    df = pd.concat(disease_dfs)
    df = pd.get_dummies(df, columns=['disease_code'])
    disease_cols = [x for x in df.columns if 'disease_code' in x]
    
    output_cols = ['days_to_death', 'days_to_followup']
    imp = SimpleImputer(missing_values=np.nan, strategy='mean')
    x_mut = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(df['mut_features'])])
    x_exp = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(df['exp_features'])])
    mut_load = np.sum(x_mut, axis=1).reshape(df.shape[0], 1)
    x = np.hstack((x_exp, mut_load))
    x = np.hstack((x, np.expand_dims(np.asarray(df['days_to_birth']), axis=1)))
    x = np.where(np.isnan(x), ma.array(x, mask=np.isnan(x)).mean(axis=0), x)
    imp.fit(x)
    x = imp.transform(x)
    scaler = preprocessing.StandardScaler().fit(x)
    days = np.asarray(df[output_cols])
    y_whole = np.asarray([x[0] if np.isnan(x[0]) == False else x[1] for x in days])
    breaks=np.arange(min(y_whole), max(y_whole), 200)
    n_intervals=len(breaks)-1
    
    test = df.loc[~df['case_id'].isin(runs_constrained_caseids[run])]
    df = df.loc[df['case_id'].isin(runs_constrained_caseids[run])]
    df = df.sample(frac=1)
    train_size = int(0.75 * df.shape[0])
    test_size = df.shape[0] - train_size
    train = df.head(train_size)
    validate = df.tail(test_size)
    dtb_train = np.expand_dims(np.asarray(train['days_to_birth']), axis=1)
    mut_train = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(train['mut_features'])])
    mut_train_load = np.sum(mut_train, axis=1).reshape(train.shape[0], 1)
    exp_train = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(train['exp_features'])])
    x_train = np.hstack((exp_train, mut_train_load))
    x_train = np.hstack((x_train, dtb_train))
    x_train = imp.transform(x_train)
    x_train = scaler.transform(x_train)
    x_train = x_train[:, ~np.all(x_train == 0, axis=0)]
    x_train = np.hstack((x_train, mut_train))
    x_train = np.hstack((x_train, train[disease_cols]))
    days = np.asarray(train[output_cols])
    y_train = np.asarray([x[0] if np.isnan(x[0]) == False else x[1] for x in days])
    e_train = np.asarray([1 if np.isnan(x[0]) == False else 0 for x in days])
    y_train=make_surv_array(y_train,[True if x == 1 else False for x in e_train],breaks)
        
    dtb_test = np.expand_dims(np.asarray(test['days_to_birth']), axis=1)
    mut_test = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(test['mut_features'])])
    mut_test_load = np.sum(mut_test, axis=1).reshape(test.shape[0], 1)
    exp_test = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(test['exp_features'])])
    x_test = np.hstack((exp_test, mut_test_load))
    x_test = np.hstack((x_test, dtb_test))
    x_test = imp.transform(x_test)
    x_test = scaler.transform(x_test)
    x_test = x_test[:, ~np.all(x_test == 0, axis=0)]
    x_test = np.hstack((x_test, mut_test))
    x_test = np.hstack((x_test, test[disease_cols]))
    days = np.asarray(test[output_cols])
    y_test = np.asarray([x[0] if np.isnan(x[0]) == False else x[1] for x in days])
    e_test = np.asarray([1 if np.isnan(x[0]) == False else 0 for x in days])
    y_test=make_surv_array(y_test,[True if x == 1 else False for x in e_test],breaks)
    
    dtb_val = np.expand_dims(np.asarray(validate['days_to_birth']), axis=1)
    mut_val = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(validate['mut_features'])])
    mut_val_load = np.sum(mut_val, axis=1).reshape(validate.shape[0], 1)
    exp_val = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(validate['exp_features'])])
    x_val = np.hstack((exp_val, mut_val_load))
    x_val = np.hstack((x_val, dtb_val))
    x_val = imp.transform(x_val)
    x_val = scaler.transform(x_val)
    x_val = x_val[:, ~np.all(x_val == 0, axis=0)]
    x_val = np.hstack((x_val, mut_val))
    x_val = np.hstack((x_val, validate[disease_cols]))
    days = np.asarray(validate[output_cols])
    y_val = np.asarray([x[0] if np.isnan(x[0]) == False else x[1] for x in days])
    e_val = np.asarray([1 if np.isnan(x[0]) == False else 0 for x in days])
    y_val=make_surv_array(y_val,[True if x == 1 else False for x in e_val],breaks)
    
    print(x_train.shape, x_val.shape, x_test.shape)
        
    if x_train.shape[1] != x_val.shape[1] or x_train.shape[1] != x_test.shape[1] or x_val.shape[1] != x_test.shape[1]:
        print("mismatched shape")
        continue

    if np.sum(e_train) < 0.2 * len(e_train) or np.sum(e_val) < 0.2 * len(e_val) or np.sum(e_test) < 0.2 * len(e_test) or len(e_train) < 150:
        print("skipping ", disease)
        continue
    
    runs_x_train[run] = x_train
    runs_x_test[run] = x_test
    runs_x_val[run] = x_val
    runs_y_train[run] = y_train
    runs_y_test[run] = y_test
    runs_y_val[run] = y_val
    runs_e_train[run] = e_train
    runs_e_test[run] = e_test
    runs_e_val[run] = e_val
    runs_n_intervals[run] = n_intervals
    runs_breaks[run] = breaks

In [None]:
# Tuning for global models

from keras.models import Sequential
from keras.layers import LSTM
from keras.layers import Dense
from keras.layers import Masking
from keras.layers import Dropout
from keras.layers import Flatten
from keras.layers.normalization import BatchNormalization
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import optimizers
from keras import regularizers
import sys

runs_global_errors = {}
units = [10, 50, 100, 200, 500, 1000]
learning_rates = [0.1, 0.01, 0.001, 0.0001, 0.00001]
for run in range(10):
    for unit in units:
        for learning_rate in learning_rates:
            model = Sequential()
            model.add(Dense(unit, activation='relu', input_shape=(runs_x_train[run].shape[1],)))
            model.add(BatchNormalization())
            model.add(Dense(unit, activation='relu'))
            model.add(BatchNormalization())
            model.add(Dense(runs_n_intervals[run], activation='sigmoid'))
            adam = optimizers.Adam(lr=learning_rate)
            model.compile(loss=surv_likelihood(runs_n_intervals[run]), optimizer=adam)
            early_stopping = EarlyStopping(monitor='val_loss', patience=2)
            history = model.fit(runs_x_train[run], runs_y_train[run], batch_size=256, epochs=1000, shuffle=True, callbacks=[early_stopping], validation_data=(runs_x_val[run], runs_y_val[run]))
            best_epochs = history.history['val_loss'].index(min(history.history['val_loss'])) + 1
            runs_global_errors[(run, (learning_rate, best_epochs, unit))] = min(history.history['val_loss'])

In [None]:
# Train global models

runs_global_models = {}
for run in range(10):
    this_runs = [x for x in runs_global_errors.keys() if run == x[0]]
    lowest_loss = 10000
    best_run = None
    for this_run in this_runs:
        if runs_global_errors[this_run] < lowest_loss:
            lowest_loss = runs_global_errors[this_run]
            best_run = this_run
    print(best_run, runs_global_errors[best_run])
    total_x = np.vstack((runs_x_train[run], runs_x_val[run]))
    total_y = np.vstack((runs_y_train[run], runs_y_val[run]))
    total_e = np.hstack((runs_e_train[run], runs_e_val[run]))
    model = Sequential()
    model.add(Dense(best_run[1][2], activation='relu', input_shape=(total_x.shape[1],)))
    model.add(BatchNormalization())
    model.add(Dense(best_run[1][2], activation='relu'))
    model.add(BatchNormalization())
    model.add(Dense(runs_n_intervals[run], activation='sigmoid'))
    adam = optimizers.Adam(lr=best_run[1][0])
    model.compile(loss=surv_likelihood(runs_n_intervals[run]), optimizer=adam)
    model.fit(total_x, total_y, batch_size=256, epochs=best_run[1][1], shuffle=True)
    runs_global_models[run] = model

In [None]:
# Do same data pre-processing for each individual disease
# Don't include diseases that don't have at least 150 patients in training set and at least 30% of patients have died

import numpy as np
import sys
import numpy.ma as ma
from sklearn import preprocessing
from sklearn.impute import SimpleImputer

runs_disease_x_train = {}
runs_disease_x_test = {}
runs_disease_x_val = {}
runs_disease_y_train = {}
runs_disease_y_test = {}
runs_disease_y_val = {}
runs_disease_y_test_orig = {}
runs_disease_e_train = {}
runs_disease_e_test = {}
runs_disease_e_val = {}
runs_disease_n_intervals = {}

for run in range(10):
    breaks = runs_breaks[run]
    n_intervals=len(breaks)-1
    for disease in disease_names:
        df = pd.read_csv('binding/all_features/all_features.csv')
        df = df.dropna(subset=['days_to_death', 'days_to_followup'], how='all')
        df = df.loc[df['case_id'].isin(disease_caseids[disease])]
        
        imp = SimpleImputer(missing_values=np.nan, strategy='mean')
        x_mut = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(df['mut_features'])])
        x_exp = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(df['exp_features'])])
        mut_load = np.sum(x_mut, axis=1).reshape(df.shape[0], 1)
        x = np.hstack((x_exp, mut_load))
        x = np.hstack((x, np.expand_dims(np.asarray(df['days_to_birth']), axis=1)))
        x = np.where(np.isnan(x), ma.array(x, mask=np.isnan(x)).mean(axis=0), x)
        imp.fit(x)
        x = imp.transform(x)
        scaler = preprocessing.StandardScaler().fit(x)
        
        disease_idx = disease_cols.index('disease_code_' + disease)
        
        test = df.loc[~df['case_id'].isin(runs_constrained_caseids[run])]
        df = df.loc[df['case_id'].isin(runs_constrained_caseids[run])]
        df = df.sample(frac=1)
        train_size = int(0.75 * df.shape[0])
        test_size = df.shape[0] - train_size
        train = df.head(train_size)
        validate = df.tail(test_size)
        output_cols = ['days_to_death', 'days_to_followup']
        dtb_train = np.expand_dims(np.asarray(train['days_to_birth']), axis=1)
        mut_train = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(train['mut_features'])])
        mut_train_load = np.sum(mut_train, axis=1).reshape(train.shape[0], 1)
        exp_train = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(train['exp_features'])])
        x_train = np.hstack((exp_train, mut_train_load))
        x_train = np.hstack((x_train, dtb_train))
        x_train = imp.transform(x_train)
        x_train = scaler.transform(x_train)
        x_train = x_train[:, ~np.all(x_train == 0, axis=0)]
        x_train = np.hstack((x_train, mut_train))
        zero_mat = np.zeros((x_train.shape[0], len(disease_cols)))
        zero_mat[:, disease_idx] = np.ones(x_train.shape[0])
        x_train = np.hstack((x_train, zero_mat))
        days = np.asarray(train[output_cols])
        y_train = np.asarray([x[0] if np.isnan(x[0]) == False else x[1] for x in days])
        e_train = np.asarray([1 if np.isnan(x[0]) == False else 0 for x in days])
        y_train=make_surv_array(y_train,[True if x == 1 else False for x in e_train],breaks)
        
        dtb_test = np.expand_dims(np.asarray(test['days_to_birth']), axis=1)
        mut_test = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(test['mut_features'])])
        mut_test_load = np.sum(mut_test, axis=1).reshape(test.shape[0], 1)
        exp_test = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(test['exp_features'])])
        x_test = np.hstack((exp_test, mut_test_load))
        x_test = np.hstack((x_test, dtb_test))
        x_test = imp.transform(x_test)
        x_test = scaler.transform(x_test)
        x_test = x_test[:, ~np.all(x_test == 0, axis=0)]
        x_test = np.hstack((x_test, mut_test))
        zero_mat = np.zeros((x_test.shape[0], len(disease_cols)))
        zero_mat[:, disease_idx] = np.ones(x_test.shape[0])
        x_test = np.hstack((x_test, zero_mat))
        days = np.asarray(test[output_cols])
        y_test = np.asarray([x[0] if np.isnan(x[0]) == False else x[1] for x in days])
        runs_disease_y_test_orig[(run, disease)] = y_test
        e_test = np.asarray([1 if np.isnan(x[0]) == False else 0 for x in days])
        y_test=make_surv_array(y_test,[True if x == 1 else False for x in e_test],breaks)
        
        dtb_val = np.expand_dims(np.asarray(validate['days_to_birth']), axis=1)
        mut_val = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(validate['mut_features'])])
        mut_val_load = np.sum(mut_val, axis=1).reshape(validate.shape[0], 1)
        exp_val = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(validate['exp_features'])])
        x_val = np.hstack((exp_val, mut_val_load))
        x_val = np.hstack((x_val, dtb_val))
        x_val = imp.transform(x_val)
        x_val = scaler.transform(x_val)
        x_val = x_val[:, ~np.all(x_val == 0, axis=0)]
        x_val = np.hstack((x_val, mut_val))
        zero_mat = np.zeros((x_val.shape[0], len(disease_cols)))
        zero_mat[:, disease_idx] = np.ones(x_val.shape[0])
        x_val = np.hstack((x_val, zero_mat))
        days = np.asarray(validate[output_cols])
        y_val = np.asarray([x[0] if np.isnan(x[0]) == False else x[1] for x in days])
        e_val = np.asarray([1 if np.isnan(x[0]) == False else 0 for x in days])
        y_val=make_surv_array(y_val,[True if x == 1 else False for x in e_val],breaks)
        
        print(x_train.shape, x_val.shape, x_test.shape)
        
        if x_train.shape[1] != x_val.shape[1] or x_train.shape[1] != x_test.shape[1] or x_val.shape[1] != x_test.shape[1]:
            print("mismatched shape")
            continue

        if np.sum(e_train) < 0.0 * len(e_train) or np.sum(e_val) < 0.0 * len(e_val) or np.sum(e_test) < 0.0 * len(e_test) or len(e_train) < 150:
            print(np.sum(e_train), len(e_train), np.sum(e_val), len(e_val), np.sum(e_test), len(e_test))
            print("skipping ", disease)
            continue

        runs_disease_x_train[(run, disease)] = x_train
        runs_disease_x_test[(run, disease)] = x_test
        runs_disease_x_val[(run, disease)] = x_val
        runs_disease_y_train[(run, disease)] = y_train
        runs_disease_y_test[(run, disease)] = y_test
        runs_disease_y_val[(run, disease)] = y_val
        runs_disease_e_train[(run, disease)] = e_train
        runs_disease_e_test[(run, disease)] = e_test
        runs_disease_e_val[(run, disease)] = e_val
        runs_disease_n_intervals[(run, disease)] = n_intervals

In [None]:
# Tuning for baseline models trained on global model weights

import keras
from keras.models import Sequential
from keras.layers import LSTM
from keras.layers import Dense
from keras.layers import Masking
from keras.layers import Dropout
from keras.layers import Flatten
from keras.layers.normalization import BatchNormalization
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import optimizers
from keras import regularizers
import operator

runs_disease_trained_global_errors = {}
learning_rates = [0.1, 0.01, 0.001, 0.0001, 0.00001]
for run in range(10):
    for disease in [item[1] for item in runs_disease_x_train.keys() if item[0] == run]:
        for learning_rate in learning_rates:
            model = keras.models.clone_model(runs_global_models[run])
            model.set_weights(runs_global_models[run].get_weights())
            adam = optimizers.Adam(lr=learning_rate)
            model.compile(loss=surv_likelihood(runs_disease_n_intervals[(run, disease)]), optimizer=adam)
            early_stopping = EarlyStopping(monitor='val_loss', patience=2)
            history = model.fit(runs_disease_x_train[(run, disease)], runs_disease_y_train[(run, disease)], batch_size=256, epochs=1000, shuffle=True,
                               callbacks=[early_stopping], validation_data=(runs_disease_x_val[(run, disease)], runs_disease_y_val[(run, disease)]))
            best_epochs = history.history['val_loss'].index(min(history.history['val_loss'])) + 1
            runs_disease_trained_global_errors[(run, (disease, learning_rate, best_epochs))] = min(history.history['val_loss'])

In [None]:
# Train each transfer model using global model weights for initialization

runs_disease_trained_global_models = {}
for run in range(10):
    for disease in [item[1] for item in runs_disease_x_train.keys() if item[0] == run]:
        this_disease_runs = [x for x in runs_disease_trained_global_errors.keys() if disease == x[1][0] and run == x[0]]
        lowest_error = 100000
        best_run = None
        for disease_run in this_disease_runs:
            if runs_disease_trained_global_errors[disease_run] < lowest_error:
                lowest_error = runs_disease_trained_global_errors[disease_run]
                best_run = disease_run
        print(best_run, runs_disease_trained_global_errors[best_run])
        total_x = np.vstack((runs_disease_x_train[(run, disease)], runs_disease_x_val[(run, disease)]))
        total_y = np.vstack((runs_disease_y_train[(run, disease)], runs_disease_y_val[(run, disease)]))
        total_e = np.hstack((runs_disease_e_train[(run, disease)], runs_disease_e_val[(run, disease)]))
        model = keras.models.clone_model(runs_global_models[run])
        model.set_weights(runs_global_models[run].get_weights())
        adam = optimizers.Adam(lr=best_run[1][1])
        model.compile(loss=surv_likelihood(runs_disease_n_intervals[(run, disease)]), optimizer=adam)
        model.fit(total_x, total_y, batch_size=256, epochs=best_run[1][2], shuffle=True)
        runs_disease_trained_global_models[(run, disease)] = model

In [None]:
# Test each transfer model using global model weights for initialization

runs_disease_trained_global_loss = {}
for run in range(10):
    for disease in [item[1] for item in runs_disease_x_train.keys() if item[0] == run]:
        test_loss = runs_disease_trained_global_models[(run, disease)].evaluate(runs_disease_x_test[(run, disease)], runs_disease_y_test[(run, disease)])
        runs_disease_trained_global_loss[(run, disease)] = test_loss

In [None]:
# Tuning for baseline models
# Find best hyperparameters for baseline model for each disease

from keras.models import Sequential
from keras.layers import LSTM
from keras.layers import Dense
from keras.layers import Masking
from keras.layers import Dropout
from keras.layers import Flatten
from keras.layers.normalization import BatchNormalization
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import optimizers
from keras import regularizers

patience = 2
runs_disease_baseline_errors = {}
units = [10, 50, 100, 200, 500, 1000]
learning_rates = [0.1, 0.01, 0.001, 0.0001, 0.00001]
for run in range(10):
    for disease in [item[1] for item in runs_disease_x_train.keys() if item[0] == run]:
        for unit in units:
            for learning_rate in learning_rates:
                print(run, disease, unit, learning_rate)
                num_epochs = 0
                min_loss = 1000000
                num_patience = 0
                model = Sequential()
                model.add(Dense(unit, activation='relu', input_shape=(runs_disease_x_train[(run, disease)].shape[1],)))
                model.add(BatchNormalization())
                model.add(Dense(unit, activation='relu'))
                model.add(BatchNormalization())
                model.add(Dense(runs_disease_n_intervals[(run, disease)], activation='sigmoid'))
                adam = optimizers.Adam(lr=learning_rate)
                model.compile(loss=surv_likelihood(runs_disease_n_intervals[(run, disease)]), optimizer=adam)
                while True:
                    model.fit(runs_disease_x_train[(run, disease)], runs_disease_y_train[(run, disease)], batch_size=256, epochs=1, shuffle=True, verbose=0)
                    loss = model.evaluate(runs_disease_x_val[(run, disease)], runs_disease_y_val[(run, disease)], verbose=0)
                    if loss < min_loss:
                        num_epochs += 1
                        min_loss = loss
                        num_patience = 0
                    else:
                        if num_patience < patience:
                            num_epochs += 1
                            num_patience += 1
                        else:
                            break
                runs_disease_baseline_errors[(run, (disease, learning_rate, num_epochs - patience, unit))] = min_loss

In [None]:
# Train each baseline model

runs_disease_baseline_models = {}
for run in range(10):
    for disease in [item[1] for item in runs_disease_x_train.keys() if item[0] == run]:
        this_disease_runs = [x for x in runs_disease_baseline_errors.keys() if disease == x[1][0] and run == x[0]]
        lowest_error = 100000
        best_run = None
        for disease_run in this_disease_runs:
            if runs_disease_baseline_errors[disease_run] < lowest_error:
                lowest_error = runs_disease_baseline_errors[disease_run]
                best_run = disease_run
        print(best_run, runs_disease_baseline_errors[best_run])
        total_x = np.vstack((runs_disease_x_train[(run, disease)], runs_disease_x_val[(run, disease)]))
        total_y = np.vstack((runs_disease_y_train[(run, disease)], runs_disease_y_val[(run, disease)]))
        total_e = np.hstack((runs_disease_e_train[(run, disease)], runs_disease_e_val[(run, disease)]))
        model = Sequential()
        model.add(Dense(best_run[1][3], activation='relu', input_shape=(total_x.shape[1],)))
        model.add(BatchNormalization())
        model.add(Dense(best_run[1][3], activation='relu'))
        model.add(BatchNormalization())
        model.add(Dense(runs_disease_n_intervals[(run, disease)], activation='sigmoid'))
        adam = optimizers.Adam(lr=best_run[1][1])
        model.compile(loss=surv_likelihood(runs_disease_n_intervals[(run, disease)]), optimizer=adam)
        model.fit(total_x, total_y, batch_size=256, epochs=best_run[1][2], shuffle=True)
        runs_disease_baseline_models[(run, disease)] = model

In [None]:
# Test each baseline model

runs_disease_baseline_loss = {}
for run in range(10):
    for disease in [item[1] for item in runs_disease_x_train.keys() if item[0] == run]:
        test_loss = runs_disease_baseline_models[(run, disease)].evaluate(runs_disease_x_test[(run, disease)], runs_disease_y_test[(run, disease)])
        runs_disease_baseline_loss[(run, disease)] = test_loss

In [None]:
# Plot loss differences for each disease

disease_diff = {}
for key in runs_disease_trained_global_loss.keys():
    diff = runs_disease_baseline_loss[key] - runs_disease_trained_global_loss[key]
    if key[1] not in disease_diff.keys():
        disease_diff[key[1]] = []
        disease_diff[key[1]].append(diff)
    else:
        disease_diff[key[1]].append(diff)
        
import plotly.plotly as py
import plotly.graph_objs as go

import numpy as np

data = [go.Box(y = disease_diff[x], name=x) for x in disease_diff.keys()]
py.iplot(data)

In [None]:
# Get feature importance
# Start by getting original losses before shuffling

original_transfer_losses = {}
for run in range(10):
    for disease in [item[1] for item in runs_disease_x_train.keys() if item[0] == run]:
        total_x = np.vstack((runs_disease_x_train[(run, disease)], runs_disease_x_val[(run, disease)]))
        total_y = np.vstack((runs_disease_y_train[(run, disease)], runs_disease_y_val[(run, disease)]))
        model = runs_disease_trained_global_models[(run, disease)]
        original_transfer_loss = model.evaluate(total_x, total_y)
        original_transfer_losses[(run, disease)] = original_transfer_loss

In [None]:
# Do permutation feature importance

shuffled_losses = {}
for run in range(10):
    for disease in [item[1] for item in runs_disease_x_train.keys() if item[0] == run]:
        total_x = np.vstack((runs_disease_x_train[(run, disease)], runs_disease_x_val[(run, disease)]))
        total_y = np.vstack((runs_disease_y_train[(run, disease)], runs_disease_y_val[(run, disease)]))
        model = runs_disease_trained_global_models[(run, disease)]
        idx_new_loss = {}
        for col_idx in range(runs_disease_x_train[(run, disease)].shape[1]):
            print(run, disease, col_idx)
            x_train_copy = total_x.copy()
            np.random.shuffle(x_train_copy[:,col_idx])
            new_loss = model.evaluate(x_train_copy, total_y)
            idx_new_loss[col_idx] = new_loss
        shuffled_losses[(run, disease)] = idx_new_loss

In [None]:
# Get loss diff for each disease run

runs_disease_loss_diff = {}
for run in range(10):
    for disease in [item[1] for item in runs_disease_x_train.keys() if item[0] == run]:
        idx_loss_diff = {}
        orig_loss = original_transfer_losses[(run, disease)]
        for col_idx in range(runs_disease_x_train[(run, disease)].shape[1]):
            idx_loss_diff[col_idx] = shuffled_losses[(run, disease)][col_idx] - orig_loss
        runs_disease_loss_diff[(run, disease)] = idx_loss_diff

In [None]:
# Get indices of features which might have been removed for having all 0s

df = pd.read_csv('binding/all_features/all_features.csv')
df = df.dropna(subset=['days_to_death', 'days_to_followup'], how='all')

imp = SimpleImputer(missing_values=np.nan, strategy='mean')
x_mut = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(df['mut_features'])])
x_exp = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(df['exp_features'])])
mut_load = np.sum(x_mut, axis=1).reshape(df.shape[0], 1)
x = np.hstack((x_exp, mut_load))
x = np.hstack((x, np.expand_dims(np.asarray(df['days_to_birth']), axis=1)))
x = np.where(np.isnan(x), ma.array(x, mask=np.isnan(x)).mean(axis=0), x)
imp.fit(x)
x = imp.transform(x)
scaler = preprocessing.StandardScaler().fit(x)

disease_idx = disease_cols.index('disease_code_' + disease)

df = df.sample(frac=1)
train_size = int(0.75 * df.shape[0])
test_size = df.shape[0] - train_size
train = df.head(train_size)
validate = df.tail(test_size)
output_cols = ['days_to_death', 'days_to_followup']
dtb_train = np.expand_dims(np.asarray(train['days_to_birth']), axis=1)
mut_train = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(train['mut_features'])])
mut_train_load = np.sum(mut_train, axis=1).reshape(train.shape[0], 1)
exp_train = np.asarray([[float(y) if y != 'NA' else np.nan for y in x[1:-1].split(", ")] for x in list(train['exp_features'])])
print(exp_train.shape)
x_train = np.hstack((exp_train, mut_train_load))
print(x_train.shape)
x_train = np.hstack((x_train, dtb_train))
print(x_train.shape)
x_train = imp.transform(x_train)
x_train = scaler.transform(x_train)
print("Indices deleted", np.where(~x_train.any(axis=0))[0])
x_train = x_train[:, ~np.all(x_train == 0, axis=0)]
print(x_train.shape)
x_train = np.hstack((x_train, mut_train))
print(x_train.shape)
zero_mat = np.zeros((x_train.shape[0], len(disease_cols)))
zero_mat[:, disease_idx] = np.ones(x_train.shape[0])
x_train = np.hstack((x_train, zero_mat))
print(x_train.shape)

In [None]:
# Get corresponding feature names for each index
# Feat index 117 always removed when removing features with all 0s

with open('binding/exp_order_list.txt', 'r') as f:
    exp_feats = [line.strip() + '_exp' for line in f]
    
with open('binding/mut_order_list.txt', 'r') as f:
    mut_feats = [line.strip() + '_mut' for line in f]
    
all_feats = exp_feats + ['mut_load', 'days_to_birth'] + mut_feats + disease_cols
del all_feats[117]

idx_feat = dict(zip(list(range(len(all_feats))), all_feats))

In [None]:
# Get feature importance for each specific disease

total_feat_importance_dfs = []
feat_size = len(list(runs_disease_loss_diff[(0, 'Liver Hepatocellular Carcinoma')].keys()))
uniq_diseases = list(set([x[1] for x in list(runs_disease_loss_diff.keys())]))
for disease in uniq_diseases:
    disease_feature_loss_changes = {}
    disease_dicts = list({k:v for (k,v) in runs_disease_loss_diff.items() if k[1] == disease}.values())
    for i in range(feat_size):
        disease_feature_loss_changes[i] = []
    for disease_dict in disease_dicts:
        for i in range(feat_size):
            disease_feature_loss_changes[i].append(disease_dict[i])
    avg_disease_feature_loss_changes = {k:np.mean(v) for (k,v) in disease_feature_loss_changes.items()}
    avg_disease_named_feature_loss_changes = {idx_feat[k]:v for (k,v) in avg_disease_feature_loss_changes.items()}
    feat_importance_df = pd.DataFrame.from_dict(avg_disease_named_feature_loss_changes, orient='index', columns=['importance (shuffled_loss - orig_loss)'])
    feat_importance_df['feature'] = feat_importance_df.index
    feat_importance_df = feat_importance_df.sort_values(by='importance (shuffled_loss - orig_loss)', ascending=False)
    feat_importance_df['disease'] = disease
    total_feat_importance_dfs.append(feat_importance_df)

In [None]:
total_feat_importance_df = pd.concat(total_feat_importance_dfs)
total_feat_importance_df

In [None]:
total_feat_importance_df.to_csv('binding/ind_disease_transfer_feature_importance.csv', index=False)


