In [None]:
import os

import pandas as pd

import matplotlib.pyplot as plt
from dna_features_viewer import GraphicFeature, GraphicRecord

from tqdm import tqdm_notebook as tqdm

# Load data

In [None]:
root = 'aggregated_results/pipeline_run/results/'
source = 'Rao_IMR90_40k_{version}_7'  # 'dixon_ES_40k_hg19_13'

fname_tads_hg19 = f'tads_hg38__do_further_investigations:False;input_files+tad_coordinates:data_newleopoldtads_{source}.csv;git_branch:master.tsv'.format(version='hg19')
fname_tads_hg38 = f'tads_hg38__do_further_investigations:False;input_files+tad_coordinates:data_newleopoldtads_{source}.csv;git_branch:master.tsv'.format(version='hg38')

fname_snps_hg19 = f'final__do_further_investigations:False;input_files+tad_coordinates:data_newleopoldtads_{source}.csv;git_branch:master.csv'.format(version='hg19')
fname_snps_hg38 = f'final__do_further_investigations:False;input_files+tad_coordinates:data_newleopoldtads_{source}.csv;git_branch:master.csv'.format(version='hg38')

In [None]:
df_tads_hg19 = pd.read_table(os.path.join(root, fname_tads_hg19))
df_tads_hg38 = pd.read_table(os.path.join(root, fname_tads_hg38))

df_snps_hg19 = pd.read_csv(os.path.join(root, fname_snps_hg19))
df_snps_hg38 = pd.read_csv(os.path.join(root, fname_snps_hg38))

## Overview

In [None]:
pd.concat([df_tads_hg19, df_tads_hg38], axis=1).head()

# Compute SNP-TAD intersections

In [None]:
def get_tad_intervals(df_tads, border_size=20_000):
    # prepare
    tmp = df_tads.copy()
    tmp.reset_index(inplace=True)
    tmp.rename(columns={'index': 'tad_idx'}, inplace=True)
    
    # define TAD sections
    df_tad_body = pd.DataFrame({
        'tad_idx': tmp['tad_idx'],
        'chrname': tmp['chrname'],
        'start': tmp['tad_start'] + border_size,
        'stop': tmp['tad_stop'] - border_size
    })
    
    foo = []
    for row in tmp.itertuples():
        foo.extend([
            {
                'tad_idx': row.tad_idx,
                'chrname': row.chrname,
                'start': row.tad_start,
                'stop': row.tad_start + border_size,
                'border_side': 'left'
            },
            {
                'tad_idx': row.tad_idx,
                'chrname': row.chrname,
                'start': row.tad_stop - border_size,
                'stop': row.tad_stop,
                'border_side': 'right'
            }
        ])
    df_tad_border = pd.DataFrame(foo)
    
    return df_tad_body, df_tad_border

In [None]:
df_tad_body_hg38, df_tad_border_hg38 = get_tad_intervals(df_tads_hg38)

result = []
for row in tqdm(df_tad_border_hg38.itertuples(), total=df_tad_border_hg38.shape[0]):
    sub = df_snps_hg38[df_snps_hg38['chromosome'] == row.chrname[3:]]
    assert sub.shape[0] > 0

    matches = sub[sub['position'].between(row.start, row.stop)]
    
    if matches.shape[0] > 0:
        for res_row in matches.drop_duplicates(subset='snpId').itertuples():
            result.append({
                'snpId': res_row.snpId,
                'chromosome': res_row.chromosome,
                'position': res_row.position,
                'tad_idx': row.tad_idx,
                'border_side': row.border_side
            })
df_res = pd.DataFrame(result)

In [None]:
df_res.head()

# Select case

In [None]:
snp_counts = (df_res.groupby(['tad_idx', 'border_side'])['snpId']
                    .count()
                    .sort_values(ascending=False)
                    .reset_index()
                    .rename(columns={'snpId': 'count'}))
snp_counts.head()

In [None]:
case_num = 5

