In [None]:
# set the number of threads for many common libraries
from os import environ
N_THREADS = '1'
environ['OMP_NUM_THREADS'] = N_THREADS
environ['OPENBLAS_NUM_THREADS'] = N_THREADS

environ['MKL_NUM_THREADS'] = N_THREADS
environ['VECLIB_MAXIMUM_THREADS'] = N_THREADS
environ['NUMEXPR_NUM_THREADS'] = N_THREADS

# https://superfastpython.com/numpy-number-blas-threads/

In [None]:
%load_ext autoreload
%autoreload 2
# from saddle import saddleplot

In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
import pandas as pd
import numpy as np
from itertools import chain

# Hi-C utilities imports:
import cooler
import bioframe
import cooltools
from cooltools.lib.numutils import fill_diag

# Visualization imports:
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import matplotlib.patches as patches
from matplotlib.ticker import EngFormatter

from itertools import cycle

# from ipywidgets import interact, fixed

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

# enable editable text ...
mpl.rcParams["pdf.fonttype"]=42
mpl.rcParams["svg.fonttype"]="none"
mpl.rcParams['axes.linewidth'] = 0.5

In [None]:
# ! pip install --upgrade --no-cache --no-deps --ignore-install cooler
# ls /home/dekkerlab/dots-test
# import higlass as hg
# import jscatter
import scipy
import logging
import multiprocess as mp
# import mpire for nested multi-processing
from mpire import WorkerPool
# bbi for stackups ...
import bbi
from data_catalog import bws, bws_vlim, telo_dict
from helper_func import get_stack, show_stacks, plot_stackups_lite, plot_stackups_sets, to_bigbed3


from tqdm import tqdm
from tqdm.notebook import trange, tqdm
import warnings

### arms here just in case ...

In [None]:
# define genomic view that will be used to call dots and pre-compute expected

# Use bioframe to fetch the genomic features from the UCSC.
hg38_chromsizes = bioframe.fetch_chromsizes('hg38')
hg38_cens = bioframe.fetch_centromeres('hg38')
hg38_arms_full = bioframe.make_chromarms(hg38_chromsizes, hg38_cens)
# autosomal only ...
included_arms = hg38_arms_full["name"].to_list()[:44]
hg38_arms = hg38_arms_full[hg38_arms_full["name"].isin(included_arms)].reset_index(drop=True)

# Read pre-called Intrinsic Domain anchors - ID anchors

In [None]:
id_anchor_fnames = {
    "mega_2X_enrichment": "ID_anchors/mega_2X_enrichment.fourth_mega.max_size.bed",
    "5hr_2X_enrichment_old": "ID_anchors/5hr_2X_enrichment.second_bulk.max_size.bed",
    "5hr_2X_enrichment": "ID_anchors/5hr_2X_enrichment.pixel_derived.bed",
    "5hr_2X_enrichment_nosing": "ID_anchors/5hr_2X_enrichment.pixel_derived.no_singletons.bed",
    "5hr_notinCyto_2X_enrichment_signal": "ID_anchors/p5notin_pCyto_anchors_2X_enrichment.pixel_derived.signal_peaks.bed",
    "5hr_2X_enrichment_signal": "ID_anchors/5hr_2X_enrichment.pixel_derived.signal_peaks.bed",
    "10hr_2X_enrichment_signal": "ID_anchors/10hrs_2X_enrichment.pixel_derived.signal_peaks.bed",
    "N93p5_2X_enrichment_signal": "ID_anchors/N93p5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "pCyto_2X_enrichment_signal": "ID_anchors/pCyto_2X_enrichment.pixel_derived.signal_peaks.bed",
    "mCyto_2X_enrichment_signal": "ID_anchors/mCyto_2X_enrichment.pixel_derived.signal_peaks.bed",
    "mega_3X_enrichment": "ID_anchors/mega_3X_enrichment.fifth_mega3x.max_size.bed",
    "MEGA_2X_enrichment": "ID_anchors/MEGAp5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGA_weaker_2X_enrichment": "ID_anchors/MEGA_plus_weak_anchors_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGAN93_2X_enrichment": "ID_anchors/MEGAN93p5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGAminus_2X_enrichment": "ID_anchors/MEGA_minus_ctrl_2X_enrichment.pixel_derived.signal_peaks.bed",
    "cyto_2x_enrichment": "ID_anchors/cyto_2x_enrichment.third_mCyto.max_size.bed",
}

id_anchors_dict = {}
for id_name, fname in id_anchor_fnames.items():
    id_anchors_dict[id_name] = pd.read_csv(fname, sep="\t")
    # ...
    print(f"loaded {len(id_anchors_dict[id_name]):5d} ID anchors {id_name:>20} in BED format ...")

### To stackups now ...

In [None]:
# Create of the stackup, the flanks are +- 50 Kb, number of bins is 100 :
_flank = 100_000 # Length of flank to one side from the boundary, in basepairs
_nbins = 200   # Number of bins to split the region

## predefine few bigwigs and bigbed derivatives of Hi-C for stacking ...

In [None]:
# ! ls ev_bigwig
# bws
bws["evm"] = "ev_bigwig/m5hR1R2.10kb.bw"
bws["evp"] = "ev_bigwig/p5hR1R2.10kb.bw"
bws["idcov"] = "pix_clust_cov.bw"
# bws["dots"] = "mega_dots_anchors.bb"
bws["dots"] = "mega_final_dots_anchors.bb"

# # plug insulation there jus for fun ...
# bws["evm"] = "ranGAP1-0-G1s-R1R2.hg38.mapq_30.1000.b10000.insul.w50000.bw"
# bws["evp"] = "ranGAP1-7-G1s-R1R2.hg38.mapq_30.1000.b10000.insul.w50000.bw"

# Create a dictionary with sets of anchors for the Fig 3D plot
## Cyto-specific and G1-specific ones ...
### the definitive step by step procedure ...

