In [None]:
import pandas as pd
import numpy as np
import math
import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns

from sklearn.preprocessing import scale
from sklearn.preprocessing import robust_scale

In [None]:
sns.set()
sns.set_style("white")
sns.set_style("ticks", {"xtick.major.size":8, "ytick.major.size":8})
sns.axes_style("whitegrid")
sns.set_palette("muted")
sns.color_palette("muted")

In [None]:
plt.rcParams['pdf.use14corefonts'] = True

SMALL_SIZE = 12
MEDIUM_SIZE = 16
BIGGER_SIZE = 22

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [None]:
#Separating data for each drug/cell

def get_pos_map(obj_list, test_df, col):
    pos_map = {obj:[] for obj in obj_list}
    for i, row in test_df.iterrows():
        if row[col] in pos_map:
            pos_map[row[col]].append(i)
    return pos_map

In [None]:
#Arrange the obj_list in the descending order of the scores

def sort_scores(obj_list, scores):
    score_map = {}
    for i, obj in enumerate(obj_list):
        score_map[obj] = scores[i]
    return {obj:sc for obj,sc in sorted(score_map.items(), key=lambda item:item[1], reverse=True)}

In [None]:
def plot_drug_performance(drug_corr_map):
    fig, ax = plt.subplots(figsize=(6,6))
    x_red = []
    y_red = []
    x_blue = []
    y_blue = []
    for d in drug_corr_map.keys():
        if drug_corr_map[d] > 0.5:
            x_red.append(d)
            y_red.append(drug_corr_map[d])
        else:
            x_blue.append(d)
            y_blue.append(drug_corr_map[d])
    
    ratio = float(len(x_red))/float(len(drug_corr_map.keys()))
    print('Red ratio = ' + str(ratio))
    ax.bar(x_red, y_red, color='red', width=1.0)
    ax.bar(x_blue, y_blue, color='blue', width=1.0)
    ax.set_xticks([])
    ax.set_xlabel('Drugs')
    ax.set_ylabel('Performance\nSpearman ρ (Predicted vs. Actual)')
    plt.show()
    return fig

In [None]:
def create_drug_performance_plot(drugs, drug_corr_list):
    drug_corr_map = sort_scores(drugs, drug_corr_list)
    fig_drug_perf = plot_drug_performance(drug_corr_map)
    print('Median spearman rho:', np.median(list(drug_corr_map.values())))
    return fig_drug_perf

In [None]:
def create_scatter_plot(Y, X, y_title, x_title):
    
    scatter_df = pd.DataFrame({x_title: X, y_title: Y})
    
    fig, ax = plt.subplots(figsize=(6,6))
    sns.scatterplot(data=scatter_df, x=x_title, y=y_title, s=35, ax=ax)
    
    ax.set_xlabel(x_title)
    ax.set_ylabel(y_title)
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]
    plt.plot(lims, lims, 'k--', alpha=0.75, zorder=0)
    ax.set_aspect('equal')
    ax.set_xlim(lims)
    ax.set_ylim(lims)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    
    print('t-test p-value:', stats.ttest_ind(X, Y)[1])
    
    return fig

In [None]:
def get_boxplot(df, x_title, y_title):
    
    fig, ax = plt.subplots(figsize=(6,6))
    ax.boxplot(df, showfliers=False, widths=0.5, patch_artist=True)
    ax.set_xticklabels(list(df.columns), rotation=45)
    ax.set_xlabel(x_title)
    ax.set_ylabel(y_title)
    
    plt.show()
    
    return fig

In [None]:
def get_violinplot(df, x_title, y_title):
    
    fig, ax = plt.subplots(figsize=(6,6))
    ax.violinplot(df, showmedians=True)
    ax.set_xticklabels(list(df.columns), rotation=45)
    ax.set_xlabel(x_title)
    ax.set_ylabel(y_title)
    
    plt.show()
    
    return fig

In [None]:
def get_corr_list(obj_list, test_df, pred, zscore_method, col='smiles'):
    corr_list = [0.0] * len(obj_list)
    pos_map = get_pos_map(obj_list, test_df, col)
    for i, obj in enumerate(obj_list):
        if len(pos_map[obj]) == 0:
            continue
        test_vals = np.take(test_df[zscore_method], pos_map[obj])
        pred_vals = np.take(pred, pos_map[obj])
        corr = 0
        all_same = np.all(pred_vals == pred_vals[0])
        if all_same:
            corr = 0
        else:
            corr = stats.spearmanr(pred_vals, test_vals)[0]
        corr_list[i] = corr
    return corr_list

