In [None]:
import collections
from pathlib import Path

import numpy as np
import pandas as pd
import networkx as nx
import pyranges as pr

from scipy import ndimage, stats
from scipy.stats import fisher_exact

import statannot
from statannot.statannot import simple_text, pval_annotation_text

import seaborn as sns
import matplotlib.pyplot as plt
from adjustText import adjust_text

import matplotlib.transforms as tx
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from matplotlib.colors import SymLogNorm
from matplotlib.gridspec import GridSpec
from matplotlib.collections import LineCollection

import cooler

from natsort import natsorted
from tqdm.auto import tqdm, trange

from bioinf_common.tools import multipletests_nan
from bioinf_common.plotting import add_identity

In [None]:
sns.set_context('talk')
pd.set_option('display.max_columns', None)

# Parameters

In [None]:
fname_data = snakemake.input.fname_data
fname_enr = snakemake.input.fname_enr

sketch_hicfile = snakemake.input.sketch_hicfile
sketch_tadfile = snakemake.input.sketch_tadfile

sketch_region = snakemake.config['sketch']['region']

outdir = Path(snakemake.output.outdir)

In [None]:
outdir.mkdir(parents=True, exist_ok=True)

# Load data

## Database data

In [None]:
df_data = pd.read_csv(fname_data, low_memory=True)
df_data.head()

In [None]:
iscancer_map = df_data.drop_duplicates(subset=['diseaseId']).set_index('diseaseId').to_dict()['is_cancer']

## Enrichment data

In [None]:
df_enr = pd.read_csv(fname_enr, low_memory=True)

In [None]:
max_enrichment = 16
df_enr['pval_border_trans'] = df_enr['pval_border'].apply(lambda x: max_enrichment if x == 0 else -np.log10(x))
df_enr['pval_border__notcorrected_trans'] = df_enr['pval_border__notcorrected'].apply(lambda x: max_enrichment if x == 0 else -np.log10(x))

In [None]:
df_enr.head()

# Publication Figures

## Figure 1: Overview Sketch

### Read contact matrix

In [None]:
c = cooler.Cooler(sketch_hicfile)

mat = c.matrix(balance=False).fetch(sketch_region)
df_bins = c.bins().fetch(sketch_region)

df_mat = pd.DataFrame(mat, index=df_bins['start'], columns=df_bins['start'])

In [None]:
df_mat.head()

### Read TADs

In [None]:
df_tads = pr.PyRanges(pd.read_csv(sketch_tadfile).rename(columns={'chrname': 'Chromosome', 'tad_start': 'Start', 'tad_stop': 'End'}))
df_tads.head()

In [None]:
df_region = pr.PyRanges(pd.DataFrame({
    'Chromosome': [sketch_region[0]],
    'Start': [sketch_region[1]],
    'End': [sketch_region[2]]
}))
df_region

In [None]:
tad_list = df_tads.overlap(df_region)
# tad_list = tad_list[(tad_list.Start >= sketch_region[1]) & (tad_list.End <= sketch_region[2])]
tad_list

#### Find TAD boundaries

In [None]:
border_size = -20_000

def get_boundaries(df):
    tmp_front = df.copy()
    tmp_front.loc[:, 'End'] = tmp_front.Start - border_size
    tmp_front['type'] = 'front'
    
    tmp_back = df.copy()
    tmp_back.loc[:, 'Start'] = tmp_back.End + border_size
    tmp_back['type'] = 'back'

    return pd.concat([tmp_front, tmp_back])

In [None]:
border_list = tad_list.apply(get_boundaries).overlap(df_region)
border_list

### Read SNPs

In [None]:
tmp = df_data[['chromosome_hg19', 'position_hg19', 'snpId']].drop_duplicates().dropna().copy().rename(columns={'chromosome_hg19': 'Chromosome', 'position_hg19': 'Start'})
tmp['Chromosome'] = 'chr' + tmp['Chromosome'].astype(str)
tmp['End'] = tmp['Start'] + 1

df_snps = pr.PyRanges(tmp)
df_snps

In [None]:
# find SNPs in TAD borders
snp_list = df_snps.overlap(border_list)
snp_list

In [None]:
# classify SNPs
def classify(x):
    return x['is_cancer'].value_counts().idxmax()

snp_cancer_map = df_data[['diseaseId', 'snpId', 'is_cancer']].drop_duplicates().dropna().groupby('snpId').apply(classify).to_dict()
list(snp_cancer_map.items())[:2]

### Create plot

In [None]:
# rotate contact matrix
mat_rot = ndimage.rotate(df_mat, 45, order=0, reshape=True, cval=0, prefilter=False)

# create figure
s = 2
fig = plt.figure(figsize=(s*8, s*8))

# main plot
ax_main = plt.gca()