In [None]:
# mcytodf = id_anchors_dict["mCyto_2X_enrichment_signal"]
# p10hrdf = id_anchors_dict["10hr_2X_enrichment_signal"]
# n93p5hrdf = id_anchors_dict["N93p5_2X_enrichment_signal"]
# megadf = id_anchors_dict["MEGA_2X_enrichment"]
# megaN93df = id_anchors_dict["MEGAN93_2X_enrichment"]
# megaminusdf = id_anchors_dict["MEGAminus_2X_enrichment"]

# cooler files that we'll work on :
binsize = 10_000
telo_clrs = { _k: cooler.Cooler(f"{_path}::/resolutions/{binsize}") for _k, _path in telo_dict.items() }

In [None]:
! wc -l "ID_anchors/pCyto_2X_enrichment.pixel_derived.signal_peaks.bed"

In [None]:
len(_Cyto_anchors)

In [None]:
len(_G1_anchors)

In [None]:
_Cyto_anchors = id_anchors_dict["pCyto_2X_enrichment_signal"]
_G1_anchors = id_anchors_dict["5hr_2X_enrichment_signal"]


cyto_not_in_g1 = bioframe.setdiff(  # Cyto specific one ...
    _Cyto_anchors.eval(
        """
        peak_start = peak_start - 100
        peak_end = peak_end + 100
        """
    ),
    _G1_anchors,
    cols1=["chrom","peak_start","peak_end"],
    cols2=["chrom","peak_start","peak_end"],
).reset_index(drop=True)


_Cyto_spec_anchors = bioframe.setdiff(  # Cyto specific one ...
    _Cyto_anchors,
    cyto_not_in_g1,
    cols1=["chrom","peak_start","peak_end"],
    cols2=["chrom","peak_start","peak_end"],
).reset_index(drop=True)

_G1_spec_anchors = bioframe.setdiff(  # G1 5hr specific one ...
    _G1_anchors,
    _Cyto_anchors,
    cols1=["chrom","peak_start","peak_end"],
    cols2=["chrom","peak_start","peak_end"],
).reset_index(drop=True)

len(_Cyto_spec_anchors) , len(_G1_spec_anchors), len(_Cyto_spec_anchors)+len(_G1_spec_anchors)
# we need bins/pixels for plotting - instead of genomic coordinates ...

# need a reference cooler for that to do genomic coords -> bins:
_clr = telo_clrs["m5hR1R2"]
# doing that using clr.offset and clr.extent functionality:
_G1_spec_anchors["bin_id"] = _G1_spec_anchors[['chrom', 'peak_start', 'peak_end']].apply(_clr.offset, axis=1, result_type="expand")
_G1_spec_anchors["bin_width"] = _G1_spec_anchors[['chrom', 'peak_start', 'peak_end']] \
    .apply(_clr.extent, axis=1, result_type="expand") \
    .apply(np.diff, axis=1, result_type="expand")[0]

# doing that using clr.offset and clr.extent functionality:
_Cyto_spec_anchors["bin_id"] = _Cyto_spec_anchors[['chrom', 'peak_start', 'peak_end']].apply(_clr.offset, axis=1, result_type="expand")
_Cyto_spec_anchors["bin_width"] = _Cyto_spec_anchors[['chrom', 'peak_start', 'peak_end']] \
    .apply(_clr.extent, axis=1, result_type="expand") \
    .apply(np.diff, axis=1, result_type="expand")[0]

# ...
anchor_dict = {
    "pcyto" : _Cyto_spec_anchors,
    "p5_wo_pCyto" : _G1_spec_anchors,
}
# just print numbers ...
(
    len(anchor_dict["pcyto"]),
    len(anchor_dict["p5_wo_pCyto"]),
    len(anchor_dict["pcyto"])+len(anchor_dict["p5_wo_pCyto"]),
)

In [None]:
bw_kyes_to_use = [
    'mG.atac',
    'H3K4me3',
    'H3K27ac',
    'H3K27me3',
    # 'ctcf',
    "evm",
    "evp",
    # "ids",
    "dots",
    "rg0r1fwd",
    "rg0r1rev",
]


def _job(packed_data, bw_sample):
    # unpack shared data
    flank, nbins, track_dict, bed_df = packed_data
    bw_track = track_dict[bw_sample]
    stack_kwargs = dict(kind="mid", flank=flank, nbins=nbins)
    from helper_func import get_stack
    return (
        bw_sample,
        get_stack(bw_track, bed_df, **stack_kwargs),
    )

stackups_anchor_dict = {}
for anchor_name, anchors_df in anchor_dict.items():
    print(f"pulling {len(anchors_df)} anchors {anchor_name} ...")

    _shared = (
        _flank,
        _nbins,
        bws,
        anchors_df,
    )

    with WorkerPool(
        n_jobs=len(bws),
        shared_objects=_shared,
        start_method="forkserver",  # little faster than spawn, fork is the fastest
        use_dill=True,
    ) as wpool:
        results = wpool.map(_job, bw_kyes_to_use, progress_bar=True)

    # sort out the results ...
    stackups_anchor_dict[anchor_name] = {sample: _pstack for sample, _pstack in results}

    # Combine RNA + and - into a single stack ...
    ## Let's combine RNA-seq in a special way: fwd + rev[::-1]
    # ...
    # redefine RNA-seq track by combining fwd and rev tracks in a special way:
    _rna_fwd = stackups_anchor_dict[anchor_name].pop("rg0r1fwd")
    _rna_rev = stackups_anchor_dict[anchor_name].pop("rg0r1rev")
    #
    stackups_anchor_dict[anchor_name]["rna"] = _rna_fwd + _rna_rev[:,::-1]

