In [1]:
import os
import os.path as osp
import numpy as np
import pandas as pd
import torch
import scipy
import scipy.stats
from scipy.signal import find_peaks
from scipy.sparse import csr_matrix
from sklearn import metrics
import matplotlib.pyplot as plt
import pickle
import gffutils
from tqdm import tqdm
import itertools
import coolbox
from coolbox.api import *
import warnings
import sqlite3
import json

warnings.filterwarnings('ignore')

In [2]:
def set_diagonal(mat, value=0):
    if mat.shape[0] - mat.shape[1]:
        raise ValueError(
            'Matrix is not square ({}, {})'.format(mat.shape[0], mat.shape[1])
        )
    l = mat.shape[0]
    idx = np.arange(l)
    mat[idx[:-1], idx[1:]], mat[idx[1:], idx[:-1]] = value, value

    return mat


def load_pred(pred_dir, ct, chrom, pred_len=200, avg_stripe=False):
    file = osp.join(pred_dir, ct, 'prediction_{}_chr{}.npz'.format(ct, chrom))
    temp = np.load(file)['arr_0']
    chrom_len = temp.shape[0]
    prep = np.insert(temp, pred_len, 0, axis=1)
    mat = np.array([
        np.insert(np.zeros(chrom_len+pred_len+1), i, prep[i]) for i in range(chrom_len)
    ])
    summed = np.vstack((
        np.zeros((pred_len, mat.shape[1])), mat
    )).T[:chrom_len+pred_len, :chrom_len+pred_len]
    if avg_stripe:
        summed = (summed + np.vstack((
            np.zeros((pred_len, mat.shape[1])), mat
        ))[:chrom_len+pred_len, :chrom_len+pred_len])/2
    
    pred = set_diagonal(summed[pred_len:-pred_len, pred_len:-pred_len])

    return pred


def load_database(db_file, gtf_file):
    if osp.isfile(db_file):
        db = gffutils.FeatureDB(db_file)
    else:
        print('creating db from raw. This might take a while.')
        db = gffutils.create_db(gtf_file, db_file)
    
    return db

In [3]:
input_dir = '/data/leslie/suny4/processed_input/'
pred_dir = '/data/leslie/suny4/predictions/chromafold/'
ct1 = 'alexia_am_gfp_myc_thelp'
ct2 = 'alexia_am_gfp_myc_nothelp'
chrom = 13
db_file = '/data/leslie/suny4/data/chrom_size/gencode.vM10.basic.annotation.db'

In [4]:
%%time

pred1 = load_pred(pred_dir, ct1, chrom, avg_stripe=True)
pred2 = load_pred(pred_dir, ct2, chrom, avg_stripe=True)

CPU times: user 4.22 s, sys: 3.17 s, total: 7.39 s
Wall time: 7.39 s


In [7]:
def quantile_normalize(preds):
    N, H, W = preds.shape
    assert H == W, f'Matrix is not square ({H}, {W})'
    pred_diag = np.column_stack((
        np.array([
            np.pad(np.diagonal(pred, offset=i), (0, i), 'constant') for i in range(200)
        ]).T.ravel() for pred in preds
    ))
    df, df_mean = pd.DataFrame(pred_diag), pd.DataFrame(np.sort(pred_diag, axis=0)).mean(axis=1)
    df_mean.index += 1
    pred_diag_qn = df.rank(method='min').stack().astype(int).map(df_mean).unstack().values.T.reshape(N, -1, 200)
    preds_qn = np.zeros_like(preds)
    for i in range(200):
        idx = np.arange(H - i, dtype=int)
        preds_qn[:, idx, idx+i] = preds_qn[:, idx+i, idx] = pred_diag_qn[:, :H-i, i]
    
    return preds_qn

In [8]:
%%time

preds_qn = quantile_normalize(np.array([pred1, pred2]))
pred1, pred2 = preds_qn[0], preds_qn[1]

CPU times: user 3.2 s, sys: 1.29 s, total: 4.48 s
Wall time: 4.48 s