# plot rotated contact matrix
plotted_mat = ax_main.matshow(
    mat_rot,
    norm=SymLogNorm(1),
    cmap='YlOrRd',
    origin='lower',
    extent=(
        df_mat.index[0] + .5, df_mat.index[-1] + .5,
        df_mat.index[0] + .5, df_mat.index[-1] + .5
    ),
    aspect='equal'
)

# contact colorbar
ax_colobar = ax_main.inset_axes([.9, .6, .025, .3])
plt.colorbar(plotted_mat, cax=ax_colobar)
ax_colobar.set_ylabel('Contacts', rotation=90)

# adjust axes
center_height = (df_mat.index[0] + df_mat.index[-1]) / 2 + .5
ax_main.set_ylim(center_height, df_mat.index[-1] + .5)

ax_main.set_xlabel(sketch_region[0])
ax_main.xaxis.set_label_position('top') 
ax_main.set_xlim(df_mat.index[0], df_mat.index[-1])

ax_main.tick_params(
    axis='both',
    reset=True,
    which='both',
    top=True, labeltop=True,
    right=False, labelright=False,
    left=False, labelleft=False,
    bottom=False, labelbottom=False,
)

ax_main.ticklabel_format(axis='both', style='plain')

# highlight TADs
for row in tad_list.df.itertuples():
    tmp = np.sin(np.deg2rad(90)) * (row.End - row.Start) / 2

    pg = plt.Polygon([
        [row.Start, center_height], 
        [(row.Start + row.End) / 2, center_height + tmp], 
        [row.End, center_height]
    ], edgecolor='black', facecolor='none')
    ax_main.add_patch(pg)

# highlight TAD boundaries
for row in border_list.df.itertuples():
    tmp = np.tan(np.deg2rad(45)) * (row.End - row.Start)

    pg = plt.Polygon([
        [row.Start, center_height], 
        [row.End if row.type == 'front' else row.Start, center_height + tmp], 
        [row.End, center_height]
    ], edgecolor='black', facecolor='gray', alpha=.6)
    ax_main.add_patch(pg)
    
# auxiliary plot
ax_height = .3
ax_aux = ax_main.inset_axes([0, -ax_height, 1, ax_height])

ax_main.get_shared_x_axes().join(ax_main, ax_aux)
ax_aux.set_xlim(df_mat.index[0], df_mat.index[-1])
    
# highlight SNPs
dash_style = (0, (5, 5))

snp_pos = snp_list.df['Start']
snp_colors = 'black' #['orange' if snp_cancer_map[s] else 'blue' for s in snp_list.df['snpId']]
snp_ls = ['solid' if snp_cancer_map[s] else dash_style for s in snp_list.df['snpId']]

trans = tx.blended_transform_factory(ax_aux.transData, ax_aux.transAxes)
xy_pairs = np.column_stack([np.repeat(snp_pos, 2), np.tile([.8, 1], len(snp_pos))])

line_segs = xy_pairs.reshape([len(snp_pos), 2, 2])
ax_aux.add_collection(LineCollection(line_segs, transform=trans, colors=snp_colors, ls=snp_ls))

ax_aux.axis('off')

annotation_list = []
for row in snp_list.df.itertuples():
    pos = row.Start
    id_ = row.snpId

    linestyle = 'solid' if snp_cancer_map[id_] else dash_style

    a = ax_aux.annotate(
        id_,
        xy=(pos, .8), xytext=(pos, .5),
        xycoords=('data', 'axes fraction'), textcoords=('data', 'axes fraction'),
        arrowprops=dict(arrowstyle='-', linestyle=linestyle),
        annotation_clip=False,
        fontsize=13)
    annotation_list.append(a)

adjust_text(annotation_list, ax=ax_aux)

# save figure
plt.tight_layout()
plt.savefig(outdir / 'figure1.pdf')

## Figure 2: p-value Histogram for nice case

### Helper functions

In [None]:
def plot_histogram(df, ax=None, bins=np.linspace(0, 4, 50), column_name='pval_border_trans'):
    ax = ax if ax is not None else plt.gca()

    sns.histplot(
        data=df,
        x=column_name, hue='is_cancer',
        stat='density', common_norm=False,
        bins=bins, element='step',
        ax=ax)

    ax.axvline(x=-np.log10(.05), ls='dashed', color='red')
    
    ax.set_xlabel(r'$-log_{10}(\mathrm{pvalue})$')
    ax.set_ylabel('Density')

In [None]:
def plot_histogram_difference(df, ax=None, bins=np.linspace(0, 4, 50)):
    # compute histogram difference
    hist_cancer, bin_edges = np.histogram(df.loc[df['is_cancer'], 'pval_border_trans'], bins=bins, density=True)
    hist_noncancer, _ = np.histogram(df.loc[~df['is_cancer'], 'pval_border_trans'], bins=bins, density=True)
    
    hist_diff = hist_cancer - hist_noncancer
    bin_edge_list = [round(n, 2) for n in bin_edges[1:]]
    
    # plot
    ax = ax if ax is not None else plt.gca()

    markerline, stemline, baseline = ax.stem(
        bin_edge_list, hist_diff,
        basefmt='grey', linefmt='k', markerfmt='ok',
        use_line_collection=True
    )
    markerline.set_markersize(5)

    ax.axvline(-np.log10(.05), ls='dashed', color='red')

    ax.set_xlabel(r'$-log_{10}(\mathrm{pvalue})$')
    ax.set_ylabel('Difference (Cancer $-$ Non-cancer)')

