In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['PATH_SOURCE_DATA'] = '/workspace/projects/boostdm/nature-release/source-data'

In [None]:
import sys
sys.path.append('./scripts/')
import pickle
import gzip
import functools
import operator
import glob

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import numpy as np
import pandas as pd

from sklearn.metrics import precision_recall_curve
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import auc

import matplotlib.pyplot as plt

from fetch_data_needle import create_observed_dataset, get_mutations, get_plot_data
from plot_needle import plot_observed_distribution

# Figure 1a.

In [None]:
obs_muts = get_mutations()

def needleplot(gene, ttype, plotname=None):
    df = create_observed_dataset(gene, ttype, obs_muts)
    args = get_plot_data(df)
    plot_observed_distribution(gene, ttype, *args, plotname=plotname)

In [None]:
gene, ttype = 'EGFR', 'LUAD'
needleplot(gene, ttype, plotname=f'./raw_plots/{gene}.{ttype}')

In [None]:
gene, ttype = 'EGFR', 'GBM'
needleplot(gene, ttype, plotname=f'./raw_plots/{gene}.{ttype}')

# Figure 1e.

In [None]:
method_label = {'boostDM_score': 'boostDM',
                'chasm_score': 'CHASMplus',
                'cadd_score': 'CADD',
                'tp53_kato': 'TP53 (Kato et al.)',
                'tp53_giacomelli': 'TP53 (Giacomelli et al.)',
                'vest4': 'VEST4',
                'sift': 'SIFT',
                'sift4g': 'SIFT4G',
                'Polyphen2_HVAR': 'Polyphen2 (HVAR)',
                'Polyphen2_HDIV': 'Polyphen2 (HDIV)',
                'fathmm': 'FATHMM',
                'MutationAssessor': 'MutationAssessor'
                }

palette_boostdm = {'boostDM_score': '#ac0f0f', 'boostDM_score_strict': '#ad6f0f'}
palette_bioinfo = dict(zip(['chasm_score', 'cadd_score', 'vest4', 'sift', 'sift4g', 'Polyphen2_HVAR', 'Polyphen2_HDIV', 
                            'fathmm', 'MutationAssessor'], 
                           ['pink', '#b491c8', '#1565c0', '#663a82', '#1e88e5', '#90caf9', 'black', '#3c1361', '#52307c', ]))
palette_experimental = dict(zip(['tp53_kato', 'tp53_giacomelli', 'pten_mighell', 'ras_bandaru'],
                                ['#607c3c', '#b5e550', '#ececa3', '#abc32f']))

palettes = {}
palettes.update(palette_boostdm)
palettes.update(palette_bioinfo)
palettes.update(palette_experimental)

cv_path = os.path.join(os.environ['PATH_SOURCE_DATA'], 'benchmark-cvdata')

In [None]:
# Precision-Recall Curve

def trycatch(f):
    
    def func(*args, **kwargs):
        x = None
        try:
            x = f(*args, **kwargs)
        except Exception as e:
            pass
        return x
    return func


def plot_pr_curve(testy, model_probs, ax, **kwargs):
    
    # plot model precision-recall curve
    precision, recall, _ = precision_recall_curve(testy, model_probs)
    ax.plot(recall, precision, **kwargs)
    
    # axis labels
    ax.set_xlabel('Recall')
    ax.set_ylabel('Precision')
    
    # axis limits
    ax.set_ylim(0.5, 1.01)
    

