In [None]:
import os
import glob
import collections

import pandas as pd
from natsort import natsorted, index_natsorted, order_by_index

import seaborn as sns
import matplotlib.pyplot as plt
from dna_features_viewer import GraphicFeature, GraphicRecord

from tqdm.auto import tqdm


In [None]:
sns.set_context('talk')

# Load data

In [None]:
root = 'aggregated_results/pipeline_run/results/'
source = 'Rao_IMR90_40k_{version}_7'  # 'dixon_ES_40k_hg19_13'

fname_tads_hg19 = f'tads_hg38__do_further_investigations:False;input_files+tad_coordinates:data_newleopoldtads_{source}.csv;git_branch:master.tsv'.format(version='hg19')
fname_tads_hg38 = f'tads_hg38__do_further_investigations:False;input_files+tad_coordinates:data_newleopoldtads_{source}.csv;git_branch:master.tsv'.format(version='hg38')

fname_snps_hg19 = f'final__do_further_investigations:False;input_files+tad_coordinates:data_newleopoldtads_{source}.csv;git_branch:master.csv'.format(version='hg19')
fname_snps_hg38 = f'final__do_further_investigations:False;input_files+tad_coordinates:data_newleopoldtads_{source}.csv;git_branch:master.csv'.format(version='hg38')

In [None]:
df_tads_hg19 = pd.read_csv(os.path.join(root, fname_tads_hg19), sep='\t')
df_tads_hg38 = pd.read_csv(os.path.join(root, fname_tads_hg38), sep='\t')

df_snps_hg19 = pd.read_csv(os.path.join(root, fname_snps_hg19))
df_snps_hg38 = pd.read_csv(os.path.join(root, fname_snps_hg38))

## Overview

In [None]:
pd.concat([df_tads_hg19, df_tads_hg38], axis=1).head()

# Compute SNP-TAD intersections

In [None]:
def get_tad_intervals(df_tads, border_size=20_000):
    # prepare
    tmp = df_tads.copy()
    tmp.reset_index(inplace=True)
    tmp.rename(columns={'index': 'tad_idx'}, inplace=True)
    
    # define TAD sections
    df_tad_body = tmp.copy()
    df_tad_body['start'] = tmp['tad_start'] + border_size
    df_tad_body['stop'] = tmp['tad_stop'] - border_size
    
    foo = []
    for row in tqdm(tmp.itertuples(), total=tmp.shape[0]):
        foo.extend([
            {
                'start': row.tad_start,
                'stop': row.tad_start + border_size,
                'border_side': 'left',
                **row._asdict()
            },
            {
                'start': row.tad_stop - border_size,
                'stop': row.tad_stop,
                'border_side': 'right',
                **row._asdict()
            }
        ])
    df_tad_border = pd.DataFrame(foo)
    
    return df_tad_body, df_tad_border

In [None]:
df_tad_body_hg38, df_tad_border_hg38 = get_tad_intervals(df_tads_hg38)

result = []
for row in tqdm(df_tad_border_hg38.itertuples(), total=df_tad_border_hg38.shape[0]):
    sub = df_snps_hg38[df_snps_hg38['chromosome'] == row.chrname[3:]]
    assert sub.shape[0] > 0

    matches = sub[sub['position'].between(row.start, row.stop)]
    
    if matches.shape[0] > 0:
        for res_row in matches.drop_duplicates(subset='snpId').itertuples():
            result.append({
                'snpId': res_row.snpId,
                'chromosome': res_row.chromosome,
                'position': res_row.position,
                'tad_idx': row.tad_idx,
                'border_side': row.border_side
            })
df_res = pd.DataFrame(result)

In [None]:
df_res.head()

# Select case

In [None]:
snp_counts = (df_res.groupby(['tad_idx', 'border_side'])['snpId']
                    .count()
                    .sort_values(ascending=False)
                    .reset_index()
                    .rename(columns={'snpId': 'count'}))
snp_counts.head()

In [None]:
case_num = 5

case_list = []
for i, row in enumerate(snp_counts.itertuples()):
    tmp = df_res[(df_res['tad_idx'] == row.tad_idx) & (df_res['border_side'] == row.border_side)].copy()
    tmp['case_idx'] = i
    case_list.append(tmp)
    
    if len(case_list) >= case_num:
        break
df_cases = pd.concat(case_list)

In [None]:
df_cases.head()

# Plot comparison

## Generate features

### TAD bodies and borders

In [None]:
def generate_tad_features(df_body, df_border):
    tmp = []
    
    # body
    for row in df_body.itertuples():
        tmp.append(GraphicFeature(
            start=row.start, end=row.stop,
            label=f'{row.tad_idx} (body)', color='blue'))

    # border
    for row in df_border.itertuples():
        tmp.append(GraphicFeature(
            start=row.start, end=row.stop,
            label=f'{row.tad_idx} (border)', color='red'))
        
    return tmp

### SNPs

#### hg38

