## This is just a pileup plotting notebook that relies on a pre-calculated results stored in an HDF5 file ...

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
# %config InlineBackend.print_figure_kwargs={'bbox_inches':None}
import pandas as pd
import numpy as np
# Hi-C utilities imports:
import cooler
import bioframe
import cooltools
# Visualization imports:
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, Normalize
from matplotlib import colors
import matplotlib.patches as patches
from matplotlib.ticker import EngFormatter

In [None]:
# bbi for stackups ...
import bbi
# functions and assets specific to this repo/project ...
from data_catalog import bws, bws_vlim, telo_dict
from tqdm import tqdm
from tqdm.notebook import trange, tqdm
# import mpire for nested multi-processing

import matplotlib.lines as lines
from matplotlib.lines import Line2D
from matplotlib.patches import ConnectionPatch, Rectangle
from mpl_toolkits.axes_grid1 import Divider, Size
from mpl_toolkits.axes_grid1.inset_locator import BboxConnector
from matplotlib import cm
# from mpl_toolkits.axes_grid1.Size import Fixed


# enable editable text ...
mpl.rcParams["pdf.fonttype"]=42
mpl.rcParams["svg.fonttype"]="none"
mpl.rcParams['axes.linewidth'] = 0.5

In [None]:
def to_mb(bp_val):
    # check MB
    if np.mod(bp_val, 1_000_000):
        # just give 1 decimal if not even Mb
        return f"{bp_val/1_000_000:.1f}"
    else:
        return f"{bp_val//1_000_000}"

# given the range - generate pretty axis name
def _get_name(_left, _right, _amount):
    if np.isclose(_left, 0.0):
        return f"<{to_mb(_right)} Mb: {_amount}"
    elif _right > 80_000_000:
        return f">{to_mb(_left)} Mb: {_amount}"
    else:
        return f"{to_mb(_left)}-{to_mb(_right)} Mb: {_amount}"

### Chrom arms as a view

In [None]:
# Use bioframe to fetch the genomic features from the UCSC.
hg38_chromsizes = bioframe.fetch_chromsizes('hg38')
hg38_cens = bioframe.fetch_centromeres('hg38')
hg38_arms_full = bioframe.make_chromarms(hg38_chromsizes, hg38_cens)
# # remove "bad" chromosomes and near-empty arms ...
# excluded_arms = ["chr13_p", "chr14_p", "chr15_p", "chr21_p", "chr22_p", "chrM_p", "chrY_p", "chrY_q", "chrX_p", "chrX_q"]
# hg38_arms = hg38_arms_full[~hg38_arms_full["name"].isin(excluded_arms)].reset_index(drop=True)

# can do 1 chromosome (or arm) as well ..
included_arms = ["chr1_q", "chr2_p", "chr4_q", "chr6_q"]
included_arms = hg38_arms_full["name"].to_list()[:44] # all autosomal ones ...
hg38_arms = hg38_arms_full[hg38_arms_full["name"].isin(included_arms)].reset_index(drop=True)

# There is a problem with our arms view of the chromosomes ...

the way we do it now - end of p-arm is alsways equal to the start of q-arm ...

After binning this could lead to the situation where last bin of p-arm is upstream of the first q-arm bin ...

This makes `cooltools.api.is_valid_expected` crash ...

Let's try solving that by adding 1 bp to the start of every q-arm ...

In [None]:
def adjust_arm_view(
    view_df,
    binsize,
):
    """
    adjust arm-based view of the genome to fix slightly overlapping p and q arms ...
    """
    _iter_view = view_df.itertuples(index=False)
    return pd.DataFrame(
        [(c,s+binsize,e,n) if ("q" in n) else (c,s,e,n) for c,s,e,n in _iter_view],
        columns=hg38_arms.columns
    )


### Read pre-called native compartments
## ... and Pick one list of anchors and annotate it with epigenetic marks ...

In [None]:
id_anchor_fnames = {
    "mega_2X_enrichment": "ID_anchors/mega_2X_enrichment.fourth_mega.max_size.bed",
    "5hr_2X_enrichment_old": "ID_anchors/5hr_2X_enrichment.second_bulk.max_size.bed",
    "5hr_2X_enrichment": "ID_anchors/5hr_2X_enrichment.pixel_derived.bed",
    "5hr_2X_enrichment_nosing": "ID_anchors/5hr_2X_enrichment.pixel_derived.no_singletons.bed",
    "5hr_notinCyto_2X_enrichment_signal": "ID_anchors/p5notin_pCyto_anchors_2X_enrichment.pixel_derived.signal_peaks.bed",
    "5hr_2X_enrichment_signal": "ID_anchors/5hr_2X_enrichment.pixel_derived.signal_peaks.bed",
    "10hr_2X_enrichment_signal": "ID_anchors/10hrs_2X_enrichment.pixel_derived.signal_peaks.bed",
    "N93p5_2X_enrichment_signal": "ID_anchors/N93p5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "pCyto_2X_enrichment_signal": "ID_anchors/pCyto_2X_enrichment.pixel_derived.signal_peaks.bed",
    "mCyto_2X_enrichment_signal": "ID_anchors/mCyto_2X_enrichment.pixel_derived.signal_peaks.bed",
    "mega_3X_enrichment": "ID_anchors/mega_3X_enrichment.fifth_mega3x.max_size.bed",
    "MEGA_2X_enrichment": "ID_anchors/MEGAp5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGA_weaker_2X_enrichment": "ID_anchors/MEGA_plus_weak_anchors_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGAN93_2X_enrichment": "ID_anchors/MEGAN93p5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGAminus_2X_enrichment": "ID_anchors/MEGA_minus_ctrl_2X_enrichment.pixel_derived.signal_peaks.bed",
    "cyto_2x_enrichment": "ID_anchors/cyto_2x_enrichment.third_mCyto.max_size.bed",
}

id_anchors_dict = {}
for id_name, fname in id_anchor_fnames.items():
    id_anchors_dict[id_name] = pd.read_csv(fname, sep="\t")
    # ...
    print(f"loaded {len(id_anchors_dict[id_name]):5d} ID anchors {id_name:>20} in BED format ...")


_anchors = id_anchors_dict["5hr_2X_enrichment_signal"]
_anchorsG1 = id_anchors_dict["5hr_2X_enrichment_signal"]
_anchorsCyto = id_anchors_dict["pCyto_2X_enrichment_signal"]


## Annotate da heck out of those anchors for pileup subgrouping ...

In [None]:
# # check if anchors overlap other anchors ....
# _g1_index = bioframe.overlap(
#     _anchors,
#     _anchorsG1.eval(
#         """
#         peak_start = peak_start - 15_000
#         peak_end = peak_end + 15_000
#         """
#     ),
#     return_input=False,
#     return_index=True,
#     return_overlap=False,
#     suffixes=('', '_'),
#     keep_order=True,
#     cols1=("chrom","peak_start","peak_end"),
#     cols2=("chrom","peak_start","peak_end"),
# ).dropna()["index"].unique()