In [None]:
bws_vlim = {
    "mM.atac": dict(vmin=0.2,vmax=1.),
    # "mT.atac": dict(vmin=0.005,vmax=0.3),
    "mG.atac": dict(vmin=0.2,vmax=1.),
    # "pM.atac": dict(vmin=0.005,vmax=0.3),
    # # "pT.atac": dict(vmin=0.005,vmax=0.3),
    "pG.atac": dict(vmin=0.2,vmax=1.),
    "H3K36me3": dict(vmin=0.03,vmax=0.08),
    "H3K4me1": dict(vmin=0.04,vmax=0.1),
    "MED1": dict(vmin=0.04,vmax=0.15),
    "H3K27me3": dict(vmin=0.2,vmax=1),
    "H3K9me3": dict(vmin=0.2,vmax=1),
    "H3K4me3": dict(vmin=1,vmax=14),
    "H3K27ac": dict(vmin=1,vmax=14),
    "ctcf": dict(vmin=2,vmax=5),
    # some rna-seq - Async rangap control and depletion ...
    "rg0r1fwd": dict(vmin=0.5,vmax=12),
    "rg0r1rev": dict(vmin=0.5,vmax=12),
    "rg8r1fwd": dict(vmin=0.5,vmax=12),
    "rg8r1rev": dict(vmin=0.5,vmax=12),

}



bws_vlim["mM.atac"] = dict(vmin=0.1,vmax=1.)
bws_vlim["mG.atac"] = dict(vmin=0.15,vmax=1.)
bws_vlim["pG.atac"] = dict(vmin=0.1,vmax=1.)
bws_vlim["H3K27me3"] = dict(vmin=0.15,vmax=0.85)
bws_vlim["H3K9me3"] = dict(vmin=0.15,vmax=1)
bws_vlim["H3K4me3"] = dict(vmin=0.5,vmax=14)
bws_vlim["H3K27ac"] = dict(vmin=0.9,vmax=18)
bws_vlim["ctcf"] = dict(vmin=1.8,vmax=6)
bws_vlim["evm"] = dict(vmin=-1.3,vmax=1.3)
bws_vlim["evp"] = dict(vmin=-1.3,vmax=1.3)
bws_vlim["dots"] = dict(vmin=0,vmax=0.3)
# # # some rna-seq - Async rangap control and depletion ...
bws_vlim["rna"] = dict(vmin=0.5,vmax=40)
bws_vlim["rg0r1fwd"] = dict(vmin=0.5,vmax=35)
bws_vlim["rg0r1rev"] = dict(vmin=0.5,vmax=35)
bws_vlim["rg8r1fwd"] = dict(vmin=0.5,vmax=35)
bws_vlim["rg8r1rev"] = dict(vmin=0.5,vmax=35)


from palettable.scientific import sequential
lajolla_r = sequential.LaJolla_10_r.get_mpl_colormap()
bilbao_r = sequential.Bilbao_20_r.get_mpl_colormap()
batlow_r = sequential.Batlow_20_r.get_mpl_colormap()
devon_r = sequential.Devon_20_r.get_mpl_colormap()
davos_r = sequential.Davos_20_r.get_mpl_colormap()
# fall_cmap = fall
lajolla = sequential.LaJolla_10.get_mpl_colormap()
bilbao = sequential.Bilbao_20.get_mpl_colormap()
batlow = sequential.Batlow_20.get_mpl_colormap()
devon = sequential.Devon_20.get_mpl_colormap()
davos = sequential.Davos_20.get_mpl_colormap()

imola_r = sequential.Imola_20_r.get_mpl_colormap()
imola = sequential.Imola_20.get_mpl_colormap()

nuuk_r = sequential.Nuuk_20_r.get_mpl_colormap()
nuuk = sequential.Nuuk_20.get_mpl_colormap()

hawaii_r = sequential.Hawaii_20_r.get_mpl_colormap()
hawaii = sequential.Hawaii_20.get_mpl_colormap()

oleron_r = sequential.Oleron_20_r.get_mpl_colormap()
oleron = sequential.Oleron_20.get_mpl_colormap()

oslo_r = sequential.Oslo_20_r.get_mpl_colormap()
oslo = sequential.Oslo_20.get_mpl_colormap()

bamako_r = sequential.Bamako_20_r.get_mpl_colormap()
bamako = sequential.Bamako_20.get_mpl_colormap()

grayc_r = sequential.GrayC_20_r.get_mpl_colormap()
grayc = sequential.GrayC_20.get_mpl_colormap()

lapaz_r = sequential.LaPaz_20_r.get_mpl_colormap()
lapaz = sequential.LaPaz_20.get_mpl_colormap()

In [None]:
from helper_func import _get_norms, _get_hms_nested_shape, _fillmissing_hms, _get_profiles


import matplotlib.lines as lines
from matplotlib.lines import Line2D
from matplotlib.patches import ConnectionPatch, Rectangle
from mpl_toolkits.axes_grid1 import Divider, Size
from mpl_toolkits.axes_grid1.inset_locator import BboxConnector
import matplotlib as mpl


def _get_norm(scale, vlims):
    """
    given a scale and vlims - return norm !
    """
    # try to extract vmin , vmax ...
    try:
        vmin, vmax = vlims
    except Exception:
        vmin, vmax = None, None
    # depending on scale ...
    if scale == "log":
        return mpl.colors.LogNorm(vmin, vmax)
    else:
        return mpl.colors.Normalize(vmin, vmax)

# profile_height=0.35
# margin_h=0.2
# margin_v=0.2
# spacing_h=0.02
# spacing_v=0.15
# cbarh_spacing_v = 0.05
# fig_fontsize=6
# cbarh = 0.1

