In [1]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Dense, Input, Activation
from tensorflow.keras.layers import BatchNormalization, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.metrics import AUC
from tensorflow.keras.metrics import MeanSquaredError
from tensorflow.keras import Sequential
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.models import clone_model, save_model, load_model
from tensorflow.keras.utils import plot_model
from tensorflow.keras.metrics import Recall, Precision
from tensorflow import math
from numpy.random import seed
from tensorflow import random

In [2]:
gene_encodings_cancer = pd.read_csv('Gene_freq_encoding_FULL.csv')
gene_encodings_healthy = pd.read_csv('Healthy_Gene_freq_encoding.csv')

In [3]:
gene_encodings_cancer.shape

(5987, 30401)

In [4]:
gene_encodings_healthy.shape

(2504, 30547)

In [5]:
all_patient_genes = pd.concat([gene_encodings_cancer, gene_encodings_healthy]).dropna(axis='columns')

In [6]:
cancer_types = set(all_patient_genes['CancerType'])

In [7]:
label_dict = {}
for i, ct in enumerate(cancer_types):
    label_dict[ct] = i
label_dict

{'KIRC-US': 0,
 'UCEC-US': 1,
 'THCA-US': 2,
 'Healthy': 3,
 'GBM-US': 4,
 'BRCA-US': 5,
 'LUSC-US': 6,
 'SKCM-US': 7,
 'LGG-US': 8,
 'PRAD-US': 9,
 'COAD-US': 10,
 'BLCA-US': 11,
 'OV-US': 12}

In [8]:
all_patient_genes = all_patient_genes.sample(frac=1)

In [9]:
gene_labels = [label for label in all_patient_genes.columns if label.startswith("ENSG")]
len(gene_labels)

30085

In [10]:
datax = all_patient_genes[gene_labels].to_numpy()
datax.shape

(8491, 30085)

In [11]:
def to_categorical(val, n_class=13):
    result = np.zeros(n_class)
    result[val] = 1
    return result

In [12]:
datay = [to_categorical(label_dict[t]) for t in list(all_patient_genes['CancerType'])]
datay = np.array(datay)
datay

array([[0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.]])

In [13]:
seed(1)
random.set_seed(2)

# Model
- Try other initializers

In [14]:
patient_genes = Input(shape=(datax.shape[1],))
layer_1 = Dense(500, activation='relu')(patient_genes)
layer_1_bn = BatchNormalization()(layer_1)
layer_2 = Dense(150, activation='relu')(layer_1_bn)
layer_2_bn = BatchNormalization()(layer_2)
layer_3 = Dense(13, activation='relu')(layer_2_bn)
layer_3_bn = BatchNormalization()(layer_3)
layer_4 = Dense(13, activation='relu')(layer_3_bn)
layer_4_bn = BatchNormalization()(layer_4)
layer_5 = Dense(150, activation='relu')(layer_4_bn)
layer_5_bn = BatchNormalization()(layer_5)
layer_6 = Dense(500, activation='relu')(layer_5_bn)
layer_6_bn = BatchNormalization()(layer_6)
layer_7 = Dense(datax.shape[1], activation='linear', name='ae_output')(layer_6_bn)

# Input from layer 3 (latent representations)
layer_9 = Dense(13, activation='relu')(layer_3)
layer_9_do = Dropout(0.1)(layer_9)
layer_9_bn = BatchNormalization()(layer_9_do)

layer_10 = Dense(25, activation='relu')(layer_9_bn)
layer_10_do = Dropout(0.1)(layer_10)
layer_10_bn = BatchNormalization()(layer_10_do)

layer_11 = Dense(50, activation='relu')(layer_10_bn)
layer_11_do = Dropout(0.1)(layer_11)
layer_11_bn = BatchNormalization()(layer_11_do)

layer_12 = Dense(75, activation='relu')(layer_11_bn)
layer_12_do = Dropout(0.1)(layer_12)
layer_12_bn = BatchNormalization()(layer_12_do)

layer_13 = Dense(100, activation='relu')(layer_12_bn)
layer_13_do = Dropout(0.1)(layer_13)
layer_13_bn = BatchNormalization()(layer_13_do)

layer_14 = Dense(75, activation='relu')(layer_13_bn)
layer_14_do = Dropout(0.1)(layer_14)
layer_14_bn = BatchNormalization()(layer_14_do)

layer_15 = Dense(50, activation='relu')(layer_14_bn)
layer_15_do = Dropout(0.1)(layer_15)
layer_15_bn = BatchNormalization()(layer_15_do)

