In [None]:
import numpy as np
import pandas as pd
import os
import sys
import matplotlib.pyplot as plt
from collections import Counter
from util import *
from matplotlib_venn import venn3
from tqdm.notebook import tqdm
import seaborn as sns

%matplotlib inline

### Define paths

In [None]:
# Input files
CANDIDATE_GENES = "results/candidate_genes/public_20Q2"

# Output files
RESULTS_FILE = "results/essential_candidates/public_20Q2/essential_genes-all.pkl"

## 1.  Load data

### Load paralog genes interacting with common essential

In [None]:
candidate_genes = []
for dataset in ["CORUM", "STRING", "SIGNOR", "HuRI"]:
    for paralog_set in ["DGD", "PANTHER"]:
        with open(os.path.join(CANDIDATE_GENES, f"{dataset}-{paralog_set}"), 'r') as f:
            genes = [int(i.strip(' \n')) for i in f.readlines()]
            candidate_genes += genes
            print(f"{dataset}-{paralog_set}:\t{len(genes)}")
candidate_genes = list(set(candidate_genes))
len(candidate_genes)

### Load cell line info

In [None]:
cell_line_inf = get_from_taiga(name='public-20q2-075d', version=22, file='sample_info')
cell_line_inf.set_index('DepMap_ID', inplace=True)

# Make ALL subtype specific for b-ALL and t-ALL
cell_line_inf.loc[(cell_line_inf['lineage_subtype']=='ALL') & 
                  (cell_line_inf['lineage_sub_subtype'].str.contains('t')), 'lineage_subtype'] = 't-ALL'
cell_line_inf.loc[(cell_line_inf['lineage_subtype']=='ALL') &
                  (cell_line_inf['lineage_sub_subtype'].str.contains('b')), 'lineage_subtype'] = 'b-ALL'

cell_line_inf['specified_disease'] = cell_line_inf.loc[cell_line_inf.lineage_subtype.isin(PEDIATRIC_CANCERS)
                                                      ].lineage_subtype
cell_line_inf.specified_disease.fillna(cell_line_inf.primary_disease, inplace=True)

diseases = dict(cell_line_inf.specified_disease.value_counts())

cell_line_inf.head()

### Load Achilles gene dependency data

In [None]:
achilles_gene_dep = get_from_taiga(name='public-20q2-075d', version=22, file='Achilles_gene_dependency',
                                   split_attribute='header')
achilles_gene_dep.head()

### Load Achilles gene effect data

In [None]:
achilles_gene_eff = get_from_taiga(name='public-20q2-075d', version=22, file='Achilles_gene_effect',
                                   split_attribute='header')
achilles_gene_eff.head()

### Load DEMETER RNAi gene dependency data

In [None]:
demeter_gene_dep = get_from_taiga(name='demeter2-achilles-5386', version=13, file='gene_dependency',
                                  split_attribute='header')
demeter_gene_dep = remap_index(demeter_gene_dep, cell_line_inf.CCLE_Name, 'DepMap_ID')
demeter_gene_dep.head()

### Load DEMETER RNAi gene effect data

In [None]:
demeter_gene_eff = get_from_taiga(name='demeter2-achilles-5386', version=13, file='gene_effect',
                                  split_attribute='header')
demeter_gene_eff = remap_index(demeter_gene_eff, cell_line_inf.CCLE_Name, 'DepMap_ID')
demeter_gene_eff.head()

### Load sanger gene dependency data

In [None]:
sanger_gene_dep = get_from_taiga(name='sanger-crispr-project-score--e20b', version=4, file='gene_dependency',
                                 split_attribute='header')
sanger_gene_dep.head()

### Load sanger gene effect data

In [None]:
sanger_gene_eff = get_from_taiga(name='sanger-crispr-project-score--e20b', version=4, file='gene_effect',
                                 split_attribute='header')
sanger_gene_eff.head()

### Load LRT scores

In [None]:
lrt_scores = get_from_taiga(name='crispr-avana-4171', version=1, file='CRISPR-LRT',
                            split_attribute='column', col='Row.name')
lrt_scores.head()

## 2.  Filter on those genes which are skewed to the left

In [None]:
skewed_genes = lrt_scores.loc[(lrt_scores.LRT >= 100) & 
                              (lrt_scores.index.get_level_values(0).isin(candidate_genes))].index
len(skewed_genes)

In [None]:
def get_left_skewed_genes(gene_effs_data, skewed_gene_list):
    skewed_gene_effs = gene_effs_data.T.loc[skewed_genes]
    return skewed_gene_effs.loc[skewed_gene_effs.mean(axis=1) < skewed_gene_effs.median(axis=1)].index

In [None]:
achilles_left_skewed_genes = get_left_skewed_genes(achilles_gene_eff, skewed_genes)

demeter_left_skewed_genes = get_left_skewed_genes(demeter_gene_eff, skewed_genes)

sanger_left_skewed_genes = get_left_skewed_genes(sanger_gene_eff, skewed_genes)

