In [None]:
import numpy as np
import pandas as pd
import os
import sys
import matplotlib.pyplot as plt
import pickle as pkl
from tqdm.notebook import tqdm
import plotly.graph_objects as go
import plotly.express as px
from bio import Entrez
from collections import Counter, defaultdict
import re
from util import *
import seaborn as sns

sns.set(color_codes=False)
sns.set_style("whitegrid")

%matplotlib inline

### Define paths

In [None]:
# Input files
ESSENTIAL_GENES  = f'results/essential_candidates/public_20Q2/essential_genes-all.pkl'
INTERACTION_DATA = f"results/essential_candidates/public_20Q2/essential_genes_annotated.pkl"
CORRELATION_DATA = f'results/essential_candidates/public_20Q2/expression_correlations.pkl'
ADDED_FEATURES   = f'results/essential_candidates/public_20Q2/added_features_median_padj.pkl'
NCBI_GENE_NAMES  = 'data/misc/ncbi_gene_names.pkl'
TH_DM_MAP        = 'data/treehouse/th_dm_map.csv'
TREEHOUSE_INFO   = 'data/treehouse/clinical_TumorCompendium_v11_PolyA_2020-04-09.tsv'

## 1.  Load data

### Load correlation data

In [None]:
correlation_data = pd.read_pickle(CORRELATION_DATA)
correlation_data.head()

### Load paralog-common essential interaction data

In [None]:
interaction_data = pd.read_pickle(INTERACTION_DATA)
interaction_data.head()

### Load essentiality data

In [None]:
essential_genes = pd.read_pickle(ESSENTIAL_GENES)
essential_genes.head()

### Load added features

In [None]:
added_features = pd.read_pickle(ADDED_FEATURES)
added_features.head()

### Load cell line info


In [None]:
cell_line_inf = get_from_taiga(name='public-20q2-075d', version=22, file='sample_info')
cell_line_inf.set_index('DepMap_ID', inplace=True)

cell_line_inf.loc[(cell_line_inf['lineage_subtype']=='ALL') &
                  (cell_line_inf['lineage_sub_subtype'].str.contains('t')), 'lineage_subtype'] = 't-ALL'
cell_line_inf.loc[(cell_line_inf['lineage_subtype']=='ALL') &
                  (cell_line_inf['lineage_sub_subtype'].str.contains('b')), 'lineage_subtype'] = 'b-ALL'

cell_line_inf['specified_disease'] = cell_line_inf.loc[cell_line_inf.lineage_subtype.isin(PEDIATRIC_CANCERS)
                                                      ].lineage_subtype
cell_line_inf.specified_disease.fillna(cell_line_inf.primary_disease, inplace=True)

DISEASES = dict(cell_line_inf.specified_disease.value_counts())

cell_line_inf.head()

### Load dependency data

In [None]:
eff_data = get_from_taiga(name='public-20q2-075d', version=22, file='Achilles_gene_effect',
                          split_attribute='header')
eff_data.head()

### Load expression data

In [None]:
exp_data = get_from_taiga(name='public-20q2-075d', version=22, file='CCLE_expression',
                          split_attribute='header')
exp_data.head()

### Load Treehouse-DepMap disease mapping

In [None]:
th_dm_map = pd.read_csv(TH_DM_MAP, sep="\t")
th_dm_map.head()

In [None]:
dm_th_disease_map = {d: [] for d in DISEASES}
for d in set(th_dm_map.depmap_name):
    dm_th_disease_map[d] = list(th_dm_map.loc[th_dm_map.depmap_name == d, "treehouse_name"])

dm_th_disease_map

### Load Treehouse info

In [None]:
treehouse_info = pd.read_csv(TREEHOUSE_INFO, sep='\t', index_col=0)
print(treehouse_info.shape)
treehouse_info.head()

In [None]:
th_diseases = {}
for th_id, disease in treehouse_info.disease.iteritems():
    if disease not in th_diseases:
        th_diseases[disease] = []
    th_diseases[disease].append(th_id)

th_diseases

In [None]:
with open(NCBI_GENE_NAMES, 'rb') as f:
    ncbi_gene_names = pkl.load(f)

