In [None]:
import collections
from pathlib import Path

import numpy as np
import pandas as pd
import networkx as nx
import pyranges as pr

from scipy import ndimage
from scipy.stats import fisher_exact

import statannot
import seaborn as sns
import matplotlib.pyplot as plt
from adjustText import adjust_text

import matplotlib.transforms as tx
from matplotlib.colors import SymLogNorm
from matplotlib.gridspec import GridSpec
from matplotlib.collections import LineCollection

import cooler

from natsort import natsorted
from tqdm.auto import tqdm, trange

In [None]:
sns.set_context('talk')
pd.set_option('display.max_columns', None)

# Parameters

In [None]:
fname_data = snakemake.input.fname_data
fname_enr = snakemake.input.fname_enr

sketch_hicfile = snakemake.input.sketch_hicfile
sketch_tadfile = snakemake.input.sketch_tadfile

sketch_region = snakemake.config['sketch']['region']

outdir = Path(snakemake.output.outdir)

# Load data

## Database data

In [None]:
df_data = pd.read_csv(fname_data, low_memory=True)
df_data.head()

## Enrichment data

In [None]:
df_enr = pd.read_csv(fname_enr, low_memory=True)

In [None]:
max_enrichment = 16
df_enr['pval_boundary_trans'] = df_enr['pval_boundary'].apply(lambda x: max_enrichment if x == 0 else -np.log10(x))

In [None]:
df_enr.head()

# Publication Figures

## Figure 1: Overview Sketch

Interesting regions:
* chr6:25960228-31407777 (hg19), chr6:25960000-31440000 (hg38)

### Read contact matrix

In [None]:
c = cooler.Cooler(sketch_hicfile)

mat = c.matrix(balance=False).fetch(sketch_region)
df_bins = c.bins().fetch(sketch_region)

df_mat = pd.DataFrame(mat, index=df_bins['start'], columns=df_bins['start'])

In [None]:
df_mat.head()

### Read TADs

In [None]:
df_tads = pr.PyRanges(pd.read_csv(sketch_tadfile).rename(columns={'chrname': 'Chromosome', 'tad_start': 'Start', 'tad_stop': 'End'}))
df_tads.head()

In [None]:
df_region = pr.PyRanges(pd.DataFrame({
    'Chromosome': [sketch_region[0]],
    'Start': [sketch_region[1]],
    'End': [sketch_region[2]]
}))
df_region

In [None]:
tad_list = df_tads.overlap(df_region)
# tad_list = tad_list[(tad_list.Start >= sketch_region[1]) & (tad_list.End <= sketch_region[2])]
tad_list

#### Find TAD boundaries

In [None]:
boundary_size = -20_000

def get_boundaries(df):
    tmp_front = df.copy()
    tmp_front.loc[:, 'End'] = tmp_front.Start - boundary_size
    tmp_front['type'] = 'front'
    
    tmp_back = df.copy()
    tmp_back.loc[:, 'Start'] = tmp_back.End + boundary_size
    tmp_back['type'] = 'back'

    return pd.concat([tmp_front, tmp_back])

In [None]:
border_list = tad_list.apply(get_boundaries)
border_list

### Read SNPs

In [None]:
tmp = df_data[['chromosome_hg19', 'position_hg19', 'snpId']].drop_duplicates().dropna().copy().rename(columns={'chromosome_hg19': 'Chromosome', 'position_hg19': 'Start'})
tmp['Chromosome'] = 'chr' + tmp['Chromosome'].astype(str)
tmp['End'] = tmp['Start'] + 1

df_snps = pr.PyRanges(tmp)
df_snps

In [None]:
# find SNPs in TAD borders
snp_list = df_snps.overlap(border_list)
snp_list

In [None]:
# classify SNPs
def classify(x):
    return x['is_cancer'].value_counts().idxmax()

snp_cancer_map = df_data[['diseaseId', 'snpId', 'is_cancer']].drop_duplicates().dropna().groupby('snpId').apply(classify).to_dict()
list(snp_cancer_map.items())[:2]

### Create plot

In [None]:
# rotate contact matrix
mat_rot = ndimage.rotate(df_mat, 45, order=0, reshape=True, cval=0, prefilter=False)

# create figure
s = 2
fig = plt.figure(figsize=(s*8, s*8))

# main plot
ax_main = plt.gca()

