# Saddleplots

## we'll just generate the data here - and them we'll reuse them in a separate notebook for plotting !

Welcome to the compartments and saddleplot notebook! 

This notebook illustrates cooltools functions used for investigating chromosomal compartments, visible as plaid patterns in mammalian interphase contact frequency maps.

These plaid patterns reflect tendencies of chromosome regions to make more frequent contacts with regions of the same type: active regions have increased contact frequency with other active regions, and intactive regions tend to contact other inactive regions more frequently. The strength of compartmentalization has been show to vary through the cell cycle, across cell types, and after degredation of components of the cohesin complex. 

In this notebook we:

* obtain compartment profiles using eigendecomposition
* calculate and visualize strength of compartmentalization using saddleplots

In [None]:
# import standard python libraries
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
import os, subprocess

In [None]:
# Import python package for working with cooler files and tools for analysis
import cooler
import cooltools.lib.plotting
import bioframe
import multiprocess as mp

In [None]:
%load_ext autoreload
%autoreload 2
# from saddle import saddleplot

In [None]:
# download test data
# this file is 145 Mb, and may take a few seconds to download
import bbi
import cooltools
import bioframe
from matplotlib.colors import LogNorm
from helper_func import saddleplot
from data_catalog import bws, bws_vlim, telo_dict

import saddle


In [None]:
from tqdm import tqdm
from tqdm.notebook import trange, tqdm
import warnings
import seaborn as sns

from mpire import WorkerPool
import warnings


## Calculating per-chromosome compartmentalization

We first load the Hi-C data at 100 kbp resolution. 

Note that the current implementation of eigendecomposition in cooltools assumes that individual regions can be held in memory-- for hg38 at 100kb this is either a 2422x2422 matrix for chr2, or a 3255x3255 matrix for the full cooler here.

In [None]:
# define genomic view that will be used to call dots and pre-compute expected

# Use bioframe to fetch the genomic features from the UCSC.
hg38_chromsizes = bioframe.fetch_chromsizes('hg38')
hg38_cens = bioframe.fetch_centromeres('hg38')
hg38_arms_full = bioframe.make_chromarms(hg38_chromsizes, hg38_cens)
# # remove "bad" chromosomes and near-empty arms ...
included_arms = hg38_arms_full["name"].to_list()[:44] # all autosomal ones ...
hg38_arms = hg38_arms_full[hg38_arms_full["name"].isin(included_arms)].reset_index(drop=True)

### pre-load coolers and pre-calculate expected ...

In [None]:
# cooler files that we'll work on :
binsize10 = 10_000
telo_clrs10 = { _k: cooler.Cooler(f"{_path}::/resolutions/{binsize10}") for _k, _path in telo_dict.items() }
binsize25 = 25_000
telo_clrs25 = { _k: cooler.Cooler(f"{_path}::/resolutions/{binsize25}") for _k, _path in telo_dict.items() }

In [None]:
def _job(packed_data, sample):
    # packed data -> exp_kwargs and a dict with coolers for each sample
    exp_kwargs, clr_dict = packed_data
    _clr = clr_dict[sample]
    # in order to use spawn/forkserver we have to import for worker
    from cooltools import expected_cis
    _exp = expected_cis( _clr, **exp_kwargs)
    return (sample, _exp)

# define expected parameters in the form of kwargs-dict:
exp_kwargs = dict(
    view_df=hg38_arms,
    intra_only=False,
    nproc=12
)

# have to use daemon=False, because _job is multiprocessing-based already ...
with WorkerPool(
    n_jobs=8,
    daemon=False,
    shared_objects=( exp_kwargs, telo_clrs10 ),
    start_method="forkserver",  # little faster than spawn, fork is the fastest
    use_dill=True,
) as wpool:
    results = wpool.map(_job, telo_clrs10, progress_bar=True)

# sort out the results ...
telo_exps_cis = {sample: _exp for sample, _exp in results}

### trans-expected second

In [None]:
def _job(packed_data, sample):
    # unpack data
    clr_dict, = packed_data
    exp_kwargs = dict(chunksize=1000000, nproc=12)
    from cooltools import expected_trans
    _clr = clr_dict[sample]
    _exp = expected_trans( _clr, **exp_kwargs).set_index(["region1", "region2"]).sort_index()
    return (sample, _exp)

# have to use daemon=False, because _job is multiprocessing-based already ...
with WorkerPool(
    n_jobs=8,
    daemon=False,
    shared_objects=(telo_clrs25, ),
    start_method="forkserver",
    use_dill=True,
) as wpool:
    results = wpool.map(_job, telo_clrs25, progress_bar=True)

# sort out the results ...
telo_exps_trans = {sample: _exp for sample, _exp in results}

# Load partition of the genome into clusters ...



## Saddleplots

