## This is just a pileup plotting notebook that relies on a pre-calculated results stored in an HDF5 file ...

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
# %config InlineBackend.print_figure_kwargs={'bbox_inches':None}
import pandas as pd
import numpy as np
# Hi-C utilities imports:
import cooler
import bioframe
import cooltools
# Visualization imports:
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, Normalize
from matplotlib import colors
import matplotlib.patches as patches
from matplotlib.ticker import EngFormatter

In [None]:
# bbi for stackups ...
import bbi
# functions and assets specific to this repo/project ...
from data_catalog import bws, bws_vlim, telo_dict
from tqdm import tqdm
from tqdm.notebook import trange, tqdm
# import mpire for nested multi-processing

import matplotlib.lines as lines
from matplotlib.lines import Line2D
from matplotlib.patches import ConnectionPatch, Rectangle
from mpl_toolkits.axes_grid1 import Divider, Size
from mpl_toolkits.axes_grid1.inset_locator import BboxConnector
from matplotlib import cm
# from mpl_toolkits.axes_grid1.Size import

# enable editable text ...
mpl.rcParams["pdf.fonttype"]=42
mpl.rcParams["svg.fonttype"]="none"
mpl.rcParams['axes.linewidth'] = 0.5

In [None]:
def to_mb(bp_val):
    # check MB
    if np.mod(bp_val, 1_000_000):
        # just give 1 decimal if not even Mb
        return f"{bp_val/1_000_000:.1f}"
    else:
        return f"{bp_val//1_000_000}"

# given the range - generate pretty axis name
def _get_name(_left, _right, _amount):
    if np.isclose(_left, 0.0):
        return f"<{to_mb(_right)} Mb: {_amount}"
    elif _right > 80_000_000:
        return f">{to_mb(_left)} Mb: {_amount}"
    else:
        return f"{to_mb(_left)}-{to_mb(_right)} Mb: {_amount}"

### Chrom arms as a view

In [None]:
# Use bioframe to fetch the genomic features from the UCSC.
hg38_chromsizes = bioframe.fetch_chromsizes('hg38')
hg38_cens = bioframe.fetch_centromeres('hg38')
hg38_arms_full = bioframe.make_chromarms(hg38_chromsizes, hg38_cens)
# # remove "bad" chromosomes and near-empty arms ...
# excluded_arms = ["chr13_p", "chr14_p", "chr15_p", "chr21_p", "chr22_p", "chrM_p", "chrY_p", "chrY_q", "chrX_p", "chrX_q"]
# hg38_arms = hg38_arms_full[~hg38_arms_full["name"].isin(excluded_arms)].reset_index(drop=True)

# can do 1 chromosome (or arm) as well ..
included_arms = ["chr1_q", "chr2_p", "chr4_q", "chr6_q"]
included_arms = hg38_arms_full["name"].to_list()[:44] # all autosomal ones ...
hg38_arms = hg38_arms_full[hg38_arms_full["name"].isin(included_arms)].reset_index(drop=True)

# There is a problem with our arms view of the chromosomes ...

the way we do it now - end of p-arm is alsways equal to the start of q-arm ...

After binning this could lead to the situation where last bin of p-arm is upstream of the first q-arm bin ...

This makes `cooltools.api.is_valid_expected` crash ...

Let's try solving that by adding 1 bp to the start of every q-arm ...

In [None]:
def adjust_arm_view(
    view_df,
    binsize,
):
    """
    adjust arm-based view of the genome to fix slightly overlapping p and q arms ...
    """
    _iter_view = view_df.itertuples(index=False)
    return pd.DataFrame(
        [(c,s+binsize,e,n) if ("q" in n) else (c,s,e,n) for c,s,e,n in _iter_view],
        columns=hg38_arms.columns
    )


### Read pre-called native compartments
## ... and Pick one list of anchors and annotate it with epigenetic marks ...

