This notebook clusters cohorts according to their driver genes composition. The plots correspond to Figure 1d in the paper. The three plots were then joined with SVG editing software. It takes as input the list of candidate driver mutations outputted by driver_mutations_primary_ALL.ipynb

In [None]:
import pandas as pd
import numpy as np
import os
import glob
from collections import OrderedDict

from scipy.stats import entropy
import scipy.cluster.hierarchy as hierarchy
from scipy.spatial.distance import pdist,squareform

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns

from aux_data_in_pyvar import config_rcparams,COLORS_SUBTYPES

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_colwidth', -1)

In [None]:
# FUNCTIONS

def normalize(rw, dict_sum):
    return rw/dict_sum[rw.name]

def jensen_shannon(x, y):
    m = 0.5 * (x + y)
    return 0.5 * (entropy(x, m) + entropy(y, m))

In [None]:
config_rcparams()

In [None]:
df = pd.read_csv("candidate_driver_muts.tsv", sep='\t')## add the path where this file is stored

# filter 
df_pry = df[df['STAGE'] == 'primary']
df_pry = df_pry[df_pry['SUBTYPE_LABEL'] != 'PHALL'] # very few patients (only 4)
df_pry = df_pry[df_pry['COHORT'] != 'PEDIATRIC ALL (Li et al., 2019, Blood)']

In [None]:
# black_list genes highly mutated in two cohorts coming from the same project. Suspects of FP
potential_false_positives = ['MSH3', 'MAP3K4']
df_pry = df_pry[~df_pry['SYMBOL'].isin(potential_false_positives)]

In [None]:
# count most mutated genes
counter = df_pry[['SYMBOL','SUBTYPE_LABEL']].drop_duplicates().groupby(['SYMBOL']).count().sort_values(by='SUBTYPE_LABEL', ascending=False)
counter.reset_index(inplace=True)
counter.rename(columns={'SUBTYPE_LABEL':'COUNT'}, inplace=True)
counter

In [None]:
# get list of genes which have at least mutations in 2 cohorts
genes_to_figure = counter[counter['COUNT']>2]['SYMBOL'].tolist()
df_pry = df_pry[df_pry['SYMBOL'].isin(genes_to_figure)]

df_pry_subset = df_pry[['SYMBOL', 'COMPARISON', 'SUBTYPE_LABEL']].drop_duplicates()

In [None]:
# create a matrix of cohorts as rows and columns as genes with relative counts of mutations

num_pat = df_pry_subset[['SUBTYPE_LABEL', 'COMPARISON']].drop_duplicates().groupby('SUBTYPE_LABEL').count()

df_pivot = pd.DataFrame(index=df_pry_subset.SUBTYPE_LABEL.unique(), columns=df_pry_subset.SYMBOL.unique())

grps = df_pry_subset.groupby(['SUBTYPE_LABEL', 'SYMBOL'])

for g in grps.groups:
    df_subset = grps.get_group(g)
    mut_pats = len(df_subset['COMPARISON'].unique())
    
    df_pivot.set_value(g[0], g[1], mut_pats/num_pat.loc[g[0], 'COMPARISON'])
df_pivot = df_pivot.fillna(0)
df_pivot

In [None]:
# get the total per cohort
dicc_total = df_pivot.sum(axis=1).to_dict()
dicc_total

In [None]:
# normalize the relative counts
df_pivot = df_pivot.apply(lambda x: normalize(x, dicc_total), axis=1)
df_pivot

In [None]:
# sort genes in matrix by number of samples mutated in descending and create sorted lists to use in plot
sorter = list(df_pry_subset[['SYMBOL','COMPARISON']].groupby("SYMBOL").count().sort_values(by='COMPARISON',ascending=False).index)
df_pivot = df_pivot[sorter]
list_genes = df_pivot.columns
list_subtypes = df_pry_subset['SUBTYPE_LABEL'].unique()
counter = df_pry_subset.groupby(['SYMBOL', 'SUBTYPE_LABEL']).count()