def plot_stackups_sets_new(
    num_extra_plots,
    hms_dict_dict,  # heatmaps dict, that controls the order, the amount etc etc ...
    scales,
    vlims,
    titles,
    cmaps,
    binsizes,
    fillmissing=False,
    extra_plots_position="left",
    len_per_thousand=0.75,
    width_per_stack=0.35,
    profile_height=0.35,
    cbar_height = 0.1,
    spacing_v=.5,  # fixed distance between axes (vertically)
    spacing_h=.2,  # fixed distance between axes (horizontally)
    # **plot_kwargs,
    fig_fontsize=6,
    **imshow_kwargs,
):
    """
    plot a buch of stackups ...
    """
    # rewrite everyhting assuming hms_dict_dict is a dict of stackup groups !
    # groups are plotted on top of each other ...

    if num_extra_plots:
        num_extra_plots = int(num_extra_plots)
    else:
        num_extra_plots = 0

    # inspect provided stacks and define figure with all of the panels ...
    num_stackup_groups, stackup_samples, stackup_group_heights, stack_width = _get_hms_nested_shape(hms_dict_dict)
    num_cols = len(stackup_samples) + num_extra_plots

    # in inches
    margin_h=0.2
    margin_v=0.2
    profile_spacing_v = 0.075
    cbarh_spacing_v = 0.05
    profile_color = "dimgray"
    imshow_kwargs = dict(interpolation="antialiased", interpolation_stage="data", filternorm=False)
    # imshow_kwargs = dict(interpolation="antialiased", interpolation_stage="rgba", filternorm=True)

    # horizontal splitting layout
    h_split = []
    h_split.append( Size.Fixed(margin_h) )
    h_split += [Size.Fixed(width_per_stack), Size.Fixed(spacing_h)]*(num_cols-1)
    h_split += [Size.Fixed(width_per_stack), Size.Fixed(margin_h)]

    # vertical splitting layout
    v_split = []
    v_split.append( Size.Fixed(margin_v) )
    v_split.append( Size.Fixed(cbar_height) )
    v_split.append( Size.Fixed(cbarh_spacing_v) )
    for _i, _num_row_per_stack in enumerate(stackup_group_heights.values()):
        v_split.append( Size.Fixed(len_per_thousand*(_num_row_per_stack/1_000)) )
        if _i < len(stackup_group_heights)-1:
            v_split.append( Size.Fixed(spacing_v) )
        else:
            v_split.append( Size.Fixed(spacing_v+profile_spacing_v) )
    profile_spacing_v
    v_split.append( Size.Fixed(profile_height) )
    v_split.append( Size.Fixed(margin_v) )


    # set figsize based on the tiling provided - i.e. post factum ...
    fig_width = sum(_h.fixed_size for _h in h_split)
    fig_height = sum(_v.fixed_size for _v in v_split)
    fig = plt.figure(
        figsize=(fig_width, fig_height),
        layout="none",
        # facecolor='lightblue'
    )
    print(f"figure overall is {fig_width=} {fig_height=}")

    divider = Divider(fig, (0, 0, 1, 1), h_split, v_split, aspect=False)
    _div_pos = divider.get_position()


    ax_profile = {}
    ax_stackup = {}
    ax_xtra = []
    ax_cbar = {}
    # define nest dict of axes ...
    # provide extra axes at the end ...
    if extra_plots_position == "left":
        # extra plots on the left ...
        for jdx in range(num_extra_plots):
            for idx, group_k in enumerate(hms_dict_dict):
                # ax_xtra.append([fig.add_subplot(gs[_i+1,jdx]) for _i in range(num_stackup_groups)])
                idx += 1  # adjust by 1, since there is a cbar at the bottom
                _stack_group_locator = divider.new_locator(nx=2*jdx+1, ny=2*idx+1)
                ax_xtra.append(fig.add_axes(_div_pos, axes_locator=_stack_group_locator))
        for jdx, sample in enumerate(stackup_samples):
            jdx += num_extra_plots  # adjust steps by the extract plots in front
            _cbar_locator = divider.new_locator(nx=2*jdx+1, ny=1)
            ax_cbar[sample] = fig.add_axes(_div_pos, axes_locator=_cbar_locator)
            ax_stackup[sample] = {}
            for idx, group_k in enumerate(hms_dict_dict):
                idx += 1  # adjust by 1, since there is a cbar at the bottom
                _stack_group_locator = divider.new_locator(nx=2*jdx+1, ny=2*idx+1)
                ax_stackup[sample][group_k] = fig.add_axes(_div_pos, axes_locator=_stack_group_locator)
            # profile ny is simply the next one :
            _profile_locator = divider.new_locator(nx=2*jdx+1, ny=2*(idx+1)+1)
            ax_profile[sample] = fig.add_axes(_div_pos, axes_locator=_profile_locator)
    # # provide extra axes at the end ...
    # if extra_plots_position == "right":
    else:  # RIGHT ...
        # start with the stacks
        for jdx, sample in enumerate(stackup_samples):
            _cbar_locator = divider.new_locator(nx=2*jdx+1, ny=1)
            ax_cbar[sample] = fig.add_axes(_div_pos, axes_locator=_cbar_locator)
            ax_stackup[sample] = {}
            for idx, group_k in enumerate(hms_dict_dict):
                idx += 1  # adjust by 1, since there is a cbar at the bottom
                _stack_group_locator = divider.new_locator(nx=2*jdx+1, ny=2*idx+1)
                ax_stackup[sample][group_k] = fig.add_axes(_div_pos, axes_locator=_stack_group_locator)
            # profile ny is simply the next one :
            _profile_locator = divider.new_locator(nx=2*jdx+1, ny=2*(idx+1)+1)
            ax_profile[sample] = fig.add_axes(_div_pos, axes_locator=_profile_locator)
        # add extra plots in the end (on the right) ...
        for jdx in range(len(stackup_samples), len(stackup_samples)+num_extra_plots):
            for idx, group_k in enumerate(hms_dict_dict):
                idx += 1  # adjust by 1, since there is a cbar at the bottom
                _stack_group_locator = divider.new_locator(nx=2*jdx+1, ny=2*idx+1)
                ax_xtra.append(fig.add_axes(_div_pos, axes_locator=_stack_group_locator))

    # fill missing if needed and calculate profiles (per group) ...
    hms_dict_dict_copy = {}
    profile_hm = {}
    for group_k, hms_dict in hms_dict_dict.items():
        hms_dict_dict_copy[group_k] = _fillmissing_hms(hms_dict, how="col mean") if fillmissing else hms_dict
        # use modified stacks to calculate profiles ...
        profile_hm[group_k] = _get_profiles(hms_dict_dict_copy[group_k], scales)
    # replace hms_dict_dict with the updated copy ...
    hms_dict_dict = hms_dict_dict_copy
    # get norms - they are just per sample - regardless of the group ...
    norms = {k: _get_norm(scales[k], vlims[k]) for k in stackup_samples}

    last_group_k = list(hms_dict_dict.keys())[-1]
    first_sample = stackup_samples[0]

    # start plotting ...
    for group_k, hms_dict in hms_dict_dict.items():
        # we've checked that samples go in the same order ...
        for sample, hm in hms_dict.items():
            ax_profile[sample].plot(profile_hm[group_k][sample], linewidth=1)
            stack_hm = ax_stackup[sample][group_k].imshow(
                              hm,
                              norm=norms[sample],
                              aspect="auto",
                              cmap=cmaps[sample],
                              **imshow_kwargs,
            )
            # beautify ...
            ax_stackup[sample][group_k].set_xticks([])
            ax_stackup[sample][group_k].set_xticklabels([])
            ax_stackup[sample][group_k].set_yticks([])
            ax_stackup[sample][group_k].set_yticklabels([])
            ax_stackup[sample][group_k].minorticks_off()
            # #
            # if sample == first_sample:
            #     ax_stackup[sample][group_k].set_ylabel(group_k,fontsize=fig_fontsize)

            if group_k == last_group_k:
                # beautify ...
                # we have to do it for every samples - but not for every group ...
                first_bin = -.5
                center_bin = stack_width/2 - .5
                last_bin = stack_width - .5
                _xticklength = 1
                flank_in_kb = int( (center_bin+.5)*binsizes[sample]/1000 )
                flank_ticks = [first_bin, center_bin, last_bin]
                flank_ticklabels = [-flank_in_kb, "", flank_in_kb]
                #
                ax_profile[sample].set_title(titles[sample],fontsize=fig_fontsize, pad=2.5)
                # ax_profile[sample].set_title(titles[sample])
                ax_profile[sample].minorticks_off()
                ax_profile[sample].set_xlim([first_bin, last_bin])
                ax_profile[sample].set_xticks(flank_ticks)
                ax_profile[sample].tick_params(axis="x", length=_xticklength, pad=0.5)
                ax_profile[sample].set_xticklabels(flank_ticklabels,fontsize=fig_fontsize)
                for _tidx, tick in enumerate(ax_profile[sample].xaxis.get_majorticklabels()):
                    if _tidx == 0:
                        tick.set_horizontalalignment("left")
                    elif _tidx == 2:
                        tick.set_horizontalalignment("right")
                    else:
                        tick.set_horizontalalignment("center")
                ax_profile[sample].set_ylim(vlims[sample])
                ax_profile[sample].tick_params(axis="y", length=0, direction="in", pad=-5)
                ax_profile[sample].set_yticks(vlims[sample])
                ax_profile[sample].set_yticklabels(vlims[sample],fontsize=fig_fontsize)
                for _tidx, tick in enumerate(ax_profile[sample].yaxis.get_majorticklabels()):
                    tick.set_horizontalalignment("left")
                    if _tidx == 0:
                        tick.set_verticalalignment("bottom")
                    elif _tidx == 1:
                        tick.set_verticalalignment("top")

                # # bottom one - show ticks for now ...
                # ax_stackup[sample][group_k].set_xticks(flank_ticks)
                # ax_stackup[sample][group_k].set_xticklabels(flank_ticklabels,fontsize=fig_fontsize)
                # ax_stackup[sample][group_k].tick_params(axis="x", length=6)
                # ax_stackup[sample][group_k].set_yticks([])
                # ax_stackup[sample][group_k].set_yticklabels([])
                # for _tidx, tick in enumerate(ax_stackup[sample][group_k].xaxis.get_majorticklabels()):
                #     if _tidx == 0:
                #         tick.set_horizontalalignment("left")
                #     elif _tidx == 2:
                #         tick.set_horizontalalignment("right")
                #     else:
                #         tick.set_horizontalalignment("center")
                # # colorbar - draw them one time per sample only !
                plt.colorbar(stack_hm,
                            cax=ax_cbar[sample],
                            orientation="horizontal",
                            ticks=vlims[sample])
                ax_cbar[sample].minorticks_off()
                ax_cbar[sample].tick_params(axis="x", length=_xticklength, pad=0.5)
                ax_cbar[sample].set_xticklabels(vlims[sample],fontsize=fig_fontsize)
                for _tidx, tick in enumerate(ax_cbar[sample].xaxis.get_majorticklabels()):
                    if _tidx == 0:
                        tick.set_horizontalalignment("left")
                    elif _tidx == 1:
                        tick.set_horizontalalignment("right")

    return ax_xtra