In [None]:
fig=plt.figure(figsize=(5, 5), dpi=124, facecolor='w', edgecolor='k')
# plt.title('Selective dependencies from different screens.')
venn3([set(achilles_left_skewed_genes), set(demeter_left_skewed_genes), set(sanger_left_skewed_genes)], 
      set_labels=[f'Achilles CRISPR ({len(achilles_left_skewed_genes)})', 
                  f'DEMETER RNAi ({len(demeter_left_skewed_genes)})', 
                  f'Sanger CRISPR ({len(sanger_left_skewed_genes)})'])
plt.show()

## 3.  Identify in which cell lines the candidate genes are a selective dependency

In [None]:
def get_essential_lines(left_skewed_genes, gene_dep):
    essential_lines = {}
    for geneID, name in left_skewed_genes:
        lines = gene_dep.loc[gene_dep[geneID, name] > .5].index.values
        if len(lines) > 0:
            essential_lines[int(geneID)] = [name, len(set(lines)), set(lines)]
    return essential_lines

In [None]:
essential_in_cell_lines = {}

essential_genes_from_screen = {"Achilles CRISPR": [], "DEMETER RNAi": [], "Sanger CRISPR": []}

for left_skewed_genes, gene_dep, screen in [(achilles_left_skewed_genes, achilles_gene_dep, "Achilles CRISPR"), 
                                    (demeter_left_skewed_genes, demeter_gene_dep, "DEMETER RNAi"), 
                                    (sanger_left_skewed_genes, sanger_gene_dep, "Sanger CRISPR")]:
    for geneID, essential_lines in get_essential_lines(left_skewed_genes, gene_dep).items():
        essential_genes_from_screen[screen].append(geneID)
        if geneID not in essential_in_cell_lines:
            essential_in_cell_lines[geneID] = essential_lines
        else:
            essential_in_cell_lines[geneID][2].union(essential_lines[2])
            essential_in_cell_lines[geneID][1] = len(essential_in_cell_lines[geneID][2])

essential_in_cell_lines = pd.DataFrame(essential_in_cell_lines, index=["gene", "n_lines", "cell_lines"]).T
print(f"{len(essential_in_cell_lines)} genes affecting {len(set(j for i in essential_in_cell_lines.cell_lines.values for j in i))} cell lines")
essential_in_cell_lines

In [None]:
fig=plt.figure(figsize=(5, 5), dpi=124, facecolor='w', edgecolor='k')
# plt.title('Selective dependencies from different screens.')
venn3([set(i) for i in essential_genes_from_screen.values()], 
      set_labels=[f'{k} ({len(set(v))})' for k, v in essential_genes_from_screen.items()])
plt.show()

In [None]:
def draw_sorted_bar(dataset, title_text, cutoff):
    fig=plt.figure(figsize=(15, 7), dpi=124, facecolor='w', edgecolor='k')
    plt.xticks(rotation=65)
    
    plt.bar(*list(map(list, zip(*sorted([list(i) 
                                         for i in dataset.loc[dataset.n_lines > cutoff, 
                                                              ['gene', 'n_lines']].values], 
                                        key=lambda x: x[1], reverse=True)))))
    plt.xlim()

#     plt.figtext(.5,.92,f'Essential paralog genes in {title_text} cell lines', fontsize=20, ha='center')
#     plt.figtext(.5,.89,f'For genes essential in ≥{cutoff} cell lines.',fontsize=12,ha='center')

    plt.xlabel("Gene")
    plt.ylabel("Number of affected cell lines")
    plt.grid(linestyle='-', axis='y')
    plt.show()

draw_sorted_bar(essential_in_cell_lines, 'any', 50)

In [None]:
lines_per_disease = {}
for idx, row in tqdm(essential_in_cell_lines.iterrows(), total=len(essential_in_cell_lines)):
    if row.n_lines > 5:
        gene = f"{row.gene} ({idx})"
        lines_per_disease[gene] = {}
        for d in diseases:
            if d not in ["Unknown", "Embryonal Cancer", "Teratoma", "Adrenal Cancer"]:
                lines_per_disease[gene][d] = 0
        for cell_line in row.cell_lines:
            if cell_line in cell_line_inf.index:
                d = cell_line_inf.loc[cell_line].specified_disease
                lines_per_disease[gene][d] += 1

lines_per_disease = pd.DataFrame(lines_per_disease)
lines_per_disease

In [None]:
fig=plt.figure(figsize=(20, 10), dpi=124, facecolor='w', edgecolor='k')
# Use the list "selection" to define the genes to plot the clustermap for
selection = [10006, 483, 23545, 1000, 8450, 131118, 1871, 63916, 2625, 9759, 3845, 10905, 4170, 29107, 9943, 5290, 6688, 27183, 9525, 7454]
lines_per_disease = lines_per_disease[selection]
col_map = {g: "r" if int(g.split(" ")[-1].strip("()")) in sel else "k" for g in lines_per_disease.columns}
sns.clustermap(lines_per_disease, xticklabels=True, yticklabels=True, figsize=(50, 10),
               standard_scale=0,
               dendrogram_ratio=(.05, .2), cbar_pos=(-.03,.3,.01,.5),
               col_colors=lines_per_disease.columns.map(col_map))
plt.show()

## 4.  Save results

In [None]:
essential_in_cell_lines.to_pickle(RESULTS_FILE)