In [None]:
import os
import pickle 

import h5py
from copy import deepcopy as dcopy

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import pandas as pd
import cooltools.lib.plotting
from cooltools.lib import numutils
import seaborn as sns
import matplotlib.lines as mlines
from matplotlib.lines import Line2D
import scipy.stats as ss

In [None]:
def filter_diags(hmap, ignore_diags):
    hmap_diag_filtered = np.copy(hmap)
    for i in range(ignore_diags):
        np.fill_diagonal(hmap_diag_filtered[:, i:], np.nan) 
        np.fill_diagonal(hmap_diag_filtered[i:, :], np.nan)
    return hmap_diag_filtered

def calc_ixns(subcomps_coarsened, ooe, n_diags=[0,0]):
    if n_diags:
        trans_ooe = filter_diags(ooe, n_diags[1])
        cis_ooe = filter_diags(ooe, n_diags[0])
    else:
        trans_ooe = cis_ooe = ooe

    mean_ixns = np.nan*np.ones((4,4))
    for i in range(3):
        for j in range(3):
            mean_ixns[i][j] = np.nanmean(trans_ooe[subcomps_coarsened==i].T[subcomps_coarsened==j])

    x_loc = subcomps_coarsened == 2

    XX_inter_mask = np.zeros((ooe.shape[0], ooe.shape[0]))
    XX_intra_mask = np.zeros((ooe.shape[0], ooe.shape[0]))
    for i in np.where(x_loc)[0]:
        for j in np.where(x_loc)[0]:
            XX_inter_mask[i,j] = (np.sum(x_loc[i:j]) < j-i)
            XX_intra_mask[i,j] = (np.sum(x_loc[i:j]) == j-i)
    XX_inter_mask = XX_inter_mask.astype(bool)
    XX_intra_mask = XX_intra_mask.astype(bool)

    mean_inter_X = np.nanmean(np.ma.array(trans_ooe, mask=~XX_inter_mask).compressed())    
    mean_intra_X = np.nanmean(np.ma.array(cis_ooe, mask=~XX_intra_mask).compressed())
    
    mean_ixns[2,2] = mean_inter_X
    mean_ixns[3,3] = mean_intra_X
    
    return mean_ixns



In [None]:
def convert_exp_to_df(experimental_ixns, idx):
    aa_experimental = experimental[0,0]
    ab_experimental = experimental[0,1]
    bb_experimental = experimental[1,1]
    ax_experimental = experimental[0,2]
    bx_experimental = experimental[1,2]
    xx_inter_experimental = experimental[2,2]
    xx_intra_experimental = experimental[3,3]
    exp_df = pd.DataFrame({
        "AA": aa_experimental,
        "AB": ab_experimental,
        "BB": bb_experimental,
        "AS": ax_experimental,
        "BS": bx_experimental,
        "SS_inter": xx_inter_experimental,
        "SS_intra": xx_intra_experimental,
    }, index=idx)
    exp_df['AA_attr'] = np.nan
    exp_df['BB_attr'] = np.nan
    exp_df['SS_attr'] = np.nan
    
    return exp_df