In [None]:
for k,df in anchor_dict.items():
    print(k,len(df))

In [None]:
keys_to_display = [ 'mG.atac', 'H3K4me3', 'H3K27ac', 'H3K27me3', "rna", "evm","evp", "dots"]

hms_dict = {}
# for _group_name in ['p5_wo_pCyto', 'pcyto']:
# # for _group_name, g in mega_anchor_dict.items():
#     g = mega_anchor_dict[_group_name]
#     stackups_anchor = stackups_anchor_dict[_group_name]
#     # define ordering ...
#     _idx = stackups_anchor["rna"][g.index].mean(axis=1).argsort()
#     # _idx = g["valency"].argsort()
#     # _idx = g["size"].argsort()
#     hms_dict[_group_name] = {k: stackups_anchor[k][_idx] for k in keys_to_display}
for _group_name in ['p5_wo_pCyto', 'pcyto']:
    g = anchor_dict[_group_name]
    stackups_anchor = stackups_anchor_dict[_group_name]
    # define ordering ...
    _idx = stackups_anchor["rna"][g.index].mean(axis=1).argsort()
    # _idx = g["valency"].argsort()
    # _idx = g["size"].argsort()
    hms_dict[_group_name] = {k: stackups_anchor[k][_idx] for k in keys_to_display}