### Compute extra multiple correction values

In [None]:
%%time
df_newcorrection = df_enr.groupby(['filter', 'TAD_type', 'tad_source', 'window_size'])[['pval_border__notcorrected']].transform(multipletests_nan)
df_newcorrection.head()

In [None]:
df_enr['pval_border__jointmultipletesting'] = df_newcorrection['pval_border__notcorrected']
df_enr['pval_border__jointmultipletesting_trans'] = df_enr['pval_border__jointmultipletesting'].apply(lambda x: max_enrichment if x == 0 else -np.log10(x))

### Main figures: nofilter

#### Subfigure 2a: nofilter

In [None]:
sub_nofilter = df_enr[
    (df_enr['filter'] == 'nofilter') &
    (df_enr['TAD_type'] == '20in') &
#     (df_enr['window_size'].isin([8,9,10,11,12])) &
    (df_enr['tad_source'].str.contains('Rao2014')) & (df_enr['tad_source'].str.contains('10kb')) &
    (df_enr['#snp'] >= 1)
]

In [None]:
stats.mannwhitneyu(sub_nofilter.loc[sub_nofilter['is_cancer'], 'pval_border'], sub_nofilter.loc[~sub_nofilter['is_cancer'], 'pval_border'])

In [None]:
plt.figure(figsize=(8, 6))
plot_histogram(sub_nofilter)

### Supplementary figures: nofilter

#### Supplementary subfigure 3a: nofilter

In [None]:
plt.figure(figsize=(8, 6))
plot_histogram_difference(sub_nofilter)

#### Supplementary subfigure 4a: nofilter

In [None]:
fig = plt.figure(figsize=(2*8, 6))
gs = fig.add_gridspec(nrows=1, ncols=2)

ax = fig.add_subplot(gs[0, 0])
plot_histogram(sub_nofilter, column_name='pval_border__notcorrected_trans')
ax.set_title('No multiple testing correction')

ax = fig.add_subplot(gs[0, 1])
plot_histogram(sub_nofilter, column_name='pval_border__jointmultipletesting_trans')
ax.set_title('Joint multiple testing correction')

### Main figures: intergenic

#### Subfigure 2b: intergenic filter

In [None]:
sub_intergenic = df_enr[
    (df_enr['filter'] == 'intergenic') &
    (df_enr['TAD_type'] == '20in') &
#     (df_enr['window_size'].isin([8,9,10,11,12])) &
    (df_enr['tad_source'].str.contains('Rao2014')) & (df_enr['tad_source'].str.contains('10kb')) &
    (df_enr['#snp'] >= 1)
]

In [None]:
stats.mannwhitneyu(sub_intergenic.loc[sub_intergenic['is_cancer'], 'pval_border'], sub_intergenic.loc[~sub_intergenic['is_cancer'], 'pval_border'])

In [None]:
plt.figure(figsize=(8, 6))
plot_histogram(sub_intergenic)

### Supplementary figures: intergenic

#### Supplementary subfigure 3b: intergenic

In [None]:
plt.figure(figsize=(8, 6))
plot_histogram_difference(sub_intergenic)

#### Supplementary subfigure 4b: intergenic

In [None]:
fig = plt.figure(figsize=(2*8, 6))
gs = fig.add_gridspec(nrows=1, ncols=2)

ax = fig.add_subplot(gs[0, 0])
plot_histogram(sub_intergenic, column_name='pval_border__notcorrected_trans')
ax.set_title('No multiple testing correction')

ax = fig.add_subplot(gs[0, 1])
plot_histogram(sub_intergenic, column_name='pval_border__jointmultipletesting_trans')
ax.set_title('Joint multiple testing correction')

### Panel figure

In [None]:
fig = plt.figure(figsize=(4*4, 4*3))
gs = fig.add_gridspec(nrows=2, ncols=2)

ax = fig.add_subplot(gs[0, 0])
plot_histogram(sub_nofilter, ax)
ax.set_title('nofilter')

ax = fig.add_subplot(gs[0, 1])
plot_histogram(sub_intergenic, ax)
ax.set_title('intergenic')

ax = fig.add_subplot(gs[1, 0])
plot_histogram_difference(sub_nofilter, ax)

ax = fig.add_subplot(gs[1, 1])
plot_histogram_difference(sub_intergenic, ax)

plt.tight_layout()
plt.savefig(outdir / 'figure2.pdf')