# # annotate G1 status of some of the pCyto anchors ...
# _anchors["G1status"] = False
# _anchors.loc[_g1_index, "G1status"] = True


# check if anchors overlap other anchors ....
_cyto_index = bioframe.overlap(
    _anchors,
    _anchorsCyto.eval(
        """
        peak_start = peak_start - 15_000
        peak_end = peak_end + 15_000
        """
    ),
    return_input=False,
    return_index=True,
    return_overlap=False,
    suffixes=('', '_'),
    keep_order=True,
    cols1=("chrom","peak_start","peak_end"),
    cols2=("chrom","peak_start","peak_end"),
).dropna()["index"].unique()

# annotate G1 status of some of the pCyto anchors ...
_anchors["Cytostatus"] = False
_anchors.loc[_cyto_index, "Cytostatus"] = True

In [None]:
# _anchors.query("~Cytostatus")

In [None]:
bw_kyes_to_use = [
    'mG.atac',
    'H3K27me3',
    'H3K4me3',
    'H3K27ac',
    'ctcf',
    'dots',
]

bws["dots"] = "mega_dots_anchors.bb"

for k, bw in bws.items():
    if k in bw_kyes_to_use:
        # left anchor annotation ...
        print(f"working on {k} ...")
        _anchors[f"{k}"] = bbi.stackup(
                bw,
                _anchors["chrom"],
                _anchors["start"],
                _anchors["end"],
                bins=1,
            ).flatten()


# ...
_anchors[f"{k}_footprint"] = bbi.stackup(
    bw,
    _anchors["chrom"],
    _anchors["peak_start"],
    _anchors["peak_end"],
    bins=1,
).flatten()


In [None]:
plt.hist(
    [
        _anchors.query("dots_footprint == 0")["size"],
        _anchors.query("dots_footprint > 0")["size"],
    ],
    bins=np.linspace(20_000,250_000, 50),
    stacked=True,
    label=["dots_footprint == 0","dots_footprint > 0"]
    # color = ['r','g']
)
plt.legend()
plt.gca().set_xlabel("ID anchor footprint")
plt.gca().set_title("ID set: 5hr_2X_enrichment_signal")

## Pre-define coolers that drive all of that - just in case ...

In [None]:
# cooler files that we'll work on :
binsize10 = 10_000
telo_clrs10 = { _k: cooler.Cooler(f"{_path}::/resolutions/{binsize10}") for _k, _path in telo_dict.items() }

# cooler files that we'll work on :
binsize25 = 25_000
telo_clrs25 = { _k: cooler.Cooler(f"{_path}::/resolutions/{binsize25}") for _k, _path in telo_dict.items() }

## Now let's load HDF5 file with all of the pileups and anchor indices for the all-by-all dataframes ...

In [None]:
import h5py

In [None]:
# ! ls /data/sergpolly/tmp/Pileups_ID*

In [None]:

# fr.items()
def print_attrs(name, obj):
    # Create indent
    shift = name.count('/') * '    '
    item_name = name.split("/")[-1]
    print(shift + item_name)
    try:
        for key, val in obj.attrs.items():
            print(shift + '    ' + f"{key}: {val}")
    except:
        pass


with h5py.File("/data/sergpolly/tmp/Pileups_ID_by_distance.hdf5", 'r') as fr:
# with h5py.File("/data/sergpolly/tmp/Pileups_ID_by_distance_pCyto.hdf5", 'r') as fr:
    fr.visititems(print_attrs)

    # check general metadata ...
    _pileup_meta = dict(fr.attrs)
    for k,v in _pileup_meta.items():
        print(f"{k}: {v}")

    print("...")
    print("restoring cis all-by-all table ...")
    # extract indices to recreate all-by-all in cis:
    cis_left = fr.get("cis/indices").get("anchor1")[()]
    cis_right = fr.get("cis/indices").get("anchor2")[()]
    # assuming index and cluster - are the same ...
    _df_intra_arm = pd.concat(
        [
            _anchors.iloc[cis_left].add_suffix("1").reset_index(drop=True),
            _anchors.iloc[cis_right].add_suffix("2").reset_index(drop=True)
        ],
        axis=1
     )
    _df_intra_arm = _df_intra_arm.reset_index(drop=True)
    _df_intra_arm["dist"] = _df_intra_arm.eval(".5*(start2+end2) - .5*(start1+end1)")

    print("restoring trans all-by-all table ...")
    # extract indices to recreate all-by-all in trans:
    trans_left = fr.get("trans/indices").get("anchor1")[()]
    trans_right = fr.get("trans/indices").get("anchor2")[()]
    # assuming index and cluster - are the same ...
    tr_feat = pd.concat(
        [
            _anchors.iloc[trans_left].add_suffix("1").reset_index(drop=True),
            _anchors.iloc[trans_right].add_suffix("2").reset_index(drop=True)
        ],
        axis=1
     )
    tr_feat = tr_feat.reset_index(drop=True)



    print("extracting cis pileups as is...")
    # sort out the results per sample ...
    fullstacks_cis = {}
    cis_pileups_grp = fr.get("cis/pileups")
    for _sample in cis_pileups_grp.keys():
        fullstacks_cis[_sample] = cis_pileups_grp.get(_sample)[()]


    print("extracting trans pileups and calculating means ...")
    # # create indexes for pileup groups
    _dotless_idx = tr_feat.query("(~Cytostatus1) & (~Cytostatus2)").index
    _dotted_idx = tr_feat.query("Cytostatus1 & Cytostatus2").index
    # _dotless_idx = tr_feat.query("(~G1status1) & (~G1status2)").index
    # _dotted_idx = tr_feat.query("G1status1 & G1status2").index
    # _dotless_idx = tr_feat.query("(dots_footprint1==0)&(dots_footprint2==0)").index
    # _dotted_idx = tr_feat.query("(dots_footprint1>0)&(dots_footprint2>0)").index
    len(tr_feat), len(_dotless_idx), len(_dotted_idx)

    # now average those sub-pileups :
    stack_means = {}
    trans_pileups_grp = fr.get("trans/pileups")
    for _sample in trans_pileups_grp.keys():
        print(f"    processing trans pileup {_sample} ...")
        #
        _stack = trans_pileups_grp.get(_sample)[()]
        stack_means[_sample] = [
            np.nanmean(_stack[_dotless_idx], axis=0),
            np.nanmean(_stack[_dotted_idx], axis=0),
            np.nanmean(_stack, axis=0),
        ]


# plotting pups ...