In [None]:
def get_avg_corr_list(obj_list, all_test, all_pred, zscore_method, col='smiles'):
    
    corr_df = pd.DataFrame(columns=[k for k in all_pred.keys()])
    avg_corr = np.zeros(len(obj_list))
    
    for k in all_pred.keys():
        corr_list = np.array(get_corr_list(obj_list, all_test[k], all_pred[k], zscore_method, col))
        corr_df[k] = corr_list
        avg_corr += corr_list
        
    avg_corr /= len(all_pred.keys())
    
    return avg_corr, corr_df

In [None]:
def calc_std_vals(df, zscore_method):
    std_df = pd.DataFrame(columns=['dataset', 'drug', 'center', 'scale'])
    std_list = []

    if zscore_method == 'zscore':
        for name, group in df.groupby(['dataset', 'drug'])['auc']:
            center = group.mean()
            scale = group.std()
            if math.isnan(scale) or scale == 0.0:
                scale = 1.0
            temp = pd.DataFrame([[name[0], name[1], center, scale]], columns=std_df.columns)
            std_list.append(temp)

    elif zscore_method == 'robustz':
        for name, group in df.groupby(['dataset', 'drug'])['auc']:
            center = group.median()
            scale = group.quantile(0.75) - group.quantile(0.25)
            if math.isnan(scale) or scale == 0.0:
                scale = 1.0
            temp = pd.DataFrame([[name[0], name[1], center, scale]], columns=std_df.columns)
            std_list.append(temp)
    else:
        for name, group in df.groupby(['dataset', 'drug'])['auc']:
            temp = pd.DataFrame([[name[0], name[1], 0.0, 1.0]], columns=std_df.columns)
            std_list.append(temp)

    std_df = pd.concat(std_list, ignore_index=True)
    return std_df

In [None]:
def standardize_data(df, std_df, zscore_method):
    merged = pd.merge(df, std_df, how="left", on=['dataset', 'drug'], sort=False)
    merged[zscore_method] = (merged['auc'] - merged['center']) / merged['scale']
    merged = merged[['cell_line', 'smiles', zscore_method]]
    return merged

In [None]:
def normalize_auc(train_std_df, test_df, zscore_method):

    test_std_df = calc_std_vals(test_df, zscore_method)
    for i, row in test_std_df.iterrows():
        dataset = row['dataset']
        drug = str(row['drug'])
        train_entry = train_std_df.query('dataset == @dataset and drug == @drug')
        if not train_entry.empty:
            test_std_df.loc[i, 'center'] = float(train_entry['center'])
            test_std_df.loc[i, 'scale'] = float(train_entry['scale'])
            
    test_df = standardize_data(test_df, test_std_df, zscore_method)
    return test_df

In [None]:
def get_cv_data(dataset, ont, zscore_method, fold_size=5):
    
    sum = 0.0
    all_pred = {}
    all_test = {}
    
    for i in range(1, fold_size+1):
        
        test_file = '../data/' + str(i) + '_test_cg' + dataset + '.txt'
        test_df = pd.read_csv(test_file, sep='\t', header=None, names=['cell_line', 'smiles', 'auc', 'dataset', 'drug'])
        
        modeldir = '../model_' + ont + dataset + '_' + str(i) + '_' + zscore_method
        pred_file = modeldir + '/predict.txt'
        pred = np.loadtxt(pred_file)
        
        train_std_df = pd.read_csv(modeldir + '/std.txt', sep='\t', header=None, names=['dataset', 'drug', 'center', 'scale'])
        test_df = normalize_auc(train_std_df, test_df, zscore_method)
        
        key = 'Fold' + str(i)
        all_pred[key] = pred
        all_test[key] = test_df
        
        corr = stats.spearmanr(pred, test_df[zscore_method])[0]
        print('Correlation for #{}: {:.3f}'.format(i, corr))
        sum += corr
    
    print('Avg Correlation: {:.3f}'.format(sum/fold_size))
    
    return all_test, all_pred

