In [None]:
import tensorflow as tf
from keras.metrics import AUC
from keras.callbacks import LearningRateScheduler, EarlyStopping
from keras.metrics import Precision, Recall
from keras.losses import BinaryCrossentropy
import datetime
from model.nn import multichannel_network
from model.data import Dataset, DataGenerator
from sklearn.utils import class_weight
import numpy as np

tf.random.set_seed(42)
# Dataset Setting: 
## choose from ['methylation', 'gene_expression', 'cnv', 'mutation']
FEATURE = ['gene_expression', 'cnv', 'methylation', 'mutation']
ds = Dataset(
    feature_contained=FEATURE, 
    dataset='CTRP', 
    set_label=True, 
    response='AUC', 
    threshold=.58)
# CTRP, "AUC", 0.58, 0.001
# GDSC, "AUC", .88, 0.001
# model parameters settings
lr_rate = 0.001
dropout_rate = .5
batch_size = 64
epochs = 2 

# Split train, test and validation set for training and testing, build generators
partition = ds.split(validation=True)
train = partition['train']
test = partition['test']
validation = partition['validation']
train_generator = DataGenerator(sample_barcode=train, **ds.get_config(), batch_size=batch_size)
validation_generator = DataGenerator(sample_barcode=validation, **ds.get_config(), batch_size=batch_size)
test_generator = DataGenerator(sample_barcode=test, **ds.get_config(), batch_size=batch_size)

# Training parameters
class_weights = class_weight.compute_class_weight(class_weight='balanced',
                                                 classes=np.unique([ds.labels[x] for x in train]),
                                                 y=[ds.labels[x] for x in train])
weights_dict = {i:w for i,w in enumerate(class_weights)}
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
def scheduler(epoch, lr):
    if(epoch % 5 ==0 and epoch !=0):
        return lr*0.1
    else:
        return lr
reduce_lr = LearningRateScheduler(scheduler)
# reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
#                               patience=5, min_lr=0.001)
early_stop = EarlyStopping(monitor='val_loss', patience=10)


# model building
model = multichannel_network(
    dataset=ds,
    train_sample_barcode=train,
    dropout=dropout_rate
    )

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_rate),
              loss=BinaryCrossentropy(),
              metrics=
              [
                Precision(name="precision"),
                Recall(name="recall"),
                AUC(curve='ROC'),
                AUC(curve='PR')
              ]
            )

history = model.fit(
    x=train_generator, 
    epochs=epochs,
    validation_data=validation_generator, 
    callbacks=[reduce_lr, early_stop],
    class_weight=weights_dict
                    )

scores = model.evaluate(x=test_generator) 
print(list(scores))


In [None]:
def make_predict(model, candidate, ds, batch_size):
    """

    Args:
        candidate (list): list["CELLINE_DRUG"]
        ds (Dataset): Dataset Object
    """
    celline_candidate = [i.split('_')[0] for i in candidate]
    drug_candidate = [i.split('_')[1] for i in candidate]
    feature = {}
    for i in ds.feature_contained:
        if i == "cnv":
            feature['cnv'] = ds.omics_data['cnv'].loc[celline_candidate].values.astype(np.float32)
        elif i == "gene_expression":
            feature['gene_expression'] = ds.omics_data['gene_expression'].loc[celline_candidate].values.astype(np.float32)
        elif i == "mutation":
            feature['mutation'] = ds.omics_data['mutation'].loc[celline_candidate].values.astype(np.float32)
        elif i == "methylation":
            feature['methylation'] = ds.omics_data['methylation'].loc[celline_candidate].values.astype(np.float32)
    feature['fingerprint'] = ds.drug_info.drug_feature['fingerprint'].loc[drug_candidate].values.astype(np.float32)
    feature['rdkit2d'] = ds.drug_info.drug_feature['rdkit2d'].loc[drug_candidate].values.astype(np.float32)
    chunks = []

    for i in range(0, len(candidate), batch_size):
        x = i
        chunks.append({
            'cnv': feature['cnv'][x:x+batch_size],
            'gene_expression': feature['gene_expression'][x:x+batch_size],
            'methylation': feature['methylation'][x:x+batch_size],
            'mutation': feature['mutation'][x:x+batch_size],
            'fingerprint': feature['fingerprint'][x:x+batch_size],
            'rdkit2d': feature['rdkit2d'][x:x+batch_size]
        })
    # last chunk
    if len(candidate) % batch_size != 0:
        last_chunk = {}
        leftover = len(candidate) % batch_size
        for i,j in feature.items():
            last_chunk[i] = np.zeros(shape=(batch_size, j.shape[1]))
            last_chunk[i][0:leftover] = j[-leftover::]
        chunks = chunks[:-1]
        chunks.append(last_chunk)
    result = []

    for idx, i in enumerate(chunks):
        print(f"{idx}/{len(chunks)}")
        result.append(model(i))
    print(result)
    result = np.concatenate(result, axis=-2)
    result = result[0:len(candidate)]
    import pandas as pd
    df = pd.DataFrame(data=result, columns=['Sensitivity'])
    df['DRUG_NAME'] = drug_candidate
    df['CELL_LINE'] = celline_candidate
    return df