# plot rotated contact matrix
ax_main.matshow(
    mat_rot,
    norm=SymLogNorm(1),
    cmap='YlOrRd',
    origin='lower',
    extent=(
        df_mat.index[0] + .5, df_mat.index[-1] + .5,
        df_mat.index[0] + .5, df_mat.index[-1] + .5
    ),
    aspect='equal'
)

# adjust axes
center_height = (df_mat.index[0] + df_mat.index[-1]) / 2 + .5
ax_main.set_ylim(center_height, df_mat.index[-1] + .5)

ax_main.set_xlabel(sketch_region[0])
ax_main.xaxis.set_label_position('top') 
ax_main.set_xlim(sketch_region[1], sketch_region[2])

ax_main.tick_params(
    axis='both',
    reset=True,
    which='both',
    top=True, labeltop=True,
    right=False, labelright=False,
    left=False, labelleft=False,
    bottom=False, labelbottom=False,
)

ax_main.ticklabel_format(axis='both', style='plain')

# highlight TADs
for row in tad_list.df.itertuples():
    tmp = np.sin(np.deg2rad(90)) * (row.End - row.Start) / 2

    pg = plt.Polygon([
        [row.Start, center_height], 
        [(row.Start + row.End) / 2, center_height + tmp], 
        [row.End, center_height]
    ], edgecolor='black', facecolor='none')
    ax_main.add_patch(pg)

# highlight TAD boundaries
for row in border_list.df.itertuples():
    tmp = np.tan(np.deg2rad(45)) * (row.End - row.Start)

    pg = plt.Polygon([
        [row.Start, center_height], 
        [row.End if row.type == 'front' else row.Start, center_height + tmp], 
        [row.End, center_height]
    ], edgecolor='black', facecolor='gray', alpha=.6)
    ax_main.add_patch(pg)
    
# auxiliary plot
ax_height = .3
ax_aux = ax_main.inset_axes([0, -ax_height, 1, ax_height])

ax_main.get_shared_x_axes().join(ax_main, ax_aux)
ax_aux.set_xlim(sketch_region[1], sketch_region[2])
    
# highlight SNPs
snp_pos = snp_list.df['Start']
snp_colors = 'black' #['orange' if snp_cancer_map[s] else 'blue' for s in snp_list.df['snpId']]
snp_ls = ['solid' if snp_cancer_map[s] else 'dashed' for s in snp_list.df['snpId']]

trans = tx.blended_transform_factory(ax_aux.transData, ax_aux.transAxes)
xy_pairs = np.column_stack([np.repeat(snp_pos, 2), np.tile([.8, 1], len(snp_pos))])

line_segs = xy_pairs.reshape([len(snp_pos), 2, 2])
ax_aux.add_collection(LineCollection(line_segs, transform=trans, colors=snp_colors, ls=snp_ls))

ax_aux.axis('off')

annotation_list = []
for row in snp_list.df.itertuples():
    pos = row.Start
    id_ = row.snpId
    
    a = ax_aux.annotate(
        id_,
        xy=(pos, .8), xytext=(pos, .5),
        xycoords=('data', 'axes fraction'), textcoords=('data', 'axes fraction'),
        arrowprops=dict(arrowstyle='-'),
        annotation_clip=False,
        fontsize=13)
    annotation_list.append(a)

adjust_text(annotation_list, ax=ax_aux)

# save figure
plt.tight_layout()
plt.savefig(outdir / 'sketch.pdf')

## Figure 2: p-value Histogram for nice case

In [None]:
def plot_histogram(df, fname, bins=np.linspace(0, 3, 50)):
    plt.figure(figsize=(8, 6))
    
    for is_cancer, group in df.groupby('is_cancer'):
        tmp = group.copy()

        tmp.loc[tmp['pval_boundary_trans'] > 3, 'pval_boundary_trans'] = 3

        sns.distplot(
            tmp['pval_boundary_trans'],
            kde=False, norm_hist=True, bins=bins,
            label=is_cancer
        )

    plt.xlabel(r'$-log_{10}(\mathrm{pvalue})$')
    plt.ylabel('Frequency')
    
    plt.ylim(0, 1)

    plt.axvline(-np.log10(.05), ls='dashed', color='red')
    plt.legend(loc='best', title='is_cancer')

    plt.tight_layout()
    plt.savefig(outdir / fname)

### Subfigure 2a: nofilter

