In [None]:
import os
import itertools
from pathlib import Path

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

from natsort import natsorted
from tqdm.auto import tqdm
from scipy.special import binom

from bioinf_common.plotting import annotated_barplot, add_identity

In [None]:
sns.set_context('talk')
pd.set_option('display.max_columns', None)

# Parameters

In [None]:
fname_data = snakemake.input.fname_data
fname_enr = snakemake.input.fname_enr

outdir = Path(snakemake.output.outdir)

# Load data

In [None]:
df_enr = pd.read_csv(fname_enr, low_memory=True)
df_enr.head()

In [None]:
df_data = pd.read_csv(fname_data, low_memory=True)
df_data.head()

In [None]:
iscancer_map = df_data.set_index('diseaseId').to_dict()['is_cancer']

## Data overview

In [None]:
print('Data', df_data.shape)
print('Enrichment', df_enr.shape)

In [None]:
df_data.describe()

In [None]:
df_enr.describe()

# Investigate various ways of measuring signal strength

## Define signal measures

In [None]:
def enrichment_quotient(df, p_thres=0.05):
    tmp = df.copy()
    tmp.loc[tmp['pval_boundary_neglog'] < -np.log10(p_thres), 'pval_boundary_neglog'] = 0

    cancer_signal = tmp.loc[tmp['is_cancer'], 'pval_boundary_neglog'].mean()
    noncancer_signal = tmp.loc[~tmp['is_cancer'], 'pval_boundary_neglog'].mean()
    
    if noncancer_signal == 0:
        # avoid ZeroDivisionError
        return np.nan

#     print(df['tad_source'].iloc[0], cancer_signal, noncancer_signal, cancer_signal / noncancer_signal)
    return cancer_signal / noncancer_signal

In [None]:
def count_quotient(df, p_thres=0.05):
    tmp = df.copy()

    sig_cancer_num = tmp[tmp['is_cancer'] & (tmp['pval_boundary_neglog'] >= -np.log10(p_thres))].shape[0]
    all_cancer_num = tmp[tmp['is_cancer']].shape[0]
    
    sig_noncancer_num = tmp[(~tmp['is_cancer']) & (tmp['pval_boundary_neglog'] >= -np.log10(p_thres))].shape[0]
    all_noncancer_num = tmp[~tmp['is_cancer']].shape[0]
    
#     print(df['tad_source'].iloc[0], sig_cancer_num, all_cancer_num, sig_noncancer_num, all_noncancer_num)

    if sig_noncancer_num == 0:
        # avoid ZeroDivisionError
        return np.nan

    return (sig_cancer_num/all_cancer_num) / (sig_noncancer_num/all_noncancer_num)

## Apply measures

In [None]:
border_type = '20in'
pvalue_type = 'pval_boundary'  # pval_boundary__notcorrected
filter_type = 'nofilter'

In [None]:
# pre-transform data
df_trans = df_enr.groupby(['tad_source', 'window_size', 'filter', 'diseaseId', 'TAD_type']).first().reset_index().copy()

df_trans = df_trans[df_trans['TAD_type'] == border_type]
df_trans = df_trans[df_trans['filter'] == filter_type]

df_trans.loc[df_trans[pvalue_type]==0, pvalue_type] = 1e-16

df_trans['pval_boundary_neglog'] = df_trans[pvalue_type].apply(lambda x: -np.log10(x))

In [None]:
df_trans.head()

In [None]:
signal_data = []
for (tad_source, window_size), group in df_trans.groupby(['tad_source', 'window_size']):
    signal_data.extend([
        {
            'tad_source': tad_source,
            'window_size': window_size,
            'signal': enrichment_quotient(group),
            'type': r'$\frac{\langle-\log_{10}(p_{cancer,boundary_{20in}})\rangle}{\langle-\log_{10}(p_{noncancer,boundary_{20in}})\rangle}$'
        },
        {
            'tad_source': tad_source,
            'window_size': window_size,
            'signal': count_quotient(group),
            'type': r'$\frac{|\mathrm{cancer}_{sig}| / |\mathrm{cancer}_{all}|}{|\mathrm{noncancer}_{sig}| / |\mathrm{noncancer}_{all}|}$'
        }
    ])
