In [None]:
import os
import glob
import collections

import pandas as pd
from natsort import natsorted, index_natsorted, order_by_index

import seaborn as sns
import matplotlib.pyplot as plt
from dna_features_viewer import GraphicFeature, GraphicRecord

from tqdm.auto import tqdm

from bioinf_common.plotting import get_distinct_colors, create_custom_legend

In [None]:
sns.set_context('talk')

# Load data

In [None]:
source_dir = 'MY_RUN/agg_both/pipeline_run/results'

In [None]:
def read_glob(fname_base, fname_suffix='csv', sep=','):
    df_list = []
    for fname in tqdm(glob.glob(f'{source_dir}/{fname_base}.do_further_investigations=False,input_files+tad_coordinates=data_newleopoldtads_*.csv.{fname_suffix}')):
        # extract meta information
        info = dict(e.split('=') for e in '.'.join(fname.split('/')[-1].split('.')[1:-1]).split(','))
        _, _, *source, version, k = info['input_files+tad_coordinates'].split('.')[0].split('_')
        source = '_'.join(source)

        # skip, as we have only hg19 data for GM1278
        if source == 'Rao_GM1278_40k':
            continue

        # read data
        tmp = pd.read_csv(fname, sep=sep)
        tmp['source'] = source
        tmp['version'] = version
        tmp['window_size'] = k
        df_list.append(tmp)
    return pd.concat(df_list)

## Disease/SNP data

In [None]:
df_final = read_glob('final')
df_final.head()

## TAD data

In [None]:
df_tads = read_glob('tads_hg38', fname_suffix='tsv', sep='\t')
df_tads.head()

# Precompute coordinates

In [None]:
def get_tad_intervals(df_tads, border_size=20_000):
    # prepare
    tmp = (df_tads.reset_index()
                  .rename(columns={'index': 'tad_idx'}))
    
    # define TAD sections
    df_tad_body = tmp.copy()
    df_tad_body['start'] = tmp['tad_start'] + border_size
    df_tad_body['stop'] = tmp['tad_stop'] - border_size
    
    foo = []
    for row in tqdm(tmp.itertuples(), total=tmp.shape[0]):
        foo.extend([
            {
                'start': row.tad_start,
                'stop': row.tad_start + border_size,
                'border_side': 'left',
                **row._asdict()
            },
            {
                'start': row.tad_stop - border_size,
                'stop': row.tad_stop,
                'border_side': 'right',
                **row._asdict()
            }
        ])
    df_tad_border = pd.DataFrame(foo)
    
    return df_tad_body, df_tad_border

In [None]:
df_tad_body, df_tad_border = get_tad_intervals(df_tads)

In [None]:
df_tad_body.head()

In [None]:
df_tad_border.head()

# Helper functions

## Misc

## Plotting

In [None]:
def generate_snp_features(sub_final):
    return [GraphicFeature(start=row.position, end=row.position+1, label=row.snpId, color='black')
            for row in sub_final.itertuples()]

In [None]:
def generate_tad_features(sub_body, sub_border):
    features_tads = collections.defaultdict(list)
    
    # body
    for row in sub_body.itertuples():
        features_tads[row.window_size].append(GraphicFeature(
            start=row.start, end=row.stop,
            color='blue'))  # label=f'{row.tad_idx}', 
        
    # border
    for row in sub_border.itertuples():
        features_tads[row.window_size].append(GraphicFeature(
            start=row.start, end=row.stop,
            color='red'))  # label=f'{row.tad_idx}', 
    
    return dict(features_tads)