In [None]:
sub_nofilter = df_enr[
    (df_enr['filter'] == 'nofilter') &
    (df_enr['TAD_type'] == '20in') &
    (df_enr['#border_snp'] > 10)
]

In [None]:
plot_histogram(sub_nofilter, 'pvalue_histogram_nofilter.pdf')

### Subfigure 2b: intergenic filter

In [None]:
sub_intergenic = df_enr[
    (df_enr['filter'] == 'intergenic') &
    (df_enr['TAD_type'] == '20in') &
    (df_enr['#border_snp'] > 10)
]

In [None]:
plot_histogram(sub_intergenic, 'pvalue_histogram_intergenic.pdf')

## Figure 3: Zoomed-out multi-dataset overview

In [None]:
df_enr.head(2)

In [None]:
%%time

df_detailed = (df_enr.groupby(['TAD_type', 'filter', 'tad_source', 'is_cancer', 'diseaseId'])
       .apply(lambda x: (x['pval_boundary'] <= .05).mean() > .5)
       .to_frame('majority_is_sig')
       .reset_index()
)

df_majority = (df_detailed.drop('diseaseId', axis=1)
       .groupby(['TAD_type', 'filter', 'tad_source', 'is_cancer'])['majority_is_sig']
       .apply(lambda x: x.mean())
       .to_frame('sig_frac')
       .reset_index()
)
df_majority.head()

In [None]:
def select_one(df, col):
    assert df[col].nunique() == 1
    return df[col].iloc[0]

def compute_pvalue(x, y):
    ct = pd.crosstab(x, y)
    return fisher_exact(ct)[1]

def custom_barplot(*args, **kwargs):
    data = kwargs['data']
    x = kwargs['x']
    y = kwargs['y']
    hue = kwargs['hue']
    
    order = ['nofilter', 'exonic', 'intronic', 'intergenic']
    assert set(order) <= set(data[x].unique())
    hue_order = data[hue].unique()
    
    # standard barplot
    ax = sns.barplot(*args, **kwargs, order=order, hue_order=hue_order)
    
    # add annotations
    box_pairs = []
    pvalues = []
    
    for filter_ in order:
        box_pairs.append(((filter_, False), (filter_, True)))
        
        # select corresponding detailed data
        tad_type = select_one(data, 'TAD_type')
        tad_source = select_one(data, 'tad_source')
        
        sub_detailed = df_detailed[(df_detailed['TAD_type'] == tad_type) & (df_detailed['tad_source'] == tad_source) & (df_detailed['filter'] == filter_)]
        
        # compute p-value
        sub_x = sub_detailed['is_cancer']
        sub_y = sub_detailed['majority_is_sig']
        pval = compute_pvalue(sub_x, sub_y)
        
        pvalues.append(pval)
            
    statannot.add_stat_annotation(
        ax, plot='barplot',
        data=kwargs['data'], x=kwargs['x'], y=kwargs['y'], hue=kwargs['hue'],
        order=order, hue_order=hue_order,
        box_pairs=box_pairs,
        text_format='simple',
        pvalues=pvalues,
        perform_stat_test=False
    )

In [None]:
g = sns.FacetGrid(df_majority, row='TAD_type', col='tad_source', height=5, aspect=2)

g.map_dataframe(custom_barplot, x='filter', y='sig_frac', hue='is_cancer', palette='tab10')

g.set_axis_labels('Filter', 'Disease fraction sig. in $>0.5$ cases')
g.add_legend(title='is_cancer')

for ax in g.axes.flat:
    ax.tick_params(labelbottom=True)

g.savefig(outdir / 'dataset_overview.pdf')

## Figure 4: multipartite graph

## Helper functions

In [None]:
def assemble_network(df, min_snp_num=1, hub_threshold=0, verbose=False):
    # remove hub diseases
#     counts = df_data.groupby('diseaseId')['snpId'].nunique().sort_values()
#     nonhub_diseases = counts[counts > hub_threshold].index
    
