In [None]:
# set the number of threads for many common libraries
from os import environ
N_THREADS = '1'
environ['OMP_NUM_THREADS'] = N_THREADS
environ['OPENBLAS_NUM_THREADS'] = N_THREADS
environ['MKL_NUM_THREADS'] = N_THREADS
environ['VECLIB_MAXIMUM_THREADS'] = N_THREADS
environ['NUMEXPR_NUM_THREADS'] = N_THREADS
# https://superfastpython.com/numpy-number-blas-threads/

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
import pandas as pd
import numpy as np
# Hi-C utilities imports:
import cooler
import bioframe
import cooltools
# Visualization imports:
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, Normalize
from matplotlib import colors
import matplotlib.patches as patches
from matplotlib.ticker import EngFormatter

In [None]:
# bbi for stackups ...
import bbi
# functions and assets specific to this repo/project ...
from data_catalog import bws, bws_vlim, telo_dict
from tqdm import tqdm
from tqdm.notebook import trange, tqdm
# import mpire for nested multi-processing
from mpire import WorkerPool

In [None]:
def to_mb(bp_val):
    # check MB
    if np.mod(bp_val, 1_000_000):
        # just give 1 decimal if not even Mb
        return f"{bp_val/1_000_000:.1f}"
    else:
        return f"{bp_val//1_000_000}"

# given the range - generate pretty axis name
def _get_name(_left, _right, _amount):
    if np.isclose(_left, 0.0):
        return f"<{to_mb(_right)} Mb: {_amount}"
    elif _right > 80_000_000:
        return f">{to_mb(_left)} Mb: {_amount}"
    else:
        return f"{to_mb(_left)}-{to_mb(_right)} Mb: {_amount}"

### Chrom arms as a view

In [None]:
# Use bioframe to fetch the genomic features from the UCSC.
hg38_chromsizes = bioframe.fetch_chromsizes('hg38')
hg38_cens = bioframe.fetch_centromeres('hg38')
hg38_arms_full = bioframe.make_chromarms(hg38_chromsizes, hg38_cens)
# # remove "bad" chromosomes and near-empty arms ...
# excluded_arms = ["chr13_p", "chr14_p", "chr15_p", "chr21_p", "chr22_p", "chrM_p", "chrY_p", "chrY_q", "chrX_p", "chrX_q"]
# hg38_arms = hg38_arms_full[~hg38_arms_full["name"].isin(excluded_arms)].reset_index(drop=True)

# can do 1 chromosome (or arm) as well ..
included_arms = ["chr1_q", "chr2_p", "chr4_q", "chr6_q"]
included_arms = hg38_arms_full["name"].to_list()[:44] # all autosomal ones ...
hg38_arms = hg38_arms_full[hg38_arms_full["name"].isin(included_arms)].reset_index(drop=True)

# There is a problem with our arms view of the chromosomes ...

the way we do it now - end of p-arm is alsways equal to the start of q-arm ...

After binning this could lead to the situation where last bin of p-arm is upstream of the first q-arm bin ...

This makes `cooltools.api.is_valid_expected` crash ...

Let's try solving that by adding 1 bp to the start of every q-arm ...

In [None]:
def adjust_arm_view(
    view_df,
    binsize,
):
    """
    adjust arm-based view of the genome to fix slightly overlapping p and q arms ...
    """
    _iter_view = view_df.itertuples(index=False)
    return pd.DataFrame(
        [(c,s+binsize,e,n) if ("q" in n) else (c,s,e,n) for c,s,e,n in _iter_view],
        columns=hg38_arms.columns
    )


## Now let's get to pileups ! First - calcualte expected for all samples ...

In [None]:
# cooler files that we'll work on :
binsize10 = 10_000
telo_clrs10 = { _k: cooler.Cooler(f"{_path}::/resolutions/{binsize10}") for _k, _path in telo_dict.items() }

# cooler files that we'll work on :
binsize25 = 25_000
telo_clrs25 = { _k: cooler.Cooler(f"{_path}::/resolutions/{binsize25}") for _k, _path in telo_dict.items() }

## We'll continue this dance between 10 and 25 kb resolutions - one for cis and one for trans ...

### cis-expected first:

In [None]:
def _job(packed_data, sample):
    # packed data -> exp_kwargs and a dict with coolers for each sample
    exp_kwargs, clr_dict = packed_data
    _clr = clr_dict[sample]
    # in order to use spawn/forkserver we have to import for worker
    from cooltools import expected_cis
    _exp = expected_cis( _clr, **exp_kwargs)
    return (sample, _exp)