# ...
scales = {k: "log" for k in keys_to_display}
vlims = {k: (bws_vlim[k]["vmin"], bws_vlim[k]["vmax"]) for k in keys_to_display}
titles = {k: k for k in keys_to_display}
cmaps = {k: davos_r for k in keys_to_display}
binsizes = {k: ((2*_flank) // _nbins) for k in keys_to_display}

for _evk in ["evm","evp","dots"]:
    scales[_evk] = "linear"
    cmaps[_evk] = "RdBu_r"

cmaps["dots"] = davos_r

axx = plot_stackups_sets_new(
    0,
    hms_dict,  # heatmaps dict, that controls the order, the amount etc etc ...
    scales,
    vlims,
    titles,
    cmaps,
    binsizes,
    fillmissing=True,
    len_per_thousand=0.73,
    width_per_stack=0.33,
    profile_height=0.33,
    cbar_height = 0.07,
    spacing_v=.03,  # fixed distance between axes (vertically)
    spacing_h=.03,  # fixed distance between axes (horizontally)
    fig_fontsize=6,
    # **imshow_kwargs,
)

plt.savefig("Fig3d-2way_stackup.svg",dpi=300)

# Legacy stuff ...

In [None]:
# num_extra_plots,
# hms_dict_dict,
# ...
scales = {k: "linear" for k in keys_to_display}
vlims = {k: (bws_vlim[k]["vmin"], bws_vlim[k]["vmax"]) for k in keys_to_display}
titles = {k: k for k in keys_to_display}
cmaps = {k: davos_r for k in keys_to_display}
# cmaps = {k: "Blues" for k in keys_to_display}
binsizes = {k: ((2*_flank) // _nbins) for k in keys_to_display}
# modify EV-related things ...
for _evk in ["evm","evp"]:
    scales[_evk] = "linear"
    cmaps[_evk] = "RdBu_r"


fillmissing=True
# fill missing if needed and calculate profiles (per group) ...
hms_dict_copy = {}
profile_hm = {}
for group_k, _hms in hms_dict.items():
    hms_dict_copy[group_k] = _fillmissing_hms(_hms, how="col mean") if fillmissing else _hms
    # use modified stacks to calculate profiles ...
    profile_hm[group_k] = _get_profiles(hms_dict_copy[group_k], scales)
# replace hms_dict_dict with the updated copy ...
# _hms = hms_dict_copy
# # get norms - they are just per sample - regardless of the group ...
# norms = _get_norms(scales, vlims)
# last_group_k = list(hms_dict_dict.keys())[-1]
# first_sample = stackup_samples[0]
_stacks_dict , = list(hms_dict_copy.values())


# ...
norms = {k: _get_norm(scales[k], vlims[k]) for k in keys_to_display}
num_stackup_groups, stackup_samples, stackup_group_heights, stack_width = _get_hms_nested_shape({"grp1": _stacks_dict})
num_stacks = len(stackup_samples)
num_stack_rows = sum(stackup_group_heights.values())

# in inches
len_per_thousand=0.75
width_per_stack=0.35
# len_per_thousand=5.9
# width_per_stack=2.7
profile_height=0.35
margin_h=0.2
margin_v=0.2
spacing_h=0.02
spacing_v=0.15
cbarh_spacing_v = 0.05
fig_fontsize=6
cbarh = 0.1
profile_color = "dimgray"

imshow_kwargs = dict(interpolation="antialiased", interpolation_stage="data", filternorm=False)
# imshow_kwargs = dict(interpolation="antialiased", interpolation_stage="rgba", filternorm=True)

w = num_stacks*width_per_stack + (num_stacks-1)*spacing_h + 2*margin_h
h = 2*margin_v + spacing_v + cbarh_spacing_v + \
    len_per_thousand*(num_stack_rows/1_000) + \
    profile_height + cbarh

print(f"figure size is {w=} {h=}")
fig = plt.figure(
    figsize=(w, h),
    layout="none",
    # facecolor='lightblue',
)

# horizontal splitting layout
h_split = []
h_split.append( Size.Fixed(margin_h) )
h_split += [Size.Fixed(width_per_stack), Size.Fixed(spacing_h)]*(num_stacks-1)
h_split += [Size.Fixed(width_per_stack), Size.Fixed(margin_h)]

# vertical splitting layout
v_split = []
v_split.append( Size.Fixed(margin_v) )
v_split.append( Size.Fixed(cbarh) )
v_split.append( Size.Fixed(cbarh_spacing_v) )
for _num_row_per_stack in group_sizes:
    v_split.append( Size.Fixed(len_per_thousand*(_num_row_per_stack/1_000)) )
    v_split.append( Size.Fixed(spacing_v) )
v_split.append( Size.Fixed(profile_height) )
v_split.append( Size.Fixed(margin_v) )

divider = Divider(fig, (0, 0, 1, 1), h_split, v_split, aspect=False)

axs_profile = [ fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=5)) for i in range(num_stacks)]
axs_stack = [ fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=3)) for i in range(num_stacks)]
axs_cbar = [ fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=1)) for i in range(num_stacks)]

