In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import networkx as nx
import pyranges as pr

from scipy import ndimage
from scipy.stats import fisher_exact

import statannot
import seaborn as sns
import matplotlib.pyplot as plt

import matplotlib.transforms as tx
from matplotlib.colors import SymLogNorm
from matplotlib.collections import LineCollection

import cooler

from natsort import natsorted
from tqdm.auto import tqdm

In [None]:
sns.set_context('talk')
pd.set_option('display.max_columns', None)

# Parameters

In [None]:
fname_data = snakemake.input.fname_data
fname_enr = snakemake.input.fname_enr

sketch_hicfile = snakemake.input.sketch_hicfile
sketch_tadfile = snakemake.input.sketch_tadfile

sketch_region = snakemake.config['sketch']['region']

outdir = Path(snakemake.output.outdir)

# Load data

## Database data

In [None]:
df_data = pd.read_csv(fname_data, low_memory=True)
df_data.head()

## Enrichment data

In [None]:
df_enr = pd.read_csv(fname_enr, low_memory=True)

In [None]:
max_enrichment = 16
df_enr['pval_boundary_trans'] = df_enr['pval_boundary'].apply(lambda x: max_enrichment if x == 0 else -np.log10(x))

In [None]:
df_enr.head()

# Create figures

## Figure 1: Overview Sketch

Interesting regions:
* chr6:25960228-31407777 (hg19), chr6:25960000-31440000 (hg38)

### Read contact matrix

In [None]:
c = cooler.Cooler(sketch_hicfile)

mat = c.matrix(balance=False).fetch(sketch_region)
df_bins = c.bins().fetch(sketch_region)

df_mat = pd.DataFrame(mat, index=df_bins['start'], columns=df_bins['start'])

In [None]:
df_mat.head()

### Read TADs

In [None]:
df_tads = pr.PyRanges(pd.read_csv(sketch_tadfile).rename(columns={'chrname': 'Chromosome', 'tad_start': 'Start', 'tad_stop': 'End'}))
df_tads.head()

In [None]:
df_region = pr.PyRanges(pd.DataFrame({
    'Chromosome': [sketch_region[0]],
    'Start': [sketch_region[1]],
    'End': [sketch_region[2]]
}))
df_region

In [None]:
tad_list = df_tads.overlap(df_region)
# tad_list = tad_list[(tad_list.Start >= sketch_region[1]) & (tad_list.End <= sketch_region[2])]
tad_list

#### Find TAD boundaries

In [None]:
boundary_size = -20_000

def get_boundaries(df):
    tmp_front = df.copy()
    tmp_front.loc[:, 'End'] = tmp_front.Start - boundary_size
    tmp_front['type'] = 'front'
    
    tmp_back = df.copy()
    tmp_back.loc[:, 'Start'] = tmp_back.End + boundary_size
    tmp_back['type'] = 'back'

    return pd.concat([tmp_front, tmp_back])

In [None]:
border_list = tad_list.apply(get_boundaries)
border_list

### Read SNPs

In [None]:
tmp = df_data[['chromosome_hg19', 'position_hg19', 'snpId']].drop_duplicates().dropna().copy().rename(columns={'chromosome_hg19': 'Chromosome', 'position_hg19': 'Start'})
tmp['Chromosome'] = 'chr' + tmp['Chromosome'].astype(str)
tmp['End'] = tmp['Start'] + 1

df_snps = pr.PyRanges(tmp)
df_snps

In [None]:
# find SNPs in TAD borders
snp_list = df_snps.overlap(border_list)
snp_list

In [None]:
# classify SNPs
def classify(x):
    return x['is_cancer'].value_counts().idxmax()

snp_cancer_map = df_data[['diseaseId', 'snpId', 'is_cancer']].drop_duplicates().dropna().groupby('snpId').apply(classify).to_dict()
list(snp_cancer_map.items())[:2]

### Create plot

In [None]:
# rotate contact matrix
mat_rot = ndimage.rotate(df_mat, 45, order=0, reshape=True, cval=0, prefilter=False)