## Figure 3: Zoomed-out multi-dataset overview

In [None]:
df_enr.head(2)

### Parameter combinations

In [None]:
fraction_threshold = .5

window_size_list = df_enr['window_size'].unique()
snp_count_threshold_list = [0, 2, 5, 7, 10, 15, 20, 30, 40, 50]

In [None]:
def helper_function(x):
    assert x.shape[0] == 1
    return x.iloc[0] <= .05

In [None]:
%%time

df_list_detailed = []
for snp_threshold in tqdm(snp_count_threshold_list):
    for window_size in tqdm(window_size_list):
        sub = df_enr[(df_enr['#snp'] >= snp_threshold) & (df_enr['window_size'] == window_size)]

        tmp_detailed = (sub.groupby(['TAD_type', 'filter', 'tad_source', 'is_cancer', 'diseaseId'])['pval_border']
               .apply(helper_function)
               .to_frame('majority_is_sig')
               .reset_index()
        )

        tmp_detailed['snp_threshold'] = snp_threshold
        tmp_detailed['window_size'] = window_size
        
        df_list_detailed.append(tmp_detailed)
    
df_detailed = pd.concat(df_list_detailed)
df_detailed.head()

In [None]:
df_detailed.to_csv(outdir / 'snpthreshold_data.csv.gz', index=False)

### Majority vote over window sizes

In [None]:
fraction_threshold = .5

In [None]:
%%time

sub = df_enr[
    (df_enr['#snp'] >= 0) &
    (~df_enr['window_size'].isin([0, 1]))
]

df_majority = (sub.groupby(['TAD_type', 'filter', 'tad_source', 'is_cancer', 'diseaseId'])
       .apply(lambda x: (x['pval_border'] <= .05).mean() > fraction_threshold)
       .to_frame('majority_is_sig')
       .reset_index()
)
df_majority.head()

In [None]:
df_majority.to_csv(outdir / 'majorityvote_data.csv.gz', index=False)

## Cancer vs Non-cancer

### Helper functions

In [None]:
def select_one(df, col):
    assert df[col].nunique() == 1
    return df[col].iloc[0]

In [None]:
def compute_pvalue(x, y):
    ct = pd.crosstab(x, y)
    
    if ct.shape != (2, 2):
        print(f'Warning: invalid contingency table: {ct}')
        return np.nan
        
    return fisher_exact(ct)[1]

In [None]:
def custom_barplot(*args, order=None, additional_varying_columns=None, ax=None, **kwargs):
    # initial setup
    data = kwargs['data']
    x = kwargs['x']

    if order is None:
        order = natsorted(data[x].unique())

    assert set(order) <= set(data[x].unique())
    
    # make sure data all required columns are constant (i.e. grouping has worked)
    additional_varying_columns = additional_varying_columns if additional_varying_columns is not None else set()
    
    varying_columns = {x, 'is_cancer', 'diseaseId', 'majority_is_sig', *additional_varying_columns}
    for col in set(data.columns) - varying_columns:
        assert data[col].nunique() == 1, col
    
    # compute fraction
    data_agg = (data.groupby([x, 'is_cancer'])['majority_is_sig']
           .apply(lambda x: x.mean())
           .to_frame('sig_frac')
           .reset_index()
    )
    
    # standard barplot
    ax = sns.barplot(
        data=data_agg,
        x=x, y='sig_frac', hue='is_cancer',
        palette='tab10',
        order=order, hue_order=[False, True],
        ax=ax)
    
    # add annotations
    box_pairs = []
    pvalues = []
    text_annot_custom = []
    
    for order_value in order:
        box_pairs.append(((order_value, False), (order_value, True)))

        # subset data
        sub = data[data[x] == order_value]
    
        # compute p-value
        sub_x = sub['is_cancer']
        sub_y = sub['majority_is_sig']
        pval = compute_pvalue(sub_x, sub_y)
        
        pvalues.append(pval)
        
        txt_stars = pval_annotation_text(pval, [[1e-4, "****"], [1e-3, "***"], [1e-2, "**"], [0.05, "*"], [1, ""]])
        txt_value = simple_text(pval, '{:.2f}', [[1e-5, "1e-5"], [1e-4, "1e-4"], [1e-3, "0.001"], [1e-2, "0.01"]], '')
        text_annot_custom.append(f'{txt_stars}\n{txt_value}')
            
    statannot.add_stat_annotation(
        ax, plot='barplot',
        data=data_agg, x=x, y='sig_frac', hue='is_cancer',
        order=order, hue_order=[False, True],
        box_pairs=box_pairs,
#         text_format='star',
        pvalues=pvalues,
        text_annot_custom=text_annot_custom,
        perform_stat_test=False,
        verbose=0
    )

### SNP count threshold effect

In [None]:
snpcount_outdir = outdir / 'snpcount_effect'
snpcount_outdir.mkdir(exist_ok=True)