In [None]:
# pileup select samples only !
_select_sample_groups = [
    ["mMito","mTelo","mCyto","m5hR1R2","m10hR1R2"],
    # # p-ones
    ["pMito","pTelo","pCyto","p5hR1R2","p10hR1R2"],
    # # # # the mix one - mp
    ["N93m5","N93m10"],
    # p ...
    ["N93p5","N93p10"],
    ["m10hR1R2","p10hR1R2","mp10hR1R2"],
    ["N93m10","N93p10","N93mp10"],
]


# Full figure 3D ...

In [None]:
margin = 0.2
tcourse_spacing = 0.1
matw = 0.35
cbarh = 0.07

imshow_kwargs = dict(
        norm=LogNorm(vmin=1/2.5, vmax=2.5),
        cmap="RdBu_r",
        interpolation="nearest",
)

# timecourse_samples = ["Mito", "Telo", "Cyto", "5hR1R2", "10hR1R2"]
timecourse_samples = ["Mito", "Telo", "Cyto", "5hR1R2"]
_nsamples = len(timecourse_samples)

_flank = 100_000
_dfff = _df_intra_arm
# _cis_subidx = _dfff.query("(dots_footprint1==0)&(dots_footprint2==0)").index
# _cis_subidx = _dfff.query("(dots_footprint1>0)&(dots_footprint2>0)").index
# _cis_G1idx = _dfff.query("G1status1 & G1status2").index
_cis_G1idx = _dfff.query("~Cytostatus1 & ~Cytostatus2").index
_cis_Cytoidx = _dfff.query("Cytostatus1 & Cytostatus2").index
# _cis_subidx = _dfff.index
_trans_G1idx = 0
_trans_Cytoidx = 1

# The first items are for padding and the second items are for the axes, sizes are in inch.
h = [ Size.Fixed(margin) ] + \
    (_nsamples-1)*[ Size.Fixed(matw), Size.Fixed(0.25*margin) ] + \
    [ Size.Fixed(matw), Size.Fixed(tcourse_spacing) ] + \
    (_nsamples-1)*[ Size.Fixed(matw), Size.Fixed(0.25*margin) ] + \
    [ Size.Fixed(matw), Size.Fixed(margin) ]
# goes from bottom to the top ...
v = [ Size.Fixed(margin), Size.Fixed(cbarh), Size.Fixed(0.5*margin),
     Size.Fixed(matw), Size.Fixed(0.25*margin), Size.Fixed(matw), Size.Fixed(1.5*margin),
     Size.Fixed(matw), Size.Fixed(0.25*margin), Size.Fixed(matw), Size.Fixed(margin)]


# set figsize based on the tiling provided ...
fig_width = sum(_h.fixed_size for _h in h)
fig_height = sum(_v.fixed_size for _v in v)
fig = plt.figure(
    figsize=(fig_width, fig_height),
    # facecolor='lightblue'
)
print(f"figure size {fig_width=} {fig_height=}")

# ...
divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)

axs_g1_m = {}
axs_g1_p = {}
axs_g1_trans_m = {}
axs_g1_trans_p = {}

axs_cyto_m = {}
axs_cyto_p = {}
axs_cyto_trans_m = {}
axs_cyto_trans_p = {}

for i, _sample in enumerate(timecourse_samples):
    # G1-specific ...
    # mind the gaps/marging between actual plots ...
    axs_g1_p[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*(i+_nsamples)+1, ny=5))
    axs_g1_m[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=5))
    axs_g1_trans_p[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*(i+_nsamples)+1, ny=3))
    axs_g1_trans_m[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=3))
    # Cyto-specific ...
    # mind the gaps/marging between actual plots ...
    axs_cyto_p[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*(i+_nsamples)+1, ny=9))
    axs_cyto_m[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=9))
    axs_cyto_trans_p[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*(i+_nsamples)+1, ny=7))
    axs_cyto_trans_m[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=7))

cbar_ax = fig.add_axes(
    divider.get_position(),
    axes_locator=divider.new_locator(nx=2*(i+_nsamples-1)+1, nx1=2*(i+_nsamples+1), ny=1)
)
cbar_ax.set_xticks([])
cbar_ax.set_yticks([])