for _i, _sample in enumerate(stackup_samples):
    hm = _stacks_dict[_sample]
    axs_profile[_i].plot(np.nanmean(hm,axis=0), linewidth=1, color=profile_color)
    stack_hm = axs_stack[_i].imshow(
                      hm,
                      norm=norms[_sample],
                      aspect="auto",
                      cmap=cmaps[_sample],
                      **imshow_kwargs,
    )
    # beautify ...
    # we have to do it for every samples - but not for every group ...
    first_bin = -.5
    center_bin = stack_width/2 - .5
    last_bin = stack_width - .5
    _xticklength = 1
    flank_in_kb = int( (center_bin+.5)*binsizes[_sample]/1000 )
    flank_ticks = [first_bin, center_bin, last_bin]
    flank_ticklabels = [-flank_in_kb, "", f"+{flank_in_kb}"]
    axs_profile[_i].set_title(titles[_sample],fontsize=fig_fontsize, pad=2.5)
    axs_profile[_i].minorticks_off()
    axs_profile[_i].set_xlim([first_bin, last_bin])
    axs_profile[_i].set_xticks(flank_ticks)
    axs_profile[_i].tick_params(axis="x", length=_xticklength, pad=0.5)
    axs_profile[_i].set_xticklabels(flank_ticklabels,fontsize=fig_fontsize)
    for _tidx, tick in enumerate(axs_profile[_i].xaxis.get_majorticklabels()):
        if _tidx == 0:
            tick.set_horizontalalignment("left")
        elif _tidx == 2:
            tick.set_horizontalalignment("right")
        else:
            tick.set_horizontalalignment("center")
    axs_profile[_i].set_ylim(vlims[_sample])
    axs_profile[_i].tick_params(axis="y", length=0, direction="in", pad=-3)
    axs_profile[_i].set_yticks(vlims[_sample])
    axs_profile[_i].set_yticklabels(vlims[_sample],fontsize=fig_fontsize)
    # axs_profile[_i].set_yscale("log")
    for _tidx, tick in enumerate(axs_profile[_i].yaxis.get_majorticklabels()):
        tick.set_horizontalalignment("left")
        if _tidx == 0:
            tick.set_verticalalignment("bottom")
        elif _tidx == 1:
            tick.set_verticalalignment("top")
    # bottom one - show ticks for now ...
    axs_stack[_i].set_xticks(flank_ticks)
    # axs_stack[_i].set_xticklabels(flank_ticklabels,fontsize=fig_fontsize)
    axs_stack[_i].set_xticklabels([])
    axs_stack[_i].tick_params(axis="x", length=_xticklength)
    axs_stack[_i].set_yticks([])
    axs_stack[_i].set_yticklabels([])
    for _tidx, tick in enumerate(axs_stack[_i].xaxis.get_majorticklabels()):
        if _tidx == 0:
            tick.set_horizontalalignment("left")
        elif _tidx == 2:
            tick.set_horizontalalignment("right")
        else:
            tick.set_horizontalalignment("center")
    # colorbar - draw them one time per sample only !
    plt.colorbar(stack_hm,
                cax=axs_cbar[_i],
                orientation="horizontal",
                ticks=vlims[_sample])
    axs_cbar[_i].minorticks_off()
    axs_cbar[_i].tick_params(axis="x", length=_xticklength, pad=0.5)
    axs_cbar[_i].set_xticklabels(vlims[_sample], fontsize=fig_fontsize)
    for _tidx, tick in enumerate(axs_cbar[_i].xaxis.get_majorticklabels()):
        if _tidx == 0:
            tick.set_horizontalalignment("left")
        elif _tidx == 1:
            tick.set_horizontalalignment("right")

# # fig.savefig("Fig2C_stackup.svg", dpi=300)
# fig.savefig("Fig2C_stackup.svg", format="svg", dpi=300)

In [None]:
# _idx_order = id_anchors_dict["pcyto"]["size"].argsort()
# # _idx_order = ov["size"].argsort()
# # _idx_order = nov["size"].argsort()

stackups_anchor = stackups_anchor_dict["pcyto"]
_idx_order = stackups_anchor["dots"].sum(axis=1).argsort()


kkk = ['mG.atac', 'H3K27me3', 'H3K4me3', 'H3K27ac']
# kkk = ['mM.atac', 'mG.atac', 'pG.atac', 'H3K27me3', 'H3K4me3', 'H3K27ac', 'ctcf']