In [None]:
id_anchor_fnames = {
    "mega_2X_enrichment": "ID_anchors/mega_2X_enrichment.fourth_mega.max_size.bed",
    "5hr_2X_enrichment_old": "ID_anchors/5hr_2X_enrichment.second_bulk.max_size.bed",
    "5hr_2X_enrichment": "ID_anchors/5hr_2X_enrichment.pixel_derived.bed",
    "5hr_2X_enrichment_nosing": "ID_anchors/5hr_2X_enrichment.pixel_derived.no_singletons.bed",
    "5hr_notinCyto_2X_enrichment_signal": "ID_anchors/p5notin_pCyto_anchors_2X_enrichment.pixel_derived.signal_peaks.bed",
    "5hr_2X_enrichment_signal": "ID_anchors/5hr_2X_enrichment.pixel_derived.signal_peaks.bed",
    "10hr_2X_enrichment_signal": "ID_anchors/10hrs_2X_enrichment.pixel_derived.signal_peaks.bed",
    "N93p5_2X_enrichment_signal": "ID_anchors/N93p5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "pCyto_2X_enrichment_signal": "ID_anchors/pCyto_2X_enrichment.pixel_derived.signal_peaks.bed",
    "mCyto_2X_enrichment_signal": "ID_anchors/mCyto_2X_enrichment.pixel_derived.signal_peaks.bed",
    "mega_3X_enrichment": "ID_anchors/mega_3X_enrichment.fifth_mega3x.max_size.bed",
    "MEGA_2X_enrichment": "ID_anchors/MEGAp5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGA_weaker_2X_enrichment": "ID_anchors/MEGA_plus_weak_anchors_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGAN93_2X_enrichment": "ID_anchors/MEGAN93p5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGAminus_2X_enrichment": "ID_anchors/MEGA_minus_ctrl_2X_enrichment.pixel_derived.signal_peaks.bed",
    "cyto_2x_enrichment": "ID_anchors/cyto_2x_enrichment.third_mCyto.max_size.bed",
}

id_anchors_dict = {}
for id_name, fname in id_anchor_fnames.items():
    id_anchors_dict[id_name] = pd.read_csv(fname, sep="\t")
    # ...
    print(f"loaded {len(id_anchors_dict[id_name]):5d} ID anchors {id_name:>20} in BED format ...")


_anchors = id_anchors_dict["5hr_2X_enrichment_signal"]


## Annotate da heck out of those anchors for pileup subgrouping ...

In [None]:
bw_kyes_to_use = [
    'mG.atac',
    'H3K27me3',
    'H3K4me3',
    'H3K27ac',
    'ctcf',
    'dots',
]

bws["dots"] = "mega_dots_anchors.bb"

for k, bw in bws.items():
    if k in bw_kyes_to_use:
        # left anchor annotation ...
        print(f"working on {k} ...")
        _anchors[f"{k}"] = bbi.stackup(
                bw,
                _anchors["chrom"],
                _anchors["start"],
                _anchors["end"],
                bins=1,
            ).flatten()


# ...
_anchors[f"{k}_footprint"] = bbi.stackup(
    bw,
    _anchors["chrom"],
    _anchors["peak_start"],
    _anchors["peak_end"],
    bins=1,
).flatten()


In [None]:
plt.hist(
    [
        _anchors.query("dots_footprint == 0")["size"],
        _anchors.query("dots_footprint > 0")["size"],
    ],
    bins=np.linspace(20_000,250_000, 50),
    stacked=True,
    label=["dots_footprint == 0","dots_footprint > 0"]
    # color = ['r','g']
)
plt.legend()
plt.gca().set_xlabel("ID anchor footprint")
plt.gca().set_title("ID set: 5hr_2X_enrichment_signal")

## Pre-define coolers that drive all of that - just in case ...

In [None]:
# cooler files that we'll work on :
binsize10 = 10_000
telo_clrs10 = { _k: cooler.Cooler(f"{_path}::/resolutions/{binsize10}") for _k, _path in telo_dict.items() }

# cooler files that we'll work on :
binsize25 = 25_000
telo_clrs25 = { _k: cooler.Cooler(f"{_path}::/resolutions/{binsize25}") for _k, _path in telo_dict.items() }

## Now let's load HDF5 file with all of the pileups and anchor indices for the all-by-all dataframes ...

In [None]:
import h5py

In [None]:

# fr.items()
def print_attrs(name, obj):
    # Create indent
    shift = name.count('/') * '    '
    item_name = name.split("/")[-1]
    print(shift + item_name)
    try:
        for key, val in obj.attrs.items():
            print(shift + '    ' + f"{key}: {val}")
    except:
        pass