In [None]:
def get_avg_cv_data(dataset, ont, zscore_method):
    
    sum = 0.0
    all_pred = {}
    all_test = {}
    
    for i in range(1, 6):
        
        test_file = '../data/' + str(i) + '_test_cg' + dataset + '.txt'
        
        for s in ['a', 'b', 'c', 'd', 'e']:
            net = ont + '_' + s + dataset
            modeldir = '../model_' + net + '_' + str(i) + '_' + zscore_method
            pred_file = modeldir + '/predict.txt'
            pred = np.loadtxt(pred_file)
            
            test_df = pd.read_csv(test_file, sep='\t', header=None, names=['cell_line', 'smiles', 'auc', 'dataset', 'drug'])
            train_std_df = pd.read_csv(modeldir + '/std.txt', sep='\t', header=None, names=['dataset', 'drug', 'center', 'scale'])
            test_df = normalize_auc(train_std_df, test_df, zscore_method)
            
            key = 'Fold' + str(i) + s
            all_pred[key] = pred
            all_test[key] = test_df
        
            corr = stats.spearmanr(pred, test_df[zscore_method])[0]
            print('Correlation for #{}{}: {:.3f}'.format(i, s, corr))
            sum += corr
    
    print('Avg Correlation: {:.3f}'.format(sum/len(all_pred.keys())))
    
    return all_test, all_pred

In [None]:
def get_top_100_corr(sorted_obj_list, corr_list):
    
    corr_100 = []
    corr_map = sort_scores(sorted_obj_list, corr_list)
    for i, obj in enumerate(sorted_obj_list):
        if i == 100:
            break
        corr_100.append(corr_map[obj])
        
    return corr_100

In [None]:
def calc_var(obj_list, train_df, col):
    var_list = [0.0] * len(obj_list)
    pos_map = get_pos_map(obj_list, train_df, col)
    for i, obj in enumerate(obj_list):
        train_vals = np.take(train_df['auc'], pos_map[obj])
        var_list[i] = np.var(train_vals)
    return sort_scores(obj_list, var_list)

In [None]:
dataset = 'cg'

all_df = pd.read_csv('../data/drugcell_all_' + dataset + '.txt', sep="\t", header=None, names=['cell_line', 'smiles', 'auc', 'dataset', 'drug'])
drugs = list(pd.read_csv('../data/drug2ind_' + dataset + '.txt', sep='\t', header=None, names=['I', 'D'])['D'])
cell_lines = list(pd.read_csv("../data/cell2ind_cg.txt", sep="\t", header=None, names=['I', 'C'])['C'])

In [None]:
ref_drug_variance_map = calc_var(drugs, all_df, 'smiles')
drugs = list(ref_drug_variance_map.keys())
top_100_drugs = list(ref_drug_variance_map.keys())[:100]

In [None]:
ref_cell_variance_map = calc_var(cell_lines, all_df, 'cell_line')
top_100_cells = list(ref_cell_variance_map.keys())[:100]

In [None]:
ont = 'cg'
zscore_method = 'auc'

cg_all_test, cg_all_pred = get_cv_data('', ont, zscore_method)
cg_drug_corr_list, cg_drug_corr_df = get_avg_corr_list(drugs, cg_all_test, cg_all_pred, zscore_method)
cg_100 = get_top_100_corr(top_100_drugs, cg_drug_corr_list)

cg_cell_corr_list, _ = get_avg_corr_list(cell_lines, cg_all_test, cg_all_pred, zscore_method, col='cell_line')
cg_cell_100 = get_top_100_corr(top_100_cells, cg_cell_corr_list)

In [None]:
ont = 'fmg_718'

fmg_all_test, fmg_all_pred = get_cv_data('', ont, zscore_method)
fmg_drug_corr_list, fmg_drug_corr_df = get_avg_corr_list(drugs, fmg_all_test, fmg_all_pred, zscore_method)

In [None]:
cg_fmg_scatterplot = create_scatter_plot(cg_drug_corr_list, fmg_drug_corr_list, "CPG-NeST", "FMG-NeST")
cg_fmg_scatterplot.savefig("../plots/CPG-NeST_FMG-NeST.pdf", bbox_inches = 'tight')

In [None]:
ont = 'random_718'

random_all_test, random_all_pred = get_avg_cv_data('', ont, zscore_method)
random_drug_corr_list, random_drug_corr_df = get_avg_corr_list(drugs, random_all_test, random_all_pred, zscore_method)
random_100 = get_top_100_corr(top_100_drugs, random_drug_corr_list)

