# CIb LSx: ONTRAM 3D CNN
## Outcome: stroke

### Load dependencies

In [None]:
!python -V

In [None]:
# !python -m pip install -U scikit-image

In [None]:
import os
import h5py
import pandas as pd
import numpy as np
import random
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.metrics import confusion_matrix
from scipy import ndimage
from sklearn import metrics
from sklearn import linear_model

# Tensorflow/Keras
import tensorflow as tf
print(tf.__version__)
from tensorflow import keras
print(keras.__version__)
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import layers
from tensorflow.keras import regularizers
from keras.utils import to_categorical

# Own functions
from functions.plot_slices import plot_slices
from functions.ontram import ontram
from functions.fit_ontram import fit_ontram
from functions.fit_ontram_batches import fit_ontram_batches
from functions.plot_results import plot_results
from functions.methods import predict

### Config Variables

In [None]:
OUTPUT_VARIABLE = "stroke"
# OUTPUT_VARIABLE = "mrs"
N_ENSEMBLES = 5
N_FOLDS = 5

In [None]:
DIR = "/tf/notebooks/katrin/"
OUTPUT_DIR = '{}results/stroke/ensemble/'.format(DIR)
MODEL_DIR = '{}results/stroke/ensemble/CIb/'.format(DIR)
INPUT_IMG = "{}data/dicom_3d_128x128x30.h5".format(DIR)
INPUT_TAB = "{}data/baseline_data_DWI_imputed.csv".format(DIR)

### Import images

In [None]:
def decode_data(string):
    decoded_string = [n.decode("UTF-8", "ignore") for n in string]
    return(decoded_string)

In [None]:
with h5py.File(INPUT_IMG, "r") as h5:
    print(h5.keys())
    X = h5["X"][:]
    Y_pat = h5["stroke"][:]
    Y_img = h5["Y"][:]
    pat = decode_data(h5["pat"])[:]

In [None]:
print(X.shape, Y_pat.shape, Y_img.shape, len(pat))

In [None]:
print(X.shape, X.min(), X.max(), X.mean(), X.std())

### Preprocessing
- standardize each patient to 0 mean, 1 variance

In [None]:
def standardize(array):
    mean = np.mean(array)
    sd = np.std(array)
    standardized = (array - mean) / sd
    return standardized

In [None]:
X = np.array([standardize(x) for x in X])
X = np.expand_dims(X, axis = 4)

In [None]:
print(X.shape, X.min(), X.max(), X.mean(), X.std())

### Import tabular data

In [None]:
INPUT_TAB

In [None]:
dat = pd.read_csv(INPUT_TAB, sep = ',')
dat.head(3)

In [None]:
# change values to numbers
dat = dat.replace('no', 0)
dat = dat.replace('yes', 1)
dat.sex = dat.sex.replace('female', 1)
dat.sex = dat.sex.replace('male', 0)
dat.event = dat.event.replace('Stroke', 1)
dat.event = dat.event.replace('TIA', 0)
dat.p_id =[format(id, '03d') for id in dat.p_id]
dat.head(3)

In [None]:
# Variables we have
dat.columns

In [None]:
# define mRS binary 
dat["mrs_3months_binary"] = 1
dat.loc[dat.mrs_3months <= 2, "mrs_3months_binary"] = 0
plt.hist(dat.event, bins = 2)

In [None]:
# match tabular data to image data
X_tab = np.zeros((X.shape[0], 12))
Y_mrs = np.zeros((X.shape[0]))
Y_pat = np.zeros((dat.shape[0]))
for i, p in enumerate(pat):
    k = np.where(dat.p_id.values == p)[0][0]
    dat_tmp = dat.iloc[k]
    X_tab[i,:] = np.array([dat_tmp.age, dat_tmp.sex, dat_tmp.mrs_before, dat_tmp.nihss_baseline, 
                           dat_tmp.stroke_before, dat_tmp.tia_before, dat_tmp.rf_hypertonia, 
                           dat_tmp.rf_diabetes, dat_tmp.rf_hypercholesterolemia, dat_tmp.rf_smoker, 
                           dat_tmp.rf_atrial_fibrillation, dat_tmp.rf_chd])
    Y_mrs[i] = dat_tmp.mrs_3months_binary
    Y_pat[i] = dat_tmp.event
X_tab

In [None]:
if OUTPUT_VARIABLE == "stroke":
    Y = Y_pat
elif OUTPUT_VARIABLE == "mrs":
    Y = Y_mrs