# define expected parameters in the form of kwargs-dict:
exp_kwargs = dict(
    view_df=adjust_arm_view(hg38_arms, binsize10),
    intra_only=False,
    nproc=12
)

# have to use daemon=False, because _job is multiprocessing-based already ...
with WorkerPool(
    n_jobs=8,
    daemon=False,
    shared_objects=( exp_kwargs, telo_clrs10 ),
    start_method="forkserver",  # little faster than spawn, fork is the fastest
    use_dill=True,
) as wpool:
    results = wpool.map(_job, telo_clrs10, progress_bar=True)

# sort out the results ...
telo_exps_cis = {sample: _exp for sample, _exp in results}
# # old way of doing it
# telo_exps_cis = {k: cooltools.expected_cis( _clr, **exp_kwargs) for k, _clr in telo_clrs10.items()}

### trans-expected second

In [None]:
def _job(packed_data, sample):
    # unpack data
    clr_dict, = packed_data
    exp_kwargs = dict(chunksize=1000000, nproc=12)
    from cooltools import expected_trans
    _clr = clr_dict[sample]
    _exp = expected_trans( _clr, **exp_kwargs).set_index(["region1", "region2"]).sort_index()
    return (sample, _exp)

# have to use daemon=False, because _job is multiprocessing-based already ...
with WorkerPool(
    n_jobs=8,
    daemon=False,
    shared_objects=(telo_clrs25, ),
    start_method="forkserver",
    use_dill=True,
) as wpool:
    results = wpool.map(_job, telo_clrs25, progress_bar=True)

# sort out the results ...
telo_exps_trans = {sample: _exp for sample, _exp in results}

# Read pre-called native compartments (anchors) and pick one for pileups:
## Skip all of the anchor characterization and annotation ...

In [None]:
id_anchor_fnames = {
    "mega_2X_enrichment": "ID_anchors/mega_2X_enrichment.fourth_mega.max_size.bed",
    "5hr_2X_enrichment_old": "ID_anchors/5hr_2X_enrichment.second_bulk.max_size.bed",
    "5hr_2X_enrichment": "ID_anchors/5hr_2X_enrichment.pixel_derived.bed",
    "5hr_2X_enrichment_nosing": "ID_anchors/5hr_2X_enrichment.pixel_derived.no_singletons.bed",
    "5hr_notinCyto_2X_enrichment_signal": "ID_anchors/p5notin_pCyto_anchors_2X_enrichment.pixel_derived.signal_peaks.bed",
    "5hr_2X_enrichment_signal": "ID_anchors/5hr_2X_enrichment.pixel_derived.signal_peaks.bed",
    "10hr_2X_enrichment_signal": "ID_anchors/10hrs_2X_enrichment.pixel_derived.signal_peaks.bed",
    "N93p5_2X_enrichment_signal": "ID_anchors/N93p5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "pCyto_2X_enrichment_signal": "ID_anchors/pCyto_2X_enrichment.pixel_derived.signal_peaks.bed",
    "mCyto_2X_enrichment_signal": "ID_anchors/mCyto_2X_enrichment.pixel_derived.signal_peaks.bed",
    "mega_3X_enrichment": "ID_anchors/mega_3X_enrichment.fifth_mega3x.max_size.bed",
    "MEGA_2X_enrichment": "ID_anchors/MEGAp5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGA_weaker_2X_enrichment": "ID_anchors/MEGA_plus_weak_anchors_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGAN93_2X_enrichment": "ID_anchors/MEGAN93p5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGAminus_2X_enrichment": "ID_anchors/MEGA_minus_ctrl_2X_enrichment.pixel_derived.signal_peaks.bed",
    "cyto_2x_enrichment": "ID_anchors/cyto_2x_enrichment.third_mCyto.max_size.bed",
}

id_anchors_dict = {}
for id_name, fname in id_anchor_fnames.items():
    id_anchors_dict[id_name] = pd.read_csv(fname, sep="\t")
    # ...
    print(f"loaded {len(id_anchors_dict[id_name]):5d} ID anchors {id_name:>20} in BED format ...")


_anchors = id_anchors_dict["pCyto_2X_enrichment_signal"]
_anchors_fname = id_anchor_fnames["pCyto_2X_enrichment_signal"]