In [None]:
cg_random_scatterplot = create_scatter_plot(cg_drug_corr_list, random_drug_corr_list, "CPG-NeST", "RSG-NeST")
cg_random_scatterplot.savefig("../plots/CPG-NeST_RSG-NeST.pdf", bbox_inches = 'tight')

In [None]:
cg_random_100_scatterplot = create_scatter_plot(cg_100, random_100, "CPG-NeST", "RSG-NeST")
cg_random_100_scatterplot.savefig("../plots/CPG-NeST_RSG-NeST_100.pdf", bbox_inches = 'tight')

In [None]:
ont = 'cg_bb'

bb_all_test, bb_all_pred = get_avg_cv_data('', ont, zscore_method)
bb_drug_corr_list, bb_drug_corr_df = get_avg_corr_list(drugs, bb_all_test, bb_all_pred, zscore_method)

In [None]:
cg_bb_scatterplot = create_scatter_plot(cg_drug_corr_list, bb_drug_corr_list, "CPG-NeST", "Shuffled-CPG-NeST")
cg_bb_scatterplot.savefig("../plots/CPG-NeST_Shuffled-CPG-NeST.pdf", bbox_inches = 'tight')

In [None]:
ont = 'cg_go'

cg_go_all_test, cg_go_all_pred = get_cv_data('', ont, zscore_method)
cg_go_drug_corr_list, cg_go_drug_corr_df = get_avg_corr_list(drugs, cg_go_all_test, cg_go_all_pred, zscore_method)

In [None]:
ont = 'fmg_718_go'

fmg_go_all_test, fmg_go_all_pred = get_cv_data('', ont, zscore_method)
fmg_go_drug_corr_list, fmg_go_drug_corr_df = get_avg_corr_list(drugs, fmg_go_all_test, fmg_go_all_pred, zscore_method)

In [None]:
cg_go_scatterplot = create_scatter_plot(cg_drug_corr_list, cg_go_drug_corr_list, "CPG-NeST", "CPG-GO")
cg_go_scatterplot.savefig("../plots/CPG-NeST_CPG-GO.pdf", bbox_inches = 'tight')

In [None]:
cg_fmg_go_scatterplot = create_scatter_plot(cg_drug_corr_list, fmg_go_drug_corr_list, "CPG-NeST", "FMG-GO")
cg_fmg_go_scatterplot.savefig("../plots/CPG-NeST_FMG-GO.pdf", bbox_inches = 'tight')

In [None]:
ont = 'cg'
dataset = '_strict'

strict_all_test, strict_all_pred = get_cv_data(dataset, ont, zscore_method)
strict_drug_corr_list, strict_drug_corr_df = get_avg_corr_list(drugs, strict_all_test, strict_all_pred, zscore_method)
strict_100 = get_top_100_corr(top_100_drugs, strict_drug_corr_list)

strict_cell_corr_list, _ = get_avg_corr_list(cell_lines, strict_all_test, strict_all_pred, zscore_method, col='cell_line')
strict_cell_100 = get_top_100_corr(top_100_cells, strict_cell_corr_list)

In [None]:
#cg_strict_scatterplot = create_scatter_plot(cg_drug_corr_list, strict_drug_corr_list, "CPG-NeST", "CPG-NeST-Strict")
#cg_strict_scatterplot.savefig("../plots/CPG-NeST_CPG-NeST-Strict.pdf", bbox_inches = 'tight')

In [None]:
#cg_strict_100_scatterplot = create_scatter_plot(cg_100, strict_100, "CPG-NeST", "CPG-NeST-Strict")
#cg_strict_100_scatterplot.savefig("../plots/CPG-NeST_CPG-NeST-Strict_100.pdf", bbox_inches = 'tight')

In [None]:
dataset = '_strict'
ont = 'random_718'

random_strict_all_test, random_strict_all_pred = get_avg_cv_data(dataset, ont, zscore_method)
random_strict_drug_corr_list, random_strict_drug_corr_df = get_avg_corr_list(drugs, random_strict_all_test, random_strict_all_pred, zscore_method)
random_strict_100 = get_top_100_corr(top_100_drugs, random_strict_drug_corr_list)