A common way to visualize preferences captured by the eigenvector is by using saddleplots.

To generate a saddleplot, we first use the eigenvector to stratify genomic regions into groups with similar values of the eigenvector. These groups are then averaged over to create the saddleplot.
This process is called "digitizing".

Cooltools will operate with `digitized` bedgraph-like track with four columns. The fourth, or value, column is a categorical, as shown above for the first three bins. Categories have the following encoding:

    - `1..n` <-> values assigned to bins defined by vrange or qrange
    - `0` <-> left outlier values
    - `n+1` <-> right outlier values
    - `-1` <-> missing data (NaNs)
    
Track values can either be digitized by numeric values, by passing `vrange`, or by quantiles, by passing `qrange`, as above.

To create saddles in cis with `saddle`, cooltools requires: a cooler, a table with expected as function of distance, and parameters for digitizing:

In [None]:
telo_trans_filt_exps = {}
for _k, _clr in tqdm(telo_clrs25.items()):
    _df = telo_exps_trans[_k].reset_index()
    m2 = _df["region2"].isin(["chrX","chrY","chrM"])
    m1 = _df["region1"].isin(["chrX","chrY","chrM"])
    telo_trans_filt_exps[_k] = _df[~(m1 | m2)]

# a view without M,X and Y chromosomes ...
sub_chrom_view = bioframe.make_viewframe(hg38_chromsizes)
bad_chroms = ["chrX","chrY","chrM"]
sub_chrom_view = sub_chrom_view[~sub_chrom_view["name"].isin(bad_chroms)]

# Let's generate CRE track with Allana-style bin assignment

In [None]:
! scp newhpc:/home/allana.schooley-umw/as38w/Ranger/cres/all_cres_encodestyle_pMunion_NEWatac_k2785_rest95.bed ./bigbed/

In [None]:
# ! scp newhpc:/home/allana.schooley-umw/as38w/Ranger/cres/all_cres_encodestyle_pMunion_95atac_95rest.bed ./bigbed/
# ! scp newhpc:/home/allana.schooley-umw/as38w/Ranger/cres/all_cres_encodestyle_pMunion_90atac_90rest.bed ./bigbed/
# ! scp newhpc:/home/allana.schooley-umw/as38w/Ranger/cres/all_cres_encodestyle_pMunion_70atac_70rest.bed ./bigbed/
# ! scp newhpc:/home/allana.schooley-umw/as38w/Ranger/cres/all_cres_encodestyle_pMunion_70atac_95rest.bed ./bigbed/
# ! scp newhpc:/home/allana.schooley-umw/as38w/Ranger/cres/all_cres_encodestyle_pMunion_70atac_90rest.bed ./bigbed/
# ! scp newhpc:/home/allana.schooley-umw/as38w/Ranger/cres/all_cres_encodestyle_pMunion_75atac_95rest.bed ./bigbed/
# ! scp newhpc:/home/allana.schooley-umw/as38w/Ranger/cres/all_cres_encodestyle_pMunion_75atac_90rest.bed ./bigbed/

In [None]:
# old way of stitching cCREs into a binned_CRE list ...

# # different CRE will be assigned categroies corresponding to their importance
# cre_fnames = {
#     "PLS" : "bigbed/PLS_ccres.bed",
#     "pELS" : "bigbed/pELS_ccres.bed",
#     "dELS" : "bigbed/dELS_ccres.bed",
#     "openK4" : "bigbed/openK4_ccres.bed",
#     "ctcf" : "bigbed/ctcf_ccres.bed",
#     "justOpen" : "bigbed/justOpen_ccres.bed",
# }
# ticklabels = [
#     "PLS",
#     "pELS",
#     "dELS",
#     "opK4",
#     "CTCF",
#     "open",
#     "none",
# ]
# categroies are going to be 1,2,3,4,5,6
# 0 - everything else ...
# cres = []
# for i, (k,fname) in enumerate(cre_fnames.items()):
#     #
#     df = pd.read_table(fname, sep="\t")[["chrom","start","end"]]
#     df["status"] = len(cre_fnames) - i
#     cres.append(df)

# cres_df = pd.concat(cres, ignore_index=True)
# cres_df = bioframe.sort_bedframe(cres_df, view_df=hg38_arms)
# # now assign CREs to bins ...
# display(clr_bins.head(2))
# display(clr_bins.tail(2))
# # now annotate bins with the anchors ...
# _bin_cres = bioframe.overlap(clr_bins, cres_df)
# # _ggg["status"]
# _bin_cres = _bin_cres.drop(columns=["chrom_","start_","end_"])
# _bin_cres["status_"] = _bin_cres["status_"].fillna(0).astype(int)
# # ...
# binned_cres = _bin_cres.groupby(["chrom","start","end"], observed=True)["status_"].max().reset_index()
# binned_cres["status_"] = binned_cres["status_"].astype("category")
# binned_cres = binned_cres.astype({"chrom":str})
# # ...
# binned_cres