df_signal = pd.DataFrame(signal_data)

In [None]:
df_signal.head()

## Visualize result

In [None]:
def my_bar(*args, **kwargs):
    # adjust baseline
    baseline = 1
    tmp = kwargs['data'].copy()
    tmp[kwargs['y']] -= baseline
    kwargs['data'] = tmp
    
    # plot
    annotated_barplot(
        *args, **kwargs,
        order=natsorted(kwargs['data']['window_size'].unique()),
        anno_kws=dict(label_offset=6, label_size=9),
        palette=sns.color_palette(),
        bottom=baseline)

In [None]:
g = sns.FacetGrid(
    df_signal,
    col='tad_source', col_wrap=2,
    sharex=False, sharey=True,
    height=7, aspect=2)

g.map_dataframe(my_bar, x='window_size', y='signal', hue='type')

plt.xlabel('Window size')
for ax in g.axes.ravel():
    ax.legend()
    
g.savefig(outdir / 'signal_vs_datasource.pdf')

# Filter-specific enriched disease fractions

## Detailed view

In [None]:
def detailed_view(df, fname):
    g = sns.FacetGrid(df, row='filter', col='tad_source', height=5, aspect=2)

    g.map_dataframe(sns.barplot, x='window_size', y='sig_count', hue='is_cancer', palette='tab10')

    g.set_axis_labels('Window size', 'Sig. disease fraction')
    g.add_legend(title='is_cancer')

    for ax in g.axes.flat:
        ax.tick_params(labelbottom=True)

    g.savefig(outdir / fname)
    plt.close()

## Aggregated view

In [None]:
def aggregated_view(df, fname):
    # aggregate data
    df_agg = df.groupby(['filter', 'tad_source', 'is_cancer'])['sig_count'].mean().reset_index()
    
    # plot
    g = sns.FacetGrid(df_agg, row='tad_source', height=5, aspect=2)

    g.map_dataframe(sns.barplot, x='filter', y='sig_count', hue='is_cancer', palette='tab10')

    g.set_axis_labels('Filter', 'Sig. disease fraction')
    g.add_legend(title='is_cancer')

    for ax in g.axes.flat:
        ax.tick_params(labelbottom=True)

    g.savefig(outdir / fname)
    plt.close()

## Generate plots

In [None]:
for tad_type, group in tqdm(df_enr.groupby('TAD_type')):
    # barplots
    df_sigcount__notcorrected = (group.groupby(['filter', 'tad_source', 'window_size', 'is_cancer'])
                                      .apply(lambda x: (x['pval_boundary__notcorrected'] <= .05).mean())
                                      .to_frame('sig_count')
                                      .reset_index())
    df_sigcount = (group.groupby(['filter', 'tad_source', 'window_size', 'is_cancer'])
                        .apply(lambda x: (x['pval_boundary'] <= .05).mean())
                        .to_frame('sig_count')
                        .reset_index())

    detailed_view(df_sigcount__notcorrected, f'sig_disease_fractions_{tad_type}__notcorrected.pdf')
    detailed_view(df_sigcount, f'sig_disease_fractions_{tad_type}.pdf')
    
    aggregated_view(df_sigcount__notcorrected, f'sig_disease_fractions_aggregated_{tad_type}__notcorrected.pdf')
    aggregated_view(df_sigcount, f'sig_disease_fractions_aggregated_fractions_{tad_type}.pdf')

## Majority vote

### Uncorrected P-Values