In [None]:
def generate_snp_features_hg38(df_snp_list):
    features_hg38 = []

    # SNPs
    for row in df_snp_list.itertuples():
        features_hg38.append(GraphicFeature(
            start=row.position, end=row.position+1,
            label=row.snpId, color='black'))

    # TADs
    assert df_snp_list['chromosome'].unique().size == 1
    case_chrom = df_snp_list['chromosome'].iloc[0]

    features_hg38.extend(generate_tad_features(
        df_tad_body_hg38[df_tad_body_hg38['chrname'] == f'chr{case_chrom}'], 
        df_tad_border_hg38[df_tad_border_hg38['chrname'] == f'chr{case_chrom}']
    ))
    
    return features_hg38

#### hg19

In [None]:
def get_snp_position_hg19(snpId):
    # hg38 -> hg19
    sub = df_snps_hg19[df_snps_hg19.snpId==snpId].drop_duplicates(subset='snpId')
    assert sub.shape[0] == 1
    return sub['position'].iloc[0]

In [None]:
def generate_snp_features_hg19(df_snp_list):
    features_hg19 = []

    # SNPs
    for row in df_snp_list.itertuples():
        pos_hg19 = get_snp_position_hg19(row.snpId)
        features_hg19.append(GraphicFeature(
            start=pos_hg19, end=pos_hg19+1,
            label=row.snpId, color='black'))

    # TADs
    assert df_snp_list['chromosome'].unique().size == 1
    case_chrom = df_snp_list['chromosome'].iloc[0]

    df_tad_body_hg19, df_tad_border_hg19 = get_tad_intervals(df_tads_hg19)
    features_hg19.extend(generate_tad_features(
        df_tad_body_hg19[df_tad_body_hg19['chrname'] == f'chr{case_chrom}'],
        df_tad_border_hg19[df_tad_border_hg19['chrname'] == f'chr{case_chrom}']
    ))
    
    return features_hg19

## General plotting

In [None]:
def plot_region(features, ax, region_start, region_end):
    record = GraphicRecord(sequence_length=region_end+1_000_000, features=features)

    record_zoom = record.crop((region_start, region_end))
    record_zoom.plot(ax=ax)

In [None]:
def plot_snp_selection(df_snp_list, name):
    # generate features
    features_hg38 = generate_snp_features_hg38(df_snp_list)
    features_hg19 = generate_snp_features_hg19(df_snp_list)
    
    # determine window
    all_snp_positions = [gf.start for gf in (features_hg19+features_hg38) if gf.label.startswith('rs')]
    
    window = 30_000
    region_start = min(all_snp_positions) - window
    region_end = max(all_snp_positions) + window

    # plot
    plt.figure(figsize=(12, 12))

    ax = plt.subplot(211)
    plot_region(features_hg38, ax, region_start, region_end)
    plt.title('hg38')

    ax = plt.subplot(212)
    plot_region(features_hg19, ax, region_start, region_end)
    plt.title('hg19')

    plt.tight_layout()
    plt.savefig(f'images/coordinate_comparison_{name}.pdf')

## Plot

In [None]:
for case_idx, group in df_cases.groupby('case_idx'):
    plot_snp_selection(group, f'case{case_idx}')

# Special requests

In [None]:
sub = df_snps_hg38[df_snps_hg38['snpId'] == 'rs13218875'].iloc[:1]

In [None]:
plot_snp_selection(sub, 'rs13218875')

# SNP shifts per disease

## Data preparation

In [None]:
df_snps_hg19['TAD_relation'].fillna('undef', inplace=True)
df_snps_hg38['TAD_relation'].fillna('undef', inplace=True)

## Global overview

In [None]:
df_snps_hg19.groupby('is_cancer')['TAD_relation'].value_counts()

In [None]:
df_snps_hg38.groupby('is_cancer')['TAD_relation'].value_counts()

## Disease specific

In [None]:
df_snps_hg19.groupby(['is_cancer', 'diseaseId'])['TAD_relation'].value_counts().head()

In [None]:
df_snps_hg38.groupby(['is_cancer', 'diseaseId'])['TAD_relation'].value_counts().head()

In [None]:
df_shift = (
    df_snps_hg19.groupby(['is_cancer', 'diseaseId'])['TAD_relation'].value_counts() - 
    df_snps_hg38.groupby(['is_cancer', 'diseaseId'])['TAD_relation'].value_counts()
).to_frame('shift').reset_index()
df_shift.head()

In [None]:
g = sns.FacetGrid(df_shift, row='is_cancer', col='TAD_relation', height=3, aspect=2)

g.map_dataframe(sns.boxplot, x='shift')

g.set(xscale='symlog')
g.set_xlabels('Difference in SNP number')

g.fig.suptitle('Per-disease SNP counts ($hg19 - hg38$)', size=16)
g.fig.subplots_adjust(top=.8)

g.savefig('images/snp_shifts.pdf')