In [None]:
# generate some binsized bins
clr_bins10 = cooler.binnify(hg38_chromsizes, binsize10)
clr_bins10 = clr_bins10[~clr_bins10["chrom"].isin(["chrX","chrY","chrM"])]
# now annotate those bins with useful info - e.g. ID anchors ...
display(clr_bins10.head(2))
display(clr_bins10.tail(2))

# generate some binsized bins
clr_bins25 = cooler.binnify(hg38_chromsizes, binsize25)
clr_bins25 = clr_bins25[~clr_bins25["chrom"].isin(["chrX","chrY","chrM"])]
# now annotate those bins with useful info - e.g. ID anchors ...
display(clr_bins25.head(2))
display(clr_bins25.tail(2))



In [None]:
cre_to_code = {
    'pls':6,
    'pels':5,
    'nels':5,
    'dels':4,
    'openk4':3,
    'ctcf':2,
    'justopen':1,
}

cre_fnames = {
    "atac95_rest95" : "./bigbed/all_cres_encodestyle_pMunion_95atac_95rest.bed",  # control sorta, like before
    "atac90_rest90" : "./bigbed/all_cres_encodestyle_pMunion_90atac_90rest.bed",
    "atac70_rest70" : "./bigbed/all_cres_encodestyle_pMunion_70atac_70rest.bed",
    "atac70_rest95" : "./bigbed/all_cres_encodestyle_pMunion_70atac_95rest.bed",
    "atac70_rest90" : "./bigbed/all_cres_encodestyle_pMunion_70atac_90rest.bed",
    "atac75_rest95" : "./bigbed/all_cres_encodestyle_pMunion_75atac_95rest.bed",
    "atac75_rest90" : "./bigbed/all_cres_encodestyle_pMunion_75atac_90rest.bed",
    "allana_latest" : "./bigbed/all_cres_encodestyle_pMunion_NEWatac_k2785_rest95.bed",
}

binned_cre10_dfs = {}
binned_cre25_dfs = {}
for name, cre_fname in cre_fnames.items():
    df = pd.read_table(cre_fname, usecols=['chrom', 'start', 'end', 'name'])
    print(f"read {name=} with {len(df)} items ...")
    # # rename nels -> pels
    # df = df.replace({'name': {"nels": "pels"}})
    # sort by coordinate
    df = bioframe.sort_bedframe(df, view_df=hg38_arms)
    df["status"] = df["name"].map(cre_to_code)

    # now annotate bins with the anchors @10kb  ...
    _bin_cres = bioframe.overlap(clr_bins10, df)
    _bin_cres = _bin_cres.drop(columns=["chrom_","start_","end_"])
    _bin_cres["status_"] = _bin_cres["status_"].fillna(0).astype(int)
    # assign cCREs per bin according to the hierarchy ...
    binned_cres = _bin_cres.groupby(["chrom","start","end"], observed=True)["status_"].max().reset_index()
    binned_cres["status_"] = binned_cres["status_"].astype("category")
    # ...
    binned_cre10_dfs[name] = binned_cres.astype({ "chrom" : str })

    # now annotate bins with the anchors @10kb  ...
    _bin_cres = bioframe.overlap(clr_bins25, df)
    _bin_cres = _bin_cres.drop(columns=["chrom_","start_","end_"])
    _bin_cres["status_"] = _bin_cres["status_"].fillna(0).astype(int)
    # assign cCREs per bin according to the hierarchy ...
    binned_cres = _bin_cres.groupby(["chrom","start","end"], observed=True)["status_"].max().reset_index()
    binned_cres["status_"] = binned_cres["status_"].astype("category")
    # ...
    binned_cre25_dfs[name] = binned_cres.astype({ "chrom" : str })



ticklabels = [
    "PLS",
    "pELS",
    "dELS",
    "opK4",
    "CTCF",
    "open",
    "none",
]
# categroies are going to be 1,2,3,4,5,6
# 0 - everything else ...


_working_cres10 = binned_cre10_dfs['allana_latest']
_working_cres25 = binned_cre25_dfs['allana_latest']