def get_gene_name(geneID):
    if geneID not in ncbi_gene_names:
        Entrez.email = "test@gmail.com"
        handle = Entrez.efetch("gene", id=str(geneID), rettype="gene_table", retmode="text")
        info = handle.readline().split()
        name = info[0]
        ncbi_gene_names[geneID] = f"{name} ({geneID})", f"{' '.join(info[1:]).strip()}"
        with open(NCBI_GENE_NAMES, 'wb') as f:
            pkl.dump(ncbi_gene_names, f)
    return ncbi_gene_names[geneID]

In [None]:
def dict_transpose(dct):
    d = defaultdict(dict)
    for key1, inner in dct.items():
        for key2, value in inner.items():
            d[key2][key1] = value
    return d

In [None]:
def disease_score(disease, cnt, cnt_for_disease, disease_correlations,
                  enriched_diseases, enriched_paralog_th_diseases):
    enriched_paralog_diseases = [j for i in enriched_paralog_th_diseases.values() for j in i.keys()]
    return np.nansum([cnt/cnt_for_disease if cnt_for_disease > 2 else 0,
                      max(disease_correlations[disease].values()) if disease in disease_correlations else 0,
                      .6 if disease in enriched_diseases else 0,
                      .6 if any(i in enriched_paralog_diseases for i in dm_th_disease_map[disease]) else 0,
                     ])

## 2.  Combine results

In [None]:
unique_diseases = set()
cnt_for_lines = Counter(cell_line_inf.specified_disease.values)

combined = pd.concat([correlation_data[['gene', 'avg_correlation_all_paralogs', 'max_correlation_all_paralogs', 
                                        'avg_correlation_interacting_paralogs', 
                                        'max_correlation_interacting_paralogs', 'correlations', 
                                        'max_disease_specific_all', 'max_disease_specific_interacting', 
                                        'disease_specific_correlations']], 
                      interaction_data[['n_paralogs', 'n_common_essentials', '% paralogs interacting with any', 
                                        '% paralogs interacting with all', 'interaction_graph', 'paralogs', 
                                        'common_essentials', 'interacting_paralogs',
                                        'interacting_common_essentials']],
                      essential_genes[['n_lines', 'cell_lines']],
                      added_features[['enriched_diseases', 'paralog_predict_score',
                                      'paralog_mutation_correlation', 'avg_th_expression',
                                      'up_enriched_th_diseases', 'down_enriched_paralog_th_diseases']]], 
                     axis=1, join='inner')