# hms_dict = {k: stackups_anchor[k][order_by_len_idx] for k in kkk}
hms_dict = {k: stackups_anchor[k][_idx_order] for k in kkk}
scales = {k: "log" for k in kkk}
vlims = {k: (bws_vlim[k]["vmin"], bws_vlim[k]["vmax"]) for k in kkk}
titles = {k: k for k in kkk}
cmaps = {k: davos_r for k in kkk}
# cmaps["H3K4me3"] = "Reds"
binsizes = {k: ((2*_flank) // _nbins) for k in kkk}

subtract_shoulders = False
kkk.append("evm")
kkk.append("evp")
for _evk in ["evm","evp"]:
    _smat = stackups_anchor[_evk][_idx_order]
    if subtract_shoulders:
        _shoulders = np.nanmean(np.c_[_smat[:,:25], _smat[:,-25:]], axis=1)[:,None]
        hms_dict[_evk] = _smat - _shoulders
    else:
        hms_dict[_evk] = _smat
    scales[_evk] = "linear"
    # vlims[_evk] = (-0.25,0.25)
    vlims[_evk] = (-1.1,1.1)
    titles[_evk] = _evk
    cmaps[_evk] = "RdBu_r"
    binsizes[_evk] = ((2*_flank) // _nbins)


# kkk.append("idcov")
kkk.append("dots")
# _kkk = "idcov"
# hms_dict[_kkk] = stackups_anchor[_kkk][_idx_order]
# scales[_kkk] = "linear"
# vlims[_kkk] = (1,25)
# titles[_kkk] = _kkk
# cmaps[_kkk] = davos_r
# binsizes[_kkk] = ((2*_flank) // _nbins)

_kkk = "dots"
hms_dict[_kkk] = stackups_anchor[_kkk][_idx_order]
scales[_kkk] = "linear"
vlims[_kkk] = (0,0.5)
titles[_kkk] = _kkk
cmaps[_kkk] = davos_r
binsizes[_kkk] = ((2*_flank) // _nbins)

plot_stackups_lite(
    None,
    hms_dict,  # heatmaps dict, that controls the order, the amount etc etc ...
    scales,
    vlims,
    titles,
    cmaps,
    binsizes,
    fillmissing=False,
    len_per_thousand=4.9,
    width_per_stack=3.2,
    extra_height=3.,  # height that goes toward the profile and colorbar
    interpolation="gaussian",
);

In [None]:
keys_to_display = [ 'mG.atac', 'H3K4me3', 'H3K27ac', 'H3K27me3', "rna", "evm","evp", "dots"]

hms_dict = {}
for name, g in mega_anchor_dict.items():
    stackups_anchor = stackups_anchor_dict[name]
    # define ordering ...
    _idx = stackups_anchor["rna"][g.index].mean(axis=1).argsort()
    # _idx = g["valency"].argsort()
    # _idx = g["size"].argsort()
    hms_dict[name] = {k: stackups_anchor[k][_idx] for k in keys_to_display}

# ...
scales = {k: "log" for k in keys_to_display}
vlims = {k: (bws_vlim[k]["vmin"], bws_vlim[k]["vmax"]) for k in keys_to_display}
titles = {k: k for k in keys_to_display}
cmaps = {k: davos_r for k in keys_to_display}
binsizes = {k: ((2*_flank) // _nbins) for k in keys_to_display}

for _evk in ["evm","evp","dots"]:
    scales[_evk] = "linear"
    cmaps[_evk] = "RdBu_r"

cmaps["dots"] = davos_r

_xfsize = 12
axx = plot_stackups_sets(
    None,
    hms_dict,
    scales,
    vlims,
    titles,
    cmaps,
    binsizes,
    fillmissing=False,
    len_per_thousand=6.9,
    width_per_stack=3.,
    extra_height=3.,  # height that goes toward the profile and colorbar
    delta_h=.3,  # fixed distance between axes (vertically)
    delta_w=.5,  # fixed distance between axes (horizontal
    fig_fontsize=_xfsize,
    interpolation="gaussian",
)


plt.savefig("Fig3d-2way_stackup.svg",dpi=300)

In [None]:
# # keys_to_display = [ 'H3K27me3',  'H3K9me3', 'mM.atac', 'mG.atac', 'pG.atac', 'H3K4me3', 'H3K27ac', "ctcf", "rg0r1fwd", "rg0r1rev"]
# keys_to_display = [ 'H3K9me3','H3K27me3','mG.atac', 'H3K4me3', 'H3K27ac', "ctcf", "rg0r1fwd", "rg0r1rev"]
# # hms_dict = {k: stackups_anchor[k][order_by_len_idx] for k in keys_to_display}
# # hms_dict = {k: stackups_anchor[k][_idx_order] for k in keys_to_display}

# hms_dict = {}
# for name, g in mega_anchor_dict.items():
#     stackups_anchor = stackups_anchor_dict[name]
#     # _idx = (stackups_anchor["rg0r1fwd"].mean(axis=1) + stackups_anchor["rg0r1rev"].mean(axis=1)).argsort()
#     _idx = stackups_anchor["dots"][:, 100-25:100+25].mean(axis=1).argsort()
#     # _idx = stackups_anchor["dots"].mean(axis=1).argsort()
#     # _idx = stackups_anchor["evp"].mean(axis=1).argsort()
#     # _idx = g["valency"].argsort()
#     if name == "nov":
#         hms_dict[name] = {k: stackups_anchor[k][_idx] for k in keys_to_display}
#     else:
#         hms_dict[name] = {k: stackups_anchor[k][_idx] for k in keys_to_display}
#     # add ev
#     for _evk in ["evm","evp", "idcov", "dots"]:
#         if name == "nov":
#             _smat = stackups_anchor[_evk][_idx]
#         else:
#             _smat = stackups_anchor[_evk][_idx]
#         hms_dict[name][_evk] = _smat

# # ...
# scales = {k: "log" for k in keys_to_display}
# vlims = {k: (bws_vlim[k]["vmin"], bws_vlim[k]["vmax"]) for k in keys_to_display}
# titles = {k: k for k in keys_to_display}
# cmaps = {k: davos_r for k in keys_to_display}
# binsizes = {k: ((2*_flank) // _nbins) for k in keys_to_display}

# for _evk in ["evm","evp"]:
#     keys_to_display.append(_evk)
#     scales[_evk] = "linear"
#     vlims[_evk] = (-1.2,1.2)
#     titles[_evk] = _evk
#     cmaps[_evk] = "RdBu_r"
#     binsizes[_evk] = ((2*_flank) // _nbins)


# keys_to_display.append("idcov")
# keys_to_display.append("dots")
# _kkk = "idcov"
# scales[_kkk] = "linear"
# vlims[_kkk] = (1,25)
# titles[_kkk] = _kkk
# cmaps[_kkk] = davos_r
# binsizes[_kkk] = ((2*_flank) // _nbins)
# _kkk = "dots"
# scales[_kkk] = "linear"
# vlims[_kkk] = (0,0.5)
# titles[_kkk] = _kkk
# cmaps[_kkk] = davos_r
# binsizes[_kkk] = ((2*_flank) // _nbins)

# _xfsize = 12
# axx = plot_stackups_sets(
#     None,
#     hms_dict,
#     scales,
#     vlims,
#     titles,
#     cmaps,
#     binsizes,
#     fillmissing=False,
#     len_per_thousand=6.9,
#     width_per_stack=3.,
#     extra_height=3.,  # height that goes toward the profile and colorbar
#     delta_h=.3,  # fixed distance between axes (vertically)
#     delta_w=.5,  # fixed distance between axes (horizontal)
#     fig_fontsize=_xfsize,
#     interpolation="gaussian",
# )

# plt.savefig("3way_stackup.svg",dpi=300)