# Create combinations of `_anchors` for the pileup - i.e. the all-by-all type of situation !
## cis and trans - done a bit differently ...

In [None]:
# cis ...
_df = bioframe.pair_by_distance(
    _anchors,
    min_sep=0,
    max_sep=100_000_000,
    relative_to='midpoints',
    keep_order=True,
    suffixes=('1', '2'),
)

# calculate the distance from the diagonal ...
_df["dist"] = _df.eval(".5*(start2+end2) - .5*(start1+end1)")

# we have to explicitly keep only intra-arm interactions for this to work !!!!!!!!
_intra_arm_index = cooltools.lib.common.assign_view_auto(_df, adjust_arm_view(hg38_arms, binsize10)).query("region1 == region2").index
_df_intra_arm = _df.loc[_intra_arm_index].reset_index(drop=True)
display(_df_intra_arm)

# trans ...
# use that recipy from stackoverflow using `triu_indices` ...
# try all pairwise combinations on a smaller subset ...
_left, _right = np.triu_indices(len(_anchors), k=1)

ALL_pairwise_anchors = pd.concat(
    [
        _anchors.iloc[_left].add_suffix("1").reset_index(drop=True),
        _anchors.iloc[_right].add_suffix("2").reset_index(drop=True)
    ],
    #
    axis=1
 )
tr_feat = ALL_pairwise_anchors.query("chrom1 != chrom2").reset_index(drop=True)
display(tr_feat)

In [None]:
_df

# This is the step of recaluclating trans-pileups

keep intermediate results in `/data/dekkerlab/tmp`

In [None]:
import os
recalculate_trans_oe_pups = True
# # check if intermediate results are still there ...
# if not all(os.path.isfile(f"/data/dekkerlab/tmp/new_{s}_trans_stack_oe.npy") for s in telo_clrs25):
#     print("some of the tmp npy files are missing ! need to recalculate! ")
#     recalculate_trans_oe_pups = True

In [None]:
# ! ls -lah /data/sergpolly/tmp/*.npy

In [None]:
# this takes up a LOT of RAM ...
if recalculate_trans_oe_pups:
    # ...
    def _job(packed_data, sample):
        # unpack data
        clr_dict, features, exp_dict = packed_data
        kwargs = dict(flank=100_000, min_diag=None, nproc=18)
        from cooltools import pileup
        _clr = clr_dict[sample]
        _stack = pileup( _clr, features, **kwargs)
        # now divide that stack by expected:
        _exp = exp_dict[sample]
        features_chroms_order = features.set_index(["chrom1","chrom2"]).index
        _stack_exp = _exp.loc[features_chroms_order, "balanced.avg"].to_numpy()
        import numpy as np
        np.save(f"/data/sergpolly/tmp/new_{sample}_trans_stack_oe.npy", _stack/_stack_exp[:,None,None])
        return True

    # have to use daemon=False, because _job is multiprocessing-based already ...
    with WorkerPool(
        n_jobs=4,
        daemon=False,
        shared_objects=(telo_clrs25, tr_feat, telo_exps_trans),
        start_method="forkserver",
        use_dill=True,
    ) as wpool:
        results = wpool.map(_job, telo_clrs25, progress_bar=True)


## Do cis pielups in parallel ...

In [None]:
def _job(packed_data, sample):
    # unpack shared data
    features, clr_dict, exp_dict, view_df = packed_data
    pup_kwargs = dict(view_df=view_df, flank=100_000, nproc=12)
    _clr = clr_dict[sample]
    _exp = exp_dict[sample]
    from cooltools import pileup
    _pstack = pileup(
        _clr,
        features,
        expected_df=_exp,
        **pup_kwargs,
    )
    return (sample, _pstack)

# have to use daemon=False, because _job is multiprocessing-based already ...
with WorkerPool(
    n_jobs=8,
    daemon=False,
    shared_objects=(_df_intra_arm, telo_clrs10, telo_exps_cis, adjust_arm_view(hg38_arms, binsize10)),
    start_method="forkserver",
    use_dill=True,
) as wpool:
    results = wpool.map(_job, telo_clrs10, progress_bar=True)

# sort out the results ...
fullstacks_cis = {sample: _pstack for sample, _pstack in results}

## Now let's store all of the results in a single HDF5 file for convenience ...

In [None]:
import h5py