In [None]:
strict_random_strict_100_scatterplot = create_scatter_plot(strict_100, random_strict_100, "CPG-NeST-Strict", "RSG-NeST-Strict")
strict_random_strict_100_scatterplot.savefig("../plots/CPG-NeST-Strict_RSG-NeST-Strict_100.pdf", bbox_inches = 'tight')

In [None]:
ont = 'cg'
dataset = "_cell_loo"
zscore_method = 'auc'

cg_logo_test, cg_logo_pred = get_cv_data(dataset, ont, zscore_method, fold_size=100)

cg_logo_test_list = [cg_logo_test[k] for k in cg_logo_test.keys()]
cg_logo_pred_list = [cg_logo_pred[k] for k in cg_logo_pred.keys()]
cg_logo_test_concat = pd.concat(cg_logo_test_list, axis=0, ignore_index=True, sort=False)
cg_logo_pred_concat = np.concatenate(cg_logo_pred_list, axis=0)

cg_logo_drug_corr_list = get_corr_list(drugs, cg_logo_test_concat, cg_logo_pred_concat, zscore_method)

logo_cell_corr_list, _ = get_avg_corr_list(top_100_cells, cg_logo_test, cg_logo_pred, 'auc', col='cell_line')
logo_cell_100 = get_top_100_corr(top_100_cells, logo_cell_corr_list)

In [None]:
zscore_method = 'zscore'

cg_all_test_zscore, cg_all_pred_zscore = get_cv_data('', ont, zscore_method)
cg_drug_corr_list_zscore, _ = get_avg_corr_list(drugs, cg_all_test_zscore, cg_all_pred_zscore, zscore_method)

cg_cell_corr_list_zscore, _ = get_avg_corr_list(cell_lines, cg_all_test_zscore, cg_all_pred_zscore, zscore_method, col='cell_line')
cg_cell_100_zscore = get_top_100_corr(top_100_cells, cg_cell_corr_list_zscore)

In [None]:
dataset = '_strict'
zscore_method = 'zscore'

strict_all_test_zscore, strict_all_pred_zscore = get_cv_data(dataset, ont, zscore_method)
strict_drug_corr_list_zscore, _ = get_avg_corr_list(drugs, strict_all_test_zscore, strict_all_pred_zscore, zscore_method)

strict_cell_corr_list_zscore, _ = get_avg_corr_list(cell_lines, strict_all_test_zscore, strict_all_pred_zscore, zscore_method, col='cell_line')
strict_cell_100_zscore = get_top_100_corr(top_100_cells, strict_cell_corr_list_zscore)

In [None]:
ont = 'cg'
dataset = "_cell_loo"
zscore_method = 'zscore'

cg_logo_zscore_test, cg_logo_zscore_pred = get_cv_data(dataset, ont, zscore_method, fold_size=100)

cg_logo_zscore_test_list = [cg_logo_zscore_test[k] for k in cg_logo_zscore_test.keys()]
cg_logo_zscore_pred_list = [cg_logo_zscore_pred[k] for k in cg_logo_zscore_pred.keys()]
cg_logo_zscore_test_concat = pd.concat(cg_logo_zscore_test_list, axis=0, ignore_index=True, sort=False)
cg_logo_zscore_pred_concat = np.concatenate(cg_logo_zscore_pred_list, axis=0)

cg_logo_zscore_drug_corr_list = get_corr_list(drugs, cg_logo_zscore_test_concat, cg_logo_zscore_pred_concat, zscore_method)

logo_zscore_cell_corr_list, _ = get_avg_corr_list(top_100_cells, cg_logo_zscore_test, cg_logo_zscore_pred, zscore_method, col='cell_line')
logo_zscore_cell_100 = get_top_100_corr(top_100_cells, logo_zscore_cell_corr_list)

In [None]:
zscore_method = 'robustz'

cg_all_test_robustz, cg_all_pred_robustz = get_cv_data('', ont, zscore_method)
cg_drug_corr_list_robustz, _ = get_avg_corr_list(drugs, cg_all_test_robustz, cg_all_pred_robustz, zscore_method)

cg_cell_corr_list_robustz, _ = get_avg_corr_list(cell_lines, cg_all_test_robustz, cg_all_pred_robustz, zscore_method, col='cell_line')
cg_cell_100_robustz = get_top_100_corr(top_100_cells, cg_cell_corr_list_robustz)