for jj, _sample in enumerate(timecourse_samples):
    axm, axp = axs_g1_m[_sample], axs_g1_p[_sample]
    taxm, taxp = axs_g1_trans_m[_sample], axs_g1_trans_p[_sample]
    #
    sample_m = f"m{_sample}"
    sample_p = f"p{_sample}"
    _cis_stack_m = np.nanmean(fullstacks_cis[sample_m][_cis_G1idx], axis=0)
    _cis_stack_p = np.nanmean(fullstacks_cis[sample_p][_cis_G1idx], axis=0)
    _trans_stack_m = stack_means[sample_m][_trans_G1idx]
    _trans_stack_p = stack_means[sample_p][_trans_G1idx]
    # cis pileups first ...
    _hm = axm.imshow( _cis_stack_m, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    _hm = axp.imshow( _cis_stack_p, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    # # trans pileups second ...
    _hm = taxm.imshow( _trans_stack_m, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    _hm = taxp.imshow( _trans_stack_p, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    for _ax in [axp, axm, taxp, taxm]:
        _ax.set_xticks([])
        _ax.set_yticks([])
    # ylabel
    if jj == 0:
        axm.set_ylabel("cis", fontsize=6, labelpad=1)
        taxm.set_ylabel("trans", fontsize=6, labelpad=1)
    # add ticks ...
    _mat_size = _trans_stack_m.shape[0]
    taxm.set_xticks([0-0.5, _mat_size/2-0.5, _mat_size-0.5])
    taxm.set_xticklabels([-_flank//1000, 0, _flank//1000], fontsize=4)
    taxm.tick_params(length=1.5, pad=1)  #,direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)
    for _tidx, tick in enumerate(taxm.xaxis.get_majorticklabels()):
        if _tidx == 0:
            tick.set_horizontalalignment("left")
        elif _tidx == 2:
            tick.set_horizontalalignment("right")
        else:
            tick.set_horizontalalignment("center")
    taxp.set_xticks([0-0.5, _mat_size/2-0.5, _mat_size-0.5])
    taxp.set_xticklabels([-_flank//1000, 0, _flank//1000], fontsize=4)
    taxp.tick_params(length=1.5, pad=1)  #,direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)
    for _tidx, tick in enumerate(taxp.xaxis.get_majorticklabels()):
        if _tidx == 0:
            tick.set_horizontalalignment("left")
        elif _tidx == 2:
            tick.set_horizontalalignment("right")
        else:
            tick.set_horizontalalignment("center")
    # for the very last one ... - do ticks again ...
    if jj == len(timecourse_samples) - 1:
        _mat_size = _trans_stack_m.shape[0]
        taxp.yaxis.tick_right()
        taxp.set_yticks(
            [0-0.5, _mat_size/2-0.5, _mat_size-0.5],
            labels=[-_flank//1000, 0, _flank//1000],
            rotation=90,
            fontsize=4,
        )
        taxp.tick_params(length=1.5, pad=1)  #,direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)
        for _tidx, tick in enumerate(taxp.yaxis.get_majorticklabels()):
            if _tidx == 0:
                tick.set_verticalalignment("top")
            elif _tidx == 2:
                tick.set_verticalalignment("bottom")
            else:
                tick.set_verticalalignment("center")
        _mat_size = _cis_stack_m.shape[0]
        axp.yaxis.tick_right()
        axp.set_yticks(
            [0-0.5, _mat_size/2-0.5, _mat_size-0.5],
            labels=[-_flank//1000, 0, _flank//1000],
            rotation=90,
            fontsize=4,
        )
        axp.tick_params(length=1.5, pad=1)  #,direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)
        for _tidx, tick in enumerate(axp.yaxis.get_majorticklabels()):
            if _tidx == 0:
                tick.set_verticalalignment("top")
            elif _tidx == 2:
                tick.set_verticalalignment("bottom")
            else:
                tick.set_verticalalignment("center")

for jj, _sample in enumerate(timecourse_samples):
    axm, axp = axs_cyto_m[_sample], axs_cyto_p[_sample]
    taxm, taxp = axs_cyto_trans_m[_sample], axs_cyto_trans_p[_sample]
    #
    sample_m = f"m{_sample}"
    sample_p = f"p{_sample}"
    _cis_stack_m = np.nanmean(fullstacks_cis[sample_m][_cis_Cytoidx], axis=0)
    _cis_stack_p = np.nanmean(fullstacks_cis[sample_p][_cis_Cytoidx], axis=0)
    _trans_stack_m = stack_means[sample_m][_trans_Cytoidx]
    _trans_stack_p = stack_means[sample_p][_trans_Cytoidx]
    # cis pileups first ...
    _hm = axm.imshow( _cis_stack_m, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    _hm = axp.imshow( _cis_stack_p, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    # # trans pileups second ...
    _hm = taxm.imshow( _trans_stack_m, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    _hm = taxp.imshow( _trans_stack_p, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    for _ax in [axp, axm, taxp, taxm]:
        _ax.set_xticks([])
        _ax.set_yticks([])
    # ylabel
    if jj == 0:
        axm.set_ylabel("cis", fontsize=6, labelpad=1)
        taxm.set_ylabel("trans", fontsize=6, labelpad=1)
    # ...
    axm.set_title(_sample, fontsize=6, pad=1)
    axp.set_title(_sample, fontsize=6, pad=1)
    # for the very last one ... - do ticks again ...
    if jj == len(timecourse_samples) - 1:
        _mat_size = _trans_stack_m.shape[0]
        taxp.yaxis.tick_right()
        taxp.set_yticks(
            [0-0.5, _mat_size/2-0.5, _mat_size-0.5],
            labels=[-_flank//1000, 0, _flank//1000],
            rotation=90,
            fontsize=4,
        )
        taxp.tick_params(length=1.5, pad=1)  #,direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)
        for _tidx, tick in enumerate(taxp.yaxis.get_majorticklabels()):
            if _tidx == 0:
                tick.set_verticalalignment("top")
            elif _tidx == 2:
                tick.set_verticalalignment("bottom")
            else:
                tick.set_verticalalignment("center")
        _mat_size = _cis_stack_m.shape[0]
        axp.yaxis.tick_right()
        axp.set_yticks(
            [0-0.5, _mat_size/2-0.5, _mat_size-0.5],
            labels=[-_flank//1000, 0, _flank//1000],
            rotation=90,
            fontsize=4,
        )
        axp.tick_params(length=1.5, pad=1)  #,direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)
        for _tidx, tick in enumerate(axp.yaxis.get_majorticklabels()):
            if _tidx == 0:
                tick.set_verticalalignment("top")
            elif _tidx == 2:
                tick.set_verticalalignment("bottom")
            else:
                tick.set_verticalalignment("center")

# add a single colorbar ...
fig.colorbar(
    cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
    cax=cbar_ax,
    orientation="horizontal",
)
cbar_ax.set_xticks([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
cbar_ax.set_xticklabels([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax], fontsize=6)
cbar_ax.minorticks_off()
cbar_ax.tick_params(length=1.5, pad=1)  #,direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)
for _tidx, tick in enumerate(cbar_ax.xaxis.get_majorticklabels()):
    if _tidx == 0:
        tick.set_horizontalalignment("left")
    elif _tidx == 2:
        tick.set_horizontalalignment("right")
    else:
        tick.set_horizontalalignment("center")


plt.savefig("fig3D_timecourse.svg", dpi=300)

In [None]:
# margin = 0.2
# tcourse_spacing = 0.1
# matw = 0.45
# cbarh = 0.075

# imshow_kwargs = dict(
#         norm=LogNorm(vmin=1/2.5, vmax=2.5),
#         cmap="RdBu_r",
#         interpolation="nearest",
# )

# # timecourse_samples = ["Mito", "Telo", "Cyto", "5hR1R2", "10hR1R2"]
# timecourse_samples = ["Mito", "Telo", "Cyto", "5hR1R2"]
# _nsamples = len(timecourse_samples)

# _flank = 100_000
# _dfff = _df_intra_arm
# # _cis_subidx = _dfff.query("(dots_footprint1==0)&(dots_footprint2==0)").index
# # _cis_subidx = _dfff.query("(dots_footprint1>0)&(dots_footprint2>0)").index
# # _cis_subidx = _dfff.query("G1status1 & G1status2").index
# # _cis_subidx = _dfff.query("~Cytostatus1 & ~Cytostatus2").index
# _cis_subidx = _dfff.index
# _trans_idx = 0

# # The first items are for padding and the second items are for the axes, sizes are in inch.
# h = [ Size.Fixed(margin) ] + \
#     (_nsamples-1)*[ Size.Fixed(matw), Size.Fixed(0.25*margin) ] + \
#     [ Size.Fixed(matw), Size.Fixed(tcourse_spacing) ] + \
#     (_nsamples-1)*[ Size.Fixed(matw), Size.Fixed(0.25*margin) ] + \
#     [ Size.Fixed(matw), Size.Fixed(margin) ]
# # goes from bottom to the top ...
# v = [ Size.Fixed(margin), Size.Fixed(cbarh), Size.Fixed(0.2*margin),
#      Size.Fixed(matw), Size.Fixed(tcourse_spacing), Size.Fixed(matw), Size.Fixed(margin)]


# # set figsize based on the tiling provided ...
# fig_width = sum(_h.fixed_size for _h in h)
# fig_height = sum(_v.fixed_size for _v in v)
# fig = plt.figure(
#     figsize=(fig_width, fig_height),
#     # facecolor='lightblue'
# )
# print(f"figure size {fig_width=} {fig_height=}")

# # ...
# divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)

# axs_m = {}
# axs_p = {}
# axs_trans_m = {}
# axs_trans_p = {}

# for i, _sample in enumerate(timecourse_samples):
#     # mind the gaps/marging between actual plots ...
#     axs_p[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*(i+_nsamples)+1, ny=5))
#     axs_m[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=5))
#     axs_trans_p[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*(i+_nsamples)+1, ny=3))
#     axs_trans_m[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=3))

# cbar_ax = fig.add_axes(
#     divider.get_position(),
#     axes_locator=divider.new_locator(nx=2*(i+_nsamples-1)+1, nx1=2*(i+_nsamples+1), ny=1)
# )
# cbar_ax.set_xticks([])
# cbar_ax.set_yticks([])


# for jj, _sample in enumerate(timecourse_samples):
#     axm, axp = axs_m[_sample], axs_p[_sample]
#     taxm, taxp = axs_trans_m[_sample], axs_trans_p[_sample]
#     #
#     sample_m = f"m{_sample}"
#     sample_p = f"p{_sample}"
#     _cis_stack_m = np.nanmean(fullstacks_cis[sample_m][_cis_subidx], axis=0)
#     _cis_stack_p = np.nanmean(fullstacks_cis[sample_p][_cis_subidx], axis=0)
#     _trans_stack_m = stack_means[sample_m][_trans_idx]
#     _trans_stack_p = stack_means[sample_p][_trans_idx]
#     # #
#     # # cis pileups first ...
#     _hm = axm.imshow( _cis_stack_m, **imshow_kwargs)
#     _hm.cmap.set_over("#300000")
#     _hm = axp.imshow( _cis_stack_p, **imshow_kwargs)
#     _hm.cmap.set_over("#300000")
#     # # trans pileups second ...
#     _hm = taxm.imshow( _trans_stack_m, **imshow_kwargs)
#     _hm.cmap.set_over("#300000")
#     _hm = taxp.imshow( _trans_stack_p, **imshow_kwargs)
#     _hm.cmap.set_over("#300000")
#     for _ax in [axp, axm, taxp, taxm]:
#         _ax.set_xticks([])
#         _ax.set_yticks([])
#     # axm.set_title(f"{_dist_key}", fontsize=9)
#     # axp.set_xticks(np.arange(len(ticklabels)))
#     # axp.set_xticklabels(np.asarray(ticklabels[::-1]), rotation="vertical")
#     # if jj == 0:
#     #     axm.set_ylabel(sample_m)
#     #     axp.set_ylabel(sample_p)

# # add a single colorbar ...
# fig.colorbar(
#     cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
#     cax=cbar_ax,
#     orientation="horizontal",
# )
# cbar_ax.set_xticks([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
# cbar_ax.set_xticklabels([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
# cbar_ax.minorticks_off()


# Mostly legacy stuff below ...

## create a function that assigns distance in dots

### Load the dots (not anchors) first ...

In [None]:
# ##################################################################
# ############################################# anchors ...
# anchor_fnames = {
#     "mega_ctrl": "dot_anchors_10kb_MEGA/mG1s_MEGA.bed",
# }
# # ...
# dot_anchors_dict = {}
# for id_name, fname in anchor_fnames.items():
#     dot_anchors_dict[id_name] = pd.read_csv(fname, sep="\t")
#     # ...
#     print(f"loaded {len(dot_anchors_dict[id_name]):5d} ID anchors {id_name:>20} in BED format ...")
# # ...
# ##################################################################
# ############################################# dots themselves ...
dot_fnames = {
    "mega_ctrl": "dots_10kb_MEGA_samples/mG1s_MEGA_10kb_wheader.bedpe",
    "mega_depl": "dots_10kb_MEGA_samples/pG1s_MEGA_10kb_wheader.bedpe",
    "mega_mito": "dots_10kb_MEGA_samples/Ms_MEGA_10kb_wheader.bedpe",
    "cyto": "dots_10kb_samples/mCyto_10kb_wheader.bedpe",
}
# ...
# let's load them all into a dictionary ...
dots_dict = {}
for id_name, fname in dot_fnames.items():
    dots_dict[id_name] = pd.read_csv(fname, sep="\t")
    # ...
    print(f"loaded {len(dots_dict[id_name]):5d} dots {id_name:>20} in BEDPE format ...")


# pick specific anchors and dots ...
# _the_anchors = dot_anchors_dict["mega_ctrl"]
_the_dots = dots_dict["mega_ctrl"]

In [None]:
_the_dots

In [None]:
def get_dot_distance(
    df_grid,
    dots,
):
    # ...
    _overlap_kwargs = dict(
        return_input=False,
        return_index=True,
        return_overlap=True,
        suffixes=("","_dot"),
        keep_order=True,
    )
    ######################################################################################
    _1 = bioframe.overlap(
        df_grid,
        dots,
        cols1=("chrom1", "peak_start1", "peak_end1"),
        cols2=("chrom1", "start1", "end1"),
        **_overlap_kwargs,
    )
    _2 = bioframe.overlap(
        df_grid,
        dots,
        cols1=("chrom1", "peak_start1", "peak_end1"),
        cols2=("chrom2", "start2", "end2"),
        **_overlap_kwargs,
    )
    _left_dot_exact = pd.concat([_1, _2]).sort_values(by="index").reset_index(drop=True)
    ######################################################################################
    # ...                                                                                #
    ######################################################################################
    _1 = bioframe.overlap(
        df_grid,
        dots,
        cols1=("chrom2", "peak_start2", "peak_end2"),
        cols2=("chrom1", "start1", "end1"),
        **_overlap_kwargs,
    )
    _2 = bioframe.overlap(
        df_grid,
        dots,
        cols1=("chrom2", "peak_start2", "peak_end2"),
        cols2=("chrom2", "start2", "end2"),
        **_overlap_kwargs,
    )
    _right_dot_exact = pd.concat([_1, _2]).sort_values(by="index").reset_index(drop=True)
    ######################################################################################
    #
    #
    ######################################################################################
    _dot_annot = _left_dot_exact.merge(
        _right_dot_exact,
        on="index",
    )
    _dot_annot["dot_order"] = _dot_annot.eval("abs(index_dot_x - index_dot_y)")
    _min_dotorder_perIDID = _dot_annot.groupby("index")["dot_order"].min()
    return _min_dotorder_perIDID.to_frame()

In [None]:
_df_intra_arm
# _the_dots

In [None]:
get_dot_distance(
    _df_intra_arm,
    _the_dots,
).query("dot_order < 3").index

# Nup93 Figure Ext fig 5 ...

In [None]:
w, h = 6, 2.9
margin = 0.2
matw = 0.75*1.15
cbarh = 0.1
# cbarw = 0.7*matw

fig = plt.figure(
    figsize=(w, h),
    # facecolor='lightblue'
)

imshow_kwargs = dict(
        norm=LogNorm(vmin=1/2.5, vmax=2.5),
        cmap="RdBu_r",
        interpolation="none",
)

# timecourse_samples = ["Mito", "Telo", "Cyto", "5hR1R2", "10hR1R2"]
# timecourse_samples = ["Mito", "Telo", "Cyto", "5hR1R2"]
timecourse_samples = ["5", "10"]
_nsamples = len(timecourse_samples)


_flank = 100_000
_dfff = _df_intra_arm



# _cis_subidx = _dfff.query("(dots_footprint1==0)&(dots_footprint2==0)").index
_cis_subidx = _dfff.query("(dots_footprint1>0)&(dots_footprint2>0)").index
# _cis_subidx = _dfff.query("G1status1 & G1status2").index
# _cis_subidx = _dfff.query("~Cytostatus1 & ~Cytostatus2").index
# _cis_subidx = _dfff.index
_trans_idx = 1

# The first items are for padding and the second items are for the axes, sizes are in inch.
h = _nsamples*[Size.Fixed(0.25*margin), Size.Fixed(matw)] + [Size.Fixed(margin)] + _nsamples*[Size.Fixed(matw), Size.Fixed(0.25*margin)]
# goes from bottom to the top ...
v = [Size.Fixed(margin), Size.Fixed(cbarh)] + 2*[Size.Fixed(margin), Size.Fixed(matw)]
# ...
divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)

axs_m = {}
axs_p = {}
axs_trans_m = {}
axs_trans_p = {}
# cax_h = {}

for i, _sample in enumerate(timecourse_samples):
    # mind the gaps/marging between actual plots ...
    axs_p[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*(i+_nsamples)+1, ny=5))
    axs_m[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=5))
    axs_trans_p[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*(i+_nsamples)+1, ny=3))
    axs_trans_m[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=3))

# _dist_key = "trans"
# # mind the gaps/marging between actual plots ...
# axs_p[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=3))
# axs_m[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=5))

cbar_ax = fig.add_axes(
    divider.get_position(),
    axes_locator=divider.new_locator(nx=2*(i+_nsamples-1)+1, nx1=2*(i+_nsamples+1)+1, ny=1)
)
cbar_ax.set_xticks([])
cbar_ax.set_yticks([])



for jj, _sample in enumerate(timecourse_samples):
    axm, axp = axs_m[_sample], axs_p[_sample]
    taxm, taxp = axs_trans_m[_sample], axs_trans_p[_sample]
    #
    sample_m = f"N93m{_sample}"
    sample_p = f"N93p{_sample}"
    _cis_stack_m = np.nanmean(fullstacks_cis[sample_m][_cis_subidx], axis=0)
    _cis_stack_p = np.nanmean(fullstacks_cis[sample_p][_cis_subidx], axis=0)
    _trans_stack_m = stack_means[sample_m][_trans_idx]
    _trans_stack_p = stack_means[sample_p][_trans_idx]
    # #
    # # cis pileups first ...
    _hm = axm.imshow( _cis_stack_m, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    _hm = axp.imshow( _cis_stack_p, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    # # trans pileups second ...
    _hm = taxm.imshow( _trans_stack_m, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    _hm = taxp.imshow( _trans_stack_p, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    for _ax in [axp, axm, taxp, taxm]:
        _ax.set_xticks([])
        _ax.set_yticks([])
    # axm.set_title(f"{_dist_key}", fontsize=9)
    # axp.set_xticks(np.arange(len(ticklabels)))
    # axp.set_xticklabels(np.asarray(ticklabels[::-1]), rotation="vertical")
    # if jj == 0:
    #     axm.set_ylabel(sample_m)
    #     axp.set_ylabel(sample_p)


# # treat trans separately ...
# jj = jj + 1
# _dist_key = "trans"
# _dist_idx = slice(None)
# axm, axp = axs_m[_dist_key], axs_p[_dist_key]
# _hm = axm.imshow( _trans_stack_m, **imshow_kwargs)
# _hm.cmap.set_over("#300000")
# _hm = axp.imshow( _trans_stack_p, **imshow_kwargs)
# _hm.cmap.set_over("#300000")
# for _ax in [axp, axm]:
#     _ax.set_xticks([])
#     _ax.set_yticks([])
# axm.set_title(f"{_dist_key}", fontsize=9)


# add a single colorbar ...

fig.colorbar(
    cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
    cax=cbar_ax,
    orientation="horizontal",
)
cbar_ax.set_xticks([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
cbar_ax.set_xticklabels([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
cbar_ax.minorticks_off()

In [None]:
w, h = 6, 2.9
margin = 0.2
matw = 0.75*1.15
cbarh = 0.1
# cbarw = 0.7*matw

fig = plt.figure(
    figsize=(w, h),
    # facecolor='lightblue'
)

imshow_kwargs = dict(
        norm=LogNorm(vmin=1/2.5, vmax=2.5),
        cmap="RdBu_r",
        interpolation="none",
)


_flank = 100_000
_dfff = _df_intra_arm


# The first items are for padding and the second items are for the axes, sizes are in inch.
h = 2*[Size.Fixed(margin), Size.Fixed(matw)]
# goes from bottom to the top ...
v = [Size.Fixed(margin), Size.Fixed(cbarh)] + 3*[Size.Fixed(margin), Size.Fixed(matw)]
# ...
divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)

# mind the gaps/marging between actual plots ...
ax_m_dot = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=1, ny=7))
ax_p_dot = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=3, ny=7))

ax_m_dotted = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=1, ny=5))
ax_p_dotted = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=3, ny=5))

ax_m_dotless = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=1, ny=3))
ax_p_dotless = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=3, ny=3))



for _ax in [ax_m_dot, ax_p_dot, ax_m_dotted, ax_p_dotted, ax_m_dotless, ax_p_dotless]:
    _ax.set_xticks([])
    _ax.set_yticks([])



sample_m = f"N93m5"
sample_p = f"N93p10"


# @ the dot exact
_cis_subidx = get_dot_distance(_dfff, _the_dots).query("dot_order < 3").index
_cis_stack_m = np.nanmean(fullstacks_cis[sample_m][_cis_subidx], axis=0)
_cis_stack_p = np.nanmean(fullstacks_cis[sample_p][_cis_subidx], axis=0)
# # cis pileups first ...
_hm = ax_m_dot.imshow( _cis_stack_m, **imshow_kwargs)
_hm.cmap.set_over("#300000")
_hm = ax_p_dot.imshow( _cis_stack_p, **imshow_kwargs)
_hm.cmap.set_over("#300000")

# dotted
_cis_subidx = _dfff.query("(dots_footprint1>0)&(dots_footprint2>0)").index
_cis_stack_m = np.nanmean(fullstacks_cis[sample_m][_cis_subidx], axis=0)
_cis_stack_p = np.nanmean(fullstacks_cis[sample_p][_cis_subidx], axis=0)
# # cis pileups first ...
_hm = ax_m_dotted.imshow( _cis_stack_m, **imshow_kwargs)
_hm.cmap.set_over("#300000")
_hm = ax_p_dotted.imshow( _cis_stack_p, **imshow_kwargs)
_hm.cmap.set_over("#300000")

# dotless
_cis_subidx = _dfff.query("(dots_footprint1==0)&(dots_footprint2==0)").index
_cis_stack_m = np.nanmean(fullstacks_cis[sample_m][_cis_subidx], axis=0)
_cis_stack_p = np.nanmean(fullstacks_cis[sample_p][_cis_subidx], axis=0)
# # cis pileups first ...
_hm = ax_m_dotless.imshow( _cis_stack_m, **imshow_kwargs)
_hm.cmap.set_over("#300000")
_hm = ax_p_dotless.imshow( _cis_stack_p, **imshow_kwargs)
_hm.cmap.set_over("#300000")

cbar_ax = fig.add_axes(
    divider.get_position(),
    axes_locator=divider.new_locator(nx=3, ny=1)
)
cbar_ax.set_xticks([])
cbar_ax.set_yticks([])






# for jj, _sample in enumerate(timecourse_samples):
#     axm, axp = axs_m[_sample], axs_p[_sample]
#     taxm, taxp = axs_trans_m[_sample], axs_trans_p[_sample]
#     #
#     sample_m = f"N93m{_sample}"
#     sample_p = f"N93p{_sample}"
#     _cis_stack_m = np.nanmean(fullstacks_cis[sample_m][_cis_subidx], axis=0)
#     _cis_stack_p = np.nanmean(fullstacks_cis[sample_p][_cis_subidx], axis=0)
#     _trans_stack_m = stack_means[sample_m][_trans_idx]
#     _trans_stack_p = stack_means[sample_p][_trans_idx]
#     # #
#     # # cis pileups first ...
#     _hm = axm.imshow( _cis_stack_m, **imshow_kwargs)
#     _hm.cmap.set_over("#300000")
#     _hm = axp.imshow( _cis_stack_p, **imshow_kwargs)
#     _hm.cmap.set_over("#300000")
#     # # trans pileups second ...
#     _hm = taxm.imshow( _trans_stack_m, **imshow_kwargs)
#     _hm.cmap.set_over("#300000")
#     _hm = taxp.imshow( _trans_stack_p, **imshow_kwargs)
#     _hm.cmap.set_over("#300000")
#     for _ax in [axp, axm, taxp, taxm]:
#         _ax.set_xticks([])
#         _ax.set_yticks([])
#     # axm.set_title(f"{_dist_key}", fontsize=9)
#     # axp.set_xticks(np.arange(len(ticklabels)))
#     # axp.set_xticklabels(np.asarray(ticklabels[::-1]), rotation="vertical")
#     # if jj == 0:
#     #     axm.set_ylabel(sample_m)
#     #     axp.set_ylabel(sample_p)


# # # treat trans separately ...
# # jj = jj + 1
# # _dist_key = "trans"
# # _dist_idx = slice(None)
# # axm, axp = axs_m[_dist_key], axs_p[_dist_key]
# # _hm = axm.imshow( _trans_stack_m, **imshow_kwargs)
# # _hm.cmap.set_over("#300000")
# # _hm = axp.imshow( _trans_stack_p, **imshow_kwargs)
# # _hm.cmap.set_over("#300000")
# # for _ax in [axp, axm]:
# #     _ax.set_xticks([])
# #     _ax.set_yticks([])
# # axm.set_title(f"{_dist_key}", fontsize=9)


# add a single colorbar ...

fig.colorbar(
    cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
    cax=cbar_ax,
    orientation="horizontal",
)
cbar_ax.set_xticks([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
cbar_ax.set_xticklabels([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
cbar_ax.minorticks_off()

In [None]:
w, h = 6, 2.9
margin = 0.2
matw = 0.75*1.15
cbarh = 0.1
# cbarw = 0.7*matw

fig = plt.figure(
    figsize=(w, h),
    # facecolor='lightblue'
)

imshow_kwargs = dict(
        norm=LogNorm(vmin=1/2.5, vmax=2.5),
        cmap="RdBu_r",
        interpolation="none",
)

_flank = 100_000
_dfff = _df_intra_arm

# The first items are for padding and the second items are for the axes, sizes are in inch.
h = 4*[Size.Fixed(margin), Size.Fixed(matw)]
# goes from bottom to the top ...
v = [Size.Fixed(margin), Size.Fixed(cbarh)] + 3*[Size.Fixed(margin), Size.Fixed(matw)]
# ...
divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)

# mind the gaps/marging between actual plots ...
ax_m_dot = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=1, ny=7))
ax_p_dot = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=5, ny=7))

ax_m_dotted = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=1, ny=5))
ax_p_dotted = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=5, ny=5))

ax_m_dotless = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=1, ny=3))
ax_p_dotless = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=5, ny=3))