In [None]:
def plot_column(features_snps, features_tads, title, plot_region, ax_list, annotation=None):
    ax_list[0].set_title(title)

    # annotate
    if annotation is not None:
        ax_list[0].annotate(
            annotation,
            xy=(0, 1), xytext=(0, 1),
            xycoords='axes fraction',
            rotation=90, fontsize=10)
    
    # plot SNPs
    record = GraphicRecord(sequence_length=plot_region[1]+1_000_000, features=features_snps)
    record_zoom = record.crop(plot_region)
    record_zoom.plot(ax=ax_list[0])

    # plot TADs for each window size in own axis
    for i, window_size in enumerate(natsorted(features_tads.keys())):
        sub_features = features_tads[window_size]

        record = GraphicRecord(sequence_length=plot_region[1]+1_000_000, features=sub_features)
        record_zoom = record.crop(plot_region)
        record_zoom.plot(ax=ax_list[i+1], with_ruler=False)

        ax_list[i+1].axis('on')
        ax_list[i+1].tick_params(
            axis='y', which='both',
            left=False, labelleft=False)
        ax_list[i+1].axes.get_xaxis().set_visible(False)
        [s.set_visible(False) for s in ax_list[i+1].spines.values()]
        ax_list[i+1].set_ylabel(window_size, rotation=0, size='large')

In [None]:
def plot_ws(
    case_idx, snp_list, df_final, df_body, df_border, 
    out_dir='images', plot_window=30_000, figsize=(20, 30)
):
    # check that sources match
    base_source = df_final.iloc[0]['source']
    print(f'Using source "{base_source}"')
    
    assert [base_source] == \
        df_final['source'].unique().tolist() == \
        df_body['source'].unique().tolist() == \
        df_border['source'].unique().tolist(), \
    'All data must be from same source'
    
    # check that all SNPs are known
    assert df_final.loc[df_final['snpId'].isin(snp_list), 'snpId'].unique().size == len(set(snp_list)), 'Unknown SNPs'
    
    # check that SNPs all lie on same chromosome
    chrom_sub = df_final.loc[df_final['snpId'].isin(snp_list), 'chromosome']
    assert chrom_sub.unique().size == 1, 'All SNPs must lie on same chromosome'

    cur_chrom = 'chr' + chrom_sub.iloc[0]
    print(f'All SNPs are on chromesome "{cur_chrom}"')
    
    # generate SNP features
    features_snps_hg19 = generate_snp_features(
        df_final[(df_final['version'] == 'hg19') & (df_final['snpId'].isin(snp_list))].drop_duplicates(subset='snpId')
    )
    features_snps_hg38 = generate_snp_features(
        df_final[(df_final['version'] == 'hg38') & (df_final['snpId'].isin(snp_list))].drop_duplicates(subset='snpId')
    )
    
    print(f'SNPs (hg19): {len(features_snps_hg19)}')
    print(f'SNPs (hg38): {len(features_snps_hg38)}')
    
    # generate TAD features
    features_tads_hg19 = generate_tad_features(
        df_body[(df_body['version'] == 'hg19') & (df_body['chrname'] == cur_chrom)],
        df_border[(df_border['version'] == 'hg19') & (df_border['chrname'] == cur_chrom)])
    features_tads_hg38 = generate_tad_features(
        df_body[(df_body['version'] == 'hg38') & (df_body['chrname'] == cur_chrom)],
        df_border[(df_border['version'] == 'hg38') & (df_border['chrname'] == cur_chrom)])

    print(f'TADs (hg19): {len(features_tads_hg19)}')
    print(f'TADs (hg38): {len(features_tads_hg38)}')
    
    # determine plot region
    all_snp_positions = [gf.start for gf in (features_snps_hg19+features_snps_hg38) if gf.label is not None and gf.label.startswith('rs')]
    
    region_start = min(all_snp_positions) - plot_window
    region_end = max(all_snp_positions) + plot_window
    
    print(f'Plot region: {region_start}-{region_end}')
    
    # do actual plot
    fig, ax = plt.subplots(
        nrows=1 + len(features_tads_hg38), ncols=2,
        gridspec_kw={'height_ratios': [5] + [1] * len(features_tads_hg38)},
        figsize=figsize)
    
    plot_column(features_snps_hg19, features_tads_hg19, 'hg19', (region_start, region_end), ax[:,0])
    plot_column(features_snps_hg38, features_tads_hg38, 'hg38', (region_start, region_end), ax[:,1])