In [None]:
def X_comp_plotting_v2(ixn_df, experimental_ixns):
    ixn_df = ixn_df.rename(columns={
        "AX":"AS",
        "BX": "BS",
        "XX_inter": "SS_inter",
        "XX_intra":"SS_intra",
        "XX_attr":"SS_attr"
    })
    ixn_df = ixn_df.reset_index(drop=True)
    
    y_cats = ["AS", "BS", 'SS_inter']
    markers = ['o', 'X', 's']

    # Defining coloring for each simulation
    ixn_df['comp_sim'] = abs(np.diff(ixn_df[y_cats[:2]])) < 0.1
    
    ixn_df['c'] = np.where(((ixn_df["SS_intra"] < 1.2*experimental_ixns[3,3]) & (ixn_df["comp_sim"] == True) & (ixn_df["SS_inter"] < 1)), 
                                "#5F2052", "#808080")


    plotting_df = ixn_df
    #ixn_df.loc[ixn_df.index[-1], 'c'] = "#C43F53" 
    
    ordering = {
        '#808080': 1,
        '#5F2052': 3,
        #'#C43F53': 5,
    }
    

    fig = plt.figure(figsize=(12, 8))
    plt.xlim((1.55, 2.65))
    plt.ylim((0.45, 1.25))
    for c in set(plotting_df.c):
        for row in plotting_df.loc[plotting_df['c'] == c].index:
            plt.plot(3*[plotting_df.iloc[row]['SS_intra']], plotting_df.iloc[row][y_cats].values, c=c, zorder=ordering[c])


    for c in set(plotting_df.c):
        for marker, y in zip(markers, y_cats):
            sns.scatterplot(data=plotting_df.loc[plotting_df['c'] == c], x="SS_intra", y=y, c=c, marker=marker, s=70, zorder=ordering[c]+1, linewidth=0.4)

    #for marker, y in zip(markers, y_cats):
    #    sns.scatterplot(data=ixn_df.loc[ixn_df['c'] == True], x="SS_intra", y=y, c='#5F2052', marker=marker, s=70, zorder=4, linewidth=0.2)

    leg_list = []
    for marker, y in zip(markers, y_cats):
        leg_list.append(mlines.Line2D([], [], color='#444444', marker=marker,
                              markersize=7, label=y))
    plt.legend(handles=leg_list)

    plt.title("S Interaction Profiles")
    plt.ylabel("S Contact Enrichment, Off Diagonal")
    plt.xlabel("S Contact Enrichment, On Diagonal");
    
    return(ixn_df.loc[ixn_df['c'] != '#808080'])


In [None]:
def AB_comp_plotting(ixn_df_sub, experimental, err_range, title, save_path=None):
    lo = 1-err_range
    hi = 1+err_range
    
    plotting_df = ixn_df_sub[['AA', 'AB', "BB", "index"]].set_index('index').stack().to_frame().reset_index()\
                                                        .rename(columns={
                                                            "index":"sim",
                                                            "level_1":"Interaction Type",
                                                            0: "Average Enrichment"
                                                        })
    
    good_sims = np.array(ixn_df_sub.loc[
            (ixn_df_sub["AA"] < round(hi*experimental[0,0],3)) & (ixn_df_sub["AA"] > round(lo*experimental[0,0],3)) \
            & (ixn_df_sub["BB"] < round(hi*experimental[1,1],3)) & (ixn_df_sub["BB"] > round(lo*experimental[1,1],3)) 
        ].index, dtype=int)
    
    p1 = np.array((len(plotting_df.index)-len(good_sims))*['#808080'])
    p2 = np.array(len(good_sims)*['#5F2052'])
    
    fig, ax = plt.subplots()
    sns.pointplot(ax=ax, data=plotting_df.loc[~plotting_df['sim'].isin(good_sims)], x="Interaction Type", y="Average Enrichment", hue="sim", palette=p1)
    plt.setp(ax.collections, linewidth=0.2, edgecolors='white', zorder=1, alpha=0.8, sizes=[20]) 
    plt.setp(ax.lines, linewidth=1.7, zorder=0, alpha=0.7)
    if len(good_sims) > 0:
        sns.pointplot(ax=ax, data=plotting_df.loc[plotting_df['sim'].isin(good_sims)], x="Interaction Type", y="Average Enrichment", hue="sim", palette=p2)
        plt.setp(ax.collections, linewidth=0.2, edgecolors='white', zorder=3, alpha=0.8, sizes=[20]) 
        plt.setp(ax.lines, linewidth=1.7, zorder=2, alpha=0.7)

    leg_lines = [Line2D([0], [0], color='#5F2052', lw=2),
                Line2D([0], [0], color='#808080', lw=2)]
    plt.legend(leg_lines, [f"Good A-B compartments", "Bad A-B compartments"], fontsize='small')
    
    ax.set_title(f'AA & BB error range: +/- {int(100*err_range)}% experimental', fontsize='small')
    fig.suptitle(f'A-B Compartmentalization for {title}', fontsize='medium');
    if save_path:
        save_fh = f'AB_{title.replace(", ", "_")}.pdf'
        plt.savefig(f'{save_path}/{save_fh}', format='pdf')
    
    return good_sims

