In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

from bioinf_common.plotting import annotated_barplot

In [None]:
sns.set_context('talk')

# Parameters

In [None]:
fname = snakemake.input.fname
tad_fname_list = snakemake.input.tad_fname_list

main_dataset = snakemake.config['parameters']['main_dataset']

outdir = Path(snakemake.output.outdir)

In [None]:
outdir.mkdir(parents=True, exist_ok=True)

# Read data

## SNP database

In [None]:
df_all = pd.read_csv(fname, dtype={'associated_genes': str})
df_all.head()

## TADs

In [None]:
df_list = []
for fname in tqdm(tad_fname_list):
    _, tad_source, window_size, _ = os.path.basename(fname).split('.')

    tmp = pd.read_csv(fname)
    tmp['tad_source'] = tad_source
    tmp['window_size'] = int(window_size)

    df_list.append(tmp)

df_tads_all = pd.concat(df_list)
df_tads_all['length'] = df_tads_all['tad_stop'] - df_tads_all['tad_start']

df_tads_all.head()

# Select data subset

In [None]:
%%time
df_sub = df_all[
    (df_all['chromosome_hg38'] == '6')
    & (25_000_000 <= df_all['position_hg38'])
    & (df_all['position_hg38'] <= 34_000_000)
].drop_duplicates(['tad_source', 'window_size', 'diseaseId', 'snpId'])

In [None]:
df_sub.shape

In [None]:
df_sub.head(2)

# Overview tables

In [None]:
outdir_lists = outdir / 'snp_lists'
outdir_lists.mkdir(parents=True, exist_ok=True)

In [None]:
disease_list = ['EFO_0001071', 'EFO_0000571', 'EFO_0000708']

## Total SNP-list per disease

In [None]:
df_totallist = df_sub.loc[
    df_sub['diseaseId'].isin(disease_list),
    [
        'tad_source',
        'window_size',
        'diseaseId',
        'snpId',
        'chromosome_hg38',
        'position_hg38',
        'variant_type_hg38',
        'variant_group_hg38',
        '20in',
    ],
].sort_values(['tad_source', 'window_size', 'diseaseId', 'snpId'])

df_totallist.to_csv(outdir_lists / 'chr6_all_snp_associations.csv.gz', index=False)
df_totallist.head()

In [None]:
for (tad_source, window_size, diseaseId), group in df_totallist.groupby(
    ['tad_source', 'window_size', 'diseaseId']
):
    group.to_csv(
        outdir_lists
        / f'chr6_all_snp_associations_{tad_source}_{window_size}_{diseaseId}.csv.gz',
        index=False,
    )

## SNP-Class Count List

In [None]:
%%time

tmp = []
for (tad_source, diseaseId, snpId), group in df_sub.groupby(
    ['tad_source', 'diseaseId', 'snpId']
):
    counts = group['20in'].value_counts().to_dict()

    tmp.append(
        {
            'tad_source': tad_source,
            'diseaseId': diseaseId,
            'snpId': snpId,
            'chromosome_hg38': group['chromosome_hg38'].iloc[0],
            'position_hg38': group['position_hg38'].iloc[0],
            'tad_count': counts.get('tad', 0),
            'border_count': counts.get('border', 0),
            'outside_count': counts.get('outside', 0)
            #         **group['20in'].value_counts().to_dict()
        }
    )

df_snp_class_counts_per_tadsource = pd.DataFrame(tmp)
df_snp_class_counts_per_tadsource.to_csv(
    outdir_lists / 'chr6_snp_class_counts_per_tadsource.csv.gz', index=False
)
df_snp_class_counts_per_tadsource.head()

In [None]:
%%time

tmp = []
for (diseaseId, snpId), group in df_sub.groupby(['diseaseId', 'snpId']):
    counts = group['20in'].value_counts().to_dict()

    tmp.append(
        {
            'diseaseId': diseaseId,
            'snpId': snpId,
            'chromosome_hg38': group['chromosome_hg38'].iloc[0],
            'position_hg38': group['position_hg38'].iloc[0],
            'tad_count': counts.get('tad', 0),
            'border_count': counts.get('border', 0),
            'outside_count': counts.get('outside', 0)
            #         **group['20in'].value_counts().to_dict()
        }
    )

df_snp_class_counts = pd.DataFrame(tmp)
df_snp_class_counts.to_csv(outdir_lists / 'chr6_snp_class_counts.csv.gz', index=False)
df_snp_class_counts.head()

### Per selected disease

In [None]:
if not df_snp_class_counts_per_tadsource.empty:
    for diseaseId, group in df_snp_class_counts_per_tadsource[
        df_snp_class_counts_per_tadsource['diseaseId'].isin(disease_list)
    ].groupby('diseaseId'):
        group.to_csv(
            outdir_lists / f'chr6_snp_class_counts_per_tadsource_{diseaseId}.csv.gz',
            index=False,
        )

In [None]:
if not df_snp_class_counts.empty:
    for diseaseId, group in df_snp_class_counts[
        df_snp_class_counts['diseaseId'].isin(disease_list)
    ].groupby('diseaseId'):
        group.to_csv(
            outdir_lists / f'chr6_snp_class_counts_{diseaseId}.csv.gz', index=False
        )

