In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%load_ext line_profiler

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
import glob
from tqdm import tqdm
import os
import h5py
from pyliftover import LiftOver
from collections import defaultdict

from loop_helpers import read_loop_frame, monte_carlo, mk_sum, loop_monte_carlo

from NucFrames import NucFrames

np.random.seed(42)

In [None]:
frame_files = glob.glob("/mnt/SSD/LayeredNuc/frames/*.hdf5")
cell_labels = [ "Cell {}".format(i + 1) for i, _ in enumerate(frame_files) ]
nfs = NucFrames(frame_files)

In [None]:
# Load the loops into a dataframe
loop_file = "/home/lpa24/dev/cam/data/rao_et_al_data/GSE63525_CH12-LX_HiCCUPS_looplist_with_motifs.txt"
df = read_loop_frame(loop_file)
cis = df[df["chr1"] == df["chr2"]]
trans = df[df["chr1"] != df["chr2"]]
print(cis.shape)

In [None]:
print(cis.shape)
new_cis = []
for chrm, df in cis.groupby('chr1'):
    idx_a, _, valid_a = nfs.chrm_bp_to_idx(chrm, df["a"].values.astype(np.int32))
    idx_b, _, valid_b = nfs.chrm_bp_to_idx(chrm, df["b"].values.astype(np.int32))
    valid = np.logical_and(valid_a, valid_b)

    idx_a = idx_a[valid]
    idx_b = idx_b[valid]
    
    new_df = df[valid].copy()
    
    new_df["idx_a"] = idx_a
    new_df["idx_b"] = idx_b
    new_cis.append(new_df)
cis = pd.concat(new_cis)
print(cis.shape)
    

In [None]:
invalid = defaultdict(set)
for nuc_file in nuc_files:
    try:
        idx_dict, track_dict = nf.track_to_idx("rt_viol", nuc_file)
    except KeyError:
        pass
    else:
        for chrm, v in idx_dict.items():
            invalid[chrm] |= set(v.flatten())
        
sphase = {}
for chrm, vals in invalid.items():
    vals = np.array(list(vals))
    vals.sort()
    sphase[chrm] = vals

In [None]:
print(cis.shape)
filtered_cis = []
for c, f in cis.groupby(["chr1"]):
    vals = sphase[c]
    f = f[np.logical_and(~f["idx_a"].isin(vals), ~f["idx_b"].isin(vals))]
    filtered_cis.append(f)
filtered_cis = pd.concat(filtered_cis)
print(filtered_cis.shape)

In [None]:
filtered_cis = cis
contact_percentile = 90

samples = 500000

truths = 0
chrm_count = 0

total = 0

for (chrm, _), df in filtered_cis.groupby(["chr1", "chr2"]):
    chrm_count += 1
    print("Chromosome {}".format(chrm))
    dists_stack = []
    for nf in nfs:
        dist_arr = nf.chrms[chrm].dists
        contact_dist = nf.chrms[chrm].contact_dists[contact_percentile]
        dist_arr = dist_arr <= contact_dist
        dists_stack.append(dist_arr)
        
    dists_stack = np.array(dists_stack)
    dists_count = np.sum(dists_stack, axis=0)
    
    idx_a = df["idx_a"]
    idx_b = df["idx_b"]
    
    loop_total = np.sum(dists_count[idx_a, idx_b])
    
    total += loop_total
    
    max_idx = dists_count.shape[0] - 1
    test_loop_total = np.sum(dists_count[max_idx - idx_a, max_idx - idx_b])
    print(test_loop_total)
    #test_sum = mk_sum(dists_count, np.array([idx_a]), np.array([idx_b]))
    
    seps = np.abs(idx_a - idx_b)
    #print(list(zip(zip(seps, idx_a), idx_b)))
    #%lprun -m loop_helpers monte_carlo(dists_count, seps, loop_total, idx_a, idx_b, B=samples)
    plt.title("Chromosome: {}".format(chrm))
    plt.axvline(test_loop_total, color='y')
    (truth_count, p) = monte_carlo(dists_count, seps, loop_total, idx_a, idx_b, B=samples, plot=True)
    truths += truth_count
    
print((1 + truths) / (1 + (samples * chrm_count)))
print(total)

In [None]:
contact_percentile = 90