In [None]:
def make_ixn_df(base_path, lambda_, d_X, d_AB, mtx_fh, ABs):
    sim_group_path = f'{base_path}/sims_mon1d/lambda-{lambda_}_dX-{d_X}_dAB-{d_AB}'
    cols = ["AA_attr", "BB_attr", "AA","AB","AS","BB","BS","SS_inter","SS_intra"]
    concat_ixn_list = []
    for AA, BB in ABs:
        comp_dir = f'AA{AA:.2f}_BB{BB:.2f}_XX{XX:.2f}'
        sim_dir = 'SC-1JoinedPol_errTol-0.01_coll-0.01'
        hmap_path = f'{sim_group_path}/{comp_dir}/{sim_dir}/results/heatmaps'
        if os.path.exists(f'{hmap_path}/{mtx_fh}'):
            with open(f'{hmap_path}/{mtx_fh}', 'rb') as o:
                mean_ixns = np.load(o)
            ixns_arr = mean_ixns[np.triu(np.ones(16).reshape([4,4])) >0]
            ixns_arr = ixns_arr[~np.isnan(ixns_arr)]
        else:
            ixns_arr = np.array(7*[np.nan])
        df_arr = np.concatenate([np.array([AA, BB]),ixns_arr])
        concat_ixn_list.append(df_arr)

    ixn_df = pd.DataFrame(concat_ixn_list, columns=cols)
    ixn_df[["lambda", 'd_S', 'd_AB', 'SS_attr']] = lambda_, d_X, d_AB, 0
    return(ixn_df)

In [None]:
cutoff_rad = 7.5
binSize = 10

with open(f'chr15:0-6p5Mb.pkl','rb') as monInfo_file:
    monInfo = pickle.load(monInfo_file)
mon_id_tmp = monInfo['mon']
mon_id = np.array([mon_id_tmp[i] for i in range(0, len(mon_id_tmp), binSize)])

In [None]:
base_path = f'/net/levsha/share/emily/notebooks/sims/bombyx/flagship/compartment_sweep' # fix the double base_path later
comp_params_only = pd.read_csv(f"{base_path}/AB_param_sets_15Sep23.csv", sep="\t", index_col=0)
ABs = comp_params_only[["AA_attr", "BB_attr"]].values
XX = 0

lambda_list = [55, 110]
dX_list     = [19, 37, 55, 110]
AB_fac  = np.array([5,10,np.inf])
max_lodAB = 0.3

In [None]:
comp_sizes = np.diff(np.concatenate([np.array([0]), np.where(np.diff(mon_id) != 0)[0]]))
biggest_comp = max(comp_sizes)
print(f'biggest compartment is {biggest_comp} bins')

ignore_diags_ooe = biggest_comp + 2
ignore_diags_ice = 0
ignore_diags = [ignore_diags_ice, ignore_diags_ooe]
    
with open(f'/net/levsha/scratch2/emily/flagship/hic/mean_ixns_binSize-10000_diagsFiltered-{ignore_diags}.npy', 'rb') as f:
    experimental = np.load(f)

In [None]:
sim_dir = 'SC-1JoinedPol_errTol-0.01_coll-0.01'
mtx_fh = f'mean_ixns_binSize-{binSize}_diagsFiltered-{ignore_diags}.npy'

for lambda_ in lambda_list:
    for d_X in dX_list:
        dAB_list = d_X*AB_fac
        dAB_list = dAB_list[np.where(lambda_/dAB_list < max_lodAB)[0]]
        dAB_list = np.round(np.where(dAB_list == np.inf, 0, dAB_list)).astype(int)
        for d_AB in dAB_list:
            d_dir = f'lambda-{lambda_}_dX-{d_X}_dAB-{d_AB}'
            for AA, BB in ABs:
                comp_dir = f'AA{AA:.2f}_BB{BB:.2f}_XX{XX:.2f}'
                sim = f'{comp_dir}__{sim_dir}'
                sim_dir_path = f'{base_path}/sims_mon1d/{d_dir}/{comp_dir}/{sim_dir}'
                hmap_path = f'{sim_dir_path}/results/heatmaps'
                hmap_fh = f'{sim}__cutoff-{cutoff_rad:04.1f}_binSize-{binSize}_IC_OOE_chainMap.npy'
                if not os.path.exists(f'{hmap_path}/{hmap_fh}'):
                    print(f'{hmap_fh} does not exist')
                    continue
                if os.path.exists(f'{hmap_path}/{mtx_fh}'):
                    continue
                else:
                    with open(f'{hmap_path}/{hmap_fh}', 'rb') as f:
                        hmap = np.load(f)

                    mean_ixns = calc_ixns(mon_id, hmap, n_diags=ignore_diags)

                    with open(f'{hmap_path}/{mtx_fh}', 'wb') as o:
                        np.save(o, mean_ixns)   

