In [None]:
import collections
from pathlib import Path

import numpy as np
import pandas as pd
import networkx as nx
import pyranges as pr

from scipy import ndimage
from scipy.stats import fisher_exact

import statannot
import seaborn as sns
import matplotlib.pyplot as plt
from adjustText import adjust_text

import matplotlib.transforms as tx
from matplotlib.colors import SymLogNorm
from matplotlib.gridspec import GridSpec
from matplotlib.patches import ConnectionPatch
from matplotlib.collections import LineCollection

from dna_features_viewer import GraphicFeature, GraphicRecord

import cooler

from natsort import natsorted
from tqdm.auto import tqdm, trange

In [None]:
sns.set_context('talk')
pd.set_option('display.max_columns', None)

# Parameters

In [None]:
fname_data = snakemake.input.fname_data

sketch_hicfile = snakemake.input.sketch_hicfile
sketch_tadfile = snakemake.input.sketch_tadfile

tad_fname_list = snakemake.input.tad_fname_list

sketch_region = snakemake.config['sketch']['region']

outdir = Path(snakemake.output.outdir)

# Read data

## Read general data

In [None]:
df_data = pd.read_csv(fname_data, low_memory=True)
df_data.head()

In [None]:
# classify SNPs
def classify(x):
    return x['is_cancer'].value_counts().idxmax()

snp_cancer_map = df_data[['diseaseId', 'snpId', 'is_cancer']].drop_duplicates().dropna().groupby('snpId').apply(classify).to_dict()
list(snp_cancer_map.items())[:2]

## Read contact matrix

In [None]:
c = cooler.Cooler(sketch_hicfile)

mat = c.matrix(balance=False).fetch(sketch_region)
df_bins = c.bins().fetch(sketch_region)

df_mat = pd.DataFrame(mat, index=df_bins['start'], columns=df_bins['start'])

In [None]:
df_mat.head()

## Read TAD data

In [None]:
df_tads = pr.PyRanges(pd.read_csv(sketch_tadfile).rename(columns={'chrname': 'Chromosome', 'tad_start': 'Start', 'tad_stop': 'End'}))
df_tads.head()

## Read SNPs

In [None]:
tmp = df_data[['chromosome_hg19', 'position_hg19', 'snpId']].drop_duplicates().dropna().copy().rename(columns={'chromosome_hg19': 'Chromosome', 'position_hg19': 'Start'})
tmp['Chromosome'] = 'chr' + tmp['Chromosome'].astype(str)
tmp['End'] = tmp['Start'] + 1

df_snps = pr.PyRanges(tmp)
df_snps

# Preprocessing

In [None]:
df_region = pr.PyRanges(pd.DataFrame({
    'Chromosome': [sketch_region[0]],
    'Start': [sketch_region[1]],
    'End': [sketch_region[2]]
}))
df_region

## TAD subsetting

In [None]:
tad_list = df_tads.overlap(df_region)
# tad_list = tad_list[(tad_list.Start >= sketch_region[1]) & (tad_list.End <= sketch_region[2])]
tad_list

## Border computation

In [None]:
boundary_size = -20_000

In [None]:
def get_boundaries(df):
    tmp_front = df.copy()
    tmp_front.loc[:, 'End'] = tmp_front.Start - boundary_size
    tmp_front['type'] = 'front'
    
    tmp_back = df.copy()
    tmp_back.loc[:, 'Start'] = tmp_back.End + boundary_size
    tmp_back['type'] = 'back'

    return pd.concat([tmp_front, tmp_back])

In [None]:
border_list = tad_list.apply(get_boundaries)
border_list

## SNP subsetting

In [None]:
# find SNPs in TAD borders
snp_list = df_snps.overlap(border_list)
snp_list

In [None]:
snp_list = snp_list.overlap(df_region)

# Plot

## Generate features

In [None]:
features_tads = collections.defaultdict(list)

for fname in tqdm(tad_fname_list):
    _, tad_source, window_size, _ = os.path.basename(fname).split('.')
    name = tad_source
    
    df_tads_current = pr.PyRanges(pd.read_csv(fname).rename(columns={'chrname': 'Chromosome', 'tad_start': 'Start', 'tad_stop': 'End'}))
    tad_list_current = df_tads_current.overlap(df_region)
    
    for row in tad_list_current.df.itertuples():
        tad_len = row.End - row.Start

        if tad_len < 2 * abs(boundary_size):
            features_tads[name].append(GraphicFeature(
                start=row.Start, end=row.End,
                color='yellow'))
        else:
            # body
            features_tads[name].append(GraphicFeature(
                start=row.Start - boundary_size, end=row.End + boundary_size,
                color='blue'))

            # border
            features_tads[name].append(GraphicFeature(
                start=row.Start, end=row.Start - boundary_size,
                color='red'))
            features_tads[name].append(GraphicFeature(
                start=row.End + boundary_size, end=row.End,
                color='red'))