In [None]:
dataset = '_strict'
zscore_method = 'robustz'

strict_all_test_robustz, strict_all_pred_robustz = get_cv_data(dataset, ont, zscore_method)
strict_drug_corr_list_robustz, _ = get_avg_corr_list(drugs, strict_all_test_robustz, strict_all_pred_robustz, zscore_method)

strict_cell_corr_list_robustz, _ = get_avg_corr_list(cell_lines, strict_all_test_robustz, strict_all_pred_robustz, zscore_method, col='cell_line')
strict_cell_100_robustz = get_top_100_corr(top_100_cells, strict_cell_corr_list_robustz)

In [None]:
drug_corr_df = pd.DataFrame({
    "AUC": cg_drug_corr_list,
    "Strict AUC": strict_drug_corr_list,
    "LOGO AUC": cg_logo_drug_corr_list,
    "Scaler": cg_drug_corr_list_zscore,
    "Strict scaler": strict_drug_corr_list_zscore,
    "LOGO scaler": cg_logo_zscore_drug_corr_list,
    "Robust scaler": cg_drug_corr_list_robustz,
    "Strict robust scaler": strict_drug_corr_list_robustz,
    })

In [None]:
drug_corr_boxplot = get_boxplot(drug_corr_df, "Cross-Validation stringencies", "Performance\nSpearman ρ (Predicted vs. Actual)")

In [None]:
drug_corr_violinplot = get_violinplot(drug_corr_df, "Cross-Validation stringencies", "Performance\nSpearman ρ (Predicted vs. Actual)")

In [None]:
from statistics import median

print(median(cg_drug_corr_list), min(cg_drug_corr_list), max(cg_drug_corr_list))
print(median(strict_drug_corr_list), min(strict_drug_corr_list), max(strict_drug_corr_list))
print(median(cg_logo_drug_corr_list), min(cg_logo_drug_corr_list), max(cg_logo_drug_corr_list))
print('\n')
print(median(cg_drug_corr_list_zscore), min(cg_drug_corr_list_zscore), max(cg_drug_corr_list_zscore))
print(median(strict_drug_corr_list_zscore), min(strict_drug_corr_list_zscore), max(strict_drug_corr_list_zscore))
print(median(cg_logo_zscore_drug_corr_list), min(cg_logo_zscore_drug_corr_list), max(cg_logo_zscore_drug_corr_list))
print('\n')
print(median(cg_drug_corr_list_robustz), min(cg_drug_corr_list_robustz), max(cg_drug_corr_list_robustz))
print(median(strict_drug_corr_list_robustz), min(strict_drug_corr_list_robustz), max(strict_drug_corr_list_robustz))
print(median(cg_logo_robustz_drug_corr_list), min(cg_logo_robustz_drug_corr_list), max(cg_logo_robustz_drug_corr_list))

In [None]:
cg_zscore_scatterplot = create_scatter_plot(cg_drug_corr_list, cg_drug_corr_list_robustz, "CPG-NeST AUC", "CPG-NeST Robustz")
#cg_fmg_scatterplot.savefig("../plots/CPG-NeST_FMG-NeST.pdf", bbox_inches = 'tight')

In [None]:
cg_zscore_strict_scatterplot = create_scatter_plot(strict_drug_corr_list, strict_drug_corr_list_robustz, "CPG-NeST Strict AUC", "CPG-NeST Strict Robustz")
#cg_fmg_scatterplot.savefig("../plots/CPG-NeST_FMG-NeST.pdf", bbox_inches = 'tight')

In [None]:
zscore_logo_scatterplot = create_scatter_plot(cg_logo_drug_corr_list, cg_logo_zscore_drug_corr_list, "CPG-NeST-LOGO", "CPG-NeST-LOGO-Zscore")

In [None]:
cg_cell_corr_list, _ = get_avg_corr_list(cell_lines, cg_all_test, cg_all_pred, 'auc', col='cell_line')
cg_cell_100 = get_top_100_corr(top_100_cells, cg_cell_corr_list)
    
strict_cell_corr_list, _ = get_avg_corr_list(cell_lines, strict_all_test, strict_all_pred, 'auc', col='cell_line')
strict_cell_100 = get_top_100_corr(top_100_cells, strict_cell_corr_list)