else:
    raise ValueError("unknown OUTPUT_VARIABLE: {}".format(OUTPUT_VARIABLE))

### Augmentation

In [None]:
j = 10
plot_slices(X[j], pat, "axial", modality = "DWI")

In [None]:
# zoom
def random_zoom3d(X_im, min_zoom, max_zoom):
    z = np.random.sample() *(max_zoom-min_zoom) + min_zoom
    zoom_matrix = np.array([[z, 0, 0, 0],
                            [0, z, 0, 0],
                            [0, 0, z, 0],
                            [0, 0, 0, 1]])
    return ndimage.affine_transform(X_im, zoom_matrix, mode = "nearest", order = 1)

In [None]:
# rotate
def random_rotate3d(X_im, min_angle_xy, max_angle_xy, min_angle_xz, max_angle_xz, min_angle_yz, max_angle_yz):
    angle_xy = np.random.uniform(min_angle_xy, max_angle_xy)
    angle_xz = np.random.uniform(min_angle_xz, max_angle_xz)
    angle_yz = np.random.uniform(min_angle_yz, max_angle_yz)
    rotation_axis = np.random.choice([0,1,2])
    if(rotation_axis == 0):
        X_im = ndimage.rotate(X_im, angle = angle_xy, axes = (0,1), mode = "nearest", reshape = False, order = 3)
    if(rotation_axis == 1):
        X_im = ndimage.rotate(X_im, angle = angle_xz, axes = (0,2), mode = "nearest", reshape = False, order = 3)
    if(rotation_axis == 2):
        X_im = ndimage.rotate(X_im, angle = angle_yz, axes = (1,2), mode = "nearest", reshape = False, order = 3)
    return X_im

In [None]:
# shifting
def random_shift3d(X_im, min_shift_x, max_shift_x, min_shift_y, max_shift_y, min_shift_z, max_shift_z):
    x_shift = np.random.uniform(min_shift_x, max_shift_x)
    y_shift = np.random.uniform(min_shift_y, max_shift_y)
    z_shift = np.random.uniform(min_shift_z, max_shift_z)
    return ndimage.shift(X_im, [x_shift, y_shift, z_shift, 0], mode = "nearest", order = 0)

In [None]:
# flip
def random_flip3d(X_im):
    axis = np.random.choice([0,1])
    if(axis == 0): # vertical flip
        X_im = X_im[:,::-1,:,:]
    return X_im

In [None]:
# smoothing
def random_gaussianfilter3d(X_im, sigma_max):
    sigma = np.random.uniform(0, sigma_max)
    return ndimage.gaussian_filter(X_im, sigma, mode = "nearest")

In [None]:
# combine augmentation functions:
def augment_batch(X_batch):
    X_batch_aug = np.empty_like(X_batch)
    for i in range(X_batch.shape[0]):
        im = X_batch[i]
        im = random_zoom3d(im, 0.7, 1.4)
        im = random_rotate3d(im, -30, 30, -10, 10, -10, 10)
        im = random_shift3d(im, -20, 20, -20, 20, -5, 5)
        im = random_flip3d(im)
        im = random_gaussianfilter3d(im, 0.2)
        X_batch_aug[i] = im
    return X_batch_aug

### Define train validation test set

In [None]:
## get TIA = 0 and stroke = 1 indices
idx_0 = np.where(Y == 0)
idx_1 = np.where(Y == 1)
print("{} TIA patients".format(len(idx_0[0])))
print("{} stroke patients".format(len(idx_1[0])))

## shuffle indices
np.random.seed(2021)
np.random.shuffle(idx_0[0])
np.random.shuffle(idx_1[0])

## split indices into 5 parts
splits_0 = np.array_split(idx_0[0], N_FOLDS)
splits_1 = np.array_split(idx_1[0], N_FOLDS)

## define chosen splits for each fold
test_folds = [0, 1, 2, 3, 4]
valid_folds = [1, 2, 3, 4, 0]
train_folds = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 4]] ## remove these splits for training data

### Define models for image data
### Complex intercept Linear Shift
Train with
- imaging and tabular data
- imaging data as complex intercept
- tabular data as linear shift
- outcome = stroke
- Ensemble with 5 Models
- 5-Fold CV
- "Warmstart"

In [None]:
# linear shift
def linear_shift_x(x):
    in_ = keras.Input(shape = x.shape[1:], name = 'x_in')
    out_ = layers.Dense(1, activation = 'linear',
                        use_bias = False, name = 'x_out')(in_)
    nn_x = keras.Model(inputs = in_, outputs = out_)
    return nn_x