In [None]:
for (tad_type, tad_source, window_size, filter_), group in df_detailed.groupby(['TAD_type', 'tad_source' , 'window_size', 'filter']):
    plt.figure(figsize=(16, 6))

    custom_barplot(x='snp_threshold', data=group)
    
    plt.xlabel('SNP count threshold')
    plt.ylabel('Significant disease fraction')
    plt.title(f'{tad_type} - {tad_source} - {window_size} - {filter_}')
    
    plt.tight_layout()
    plt.savefig(snpcount_outdir / f'dataset_overview_snpcountthreshold_{tad_type}_{tad_source}_{window_size}_{filter_}.pdf')
    plt.close()

### Filter effect

In [None]:
filter_outdir = outdir / 'filter_effect'
filter_outdir.mkdir(exist_ok=True)

#### Panel figure

In [None]:
aggregated_subset = df_majority[
    (df_majority['TAD_type'] == '20in') &
    (df_majority['tad_source'].str.contains('Rao2014')) & (df_majority['tad_source'].str.contains('10kb'))
].copy()
aggregated_subset.head()

In [None]:
# just for the visualization
aggregated_subset['filter'] = aggregated_subset['filter'].replace('nofilter', 'all SNPs')

In [None]:
aggregated_groups = aggregated_subset.groupby('tad_source').groups.items()

In [None]:
fig = plt.figure(figsize=(20, 25), constrained_layout=False)
gs = fig.add_gridspec(nrows=2, ncols=1, height_ratios=[1, 2])


# main figure
with sns.plotting_context('paper', font_scale=3):
    ax = fig.add_subplot(gs[0, :])

    custom_barplot(
        x='filter', data=aggregated_subset,
        additional_varying_columns={'tad_source'},
        order=['all SNPs', 'exonic', 'intronic', 'intergenic'],
        ax=ax)
    
    pal = sns.color_palette('tab10', desat=.75)  # because seaborn does it this way
    ax.legend(handles=[
        Patch(facecolor=pal[1], edgecolor=pal[1], label='Cancer'),
        Patch(facecolor=pal[0], edgecolor=pal[0], label='Non-cancer')
    ], loc='best')

    ax.set_xlabel('')
    ax.set_ylabel(f'Disease fraction sig. in $>{fraction_threshold}$ cases')

    
# subfigures
inner_grid = gs[1, :].subgridspec(nrows=3, ncols=2, wspace=None, hspace=.3)
axs = inner_grid.subplots()

ax_list = axs.ravel()
assert len(ax_list) == len(aggregated_groups)

for (tad_source, idx_list), ax in zip(aggregated_groups, ax_list):
    group = aggregated_subset.loc[idx_list]
    
    custom_barplot(
        x='filter', data=group,
        order=['all SNPs', 'exonic', 'intronic', 'intergenic'],
        ax=ax)
    ax.legend([],[], frameon=False)
    
    ax.set_title('-'.join(tad_source.split('-')[1:3]))
    ax.set_xlabel('')
    ax.set_ylabel(f'Disease fraction sig. in $>{fraction_threshold}$ cases')


# save figure
plt.tight_layout()
plt.savefig(outdir / 'figure3.pdf')

#### All figures

In [None]:
for (tad_type, tad_source), group in df_majority.groupby(['TAD_type', 'tad_source']):
    plt.figure(figsize=(12, 6))

    custom_barplot(
        x='filter', data=group,
        order=['nofilter', 'exonic', 'intronic', 'intergenic'])

    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), title='is_cancer')

    plt.xlabel('Filter')
    plt.ylabel(f'Disease fraction sig. in $>{fraction_threshold}$ cases')
    plt.title(f'{tad_type} - {tad_source}')

    plt.tight_layout()
    plt.savefig(filter_outdir / f'dataset_overview_filter_{tad_type}_{tad_source}.pdf')
    plt.close()

### TAD-type effect

In [None]:
tadtype_outdir = outdir / 'tadtype_effect'
tadtype_outdir.mkdir(exist_ok=True)

In [None]:
for (filter_, tad_source), group in df_majority.groupby(['filter', 'tad_source']):
    plt.figure(figsize=(12, 6))

    custom_barplot(x='TAD_type', data=group)

    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), title='is_cancer')

    plt.xlabel('Border size [kbp]')
    plt.ylabel(f'Disease fraction sig. in $>{fraction_threshold}$ cases')
    plt.title(f'{tad_source} - {filter_}')

    plt.tight_layout()
    plt.savefig(tadtype_outdir / f'dataset_overview_bordersize_{filter_}_{tad_source}.pdf')
    plt.close()

## Figure 4: multipartite graph

## Prepare data

### Subset disease-SNP associations

In [None]:
df_data.head(1)

In [None]:
df_data_sub = df_data[
    (df_data['tad_source'].str.contains('Rao2014')) & (df_data['tad_source'].str.contains('10kb')) &
    (df_data['window_size'] == 0)
]