In [None]:
with h5py.File("/data/sergpolly/tmp/Pileups_ID_by_distance_pCyto.hdf5", 'x') as f:
    # add metadata just in case
    f.attrs["cis_binsize"] = binsize10
    f.attrs["trans_binsize"] = binsize25
    f.attrs["id_anchors_fname"] = _anchors_fname
    # CIS ...
    print("saving cis data ...")
    _cis_grp = f.create_group("cis")
    # a group for indices - i.e. _anchors pairwise
    _idxs_subgrp = _cis_grp.create_group("indices")
    _idxs_subgrp.create_dataset("anchor1", data=_df_intra_arm["cluster1"].to_numpy())
    _idxs_subgrp.create_dataset("anchor2", data=_df_intra_arm["cluster2"].to_numpy())
    # a group for pileups - i.e. 3D array
    _pups_subgrp = _cis_grp.create_group("pileups")
    # create cis pileups dataset ...
    for _sample, _arr in fullstacks_cis.items():
        _pups_subgrp.create_dataset(_sample, data=_arr)
    #
    #
    # TRANS ...
    print("re-saving trans data ...")
    _trans_grp = f.create_group("trans")
    # a group for indices - i.e. _anchors pairwise
    _idxs_subgrp = _trans_grp.create_group("indices")
    _idxs_subgrp.create_dataset("anchor1", data=tr_feat["cluster1"].to_numpy())
    _idxs_subgrp.create_dataset("anchor2", data=tr_feat["cluster2"].to_numpy())
    # a group for pileups - i.e. 3D array
    _pups_subgrp = _trans_grp.create_group("pileups")
    # create trans pileups dataset ...
    for _sample in telo_clrs10:
        print(f"    {_sample} ...")
        _arr = np.load(f"/data/sergpolly/tmp/new_{_sample}_trans_stack_oe.npy")
        _pups_subgrp.create_dataset(_sample, data=_arr)


In [None]:
! ls -lah /data/sergpolly/tmp/Pileups_ID_by_distance_pCyto.hdf5
# ! rm /data/sergpolly/tmp/Pileups_ID_by_distance_pCyto.hdf5

In [None]:
! ls -lah /data/sergpolly/tmp/Pileups_ID_by_distance.hdf5
# ! rm /data/sergpolly/tmp/Pileups_ID_by_distance.hdf5

In [None]:
# # create indexes for pileup groups
# _enrich1_idx = tr_feat.query("(dots1>0)&(dots2>0)").index
# _enrich10_idx = tr_feat.query("(dots_footprint1==0)&(dots_footprint2==0)").index
# _deplete_idx = tr_feat.query("(dots_footprint1>0)&(dots_footprint2>0)").index
# # _enrich1_idx = tr_feat.query("(valency1>1)&(valency2>1)").index
# # _enrich10_idx = tr_feat.query("(valency1>=10)&(valency2>=10)").index
# # _deplete_idx = tr_feat.query("(valency1==1)&(valency2==1)").index
# len(tr_feat), len(_enrich10_idx), len(_enrich1_idx), len(_deplete_idx)

# # now average those sub-pileups :

# def _job(sample):
#     # extract stack of observed and expected per sample ...
#     _stack = np.load(f"/data/sergpolly/tmp/new_{sample}_trans_stack_oe.npy")
#     return (
#         sample,
#         np.nanmean(_stack[_enrich10_idx], axis=0),
#         np.nanmean(_stack[_enrich1_idx], axis=0),
#         np.nanmean(_stack[_deplete_idx], axis=0),
#         np.nanmean(_stack, axis=0)
#     )

# with WorkerPool( n_jobs=4, start_method="fork", use_dill=True ) as wpool:
#     results = wpool.map(_job, telo_clrs25, progress_bar=True)
# # unpack results ...
# stack_means = {s: [e10, e1, ed, eall] for s, e10, e1, ed, eall in results }

# plotting pups ...

In [None]:
# pileup select samples only !
_select_sample_groups = [
    [
        "mMito",
        "mTelo",
        "mCyto",
        "m5hR1R2",
        "m10hR1R2"
    ],
    # # p-ones
    [
        "pMito",
        "pTelo",
        "pCyto",
        "p5hR1R2",
        "p10hR1R2",
    ],
    # # # # the mix one - mp
    [
        "N93m5",
        "N93m10",
    ],
    # p ...
    [
        "N93p5",
        "N93p10",
    ],
    [
        "m10hR1R2",
        "p10hR1R2",
        "mp10hR1R2",
    ],
    [
        "N93m10",
        "N93p10",
        "N93mp10",
    ],
]