# complex shift for image
def complex_intercept_b(input_shape, output_shape, input_name, activation = "linear"):
    
    initializer = keras.initializers.HeNormal(seed = 2802)
    
    in_ = keras.Input(shape = input_shape, name = input_name)
    x = layers.Convolution3D(32, kernel_size=(3, 3, 3), padding = 'same', 
                             activation = 'relu', kernel_initializer = initializer)(in_)
    x = layers.BatchNormalization(center=True, scale=True)(x)
    x = layers.MaxPooling3D(pool_size=(2, 2, 1))(x) # evtl (2,2,2)
    x = layers.Convolution3D(32, kernel_size=(3, 3, 3), padding = 'same', 
                             activation = 'relu', kernel_initializer = initializer)(x)
    x = layers.BatchNormalization(center=True, scale=True)(x)
    x = layers.MaxPooling3D(pool_size=(2, 2, 2))(x)
    x = layers.Convolution3D(64, kernel_size=(3, 3, 3), padding = 'same', 
                             activation = 'relu', kernel_initializer = initializer)(x)
    x = layers.BatchNormalization(center=True, scale=True)(x)
    x = layers.MaxPooling3D(pool_size=(2, 2, 2))(x)
    x = layers.Convolution3D(64, kernel_size=(3, 3, 3), padding = 'same', 
                             activation = 'relu', kernel_initializer = initializer)(x)
    x = layers.BatchNormalization(center=True, scale=True)(x)
    x = layers.MaxPooling3D(pool_size=(2, 2, 2))(x)
    x = layers.Convolution3D(128, kernel_size=(3, 3, 3), padding = 'same', 
                             activation = 'relu', kernel_initializer = initializer)(x)
    x = layers.BatchNormalization(center=True, scale=True)(x)
    x = layers.MaxPooling3D(pool_size=(2, 2, 2))(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128, activation = 'relu', kernel_initializer = initializer)(x)
    x = layers.BatchNormalization(center=True, scale=True)(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(128, activation = 'relu', kernel_initializer = initializer)(x)
    x = layers.BatchNormalization(center=True, scale=True)(x)
    x = layers.Dropout(0.3)(x)
    
    h = layers.Dense(output_shape, activation = activation)(x) # activation = linear
    
    out_ = layers.Lambda(lambda x: x * 0.1)(h) # get rid of too large values for h
    
    nn_im = keras.Model(inputs = in_, outputs = out_)
    return nn_im

logreg_model = linear_model.LogisticRegression(max_iter=1000)

### Train Models

In [None]:
for fold in range(N_FOLDS):
    
    ## define train, test and validation splits
    test_idx = np.concatenate((splits_0[test_folds[fold]], splits_1[test_folds[fold]]), axis = None)
    valid_idx = np.concatenate((splits_0[valid_folds[fold]], splits_1[valid_folds[fold]]), axis = None)

    train_0 = np.delete(splits_0, train_folds[fold], 0)
    train_0 = [item for sublist in train_0 for item in sublist]
    
    train_1 = np.delete(splits_1, train_folds[fold], 0)
    train_1 = [item for sublist in train_1 for item in sublist]
    
    train_idx = np.concatenate((train_0, train_1), axis = None)
    
    X_im_train = X[train_idx]
    X_im_test = X[test_idx]
    X_im_valid = X[valid_idx]
    
    X_tab_train = X_tab[train_idx]
    X_tab_test = X_tab[test_idx]
    X_tab_valid = X_tab[valid_idx]
    
    Y_train = Y[train_idx]
    Y_test = Y[test_idx]
    Y_valid = Y[valid_idx] 
    
    logreg_model.fit(X_tab_train, Y_train)
    LSx_weights = logreg_model.coef_.reshape(12, 1)
    
    Y_train = to_categorical(Y_train)
    Y_valid = to_categorical(Y_valid)
    Y_test = to_categorical(Y_test)

    for run in range(N_ENSEMBLES):
    
        ## create output directory
        folder_name = "CIb_LSx/fold_{}/run_{}/".format(fold, run)
        if not os.path.exists(OUTPUT_DIR + folder_name):
            os.makedirs(OUTPUT_DIR + folder_name)
       
        print("training fold {}/{}, run {}/{}".format(fold+1, N_FOLDS, run+1, N_ENSEMBLES))
    
        ## compile and fit model
        nn_bl = complex_intercept_b(X_im_train.shape[1:], Y_train.shape[1]-1, "bl_in", "linear")
        ## load weights from trained model in CIb, same fold, same run
        nn_bl.load_weights('{}fold_{}/best_model_run{}.hdf5'.format(MODEL_DIR, fold, run))
        
        nn_x = linear_shift_x(X_tab_train)
        ## set weights from logistic regression with corresponding fold + noise
        np.random.seed(1234 + run)
        noise = np.random.normal(loc = 0, scale = 0.1, size = 12)
        nn_x.set_weights([np.add(LSx_weights.flatten(), noise).reshape(12, 1)])
        
        ci_ls = ontram(nn_bl = nn_bl, nn_x = nn_x, response_varying = True)
        
        hist = fit_ontram(ci_ls,
                          y_train = Y_train,
                          x_train = X_tab_train,
                          x_train_im = X_im_train,
                          x_test = X_tab_valid,
                          x_test_im = X_im_valid, 
                          y_test = Y_valid,
                          batch_size = 32,
                          epochs = 200,
                          optimizer = tf.keras.optimizers.Adam(lr = 0.0001),
                          augment_batch = augment_batch,
                          balance_batches = True,
                          output_dir = OUTPUT_DIR + folder_name)

        ## save training loss and accuracy
        out = pd.DataFrame({'fold': fold,
                            'run': run,
                            'train_loss': hist["train_loss"], 
                            'train_acc': hist["train_acc"],
                            'test_loss': hist["test_loss"], 
                            'test_acc': hist["test_acc"]})
        if run == 0 and fold == 0:
            out.to_csv("{}CIb_LSx/ensemble_history.csv".format(OUTPUT_DIR), index = False)
        else:
            out.to_csv("{}CIb_LSx/ensemble_history.csv".format(OUTPUT_DIR), 
                       mode='a', header=False, index = False)

        ## save best model
        best_model = np.where(out.test_loss == np.min(out.test_loss))[0][0]
        print('best model run {}: {}'.format(run, best_model))
        ci_ls.model.load_weights('{}{}model-{:03d}.hdf5'.format(OUTPUT_DIR, folder_name, best_model))
        ci_ls.model.save_weights('{}CIb_LSx/fold_{}/best_model_run{}.hdf5'.format(OUTPUT_DIR, fold, run))
        
        # predict model
        pred = predict(ci_ls, bl = X_im_test, x = X_tab_test, y = Y_test)
        out = pd.DataFrame({'pid': np.array(pat)[test_idx],
                            'fold': fold,
                            'run': run,
                            'pred_prob_tia': pred["pdf"][:, 0],
                            'pred_prob_stroke': pred["pdf"][:, 1],
                            'pred_label_stroke': pred["response"],
                            'patient_label_tia': Y_test[:, 0],
                            'patient_label_stroke': Y_test[:, 1]})
        if run == 0 and fold == 0:
            out.to_csv("{}CIb_LSx/ensemble_predictions.csv".format(OUTPUT_DIR), index = False)
        else:
            out.to_csv("{}CIb_LSx/ensemble_predictions.csv".format(OUTPUT_DIR), 
                       mode='a', header=False, index = False)
            
        ## save model weights
        names = ['age', 'sex', 'mrs_before', 'nihss_baseline', 'stroke_before', 
                     'tia_before', 'rf_hypertonia', 'rf_diabetes', 'rf_hypercholesterolemia', 
                     'rf_smoker', 'rf_atrial_fibrillation', 'rf_chd']
        weights = np.array(pred['beta_w']).flatten()
        out = pd.DataFrame({'fold': fold,
                            'run': run,
                            'names': names,
                            'coef': weights})
        if run == 0 and fold == 0:
            out.to_csv("{}CIb_LSx/ensemble_weights.csv".format(OUTPUT_DIR), index = False)
        else:
            out.to_csv("{}CIb_LSx/ensemble_weights.csv".format(OUTPUT_DIR), 
                       mode='a', header=False, index = False)

In [None]:
pred = pd.read_csv("{}CIb_LSx/ensemble_predictions.csv".format(OUTPUT_DIR))
pred.head(5)

In [None]:
weights = pd.read_csv("{}CIb_LSx/ensemble_weights.csv".format(OUTPUT_DIR))
weights.head(5)

In [None]:
hist = pd.read_csv("{}CIb_LSx/ensemble_history.csv".format(OUTPUT_DIR))
hist.head(5)