In [None]:
import os
import sys
import urllib
import tempfile
import collections

import numpy as np
import pandas as pd
import networkx as nx

from bs4 import BeautifulSoup

from tqdm import tqdm_notebook as tqdm

In [None]:
from tqdm import tqdm as tqdm_orig
tqdm_orig.pandas()

## Load data

In [None]:
df = pd.read_table(
    'data/curated_variant_disease_associations.tsv.gz',
    usecols=['snpId','diseaseId','diseaseName','source'])
df.rename(columns={'diseaseId': 'UMLS_CUI'}, inplace=True)

In [None]:
df.head()

## Decide whether to use hg19 or hg38

In [None]:
# Restart kernel and rerun all cells after changing this
USE_HG38 = True

In [None]:
if USE_HG38:
    genome_version = 'hg38'
    tad_data_fname = 'results/tads_hESC_hg38.tsv'
    
    def snp_position_convert(row):
        # don't do anything, as positions are already in hg38
        return row.position
else:
    genome_version = 'hg19'
    tad_data_fname = 'data/tads_hESC_hg19_with_ids.txt'
    
    # get hg19 SNP-positions
    df_snp_pos_map = pd.read_csv('results/snp_positions_hg19.csv')
    df_snp_pos_map['chrom'] = df_snp_pos_map['chrom'].apply(lambda x: x[3:])
    df_snp_pos_map['pos'] = list(zip(df_snp_pos_map['chrom'], df_snp_pos_map['end']))

    snp_pos_hg19 = df_snp_pos_map.set_index('SNPS').to_dict()['pos']
    def snp_position_convert(row):
        # convert BP-position from hg38 to hg19
        if row.snpId not in snp_pos_hg19:
            return np.nan
        
        chrom, pos_hg19 = snp_pos_hg19[row.snpId]
        assert row.chromosome == chrom
        return pos_hg19

## Disease ontology

### Load data

In [None]:
#g = onto2nx.parse_owl_rdf('data/doid.owl')
#nx.write_edgelist(g, 'results/doid_graph.edgelist.gz')
doid_graph = nx.read_edgelist('results/doid_graph.edgelist.gz', create_using=nx.DiGraph()).reverse()
print(nx.info(doid_graph))

In [None]:
with open('data/doid.owl') as fd:
    soup = BeautifulSoup(fd, 'xml')

In [None]:
node_owl_data = {}

for entry in tqdm(soup.find_all('Class')):
    doid = entry['rdf:about'].split('/')[-1]
    
    # get label
    lbl = entry.find('rdfs:label').get_text()
    
    # get UMLS_CUI terms
    terms = []
    for xref in entry.find_all('oboInOwl:hasDbXref'):
        txt = xref.get_text()
        if txt.startswith('UMLS_CUI:'):
            cui = txt.split(':')[-1]
            terms.append(cui)
    
    assert doid not in node_owl_data
    node_owl_data[doid] = {
        'label': lbl,
        'UMLS_CUI': terms
    }
    
nx.set_node_attributes(doid_graph, node_owl_data)

In [None]:
# check out exemplary node (cancer)
doid_data = dict(doid_graph.nodes(data=True))

doid_data['DOID_162']

In [None]:
data_cui = []
for node, data in tqdm(doid_data.items()):
    for term in data['UMLS_CUI']:
        data_cui.append((node, data['label'], term))

df_cui = pd.DataFrame(data_cui, columns=['DOID','DO_label','UMLS_CUI'])
df_cui.head()

### Find cancer subtree

In [None]:
cancer_nodes = nx.descendants(doid_graph, 'DOID_162')

data_cancer = []
for n in cancer_nodes:
    data_cancer.append((n, True))
for n in (doid_graph.nodes() - cancer_nodes):
    data_cancer.append((n, False))
    
df_iscancer = pd.DataFrame(data_cancer, columns=['DOID','is_cancer'])
df_iscancer.head()

### Merge data sources

