In [None]:
import numpy as np
import pandas as pd
import os
import sys
import matplotlib.pyplot as plt
import pickle as pkl
from tqdm.notebook import tqdm
from scipy.stats import pearsonr
from itertools import product
from util import *
import seaborn as sns

%matplotlib inline

### Define paths

In [None]:
# Input files
ESSENTIAL_GENES = "results/essential_candidates/public_20Q2/essential_genes_annotated.pkl"
NCBI_GENE_NAMES = 'data/misc/ncbi_gene_names.pkl'

# Output files
RESULTS_FILE = 'results/essential_candidates/public_20Q2/expression_correlations.pkl'

## 1.  Load data

### Load identified genes

In [None]:
identified_genes = pd.read_pickle(ESSENTIAL_GENES)
identified_genes.head()

### Load effect data

In [None]:
eff_data = get_from_taiga(name='public-20q2-075d', version=22, file='Achilles_gene_effect',
                          split_attribute='header')
eff_data.head()

### Load expression data

In [None]:
exp_data = get_from_taiga(name='public-20q2-075d', version=22, file='CCLE_expression',
                          split_attribute='header')
exp_data.head()

### Load cell line info

In [None]:
cell_line_inf = get_from_taiga(name='public-20q2-075d', version=22, file='sample_info')
cell_line_inf.set_index('DepMap_ID', inplace=True)

cell_line_inf.loc[(cell_line_inf['lineage_subtype']=='ALL') &
                  (cell_line_inf['lineage_sub_subtype'].str.contains('t')), 'lineage_subtype'] = 't-ALL'
cell_line_inf.loc[(cell_line_inf['lineage_subtype']=='ALL') &
                  (cell_line_inf['lineage_sub_subtype'].str.contains('b')), 'lineage_subtype'] = 'b-ALL'

cell_line_inf['specified_disease'] = cell_line_inf.loc[cell_line_inf.lineage_subtype.isin(PEDIATRIC_CANCERS)
                                                      ].lineage_subtype
cell_line_inf.specified_disease.fillna(cell_line_inf.primary_disease, inplace=True)

diseases = cell_line_inf.specified_disease.unique()

cell_line_inf.head()

In [None]:
with open(NCBI_GENE_NAMES, 'rb') as f:
    ncbi_gene_names = pkl.load(f)

def get_gene_name(geneID):
    if geneID not in ncbi_gene_names:
        Entrez.email = "test@gmail.com"
        handle = Entrez.efetch("gene", id=str(geneID), rettype="gene_table", retmode="text")
        info = handle.readline().split()
        name = info[0]
        ncbi_gene_names[geneID] = f"{name} ({geneID})", f"{' '.join(info[1:]).strip()}"
        with open(NCBI_GENE_NAMES, 'wb') as f:
            pkl.dump(ncbi_gene_names, f)
    return ncbi_gene_names[geneID]

## 2.  Find correlations between genes and paralogs

In [None]:
def get_correlation(cell_lines, dep_gene, exp_gene):
    d = eff_data.loc[cell_lines, dep_gene].values.flatten().astype(np.float64)
    e = exp_data.loc[cell_lines, exp_gene].values.flatten().astype(np.float64)
    
    # Filter out NaNs from both lists
    nan_filter = ~(np.isnan(d)+np.isnan(e))
    
    if len(d[nan_filter]) < 3:
        return np.nan
    
    return pearsonr(d[nan_filter], e[nan_filter])[0]  # TODO: doe iets met de p_value?

In [None]:
thres = .3
cell_lines = set(eff_data.index).intersection(exp_data.index)

results = {
    "geneID": [],
    "gene": [],
    "n_paralogs": [],
    "avg_correlation_all_paralogs": [],
    "max_correlation_all_paralogs": [],
    "max_disease_specific_all": [],
    "avg_correlation_interacting_paralogs": [],
    "max_correlation_interacting_paralogs": [],
    "max_disease_specific_interacting": [],
    "correlations": [],
    "disease_specific_correlations": [],
}

not_found = 0