In [None]:
# get annotations of total number of samples on each matrix cell
df_annot = pd.DataFrame(index=df_pry_subset.SUBTYPE_LABEL.unique(), columns=df_pry_subset.SYMBOL.unique())

grps = df_pry_subset.groupby(['SUBTYPE_LABEL', 'SYMBOL'])

for g in grps.groups:
    df_subset = grps.get_group(g)
    mut_pats = len(df_subset['COMPARISON'].unique())
    
    df_annot.set_value(g[0], g[1], mut_pats)

In [None]:
# also sort annotations that accompany matrix
df_annot = df_annot.fillna(0)
df_annot = df_annot[sorter]

In [None]:
## MAKE PLOTS

output = "cancer_genes_in_ALL_primary.svg"

fig = plt.figure(figsize=(14, 14))
fig.suptitle("Mutated cancer genes in ALL subtypes")

gs = gridspec.GridSpec(ncols=2, nrows=2, width_ratios=[2, 1], height_ratios=[1,4],hspace=0.5)

# HIEARCHICAL CLUSTERING
ax0 = fig.add_subplot(gs[0,0])

X = df_pivot.values
Y = pdist(X, metric=jensen_shannon)
linkage = hierarchy.linkage(Y, method='ward')
dist_matrix = squareform(Y)

hierarchy.dendrogram(linkage,truncate_mode='level',
                    labels=df_pivot.index.values,
                    leaf_rotation=90,
                    color_threshold=0,
                    above_threshold_color='gray',
                    no_plot=False,
                    ax=ax0)

xlabels = [item.get_text() for item in ax0.axes.get_xticklabels()]
xticks = ax0.axes.get_xticks()

# HEATMAP

ax1 = fig.add_subplot(gs[1,0])
df_values = df_pivot.reindex(xlabels).T
df_nums = df_annot.reindex(xlabels).T

ax1.set_ylim([-1, len(list_genes)+1])

sns.heatmap(df_values, annot=df_nums, ax=ax1, cmap='Blues',
            cbar_kws={'shrink':0.2,'use_gridspec':True, 'pad':0.01})

ax1.set_yticks(range(0,len(list_genes)+1, 1))
ax1.set_yticklabels(labels=list_genes, rotation=0, fontstyle='italic',va='center')

#BARPLOT
ax2 = fig.add_subplot(gs[1,1], sharey=ax1)
barWidth = 1

for i,gene in enumerate(list_genes):
    suma = 0
    for j,sub in enumerate(list_subtypes):
        if j == 0:
            try:
                ax2.barh(i,counter.loc[(gene, sub), 'COMPARISON'], color=COLORS_SUBTYPES[sub], edgecolor='white', 
                            height=1, label=sub, align='edge')
                suma = suma+counter.loc[(gene, sub), 'COMPARISON']
            except KeyError:
                ax2.barh(i,0, color=COLORS_SUBTYPES[sub], edgecolor='white', 
                            height=1, label=sub, align='edge')
                suma = suma+0
        else:
            try:
                ax2.barh(i,counter.loc[(gene, sub), 'COMPARISON'], color=COLORS_SUBTYPES[sub], edgecolor='white', 
                        height=1, label=sub, left=suma, align='edge')
                suma = suma+counter.loc[(gene, sub), 'COMPARISON']
            except KeyError:
                ax2.barh(i, 0, color=COLORS_SUBTYPES[sub], edgecolor='white', 
                            height=1, label=sub, left=suma, align='edge')
                suma = suma+0

ax2.spines['right'].set_visible(False)
ax2.spines['bottom'].set_visible(False)
ax2.xaxis.set_tick_params(reset=True,labeltop=True, top=True, bottom=False, labelbottom=False)   
ax2.set_yticks(range(0,len(list_genes)+1, 1))
ax2.set_yticklabels(labels=list_genes, rotation=0, fontstyle='italic', va='center')

# Legend
handles, labels = plt.gca().get_legend_handles_labels()
by_label = OrderedDict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys(),prop={'size': 10},ncol=2, bbox_to_anchor=(1, 1))

plt.tight_layout()
plt.savefig(output, bbox_inches='tight', dpi=300)
plt.show()