In [1]:
import os
import os.path as osp
import numpy as np
import pandas as pd
import torch
import scipy
import scipy.stats
from scipy.signal import find_peaks
from scipy.sparse import csr_matrix
from sklearn import metrics
import matplotlib.pyplot as plt
import pickle
import gffutils
from tqdm import tqdm
import itertools
import coolbox
from coolbox.api import *
import warnings
import sqlite3
import json

warnings.filterwarnings('ignore')

In [2]:
def set_diagonal(mat, value=0):
    if mat.shape[0] - mat.shape[1]:
        raise ValueError(
            'Matrix is not square ({}, {})'.format(mat.shape[0], mat.shape[1])
        )
    l = mat.shape[0]
    idx = np.arange(l)
    mat[idx[:-1], idx[1:]], mat[idx[1:], idx[:-1]] = value, value

    return mat


def load_pred(pred_dir, ct, chrom, pred_len=200, avg_stripe=False):
    file = osp.join(pred_dir, ct, 'prediction_{}_chr{}.npz'.format(ct, chrom))
    temp = np.load(file)['arr_0']
    chrom_len = temp.shape[0]
    prep = np.insert(temp, pred_len, 0, axis=1)
    mat = np.array([
        np.insert(np.zeros(chrom_len+pred_len+1), i, prep[i]) for i in range(chrom_len)
    ])
    summed = np.vstack((
        np.zeros((pred_len, mat.shape[1])), mat
    )).T[:chrom_len+pred_len, :chrom_len+pred_len]
    if avg_stripe:
        summed = (summed + np.vstack((
            np.zeros((pred_len, mat.shape[1])), mat
        ))[:chrom_len+pred_len, :chrom_len+pred_len])/2
    
    pred = set_diagonal(summed[pred_len:-pred_len, pred_len:-pred_len])

    return pred


def load_database(db_file, gtf_file):
    if osp.isfile(db_file):
        db = gffutils.FeatureDB(db_file)
    else:
        print('creating db from raw. This might take a while.')
        db = gffutils.create_db(gtf_file, db_file)
    
    return db

In [3]:
input_dir = '/data/leslie/suny4/processed_input/'
pred_dir = '/data/leslie/suny4/predictions/chromafold/'
ct = 'alexia_am_gfp_myc_thelp'
chrom = 13
db_file = '/data/leslie/suny4/data/chrom_size/gencode.vM10.basic.annotation.db'

In [None]:
pred = load_pred(pred_dir, ct, chrom, avg_stripe=True)