In [1]:
import os
import os.path as osp
import numpy as np
import pandas as pd
import scipy
import scipy.stats
from scipy.signal import find_peaks
from tqdm import tqdm
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings('ignore')
%matplotlib inline

In [12]:
def load_pred(pred_dir, ct, chrom, pred_len=200, avg_stripe=False):
    file = osp.join(pred_dir, ct, 'prediction_{}_chr{}.npz'.format(ct, chrom))
    temp = np.load(file)['arr_0']
    chrom_len = temp.shape[0]
    prep = np.insert(temp, pred_len, 0, axis=1)
    mat = np.array([
        np.insert(np.zeros(chrom_len+pred_len+1), i, prep[i]) for i in range(chrom_len)
    ])
    summed = np.vstack((
        np.zeros((pred_len, mat.shape[1])), mat
    )).T[:chrom_len+pred_len, :chrom_len+pred_len]
    if avg_stripe:
        summed = (summed + np.vstack((
            np.zeros((pred_len, mat.shape[1])), mat
        ))[:chrom_len+pred_len, :chrom_len+pred_len])/2
    
    return summed[pred_len:-pred_len, pred_len:-pred_len]


def quantile_norm(pred1, pred2):
    l = pred1.shape[0]
    pred1_diag = np.array([np.pad(np.diagonal(pred1, offset=i), (0, i), 'constant') for i in range(200)]).T
    pred2_diag = np.array([np.pad(np.diagonal(pred2, offset=i), (0, i), 'constant') for i in range(200)]).T
    pred = np.column_stack((pred1_diag.ravel(), pred2_diag.ravel()))
    df, df_sorted = pd.DataFrame(pred), pd.DataFrame(np.sort(pred, axis=0))
    df_mean = df_sorted.mean(axis=1)
    df_mean.index += 1
    df_qn = df.rank(method='min').stack().astype(int).map(df_mean).unstack()
    pred1_stripe, pred2_stripe = df_qn[0].values.reshape(-1,200), df_qn[1].values.reshape(-1,200)
    
    pred1_qn, pred2_qn = np.zeros_like(pred1), np.zeros_like(pred2)
    for i in range(200):
        idx = np.arange(l-i, dtype=int)
        pred1_qn[idx, idx+i] = pred1_qn[idx+i, idx] = pred1_stripe[:l-i, i]
        pred2_qn[idx, idx+i] = pred2_qn[idx+i, idx] = pred2_stripe[:l-i, i]
    
    return pred1_qn, pred2_qn


def quantile_norm_multi(preds):
    return


def topdom(mat, size):
    padmat = np.pad(mat, size, mode='constant', constant_values=np.nan)
    dim = padmat.shape[0]
    signal = np.array([
        np.nanmean(padmat[i-size:i+size, i-size:i+size]) for i in range(dim)
    ][size:-size])
    
    return signal


def generate_dim(mindim, maxdim, numdim):
    mindim, maxdim = max(1, mindim), min(100, maxdim)
    return np.linspace(mindim, maxdim, num=numdim, dtype=int)


def get_tads(mat, sizes):
    signal = np.array([topdom(mat, i) for i in tqdm(sizes)])
    rows, idxs = [], []
    for i in range(len(signal)):
        idx = find_peaks(signal[i], prominence=(0.25, ))[0]
        row = np.full_like(idx, i)
        rows.append(row)
        idxs.append(idx)
    tads = np.array([
        np.concatenate(rows, axis=None), np.concatenate(idxs, axis=None)
    ])
    
    return tads


def merge_tads(tads1, tads2, sizes, closethresh):
    alltads = []
    for s in sizes:
        n, m = tads1.shape[1], tads2.shape[1]
        tads1_s, tads2_s = tads1[1, tads1[0]==s], tads2[1, tads2[0]==s]
        i, j = 0, 0
        while i < n and j < m:
            if abs(tads1_s[i] - tads2_s[j]) <= closethresh:
                alltads.append([s, int(mean(tads1_s[i], tads2_s[j]))])
                i += 1
                j += 1
            else:
                if tads1_s[i] < tads2_s[j]:
                    alltads.append([s, tads1_s[i]])
                    i += 1
                else:
                    alltads.append([s, tads2_s[j]])
                    j += 1
                    
    return np.array(alltads)


def tads_to_coords(tads, sizes):
    coords = np.array([
        tads[1] - sizes[tads[0]], tads[1] + sizes[tads[0]]
    ])
    
    return coords


def get_tad_coords(mat1, mat2, mindim=10, maxdim=100, numdim=10, closethresh=2):
    sizes = generate_dim(mindim, maxdim, numdim)
    tads1, tads2 = get_tads(mat1, sizes), get_tads(mat2, sizes)
    alltads = merge_tads(tads1, tads2, sizes, closethresh)
    coords = tads_to_coords(alltads, sizes)
    
    return coords

In [10]:
%%time

chrom = 16
pred_dir = '/data/leslie/suny4/predictions/chromafold/'
ct1 = 'mycGCB_am_gfp_myc_gcb_thelp_sample'
ct2 = 'mycGCB_am_gfp_myc_gcb_nothelp_sample'

pred1 = load_pred(pred_dir, ct1, chrom, avg_stripe=True)
pred2 = load_pred(pred_dir, ct2, chrom, avg_stripe=True)
pred1_qn, pred2_qn = quantile_norm(pred1, pred2)

CPU times: user 5.6 s, sys: 2.68 s, total: 8.28 s
Wall time: 8.28 s


In [13]:
coords = get_tad_coords(pred1_qn, pred2_qn)
coords

100%|██████████| 10/10 [00:13<00:00,  1.33s/it]
100%|██████████| 10/10 [00:12<00:00,  1.21s/it]


IndexError: index 0 is out of bounds for axis 0 with size 0