In [None]:
import numpy as np
import pandas as pd
import os
import sys
import matplotlib.pyplot as plt
import pickle as pkl
from tqdm.notebook import tqdm
from scipy.stats import ttest_ind, chi2_contingency
from util import *
from scipy.stats import pearsonr, norm
from bio import Entrez
import seaborn as sns
from statsmodels.stats.multitest import multipletests

%matplotlib inline

### Define paths

In [None]:
# Input files
# Previous results
CORRELATION_DATA = 'results/essential_candidates/public_20Q2/expression_correlations.pkl'
ESSENTIAL_GENES  = 'results/essential_candidates/public_20Q2/essential_genes-all.pkl'
INTERACTION_DATA = "results/essential_candidates/public_20Q2/essential_genes_annotated.pkl"

# New datasets
TREEHOUSE_DATA   = 'data/treehouse/treehouse_depmap_genes.csv'
TREEHOUSE_INFO   = 'data/treehouse/clinical_TumorCompendium_v11_PolyA_2020-04-09.tsv'
TH_DM_MAP        = 'data/treehouse/th_dm_map.csv'

NCBI_GENE_NAMES  = 'data/misc/ncbi_gene_names.pkl'

# Output files
RESULTS_FILE = 'results/essential_candidates/public_20Q2/added_features_median_padj.pkl'

## 1.  Load data

### Load candidate gene data

In [None]:
gene_list = pd.concat([pd.read_pickle(CORRELATION_DATA).gene, 
                       pd.read_pickle(ESSENTIAL_GENES).cell_lines,
                       pd.read_pickle(INTERACTION_DATA).paralogs],
                      axis=1, join='inner')
gene_list.head()

### Load cell line info

In [None]:
cell_line_inf = get_from_taiga(name='public-20q2-075d', version=22, file='sample_info')
cell_line_inf.set_index('DepMap_ID', inplace=True)

cell_line_inf.loc[(cell_line_inf['lineage_subtype']=='ALL') &
                  (cell_line_inf['lineage_sub_subtype'].str.contains('t')), 'lineage_subtype'] = 't-ALL'
cell_line_inf.loc[(cell_line_inf['lineage_subtype']=='ALL') &
                  (cell_line_inf['lineage_sub_subtype'].str.contains('b')), 'lineage_subtype'] = 'b-ALL'

cell_line_inf['specified_disease'] = cell_line_inf.loc[cell_line_inf.lineage_subtype.isin(PEDIATRIC_CANCERS)
                                                      ].lineage_subtype
cell_line_inf.specified_disease.fillna(cell_line_inf.primary_disease, inplace=True)

diseases = cell_line_inf.specified_disease.unique()

cell_line_inf.head()

### Load gene effect data

In [None]:
eff_data = get_from_taiga(name='public-20q2-075d', version=22, file='Achilles_gene_effect',
                          split_attribute='header')
eff_data.head()

### Load DepMap prediction data

In [None]:
pred_data = get_from_taiga(name='tda-ensemble-predictions-3293', version=5, file='Avana_pred_models',
                           split_attribute='column', col='gene')
pred_data.head()

In [None]:
pred_data.loc[pred_data.best]

### Load mutation data

In [None]:
mut_data = get_from_taiga(name='public-20q2-075d', version=22, file='CCLE_mutations')
mut_data.head()

### Load Treehouse expression data

In [None]:
treehouse_data = pd.read_csv(TREEHOUSE_DATA, index_col=0)
treehouse_data.head()

### Load Treehouse info

In [None]:
treehouse_info = pd.read_csv(TREEHOUSE_INFO, sep='\t', index_col=0)
treehouse_info.head()

In [None]:
with open(NCBI_GENE_NAMES, 'rb') as f:
    ncbi_gene_names = pkl.load(f)

def get_gene_name(geneID):
    if geneID not in ncbi_gene_names:
        Entrez.email = "test@gmail.com"
        handle = Entrez.efetch("gene", id=str(geneID), rettype="gene_table", retmode="text")
        info = handle.readline().split()
        name = info[0]
        ncbi_gene_names[geneID] = f"{name} ({geneID})", f"{' '.join(info[1:]).strip()}"
        with open(NCBI_GENE_NAMES, 'wb') as f:
            pkl.dump(ncbi_gene_names, f)
    return ncbi_gene_names[geneID]