ax_m_trans_dotted = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=3, ny=5))
ax_p_trans_dotted = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=7, ny=5))

ax_m_trans_dotless = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=3, ny=3))
ax_p_trans_dotless = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=7, ny=3))

for _ax in [
    ax_m_dot,
    ax_p_dot,
    ax_m_dotted,
    ax_p_dotted,
    ax_m_dotless,
    ax_p_dotless,
    ax_m_trans_dotted,
    ax_p_trans_dotted,
    ax_m_trans_dotless,
    ax_p_trans_dotless,
]:
    _ax.set_xticks([])
    _ax.set_yticks([])

sample_m = f"N93m5"
sample_p = f"N93p10"

# @ the dot exact
_cis_subidx = get_dot_distance(_dfff, _the_dots).query("dot_order < 3").index
_cis_stack_m = np.nanmean(fullstacks_cis[sample_m][_cis_subidx], axis=0)
_cis_stack_p = np.nanmean(fullstacks_cis[sample_p][_cis_subidx], axis=0)
# # cis pileups first ...
_hm = ax_m_dot.imshow( _cis_stack_m, **imshow_kwargs)
_hm.cmap.set_over("#300000")
_hm = ax_p_dot.imshow( _cis_stack_p, **imshow_kwargs)
_hm.cmap.set_over("#300000")