with h5py.File("/data/sergpolly/tmp/Pileups_ID_by_distance.hdf5", 'r') as fr:
    fr.visititems(print_attrs)

    # check general metadata ...
    _pileup_meta = dict(fr.attrs)
    for k,v in _pileup_meta.items():
        print(f"{k}: {v}")

    print("...")
    print("restoring cis all-by-all table ...")
    # extract indices to recreate all-by-all in cis:
    cis_left = fr.get("cis/indices").get("anchor1")[()]
    cis_right = fr.get("cis/indices").get("anchor2")[()]
    # assuming index and cluster - are the same ...
    _df_intra_arm = pd.concat(
        [
            _anchors.iloc[cis_left].add_suffix("1").reset_index(drop=True),
            _anchors.iloc[cis_right].add_suffix("2").reset_index(drop=True)
        ],
        axis=1
     )
    _df_intra_arm = _df_intra_arm.reset_index(drop=True)
    _df_intra_arm["dist"] = _df_intra_arm.eval(".5*(start2+end2) - .5*(start1+end1)")

    print("restoring trans all-by-all table ...")
    # extract indices to recreate all-by-all in trans:
    trans_left = fr.get("trans/indices").get("anchor1")[()]
    trans_right = fr.get("trans/indices").get("anchor2")[()]
    # assuming index and cluster - are the same ...
    tr_feat = pd.concat(
        [
            _anchors.iloc[trans_left].add_suffix("1").reset_index(drop=True),
            _anchors.iloc[trans_right].add_suffix("2").reset_index(drop=True)
        ],
        axis=1
     )
    tr_feat = tr_feat.reset_index(drop=True)



    print("extracting cis pileups as is...")
    # sort out the results per sample ...
    fullstacks_cis = {}
    cis_pileups_grp = fr.get("cis/pileups")
    for _sample in cis_pileups_grp.keys():
        fullstacks_cis[_sample] = cis_pileups_grp.get(_sample)[()]


    print("extracting trans pileups and calculating means ...")
    # create indexes for pileup groups
    _dotless_idx = tr_feat.query("(dots_footprint1==0)&(dots_footprint2==0)").index
    _dotted_idx = tr_feat.query("(dots_footprint1>0)&(dots_footprint2>0)").index
    len(tr_feat), len(_dotless_idx), len(_dotted_idx)

    # now average those sub-pileups :
    stack_means = {}
    trans_pileups_grp = fr.get("trans/pileups")
    for _sample in trans_pileups_grp.keys():
        print(f"    processing trans pileup {_sample} ...")
        #
        _stack = trans_pileups_grp.get(_sample)[()]
        stack_means[_sample] = [
            np.nanmean(_stack[_dotless_idx], axis=0),
            np.nanmean(_stack[_dotted_idx], axis=0),
            np.nanmean(_stack, axis=0),
        ]


In [None]:
# Just reporting the number of MCD interactions genome wide ...

print(f"number of intra-arm interactions {len(_df_intra_arm)}")
print(f"number of intra-chromosomal interactions {len(tr_feat)}")

# plotting pups ...

In [None]:
# pileup select samples only !
_select_sample_groups = [
    [
        "mMito",
        "mTelo",
        "mCyto",
        "m5hR1R2",
        "m10hR1R2"
    ],
    # # p-ones
    [
        "pMito",
        "pTelo",
        "pCyto",
        "p5hR1R2",
        "p10hR1R2",
    ],
    # # # # the mix one - mp
    [
        "N93m5",
        "N93m10",
    ],
    # p ...
    [
        "N93p5",
        "N93p10",
    ],
    [
        "m10hR1R2",
        "p10hR1R2",
        "mp10hR1R2",
    ],
    [
        "N93m10",
        "N93p10",
        "N93mp10",
    ],
]


In [None]:
margin = 0.2
matw = 0.5
cbarh = 0.08

imshow_kwargs = dict(
        norm=LogNorm(vmin=1/2.5, vmax=2.5),
        cmap="RdBu_r",
        interpolation="nearest",
)