#         disease_list = [d for gf in features_snps for d in get_associated_diseases(gf.label)]
#         disease_counts = collections.Counter(disease_list)
#         common_disease, _ = disease_counts.most_common(1)[0]

#         for gf in features_snps:
#             if common_disease in get_associated_diseases(gf.label):
#                 gf.label = f'|{gf.label}|'
#                 gf.color = 'green'

    plt.tight_layout()
    plt.savefig(f'{out_dir}/coordinate_comparison_{case_idx}.pdf')
    plt.close()

## Minimal test case

In [None]:
tmp_final = pd.DataFrame({
    'snpId': ['rs1', 'rs1'],
    'chromosome': ['1', '1'],
    'position': [113, 150],
    'source': ['artificial', 'artificial'],
    'version': ['hg19', 'hg38'],
    'window_size': [1, 1]
})

tmp_body, tmp_border = get_tad_intervals(pd.DataFrame({
    'chrname': ['chr1', 'chr1'],
    'tad_start': [100, 60],
    'tad_stop': [200, 120],
    'source': ['artificial', 'artificial'],
    'version': ['hg19', 'hg38'],
    'window_size': [1, 1]
}), border_size=20)

In [None]:
snp_list = ['rs1']
plot_ws(
    'minimal_example', snp_list, tmp_final, tmp_body, tmp_border, 
    plot_window=10, figsize=(20, 9))

# Actual cases

In [None]:
source_subset = 'dixon_ES_40k'

In [None]:
sub_final = df_final.query(f'TAD_type == "20in" and source == "{source_subset}"')
sub_tad_body = df_tad_body.query(f'source == "{source_subset}"')
sub_tad_border = df_tad_border.query(f'source == "{source_subset}"')

## Minimal (real) example

In [None]:
snp_list = ['rs3798343']
plot_ws('minimal_example', snp_list, sub_final, sub_tad_body, sub_tad_border)

## Find SNPs in most inhabited TAD-border of enriched diseases

In [None]:
sub_enr = sub_final[(sub_final['version'] == 'hg38') & (sub_final['pval_boundary'] <= .05)]

In [None]:
tmp_list = []
for disease, group in tqdm(sub_enr.groupby('diseaseId')):
    sub_group = group.drop_duplicates(subset='snpId')
    
    # count border memberships
    border_counts = collections.defaultdict(set)
    for row in tqdm(sub_group.itertuples(), total=sub_group.shape[0], leave=False):
        borders = sub_tad_border[
            (sub_tad_border['chrname'].str[3:] == row.chromosome) &
            (sub_tad_border['start'] <= row.position) &
            (sub_tad_border['stop'] >= row.position)
        ]
        ser_border_idx = borders['Index'].map(str) + '_' + borders['border_side'] + '__' + borders['version'] + '_' + borders['window_size'].map(str)
        
        for b in ser_border_idx.tolist():
            border_counts[b].add(row.snpId)
            
    # sanity checks
    # TODO
            
    # find best border
    bb_idx = max(border_counts.items(), key=lambda x: len(x[1]))[0]
    tmp_list.append({
        'diseaseId': disease,
        'border_idx': bb_idx,
        'snp_list': ';'.join(border_counts[bb_idx])
    })

df_select = pd.DataFrame(tmp_list)
df_select.head()

## Basic statistics

In [None]:
sns.distplot(df_select['snp_list'].apply(lambda x: len(x.split(';'))), kde=False)

## Plot results

In [None]:
for row in df_select.itertuples():
    print(row)
    plot_ws(
        row.diseaseId, row.snp_list.split(';'),
        sub_final, sub_tad_body, sub_tad_border)
    print()