In [None]:
id_anchor_fnames = {
    "mega_2X_enrichment": "ID_anchors/mega_2X_enrichment.fourth_mega.max_size.bed",
    "5hr_2X_enrichment_old": "ID_anchors/5hr_2X_enrichment.second_bulk.max_size.bed",
    "5hr_2X_enrichment": "ID_anchors/5hr_2X_enrichment.pixel_derived.bed",
    "5hr_2X_enrichment_nosing": "ID_anchors/5hr_2X_enrichment.pixel_derived.no_singletons.bed",
    "5hr_notinCyto_2X_enrichment_signal": "ID_anchors/p5notin_pCyto_anchors_2X_enrichment.pixel_derived.signal_peaks.bed",
    "5hr_2X_enrichment_signal": "ID_anchors/5hr_2X_enrichment.pixel_derived.signal_peaks.bed",
    "10hr_2X_enrichment_signal": "ID_anchors/10hrs_2X_enrichment.pixel_derived.signal_peaks.bed",
    "N93p5_2X_enrichment_signal": "ID_anchors/N93p5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "pCyto_2X_enrichment_signal": "ID_anchors/pCyto_2X_enrichment.pixel_derived.signal_peaks.bed",
    "mCyto_2X_enrichment_signal": "ID_anchors/mCyto_2X_enrichment.pixel_derived.signal_peaks.bed",
    "mega_3X_enrichment": "ID_anchors/mega_3X_enrichment.fifth_mega3x.max_size.bed",
    "MEGA_2X_enrichment": "ID_anchors/MEGAp5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGA_weaker_2X_enrichment": "ID_anchors/MEGA_plus_weak_anchors_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGAN93_2X_enrichment": "ID_anchors/MEGAN93p5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGAminus_2X_enrichment": "ID_anchors/MEGA_minus_ctrl_2X_enrichment.pixel_derived.signal_peaks.bed",
    "cyto_2x_enrichment": "ID_anchors/cyto_2x_enrichment.third_mCyto.max_size.bed",
}

id_anchors_dict = {}
for id_name, fname in id_anchor_fnames.items():
    id_anchors_dict[id_name] = pd.read_csv(fname, sep="\t")
    # ...
    print(f"loaded {len(id_anchors_dict[id_name]):5d} ID anchors {id_name:>20} in BED format ...")

### Check enrichemnts of Nezar's IPGs

In [None]:
nezar_df = bioframe.read_table("GSM7990272_DLD1.360.NT.50000.E1-E128.comp_10.kmeans9_5.bed", schema="bed9")

# Now make saddle-compatible track - i.e. digitized into 0,1,2,3,...categories
_comp_dict = {
    'A1':2,
    'A2':3,
    'V+VI':4,
    'B2/B3':5,
    'B4':6,
}

k = "name"
_track = nezar_df[["chrom","start","end",k]].replace({k: _comp_dict})#.astype({k:"category"})
_track[k].unique()


In [None]:
# generate some binsized bins
clr_bins50 = cooler.binnify(hg38_chromsizes, 50_000)
clr_bins50 = clr_bins50[~clr_bins50["chrom"].isin(["chrX","chrY","chrM"])]

# now annotate bins with the anchors @10kb  ...
_bin_assigned = bioframe.overlap(clr_bins50, _track)
_bin_assigned = _bin_assigned.drop(columns=["chrom_","start_","end_"])
_bin_assigned["name_"] = _bin_assigned["name_"].fillna(0).astype(int)

_cats = pd.CategoricalDtype(categories=list(range(7)), ordered=True)

_bin_assigned["name_"] = _bin_assigned["name_"].astype(_cats)
# ...
_track_bins = _bin_assigned.astype({ "chrom" : str })
_track_bins = _track_bins.rename(columns={"name_":"name"})
_track_bins

In [None]:
# # indices of compartments back to their names ...
# name_to_descr = dict((i, name) for name, i in _comp_dict.items())

_chosen_idname = "5hr_2X_enrichment_signal"
# _chosen_idname = "MEGAminus_2X_enrichment"
# _chosen_idname = "MEGA_2X_enrichment"
# _chosen_idname = "pCyto_2X_enrichment_signal"
_anchors = id_anchors_dict[_chosen_idname]

In [None]:
_track_bins_id = _track_bins.copy()
# ...
_id_bins = bioframe.overlap( _track_bins, _anchors, return_index=True).dropna()["index"]
# ...
_track_bins_id.loc[_id_bins, "name"] = 1

In [None]:
cre_to_code = {
    'pls':6,
    'pels':5,
    'nels':5,
    'dels':4,
    'openk4':3,
    'ctcf':2,
    'justopen':1,
}

_cats = pd.CategoricalDtype(categories=list(range(6+1)), ordered=True)

cre_fnames = {
    # "atac95_rest95" : "./bigbed/all_cres_encodestyle_pMunion_95atac_95rest.bed",  # control sorta, like before
    # "atac90_rest90" : "./bigbed/all_cres_encodestyle_pMunion_90atac_90rest.bed",
    # "atac70_rest70" : "./bigbed/all_cres_encodestyle_pMunion_70atac_70rest.bed",
    # "atac70_rest95" : "./bigbed/all_cres_encodestyle_pMunion_70atac_95rest.bed",
    # "atac70_rest90" : "./bigbed/all_cres_encodestyle_pMunion_70atac_90rest.bed",
    # "atac75_rest95" : "./bigbed/all_cres_encodestyle_pMunion_75atac_95rest.bed",
    # "atac75_rest90" : "./bigbed/all_cres_encodestyle_pMunion_75atac_90rest.bed",
    "allana_latest" : "./bigbed/all_cres_encodestyle_pMunion_NEWatac_k2785_rest95.bed",
}