In [None]:
df_majority__notcorrected = (df_enr.groupby(['TAD_type', 'filter', 'tad_source', 'is_cancer', 'diseaseId'])
       .apply(lambda x: (x['pval_boundary__notcorrected'] <= .05).mean() > .5)
       .to_frame('majority_is_sig')
       .reset_index()
       .drop('diseaseId', axis=1)
       .groupby(['TAD_type', 'filter', 'tad_source', 'is_cancer'])['majority_is_sig']
       .apply(lambda x: x.mean())
       .to_frame('sig_frac')
       .reset_index()
)
df_majority__notcorrected.head()

In [None]:
g = sns.FacetGrid(df_majority__notcorrected, row='TAD_type', col='tad_source', height=5, aspect=2)

g.map_dataframe(sns.barplot, x='filter', y='sig_frac', hue='is_cancer', palette='tab10')

g.set_axis_labels('Filter', 'Disease fraction sig. in $>0.5$ cases')
g.add_legend(title='is_cancer')

for ax in g.axes.flat:
    ax.tick_params(labelbottom=True)

g.savefig(outdir / 'sig_disease_fractions_majority__notcorrected.pdf')

### Corrected P-Values

In [None]:
df_majority = (df_enr.groupby(['TAD_type', 'filter', 'tad_source', 'is_cancer', 'diseaseId'])
       .apply(lambda x: (x['pval_boundary'] <= .05).mean() > .5)
       .to_frame('majority_is_sig')
       .reset_index()
       .drop('diseaseId', axis=1)
       .groupby(['TAD_type', 'filter', 'tad_source', 'is_cancer'])['majority_is_sig']
       .apply(lambda x: x.mean())
       .to_frame('sig_frac')
       .reset_index()
)
df_majority.head()

In [None]:
g = sns.FacetGrid(df_majority, row='TAD_type', col='tad_source', height=5, aspect=2)

g.map_dataframe(sns.barplot, x='filter', y='sig_frac', hue='is_cancer', palette='tab10')

g.set_axis_labels('Filter', 'Disease fraction sig. in $>0.5$ cases')
g.add_legend(title='is_cancer')

for ax in g.axes.flat:
    ax.tick_params(labelbottom=True)

g.savefig(outdir / 'sig_disease_fractions_majority.pdf')

# SNP counts

## Aggregate counts

In [None]:
sub = df_enr[df_enr['TAD_type'] == '20in'][['diseaseId', '#border_snp', '#snp', 'tad_source', 'window_size', 'is_cancer', 'filter']].drop_duplicates()
sub['snp_fraction'] = sub['#border_snp'] / sub['#snp']

sub = sub[sub['filter'] == 'nofilter']

sub.head()

## Plot count data

In [None]:
g = sns.FacetGrid(
    sub,
    col='tad_source', col_wrap=min(2, sub['tad_source'].unique().size),
    col_order=natsorted(sub['tad_source'].unique()),
    sharex=False, sharey=True,
    height=7, aspect=2)

g.map_dataframe(
    sns.boxplot,
    x='window_size', y='snp_fraction', hue='is_cancer',
    order=natsorted(sub['window_size'].unique()))

# g.set(yscale='log')
g.set_axis_labels('Window size', r'$\frac{|snp_{boundary}|}{|snp_{all}|}$ per disease')
g.add_legend(title='is_cancer')

g.savefig(outdir / 'snp_numbers.pdf')

# Variant type effect

## Prepare data

In [None]:
df_trans = df_enr.set_index(['TAD_type', 'filter', 'tad_source', 'window_size', 'diseaseId'])[['pval_boundary']]
df_trans = df_trans.loc[~df_trans.index.duplicated(keep='first')]

In [None]:
df_trans.head()

In [None]:
df_none = df_trans.loc['20in'].loc['nofilter']
df_exonic = df_trans.loc['20in'].loc['exonic']
df_intronic = df_trans.loc['20in'].loc['intronic']
df_intergenic = df_trans.loc['20in'].loc['intergenic']
df_nonexonic = df_trans.loc['20in'].loc['nonexonic']