cg_cell_corr_list_zscore, _ = get_avg_corr_list(cell_lines, cg_all_test_zscore, cg_all_pred_zscore, 'zscore', col='cell_line')
cg_cell_100_zscore = get_top_100_corr(top_100_cells, cg_cell_corr_list_zscore)
    
strict_cell_corr_list_zscore, _ = get_avg_corr_list(cell_lines, strict_all_test_zscore, strict_all_pred_zscore, 'zscore', col='cell_line')
strict_cell_100_zscore = get_top_100_corr(top_100_cells, strict_cell_corr_list_zscore)

cg_cell_corr_list_robustz, _ = get_avg_corr_list(cell_lines, cg_all_test_robustz, cg_all_pred_robustz, 'robustz', col='cell_line')
cg_cell_100_robustz = get_top_100_corr(top_100_cells, cg_cell_corr_list_robustz)
    
strict_cell_corr_list_robustz, _ = get_avg_corr_list(cell_lines, strict_all_test_robustz, strict_all_pred_robustz, 'robustz', col='cell_line')
strict_cell_100_robustz = get_top_100_corr(top_100_cells, strict_cell_corr_list_robustz)

In [None]:
print(median(strict_cell_corr_list_robustz), min(strict_cell_corr_list_robustz), max(strict_cell_corr_list_robustz))

In [None]:
logo_cell_corr_list = get_corr_list(top_100_cells, cg_logo_test_concat, cg_logo_pred_concat, 'auc', col='cell_line')
logo_cell_100 = get_top_100_corr(top_100_cells, logo_cell_corr_list)

logo_zscore_cell_corr_list = get_corr_list(top_100_cells, cg_logo_test_zscore_concat, cg_logo_pred_zscore_concat, 'zscore', col='cell_line')
logo_zscore_cell_100 = get_top_100_corr(top_100_cells, logo_zscore_cell_corr_list)

In [None]:
cell_corr_df = pd.DataFrame({
    "Strict AUC": strict_cell_corr_list,
    "Scaler": cg_cell_corr_list_zscore,
    "Strict scaler": strict_cell_corr_list_zscore,
    "Robust scaler": cg_cell_corr_list_robustz,
    "Strict robust scaler": strict_cell_corr_list_robustz,
    })

In [None]:
cell_corr_boxplot = get_boxplot(cell_corr_df, "Cross-Validation stringencies", "Performance\nSpearman ρ (Predicted vs. Actual)")

In [None]:
cell_corr_violinplot = get_violinplot(cell_corr_df, "Cross-Validation stringencies", "Performance\nSpearman ρ (Predicted vs. Actual)")

In [None]:
#ont = 'cg'
#dataset = "_drug_loo"

#top_100_drugs = list(ref_drug_variance_map.keys())[:100]

#cg_drug_loo_test, cg_drug_loo_pred = get_cv_data(dataset, ont, fold_size=100)

In [None]:
#cg_drug_loo_corr_list, _ = get_avg_corr_list(top_100_drugs, cg_drug_loo_test, cg_drug_loo_pred)
#cg_100 = get_top_100_corr(top_100_drugs, cg_drug_corr_list)
#cg_drug_loo_scatterplot = create_scatter_plot(cg_drug_loo_corr_list, cg_100, "Drug LOO NeST", "CTG-NeST")

In [None]:
#cg_drug_loo_test_list = [cg_drug_loo_test[k] for k in cg_drug_loo_test.keys()]
#cg_drug_loo_pred_list = [cg_drug_loo_pred[k] for k in cg_drug_loo_pred.keys()]
#cg_drug_loo_test_concat = pd.concat(cg_drug_loo_test_list, axis=0, ignore_index=True, sort=False)
#cg_drug_loo_pred_concat = np.concatenate(cg_drug_loo_pred_list, axis=0)

In [None]:
#cg_drug_loo_corr_list = get_corr_list(top_100_drugs, cg_drug_loo_test_concat, cg_drug_loo_pred_concat)

#cg_100 = get_top_100_corr(top_100_drugs, cg_drug_corr_list)
#drug_loo_100 = get_top_100_corr(top_100_drugs, cg_drug_loo_corr_list)
    
#cg_drug_loo_100_scatterplot = create_scatter_plot(drug_loo_100, cg_100, "Drug LOO NeST", "CPG-NeST")