In [None]:
_flank=100_000
num_trans_groups = 4
ggg = [ "e10", "e1", "ed", "eall" ]

for _sample_group in _select_sample_groups:

    f, axs = plt.subplots(
        nrows=len(_sample_group),
        ncols=len(ggg)+1,
        figsize=(3*len(ggg), 3*len(_sample_group)),
        width_ratios=[1]*len(ggg)+[0.05],
        sharex=True,
        sharey=True,
    )

    gs = axs[0, -1].get_gridspec()
    # remove axes for the last column ...
    for ax in axs[:, -1]:
        ax.remove()

    axcb = f.add_subplot(gs[1:3, -1])

    for i, (_axs, k) in enumerate(zip(axs,_sample_group)):
        # going over samples ...
        _stacks = stack_means[k]
        print(k)
        for j, (ax, _q) in enumerate(zip(_axs, ggg)):
            # going over groupings (by dist, or whatever ...)
            _ccc = ax.imshow(
                _stacks[j],
                cmap='RdBu_r',
                norm=LogNorm(vmin=1/2.5,vmax=2.5),
                # norm=colors.CenteredNorm(vcenter=1,halfrange=0.9,clip=False),
                # norm=colors.TwoSlopeNorm(1, vmin=0.5, vmax=2),
                aspect="auto",
            )
            _ccc.cmap.set_over("#400000")
            ticks_pixels = np.linspace(0, _flank*2//binsize25, 5)
            ticks_kbp = ((ticks_pixels-ticks_pixels[-1]/2)*binsize25//1000).astype(int)
            # ax.set_title(f"{int(_q.left/1_000)} - {int(_q.right/1_000)} kp: {len(_mtx)}")
            if i == 0:
                # top row
                _axname = _q
                ax.set_title(_axname)
            if i == len(_sample_group)-1:
                ax.set_xticks(ticks_pixels, ticks_kbp)
                ax.set_xlabel('relative position, kbp')
            ax.set_yticks(ticks_pixels, ticks_kbp)
            if j<1:
                ax.set_ylabel(f"{k}", fontsize=14)

    plt.colorbar(_ccc, label="obs/exp", cax=axcb)

In [None]:
_flank = 100_000
# _dfff = arm_feat.query("(valency1>1)&(valency2>1)")
_dfff = arm_feat
#_dfff = arm_feat.query("((H3K27ac1>4)&(H3K27ac2>4))&(valency1>2)&(valency2>2)")
_dfff = arm_feat.query("(dots1>0)&(dots2>0)")

print(f"dealing with {len(_dfff)} elements total ...")
dist_bins = [0, 2_500_000, 50_000_000, 90_000_000, 170_000_000, 1_000_000_000]
ggg = _dfff.groupby(pd.cut( _dfff["dist"], dist_bins ))
nquants = len(ggg)

for _sample_group in _select_sample_groups:

    f, axs = plt.subplots(
        nrows=len(_sample_group),
        ncols=len(ggg)+1,
        figsize=(3*len(ggg), 3*len(_sample_group)),
        width_ratios=[1]*nquants+[0.05],
        sharex=True,
        sharey=True,
    )

    gs = axs[0, -1].get_gridspec()
    # remove axes for the last column ...
    for ax in axs[:, -1]:
        ax.remove()

    axcb = f.add_subplot(gs[1:3, -1])

    for i, (_axs, k) in enumerate(zip(axs,_sample_group)):
        # going over samples ...
        _stacks = fullstacks_arm_cis[k]
        print(k)
        for j, (ax, (_q, _mtx)) in enumerate(zip(_axs, ggg.groups.items())):
            # going over groupings (by dist, or whatever ...)
            _ccc = ax.imshow(
                np.nanmean(_stacks[_mtx], axis=0),
                cmap='RdBu_r',
                norm=LogNorm(vmin=1/2.5,vmax=2.5),
                # norm=colors.CenteredNorm(vcenter=1,halfrange=0.9,clip=False),
                # norm=colors.TwoSlopeNorm(1, vmin=0.5, vmax=2),
                aspect="auto",
            )
            _ccc.cmap.set_over("#400000")
            ticks_pixels = np.linspace(0, _flank*2//binsize10, 5)
            ticks_kbp = ((ticks_pixels-ticks_pixels[-1]/2)*binsize10//1000).astype(int)
            # ax.set_title(f"{int(_q.left/1_000)} - {int(_q.right/1_000)} kp: {len(_mtx)}")
            if i == 0:
                # top row
                _axname = _get_name(_q.left, _q.right, len(_mtx))
                ax.set_title(_axname)
            if i == len(_sample_group)-1:
                ax.set_xticks(ticks_pixels, ticks_kbp)
                ax.set_xlabel('relative position, kbp')
            ax.set_yticks(ticks_pixels, ticks_kbp)
            if j<1:
                ax.set_ylabel(f"{k}", fontsize=14)

    plt.colorbar(_ccc, label="obs/exp", cax=axcb)

In [None]:
_df_intra_arm

In [None]:
1400+700

In [None]:
_anchors

In [None]:
_flank = 100_000
_dfff = _df_intra_arm
# _dfff = _df_intra_arm.query("(valency1>1)&(valency2>1)")
# _dfff = _df_intra_arm.query("(ctcf1<1.9)&(ctcf2<1.9)")
# _dfff = _df_intra_arm.query("(dots1>0)&(dots2>0)")
# _dfff = _df_intra_arm.query("(dots_footprint1==0)&(dots_footprint2==0)")
#_dfff = _df_intra_arm.query("((H3K27ac1>4)&(H3K27ac2>4))&(valency1>2)&(valency2>2)")

print(f"dealing with {len(_dfff)} elements total ...")
# dist_bins = [0, 300_000, 1_000_000, 10_000_000, 1_000_000_000]
dist_bins = [0, 250_000, 500_000, 1_000_000, 2_500_000, 5_000_000, 10_000_000, 1_000_000_000]
ggg = _dfff.groupby(pd.cut( _dfff["dist"], dist_bins ), observed=True)
nquants = len(ggg)

for _sample_group in _select_sample_groups:

    f, axs = plt.subplots(
        nrows=len(_sample_group),
        ncols=len(ggg)+1,
        figsize=(3*len(ggg), 3*len(_sample_group)),
        width_ratios=[1]*nquants+[0.05],
        sharex=True,
        sharey=True,
    )

    gs = axs[0, -1].get_gridspec()
    # remove axes for the last column ...
    for ax in axs[:, -1]:
        ax.remove()

    axcb = f.add_subplot(gs[1:3, -1])

    for i, (_axs, k) in enumerate(zip(axs,_sample_group)):
        # going over samples ...
        _stacks = fullstacks_cis[k]
        print(k)
        for j, (ax, (_q, _mtx)) in enumerate(zip(_axs, ggg.groups.items())):
            # going over groupings (by dist, or whatever ...)
            _ccc = ax.imshow(
                np.nanmean(_stacks[_mtx], axis=0),
                cmap='RdBu_r',
                norm=LogNorm(vmin=1/2.5,vmax=2.5),
                # norm=colors.CenteredNorm(vcenter=1,halfrange=0.9,clip=False),
                # norm=colors.TwoSlopeNorm(1, vmin=0.5, vmax=2),
                aspect="auto",
            )
            _ccc.cmap.set_over("#400000")
            ticks_pixels = np.linspace(0, _flank*2//binsize10, 5)
            ticks_kbp = ((ticks_pixels-ticks_pixels[-1]/2)*binsize10//1000).astype(int)
            # ax.set_title(f"{int(_q.left/1_000)} - {int(_q.right/1_000)} kp: {len(_mtx)}")
            if i == 0:
                # top row
                _axname = _get_name(_q.left, _q.right, len(_mtx))
                ax.set_title(_axname)
            if i == len(_sample_group)-1:
                ax.set_xticks(ticks_pixels, ticks_kbp)
                ax.set_xlabel('relative position, kbp')
            ax.set_yticks(ticks_pixels, ticks_kbp)
            if j<1:
                ax.set_ylabel(f"{k}", fontsize=14)

    plt.colorbar(_ccc, label="obs/exp", cax=axcb)

In [None]:
_df_intra_arm["H3K27ac1"].hist(bins=np.linspace(0,120, 100) )

In [None]:
_anchors["H3K27ac"].hist(bins=np.linspace(0,120, 100) )

In [None]:
_flank = 100_000
_dfff = _df_intra_arm
# _dfff = _df_intra_arm.query("(valency1>1)&(valency2>1)")
# _dfff = _df_intra_arm.query("(size1>20_000)&(size2>20_000)")
# _dfff = _df_intra_arm.query("(size2>70_000)")

#_dfff = _df_intra_arm.query("((H3K27ac1>4)&(H3K27ac2>4))&(valency1>2)&(valency2>2)")
_dfff = _df_intra_arm.query("(H3K27ac1<1)&(H3K27ac2<1)")

print(f"dealing with {len(_dfff)} elements total ...")
# dist_bins = [0, 300_000, 1_000_000, 1_500_000, 3_000_000, 5_000_000, 10_000_000, 1_000_000_000]
dist_bins = [0, 500_000, 10_000_000, 1_000_000_000]
# dist_bins = [0, 300_000, 1_000_000, 1_500_000, 3_000_000, 5_000_000, 10_000_000, 1_000_000_000]
# dist_bins = [0, 1_000_000_000]
ggg = _dfff.groupby(pd.cut( _dfff["dist"], dist_bins ), observed=True)

trans_category = "eall"
nquants = len(ggg)

for _sample_group in _select_sample_groups:

    f, axs = plt.subplots(
        nrows=len(_sample_group),
        ncols=len(ggg)+1+1,
        # ncols=len(ggg),
        figsize=(3*(len(ggg)+1), 3*len(_sample_group)),
        width_ratios=[1]*(len(ggg)+1)+[0.05],
        sharex=False,
        sharey=False,
    )

    gs = axs[0, -1].get_gridspec()
    # remove axes for the last column ...
    for ax in axs[:, -1]:
        ax.remove()

    axcb = f.add_subplot(gs[1:3, -1])

    # imshow ...
    imshow_kwargs = dict(
        cmap='RdBu_r',
        norm=LogNorm(vmin=1/2.5,vmax=2.5),
        aspect=1,
    )

    for i, (_axs, k) in enumerate(zip(axs,_sample_group)):
        # going over samples ...
        _stacks = fullstacks_cis[k]
        print(k)
        for j, (ax, (_q, _mtx)) in enumerate(zip(_axs, ggg.groups.items())):
            # going over groupings (by dist, or whatever ...)
            _ccc = ax.imshow( np.nanmean(_stacks[_mtx], axis=0), **imshow_kwargs )
            _ccc.cmap.set_over("#300000")
            # _ccc.cmap.set_over("black")
            ticks_pixels = np.linspace(0, _flank*2//binsize10, 5)
            ticks_kbp = ((ticks_pixels-ticks_pixels[-1]/2)*binsize10//1000).astype(int)
            if i == 0:
                # top row
                _axname = _get_name(_q.left, _q.right, len(_mtx))
                ax.set_title(_axname)
            if i == len(_sample_group)-1:
                ax.set_xticks(ticks_pixels, ticks_kbp)
                ax.set_xlabel('relative position, kbp')
            else:
                ax.set_xticks(ticks_pixels,[])
            if j<1:
                ax.set_ylabel(f"{k}", fontsize=14)
                ax.set_yticks(ticks_pixels, ticks_kbp)
            else:
                ax.set_yticks(ticks_pixels,[])
        # plot trans separately after all ...
        # going over groupings (by dist, or whatever ...)
        j = j + 1
        ax = _axs[j]
        _stacks = stack_means[k][1]  # the e1 one
        # we need to adjust from 250_000 flank to 100_000 one ...
        _ccc = ax.imshow( _stacks, **imshow_kwargs )
        _ccc.cmap.set_over("#300000")
        # _ccc.cmap.set_over("black")
        ticks_pixels = np.linspace(0, _flank*2//binsize25, 5)
        ticks_kbp = ((ticks_pixels-ticks_pixels[-1]/2)*binsize25//1000).astype(int)
        if i == 0:
            # top row
            _axname = "trans"
            ax.set_title(_axname)
        if i == len(_sample_group)-1:
            ax.set_xticks(ticks_pixels, ticks_kbp)
            ax.set_xlabel('relative position, kbp')
        else:
            ax.set_xticks(ticks_pixels, [])
        if j<1:
            ax.set_ylabel(f"{trans_category}", fontsize=14)
            ax.set_yticks(ticks_pixels, ticks_kbp)
        else:
            ax.set_yticks(ticks_pixels,[])
    plt.colorbar(_ccc, label="obs/exp", cax=axcb)