@trycatch
def plot_prc(gene, ttype, score, ax=None, plot=True, **kwargs):
    
    if score in ['chasm_score']:
        with gzip.open(os.path.join(cv_path, 
                                    f'cv_chasm/{gene}.{ttype}.cv_chasm.pickle.gz'), 
                       'rb') as g:
            tables = pickle.load(g)
        df = pd.concat(tables, axis=0)
        df = df[(df[score] != '.') & (~df[score].isnull())]
        X = df[score].values.reshape(-1, 1)
        y = df['driver'].values
    
    elif score in ['cadd_score']:
        with gzip.open(os.path.join(cv_path, 
                                    f'cv_cadd/{gene}.{ttype}.cv_cadd.pickle.gz'), 
                       'rb') as g:
            tables = pickle.load(g)
        df = pd.concat(tables, axis=0)
        df = df[~df[score].isnull()]
        X = df[score].values.reshape(-1, 1)
        y = df['driver'].values
    
    elif score in ['boostDM_score']:
        with gzip.open(os.path.join(cv_path, 
                                    f'cv_boostdm/{gene}.{ttype}.cv_boostdm.pickle.gz'), 
                       'rb') as g:
            tables = pickle.load(g)
        df = pd.concat(tables, axis=0)
        df = df[(~df[score].isnull())]
        X = df[score].values.reshape(-1, 1)
        y = df['driver'].values
    
    elif score in ['tp53_kato']:
        with gzip.open(os.path.join(cv_path, 
                                    f'cv_tp53_kato/{gene}.{ttype}.cv_tp53.pickle.gz'), 
                       'rb') as g:
            tables = pickle.load(g)
        df = pd.concat(tables, axis=0)
        df = df[(~df['tp53_score'].isnull())]
        X = df['tp53_score'].values.reshape(-1, 1)
        y = df['driver'].values
    
    elif score in ['pten_mighell']:
        with gzip.open(os.path.join(cv_path, 
                                    f'cv_pten_mighell/{gene}.{ttype}.cv_pten.pickle.gz'), 
                       'rb') as g:
            tables = pickle.load(g)
        df = pd.concat(tables, axis=0)
        df = df[(~df[score].isnull())]
        X = df[score].values.reshape(-1, 1)
        y = df['driver'].values
    
    elif score in ['ras_bandaru']:
        with gzip.open(os.path.join(cv_path, 
                                    f'cv_ras_bandaru/{gene}.{ttype}.pickle.gz'), 
                       'rb') as g:
            tables = pickle.load(g)
        df = pd.concat(tables, axis=0)
        df = df[(~df['ras_score'].isnull())]
        X = df['ras_score'].values.reshape(-1, 1)
        y = df['driver'].values
        
    elif score in ['tp53_giacomelli']:
        with gzip.open(os.path.join(cv_path, 
                                    f'cv_tp53_giacomelli/{gene}.{ttype}.pickle.gz'), 
                       'rb') as g:
            tables = pickle.load(g)
        df = pd.concat(tables, axis=0)
        df = df[(~df['tp53_natgen_score'].isnull())]
        X = df['tp53_natgen_score'].values.reshape(-1, 1)
        y = df['driver'].values
    
    elif score in ['sift']:
        with gzip.open(os.path.join(cv_path, 
                                    f'cv_{score}/{gene}.{ttype}.pickle.gz'), 
                       'rb') as g:
            tables = pickle.load(g)
        df = pd.concat(tables, axis=0)
        df = df[(df['SIFT'] != '.') & (~df['SIFT'].isnull())]
        X = df['SIFT'].values.reshape(-1, 1)
        y = df['driver'].values
        
    elif score in ['sift4g']:
        with gzip.open(os.path.join(cv_path, 
                                    f'cv_{score}/{gene}.{ttype}.pickle.gz'), 
                       'rb') as g:
            tables = pickle.load(g)
        df = pd.concat(tables, axis=0)
        df = df[(df['SIFT4G'] != '.') & (~df['SIFT4G'].isnull())]
        X = df['SIFT4G'].values.reshape(-1, 1)
        y = df['driver'].values
        
    elif score in ['Polyphen2_HDIV']:
        with gzip.open(os.path.join(cv_path, 
                                    f'cv_{score}/{gene}.{ttype}.pickle.gz'), 
                       'rb') as g:
            tables = pickle.load(g)
        df = pd.concat(tables, axis=0)
        df = df[(df[score] != '.') & (~df[score].isnull())]
        X = df[score].values.reshape(-1, 1)
        y = df['driver'].values
    
    elif score in ['Polyphen2_HVAR']:
        with gzip.open(os.path.join(cv_path, 
                                    f'cv_{score}/{gene}.{ttype}.pickle.gz'), 
                       'rb') as g:
            tables = pickle.load(g)
        df = pd.concat(tables, axis=0)
        df = df[(df[score] != '.') & (~df[score].isnull())]
        X = df[score].values.reshape(-1, 1)
        y = df['driver'].values
        
    elif score in ['fathmm']:
        with gzip.open(os.path.join(cv_path, 
                                    f'cv_{score}/{gene}.{ttype}.pickle.gz'), 
                       'rb') as g:
            tables = pickle.load(g)
        df = pd.concat(tables, axis=0)
        df = df[(df['FATHMM'] != '.') & (~df['FATHMM'].isnull())]
        X = df['FATHMM'].values.reshape(-1, 1)
        y = df['driver'].values
        
    elif score in ['MutationAssessor']:
        with gzip.open(os.path.join(cv_path, 
                                    f'cv_{score}/{gene}.{ttype}.pickle.gz'), 
                       'rb') as g:
            tables = pickle.load(g)
        df = pd.concat(tables, axis=0)
        df = df[(df[score] != '.') & (~df[score].isnull())]
        X = df[score].values.reshape(-1, 1)
        y = df['driver'].values

    elif score in ['vest4']:
        with gzip.open(os.path.join(cv_path, 
                                    f'cv_{score}/{gene}.{ttype}.pickle.gz'), 
                       'rb') as g:
            tables = pickle.load(g)
        df = pd.concat(tables, axis=0)
        df = df[(df['VEST4'] != '.') & (~df['VEST4'].isnull())]
        X = df['VEST4'].values.reshape(-1, 1)
        y = df['driver'].values
        
    # fit simple logistic model
    model = LogisticRegression(solver='lbfgs')
    
    # number of mutations
    n = len(y)
    positive = sum(y)
    negative = n - positive
    
    model.fit(X, y)
    yhat = model.predict_proba(X)
    probs = yhat[:, 1]
    assert(X.shape[0] == len(y))

    # calculate the precision-recall auc
    precision, recall, _ = precision_recall_curve(y, probs)
    auc_score = auc(recall, precision)

    # plot precision-recall curves
    if plot:
        plot_pr_curve(y, probs, ax, label=f'{method_label[score]}: auPRC={auc_score:.2}', **kwargs)
    
    return auc_score, precision, recall, positive, negative