# Plots for all windows sizes

## Read TADs for all window sizes

In [None]:
root = 'aggregated_results/pipeline_run/results/'
source = 'Rao_IMR90_40k_{version}_*'

fname_tads_hg19 = f'tads_hg38__do_further_investigations:False;input_files+tad_coordinates:data_newleopoldtads_{source}.csv;git_branch:master.tsv'.format(version='hg19')
fname_tads_hg38 = f'tads_hg38__do_further_investigations:False;input_files+tad_coordinates:data_newleopoldtads_{source}.csv;git_branch:master.tsv'.format(version='hg38')

In [None]:
df_list = []

for fname in glob.glob(os.path.join(root, fname_tads_hg19)):
    tmp = pd.read_csv(fname, sep='\t')
    tmp['version'] = 'hg19'
    tmp['window_size'] = fname.split('/')[-1].split('.')[0].split(';')[1].split('_')[-1]
    df_list.append(tmp)
for fname in glob.glob(os.path.join(root, fname_tads_hg38)):
    tmp = pd.read_csv(fname, sep='\t')
    tmp['version'] = 'hg38'
    tmp['window_size'] = fname.split('/')[-1].split('.')[0].split(';')[1].split('_')[-1]
    df_list.append(tmp)

df_tadcoords = pd.concat(df_list)

In [None]:
df_tadcoords.head()

In [None]:
df_tadcoords_body, df_tadcoords_border = get_tad_intervals(df_tadcoords)

In [None]:
display(df_tadcoords_body.head())
display(df_tadcoords_border.head())

## Plot

In [None]:
coord_version = 'hg38'  # hg19, hg38

In [None]:
def get_snp_position_hg19(snpId):
    # hg38 -> hg19
    sub = df_snps_hg19[df_snps_hg19.snpId==snpId].drop_duplicates(subset='snpId')
    assert sub.shape[0] == 1
    return sub['position'].iloc[0]

### Determine data

In [None]:
case_idx, df_snp_list = list(df_cases.groupby('case_idx'))[0]

In [None]:
assert df_snp_list['chromosome'].unique().size == 1
chrom = 'chr' + df_snp_list['chromosome'].iloc[0]

### SNPs

In [None]:
features_snps = []

In [None]:
for row in df_snp_list.itertuples():
    if coord_version == 'hg19':
        pos = get_snp_position_hg19(row.snpId)
    elif coord_version == 'hg38':
        pos = row.position
    else:
        raise RuntimeError(coord_version)
    
    features_snps.append(GraphicFeature(
        start=pos, end=pos+1,
        label=row.snpId, color='black'))

### TADs

In [None]:
features_tads = collections.defaultdict(list)

In [None]:
# body
sub = df_tadcoords_body[
    (df_tadcoords_body['version'] == coord_version) &
    (df_tadcoords_body['chrname'] == chrom)
]
for row in sub.itertuples():
    features_tads[row.window_size].append(GraphicFeature(
        start=row.start, end=row.stop,
        color='blue'))  # label=f'{row.tad_idx}', 

In [None]:
# border
sub = df_tadcoords_border[
    (df_tadcoords_border['version'] == coord_version) &
    (df_tadcoords_border['chrname'] == chrom)
]
for row in sub.itertuples():
    features_tads[row.window_size].append(GraphicFeature(
        start=row.start, end=row.stop,
        color='red'))  # label=f'{row.tad_idx}', 

### Plotting

In [None]:
all_snp_positions = [gf.start for gf in features_snps if gf.label is not None and gf.label.startswith('rs')]
    
window = 30_000
region_start = min(all_snp_positions) - window
region_end = max(all_snp_positions) + window

In [None]:
fig, ax = plt.subplots(nrows=1 + len(features_tads), ncols=1, figsize=(20, 30))

plt.suptitle(coord_version)
#ax[0].set_title(coord_version)

record = GraphicRecord(sequence_length=region_end+1_000_000, features=features_snps)
record_zoom = record.crop((region_start, region_end))
record_zoom.plot(ax=ax[0])

# plot TADs for each window size in own axis
for i, window_size in enumerate(natsorted(features_tads.keys())):
    sub_features = features_tads[window_size]
    
    record = GraphicRecord(sequence_length=region_end+1_000_000, features=sub_features)
    record_zoom = record.crop((region_start, region_end))
    record_zoom.plot(ax=ax[i+1], with_ruler=False)
    
    ax[i+1].axis('on')
    ax[i+1].tick_params(
        axis='y', which='both',
        left=False, labelleft=False)
    ax[i+1].axes.get_xaxis().set_visible(False)
    [s.set_visible(False) for s in ax[i+1].spines.values()]
    ax[i+1].set_ylabel(window_size, rotation=0, size='large')
    
plt.subplots_adjust(hspace=.5, top=0.95)
#plt.tight_layout()
plt.savefig(f'images/coordinate_comparison_all_windowsizes_{coord_version}.pdf')