### ADD FEATURES HERE ###
for idx, row in tqdm(combined.iterrows(), total=len(combined)):
    cell_lines = set(row.cell_lines).intersection(cell_line_inf.index)
    diseases = cell_line_inf.loc[cell_lines].specified_disease
    unique_diseases.update(diseases)
    
    combined.loc[idx, 'n_diseases'] = str(len(diseases.unique()))
    combined.loc[idx, 'diseases'] = [[diseases]]
    
    combined.loc[idx, 'n_interacting_paralogs'] = int(len(combined.loc[idx].interacting_paralogs))

    cnt_for_gene = Counter(diseases.values)
    disease_correlations = dict_transpose(row.disease_specific_correlations)
    
    top_diseases = dict(sorted(cnt_for_gene.items(), 
                               key=lambda x: disease_score(x[0], x[1], cnt_for_lines[x[0]], 
                                                           disease_correlations, row.enriched_diseases,
                                                           row.down_enriched_paralog_th_diseases[0]),
                               reverse=True))
    
    combined.loc[idx, 'top_diseases'] = [dict(
            diseases        = list(top_diseases.keys()),
            total_lines     = [cnt_for_lines[disease] for disease in top_diseases],
            dependent_lines = list(top_diseases.values()),
            max_correlation = [max(disease_correlations[disease].items(), key=lambda x: x[1]) 
                               if disease in disease_correlations else None 
                               for disease in top_diseases],
            enriched_p      = [row.enriched_diseases[disease]['p'] 
                               if disease in row.enriched_diseases else None
                               for disease in top_diseases],
            enriched_th_avg = [any(th_dis in row.up_enriched_th_diseases[0]
                                   for th_dis in dm_th_disease_map[disease])
                               for disease in top_diseases],
            th_enr_paralogs = [[p for p, d in row.down_enriched_paralog_th_diseases[0].items()
                                    if len(set(dm_th_disease_map[disease]).intersection(d.keys()))]
                               for disease in top_diseases],
            score           = [disease_score(disease, cnt, cnt_for_lines[disease], disease_correlations,
                                             row.enriched_diseases, row.down_enriched_paralog_th_diseases[0])
                               for disease, cnt in top_diseases.items()],
    )]
    
    combined.loc[idx, 'top_disease_score'] = combined.loc[idx, 'top_diseases']['score'][0]
    

    p_score = {p: dict(prediction = row.paralog_predict_score[0][p] 
                                    if p in row.paralog_predict_score[0] else np.nan,
                       mutation_corr = row.paralog_mutation_correlation[0][p]
                                       if p in row.paralog_mutation_correlation[0] else np.nan,
                       expression_corr = row.correlations[p]
                                         if p in row.correlations else np.nan,
                       max_dis_corr = max(row.disease_specific_correlations[p].items(), key=lambda x: x[1])
                                      if p in row.disease_specific_correlations else np.nan,
                       interacting = True if p in row.interacting_paralogs else False)
              for p in row.paralogs}
    
    combined.loc[idx, 'top_paralogs'] = [{get_gene_name(p)[0]: 
                        dict(prediction = p_score[p]["prediction"],
                             mutation_corr = p_score[p]["mutation_corr"],
                             expression_corr = p_score[p]["expression_corr"],
                             max_dis_corr = p_score[p]["max_dis_corr"],
                             interacting = p_score[p]["interacting"], 
                             score = np.nansum([p_score[p]["prediction"],
                                                -p_score[p]["mutation_corr"][1] 
                                                  if p_score[p]["mutation_corr"] is not np.nan else 0,
                                                p_score[p]["expression_corr"],
                                                p_score[p]["max_dis_corr"][1]
                                                  if p_score[p]["max_dis_corr"] is not np.nan else 0,
                                                1 if p_score[p]["interacting"] else 0]),
                          ) for p in row.paralogs}]
    

combined.index.name = 'GeneID'

pd.set_option("display.max_columns", 100)
    
combined.head()

### Select interesting candidates

In [None]:
selection = combined.loc[((combined.max_correlation_interacting_paralogs > .24) |
                          ((combined.top_disease_score > 1.68) & 
                           (combined.max_disease_specific_interacting.str[1].str[1] > .7))) &
                         (combined.n_interacting_paralogs <= 5) &
                         (combined.n_interacting_paralogs > 0)]
selection[['gene', 
           'n_common_essentials', 
           'n_paralogs', 
           'n_interacting_paralogs',
           'n_lines', 'n_diseases',  
           'max_correlation_interacting_paralogs', 
           'top_disease_score',
           'max_disease_specific_interacting',
           'top_paralogs',
         ]]

### Save results

In [None]:
combined.to_pickle("results/annotated_candidates.pkl")

# Also save to csv for reading into e.g. excel
combined.to_csv("results/annotated_candidates.csv", sep=";")
selection.to_csv("results/most_interesting_candidates.csv", sep=";")

## 3.  Print some of the data

In [None]:
print('List of genes of interest:')
cnt = 1
for idx, row in combined.loc[selection.index].iterrows():
    print(f" {cnt}.\t{row.gene} ({idx}) has {row.n_paralogs} paralogs and interacts with {row.n_common_essentials} common essentials.")
    print(f"\t{int(row.n_interacting_paralogs)} of the paralogs interact with {len(row.interacting_common_essentials)} of the common essentials.")
    print(f"\t Interacting paralogs:\tInteracting common essentials:")
    for i in range(max(int(row.n_interacting_paralogs), len(row.interacting_common_essentials))):
        print(f"\t    {get_gene_name(list(row.interacting_paralogs)[i])[0] if i < row.n_interacting_paralogs else '      '}",
              '\t' if len(get_gene_name(list(row.interacting_paralogs)[i])[0] if i < row.n_interacting_paralogs else '      ') < 8 else ' ',
              f"\t   {get_gene_name(list(row.interacting_common_essentials)[i])[0] if i < len(row.interacting_common_essentials) else ' '}")
    
    print()
    cnt+=1