# create figure
s = 2
plt.figure(figsize=(s*8, s*6))

# plot rotated contact matrix
plt.matshow(
    mat_rot,
    norm=SymLogNorm(1),
    cmap='YlOrRd',
    origin='lower',
    extent=(
        df_mat.index[0] + .5, df_mat.index[-1] + .5,
        df_mat.index[0] + .5, df_mat.index[-1] + .5
    ),
    fignum=0,
    aspect='equal'
)

# adjust axes
center_height = (df_mat.index[0] + df_mat.index[-1]) / 2 + .5
plt.ylim(center_height, df_mat.index[-1] + .5)

plt.xlabel(sketch_region[0])
plt.xlim(sketch_region[1], sketch_region[2])

plt.tick_params(
    axis='both',
    reset=True,
    which='both',
    top=False, right=False,
    left=False, labelleft=False,
    bottom=True, labelbottom=True,
)

plt.ticklabel_format(axis='both', style='plain')

# highlight TADs
for row in tad_list.df.itertuples():
    tmp = np.sin(np.deg2rad(90)) * (row.End - row.Start) / 2

    pg = plt.Polygon([
        [row.Start, center_height], 
        [(row.Start + row.End) / 2, center_height + tmp], 
        [row.End, center_height]
    ], edgecolor='black', facecolor='none')
    plt.gca().add_patch(pg)

# highlight TAD boundaries
for row in border_list.df.itertuples():
    tmp = np.tan(np.deg2rad(45)) * (row.End - row.Start)

    pg = plt.Polygon([
        [row.Start, center_height], 
        [row.End if row.type == 'front' else row.Start, center_height + tmp], 
        [row.End, center_height]
    ], edgecolor='black', facecolor='gray', alpha=.6)
    plt.gca().add_patch(pg)
    
# highlight SNPs
snp_pos = snp_list.df['Start']
snp_colors = ['orange' if snp_cancer_map[s] else 'blue' for s in snp_list.df['snpId']]

ax = plt.gca()
trans = tx.blended_transform_factory(ax.transData, ax.transAxes)
xy_pairs = np.column_stack([np.repeat(snp_pos, 2), np.tile([.98, 1], len(snp_pos))])

line_segs = xy_pairs.reshape([len(snp_pos), 2, 2])
ax.add_collection(LineCollection(line_segs, transform=trans, colors=snp_colors))
    
# save figure
plt.tight_layout()
plt.savefig(outdir / 'sketch.pdf')

## Figure 2: p-value Histogram for nice case

In [None]:
def plot_histogram(df, fname, bins=np.linspace(0, 3, 50)):
    plt.figure(figsize=(8, 6))
    
    for is_cancer, group in df.groupby('is_cancer'):
        tmp = group.copy()

        tmp.loc[tmp['pval_boundary_trans'] > 3, 'pval_boundary_trans'] = 3

        sns.distplot(
            tmp['pval_boundary_trans'],
            kde=False, norm_hist=True, bins=bins,
            label=is_cancer
        )

    plt.xlabel(r'$-log_{10}(\mathrm{pvalue})$')
    plt.ylabel('Frequency')
    
    plt.ylim(0, 1)

    plt.axvline(-np.log10(.05), ls='dashed', color='red')
    plt.legend(loc='best', title='is_cancer')

    plt.tight_layout()
    plt.savefig(outdir / fname)

### Subfigure 2a: nofilter

In [None]:
sub_nofilter = df_enr[
    (df_enr['filter'] == 'nofilter') &
    (df_enr['TAD_type'] == '20in') &
    (df_enr['#border_snp'] > 10)
]

In [None]:
plot_histogram(sub_nofilter, 'pvalue_histogram_nofilter.pdf')

### Subfigure 2b: intergenic filter

In [None]:
sub_intergenic = df_enr[
    (df_enr['filter'] == 'intergenic') &
    (df_enr['TAD_type'] == '20in') &
    (df_enr['#border_snp'] > 10)
]

In [None]:
plot_histogram(sub_intergenic, 'pvalue_histogram_intergenic.pdf')