#     df = df[df['diseaseId'].isin(nonhub_diseases)]
    
    # transform dataframe to networkx object
    graph = nx.from_pandas_edgelist(df, 'diseaseId', 'snpId')
    graph.name = 'bipartite graph'
    if verbose: print(nx.info(graph))
    
    # project bipartite graph
    graph_proj_multi = nx.bipartite.projected_graph(graph, df['diseaseId'].unique().tolist(), multigraph=True)
    graph_proj_multi.name = 'projected multigraph'
    if verbose: print(nx.info(graph_proj_multi))
    
    # remove isolated nodes
    if verbose: print('Remove isolated nodes')
    graph_proj_multi.remove_nodes_from(list(nx.isolates(graph_proj_multi)))
    if verbose: print(nx.info(graph_proj_multi))
    
    # merge multiple edges in MultiGraph into single edge with weight attribute
    graph_proj = nx.Graph()
    graph_proj.name = 'projected graph'

    for u, v, data in graph_proj_multi.edges(data=True):
        w = data.get('weight', 1)
        assert w == 1, 'oopsie'

        if graph_proj.has_edge(u, v):
            graph_proj[u][v]['weight'] += w
        else:
            graph_proj.add_edge(u, v, weight=w)

    if verbose: print(nx.info(graph_proj))
    
    # subset graph based on SNP count threshold
    if verbose: print('Subset graph')
    graph_proj = graph_proj.edge_subgraph([(u, v) for u, v, data in graph_proj.edges(data=True) if data['weight'] >= min_snp_num])
    if verbose: print(nx.info(graph_proj))
    
    return graph_proj

In [None]:
def plot_network(graph):
    # graph layout
    pos = nx.drawing.nx_agraph.graphviz_layout(graph, prog='neato', args='-Goverlap=scale')
    
    # node colors
    iscancer_map = df_data.set_index('diseaseId').loc[list(graph.nodes())].to_dict()['is_cancer']
    node_color_list = ['orange' if iscancer_map[n] else 'blue' for n in graph.nodes()]
    
    # edge colors
#     max_weight = max(nx.get_edge_attributes(graph, 'weight').values())
#     edge_color_list = [(0, 0, 0, data['weight'] / max_weight) for u, v, data in graph.edges(data=True)]
    edge_color_list = (0, 0, 0, .2)
#     edge_color_list = [(0, 0, 0, .5 if data['weight'] > 1 else 0) for u, v, data in graph.edges(data=True)]
    
    # do plot
    nx.draw_networkx_nodes(graph, pos, node_size=20, node_color=node_color_list)
    nx.draw_networkx_edges(graph, pos, edge_color=edge_color_list)

    plt.axis('off')

### Create network

In [None]:
df_data.head()

In [None]:
sub_boundary = df_data[(df_data['20in'] == 'boundary')]
graph_boundary = assemble_network(sub_boundary, 2)

In [None]:
sub_nonboundary = df_data[~(df_data['20in'] == 'boundary')]
graph_nonboundary = assemble_network(sub_nonboundary, 2)

### Plot networks

In [None]:
plt.figure(figsize=(12, 6))

plt.subplot(121)
plot_network(graph_nonboundary)
plt.title('nonboundary SNPs')

plt.subplot(122)
plot_network(graph_boundary)
plt.title('boundary SNPs')

plt.tight_layout()
plt.savefig(outdir / 'networks.pdf')

## Edge statistics

In [None]:
def compute_edge_asymmetries(graph):
    disease_nodes = list(graph.nodes())
    disease_pairs = list(graph.edges())
    iscancer_map = df_data.set_index('diseaseId').loc[list(graph.nodes())].to_dict()['is_cancer']
    
    tmp = dict(collections.Counter([iscancer_map[d] for d in disease_nodes]).most_common())
    node_type_counts = {
        'cancer': tmp[True],
        'noncancer': tmp[False] 
    }
    
    tmp = dict(collections.Counter([(iscancer_map[d1], iscancer_map[d2]) for d1, d2 in disease_pairs]).most_common())
    edge_type_counts = {
        'cancer_pairs': tmp.get((True, True), 0),
        'noncancer_pairs': tmp.get((False, False), 0),
        'across': tmp.get((True, False), 0) + tmp.get((False, True), 0)
    }

    cancer_asym = 2 * edge_type_counts['cancer_pairs'] / (node_type_counts['cancer'] * (node_type_counts['cancer'] - 1)) - edge_type_counts['across'] / (node_type_counts['cancer'] * node_type_counts['noncancer'])
    noncancer_asym = 2 * edge_type_counts['noncancer_pairs'] / (node_type_counts['noncancer'] * (node_type_counts['noncancer'] - 1)) - edge_type_counts['across'] / (node_type_counts['cancer'] * node_type_counts['noncancer'])
    
    return (cancer_asym, noncancer_asym)