_status_query = "(status_ == 0)|(status_ == 5)|(status_ == 4)|(status_ == 6)|(status_ == 7)"

binned_cre10_dfs = {}
binned_cre25_dfs = {}
for name, cre_fname in cre_fnames.items():
    df = pd.read_table(cre_fname, usecols=['chrom', 'start', 'end', 'name'])
    print(f"read {name=} with {len(df)} items ...")
    # # rename nels -> pels
    # df = df.replace({'name': {"nels": "pels"}})
    # sort by coordinate
    df = bioframe.sort_bedframe(df, view_df=hg38_arms)
    df["status"] = df["name"].map(cre_to_code)

    # now annotate bins with the anchors @10kb  ...
    _bin_cres = bioframe.overlap(clr_bins10, df)
    _bin_cres = _bin_cres.drop(columns=["chrom_","start_","end_"])
    _bin_cres["status_"] = _bin_cres["status_"].fillna(0).astype(int)
    # assign cCREs per bin according to the hierarchy ...
    binned_cres = _bin_cres.groupby(["chrom","start","end"], observed=True)["status_"].max().reset_index()
    binned_cres["status_"] = binned_cres["status_"].astype(_cats)
    # value_counts of annotated bins before gene assignment ...
    print("@10kb\n")
    display(binned_cres["status_"].value_counts().sort_index())
    print("\n")
    #
    binned_cre10_dfs[name] = binned_cres.astype({ "chrom" : str })

    # now annotate bins with the anchors @25kb  ...
    _bin_cres = bioframe.overlap(clr_bins25, df)
    _bin_cres = _bin_cres.drop(columns=["chrom_","start_","end_"])
    _bin_cres["status_"] = _bin_cres["status_"].fillna(0).astype(int)
    # assign cCREs per bin according to the hierarchy ...
    binned_cres = _bin_cres.groupby(["chrom","start","end"], observed=True)["status_"].max().reset_index()
    binned_cres["status_"] = binned_cres["status_"].astype(_cats)
    # value_counts of annotated bins before gene assignment ...
    print("@25kb\n")
    display(binned_cres["status_"].value_counts().sort_index())
    print("\n")
    # ...
    binned_cre25_dfs[name] = binned_cres.astype({ "chrom" : str })



ticklabels = [
    "PLS",
    "pELS",
    "dELS",
    "opK4",
    "CTCF",
    "open",
    "none",
]
# # categroies are going to be 1,2,3,4,5,6
# # 0 - everything else ...

_working_cres10 = binned_cre10_dfs["allana_latest"]
_working_cres25 = binned_cre25_dfs["allana_latest"]

In [None]:
df["name"].value_counts()

In [None]:
26855+2008

In [None]:
! head ./bigbed/all_cres_encodestyle_pMunion_NEWatac_k2785_rest95.bed


In [None]:
ooo = []
for k in range(7):
    _xxx = bioframe.overlap(_working_cres10, _track_bins_id).dropna().query(f"name_ == {k}")["status_"].value_counts().sort_index()
    _xxx.name = k
    ooo.append(_xxx)
_mat1 = pd.concat(ooo, axis=1)
print(_mat1)
_mmm1 = ((_mat1/_mat1.sum(axis=0)).T / (_mat1.sum(axis=1)/len(_working_cres10))).T

In [None]:
ooo = []
for k in range(7):
    _xxx = bioframe.overlap(_working_cres10, _track_bins).dropna().query(f"name_ == {k}")["status_"].value_counts().sort_index()
    _xxx.name = k
    ooo.append(_xxx)
_mat2 = pd.concat(ooo, axis=1)
print(_mat2)
_mmm2 = ((_mat2/_mat2.sum(axis=0)).T / (_mat2.sum(axis=1)/len(_working_cres10))).T

In [None]:
f,axs  = plt.subplots(ncols=2,sharey=True)



axs[0].imshow(
    _mmm1,
    vmin=0,
    vmax=5,
)


axs[1].imshow(
    _mmm2,
    vmin=0,
    vmax=5,
)

for i,ax in enumerate(axs):
    ax.set_xticks(np.arange(7))
    ax.set_xticklabels(["none","id"]+list(_comp_dict))
    if i ==0 :
        ax.set_yticks(np.arange(len(ticklabels)))
        ax.set_yticklabels(ticklabels[::-1]);