def plot_tp53_boxplot():
    
    gene = 'TP53'
    fig, ax = plt.subplots(figsize=(10, 3))
    tumor_types = {'PRAD', 'ESCA', 'LUSC', 'LUAD', 'UCEC', 'PAAD', 'OV', 'BLCA', 'GBM'}
    
    scores, scores_full = zip(*list(method_label.items()))
        
    box = [[] for _ in scores]
    for ttype in tumor_types:
        for i, score in enumerate(scores):
            auc, _, _, _, _ = plot_prc(gene, ttype, score, ax=ax, plot=False, lw=4, alpha=1)
            box[i].append(auc)
    
    sorted_index = np.argsort(list(map(np.nanmedian, box)))[::-1]
    box = [box[i] for i in sorted_index]
    scores_full = [scores_full[i] for i in sorted_index]
    
    
    ax.boxplot(box, showfliers=False)
    x = [[i+1] * len(b) for i, b in enumerate(box)]
    x = functools.reduce(operator.concat, x)
    x = np.array(x) + np.random.normal(0, 0.05, size=len(x))
    
    colors = [[palettes[scores[sorted_index[i]]]] * len(b) for i, b in enumerate(box)]
    colors = functools.reduce(operator.concat, colors)
    y = functools.reduce(operator.concat, box)
    ax.scatter(x, y, c=colors, s=30)
    ax.set_ylabel('auPRC')
    ax.set_xlabel('Method')
    ax.set_xticklabels(scores_full, rotation=90)
    ax.set_title('Performance in TP53 across tumor-types')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    plt.savefig('./raw_plots/TP53.boxplot.svg', dpi=300, bbox_inches='tight')
    plt.show()
    
    
def plot_tp53_bioinfo(gene, ttype):
    
    fig, ax = plt.subplots(figsize=(4,4))
    lw = 4
    alpha = 0.8
    plot_prc(gene, ttype, 'boostDM_score', ax=ax, color=palettes['boostDM_score'], lw=5, alpha=1)
    plot_prc(gene, ttype, 'chasm_score', ax=ax, color=palettes['chasm_score'], lw=lw, alpha=alpha)
    plot_prc(gene, ttype, 'cadd_score', ax=ax, color=palettes['cadd_score'], lw=lw, alpha=alpha)
    plot_prc(gene, ttype, 'sift4g', ax=ax, color=palettes['sift4g'], lw=lw, alpha=alpha)
    plot_prc(gene, ttype, 'vest4', ax=ax, color=palettes['vest4'], lw=lw, alpha=alpha)
    plot_prc(gene, ttype, 'Polyphen2_HDIV', ax=ax, color=palettes['Polyphen2_HDIV'], lw=lw, alpha=0.5)
    
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.set_title(f'{gene} ({ttype})')
    ax.legend(loc=(0.5, 0.3))
    plt.savefig(f'./raw_plots/{gene}.{ttype}.prc.bioinfo.svg', dpi=300, bbox_inches='tight')
    plt.show()
    
    