In [None]:
sub = df_data[(df_data['20in'] == 'boundary')]

tmp = []
for thres in trange(1, 10):
    graph = assemble_network(sub, thres, hub_threshold=0)
    asyms = compute_edge_asymmetries(graph)
    
    tmp.append({
        'threshold': thres,
        'cancer_asymmetry': asyms[0],
        'noncancer_asymmetry': asyms[1],
        'node_count': len(graph.nodes()),
        'edge_count': len(graph.edges())
    })

df_asym_boundary = pd.DataFrame(tmp)
df_asym_boundary.head()

In [None]:
sub = df_data[~(df_data['20in'] == 'boundary')]

tmp = []
for thres in trange(1, 10):
    graph = assemble_network(sub, thres, hub_threshold=0)
    asyms = compute_edge_asymmetries(graph)
    
    tmp.append({
        'threshold': thres,
        'cancer_asymmetry': asyms[0],
        'noncancer_asymmetry': asyms[1],
        'node_count': len(graph.nodes()),
        'edge_count': len(graph.edges())
    })

df_asym_nonboundary = pd.DataFrame(tmp)
df_asym_nonboundary.head()

In [None]:
df_asym_boundary['type'] = 'boundary'
df_asym_nonboundary['type'] = 'nonboundary'

df_long = pd.melt(pd.concat([df_asym_boundary, df_asym_nonboundary]), id_vars=['threshold', 'type'])
df_long.head()

In [None]:
plt.figure(figsize=(8, 6))

sns.lineplot(x='threshold', y='value', hue='variable', style='type', data=df_long, hue_order=['noncancer_asymmetry', 'cancer_asymmetry'])

plt.tight_layout()
plt.savefig(outdir / 'edge_asymmetries.pdf')

In [None]:
plt.figure(figsize=(8, 6))
sns.lineplot(x='threshold', y='value', hue='variable', style='type', data=df_long, hue_order=['node_count', 'edge_count'])

In [None]:
# compute
df_quotient = pd.DataFrame({
    'threshold': df_asym_boundary.loc[df_asym_boundary.index, 'threshold'],
    'noncancer_quotient': df_asym_boundary.loc[df_asym_boundary.index, 'noncancer_asymmetry'] / df_asym_nonboundary.loc[df_asym_boundary.index, 'noncancer_asymmetry'],
    'cancer_quotient': df_asym_boundary.loc[df_asym_boundary.index, 'cancer_asymmetry'] / df_asym_nonboundary.loc[df_asym_boundary.index, 'cancer_asymmetry']
})

# plot
plt.figure(figsize=(8, 6))

sns.lineplot(x='threshold', y='value', hue='variable', data=pd.melt(df_quotient, id_vars='threshold'), hue_order=['noncancer_quotient', 'cancer_quotient'])
plt.axhline(1, color='red', ls='dashed')

plt.tight_layout()
plt.savefig(outdir / 'edge_asymmetry_quotients.pdf')

# Miscellaneous Figures

## SNP associations to Cancer and Non-Cancer

In [None]:
df_data.head()

In [None]:
%%time

tmp_sub = df_data.loc[:, ['diseaseId', 'snpId', 'is_cancer']].drop_duplicates().dropna()
tmp_sub['is_cancer_shuffled'] = np.random.permutation(tmp_sub['is_cancer'].values)

df_cancercounts = (tmp_sub.groupby('snpId')['is_cancer']
                          .apply(lambda x: pd.Series(
                              [x.tolist().count(False), x.tolist().count(True)],
                              index=['noncancer_count', 'cancer_count']))
                          .unstack())
df_cancercountsshuffled = (tmp_sub.groupby('snpId')['is_cancer_shuffled']
                                  .apply(lambda x: pd.Series(
                                      [x.tolist().count(False), x.tolist().count(True)],
                                      index=['noncancer_count', 'cancer_count']))
                                  .unstack())

In [None]:
df_cancercounts.head()

In [None]:
plt.figure(figsize=(2*8, 6))

plt.subplot(121)
sns.scatterplot(x='cancer_count', y='noncancer_count', data=df_cancercounts, alpha=.2)
plt.title('observed')

plt.subplot(122)
sns.scatterplot(x='cancer_count', y='noncancer_count', data=df_cancercountsshuffled, alpha=.2)
plt.title('shuffled')