layer_16 = Dense(25, activation='relu')(layer_15_bn)
layer_16_do = Dropout(0.1)(layer_16)
layer_16_bn = BatchNormalization()(layer_16_do)

layer_17 = Dense(13, activation='softmax', name='cancer_output')(layer_16_bn)

In [15]:
model = Model(inputs=[patient_genes],
              outputs=[layer_7, layer_17],
             )

In [16]:
losses = {'cancer_output': 'categorical_crossentropy',
          'ae_output': 'mse',
         }

opt = RMSprop(learning_rate=0.001, rho=0.9, momentum=0.0)

In [17]:
model.compile(optimizer=opt,
              loss=losses,
              metrics={'cancer_output': ['accuracy', Recall(), Precision()], 
                       'ae_output': 'mse'},
             )

In [18]:
full_raw_data = [(datax[i],datay[i]) for i in range(datay.shape[0])]
np.random.shuffle(full_raw_data)
datax_train = np.array([np.array(x) for (x, _) in full_raw_data])
datay_train = np.array([np.array(y) for (_, y) in full_raw_data])

In [19]:
stand_datax_train = datax_train / np.max(datax_train)

In [20]:
reduce_lr = ReduceLROnPlateau(verbose=1)

In [21]:
history = model.fit(stand_datax_train[500:],
                   {"ae_output": stand_datax_train[500:], "cancer_output": datay_train[500:]},
                   batch_size=16, 
                   epochs=50,
                   validation_split=0.4,
                   callbacks=[reduce_lr],
                   )

Train on 5593 samples, validate on 2398 samples
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50


Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50


KeyboardInterrupt: 

In [None]:
# Evaluate full network
model.evaluate(stand_datax_train[:500],
               {"ae_output": stand_datax_train[:500], "cancer_output": datay_train[:500]},
               )

In [None]:
model.save_weights('classifier_weights')
model.save()

# Interpretability

In [None]:
# SHAP libraries for neural network
from shap import GradientExplainer
from shap import decision_plot, summary_plot, multioutput_decision_plot

### Might want to try other backgrounds. Plots are not making sense.
### Try LIME

In [None]:
background = np.zeros((13, stand_datax_train.shape[1]))
summary_background = np.zeros((1, stand_datax_train.shape[1]))
index_to_explain = 23

In [None]:
explainer = GradientExplainer(full_nn, background)
to_explain = stand_datax_train[index_to_explain]

In [None]:
print(to_explain)
print(datay_train[index_to_explain])

In [None]:
list(cancer_types)[np.argmax(datay_train[index_to_explain])]

In [None]:
prediction_ov = full_nn(stand_datax_train[index_to_explain].reshape(1, stand_datax_train[1].shape[0]))

In [None]:
list(cancer_types)[np.argmax(prediction_ov)]

## Explain a patient

In [None]:
to_explain = to_explain.reshape(1, to_explain.shape[0])

In [None]:
# cannot do entire dataset. Pull one example of interest and perform analysis
shap_vals = explainer.shap_values(X=to_explain, ranked_outputs=None)

In [None]:
# this won't make sense for one example. Used for whole dataset.
summary_plot(shap_vals, 
             to_explain, 
             gene_names, 
             class_names=list(cancer_types), 
             color=plt.get_cmap("tab20c"),
             max_display=30
            )

In [None]:
def get_gene_names(gene_codes, ensemble_obj=None):
    gene_names = []
    for gene in gene_codes:
        try:
            gene_info = ensemble_obj.gene_by_id(gene_id=gene)
            gene_names.append(gene_info.gene_name + '(' + gene_info.biotype + ')')
        except ValueError:
            gene_names.append('GENE NOT FOUND')
    return gene_names

In [None]:
# get class names
list(cancer_types)

In [None]:
np.unique(to_explain[0], return_index=True, return_counts=True)

In [None]:
from pyensembl import EnsemblRelease

In [None]:
gene_names = get_gene_names(gene_labels, ensemble_obj=EnsemblRelease(77))

In [None]:
cancer_predictions = full_nn(stand_datax_train)

In [None]:
np.argmax(cancer_predictions[index_to_explain])

In [None]:
multioutput_decision_plot(base_values=list(background),
                          shap_values=shap_vals,
                          row_index=index_to_explain,
                          feature_names=gene_names,
                          highlight=[np.argmax(cancer_predictions[index_to_explain])],
                          legend_labels=list(cancer_types),
                         )