def plot_tp53_experimental(gene, ttype):
    
    fig, ax = plt.subplots(figsize=(4,4))
    lw = 4
    alpha = 0.8
    plot_prc(gene, ttype, 'boostDM_score', ax=ax, color=palettes['boostDM_score'], lw=5, alpha=1)
    plot_prc(gene, ttype, 'tp53_kato', ax=ax, color=palettes['tp53_kato'], lw=lw, alpha=alpha)
    plot_prc(gene, ttype, 'tp53_giacomelli', ax=ax, color=palettes['tp53_giacomelli'], lw=lw, alpha=alpha)
    
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.set_title(f'{gene} ({ttype})')
    ax.legend(loc=(0.5, 0.3))
    #plt.savefig(f'./raw_plots/{gene}.{ttype}.prc.experimental.svg', dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
plot_tp53_experimental('TP53', 'BLCA')

In [None]:
plot_tp53_bioinfo('TP53', 'BLCA')

In [None]:
plot_tp53_boxplot()

## Figure 1f.

In [None]:
highlighted = {
    'cadd': [('KRAS', 'PAAD'), ('EGFR', 'LUAD'), ('CTNNB1', 'HC'), ('FOXA1', 'PRAD'), ('TP53', 'BLCA')],
    'chasm': [('AR', 'PRAD'), ('EGFR', 'LUAD'), ('FOXA1', 'PRAD'), ('TP53', 'BLCA')],
    'ras_bandaru': [('NRAS', 'CM'), ('NRAS', 'MM'), ('NRAS', 'COREAD'), 
                    ('KRAS', 'CH'), ('KRAS', 'COREAD'), ('KRAS', 'LUAD'), 
                    ('KRAS', 'PAAD'), ('KRAS', 'ALL'), ('KRAS', 'MM'), 
                    ('HRAS', 'BLCA'), ('HRAS', 'HNSC')],
    'sift4g': [('KRAS', 'PAAD'), ('EGFR', 'LUAD'), ('PIK3CA', 'UCEC'), ('FOXA1', 'PRAD'), ('TP53', 'BLCA')]
}

In [None]:
def scatter_fscore(score):

    fig, ax = plt.subplots(figsize=(4,4))

    boostdm, method = {}, {}
    
    if score in ['chasm', 'cadd']:
        glob_fn = os.path.join(cv_path, f'cv_{score}', f'*.*.cv_{score}.pickle.gz')
        color_label = f'{score}_score'
        prc_arg = f'{score}_score'
    
    elif score in ['sift4g', 'ras_bandaru']:
        glob_fn = os.path.join(cv_path, f'cv_{score}', f'*.*.pickle.gz')
        color_label = score
        prc_arg = score
    
    for fn in glob.glob(glob_fn):
        gene, ttype = tuple(os.path.basename(fn).split('.')[:2])
        x = plot_prc(gene, ttype, 'boostDM_score', ax=ax, plot=False)
        y = plot_prc(gene, ttype, prc_arg, ax=ax, plot=False)
        if (x is None) or (y is None):
            continue
        boostdm[(gene, ttype)] = x
        method[(gene, ttype)] = y
    
    aucs_boostdm = {k: s[0] for k, s in boostdm.items() if (s[3] >= 200) and (s[4] >= 200)}
    aucs_method = {k: s[0] for k, s in method.items() if (s[3] >= 200) and (s[4] >= 200)}
    labels = set(aucs_boostdm.keys()).intersection(set(aucs_method.keys()))
    y = [aucs_method[k] for k in labels]
    x = [aucs_boostdm[k] for k in labels]
    ax.scatter(x, y, s=200, alpha=0.5, color=palettes[color_label])
    for i, l in enumerate(labels):
        if l in highlighted[score]:
            ax.text(x[i], y[i], f'{l[0]} ({l[1]})')
            ax.scatter([x[i]], [y[i]], s=200, color='white', edgecolor='black')
    ax.plot([0.85, 1], [0.85, 1], '--', c='r', lw=4)
    ax.set_ylabel(f'{score.upper()} (auPRC)')
    ax.set_xlabel('boostDM (auPRC)')
    ax.set_xlim(0.84, 1.01)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    plt.savefig(f'./raw_plots/boostdm_vs_{score}.prc.global.svg', dpi=200, bbox_inches='tight')
    plt.show()

In [None]:
scatter_fscore('cadd')

In [None]:
scatter_fscore('chasm')

In [None]:
scatter_fscore('sift4g')

In [None]:
scatter_fscore('ras_bandaru')