In [None]:
concat_list = []
for lambda_ in lambda_list:
    for d_X in dX_list:
        dAB_list = d_X*AB_fac
        dAB_list = dAB_list[np.where(lambda_/dAB_list < max_lodAB)[0]]
        dAB_list = np.round(np.where(dAB_list == np.inf, 0, dAB_list)).astype(int)
        for d_AB in dAB_list:
            tmp_df = make_ixn_df(base_path, lambda_, d_X, d_AB, mtx_fh, ABs)
            concat_list.append(tmp_df)
full_df = pd.concat(concat_list, axis=0).reset_index(drop=True)     
full_df = full_df[["AA_attr", "BB_attr", 'SS_attr', "lambda", 'd_S', 'd_AB', "AA","AB","BB","AS","BS","SS_inter","SS_intra"]]

In [None]:
full_df = pd.concat(concat_list, axis=0).reset_index(drop=True)     
full_df = full_df[["AA_attr", "BB_attr", 'SS_attr', "lambda", 'd_S', 'd_AB', "AA","AB","BB","AS","BS","SS_inter","SS_intra"]]

In [None]:
full_df["d_ratio"] = full_df["d_AB"]/full_df["d_S"]
full_df["AB_compScore"] = (full_df["AA"]+full_df["BB"]-full_df["AB"])/(full_df["AA"]+full_df["BB"]+full_df["AB"])
full_df["d_ratio"] = full_df["d_AB"]/full_df["d_S"]
full_df['lod_S'] = full_df['lambda']/full_df['d_S']

In [None]:
exp_df = convert_exp_to_df(experimental, [0])
exp = pd.Series(exp_df.iloc[0])

In [None]:
full_df['MSE'] = full_df.sub(exp, axis=1).apply(np.square).apply(np.nanmean, axis=1)
full_df["Euc_dist"] = np.sqrt(full_df.sub(exp, axis=1).apply(np.square).apply(np.sum, axis=1))
full_df["rank"] = ss.rankdata(full_df["MSE"])

In [None]:
full_df["index"] = full_df.index
good_sim_ids = []

for lambda_ in lambda_list:
    for d_X in dX_list:
        dAB_list = d_X*AB_fac
        dAB_list = dAB_list[np.where(lambda_/dAB_list < max_lodAB)[0]]
        dAB_list = np.round(np.where(dAB_list == np.inf, 0, dAB_list)).astype(int)
        for d_AB in dAB_list:
            ixn_df_sub = full_df.loc[(full_df['lambda'] == lambda_) & (full_df.d_S == d_X) & (full_df.d_AB == d_AB)]
            if len(ixn_df_sub['AA'].unique()) == 1:
                continue
            good_sim_ids.append(AB_comp_plotting(ixn_df_sub, experimental, 0.1, f'Lambda={lambda_}, dS={d_X}, dAB={d_AB}'))

narrowed_df = full_df.iloc[np.concatenate(good_sim_ids)]

In [None]:
final_df_wIndex = X_comp_plotting_v2(narrowed_df, experimental)
final_df = final_df_wIndex.reset_index(drop=True)

In [None]:
full_df.to_csv("full_param_sets_25Apr24.csv", sep="\t") #
full_df.loc[full_df['rank'] <= 50][["AA_attr", "BB_attr", "SS_attr", "lambda", "d_S", "d_AB", "Euc_dist", "MSE", "rank"]].to_csv("final_param_sets_RankFiltering_25Apr24.csv", sep="\t") #

In [None]:
final_df.drop(['comp_sim', 'c'], axis=1).to_csv("final_param_sets_QualitativeFiltering_25Apr24.csv", sep="\t") #