### Identify enriched diseases

In [None]:
aggregated_subset.head(1)

In [None]:
isenriched_map = (aggregated_subset.groupby(['TAD_type', 'filter', 'diseaseId'])['majority_is_sig'].mean() > .5).to_dict()
list(isenriched_map.items())[0]

## Helper functions

In [None]:
node_color_map = {
    (True, True): 'red',
    (True, False): 'orange',
    (False, True): 'blueviolet',
    (False, False): 'lightblue',
}

In [None]:
def assemble_network(df, min_snp_num=1, hub_threshold=0, verbose=False, show_isolated_nodes=True):
    # remove hub diseases
#     counts = df_data.groupby('diseaseId')['snpId'].nunique().sort_values()
#     nonhub_diseases = counts[counts > hub_threshold].index
    
#     df = df[df['diseaseId'].isin(nonhub_diseases)]
    
    # transform dataframe to networkx object
    graph = nx.from_pandas_edgelist(df, 'diseaseId', 'snpId')
    graph.name = 'bipartite graph'
    if verbose: print(nx.info(graph))
    
    # project bipartite graph
    graph_proj_multi = nx.bipartite.projected_graph(graph, df['diseaseId'].unique().tolist(), multigraph=True)
    graph_proj_multi.name = 'projected multigraph'
    if verbose: print(nx.info(graph_proj_multi))
    
    # remove isolated nodes
    if not show_isolated_nodes:
        if verbose: print('Remove isolated nodes')
        graph_proj_multi.remove_nodes_from(list(nx.isolates(graph_proj_multi)))
        if verbose: print(nx.info(graph_proj_multi))
    
    # merge multiple edges in MultiGraph into single edge with weight attribute
    graph_proj = nx.Graph()
    graph_proj.name = 'projected graph'

    for u, v, data in graph_proj_multi.edges(data=True):
        w = data.get('weight', 1)
        assert w == 1, 'oopsie'

        if graph_proj.has_edge(u, v):
            graph_proj[u][v]['weight'] += w
        else:
            graph_proj.add_edge(u, v, weight=w)

    if verbose: print(nx.info(graph_proj))
    
    # subset graph based on SNP count threshold
    if verbose: print('Subset graph')
    graph_proj = graph_proj.edge_subgraph([(u, v) for u, v, data in graph_proj.edges(data=True) if data['weight'] >= min_snp_num]).copy()
    graph_proj.add_nodes_from(nx.isolates(graph_proj_multi))
    if verbose: print(nx.info(graph_proj))
        
    # add node data
    nx.set_node_attributes(graph_proj, iscancer_map, 'is_cancer')

    nx.set_node_attributes(graph_proj, {n: isenriched_map.get(('20in', 'all SNPs', n), False) for n in graph_proj.nodes()}, 'is_enriched__nofilter')
    nx.set_node_attributes(graph_proj, {n: isenriched_map.get(('20in', 'intergenic', n), False) for n in graph_proj.nodes()}, 'is_enriched__intergenic')
    
    return graph_proj

In [None]:
def plot_network(graph, ax=None, enrichment_color_key='is_enriched'):
    # graph layout
    pos = nx.drawing.nx_agraph.graphviz_layout(graph, prog='neato', args='-Goverlap=scale')
    
    # node colors
    node_color_list = [node_color_map[data['is_cancer'], data[enrichment_color_key]] for n, data in graph.nodes(data=True)]
    
    # edge colors
#     max_weight = max(nx.get_edge_attributes(graph, 'weight').values())
#     edge_color_list = [(0, 0, 0, data['weight'] / max_weight) for u, v, data in graph.edges(data=True)]
    edge_color_list = (0, 0, 0, .2)
#     edge_color_list = [(0, 0, 0, .5 if data['weight'] > 1 else 0) for u, v, data in graph.edges(data=True)]
    
    # do plot
    ax = ax if ax is not None else plt.gca()
    
    nx.draw_networkx_nodes(graph, pos, node_size=20, node_color=node_color_list, ax=ax)
    nx.draw_networkx_edges(graph, pos, edge_color=edge_color_list, ax=ax)

    ax.axis('off')

### Visualize networks

#### Nofilter

In [None]:
df_data_sub_nofilter = df_data_sub[(df_data_sub['filter_nofilter_hg38'])]

In [None]:
sub_border = df_data_sub_nofilter[(df_data_sub_nofilter['20in'] == 'border')]
graph_border = assemble_network(sub_border, 2)
print(nx.info(graph_border))

In [None]:
sub_nonborder = df_data_sub_nofilter[~(df_data_sub_nofilter['20in'] == 'border')]
graph_nonborder = assemble_network(sub_nonborder, 2)
print(nx.info(graph_nonborder))

In [None]:
plt.figure(figsize=(12, 6))