### Load Treehouse-DepMap disease mapping

In [None]:
th_dm_map = pd.read_csv(TH_DM_MAP, sep="\t")
th_dm_map.head()

In [None]:
dm_th_disease_map = {d: [] for d in DISEASES}
for d in set(th_dm_map.depmap_name):
    dm_th_disease_map[d] = list(th_dm_map.loc[th_dm_map.depmap_name == d, "treehouse_name"])

dm_th_disease_map

## 2.  Annotate genes

### Add enrichment info

In [None]:
def get_enrichments(sample_id, data, idx_category_map, category_field, idx_field=None, skewed=None, p_cutoff=.05):
    assert skewed in [None, 'left', 'right']
    categories  = set(idx_category_map[category_field])
    if idx_field is None:
        total_data = data.loc[idx_category_map.index, sample_id]
    else:
        total_data = data.loc[idx_category_map[idx_field], sample_id]
        
    enriched_categories = {}
    for cat in categories:
        if idx_field is None:
            a = total_data.loc[idx_category_map.loc[idx_category_map[category_field] == cat].index]
            b = total_data.loc[idx_category_map.loc[idx_category_map[category_field] != cat].index]
        else:
            a = total_data.loc[idx_category_map.loc[idx_category_map[category_field] == cat, idx_field]]
            b = total_data.loc[idx_category_map.loc[idx_category_map[category_field] != cat, idx_field]]
            
        if len(a) > 2 and len(b) > 2:
            if skewed == 'left' and np.median(a.values) >= np.median(b.values):
                continue
            elif skewed == 'right' and np.median(a.values) <= np.median(b.values):
                continue
            t, p = ttest_ind(a, b, axis=None)
            if p_cutoff is None or p < p_cutoff:
                enriched_categories[cat] = dict(p=p, n=len(a))
    return enriched_categories

In [None]:
enriched_diseases = {"gene": [], "enr_d": [], "p_val": [], "n": []}
for geneID, row in tqdm(gene_list.iterrows(), total=len(gene_list)):
    cell_lines = cell_line_inf.loc[set(row.cell_lines).intersection(eff_data.index)]
    enr_d = get_enrichments(geneID, eff_data, cell_lines, 'specified_disease',
                                       skewed='left', p_cutoff=None)
    for d in enr_d:
        enriched_diseases["gene"].append(geneID)
        enriched_diseases["enr_d"].append(d)
        enriched_diseases["p_val"].append(enr_d[d]["p"])
        enriched_diseases["n"].append(enr_d[d]["n"])
    
enriched_diseases = pd.DataFrame(enriched_diseases)

reject, p_adj, _, _ = multipletests(enriched_diseases.p_val, method='fdr_bh')
enriched_diseases["p_adj"] = p_adj

for geneID, row in gene_list.iterrows():
    enrichments = {}
    for _, inf in enriched_diseases.loc[enriched_diseases.gene == geneID].iterrows():
        if inf.p_adj < 0.1:
            enrichments[inf.enr_d] = dict(p=inf.p_adj, n=inf.n)
    gene_list.loc[geneID, 'enriched_diseases'] = [enrichments]

gene_list

### Add DepMap prediction info

In [None]:
predictions = pred_data.loc[pred_data.best]

for geneID, row in tqdm(gene_list.iterrows(), total=len(gene_list)):
    pred_paralogs = {}
    if geneID in pred_data.index:
        for i in range(10):
            prediction, feature = predictions.loc[geneID, f"feature{i}"].split('_', 1)
            if ' ' not in prediction:
                continue
            
            try:
                pred_id = int(prediction.split(' ', 1)[1].strip('()'))
                
                if pred_id in row.paralogs:
                    pred_paralogs[pred_id] = predictions.loc[geneID, f"feature{i}_importance"]
                    
            except ValueError as e:
                print(geneID)
                print(i, prediction)

    gene_list.loc[geneID, 'paralog_predict_score'] = [pred_paralogs]
    