B = 10000
ds = []
for (chrm, _), df in filtered_cis.groupby(["chr1", "chr2"]):
    idx_a = df["idx_a"]
    idx_b = df["idx_b"]
    
    d = df.copy()
    
    idx_a = d["idx_a"] = idx_a
    idx_b = d["idx_b"] = idx_b
    d["idx_sep"] = d["idx_b"] - d["idx_a"]
    #d = d[d["idx_sep"] >= 6]
    
    contact_arr_stack = []
    for label, nf in zip(cell_labels, nfs):
        dist_arr = nf.chrm[chrm].dists
        contact_dist = nf.chrm[chrm].contact_dists[contact_percentile]
        contact_arr = dist_arr <= contact_dist
        contact_arr_stack.append(contact_arr)
        d[label] = contact_arr[d["idx_a"], d["idx_b"]]
        
    sum_pseudo_contact_mat = np.sum(np.array(contact_arr_stack), axis=0)
        
    p_vals = []
    for _, row in d.iterrows():
        p = loop_monte_carlo(row.idx_a, row.idx_b, sum_pseudo_contact_mat, B=B)
        p_vals.append(p)
        
    p_vals = np.array(p_vals)
    d["loop_p_val"] = p_vals
    ds.append(d.copy())
    
    # Calculate cellwise pvalues
    seps = np.abs(idx_a - idx_b)
    cell_pval_dicts = []
    for nf in nfs:
        dist_arr = nf.chrms[chrm].dists
        contact_dist = nf.chrms[chrm].contact_dists[contact_percentile]
        contact_arr = dist_arr <= contact_dist
        
        cell_loop_total = np.sum(contact_arr[idx_a, idx_b])
        cell_truth_count, cell_pval = monte_carlo(contact_arr, seps, cell_loop_total, idx_a, idx_b, B=B)
        cell_pval_dicts.append({ "cell": name, "pval": cell_pval })
    cell_pval = pd.DataFrame(cell_pval_dicts)
        
    d = d.sort_values(by="loop_p_val", axis=0)
    
    fig = plt.figure(figsize=(25,5))
    fig.suptitle("Chromosome {}".format(chrm))
    ax = fig.add_subplot(1,1,1)
    ax.tick_params(labelsize=8, axis="x")
    ax.tick_params(labelsize=15, axis="y")

    
    main = sns.heatmap(d.ix[:, 7:15].T, ax=ax, cbar=False, xticklabels=d.idx_sep.values)
    #main.yaxis.set_yticks(rotation=0)
    for l in main.yaxis.get_ticklabels():
        l.set_rotation(0)
    divider = make_axes_locatable(ax)
    
    loop_pval_ax = divider.append_axes("top", size="20%")
    #loop_pval_ax.plot(np.arange(d.shape[0]), d.loop_p_val.values)
    #sns.barplot(np.arange(d.shape[0]), d.loop_p_val.values, ax=loop_pval_ax, color="b", linewidth=0)
    loop_pval_ax.bar(np.arange(d.shape[0]), d.loop_p_val.values, width=1)
    loop_pval_ax.set_xlim(0, d.shape[0])
    #pval_ax.get_yaxis().set_visible(False)
    loop_pval_ax.get_xaxis().set_ticks([])
    loop_pval_ax.get_yaxis().set_ticks([0, 0.5, 1.0])
    
    """
    cell_pval_ax = divider.append_axes("right", size="10%")
    cell_pval_ax.barh(np.arange(cell_pval.shape[0]), 
                     cell_pval.pval.values, 
                     height=1)
    cell_pval_ax.set_xlim(0, 1.0)
    cell_pval_ax.get_yaxis().set_visible(False)
    cell_pval_ax.xaxis.tick_top()
    for label in cell_pval_ax.xaxis.get_ticklabels():
        label.set_rotation(-90)
 
    """
    plt.savefig("/home/lpa24/dev/cam/data/figures/extended_data_fig_7/chrm_{}_loop_pval.svg".format(chrm))

In [None]:
d = pd.concat(ds)
print(d.columns)
d.to_csv("/home/lpa24/dev/cam/data/figures/loop_pseudocontact_analysis.csv")

In [None]:
# Main figure
contact_percentile = 90
d = filtered_cis.copy()

d["idx_sep"] = d["idx_b"] - d["idx_a"]

d = d[d.idx_sep >= 7]


new_dfs = []
for (chrm, _), f in d.groupby(["chr1", "chr2"]):
    contact_arr_stack = []
    
    new_f = f.copy()
    for label, name in zip(cell_labels, nf.nuc_names):
        dist_arr = nf.dists[name, chrm, chrm]
        contact_dist = nf.contact_dists[name][contact_percentile]
        contact_arr = dist_arr <= contact_dist
        contact_arr_stack.append(contact_arr)
        new_f[label] = contact_arr[f["idx_a"], f["idx_b"]]
    new_dfs.append(new_f)