plt.subplot(121)
plot_network(graph_nonborder, enrichment_color_key='is_enriched__nofilter')
plt.title('nonborder SNPs')

plt.subplot(122)
plot_network(graph_border, enrichment_color_key='is_enriched__nofilter')
plt.title('border SNPs')

plt.tight_layout()
plt.savefig(outdir / 'networks_nofilter.pdf')

#### Intergenic

In [None]:
df_data_sub_intergenic = df_data_sub[(df_data_sub['filter_intergenic_hg38'])]

In [None]:
sub_border = df_data_sub_intergenic[(df_data_sub_intergenic['20in'] == 'border')]
graph_border = assemble_network(sub_border, 2)
print(nx.info(graph_border))

In [None]:
sub_nonborder = df_data_sub_intergenic[~(df_data_sub_intergenic['20in'] == 'border')]
graph_nonborder = assemble_network(sub_nonborder, 2)
print(nx.info(graph_nonborder))

In [None]:
plt.figure(figsize=(12, 6))

plt.subplot(121)
plot_network(graph_nonborder, enrichment_color_key='is_enriched__intergenic')
plt.title('nonborder SNPs')

plt.subplot(122)
plot_network(graph_border, enrichment_color_key='is_enriched__intergenic')
plt.title('border SNPs')

plt.tight_layout()
plt.savefig(outdir / 'networks_intergenic.pdf')

#### Network grid

In [None]:
sub = df_data_sub_nofilter[(df_data_sub_nofilter['20in'] == 'border')]
graph_border = assemble_network(sub, 3, hub_threshold=0)

sub = df_data_sub_nofilter[~(df_data_sub_nofilter['20in'] == 'border')]
graph_nonborder = assemble_network(sub, 3, hub_threshold=0)

In [None]:
sub = df_data_sub_intergenic[(df_data_sub_intergenic['20in'] == 'border')]
graph_border_intergenic = assemble_network(sub, 3, hub_threshold=0)

sub = df_data_sub_intergenic[~(df_data_sub_intergenic['20in'] == 'border')]
graph_nonborder_intergenic = assemble_network(sub, 3, hub_threshold=0)

In [None]:
# setup figure
fig = plt.figure(figsize=(12, 12))
gs = fig.add_gridspec(nrows=2, ncols=2)


# plot networks
ax = fig.add_subplot(gs[0, 0])
plot_network(nx.subgraph(graph_nonborder, graph_border.nodes()), ax=ax, enrichment_color_key='is_enriched__nofilter')
ax.set_title('nonborder SNPs')

ax = fig.add_subplot(gs[0, 1])
plot_network(graph_border, ax=ax, enrichment_color_key='is_enriched__nofilter')
ax.set_title('border SNPs')
ax.legend(handles=[
    Line2D(
        [0], [0], marker='o', color='w',
        label={True: 'is_cancer', False: 'not-is_cancer'}[is_cancer] + ' & ' + {True: 'is_enriched', False: 'not-is_enriched'}[is_enriched],
        markerfacecolor=color, markersize=10)
    for (is_cancer, is_enriched), color in node_color_map.items()
], loc='best', fontsize=8)

ax = fig.add_subplot(gs[1, 0])
plot_network(nx.subgraph(graph_nonborder_intergenic, graph_border_intergenic.nodes()), ax=ax, enrichment_color_key='is_enriched__intergenic')
ax.set_title('nonborder SNPs (intergenic)')

ax = fig.add_subplot(gs[1, 1])
plot_network(graph_border_intergenic, ax=ax, enrichment_color_key='is_enriched__intergenic')
ax.set_title('border SNPs (intergenic)')


# save figure
plt.tight_layout()
plt.savefig(outdir / 'figure4.pdf')

## Edge statistics

In [None]:
def compute_edge_asymmetries(graph):
    disease_nodes = list(graph.nodes())
    disease_pairs = list(graph.edges())
    
    tmp = dict(collections.Counter([iscancer_map[d] for d in disease_nodes]).most_common())
    node_type_counts = {
        'cancer': tmp.get(True, 0),
        'noncancer': tmp.get(False, 0)
    }
    
    tmp = dict(collections.Counter([(iscancer_map[d1], iscancer_map[d2]) for d1, d2 in disease_pairs]).most_common())
    edge_type_counts = {
        'cancer_pairs': tmp.get((True, True), 0),
        'noncancer_pairs': tmp.get((False, False), 0),
        'across': tmp.get((True, False), 0) + tmp.get((False, True), 0)
    }

    if (node_type_counts['cancer'] * (node_type_counts['cancer'] - 1)) == 0:
        print(f'Warning: division by zero (cancer: {node_type_counts})')
        cancer_asym = np.nan
    else:
        cancer_asym = 2 * edge_type_counts['cancer_pairs'] / (node_type_counts['cancer'] * (node_type_counts['cancer'] - 1)) - edge_type_counts['across'] / (node_type_counts['cancer'] * node_type_counts['noncancer'])
    
    if (node_type_counts['noncancer'] * (node_type_counts['noncancer'] - 1)) == 0:
        print(f'Warning: division by zero (noncancer: {node_type_counts})')
        noncancer_asym = np.nan
    else:
        noncancer_asym = 2 * edge_type_counts['noncancer_pairs'] / (node_type_counts['noncancer'] * (node_type_counts['noncancer'] - 1)) - edge_type_counts['across'] / (node_type_counts['cancer'] * node_type_counts['noncancer'])
    
    
    return (cancer_asym, noncancer_asym)