gene_list

### Add mutation correlations

In [None]:
mut_data.loc[mut_data.Variant_Classification != 'Silent', 
                         ['Hugo_Symbol', 'Entrez_Gene_Id', 'DepMap_ID', 'Variant_Classification']
                        ]

In [None]:
def get_correlation(data1, data2, min_points=3):
    x = data1.flatten().astype(np.float64)
    y = data2.flatten().astype(np.float64)
    
    # Filter out NaNs from both lists
    nan_filter = ~(np.isnan(x)+np.isnan(y))
    
    if len(x[nan_filter]) < min_points:
        return np.nan
    
    return pearsonr(x[nan_filter], y[nan_filter])[0]  # TODO: doe iets met de p_value?

In [None]:
mutations = {}
for geneID, cell_line in mut_data.loc[mut_data.Variant_annotation.isin(['damaging', 'other non-conserving']),
                                      ['Entrez_Gene_Id', 'DepMap_ID']].values:
    if geneID not in mutations:
        mutations[geneID] = set()
    mutations[geneID].add(cell_line)
mutations

In [None]:
all_muts = [get_correlation(eff_data[geneID].values, np.array([1. if cell_line in mutations[paralog] else 0.
                                                               for cell_line in eff_data.index]))
            for geneID, row in tqdm(gene_list.iterrows(), total=len(gene_list))
            for paralog in row.paralogs if paralog in mutations]
all_muts = [x for x in all_muts if x is not np.nan]
mut_corr_avg = np.mean(all_muts)
mut_corr_std = np.std(all_muts)
print(mut_corr_avg, mut_corr_std)

In [None]:
sns.set_style("whitegrid")
fig=plt.figure(figsize=(15, 4), dpi=124, facecolor='w', edgecolor='k')

sns.distplot(all_muts, norm_hist=False, rug=True, hist=False, kde_kws=dict(bw=.01))
plt.xlabel("Mutation correlation (Pearson R)")
plt.vlines(mut_corr_avg, 0, 14, linestyles="--", label="Mean")
plt.vlines(mut_corr_avg+mut_corr_std, 0, 14, colors='gray', linestyles="--", label="Standard deviation")
plt.vlines(mut_corr_avg-mut_corr_std, 0, 14, colors='gray', linestyles="--")

plt.ylim(0, 13)
plt.legend()
plt.ylabel("Density")

plt.show()

In [None]:
for geneID, row in tqdm(gene_list.iterrows(), total=len(gene_list)):
    mut_corrs = {}
    for paralog in row.paralogs:
        if paralog in mutations:
            muts = np.array([1. if cell_line in mutations[paralog] else 0. for cell_line in eff_data.index])
            corr = get_correlation(eff_data[geneID].values, muts)
            mut_corrs[paralog] = (corr, (corr - mut_corr_avg) / mut_corr_std)

    gene_list.loc[geneID, 'paralog_mutation_correlation'] = [mut_corrs]
    
gene_list

### Add Treehouse data

In [None]:
th_diseases = {}
for th_id, disease in treehouse_info.disease.iteritems():
    if disease not in th_diseases:
        th_diseases[disease] = []
    th_diseases[disease].append(th_id)

th_diseases

In [None]:
enriched_diseases = {"gene": [], "paralog": [], "enr_d": [], "p_val": [], "n": []}
for geneID, row in tqdm(gene_list.iterrows(), total=len(gene_list)):
    gene = row.gene
    paralogs = [(p, get_gene_name(p)[0].split(' ')[0]) for p in row.paralogs]
    avg_disease_scores = {d: 0 for d in th_diseases.keys()}
    
    if gene in treehouse_data.index:
        for disease, th_ids in th_diseases.items():
            avg_disease_scores[disease] = np.median(treehouse_data.loc[gene, th_ids].values)
        avg_disease_scores = dict(sorted(avg_disease_scores.items(), key=lambda x: x[1], reverse=True))
    else:
        print(f"{gene} ({geneID}) not in treehouse")
        
    gene_list.loc[geneID, 'avg_th_expression'] = [avg_disease_scores]
    gene_list.loc[geneID, 'up_enriched_th_diseases'] = [get_enrichments(gene, treehouse_data.T, treehouse_info,
                                                                        'disease', skewed='right')]
    
    for p_id, p_name in paralogs:
        if p_name in treehouse_data.index:
            enr_d = get_enrichments(p_name, treehouse_data.T, treehouse_info,
                                    'disease', skewed='left', p_cutoff=None)
            for d in enr_d:
                enriched_diseases["gene"].append(geneID)
                enriched_diseases["paralog"].append(p_id)
                enriched_diseases["enr_d"].append(d)
                enriched_diseases["p_val"].append(enr_d[d]["p"])
                enriched_diseases["n"].append(enr_d[d]["n"])
                