In [None]:
def _job(packed_data, sample):
    clr_dict, exp_dict, _atrack, view_df = packed_data
    _clr = clr_dict[sample]
    _exp = exp_dict[sample]
    from cooltools.api.saddle import saddle_stack
    _sum, _count = saddle_stack(
        _clr,
        _exp,
        _atrack,
        'cis',
        n_bins=None,
        drop_track_na=True,
        view_df=view_df
    )
    return sample, _sum, _count

# have to use daemon=False, because _job is multiprocessing-based already ...
with WorkerPool(
    n_jobs=16,
    daemon=True,
    shared_objects=( telo_clrs10, telo_exps_cis, _working_cres10, hg38_arms ),
    start_method="fork",  # little faster than spawn, fork is the fastest
    use_dill=True,
) as wpool:
    results = wpool.map(_job, telo_clrs10, progress_bar=True)

# sort out the results ...
interaction_sums = {}
interaction_counts = {}
for sample, _sum, _counts in results:
    interaction_sums[sample] = _sum
    interaction_counts[sample] = _counts

In [None]:
# trans saddles here yo !
def _job_trans(packed_data, sample):
    clr_dict, exp_dict, _atrack, view_df = packed_data
    _clr = clr_dict[sample]
    _exp = exp_dict[sample]
    from cooltools.api.saddle import saddle_stack
    _sum, _count = saddle_stack(
        _clr,
        _exp,
        _atrack,
        'trans',
        n_bins=None,
        drop_track_na=True,
        view_df=view_df,
    )
    return sample, _sum, _count

# have to use daemon=False, because _job_trans is multiprocessing-based already ...
with WorkerPool(
    n_jobs=16,
    daemon=True,
    shared_objects=( telo_clrs25, telo_trans_filt_exps, _working_cres25, sub_chrom_view ),
    start_method="fork",  # little faster than spawn, fork is the fastest
    use_dill=True,
) as wpool:
    results = wpool.map(_job_trans, telo_clrs25, progress_bar=True)

# sort out the results ...
interaction_sums_trans = {}
interaction_counts_trans = {}
for sample, _sum, _counts in results:
    interaction_sums_trans[sample] = _sum
    interaction_counts_trans[sample] = _counts

# Let's save this results using HDF5 for conveniece and to practice ...

In [None]:
import h5py

In [None]:
# # we've got 4 dictionaries to store along with some metadata ...
# interaction_sums[sample] = _sum
# interaction_counts[sample] = _counts
# interaction_sums_trans[sample] = _sum
# interaction_counts_trans[sample] = _counts

In [None]:
with h5py.File("saddles_cre_by_distance_Allana_latest_cre.hdf5", 'x') as f:
    # add metadata just in case
    f.attrs["cis_binsize"] = binsize10
    f.attrs["trans_binsize"] = binsize25
    f.attrs["cre_fname"] = "./bigbed/all_cres_encodestyle_pMunion_95atac_95rest.bed"
    # CIS ...
    # interaction_sums ...
    _sums_grp = f.create_group("sums")
    # create subgroups per sample
    for _sample, _arr in interaction_sums.items():
        _sums_grp.create_dataset(_sample, data=_arr)
    # interaction_counts ...
    _sums_grp = f.create_group("counts")
    # create subgroups per sample
    for _sample, _arr in interaction_counts.items():
        _sums_grp.create_dataset(_sample, data=_arr)
    # TRANS ...
    # interaction_sums ...
    _sums_grp = f.create_group("sums_trans")
    # create subgroups per sample
    for _sample, _arr in interaction_sums_trans.items():
        _sums_grp.create_dataset(_sample, data=_arr)
    # interaction_counts ...
    _sums_grp = f.create_group("counts_trans")
    # create subgroups per sample
    for _sample, _arr in interaction_counts_trans.items():
        _sums_grp.create_dataset(_sample, data=_arr)


In [None]:
! ls -lah saddles_cre_by_distance.hdf5
# ! rm saddles_cre_by_distance.hdf5
! ls -lah saddles_cre_by_distance_Allana_latest_cre.hdf5

In [None]:
# fr = h5py.File("saddles_cre_by_distance.hdf5", 'r')
fr = h5py.File("saddles_cre_by_distance_Allana_latest_cre.hdf5", 'r')

# fr.items()
def print_attrs(name, obj):
    # Create indent
    shift = name.count('/') * '    '
    item_name = name.split("/")[-1]
    print(shift + item_name)
    try:
        for key, val in obj.attrs.items():
            print(shift + '    ' + f"{key}: {val}")
    except:
        pass


# f = h5py.File('foo.hdf5','r')
fr.visititems(print_attrs)

# check general metadata ...
dict(fr.attrs)

fr.close()