## Figure 3: Zoomed-out multi-dataset overview

In [None]:
df_enr.head(2)

In [None]:
%%time

df_detailed = (df_enr.groupby(['TAD_type', 'filter', 'tad_source', 'is_cancer', 'diseaseId'])
       .apply(lambda x: (x['pval_boundary'] <= .05).mean() > .5)
       .to_frame('majority_is_sig')
       .reset_index()
)

df_majority = (df_detailed.drop('diseaseId', axis=1)
       .groupby(['TAD_type', 'filter', 'tad_source', 'is_cancer'])['majority_is_sig']
       .apply(lambda x: x.mean())
       .to_frame('sig_frac')
       .reset_index()
)
df_majority.head()

In [None]:
def select_one(df, col):
    assert df[col].nunique() == 1
    return df[col].iloc[0]

def compute_pvalue(x, y):
    ct = pd.crosstab(x, y)
    return fisher_exact(ct)[1]

def custom_barplot(*args, **kwargs):
    data = kwargs['data']
    x = kwargs['x']
    y = kwargs['y']
    hue = kwargs['hue']
    
    order = data[x].unique()
    hue_order = data[hue].unique()
    
    # standard barplot
    ax = sns.barplot(*args, **kwargs, order=order, hue_order=hue_order)
    
    # add annotations
    box_pairs = []
    pvalues = []
    
    for filter_ in order:
        box_pairs.append(((filter_, False), (filter_, True)))
        
        # select corresponding detailed data
        tad_type = select_one(data, 'TAD_type')
        tad_source = select_one(data, 'tad_source')
        
        sub_detailed = df_detailed[(df_detailed['TAD_type'] == tad_type) & (df_detailed['tad_source'] == tad_source) & (df_detailed['filter'] == filter_)]
        
        # compute p-value
        sub_x = sub_detailed['is_cancer']
        sub_y = sub_detailed['majority_is_sig']
        pval = compute_pvalue(sub_x, sub_y)
        
        pvalues.append(pval)
            
    statannot.add_stat_annotation(
        ax, plot='barplot',
        data=kwargs['data'], x=kwargs['x'], y=kwargs['y'], hue=kwargs['hue'],
        order=order, hue_order=hue_order,
        box_pairs=box_pairs,
        text_format='simple',
        pvalues=pvalues,
        perform_stat_test=False
    )

In [None]:
g = sns.FacetGrid(df_majority, row='TAD_type', col='tad_source', height=5, aspect=2)

g.map_dataframe(custom_barplot, x='filter', y='sig_frac', hue='is_cancer', palette='tab10')

g.set_axis_labels('Filter', 'Disease fraction sig. in $>0.5$ cases')
g.add_legend(title='is_cancer')

for ax in g.axes.flat:
    ax.tick_params(labelbottom=True)

g.savefig(outdir / 'dataset_overview.pdf')

## Figure 4: Comparison between 20in and 40in borders

In [None]:
# TODO

## Figure 5: multipartite graph

In [None]:
sub = df_data[df_data['20in'] == 'boundary']

In [None]:
graph = nx.from_pandas_edgelist(sub, 'diseaseId', 'snpId')
print(nx.info(graph))

In [None]:
graph_proj = nx.bipartite.projected_graph(graph, sub['diseaseId'].unique().tolist())
print(nx.info(graph_proj))

In [None]:
pos = nx.drawing.nx_agraph.graphviz_layout(graph_proj, prog='neato', args='-Goverlap=scale')

In [None]:
iscancer_map = sub.set_index('diseaseId').to_dict()['is_cancer']
node_color_list = ['orange' if iscancer_map[n] else 'blue' for n in graph_proj.nodes()]

In [None]:
plt.figure(figsize=(2*8, 2*6))

nx.draw_networkx_nodes(graph_proj, pos, node_size=20, node_color=node_color_list)
nx.draw_networkx_edges(graph_proj, pos, alpha=.2)

plt.axis('off')

plt.tight_layout()
plt.savefig(outdir / 'network.pdf')