enriched_diseases = pd.DataFrame(enriched_diseases)

reject, p_adj, _, _ = multipletests(enriched_diseases.p_val, method='fdr_bh')
enriched_diseases["p_adj"] = p_adj
           
for geneID, row in gene_list.iterrows():
    enrichments = {}
    for p_id in row.paralogs:
        for _, inf in enriched_diseases.loc[(enriched_diseases.gene == geneID) &
                                            (enriched_diseases.paralog == p_id)].iterrows():    
            if inf.p_adj < 0.05:
                enrichments[inf.enr_d] = {p_id: dict(p=inf.p_adj, n=inf.n)}
    gene_list['enriched_diseases'][geneID] = enrichments

gene_list

In [None]:
def enr_vals(gene_id, gene, cell_lines, name, typ, data=eff_data, cl_field=None, data_name="gene_effect"):
    dat_lines = set(data.index).intersection(cell_lines)
    dat_vals = {data_name: list(data.loc[dat_lines, gene].values.flatten()),
                "gene": [get_gene_name(gene_id)[0]]*len(dat_lines),
                "name": [f"{name} enrichment"]*len(dat_lines),
                "type": [typ]*len(dat_lines)}
    return dat_vals

In [None]:
data = {"gene_effect": [], "name": [], "type": [], "gene": []}
for p_idx, enr_diseases in gene_list.down_enriched_paralog_th_diseases[29107][0].items():
    for dm_dis, th_diss in dm_th_dis_map.items():
        overlap = set(th_diss).intersection(enr_diseases)
        if len(overlap) > 0:
            for k, v in enr_vals(p_idx, "NXT2", treehouse_info.loc[treehouse_info.disease.isin(overlap)].index,
                                 dm_dis, "enriched", treehouse_data.T).items():
                data[k].extend(v)
            for k, v in enr_vals(p_idx, "NXT2", treehouse_info.loc[~treehouse_info.disease.isin(overlap)].index,
                                 dm_dis, "other", treehouse_data.T).items():
                data[k].extend(v)
            
data = pd.DataFrame(data)
data

In [None]:
fig=plt.figure(figsize=(15, 15), dpi=124, facecolor='w', edgecolor='k')

sns.boxplot(data=data, x="gene_effect", y="name", hue="type", orient="h", palette=['royalblue', 'lightgray'])
plt.ylabel(None)
plt.xlabel("Treehouse expression score (log2(TPM+1))")
plt.title("NXT2 (55916)")
plt.legend()
plt.show()

In [None]:
fig=plt.figure(figsize=(15, 15), dpi=124, facecolor='w', edgecolor='k')

# Explore distributions:
g = 'NXT2'
d = ['neuroblastoma', 'wilms tumor', 'rhabdomyosarcoma']
d = list(gene_list.down_enriched_paralog_th_diseases[29107][0].val## Save results

gene_list.to_pickle(f'results/essential_candidates/public_20Q2/added_features_median_padj.pkl')ues())[0].keys()

sns.distplot(treehouse_data.loc[g], bins=50, label='all', kde=False)
for i in d:
    sns.distplot(treehouse_data.loc[g, th_diseases[i]], label=i, kde=False)
    
plt.legend()
plt.show()

## Save results

In [None]:
gene_list.to_pickle(RESULTS_FILE)