case_list = []
for i, row in enumerate(snp_counts.itertuples()):
    tmp = df_res[(df_res['tad_idx'] == row.tad_idx) & (df_res['border_side'] == row.border_side)].copy()
    tmp['case_idx'] = i
    case_list.append(tmp)
    
    if len(case_list) >= case_num:
        break
df_cases = pd.concat(case_list)

In [None]:
df_cases.head()

# Plot comparison

## Generate features

### TAD bodies and borders

In [None]:
def generate_tad_features(df_body, df_border):
    tmp = []
    
    # body
    for row in df_body.itertuples():
        tmp.append(GraphicFeature(
            start=row.start, end=row.stop,
            label=f'{row.tad_idx} (body)', color='blue'))

    # border
    for row in df_border.itertuples():
        tmp.append(GraphicFeature(
            start=row.start, end=row.stop,
            label=f'{row.tad_idx} (border)', color='red'))
        
    return tmp

### SNPs

### hg38

In [None]:
def generate_snp_features_hg38(df_snp_list):
    features_hg38 = []

    # SNPs
    for row in df_snp_list.itertuples():
        features_hg38.append(GraphicFeature(
            start=row.position, end=row.position+1,
            label=row.snpId, color='black'))

    # TADs
    assert df_snp_list['chromosome'].unique().size == 1
    case_chrom = df_snp_list['chromosome'].iloc[0]

    features_hg38.extend(generate_tad_features(
        df_tad_body_hg38[df_tad_body_hg38['chrname'] == f'chr{case_chrom}'], 
        df_tad_border_hg38[df_tad_border_hg38['chrname'] == f'chr{case_chrom}']
    ))
    
    return features_hg38

### hg19

In [None]:
def get_snp_position_hg19(snpId):
    # hg38 -> hg19
    sub = df_snps_hg19[df_snps_hg19.snpId==snpId].drop_duplicates(subset='snpId')
    assert sub.shape[0] == 1
    return sub['position'].iloc[0]

In [None]:
def generate_snp_features_hg19(df_snp_list):
    features_hg19 = []

    # SNPs
    for row in df_snp_list.itertuples():
        pos_hg19 = get_snp_position_hg19(row.snpId)
        features_hg19.append(GraphicFeature(
            start=pos_hg19, end=pos_hg19+1,
            label=row.snpId, color='black'))

    # TADs
    assert df_snp_list['chromosome'].unique().size == 1
    case_chrom = df_snp_list['chromosome'].iloc[0]

    df_tad_body_hg19, df_tad_border_hg19 = get_tad_intervals(df_tads_hg19)
    features_hg19.extend(generate_tad_features(
        df_tad_body_hg19[df_tad_body_hg19['chrname'] == f'chr{case_chrom}'],
        df_tad_border_hg19[df_tad_border_hg19['chrname'] == f'chr{case_chrom}']
    ))
    
    return features_hg19

## General plotting

In [None]:
def plot_region(features, ax, region_start, region_end):
    record = GraphicRecord(sequence_length=region_end+1_000_000, features=features)

    record_zoom = record.crop((region_start, region_end))
    record_zoom.plot(ax=ax)

In [None]:
def plot_snp_selection(df_snp_list, name):
    # generate features
    features_hg38 = generate_snp_features_hg38(df_snp_list)
    features_hg19 = generate_snp_features_hg19(df_snp_list)
    
    # determine window
    all_snp_positions = [gf.start for gf in (features_hg19+features_hg38) if gf.label.startswith('rs')]
    
    window = 30_000
    region_start = min(all_snp_positions) - window
    region_end = max(all_snp_positions) + window

    # plot
    plt.figure(figsize=(12, 12))

    ax = plt.subplot(211)
    plot_region(features_hg38, ax, region_start, region_end)
    plt.title('hg38')

    ax = plt.subplot(212)
    plot_region(features_hg19, ax, region_start, region_end)
    plt.title('hg19')

    plt.tight_layout()
    plt.savefig(f'images/coordinate_comparison_{name}.pdf')

## Plot

In [None]:
for case_idx, group in df_cases.groupby('case_idx'):
    plot_snp_selection(group, f'case{case_idx}')