# dotted
_cis_subidx = _dfff.query("(dots_footprint1>0)&(dots_footprint2>0)").index
_cis_stack_m = np.nanmean(fullstacks_cis[sample_m][_cis_subidx], axis=0)
_cis_stack_p = np.nanmean(fullstacks_cis[sample_p][_cis_subidx], axis=0)
# # cis pileups first ...
_hm = ax_m_dotted.imshow( _cis_stack_m, **imshow_kwargs)
_hm.cmap.set_over("#300000")
_hm = ax_p_dotted.imshow( _cis_stack_p, **imshow_kwargs)
_hm.cmap.set_over("#300000")

# dotless
_cis_subidx = _dfff.query("(dots_footprint1==0)&(dots_footprint2==0)").index
_cis_stack_m = np.nanmean(fullstacks_cis[sample_m][_cis_subidx], axis=0)
_cis_stack_p = np.nanmean(fullstacks_cis[sample_p][_cis_subidx], axis=0)
# # cis pileups first ...
_hm = ax_m_dotless.imshow( _cis_stack_m, **imshow_kwargs)
_hm.cmap.set_over("#300000")
_hm = ax_p_dotless.imshow( _cis_stack_p, **imshow_kwargs)
_hm.cmap.set_over("#300000")

# TRANS ...
# dotted
_trans_idx = 1
_trans_stack_m = stack_means[sample_m][_trans_idx]
_trans_stack_p = stack_means[sample_p][_trans_idx]
# trans pileups first ...
_hm = ax_m_trans_dotted.imshow( _trans_stack_m, **imshow_kwargs)
_hm.cmap.set_over("#300000")
_hm = ax_p_trans_dotted.imshow( _trans_stack_p, **imshow_kwargs)
_hm.cmap.set_over("#300000")