## Other statistics

In [None]:
df = df_all[(df_all['tad_source'] == main_dataset) & (df_all['window_size'] == 10)]
df.shape

In [None]:
df_uniqsnp = df.drop_duplicates('snpId')

In [None]:
df_uniqefo = df.drop_duplicates('diseaseId')

In [None]:
df_tads = df_tads_all[
    (df_tads_all['tad_source'] == main_dataset) & (df_tads_all['window_size'] == 10)
]
df_tads.shape

### EFO database

In [None]:
pd.DataFrame(
    {
        '#efo': [df_uniqefo.shape[0]],
        '#cancer_efo': ((df_uniqefo['is_cancer'])).sum(),
        '%cancer_efo': ((df_uniqefo['is_cancer'])).sum() / df_uniqefo.shape[0],
    }
)

### SNP database

In [None]:
df_uniqsnp.head(1)

In [None]:
pd.DataFrame(
    {
        '#snps': [df_uniqsnp.shape[0]],
        '#border_snps': ((df_uniqsnp['20in'] == 'border')).sum(),
        '%border_snps': ((df_uniqsnp['20in'] == 'border')).sum() / df_uniqsnp.shape[0],
    }
)

In [None]:
pd.DataFrame(
    {
        '#cancer_snps': [df_uniqsnp['is_cancer'].sum()],
        '#border_cancer_snps': (
            df_uniqsnp['is_cancer'] & (df_uniqsnp['20in'] == 'border')
        ).sum(),
        '%border_cancer_snps': (
            df_uniqsnp['is_cancer'] & (df_uniqsnp['20in'] == 'border')
        ).sum()
        / df_uniqsnp['is_cancer'].sum(),
    }
)

In [None]:
pd.DataFrame(
    {
        '#noncancer_snps': [(~df_uniqsnp['is_cancer']).sum()],
        '#border_noncancer_snps': (
            (~df_uniqsnp['is_cancer']) & (df_uniqsnp['20in'] == 'border')
        ).sum(),
        '%border_noncancer_snps': (
            (~df_uniqsnp['is_cancer']) & (df_uniqsnp['20in'] == 'border')
        ).sum()
        / (~df_uniqsnp['is_cancer']).sum(),
    }
)

In [None]:
pd.DataFrame(
    {
        '#cancer_snps': [df_uniqsnp['is_cancer'].sum()],
        '#intergenic_cancer_snps': (
            df_uniqsnp['is_cancer'] & (df_uniqsnp['variant_group_hg38'] == 'intergenic')
        ).sum(),
        '%intergenic_cancer_snps': (
            df_uniqsnp['is_cancer'] & (df_uniqsnp['variant_group_hg38'] == 'intergenic')
        ).sum()
        / df_uniqsnp['is_cancer'].sum(),
    }
)

In [None]:
pd.DataFrame(
    {
        '#snps': [df_uniqsnp.shape[0]],
        '#intergenic_snps': (df_uniqsnp['variant_group_hg38'] == 'intergenic').sum(),
        '%intergenic_snps': (df_uniqsnp['variant_group_hg38'] == 'intergenic').sum()
        / df_uniqsnp.shape[0],
    }
)

In [None]:
pd.DataFrame(
    {
        '#intergenic_snps': [(df_uniqsnp['variant_group_hg38'] == 'intergenic').sum()],
        '#border_intergenic_snps': (
            (df_uniqsnp['variant_group_hg38'] == 'intergenic')
            & (df_uniqsnp['20in'] == 'border')
        ).sum(),
        '%border_intergenic_snps': (
            (df_uniqsnp['variant_group_hg38'] == 'intergenic')
            & (df_uniqsnp['20in'] == 'border')
        ).sum()
        / (df_uniqsnp['variant_group_hg38'] == 'intergenic').sum(),
    }
)

In [None]:
pd.DataFrame(
    {
        '#intergenic_cancer_snps': [
            (
                df_uniqsnp['is_cancer']
                & (df_uniqsnp['variant_group_hg38'] == 'intergenic')
            ).sum()
        ],
        '#border_intergenic_cancer_snps': (
            df_uniqsnp['is_cancer']
            & (df_uniqsnp['variant_group_hg38'] == 'intergenic')
            & (df_uniqsnp['20in'] == 'border')
        ).sum(),
        '%border_intergenic_cancer_snps': (
            df_uniqsnp['is_cancer']
            & (df_uniqsnp['variant_group_hg38'] == 'intergenic')
            & (df_uniqsnp['20in'] == 'border')
        ).sum()
        / (df_uniqsnp['variant_group_hg38'] == 'intergenic').sum(),
    }
)

## TADs

In [None]:
genome_length = 3_092_480_053  # hg38
border_length = 20_000

In [None]:
pd.DataFrame(
    {
        'total_tad_length': [df_tads['length'].sum()],
        'total_border_length': [df_tads.shape[0] * 2 * border_length],
        '%total_border_length': [df_tads.shape[0] * 2 * border_length / genome_length],
    }
)