In [None]:
top_n = 5

pd.set_option("display.max_columns", 100)
a = {'rank': list(range(1, top_n+1))}
for idx, row in selection.loc[selection.n_diseases.astype(int) >= 5].iterrows():
    a[get_gene_name(idx)[0]] = [f"{d} [{s:.2f}]"
                        for d, s in list(zip(row.top_diseases['diseases'], row.top_diseases['score']))[:top_n]]
    
pd.DataFrame(a).set_index('rank')

In [None]:
for idx, row in selection.iterrows():
    print(get_gene_name(idx)[0])
    print(f"paralogs\t{[get_gene_name(p)[0] for p in row.paralogs]}")
    for k, v in row.top_diseases.items():
        print(f"{k}\t{v[:5]}")
    print()

In [None]:
n_cancers = len(PEDIATRIC_CANCERS)

scores = {}

for idx, row in selection.iterrows():
    scores[get_gene_name(idx)[0]] = {d: (s, l) for d, s, l in zip(row.top_diseases['diseases'],
                                                          row.top_diseases['score'],
                                                          row.top_diseases['dependent_lines'])}
# print(scores)

selection_diseases_scores = pd.DataFrame(scores)
selection_diseases_scores['disease'] = selection_diseases_scores.index
   
selection_diseases_scores = pd.melt(selection_diseases_scores, id_vars=['disease'], value_vars=list(scores.keys()), var_name="gene", value_name="score")
selection_diseases_scores['lines'] = selection_diseases_scores.score.str[1]
selection_diseases_scores['score'] = selection_diseases_scores.score.str[0] 
selection_diseases_scores

### Create interactive plots of the expression correlations

In [None]:
CELL_LINES = set(eff_data.index).intersection(exp_data.index)