# dotless
_trans_idx = 0
_trans_stack_m = stack_means[sample_m][_trans_idx]
_trans_stack_p = stack_means[sample_p][_trans_idx]
# trans pileups first ...
_hm = ax_m_trans_dotless.imshow( _trans_stack_m, **imshow_kwargs)
_hm.cmap.set_over("#300000")
_hm = ax_p_trans_dotless.imshow( _trans_stack_p, **imshow_kwargs)
_hm.cmap.set_over("#300000")


# add a single colorbar ...
cbar_ax = fig.add_axes(
    divider.get_position(),
    axes_locator=divider.new_locator(nx=7, ny=1)
)
cbar_ax.set_xticks([])
cbar_ax.set_yticks([])

fig.colorbar(
    cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
    cax=cbar_ax,
    orientation="horizontal",
)
cbar_ax.set_xticks([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
cbar_ax.set_xticklabels([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
cbar_ax.minorticks_off()

In [None]:
# w, h = 6, 2.9
# margin = 0.2
# matw = 1.15
# cbarh = 0.1
# # cbarw = 0.7*matw

# fig = plt.figure(
#     figsize=(w, h),
#     # facecolor='lightblue'
# )

# imshow_kwargs = dict(
#         norm=LogNorm(vmin=1/2.5, vmax=2.5),
#         cmap="RdBu_r",
#         interpolation="none",
# )


# _flank = 100_000
# _dfff = _df_intra_arm
# # _dfff = _df_intra_arm.query("(dots_footprint1==0)&(dots_footprint2==0)")
# dist_bins = [0, 500_000, 10_000_000, 1_000_000_000]
# dist_bins = [0, 250_000, 1_000_000, 10_000_000, 1_000_000_000]
# _dist_groups = _dfff.groupby(pd.cut( _dfff["dist"], dist_bins ), observed=True)
# ndist = len(_dist_groups)


# # The first items are for padding and the second items are for the axes, sizes are in inch.
# h =  (len(_dist_groups)+1)*[Size.Fixed(margin), Size.Fixed(matw)]
# # goes from bottom to the top ...
# v = [Size.Fixed(margin), Size.Fixed(cbarh)] + \
#     2 * [Size.Fixed(0.5*margin), Size.Fixed(matw)]
# # ...
# divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)

# sample_m = "m5hR1R2"
# sample_p = "p5hR1R2"
# _cis_stack_m = fullstacks_cis[sample_m]
# _cis_stack_p = fullstacks_cis[sample_p]

# _trans_stack_m = stack_means[sample_m][-1]
# _trans_stack_p = stack_means[sample_p][-1]

# axs_m = {}
# axs_p = {}
# cax_h = {}

# for i, _dist_key in enumerate(_dist_groups.groups):
#     # mind the gaps/marging between actual plots ...
#     axs_p[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=3))
#     axs_m[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=5))
# i = i + 1
# _dist_key = "trans"
# # mind the gaps/marging between actual plots ...
# axs_p[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=3))
# axs_m[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=5))

# cbar_ax = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=1))
# cbar_ax.set_xticks([])
# cbar_ax.set_yticks([])



# for jj, (_dist_key, _dist_idx) in enumerate(_dist_groups.groups.items()):
#     axm, axp = axs_m[_dist_key], axs_p[_dist_key]
#     # cis pileups first ...
#     _hm = axm.imshow( np.nanmean(_cis_stack_m[_dist_idx], axis=0), **imshow_kwargs)
#     _hm.cmap.set_over("#300000")
#     _hm = axp.imshow( np.nanmean(_cis_stack_p[_dist_idx], axis=0), **imshow_kwargs)
#     _hm.cmap.set_over("#300000")
#     for _ax in [axp, axm]:
#         _ax.set_xticks([])
#         _ax.set_yticks([])
#     axm.set_title(f"{_dist_key}", fontsize=9)
#     # axp.set_xticks(np.arange(len(ticklabels)))
#     # axp.set_xticklabels(np.asarray(ticklabels[::-1]), rotation="vertical")
#     if jj == 0:
#         axm.set_ylabel(sample_m)
#         axp.set_ylabel(sample_p)

# stack_means[_sample]

# # treat trans separately ...
# jj = jj + 1
# _dist_key = "trans"
# _dist_idx = slice(None)
# axm, axp = axs_m[_dist_key], axs_p[_dist_key]
# _hm = axm.imshow( _trans_stack_m, **imshow_kwargs)
# _hm.cmap.set_over("#300000")
# _hm = axp.imshow( _trans_stack_p, **imshow_kwargs)
# _hm.cmap.set_over("#300000")
# for _ax in [axp, axm]:
#     _ax.set_xticks([])
#     _ax.set_yticks([])
# axm.set_title(f"{_dist_key}", fontsize=9)


# # add a single colorbar ...

# fig.colorbar(
#     cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
#     cax=cbar_ax,
#     orientation="horizontal",
# )
# cbar_ax.set_xticks([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
# cbar_ax.set_xticklabels([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
# cbar_ax.minorticks_off()