# ANALYSIS

In [None]:
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

## GDSC and CTRPv2 Basic Statistics

CTRP AUC Histogram

In [None]:
sns.histplot(data=ds.response['AUC'])

## AUC, AUPRC and Confusion Matrix over test dataset

In [None]:
pred_df = make_predict(model=model, candidate=test, ds=ds, batch_size=64)

In [None]:
test[-1]

In [None]:
pred_df['true_labels'] = [ds.labels[i] for i in test]

In [None]:
pred_df.rename(columns={"AUC_predicted": "Sensitivity"}, inplace=True)

In [None]:
pred_df['pred_labels'] = [1 if i>0.5 else 0 for i in pred_df['AUC_predicted']]

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay, RocCurveDisplay, PrecisionRecallDisplay
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(y_true=list(pred_df['true_labels']), y_pred=list(pred_df['pred_labels']))
disp=ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[0,1])
disp.plot(cmap=plt.cm.Blues)


In [None]:
RocCurveDisplay.from_predictions(y_true=list(pred_df['true_labels']), y_pred=list(pred_df['Sensitivity']),)

In [None]:
PrecisionRecallDisplay.from_predictions(y_true=list(pred_df['true_labels']), y_pred=list(pred_df['Sensitivity']),)

## Predicted Scores Comparison between sensitive and insensitive drugs

In [None]:
sns.violinplot(data=pred_df, x="true_labels", y="Sensitivity")

## Test Cancer(Lung, Colon, Breast, Stomach) Across drug types 

Find all cellines related to Colon Cancer

In [None]:
model_list = pd.read_csv('data/raw_data/model_list_20230307.csv')
lung_ccl = model_list[(model_list['tissue'] == 'Lung') & 
                (model_list['tissue_status'] == 'Tumour')]['model_name']
breast_ccl = model_list[(model_list['tissue'] == 'Breast') & 
                (model_list['tissue_status'] == 'Tumour')]['model_name']
colo_ccl = model_list[((model_list['tissue'] == 'Small Intestine') | (model_list['tissue'] == 'Large Intestine')) & 
                (model_list['tissue_status'] == 'Tumour')]['model_name']
stomach_ccl = model_list[(model_list['tissue'] == 'Stomach') & 
                (model_list['tissue_status'] == 'Tumour')]['model_name']


In [None]:
celline_barcode = set(ds.celline_barcode)
colo_ccl = set(colo_ccl).intersection(celline_barcode)
lung_ccl = set(lung_ccl).intersection(celline_barcode)
breast_ccl = set(breast_ccl).intersection(celline_barcode)
stomach_ccl = set(stomach_ccl).intersection(celline_barcode)
drug_candidate = set(ds.drug_info.all_drugs.index).difference(set(ds.processed_experiment['DRUG_NAME']))

In [None]:
import itertools
l = {"colo": colo_ccl, "lung": lung_ccl, "breast": breast_ccl, "stomach": stomach_ccl}
experiment_candidate_dict = {name:itertools.product(i, drug_candidate) for name, i in l.items()}
for name, itm in experiment_candidate_dict.items():
    experiment_candidate_dict[name] = ["_".join(i) for i in itm]

In [None]:
for name, itm in experiment_candidate_dict.items():
    experiment_candidate_dict[name] = make_predict(model=model, candidate=itm, ds=ds, batch_size=64)

In [None]:
experiment_candidate_dict['breast']

In [None]:
breast_drug_df = experiment_candidate_dict['breast'].groupby('DRUG_NAME').mean()
colo_drug_df = experiment_candidate_dict['colo'].groupby('DRUG_NAME').mean()
lung_drug_df = experiment_candidate_dict['lung'].groupby('DRUG_NAME').mean()
stomach_drug_df = experiment_candidate_dict['stomach'].groupby('DRUG_NAME').mean()

