In [None]:
import pandas as pd
import numpy as np
import math
import matplotlib.pyplot as plt
import glob

In [None]:
SMALL_SIZE = 12
MEDIUM_SIZE = 16
BIGGER_SIZE = 22

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [None]:
def generate_test_data(genie_data, mutation_data, drugs):
    
    cell_lines = genie_data.query('`CDK4/6 Inhibitor Overall` == "Yes" or `mTOR Inhibitor Overall` == "Yes"')
    cell_lines = sorted(cell_lines['Sample ID'])
    cell_line_df = pd.DataFrame(cell_lines, columns=['C'])
    
    filtered_mut_data = mutation_data.query('sampleId in @cell_lines')
    filtered_mut_data.sort_values(by=['sampleId'], inplace = True)
    filtered_mut_data.drop(columns=['sampleId', 'Altered'], inplace=True)
    
    test_data = []
    for d in drugs:
        for c in cell_lines:
            test_data.append((c, d, 0.5))
    test_data_df = pd.DataFrame(test_data, columns=['C', 'D', 'AUC'])
    
    return cell_line_df, filtered_mut_data, test_data_df

In [None]:
def create_survival_plot(genie_data, test_data, predict_data):
    
    genie_data = genie_data.query('`CDK4/6 Inhibitor Overall` == "Yes" or `mTOR Inhibitor Overall` == "Yes"')
    genie_data.sort_values(by=['Sample ID'], inplace = True)
    
    pred_df = pd.Series(predict_data, name='P_AUC')
    pred_df = pd.concat([test_data, pred_df], axis=1)[['C', 'P_AUC']]
    
    pred_median_df = pred_df.groupby(['C']).median()
    
    pred_median_df = pd.merge(pred_median_df, genie_data, left_on='C', right_on='Sample ID')
    
    pred_median = np.median(predict_data)
    print(pred_median)
    
    dc_p = []
    dc_n = []
    
    for i,row in pred_median_df.iterrows():
        if row['P_AUC'] > pred_median:
            dc_n.append(row['Overall Survival (Months)'])
        else:
            dc_p.append(row['Overall Survival (Months)'])

    print(np.median(dc_p))
    print(np.median(dc_n))


In [None]:
genie_data = pd.read_csv('../data/GENIE/brca_akt1_genie_2019_clinical_data.tsv', sep='\t')

mutation_data = pd.read_csv('../data/GENIE/sample_matrix.txt', sep='\t')
mutation_data['sampleId'] = mutation_data['studyID:sampleId'].str.split(":", expand=True)[1]
mutation_data.drop(columns =['studyID:sampleId'], inplace = True)

drugs = pd.read_csv('../data/GENIE/GENIE_drug2ind.txt', sep='\t', header=None, names=['I', 'D'])['D']

predict_data = np.loadtxt('../result/drugcell_genie.predict')
test_data = pd.read_csv('../data/GENIE/GENIE_test.txt', sep='\t', header=None, names=['C', 'D', 'AUC'])

In [None]:
cell_lines, filtered_mut_data, test_data = generate_test_data(genie_data, mutation_data, drugs)

cell_lines.to_csv('../data/GENIE/GENIE_cell2ind.txt', sep='\t', header=False, index=True)
filtered_mut_data.to_csv('../data/GENIE/GENIE_cell2mutation.txt', header=False, index=False)
test_data.to_csv("../data/GENIE/GENIE_test.txt", sep='\t', header=False, index=False)

In [None]:
create_survival_plot(genie_data, test_data, predict_data)