def make_plot(gene_id, trendline_correlation_threshold=.4):
    thresh = trendline_correlation_threshold
    
    info = combined.loc[gene_id]
    _cell_lines   = CELL_LINES.intersection(info.cell_lines)
    _paralogs     = info.interacting_paralogs
    _correlations = info.correlations
    
    fig = go.Figure()
    for i, _paralog in enumerate(_paralogs):    
        # Make trendlines first, so we can copy the colors of the diseases!
        tl_data = pd.concat([eff_data.loc[CELL_LINES, gene_id], 
                             exp_data.loc[CELL_LINES, _paralog], 
                             cell_line_inf.loc[CELL_LINES].specified_disease], 
                            axis=1, join='inner')
        tl_data.columns = ['dependency', 'expression', 'disease']
        trendlines = px.scatter(tl_data.loc[tl_data.disease.isin(info.diseases[0])], 
                                            x='dependency', y='expression', color='disease',
                                            color_continuous_scale='Rainbow', trendline="ols").data[1::2]
 
        # Customize trendlines
        good_shown = False
        rest_shown = False
        
        disease_color_map = {}        
        lines = {}

        for trendline in trendlines:
            disease = trendline['name']
            disease_color_map[disease] = trendline['marker']['color']
            
            if trendline['x'] is not None:
                a = float(re.split('<br>', trendline['hovertemplate'])[1].split()[2])
                b = float(re.split('<br>', trendline['hovertemplate'])[1].split()[6])
                lines[disease] = (a, b)
        
        # Draw dependent cell lines
        fig.add_trace(go.Scatter(x=eff_data.loc[_cell_lines, gene_id].T.values[0],
                                 y=exp_data.loc[_cell_lines, _paralog].T.values[0],
                                 mode='markers',
                                 marker=dict(color=[disease_color_map[d] for d in info.diseases[0]],
                                             symbol=i,
                                             line=dict(width=.7, color='DarkSlateGrey')),
                                 text=[f"Paralog: {get_gene_name(_paralog)[0]}<br>"\
                                       f"Cell line: {i}<br>"\
                                       f"Disease: <b>{j}</b>" for i, j in zip(_cell_lines, info.diseases[0])],
                                 name=f"{get_gene_name(_paralog)[0]} [{_correlations[_paralog]:.2f}]",
                                ))
        
        for trendline in trendlines:
            disease = trendline['name']
            
            # Some diseases have no correlation determined, so we don't draw these
            # Mostly bc they give KeyErrors
            if disease in info.disease_specific_correlations[_paralog]: 
                # We align the trendline with the markers
                # So the hover text matches at the positions
                x = trendline['x'] 
                y = trendline['y']
                trendline['x'] = np.array(x[0])
                trendline['y'] = np.array(y[0])
                for _x, _dis in zip(eff_data.loc[_cell_lines, gene_id].T.values[0], info.diseases[0]):
                    if disease == _dis:
                        trendline['x'] = np.append(trendline['x'], _x)
                        trendline['y'] = np.append(trendline['y'], lines[disease][0] * _x + lines[disease][1])
                trendline['x'] = np.append(trendline['x'], -.4)
                trendline['y'] = np.append(trendline['y'], lines[disease][0] * -.4 + lines[disease][1])
                trendline['x'] = np.append(trendline['x'], x[-1])
                trendline['y'] = np.append(trendline['y'], y[-1])
                
                
                corr = info.disease_specific_correlations[_paralog][disease]
                
                trendline['hovertemplate'] = f"<b>{disease} trendline</b><br>"\
                                             f"Disease specific correlation: {corr:.3f}<br>"\
                                             f"Paralog: {get_gene_name(_paralog)[0]}"
                trendline['legendgroup']   = f"{_paralog} trendlines{' good' if corr > thresh else ''}"
                trendline['name']          = f"{get_gene_name(_paralog)[0].split()[0]} "\
                                             f"<b>Trendlines{f' [>{thresh}]' if corr > thresh else ''}</b>"
                if corr <= thresh:
                    trendline['visible'] = 'legendonly'
                    
                # Only draw one legend entry and not for every trendline
                if not rest_shown:
                    if corr <= thresh:
                        trendline['showlegend'] = True
                        rest_shown = True
                elif not good_shown and corr > thresh:
                    trendline['showlegend'] = True
                    good_shown = True

                fig.add_trace(trendline)
        
    # Add some titles etc.
    fig.update_layout(title=f'<b>Correlation plot of cell lines where {get_gene_name(gene_id)[0]} '\
                            f'is a dependency, with its interacting paralogs</b>', 
                      xaxis_title='Dependency score [CERES]',
                      yaxis_title='Expression of paralog [log<sub>2</sub>(TPM+1)]',
                      legend_title='<b>Paralogs [correlation]:</b>',
                      hovermode='x',
                      xaxis=dict(showspikes=True, spikethickness=1, spikecolor='DarkSlateGrey', 
                                 spikemode='across', range=[-2.5, -.4]),
                      yaxis_range=[-.5, 10.0],
                      showlegend=True,
                     )
    
    fig.show()

In [None]:
for gene_id in selection.index:
    try:
        make_plot(gene_id, .5)
    except Exception as e:
        print(f"Couldn't create plot for {get_gene_name(gene_id)[0]}: {e}")

### Create other plots

In [None]:
fig=plt.figure(figsize=(15, 7), dpi=128, facecolor='w', edgecolor='k')
sns.set_style("whitegrid")

ax = sns.swarmplot(data=selection_diseases_scores.loc[selection_diseases_scores.disease.isin(PEDIATRIC_CANCERS)],
                   x="gene", y="score", hue="disease", size=8, palette="muted")

ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
# ax.axes.set_title("Scores accross pediatric diseases",fontsize=20)
ax.set_xlabel(None)
ax.set_ylabel("Disease score",fontsize=15)
ax.tick_params(labelsize=12)

plt.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0., title="Pediatric diseases")
plt.show()