In [None]:
sub_samples_m = [
    "mMito",
    "mTelo",
    "mCyto",
    "m5hR1R2",
    "m10hR1R2",
]
sub_samples_p = [
    "pMito",
    "pTelo",
    "pCyto",
    "p5hR1R2",
    "p10hR1R2",
]

# introduce distance ranges
# 0-1mb: 0:21 bins
# 1-7Mb: 21:141 bins
# 7-50Mb: 141:1001 bins
distances = {
    "short:<1MB": slice(0,int(1_000_000/binsize10)+1),
    "mid:1MB-7Mb": slice(int(1_000_000/binsize10),int(7_000_000/binsize10)+1),
    "long7Mb-50Mb": slice(int(7_000_000/binsize10),int(50_000_000/binsize10)+1),
    "trans": slice(None),
}

# introduce distance ranges
# 0-0.5mb
# 0.5-10Mb
# 10-500Mb
distances = {
    "short:<0.5MB": slice(0,int(500_000/binsize10)+1),
    "mid:0.5MB-10Mb": slice(int(500_000/binsize10),int(10_000_000/binsize10)+1),
    "long10Mb-100Mb": slice(int(10_000_000/binsize10),int(500_000_000/binsize10)+1),
    "trans": slice(None),
}

# distances = {
#     "<0.25Mb": slice(0,int(250_000/binsize10)+1),
#     "0.25-1Mb": slice(int(500_000/binsize10),int(1_000_000/binsize10)+1),
#     "1-10Mb": slice(int(1_000_000/binsize10),int(10_000_000/binsize10)+1),
#     ">10Mb": slice(int(10_000_000/binsize10),None),
#     # "50Mb-550Mb": slice(int(50_000_000/binsize10),int(550_000_000/binsize10)+1),
#     "trans": slice(None),
# }

In [None]:
fig, axs = plt.subplots(
    nrows=len(sub_samples_m),
    ncols=2*len(distances),
    figsize=(4*len(distances),2*len(sub_samples_m)),
    sharex=True,
    sharey=True,
)

imshow_kwargs = dict(
        norm=LogNorm(vmin=1/2, vmax=2),
        cmap="RdBu_r",
        interpolation="none",
)

