# Machine Learning Library Strategy

In [2]:
import os
import sys
from pathlib import Path

from IPython.display import display, HTML, Markdown
import numpy as np
import pandas as pd

from dask import dataframe as dd
from dask import delayed
from dask.distributed import Client

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns

# Project level imports
sys.path.insert(0, '../lib')
from ncbi_remap.plotting import make_figs
from ncbi_remap.normalization import cpm

# Connect to data store
store = pd.HDFStore('../sra.h5', mode='r')

In [2]:
from pymongo import MongoClient
try:
    with open('../output/.mongodb_host', 'r') as fh:
        host = fh.read().strip()
except FileNotFoundError:
    host = 'localhost'

mongoClient = MongoClient(host=host, port=27017)
db = mongoClient['sra']
ncbi = db['ncbi']

In [None]:
cluster = LocalCluster(n_workers=cpus, memory_limit=mem)

In [3]:
daskClient = Client()
daskClient

0,1
Client  Scheduler: tcp://127.0.0.1:42171  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 8  Cores: 8  Memory: 26.14 GB


In [4]:
samples = store['aln/complete'].srx.unique().tolist()

In [5]:
len(samples)

26545

In [6]:
labels = pd.DataFrame(list(ncbi.aggregate([
    {
        '$match': {
            '_id': {'$in': samples}
        }
    },
    {
        '$project': {
            '_id': 0,
            'srx': '$_id',
            'label': '$sra.experiment.library_strategy'
        }
    }
])))

labels.set_index('srx', inplace=True)

In [7]:
rnaseq = labels.query('label == "RNA-Seq"').index.unique().tolist()
dnaseq = labels.query('label == "WGS"').index.unique().tolist()
chipseq = labels.query('label == "ChIP-Seq"').index.unique().tolist()

In [8]:
len(rnaseq), len(dnaseq), len(chipseq)

(13829, 1970, 3003)

## Junction count distributions

In [9]:
@delayed
def read_juncs(srx):
    chroms = ['chrX', 'chr2L', 'chr2R', 'chr3L', 'chr3R', 'chr4', 'chrY']
    junc = pd.read_parquet(f'../aln-wf/output/junction_counts/{srx}.parquet')
    dat = junc.query(f'Site1_chr == {chroms} & Site1_chr == Site2_chr')[['PrimaryGene', 'srx', 'count']]
    dat.columns = ['FBgn', 'srx', 'count']
    dat.dropna(inplace=True)
    return dat.groupby(['FBgn', 'srx'])['count'].sum()

In [None]:
_size = 2000
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(30, 20), sharey=True)

# RNA-seq
futures = daskClient.compute([read_juncs(x) for x in rnaseq[:_size]])
juncs = pd.concat(daskClient.gather(futures))
sns.boxplot(x='srx', y='count', data=juncs.reset_index(), showfliers=False, color='grey', ax=ax1)
plt.setp(ax1.get_xticklabels(), visible=False);
ax1.set_title(f'RNA-Seq: {_size} samples')

# DNA-seq
futures = daskClient.compute([read_juncs(x) for x in dnaseq[:_size]])
juncs = pd.concat(daskClient.gather(futures))
sns.boxplot(x='srx', y='count', data=juncs.reset_index(), showfliers=False, color='grey', ax=ax2)
plt.setp(ax2.get_xticklabels(), visible=False);
ax2.set_title(f'DNA-Seq: {_size} samples')

# ChIP-seq
futures = daskClient.compute([read_juncs(x) for x in chipseq[:_size]])
juncs = pd.concat(daskClient.gather(futures))
sns.boxplot(x='srx', y='count', data=juncs.reset_index(), showfliers=False, color='grey', ax=ax3)
plt.setp(ax3.get_xticklabels(), visible=False)
ax3.set_title(f'ChIP-Seq: {_size} samples');
plt.tight_layout();



## Normalized read counts

In [9]:
@delayed
def get_norm_counts(srx):
    # Gene level counts
    gene = pd.read_parquet(f'../aln-wf/output/gene_counts/{srx}.parquet').reset_index()[['FBgn', 'srx', 'count']]
    gene['var_type'] = 'gene'

    # junction level counts
    chroms = ['chrX', 'chr2L', 'chr2R', 'chr3L', 'chr3R', 'chr4', 'chrY']
    junc = pd.read_parquet(f'../aln-wf/output/junction_counts/{srx}.parquet')
    junc = junc.query(f'Site1_chr == {chroms} & Site1_chr == Site2_chr')[['PrimaryGene', 'srx', 'count']]
    junc.columns = ['FBgn', 'srx', 'count']
    junc.dropna(inplace=True) 
    junc = junc.groupby(['FBgn', 'srx'])['count'].sum().reset_index()
    junc['var_type'] = 'junction'

    # intergenic counts
    inter_annot = pd.read_csv('../output/dmel_r6-11.intergenic.bed', sep='\t', header=None, names=['chrom', 'start', 'end', 'FBgn'], index_col='FBgn')
    inter_names = inter_annot.query(f'chrom == {chroms}').index.unique().tolist()
    inter = pd.read_parquet(f'../aln-wf/output/intergenic_counts/{srx}.parquet').query(f'FBgn == {inter_names}').reset_index()[['FBgn', 'srx', 'count']]
    inter['var_type'] = 'intergenic'

    # combine, normalize, aggregate
    df = pd.concat([gene, junc, inter])
    norm = cpm(df.set_index(['FBgn', 'srx', 'var_type']), log='log10').reset_index()
    medians = norm.groupby(['srx', 'var_type']).median().unstack()
    medians.columns = medians.columns.droplevel(0)
    return medians

In [10]:
futures = daskClient.compute([get_norm_counts(x) for x in samples])

In [None]:
data = pd.concat(daskClient.gather(futures))

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')

_dat = data.join(labels)
for g, _dd in _dat.groupby('label'):
    ax.scatter(_dd.gene, _dd.junction, _dd.intergenic, label=g)
    
ax.set_xlabel('gene')
ax.set_ylabel('junction')
ax.set_zlabel('intergenic')
plt.legend(loc=(1, 0.5))