In [None]:
colo_set = set(colo_drug_df.sort_values(by='Sensitivity', ascending=False).head(30).index)
stomach_set = set(stomach_drug_df.sort_values(by='Sensitivity', ascending=False).head(30).index)
lung_set = set(lung_drug_df.sort_values(by='Sensitivity', ascending=False).head(30).index)
breast_set = set(breast_drug_df.sort_values(by='Sensitivity', ascending=False).head(30).index)

In [None]:
common_set = colo_set.intersection(stomach_set, lung_set, breast_set)

In [None]:
print(f"Colon: {colo_set.difference(common_set)}")
print(f"Stomach: {stomach_set.difference(common_set)}")
print(f"Lung: {lung_set.difference(common_set)}")
print(f"Breast: {breast_set.difference(common_set)}")

In [None]:
'trastuzumab' in ds.drug_info.all_drugs.index

## Test type-specific drugs Across Cancer Types

Gefitinib Lung Cancer-specific

In [None]:
gefitinib_entry = ["_".join((i, "gefitinib")) for i in ds.celline_barcode]

In [None]:
gefitinib_df = make_predict(model=model, candidate = gefitinib_entry, ds=ds, batch_size=64)

In [None]:
lung_ccl = model_list[(model_list['tissue'] == 'Lung') & 
                (model_list['tissue_status'] == 'Tumour')]['model_name']

In [None]:
gefitinib_df[gefitinib_df['CELL_LINE'].isin(lung_ccl)]

## PRISM Holdout Validation

In [None]:
import pandas as pd
import numpy as np
from model.data import Dataset

In [None]:
prism = pd.read_csv('data/PRISM/secondary-screen-dose-response-curve-parameters.csv', low_memory=False)

In [None]:
import matplotlib.pyplot as plt
from sklearn.preprocessing import minmax_scale
prism['auc'] = minmax_scale(X=prism['auc'])
prism['auc'].quantile([0.3333, 0.66666, 0.9]) # threshold = 0.1728

In [None]:
prism_drug = pd.DataFrame()
prism_drug['DRUG_NAME'] = prism['name']
prism_drug['CanonicalSMILES'] = prism['smiles']
prism_drug.drop_duplicates(subset='DRUG_NAME', keep='first', inplace=True)
prism_drug.reset_index(inplace=True, drop=True)
prism_drug_overlapped = list(set(prism_drug['DRUG_NAME']).intersection(set(ds.drug_info.all_drugs.index)))

In [None]:
prism['ccle_name'] = [str(i).split("_")[0] for i in prism['ccle_name']]

In [None]:
prism_experiment = prism[(prism['ccle_name'].isin(ds.celline_barcode))&(prism['name'].isin(prism_drug_overlapped))].reset_index(drop=True)

In [None]:
experiment_id = ["_".join((str(i), j)) for i,j in zip(prism_experiment['ccle_name'], prism_experiment['name'])]

In [None]:
experiment_df = pd.DataFrame()
experiment_df['SAMPLE_BARCODE'] = experiment_id
experiment_df['AUC'] = prism_experiment['auc']
experiment_df['LABELS'] = [1 if i<0.1728 else 0 for i in prism_experiment['auc'] ]


In [None]:
experiment_candidate_prism = list(set(experiment_df['SAMPLE_BARCODE']).difference(train))

In [None]:
pred_df = make_predict(model=model, candidate=experiment_candidate_prism, ds=ds, batch_size=64)

In [None]:
ref = experiment_df[experiment_df['SAMPLE_BARCODE'].isin(experiment_candidate_prism)].drop_duplicates(subset='SAMPLE_BARCODE')
pred_df["SAMPLE_BARCODE"] = ["_".join((i,j)) for i,j in zip(pred_df['CELL_LINE'], pred_df['DRUG_NAME'])]
pred_df = pred_df.join(ref.set_index('SAMPLE_BARCODE'), on='SAMPLE_BARCODE')

In [None]:
pred_df['LABELS_PRED'] = [1 if i>=0.5 else 0 for i in pred_df['Sensitivity']]

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay, RocCurveDisplay, PrecisionRecallDisplay
from sklearn.metrics import confusion_matrix

RocCurveDisplay.from_predictions(y_true=list(pred_df['LABELS']), y_pred=list(pred_df['Sensitivity']))

In [None]:
cm = confusion_matrix(y_true=list(pred_df['LABELS']), y_pred=list(pred_df['LABELS_PRED']))
disp=ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[0,1])
disp.plot(cmap=plt.cm.Blues)


## Feature Ablation

## Cross-dataset Validation