In [None]:
for name, df_tmp in [
    ('none', df_none), ('exonic', df_exonic), 
    ('intronic', df_intronic), ('intergenic', df_intergenic),
    ('nonexonic', df_nonexonic)
]:
    print(name)
    print('#entries:', df_tmp.shape[0])
    print()

## Aggregate data

In [None]:
df_merged = pd.DataFrame({
    'enrichment_none': df_none['pval_boundary'],
    'enrichment_exonic': df_exonic['pval_boundary'],
    'enrichment_intronic': df_intronic['pval_boundary'],
    'enrichment_intergenic': df_intergenic['pval_boundary'],
    'enrichment_nonexonic': df_nonexonic['pval_boundary']
})
df_merged.head()

In [None]:
# replace 0 by per-group minimum
tmp = df_merged.copy()

for col in tmp.columns:
    # replace values
    idx = tmp[col] != 0
    min_ = tmp[idx].groupby(level=['tad_source', 'window_size'])[col].apply(lambda x: x.min())
    
    if not tmp.loc[~idx, col].empty:
        tmp.loc[~idx, col] = min_
    
    # sanity check
    assert (df_merged.loc[idx & (~np.isnan(tmp[col])), col] == tmp.loc[idx & (~np.isnan(tmp[col])), col]).all()
    
assert not (tmp == 0).any().any()
df_merged = tmp

In [None]:
df_merged.head()

In [None]:
# transform to log-space
df_merged_log = df_merged.applymap(lambda x: -np.log10(x) if x > 0 else np.nan if np.isnan(x) else -1)
assert (df_merged_log != -1).all().all()

In [None]:
# add cancer labels
df_merged_log['is_cancer'] = df_merged_log.index.get_level_values('diseaseId').map(iscancer_map)

In [None]:
df_merged_log.head()

## Visualize

In [None]:
def custom_scatter(x, y, data, color):
    ax = sns.scatterplot(x=x, y=y, data=data, color=color)
    
    ax.axhline(-np.log10(.05), color='red', ls='dashed')
    ax.axvline(-np.log10(.05), color='red', ls='dashed')
    add_identity(ax, color='grey', ls='dashed')
    
    # fix axis ranges
    max_ = max(data[x_axis_data_source].max(), data[y_axis_data_source].max())
    max_ *= 1.05
    
    if not np.isnan(max_):
        ax.set_xlim((-.1, max_))
        ax.set_ylim((-.1, max_))
    
    # annotate diseases
    tmp = data.reset_index()

    if not tmp[y].empty:
        sel = tmp.loc[tmp[y].idxmax()]

        ax.annotate(
            sel.diseaseId,
            xy=(sel[x], sel[y]), xytext=(50, 0),
            xycoords='data', textcoords='offset points',
            fontsize=10, ha='center', va='center',
            arrowprops=dict(arrowstyle='->')
        )
    else:
        print('Warning, no disease annotation possible')

In [None]:
target_dir = outdir / 'enrichment_variants'
target_dir.mkdir()

for x_axis_data_source, y_axis_data_source in tqdm(itertools.combinations(df_merged.columns, 2), total=binom(4,2)):
    for idx, group in tqdm(df_merged_log.groupby(level=['tad_source', 'window_size']), leave=False):
        g = sns.FacetGrid(
            group, col='is_cancer', 
            sharex=False, sharey=False,
            height=5, aspect=1)

        g.map_dataframe(custom_scatter, x=x_axis_data_source, y=y_axis_data_source)
        g.set_axis_labels(x_axis_data_source, y_axis_data_source)
        
        plt.subplots_adjust(top=0.8)
        g.fig.suptitle(', '.join(str(x) for x in idx))
        
        g.savefig(target_dir / f'enrichment_variants__{"_".join(str(x) for x in idx)}_{x_axis_data_source}_{y_axis_data_source}.pdf')
        plt.close()