d = pd.concat(new_dfs)

names = d.apply(lambda row: "chr{}, ({:.1f}-{:.1f})Mb".format(row["chr1"], row["a"] / 1e6, row["b"] / 1e6), axis=1)
d["names"] = names

p_vals = []
for _, row in d.iterrows():
    chrm = row["chr1"]
    contact_arr_stack = []
    for label, name in zip(cell_labels, nf.nuc_names):
        dist_arr = nf.dists[name, chrm, chrm]
        contact_dist = nf.contact_dists[name][contact_percentile]
        contact_arr = dist_arr <= contact_dist
        contact_arr_stack.append(contact_arr)
        
    sum_pseudo_contact_mat = np.sum(np.array(contact_arr_stack), axis=0)
    p = loop_monte_carlo(row.idx_a, row.idx_b, sum_pseudo_contact_mat, B=B)
    p_vals.append(p)

p_vals = np.array(p_vals)
d["loop_p_val"] = p_vals

fig = plt.figure(figsize=(15,5))
#fig.suptitle("Chromosome {}".format(chrm))
ax = fig.add_subplot(1,1,1)
ax.tick_params(labelsize=8, axis="x")
ax.tick_params(labelsize=15, axis="y")
m = d.ix[:, 7:15].T

#CALC NUMBER OF LOOPS THAT NEVER CONTACT
print(np.sum(m.sum(axis=0) == 0))

print(m.shape)

order = np.argsort(np.sum(m.values, axis=0))
names = d.names.values[order]
order = m.sum(axis=0).sort_values().index

main = sns.heatmap(m[order], ax=ax, cbar=False, xticklabels=names)

names_order_lookup = { name: i for i, name in enumerate(names) }
d["name_sort"] = d.apply(lambda x: names_order_lookup[x["names"]], axis=1)
d=d.sort_values(["name_sort"])

divider = make_axes_locatable(ax)
loop_pval_ax = divider.append_axes("top", size="20%")
#loop_pval_ax.plot(np.arange(d.shape[0]), d.loop_p_val.values)
#sns.barplot(np.arange(d.shape[0]), d.loop_p_val.values, ax=loop_pval_ax, color="b", linewidth=0)
loop_pval_ax.bar(np.arange(d.shape[0]), d.loop_p_val.values, width=1)
loop_pval_ax.set_xlim(0, d.shape[0])
#pval_ax.get_yaxis().set_visible(False)
loop_pval_ax.get_xaxis().set_ticks([])
loop_pval_ax.get_yaxis().set_ticks([0, 0.5, 1.0])


plt.tight_layout()
plt.savefig("/home/lpa24/dev/cam/data/figures/fig5_e.svg")


In [None]:
filtered_cis["sep"] = filtered_cis["idx_b"] - filtered_cis["idx_a"]
print(filtered_cis[filtered_cis["chr1"] == "14"].sort_values(["sep"]))

In [None]:
def cumulative_plot(loops, nf, chrm, a, b):
    loop = loops[np.logical_and(np.logical_and(loops["chr1"] == chrm,
                         loops["a"] == a), loops["b"]==b)]
    
    idx_a = loop.idx_a.values[0]
    idx_b = loop.idx_b.values[0]
    sep = idx_b - idx_a
    idx_c = idx_b + sep
    
    dists = []
    controls = []
    for nf in nfs:
        dist_arr = nf.chrms[chrm].dists
        dist = dist_arr[idx_a, idx_b]
        control_dist = dist_arr[idx_b, idx_c]
        controls.append(control_dist)
        dists.append(dist)
    vals, base = np.histogram(dists)
    cumulative = np.cumsum(vals)
    plt.plot(base[:-1], cumulative, c='blue', label="loop")
    
    vals, base = np.histogram(controls)
    cumulative = np.cumsum(vals)
    plt.plot(base[:-1], cumulative, c='red', label="control")
    plt.legend()
    
for _, loop in filtered_cis[filtered_cis["sep"] <= 10].sort_values(["sep"], ascending=False).iterrows():
    a = loop["a"]
    b = loop["b"]
    try:
        cumulative_plot(filtered_cis, nf, loop["chr1"], a, b)
        plt.title("Chrm {}, {} - {}, {} sep".format("14", a, b, loop["sep"]))
        plt.show()
    except IndexError:
        pass

In [None]:
    d.columns

In [None]:
nf.nuc_names
list(nf.store["nuc_names"])