_flank = 100_000
_dfff = _df_intra_arm
# # _dfff = _df_intra_arm.query("(dots_footprint1==0)&(dots_footprint2==0)")
# dist_bins = [0, 500_000, 10_000_000, 1_000_000_000]
# dist_bins = [0, 250_000, 1_000_000, 10_000_000, 1_000_000_000]
dist_bins = [0, 1_000_000, 10_000_000, 1_000_000_000]
_dist_groups = _dfff.groupby(pd.cut( _dfff["dist"], dist_bins ), observed=True)
ndist = len(_dist_groups)


# The first items are for padding and the second items are for the axes, sizes are in inch.
h =  [ Size.Fixed(margin) ] + len(_dist_groups)*[Size.Fixed(matw), Size.Fixed(0.2*margin)] + [Size.Fixed(matw), Size.Fixed(margin)]
# goes from bottom to the top ...
v = [Size.Fixed(margin), Size.Fixed(cbarh)] + [Size.Fixed(0.6*margin), Size.Fixed(matw)] + [Size.Fixed(0.2*margin), Size.Fixed(matw)]


for _fig_fname, (sample_m, sample_p) in {
    "Fig2E.svg": ("m5hR1R2", "p5hR1R2"),
    "FigE2C.svg": ("N93m5", "N93p5"),
}.items():
    # set figsize based on the tiling provided ...
    fig_width = sum(_h.fixed_size for _h in h)
    fig_height = sum(_v.fixed_size for _v in v)
    fig = plt.figure(
        figsize=(fig_width, fig_height),
        # facecolor='lightblue'
    )
    print(f"figure overall is {fig_width=} {fig_height=}")

    # ...
    divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)

    _cis_stack_m = fullstacks_cis[sample_m]
    _cis_stack_p = fullstacks_cis[sample_p]

    _trans_stack_m = stack_means[sample_m][-1]
    _trans_stack_p = stack_means[sample_p][-1]

    axs_m = {}
    axs_p = {}
    cax_h = {}
    for i, _dist_key in enumerate(_dist_groups.groups):
        # mind the gaps/marging between actual plots ...
        axs_p[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=3))
        axs_m[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=5))
    i = i + 1
    _dist_key = "trans"
    # mind the gaps/marging between actual plots ...
    axs_p[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=3))
    axs_m[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=5))

    cbar_ax = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=1))
    cbar_ax.set_xticks([])
    cbar_ax.set_yticks([])

    for jj, (_dist_key, _dist_idx) in enumerate(_dist_groups.groups.items()):
        axm, axp = axs_m[_dist_key], axs_p[_dist_key]
        # cis pileups first ...
        _mat = np.nanmean(_cis_stack_m[_dist_idx], axis=0)
        _hm = axm.imshow( _mat, **imshow_kwargs)
        _hm.cmap.set_over("#300000")
        _mat = np.nanmean(_cis_stack_p[_dist_idx], axis=0)
        _hm = axp.imshow( _mat, **imshow_kwargs)
        _hm.cmap.set_over("#300000")
        for _ax in [axp, axm]:
            _ax.set_xticks([])
            _ax.set_yticks([])
        axm.set_title(f"{_dist_key}", fontsize=6)
        _mat_size = _mat.shape[0]
        axp.set_xticks([0-0.5,_mat_size/2-0.5,(_mat_size-1)+0.5])
        axp.set_xticklabels([-_flank//1000, 0, _flank//1000], fontsize=6)
        axp.tick_params(length=1.5, pad=1)#,direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)
        # axp.set_xticklabels(np.asarray(ticklabels[::-1]), rotation="vertical")
        for _tidx, tick in enumerate(axp.xaxis.get_majorticklabels()):
            if _tidx == 0:
                tick.set_horizontalalignment("left")
            elif _tidx == 2:
                tick.set_horizontalalignment("right")
            else:
                tick.set_horizontalalignment("center")
        if jj == 0:
            axm.set_ylabel(sample_m, fontsize=6)
            axp.set_ylabel(sample_p, fontsize=6)

    # treat trans separately ...
    jj = jj + 1
    _dist_key = "trans"
    _dist_idx = slice(None)
    axm, axp = axs_m[_dist_key], axs_p[_dist_key]
    _hm = axm.imshow( _trans_stack_m, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    _hm = axp.imshow( _trans_stack_p, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    for _ax in [axp, axm]:
        _ax.set_xticks([])
        _ax.set_yticks([])
    axm.set_title(f"{_dist_key}", fontsize=6)
    _mat_size = _trans_stack_m.shape[0]
    axp.set_xticks([0-0.5, _mat_size/2-0.5, _mat_size-0.5])
    axp.set_xticklabels([-_flank//1000, 0, _flank//1000], fontsize=6)
    axp.tick_params(length=1.5, pad=1)#,direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)
    for _tidx, tick in enumerate(axp.xaxis.get_majorticklabels()):
        if _tidx == 0:
            tick.set_horizontalalignment("left")
        elif _tidx == 2:
            tick.set_horizontalalignment("right")
        else:
            tick.set_horizontalalignment("center")
    # ....
    axp.yaxis.tick_right()
    axp.set_yticks(
        [0-0.5,_mat_size/2-0.5,(_mat_size-1)+0.5],
        labels=[_flank//1000, 0, -_flank//1000],
        rotation=90,
        fontsize=6,
    )
    for _tidx, tick in enumerate(axp.yaxis.get_majorticklabels()):
        if _tidx == 0:
            tick.set_verticalalignment("top")
        elif _tidx == 2:
            tick.set_verticalalignment("bottom")
        else:
            tick.set_verticalalignment("center")
    axp.tick_params(length=1.5, pad=1)  #,direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)
    axm.yaxis.tick_right()
    axm.set_yticks(
        [0-0.5,_mat_size/2-0.5,(_mat_size-1)+0.5],
        labels=[_flank//1000, 0, -_flank//1000],
        rotation=90,
        fontsize=6,
    )
    axm.tick_params(length=1.5, pad=1)  #,direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)
    for _tidx, tick in enumerate(axm.yaxis.get_majorticklabels()):
        if _tidx == 0:
            tick.set_verticalalignment("top")
        elif _tidx == 2:
            tick.set_verticalalignment("bottom")
        else:
            tick.set_verticalalignment("center")

    # add a single colorbar ...
    fig.colorbar(
        cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
        cax=cbar_ax,
        orientation="horizontal",
    )
    cbar_ax.set_xticks([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
    cbar_ax.set_xticklabels([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax], fontsize=6)
    cbar_ax.tick_params(length=1.5, pad=1)#,direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)
    cbar_ax.minorticks_off()
    for _tidx, tick in enumerate(cbar_ax.xaxis.get_majorticklabels()):
        if _tidx == 0:
            tick.set_horizontalalignment("left")
        elif _tidx == 2:
            tick.set_horizontalalignment("right")
        else:
            tick.set_horizontalalignment("center")

    fig.savefig(_fig_fname, dpi=300)

# Legacy non-publication ready stuff ...

In [None]:
_flank=100_000
num_trans_groups = 3
ggg = ["dotless","dotted","all"]

for _sample_group in _select_sample_groups:

    f, axs = plt.subplots(
        nrows=len(_sample_group),
        ncols=len(ggg)+1,
        figsize=(3*len(ggg), 3*len(_sample_group)),
        width_ratios=[1]*len(ggg)+[0.05],
        sharex=True,
        sharey=True,
    )

    gs = axs[0, -1].get_gridspec()
    # remove axes for the last column ...
    for ax in axs[:, -1]:
        ax.remove()

    axcb = f.add_subplot(gs[1:3, -1])

    for i, (_axs, k) in enumerate(zip(axs,_sample_group)):
        # going over samples ...
        _stacks = stack_means[k]
        print(k)
        for j, (ax, _q) in enumerate(zip(_axs, ggg)):
            # going over groupings (by dist, or whatever ...)
            _ccc = ax.imshow(
                _stacks[j],
                cmap='RdBu_r',
                norm=LogNorm(vmin=1/2.5,vmax=2.5),
                # norm=colors.CenteredNorm(vcenter=1,halfrange=0.9,clip=False),
                # norm=colors.TwoSlopeNorm(1, vmin=0.5, vmax=2),
                aspect="auto",
            )
            _ccc.cmap.set_over("#400000")
            ticks_pixels = np.linspace(0, _flank*2//binsize25, 5)
            ticks_kbp = ((ticks_pixels-ticks_pixels[-1]/2)*binsize25//1000).astype(int)
            # ax.set_title(f"{int(_q.left/1_000)} - {int(_q.right/1_000)} kp: {len(_mtx)}")
            if i == 0:
                # top row
                _axname = _q
                ax.set_title(_axname)
            if i == len(_sample_group)-1:
                ax.set_xticks(ticks_pixels, ticks_kbp)
                ax.set_xlabel('relative position, kbp')
            ax.set_yticks(ticks_pixels, ticks_kbp)
            if j<1:
                ax.set_ylabel(f"{k}", fontsize=14)

    plt.colorbar(_ccc, label="obs/exp", cax=axcb)

In [None]:
_flank = 100_000
_dfff = _df_intra_arm
# _dfff = _df_intra_arm.query("(valency1>1)&(valency2>1)")
# _dfff = _df_intra_arm.query("(ctcf1<1.9)&(ctcf2<1.9)")
# _dfff = _df_intra_arm.query("(dots1>0)&(dots2>0)")
# _dfff = _df_intra_arm.query("(dots_footprint1==0)&(dots_footprint2==0)")
#_dfff = _df_intra_arm.query("((H3K27ac1>4)&(H3K27ac2>4))&(valency1>2)&(valency2>2)")

print(f"dealing with {len(_dfff)} elements total ...")
# dist_bins = [0, 300_000, 1_000_000, 10_000_000, 1_000_000_000]
dist_bins = [0, 250_000, 500_000, 1_000_000, 2_500_000, 5_000_000, 10_000_000, 1_000_000_000]
ggg = _dfff.groupby(pd.cut( _dfff["dist"], dist_bins ), observed=True)
nquants = len(ggg)

for _sample_group in _select_sample_groups:

    f, axs = plt.subplots(
        nrows=len(_sample_group),
        ncols=len(ggg)+1,
        figsize=(3*len(ggg), 3*len(_sample_group)),
        width_ratios=[1]*nquants+[0.05],
        sharex=True,
        sharey=True,
    )

    gs = axs[0, -1].get_gridspec()
    # remove axes for the last column ...
    for ax in axs[:, -1]:
        ax.remove()

    axcb = f.add_subplot(gs[1:3, -1])

    for i, (_axs, k) in enumerate(zip(axs,_sample_group)):
        # going over samples ...
        _stacks = fullstacks_cis[k]
        print(k)
        for j, (ax, (_q, _mtx)) in enumerate(zip(_axs, ggg.groups.items())):
            # going over groupings (by dist, or whatever ...)
            _ccc = ax.imshow(
                np.nanmean(_stacks[_mtx], axis=0),
                cmap='RdBu_r',
                norm=LogNorm(vmin=1/2.5,vmax=2.5),
                # norm=colors.CenteredNorm(vcenter=1,halfrange=0.9,clip=False),
                # norm=colors.TwoSlopeNorm(1, vmin=0.5, vmax=2),
                aspect="auto",
            )
            _ccc.cmap.set_over("#400000")
            ticks_pixels = np.linspace(0, _flank*2//binsize10, 5)
            ticks_kbp = ((ticks_pixels-ticks_pixels[-1]/2)*binsize10//1000).astype(int)
            # ax.set_title(f"{int(_q.left/1_000)} - {int(_q.right/1_000)} kp: {len(_mtx)}")
            if i == 0:
                # top row
                _axname = _get_name(_q.left, _q.right, len(_mtx))
                ax.set_title(_axname)
            if i == len(_sample_group)-1:
                ax.set_xticks(ticks_pixels, ticks_kbp)
                ax.set_xlabel('relative position, kbp')
            ax.set_yticks(ticks_pixels, ticks_kbp)
            if j<1:
                ax.set_ylabel(f"{k}", fontsize=14)

    plt.colorbar(_ccc, label="obs/exp", cax=axcb)

In [None]:
_flank = 100_000
_dfff = _df_intra_arm
# _dfff = _df_intra_arm.query("(valency1>1)&(valency2>1)")
# _dfff = _df_intra_arm.query("(size1>20_000)&(size2>20_000)")
# _dfff = _df_intra_arm.query("(size2>70_000)")

#_dfff = _df_intra_arm.query("((H3K27ac1>4)&(H3K27ac2>4))&(valency1>2)&(valency2>2)")
# _dfff = _df_intra_arm.query("(H3K27ac1<1)&(H3K27ac2<1)")

print(f"dealing with {len(_dfff)} elements total ...")
# dist_bins = [0, 300_000, 1_000_000, 1_500_000, 3_000_000, 5_000_000, 10_000_000, 1_000_000_000]
dist_bins = [0, 500_000, 10_000_000, 1_000_000_000]
# dist_bins = [0, 300_000, 1_000_000, 1_500_000, 3_000_000, 5_000_000, 10_000_000, 1_000_000_000]
# dist_bins = [0, 1_000_000_000]
ggg = _dfff.groupby(pd.cut( _dfff["dist"], dist_bins ), observed=True)

trans_category = "all"
nquants = len(ggg)

for _sample_group in _select_sample_groups:

    f, axs = plt.subplots(
        nrows=len(_sample_group),
        ncols=len(ggg)+1+1,
        # ncols=len(ggg),
        figsize=(3*(len(ggg)+1), 3*len(_sample_group)),
        width_ratios=[1]*(len(ggg)+1)+[0.05],
        sharex=False,
        sharey=False,
    )

    gs = axs[0, -1].get_gridspec()
    # remove axes for the last column ...
    for ax in axs[:, -1]:
        ax.remove()

    axcb = f.add_subplot(gs[1:3, -1])

    # imshow ...
    imshow_kwargs = dict(
        cmap='RdBu_r',
        norm=LogNorm(vmin=1/2.5,vmax=2.5),
        aspect=1,
    )

    for i, (_axs, k) in enumerate(zip(axs,_sample_group)):
        # going over samples ...
        _stacks = fullstacks_cis[k]
        print(k)
        for j, (ax, (_q, _mtx)) in enumerate(zip(_axs, ggg.groups.items())):
            # going over groupings (by dist, or whatever ...)
            _ccc = ax.imshow( np.nanmean(_stacks[_mtx], axis=0), **imshow_kwargs )
            _ccc.cmap.set_over("#300000")
            # _ccc.cmap.set_over("black")
            ticks_pixels = np.linspace(0, _flank*2//binsize10, 5)
            ticks_kbp = ((ticks_pixels-ticks_pixels[-1]/2)*binsize10//1000).astype(int)
            if i == 0:
                # top row
                _axname = _get_name(_q.left, _q.right, len(_mtx))
                ax.set_title(_axname)
            if i == len(_sample_group)-1:
                ax.set_xticks(ticks_pixels, ticks_kbp)
                ax.set_xlabel('relative position, kbp')
            else:
                ax.set_xticks(ticks_pixels,[])
            if j<1:
                ax.set_ylabel(f"{k}", fontsize=14)
                ax.set_yticks(ticks_pixels, ticks_kbp)
            else:
                ax.set_yticks(ticks_pixels,[])
        # plot trans separately after all ...
        # going over groupings (by dist, or whatever ...)
        j = j + 1
        ax = _axs[j]
        _stacks = stack_means[k][1]  # the e1 one
        # we need to adjust from 250_000 flank to 100_000 one ...
        _ccc = ax.imshow( _stacks, **imshow_kwargs )
        _ccc.cmap.set_over("#300000")
        # _ccc.cmap.set_over("black")
        ticks_pixels = np.linspace(0, _flank*2//binsize25, 5)
        ticks_kbp = ((ticks_pixels-ticks_pixels[-1]/2)*binsize25//1000).astype(int)
        if i == 0:
            # top row
            _axname = "trans"
            ax.set_title(_axname)
        if i == len(_sample_group)-1:
            ax.set_xticks(ticks_pixels, ticks_kbp)
            ax.set_xlabel('relative position, kbp')
        else:
            ax.set_xticks(ticks_pixels, [])
        if j<1:
            ax.set_ylabel(f"{trans_category}", fontsize=14)
            ax.set_yticks(ticks_pixels, ticks_kbp)
        else:
            ax.set_yticks(ticks_pixels,[])
    plt.colorbar(_ccc, label="obs/exp", cax=axcb)