for sample_m, sample_p, (i, axs) in zip(sub_samples_m, sub_samples_p, enumerate(axs)):
    for jj, (_dist_name, _dist) in enumerate(distances.items()):
        axm, axp = axs[jj], axs[len(distances) + jj]
        if _dist_name != "trans":
            Cm = np.nanmean(interaction_sums[sample_m][_dist], axis=0) / np.nanmean(interaction_counts[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums[sample_p][_dist], axis=0) / np.nanmean(interaction_counts[sample_p][_dist], axis=0)
        elif _dist_name == "trans":
            # pass
            Cm = np.nanmean(interaction_sums_trans[sample_m][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums_trans[sample_p][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_p][_dist], axis=0)
        else:
            pass
        axm.imshow(Cm, **imshow_kwargs)
        axp.imshow(Cp, **imshow_kwargs)
        for _ax in [axp, axm]:
            _ax.set_xticks([])
            _ax.set_yticks([])
        if i == 0:
            axm.set_title(f"m-{_dist_name}")
            axp.set_title(f"p-{_dist_name}")
        if i == len(sub_samples_m)-1:
            for _ax in [axm, axp]:
                _ax.set_xticks(np.arange(len(ticklabels)))
                _ax.set_xticklabels(np.asarray(ticklabels[::-1]), rotation="vertical")
        if jj == 0:
            axm.set_ylabel(sample_m.lstrip("m"))



In [None]:
fig, axs = plt.subplots(
    nrows=len(sub_samples_m),
    ncols=2*len(distances),
    figsize=(4*len(distances),2*len(sub_samples_m)),
    sharex=True,
    sharey=True,
)

imshow_kwargs = dict(
        norm=LogNorm(vmin=1/2.25, vmax=2.25),
        cmap="RdBu_r",
        interpolation="none",
)

for sample_m, sample_p, (i, axs) in zip(sub_samples_m, sub_samples_p, enumerate(axs)):
    for jj, (_dist_name, _dist) in enumerate(distances.items()):
        axm, axp = axs[jj], axs[len(distances) + jj]
        if _dist_name != "trans":
            Cm = np.nanmean(interaction_sums[sample_m][_dist], axis=0) / np.nanmean(interaction_counts[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums[sample_p][_dist], axis=0) / np.nanmean(interaction_counts[sample_p][_dist], axis=0)
        elif _dist_name == "trans":
            # pass
            Cm = np.nanmean(interaction_sums_trans[sample_m][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums_trans[sample_p][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_p][_dist], axis=0)
        else:
            pass
        axm.imshow(Cm, **imshow_kwargs)
        axp.imshow(Cp, **imshow_kwargs)
        for _ax in [axp, axm]:
            _ax.set_xticks([])
            _ax.set_yticks([])
        if i == 0:
            axm.set_title(f"m-{_dist_name}")
            axp.set_title(f"p-{_dist_name}")
        if i == len(sub_samples_m)-1:
            for _ax in [axm, axp]:
                _ax.set_xticks(np.arange(len(ticklabels)))
                _ax.set_xticklabels(np.asarray(ticklabels[::-1]), rotation="vertical")
        if jj == 0:
            axm.set_ylabel(sample_m.lstrip("m"))

In [None]:
# # # # the mix one - mp
sub_samples_m = [
    "N93m5",
    "N93m10",
]
# p ...
sub_samples_p = [
    "N93p5",
    "N93p10",
]

fig, axs = plt.subplots(
    nrows=len(sub_samples_m),
    ncols=2*len(distances),
    figsize=(4*len(distances),2*len(sub_samples_m)),
    sharex=True,
    sharey=True,
)

imshow_kwargs = dict(
        norm=LogNorm(vmin=1/2.5, vmax=2.5),
        cmap="RdBu_r",
        interpolation="none",
)

for sample_m, sample_p, (i, axs) in zip(sub_samples_m, sub_samples_p, enumerate(axs)):
    for jj, (_dist_name, _dist) in enumerate(distances.items()):
        axm, axp = axs[jj], axs[len(distances) + jj]
        if _dist_name != "trans":
            Cm = np.nanmean(interaction_sums[sample_m][_dist], axis=0) / np.nanmean(interaction_counts[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums[sample_p][_dist], axis=0) / np.nanmean(interaction_counts[sample_p][_dist], axis=0)
        elif _dist_name == "trans":
            Cm = np.nanmean(interaction_sums_trans[sample_m][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums_trans[sample_p][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_p][_dist], axis=0)
        else:
            pass
        axm.imshow(Cm, **imshow_kwargs)
        axp.imshow(Cp, **imshow_kwargs)
        for _ax in [axp, axm]:
            _ax.set_xticks([])
            _ax.set_yticks([])
        if i == 0:
            axm.set_title(f"m-{_dist_name}")
            axp.set_title(f"p-{_dist_name}")
        if i == len(sub_samples_m)-1:
            for _ax in [axm, axp]:
                _ax.set_xticks(np.arange(len(ticklabels)))
                _ax.set_xticklabels(np.asarray(ticklabels[::-1]), rotation="vertical")
        if jj == 0:
            axm.set_ylabel(sample_m.lstrip("m"))

In [None]:
from matplotlib import cm

In [None]:
sub_samples_m =[
        "m10hR1R2",
        "p10hR1R2",
        "mp10hR1R2",
    ]
sub_samples_p = [
        "N93m10",
        "N93p10",
        "N93mp10",
    ]

fig, axs = plt.subplots(
    nrows=len(sub_samples_m),
    ncols=2*len(distances),
    figsize=(4*len(distances),2*len(sub_samples_m)),
    sharex=True,
    sharey=True,
)

imshow_kwargs = dict(
        norm=LogNorm(vmin=1/2.5, vmax=2.5),
        cmap="RdBu_r",
        interpolation="none",
)

for sample_m, sample_p, (i, axs) in zip(sub_samples_m, sub_samples_p, enumerate(axs)):
    for jj, (_dist_name, _dist) in enumerate(distances.items()):
        axm, axp = axs[jj], axs[len(distances) + jj]
        if _dist_name != "trans":
            Cm = np.nanmean(interaction_sums[sample_m][_dist], axis=0) / np.nanmean(interaction_counts[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums[sample_p][_dist], axis=0) / np.nanmean(interaction_counts[sample_p][_dist], axis=0)
        elif _dist_name == "trans":
            Cm = np.nanmean(interaction_sums_trans[sample_m][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums_trans[sample_p][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_p][_dist], axis=0)
        else:
            pass
        axm.imshow(Cm, **imshow_kwargs)
        axp.imshow(Cp, **imshow_kwargs)
        for _ax in [axp, axm]:
            _ax.set_xticks([])
            _ax.set_yticks([])
        if i == 0:
            axm.set_title(f"m-{_dist_name}")
            axp.set_title(f"p-{_dist_name}")
        if i == len(sub_samples_m)-1:
            for _ax in [axm, axp]:
                _ax.set_xticks(np.arange(len(ticklabels)))
                _ax.set_xticklabels(np.asarray(ticklabels[::-1]), rotation="vertical")
        if jj == 0:
            axm.set_ylabel(sample_m)



# try adding an axes manually ...
cax = fig.add_axes([0.88,0.001,0.1,0.02])
fig.colorbar(
    cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
    cax=cax,
    orientation="horizontal",
)
cax.set_xticks([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
cax.set_xticklabels([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
cax.minorticks_off()