for idx, row in tqdm(identified_genes.iterrows(), total=len(identified_genes)):
    paralogs = row.paralogs
    interacting_paralogs = row.interacting_paralogs
    
    correlations = {}
    correlations_per_disease = {}
    max_disease_specific = {}
    
    # IDEA: only look at interacting paralogs?
    for paralog in paralogs:
        if idx in eff_data.columns and paralog in exp_data.columns: 
            corr = get_correlation(cell_lines, idx, paralog)
            # If d or e is very small or constant corr will be NaN
            if corr is not np.nan:
                correlations[paralog] = corr
                correlations_per_disease[paralog] = {}
                for disease in diseases:
                    disease_lines = cell_lines.intersection(cell_line_inf.loc[
                                                            cell_line_inf.specified_disease == disease].index)
                    corr = get_correlation(disease_lines, idx, paralog)
                    if corr is not np.nan:
                        correlations_per_disease[paralog][disease] = corr
                max_disease_specific[paralog] = max(correlations_per_disease[paralog].items(), 
                                                    key=lambda x: x[1])
    
    if len(correlations) > 0:
        results["geneID"].append(idx)
        results["gene"].append(row.gene)
        results["n_paralogs"].append(len(paralogs))
        
        results["avg_correlation_all_paralogs"].append(np.median(list(correlations.values())))
        results["max_correlation_all_paralogs"].append(max(list(correlations.values())))
        results["max_disease_specific_all"].append(max(max_disease_specific.items(), key=lambda x: x[1][1]))
        
        interacting_correlations = [c for p, c in correlations.items() if p in interacting_paralogs]
        if len(interacting_correlations) > 0:
            results["avg_correlation_interacting_paralogs"].append(np.median(interacting_correlations))
            results["max_correlation_interacting_paralogs"].append(max(interacting_correlations))
            results["max_disease_specific_interacting"].append(max({p: c 
                                                                    for p, c in max_disease_specific.items() 
                                                                    if p in interacting_paralogs}.items(), 
                                                                   key=lambda x: x[1][1]))
        else:
            results["avg_correlation_interacting_paralogs"].append(None)
            results["max_correlation_interacting_paralogs"].append(None)
            results["max_disease_specific_interacting"].append(None)
            
        results["correlations"].append(correlations)
        results["disease_specific_correlations"].append(correlations_per_disease)
    else:
        not_found += 1You 
        print(f"For gene {idx} no (significant) correlations found...")
        print(f"There were {len(paralogs)} evaluations.")

if not_found > 0:
    print(f"For {not_found} genes no correlations were determined!")
    
results = pd.DataFrame(results).set_index('geneID')
results

## 3.  Results

In [None]:
cell_lines = set(eff_data.index).intersection(exp_data.index)

def make_plot(gene_id, paralog_id, diseases):
    d = eff_data.loc[cell_lines, gene_id].values.flatten().astype(float)
    e = exp_data.loc[cell_lines, paralog_id].values.flatten().astype(float)
    
    sns.set_style("whitegrid")
    fig=plt.figure(figsize=(13, 6), dpi=124, facecolor='w', edgecolor='k')
    
    sns.regplot(d, e, ci=None,
                line_kws=dict(color="#d71d24", label=f"Correlation across all lines (Pearson R = {results.loc[gene_id].correlations[paralog_id]:.2f})"),
                label=f"All cancer cell lines (n = {len(cell_lines)})")
    for disease in diseases:
        dis_lines = set(cell_line_inf.loc[cell_line_inf.specified_disease == disease].index).intersection(cell_lines)
        d = eff_data.loc[dis_lines, gene_id].values.flatten().astype(float)
        e = exp_data.loc[dis_lines, paralog_id].values.flatten().astype(float)
        sns.regplot(d, e, ci=None, label=f"{disease} lines (n = {len(dis_lines)})", line_kws=dict(label=f"{disease} specific correlation (Pearson R = {results.loc[gene_id].disease_specific_correlations[paralog_id][disease]:.2f})"))
    
    plt.xlabel(f"{results.gene[gene_id]} ({gene_id}) gene effect (CERES)")
    plt.ylabel(f"{results.gene[paralog_id]} ({paralog_id}) expression (log2(TPM+1))")
    plt.legend(loc="lower right")
    plt.show()

In [None]:
make_plot(29107, 55916, ["neuroblastoma"])

In [None]:
def rem_nans(x):
    return x[~np.isnan(x)]
np.percentile(rem_nans(results.max_correlation_interacting_paralogs), 95)

In [None]:
# Genes of interest to highlight in plot
GOI = [10006, 483, 23545, 1000, 8450, 131118, 1871, 63916, 2625, 9759,
       3845, 10905, 4170, 29107, 9943, 5290, 6688, 27183, 9525, 7454]

df = pd.melt(results[["max_correlation_all_paralogs", "max_correlation_interacting_paralogs", "avg_correlation_all_paralogs", "avg_correlation_interacting_paralogs"]].reset_index(), id_vars='geneID')#, "max_disease_specific_interacting_value", "max_disease_specific_all_value"]])
df["Paralog group"] = ["All paralogs" if i else "Interacting paralogs" for i in df.variable.str.contains("all")]
df["minmax"] = ["Average of paralogs" if not i else "Max disease specific correlation" if j else "Max of paralogs" for i, j in zip(df.variable.str.contains("max"), df.variable.str.contains("disease"))]
df

sns.set_style("whitegrid")
fig=plt.figure(figsize=(12, 4), dpi=124, facecolor='w', edgecolor='k')

sns.boxplot(y="minmax", x="value", hue="Paralog group", data=df, orient="h")

sns.swarmplot(y="minmax", x="value", hue="Paralog group", data=df.loc[df.geneID.isin(GOI)],
              dodge=True, palette=("gold", "gold"))

plt.xlabel("Expression correlation (Pearson R)")
plt.ylabel(None)
plt.vlines(.24, -.5, 1.5, "#d71d24", "--", label="Selection cutoff")

handles, labels = plt.gca().get_legend_handles_labels()
del(handles[-2])

plt.legend(handles, ['All paralogs', 'Interacting paralogs', 'Selected candidates', 'Selection cutoff'])

plt.show()

### Save results

In [None]:
results.to_pickle(RESULTS_FILE)