features_tads = dict(features_tads)

## Main figure

In [None]:
fig, ax_list = plt.subplots(
    nrows=1 + len(features_tads) + 1, ncols=1,
    gridspec_kw={'height_ratios': [10] + [1] * len(features_tads) + [5]},
    sharex=True,
    figsize=(15, 25))

# heatmap
mat_rot = ndimage.rotate(df_mat, 45, order=0, reshape=True, cval=0, prefilter=False)

ax = ax_list[0]
ax.matshow(
    mat_rot,
    norm=SymLogNorm(1),
    cmap='YlOrRd',
    origin='lower',
    extent=(
        df_mat.index[0] + .5, df_mat.index[-1] + .5,
        df_mat.index[0] + .5, df_mat.index[-1] + .5
    ),
    aspect='auto'
)

center_height = (df_mat.index[0] + df_mat.index[-1]) / 2 + .5
ax.set_ylim(center_height, df_mat.index[-1] + .5)

ax.set_xlabel(sketch_region[0])
ax.xaxis.set_label_position('top') 
ax.set_xlim(sketch_region[1], sketch_region[2])

ax.tick_params(
    axis='both',
    reset=True,
    which='both',
    top=True, labeltop=True,
    right=False, labelright=False,
    left=False, labelleft=False,
    bottom=False, labelbottom=False)

ax.ticklabel_format(axis='both', style='plain')

for row in tad_list.df.itertuples():
    tmp = np.sin(np.deg2rad(90)) * (row.End - row.Start) / 2

    pg = plt.Polygon([
        [row.Start, center_height], 
        [(row.Start + row.End) / 2, center_height + tmp], 
        [row.End, center_height]
    ], edgecolor='black', facecolor='none')
    ax.add_patch(pg)

for row in border_list.df.itertuples():
    tmp = np.tan(np.deg2rad(45)) * (row.End - row.Start)

    pg = plt.Polygon([
        [row.Start, center_height], 
        [row.End if row.type == 'front' else row.Start, center_height + tmp], 
        [row.End, center_height]
    ], edgecolor='black', facecolor='gray', alpha=.6)
    ax.add_patch(pg)

# TAD plots
# for (name, feature_list), ax in zip(features_tads.items(), ax_list[1:-1]):
for name, ax in zip(sorted(features_tads), ax_list[1:-1]):
    feature_list = features_tads[name]
    
    record = GraphicRecord(sequence_length=sketch_region[2]+1_000_000, features=feature_list)
    record_zoom = record.crop(sketch_region[1:])

    record_zoom.plot(ax=ax, with_ruler=False)

    ax.axis('off')
    ax.text(
        0, 1,
        name,
        horizontalalignment='left',
        verticalalignment='top',
        fontsize=10,
        transform=ax.transAxes)
    
# SNP plot
ax = ax_list[-1]

for row in snp_list.df.itertuples():
    id_ = row.snpId
    pos = row.Start

    con = ConnectionPatch(
        xyA=(pos, 0), coordsA=tx.blended_transform_factory(ax_list[0].transData, ax_list[0].transAxes),
        xyB=(pos, 0.8), coordsB=tx.blended_transform_factory(ax_list[-1].transData, ax_list[-1].transAxes),
        linewidth=.5,
        linestyle='solid' if snp_cancer_map[id_] else dash_style)
    fig.add_artist(con)

# SNP label plot
annotation_list = []
for row in snp_list.df.itertuples():
    id_ = row.snpId
    pos = row.Start

    linestyle = 'solid' if snp_cancer_map[id_] else dash_style

    a = ax.annotate(
        id_,
        xy=(pos, .8), xytext=(pos, .5),
        xycoords=('data', 'axes fraction'), textcoords=('data', 'axes fraction'),
        arrowprops=dict(arrowstyle='-', linewidth=.5, linestyle=linestyle),
        annotation_clip=False,
        fontsize=13)
    annotation_list.append(a)

adjust_text(annotation_list, ax=ax)

ax.axis('off')

# save figure
plt.tight_layout()
plt.savefig(outdir / 'supp.pdf')