In [None]:
sub = df_data_sub[(df_data_sub['20in'] == 'border')]

tmp = []
for thres in trange(1, 10):
    graph = assemble_network(sub, thres, hub_threshold=0)
    asyms = compute_edge_asymmetries(graph)
    
    tmp.append({
        'threshold': thres,
        'cancer_asymmetry': asyms[0],
        'noncancer_asymmetry': asyms[1],
        'node_count': len(graph.nodes()),
        'edge_count': len(graph.edges())
    })

df_asym_border = pd.DataFrame(tmp)
df_asym_border.head()

In [None]:
sub = df_data_sub[~(df_data_sub['20in'] == 'border')]

tmp = []
for thres in trange(1, 10):
    graph = assemble_network(sub, thres, hub_threshold=0)
    asyms = compute_edge_asymmetries(graph)
    
    tmp.append({
        'threshold': thres,
        'cancer_asymmetry': asyms[0],
        'noncancer_asymmetry': asyms[1],
        'node_count': len(graph.nodes()),
        'edge_count': len(graph.edges())
    })

df_asym_nonborder = pd.DataFrame(tmp)
df_asym_nonborder.head()

In [None]:
df_asym_border['type'] = 'border'
df_asym_nonborder['type'] = 'nonborder'

df_long = pd.melt(pd.concat([df_asym_border, df_asym_nonborder]), id_vars=['threshold', 'type'])
df_long.head()

In [None]:
plt.figure(figsize=(8, 6))

sns.lineplot(x='threshold', y='value', hue='variable', style='type', data=df_long, hue_order=['noncancer_asymmetry', 'cancer_asymmetry'])

plt.tight_layout()
plt.savefig(outdir / 'edge_asymmetries.pdf')

In [None]:
plt.figure(figsize=(8, 6))
sns.lineplot(x='threshold', y='value', hue='variable', style='type', data=df_long, hue_order=['node_count', 'edge_count'])

In [None]:
# compute
df_quotient = pd.DataFrame({
    'threshold': df_asym_border.loc[df_asym_border.index, 'threshold'],
    'noncancer_quotient': df_asym_border.loc[df_asym_border.index, 'noncancer_asymmetry'] / df_asym_nonborder.loc[df_asym_border.index, 'noncancer_asymmetry'],
    'cancer_quotient': df_asym_border.loc[df_asym_border.index, 'cancer_asymmetry'] / df_asym_nonborder.loc[df_asym_border.index, 'cancer_asymmetry']
})

# plot
plt.figure(figsize=(8, 6))

sns.lineplot(x='threshold', y='value', hue='variable', data=pd.melt(df_quotient, id_vars='threshold'), hue_order=['noncancer_quotient', 'cancer_quotient'])
plt.axhline(1, color='red', ls='dashed')

plt.tight_layout()
plt.savefig(outdir / 'edge_asymmetry_quotients.pdf')

# Miscellaneous Figures

## SNP associations to Cancer and Non-Cancer

In [None]:
df_data.head()

In [None]:
%%time

tmp_sub = df_data.loc[:, ['diseaseId', 'snpId', 'is_cancer']].drop_duplicates().dropna()
tmp_sub['is_cancer_shuffled'] = np.random.permutation(tmp_sub['is_cancer'].values)

df_cancercounts = (tmp_sub.groupby('snpId')['is_cancer']
                          .apply(lambda x: pd.Series(
                              [x.tolist().count(False), x.tolist().count(True)],
                              index=['noncancer_count', 'cancer_count']))
                          .unstack())
df_cancercountsshuffled = (tmp_sub.groupby('snpId')['is_cancer_shuffled']
                                  .apply(lambda x: pd.Series(
                                      [x.tolist().count(False), x.tolist().count(True)],
                                      index=['noncancer_count', 'cancer_count']))
                                  .unstack())

In [None]:
df_cancercounts.head()

In [None]:
plt.figure(figsize=(2*8, 6))

plt.subplot(121)
sns.scatterplot(x='cancer_count', y='noncancer_count', data=df_cancercounts, alpha=.2)
plt.title('observed')

plt.subplot(122)
sns.scatterplot(x='cancer_count', y='noncancer_count', data=df_cancercountsshuffled, alpha=.2)
plt.title('shuffled')