In [None]:
print('Nodes in doid.owl:', len(doid_data))
print('Nodes with UMLS_CUI:', df_cui.DOID.unique().size)
print('(Non)cancer nodes (should be all):', df_iscancer.DOID.unique().size)

In [None]:
df_onto = df_cui.merge(df_iscancer, on='DOID')

print(df_onto.shape)
df_onto.head()

In [None]:
# save disease cancer-classification
tmp = df_onto[['UMLS_CUI','is_cancer','DO_label']].copy()
tmp.rename(columns={'UMLS_CUI': 'term', 'is_cancer': 'type', 'DO_label': 'label'}, inplace=True)
tmp['type'] = tmp['type'].apply(lambda x: 'cancer' if x else 'disease')
tmp.to_csv('results/disease_terms.csv', index=False)

## Infer TAD relations

### Load SNP positions

In [None]:
df_snppos = pd.read_table(
    'data/all_variant_disease_pmid_associations.tsv.gz', usecols=['snpId','chromosome','position'])

df_snppos.drop_duplicates(inplace=True)
df_snppos.dropna(inplace=True)
df_snppos['position'] = df_snppos['position'].astype(int)

In [None]:
new_snp_positions = df_snppos.progress_apply(snp_position_convert, axis=1)

# save mapping overview
if not USE_HG38:
    df_tmp = df_snppos.copy()
    df_tmp.rename(columns={'position': 'position_hg38'}, inplace=True)
    df_tmp['position_hg19'] = new_snp_positions
    df_tmp.to_csv('results/snps_hg19_hg38.csv', index=False)

df_snppos['position'] = new_snp_positions
df_snppos.dropna(inplace=True)

In [None]:
df_snppos.head()

### Load TAD data

In [None]:
df_tads = pd.read_table(tad_data_fname)

In [None]:
df_tads.head()

### Helper functions

In [None]:
class RangeDict(dict):
    """ Optimized for ranges with step==1
    """
    def __getitem__(self, item):
        if type(item) != range:
            for key in self:
                if key.step == 1:
                    if key.start <= item < key.stop:
                        return self[key]
                else:
                    if item in key:
                        return self[key]
        else:
            return super().__getitem__(item)

In [None]:
class EmptyTAD(Exception):
    pass

class OverlappingTADS(Exception):
    pass

In [None]:
def get_tad_lengths(row, type_):
    """ Get TAD and boundary lengths depending on type
    """
    if type_ not in (
        '20in', '40in',
        '20out', '40out',
        '20inout', '40inout',
        '60in', '80in'
    ):
        raise RuntimeError(f'Invalid type {type_}')

    tad_start = row.tad_start
    tad_stop = row.tad_stop

    # set reach of border
    if type_ in ('20in', '20out', '20inout'):
        bp_in = bp_out = 20_000
    if type_ in ('40in', '40out', '40inout'):
        bp_in = bp_out = 40_000
    if type_ in ('60in',):
        bp_in = bp_out = 60_000
    if type_ in ('80in',):
        bp_in = bp_out = 80_000

    # rescale border length if TAD is too small
    tad_len = tad_stop - tad_start
    if tad_len <= 0:
        raise EmptyTAD(row)
    
    if tad_len < 2*bp_in:
        bp_in = tad_len // 4

    # assert that TADs are not overlapping
    # rescale borders if they would overlap
    dist_prev = (tad_start - row.prev_tad_stop) \
        if (
            not np.isnan(row.prev_tad_stop)
            and row.chrname == row.prev_tad_chr
        ) else float('inf')
    dist_next = (row.next_tad_start - tad_stop) \
        if (
            not np.isnan(row.next_tad_start)
            and row.chrname == row.next_tad_chr
        ) else float('inf')
    dist_min = min(dist_prev, dist_next)
    if dist_min < 0:
        raise OverlappingTADS(row)
    
    if dist_min < 2*bp_out:
        bp_out = dist_min // 4

    # final sanity checks
    bp_in = int(bp_in)
    bp_out = int(bp_out)
    assert bp_in >= 0 and bp_out >= 0

    # return appropriate ranges
    if type_ in ('20in', '40in', '60in', '80in'):
        return (
            range(tad_start, tad_start+bp_in),
            range(tad_start+bp_in, tad_stop-bp_in),
            range(tad_stop-bp_in, tad_stop)
        )
    elif type_ in ('20out', '40out'):
        return (
            range(tad_start-bp_out, tad_start),
            range(tad_start, tad_stop),
            range(tad_stop, tad_stop+bp_out)
        )
    elif type_ in ('20inout', '40inout'):
        return (
            range(tad_start-bp_out, tad_start+bp_in),
            range(tad_start+bp_in, tad_stop-bp_in),
            range(tad_stop-bp_in, tad_stop+bp_out)
        )