In [None]:
pd.DataFrame(
    {
        'total_tad_length': [df_tads['length'].sum()],
        'total_border_length': [df_tads.shape[0] * 2 * border_length],
        '%total_border_length': [
            df_tads.shape[0] * 2 * border_length / df_tads['length'].sum()
        ],
    }
)

# SNP lists

In [None]:
df_all.columns

## Disease specific

In [None]:
pd.set_option('display.max_columns', 500)

In [None]:
region = ('chr6', 25_000_000, 34_000_000)

In [None]:
disease_list = ['EFO_0001071', 'EFO_0000571', 'EFO_0000708']

In [None]:
df.loc[
    df['diseaseId'].isin(disease_list),
    ['diseaseId', 'snpId', 'variant_type_hg38', 'variant_group_hg38'],
].drop_duplicates()

In [None]:
df.loc[
    df['diseaseId'].isin(disease_list)
    & (df['chromosome_hg38'] == region[0][3:])
    & (df['position_hg38'] >= region[1])
    & (df['position_hg38'] <= region[2])
    & (df['20in'] <= 'border'),
    [
        'diseaseId',
        'snpId',
        'chromosome_hg38',
        'position_hg38',
        'variant_type_hg38',
        'variant_group_hg38',
    ],
].drop_duplicates().sort_values(['diseaseId', 'chromosome_hg38', 'position_hg38'])

# Plot database statistics

## Number of entries per disease

In [None]:
disease_counts = (
    df['diseaseId']
    .value_counts()
    .rename('count')
    .reset_index()
    .rename(columns={'index': 'diseaseId'})
    .sort_values('count')
    .merge(df[['diseaseId', 'is_cancer']], how='left', on='diseaseId')
)

disease_counts.head()

In [None]:
sns.boxplot(x='is_cancer', y='count', data=disease_counts)

plt.title('#rows associated with single diseases')
plt.yscale('log')

plt.tight_layout()
plt.savefig(outdir / 'disease_count_distribution.pdf')

## Odds ratio distribution

In [None]:
df['odds_ratio'].describe()

In [None]:
odds_ratio = df['odds_ratio'].dropna()
sns.boxplot(odds_ratio[odds_ratio < odds_ratio.quantile(0.75)], orient='v')

plt.title('Odds ratios (< 75% quantile) for all diseases')

plt.tight_layout()
plt.savefig(outdir / 'oddsratio_distribution.pdf')

## VEP statistics

### Raw variant types

In [None]:
variant_type_col = df.filter(like='variant_type').columns[0]

In [None]:
variant_type_counts = (
    df[['snpId', variant_type_col]]
    .drop_duplicates()[variant_type_col]
    .value_counts()
    .rename('count')
    .reset_index()
    .rename(columns={'index': 'variant_type'})
)

In [None]:
plt.figure(figsize=(16, 8))
sns.barplot(
    x='count',
    y='variant_type',
    data=variant_type_counts,
    orient='h',
    color=sns.color_palette()[0],
)

plt.title('#variant_type in database')
plt.xscale('log')

plt.tight_layout()
plt.savefig(outdir / 'variant_type_counts.pdf')

### Variant groups

In [None]:
variant_group_col = df.filter(like='variant_group').columns[0]

In [None]:
variant_group_counts = (
    df[['snpId', variant_group_col]]
    .drop_duplicates('snpId')[variant_group_col]
    .value_counts()
    .rename('count')
    .reset_index()
    .rename(columns={'index': 'variant_group'})
)

In [None]:
plt.figure(figsize=(8, 6))
annotated_barplot(
    x='variant_group',
    y='count',
    data=variant_group_counts,
    anno_kws=dict(label_offset=8, label_size=12),
)

plt.tight_layout()
plt.savefig(outdir / 'variant_group_counts.pdf')

## Gene counts

In [None]:
df_tmp = pd.DataFrame(
    {
        'diseaseId': df['diseaseId'],
        'associated_genes': df['associated_genes'].str.split(','),
        'gene_count': df['associated_genes']
        .str.split(',')
        .apply(lambda x: len(x) if isinstance(x, list) else 0),
    }
)
df_tmp.head()

In [None]:
sns.boxplot(y=df_tmp.groupby('diseaseId')['gene_count'].sum())

plt.xlabel('All diseases')
plt.ylabel('#associated genes')

unique_genes = set(
    g for gs in df_tmp['associated_genes'] if isinstance(gs, list) for g in gs
)
plt.title(f'{len(unique_genes)} unique genes in total')

plt.yscale('log')

plt.tight_layout()
plt.savefig(outdir / 'gene_counts.pdf')

## Filter statistics

In [None]:
df_filter_stats = df.filter(like='filter_').sum(axis=0).to_frame('count').reset_index()
df_filter_stats.head()

In [None]:
plt.figure(figsize=(16, 8))
sns.barplot(
    x='count', y='index', data=df_filter_stats, orient='h', color=sns.color_palette()[0]
)

plt.xlabel('Entry count')
plt.ylabel('Filter type')

plt.tight_layout()
plt.savefig(outdir / 'filter_counts.pdf')