In [None]:
dis_col_map = {d: "b" if d in PEDIATRIC_CANCERS else "r" for d in DISEASES}
df = selection_diseases_scores.pivot(index="gene", columns="disease", values="score").fillna(-1)
sns.clustermap(df,
               xticklabels=True, yticklabels=True, figsize=(10, 7), dendrogram_ratio=(.1, .2),
               cbar_pos=(-.08,.345,.02,.5),
               standard_scale=0,
               vmin=0, cbar_kws=dict(label='Disease score'))

plt.show()

In [None]:
lines_per_disease = {}
for idx, row in tqdm(combined.iterrows(), total=len(combined)):
    if row.n_lines > 2:
        gene = f"{row.gene} ({idx})"
        lines_per_disease[gene] = {}
        for d in DISEASES:
            if d in set([j for i in selection.diseases.values for j in i[0].values]):
                lines_per_disease[gene][d] = 0
        for cell_line in row.cell_lines:
            if cell_line in cell_line_inf.index:
                d = cell_line_inf.loc[cell_line].specified_disease
                if d in lines_per_disease[gene].keys():
                    lines_per_disease[gene][d] += 1
        for d, cnt in DISEASES.items():
            if d in lines_per_disease[gene].keys():
                lines_per_disease[gene][d] /= cnt

lines_per_disease = pd.DataFrame(lines_per_disease)
lines_per_disease

In [None]:
fig=plt.figure(figsize=(10, 5), dpi=124, facecolor='w', edgecolor='k')
sns.clustermap(lines_per_disease[[get_gene_name(i)[0] for i in selection.index]].T,
               xticklabels=True, yticklabels=True, figsize=(10, 7), dendrogram_ratio=(.1, .3),
               cbar_pos=(-.08,.325,.02,.455),
               vmin=0, vmax=1, cbar_kws=dict(label='Relative dependency'))
plt.xticks(rotation=45)
plt.show()

In [None]:
fig=plt.figure(figsize=(10, 5), dpi=124, facecolor='w', edgecolor='k')
sns.regplot(data=combined, x="n_paralogs", y="max_correlation_all_paralogs")
plt.show()

In [None]:
def enr_vals(gene_id, cell_lines, name, typ, p_val, data=eff_data, cl_field=None, data_name="gene_effect"):
    dat_lines = set(data.index).intersection(cell_lines)
    dat_vals = {data_name: list(data.loc[dat_lines, gene_id].values.flatten()),
                "gene": [get_gene_name(gene_id)[0]]*len(dat_lines),
                "name": [f"{name} enrichment\n(p_adj={p_val:.3f})"]*len(dat_lines),
                "type": [typ]*len(dat_lines)}
    return dat_vals

In [None]:
data = {"gene_effect": [], "name": [], "type": [], "gene": []}
for idx, row in selection.iterrows():
    for enr_dis, inf in row.enriched_diseases.items():
        for k, v in enr_vals(idx, cell_line_inf.loc[(cell_line_inf.specified_disease == enr_dis) & (cell_line_inf.index.isin(row.cell_lines))].index, enr_dis, "enriched", inf['p']).items():
            data[k].extend(v)
        for k, v in enr_vals(idx, cell_line_inf.loc[(cell_line_inf.specified_disease != enr_dis) & (cell_line_inf.index.isin(row.cell_lines))].index, enr_dis, "other", inf['p']).items():
            data[k].extend(v)
    
data = pd.DataFrame(data)
data

In [None]:
fig, axs = plt.subplots(len(set(data.gene)), sharex=True, figsize=(15, 19), dpi=124,
                        gridspec_kw={'height_ratios': [len(set(data.loc[data.gene==i, "name"])) for i in sorted(set(data.gene))]})
for ax, gene in zip(axs, sorted(set(data.gene))):
    g = sns.boxplot(data=data.loc[data.gene==gene], x="gene_effect", y="name", hue="type", orient="h", ax=ax,
                    palette=['royalblue', 'lightgray'])
    g.set(ylabel=None, title=gene)
    
    plt.setp(ax.get_yticklabels(), fontsize=13)
    
    if ax != axs[0]:
        g.legend_.remove()
    else:
        ax.legend()
    if ax != axs[-1]:
        g.set(xlabel=None)
    else:
        g.set(xlabel="Gene effect score (CERES)")

plt.show()