def parse_tad_annotations(type_, fname=tad_data_fname):
        print(f' > Parsing TADs ({type_})', file=sys.stderr)
        df_tad = pd.read_table(fname)
        df_tad['prev_tad_stop'] = df_tad.tad_stop.shift(1)
        df_tad['next_tad_start'] = df_tad.tad_start.shift(-1)
        df_tad['prev_tad_chr'] = df_tad.chrname.shift(1)
        df_tad['next_tad_chr'] = df_tad.chrname.shift(-1)

        error_counter = collections.defaultdict(int)
        res = collections.defaultdict(RangeDict)
        for row in tqdm(df_tad.itertuples(), total=df_tad.shape[0]):
            try:
                rb1, rt, rb2 = get_tad_lengths(row, type_)
            except EmptyTAD as ex:
                error_counter['empty_tad'] += 1
            except OverlappingTADS as ex:
                error_counter['overlapping_tads'] += 1
                
            # normalize chromosome name
            chrom = row.chrname
            if chrom.startswith('chr'):
                chrom = chrom[3:]

            # store range-associations
            res[chrom][rb1] = 'boundary'
            res[chrom][rt] = 'tad'
            res[chrom][rb2] = 'boundary'
            
        if error_counter:
            print('TAD errors:')
            for k,v in sorted(error_counter.items()):
                print(f' > {k}: {v}')

        return dict(res)

### Do work

In [None]:
def access_range_dict(row, dict_):
    range_dict_ = dict_.get(row['chromosome'], None)
    if range_dict_ is None:
        return 'undef'
    
    return range_dict_[row['position']]

In [None]:
tad_anno_20in = parse_tad_annotations('20in')
df_snppos['TAD_20in'] = df_snppos.progress_apply(lambda x: access_range_dict(x, tad_anno_20in), axis=1)

tad_anno_40in = parse_tad_annotations('40in')
df_snppos['TAD_40in'] = df_snppos.progress_apply(lambda x: access_range_dict(x, tad_anno_40in), axis=1)

tad_anno_20out = parse_tad_annotations('20out')
df_snppos['TAD_20out'] = df_snppos.progress_apply(lambda x: access_range_dict(x, tad_anno_20out), axis=1)

tad_anno_40out = parse_tad_annotations('40out')
df_snppos['TAD_40out'] = df_snppos.progress_apply(lambda x: access_range_dict(x, tad_anno_40out), axis=1)

tad_anno_20inout = parse_tad_annotations('20inout')
df_snppos['TAD_20inout'] = df_snppos.progress_apply(lambda x: access_range_dict(x, tad_anno_20inout), axis=1)

tad_anno_40inout = parse_tad_annotations('40inout')
df_snppos['TAD_40inout'] = df_snppos.progress_apply(lambda x: access_range_dict(x, tad_anno_40inout), axis=1)

In [None]:
df_snppos.head()

## Merge into DisGeNET

In [None]:
df_final = df.copy()
df.shape

In [None]:
df_final = df_final.merge(df_onto, on='UMLS_CUI')
df_final.shape

In [None]:
df_final = df_final.merge(df_snppos, on='snpId')
df_final.shape