In [None]:
# set the number of threads for many common libraries
from os import environ
N_THREADS = '1'
environ['OMP_NUM_THREADS'] = N_THREADS
environ['OPENBLAS_NUM_THREADS'] = N_THREADS
environ['MKL_NUM_THREADS'] = N_THREADS
environ['VECLIB_MAXIMUM_THREADS'] = N_THREADS
environ['NUMEXPR_NUM_THREADS'] = N_THREADS

# https://superfastpython.com/numpy-number-blas-threads/

In [None]:
%load_ext autoreload
%autoreload 2
# from saddle import saddleplot

In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
import pandas as pd
import numpy as np
from itertools import chain

# Hi-C utilities imports:
import cooler
import bioframe
import cooltools
from cooltools.lib.numutils import fill_diag

# Visualization imports:
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import matplotlib.patches as patches
from matplotlib.ticker import EngFormatter

from itertools import cycle

# from ipywidgets import interact, fixed

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets


# enable editable text ...
mpl.rcParams["pdf.fonttype"]=42
mpl.rcParams["svg.fonttype"]="none"
mpl.rcParams['axes.linewidth'] = 0.5

In [None]:
# ! pip install --upgrade --no-cache --no-deps --ignore-install cooler
# ls /home/dekkerlab/dots-test
# import higlass as hg
# import jscatter
import scipy
import logging
import multiprocess as mp
# import mpire for nested multi-processing
from mpire import WorkerPool
# bbi for stackups ...
import bbi
from data_catalog import bws, bws_vlim, telo_dict
from helper_func import (
    get_stack,
    show_stacks,
    plot_stackups_lite,
    plot_stackups_sets,
    plot_stackups_sets_new,
    to_bigbed3,
)


from tqdm import tqdm
from tqdm.notebook import trange, tqdm
import warnings

In [None]:
bws_vlim = {
    "mM.atac": dict(vmin=0.2,vmax=1.),
    # "mT.atac": dict(vmin=0.005,vmax=0.3),
    "mG.atac": dict(vmin=0.2,vmax=1.),
    # "pM.atac": dict(vmin=0.005,vmax=0.3),
    # # "pT.atac": dict(vmin=0.005,vmax=0.3),
    "pG.atac": dict(vmin=0.2,vmax=1.),
    "H3K36me3": dict(vmin=0.03,vmax=0.08),
    "H3K4me1": dict(vmin=0.04,vmax=0.1),
    "MED1": dict(vmin=0.04,vmax=0.15),
    "H3K27me3": dict(vmin=0.2,vmax=1),
    "H3K4me3": dict(vmin=1,vmax=14),
    "H3K27ac": dict(vmin=1,vmax=14),
    "ctcf": dict(vmin=2,vmax=5),
    # some rna-seq - Async rangap control and depletion ...
    "rg0r1fwd": dict(vmin=0.5,vmax=12),
    "rg0r1rev": dict(vmin=0.5,vmax=12),
    "rg8r1fwd": dict(vmin=0.5,vmax=12),
    "rg8r1rev": dict(vmin=0.5,vmax=12),
}


bws_vlim["mM.atac"] = dict(vmin=0.1,vmax=1.)
bws_vlim["mG.atac"] = dict(vmin=0.15,vmax=1.)
bws_vlim["pG.atac"] = dict(vmin=0.1,vmax=1.)
bws_vlim["H3K27me3"] = dict(vmin=0.15,vmax=0.85)
bws_vlim["H3K9me3"] = dict(vmin=0.15,vmax=1)
bws_vlim["H3K4me3"] = dict(vmin=0.5,vmax=14)
bws_vlim["H3K27ac"] = dict(vmin=0.9,vmax=18)
bws_vlim["ctcf"] = dict(vmin=1.8,vmax=6)
bws_vlim["evm"] = dict(vmin=-1.3,vmax=1.3)
bws_vlim["evp"] = dict(vmin=-1.3,vmax=1.3)
bws_vlim["dots"] = dict(vmin=0,vmax=0.5)
# # # some rna-seq - Async rangap control and depletion ...
bws_vlim["rna"] = dict(vmin=0.5,vmax=40)
bws_vlim["rg0r1fwd"] = dict(vmin=0.5,vmax=35)
bws_vlim["rg0r1rev"] = dict(vmin=0.5,vmax=35)
bws_vlim["rg8r1fwd"] = dict(vmin=0.5,vmax=35)
bws_vlim["rg8r1rev"] = dict(vmin=0.5,vmax=35)


from palettable.scientific import sequential
lajolla_r = sequential.LaJolla_10_r.get_mpl_colormap()
bilbao_r = sequential.Bilbao_20_r.get_mpl_colormap()
batlow_r = sequential.Batlow_20_r.get_mpl_colormap()
devon_r = sequential.Devon_20_r.get_mpl_colormap()
davos_r = sequential.Davos_20_r.get_mpl_colormap()
# fall_cmap = fall
lajolla = sequential.LaJolla_10.get_mpl_colormap()
bilbao = sequential.Bilbao_20.get_mpl_colormap()
batlow = sequential.Batlow_20.get_mpl_colormap()
devon = sequential.Devon_20.get_mpl_colormap()
davos = sequential.Davos_20.get_mpl_colormap()

imola_r = sequential.Imola_20_r.get_mpl_colormap()
imola = sequential.Imola_20.get_mpl_colormap()

nuuk_r = sequential.Nuuk_20_r.get_mpl_colormap()
nuuk = sequential.Nuuk_20.get_mpl_colormap()

hawaii_r = sequential.Hawaii_20_r.get_mpl_colormap()
hawaii = sequential.Hawaii_20.get_mpl_colormap()

oleron_r = sequential.Oleron_20_r.get_mpl_colormap()
oleron = sequential.Oleron_20.get_mpl_colormap()

oslo_r = sequential.Oslo_20_r.get_mpl_colormap()
oslo = sequential.Oslo_20.get_mpl_colormap()

bamako_r = sequential.Bamako_20_r.get_mpl_colormap()
bamako = sequential.Bamako_20.get_mpl_colormap()

grayc_r = sequential.GrayC_20_r.get_mpl_colormap()
grayc = sequential.GrayC_20.get_mpl_colormap()

lapaz_r = sequential.LaPaz_20_r.get_mpl_colormap()
lapaz = sequential.LaPaz_20.get_mpl_colormap()

### arms here just in case ...

In [None]:
# define genomic view that will be used to call dots and pre-compute expected

# Use bioframe to fetch the genomic features from the UCSC.
hg38_chromsizes = bioframe.fetch_chromsizes('hg38')
hg38_cens = bioframe.fetch_centromeres('hg38')
hg38_arms_full = bioframe.make_chromarms(hg38_chromsizes, hg38_cens)
# autosomal only ...
included_arms = hg38_arms_full["name"].to_list()[:44]
hg38_arms = hg38_arms_full[hg38_arms_full["name"].isin(included_arms)].reset_index(drop=True)

# Read pre-called Intrinsic Domains (IDs) ...

In [None]:
ids_fnames = {
    "mega_2X_enrichment": "native_comps_10kb/fourth_mega.bedpe",
    "5hr_2X_enrichment_old": "native_comps_10kb/second_bulk.bedpe",
    "5hr_2X_enrichment": "native_comps_10kb/5hr_2X_second.bedpe",
    "mega_3X_enrichment": "native_comps_10kb/fifth_mega3x.bedpe",
    "cyto_2x_enrichment": "native_comps_10kb/third_mCyto.bedpe",
}

# let's load them all into a dictionary ...
ids_dict = {}
for id_name, fname in ids_fnames.items():
    ids_dict[id_name] = pd.read_csv(fname, sep="\t")
    # ...
    print(f"loaded {len(ids_dict[id_name]):5d} ID interactions {id_name:>20} in BEDPE format ...")

# Read pre-called Intrinsic Domain anchors - ID anchors

In [None]:
id_anchor_fnames = {
    "mega_2X_enrichment": "ID_anchors/mega_2X_enrichment.fourth_mega.max_size.bed",
    "5hr_2X_enrichment_old": "ID_anchors/5hr_2X_enrichment.second_bulk.max_size.bed",
    "5hr_2X_enrichment": "ID_anchors/5hr_2X_enrichment.pixel_derived.bed",
    "5hr_2X_enrichment_nosing": "ID_anchors/5hr_2X_enrichment.pixel_derived.no_singletons.bed",
    "5hr_2X_enrichment_signal": "ID_anchors/5hr_2X_enrichment.pixel_derived.signal_peaks.bed",
    "10hr_2X_enrichment_signal": "ID_anchors/10hrs_2X_enrichment.pixel_derived.signal_peaks.bed",
    "N93p5_2X_enrichment_signal": "ID_anchors/N93p5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "pCyto_2X_enrichment_signal": "ID_anchors/pCyto_2X_enrichment.pixel_derived.signal_peaks.bed",
    "mCyto_2X_enrichment_signal": "ID_anchors/mCyto_2X_enrichment.pixel_derived.signal_peaks.bed",
    "mega_3X_enrichment": "ID_anchors/mega_3X_enrichment.fifth_mega3x.max_size.bed",
    "MEGA_2X_enrichment": "ID_anchors/MEGAp5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGAN93_2X_enrichment": "ID_anchors/MEGAN93p5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGAminus_2X_enrichment": "ID_anchors/MEGA_minus_ctrl_2X_enrichment.pixel_derived.signal_peaks.bed",
    "cyto_2x_enrichment": "ID_anchors/cyto_2x_enrichment.third_mCyto.max_size.bed",
}


id_anchors_dict = {}
for id_name, fname in id_anchor_fnames.items():
    id_anchors_dict[id_name] = pd.read_csv(fname, sep="\t")
    # ...
    print(f"loaded {len(id_anchors_dict[id_name]):5d} ID anchors {id_name:>20} in BED format ...")

### To stackups now ...

In [None]:
# Create of the stackup, the flanks are +- 50 Kb, number of bins is 100 :
_flank = 100_000 # Length of flank to one side from the boundary, in basepairs
_nbins = 100   # Number of bins to split the region

In [None]:
# ! ls ev_bigwig

In [None]:
# # bws
bws["evm"] = "ev_bigwig/m5hR1R2.10kb.bw"
bws["evp"] = "ev_bigwig/p5hR1R2.10kb.bw"
bws["idcov"] = "pix_clust_cov.bw"


bws["nevm"] = "ev_bigwig/N93m5.10kb.bw"
bws["nevp"] = "ev_bigwig/N93p5.10kb.bw"

# # plug insulation there jus for fun ...
# bws["evm"] = "ranGAP1-0-G1s-R1R2.hg38.mapq_30.1000.b10000.insul.w50000.bw"
# bws["evp"] = "ranGAP1-7-G1s-R1R2.hg38.mapq_30.1000.b10000.insul.w50000.bw"


In [None]:
# _chosen_stackup_name = "mega_2X_enrichment"
_chosen_stackup_name = "5hr_2X_enrichment_signal"

In [None]:
for k in bw_kyes_to_use:
    print(k, bws[k])

In [None]:
# ! ls /abyss/sergpolly/data_ranger/bigwigs/rnaseq/

In [None]:

bw_kyes_to_use = [
    # 'mM.atac',
    'mG.atac',
    # 'pG.atac',
    'H3K4me3',
    'H3K27ac',
    'H3K27me3',
    # 'ctcf',
    "evm",
    "evp",
    "nevm",
    "nevp",
    # "ids",
    # "dots",
    "rg0r1fwd",
    "rg0r1rev",
]


def _job(packed_data, bw_sample):
    # unpack shared data
    flank, nbins, track_dict, bed_df = packed_data
    bw_track = track_dict[bw_sample]
    stack_kwargs = dict(kind="mid", flank=flank, nbins=nbins)
    from helper_func import get_stack
    return (
        bw_sample,
        get_stack(bw_track, bed_df, **stack_kwargs),
    )

_shared = (
    _flank,
    _nbins,
    bws,
    id_anchors_dict[_chosen_stackup_name],
)

with WorkerPool(
    n_jobs=len(bws),
    shared_objects=_shared,
    start_method="forkserver",  # little faster than spawn, fork is the fastest
    use_dill=True,
) as wpool:
    results = wpool.map(_job, bw_kyes_to_use, progress_bar=True)

# sort out the results ...
stackups_anchor = {sample: _pstack for sample, _pstack in results}

# Combine RNA + and - into a single stack ...
## Let's combine RNA-seq in a special way: fwd + rev[::-1]
# ...
# redefine RNA-seq track by combining fwd and rev tracks in a special way:
_rna_fwd = stackups_anchor.pop("rg0r1fwd")
_rna_rev = stackups_anchor.pop("rg0r1rev")
#
stackups_anchor["rna"] = _rna_fwd + _rna_rev[:,::-1]

In [None]:
# for k in bw_kyes_to_use:
#     print(k, bws[k])

In [None]:
# bws_vlim
# num_stackup_groups, stackup_samples, stackup_group_heights, stack_width = _get_hms_nested_shape(hms_dict)
# stackup_group_heights
# _get_hms_nested_shape?
#  _get_hms_nested_shape({"hui_": _stacks_dict})
# keys_to_display = [ 'mG.atac', 'H3K4me3', 'H3K27ac', 'H3K27me3', "rna", "evm", "evp"]
# bws_vlim["mG.atac"] = dict(vmin=0.1,vmax=1)
# bws_vlim["H3K4me3"] = dict(vmin=0.1,vmax=11)
# bws_vlim["H3K27ac"] = dict(vmin=0.1,vmax=10)
# bws_vlim["H3K27me3"] = dict(vmin=0.1,vmax=0.9)
# bws_vlim["rna"] = dict(vmin=0.1,vmax=25)
# bws_vlim["evm"] = dict(vmin=-1.3,vmax=1.3)
# bws_vlim["evp"] = dict(vmin=-1.3,vmax=1.3)
keys_to_display = [ 'mG.atac', 'H3K4me3', 'H3K27ac', 'H3K27me3', "rna", "evm", "evp"]

bws_vlim["mM.atac"] = dict(vmin=0.1,vmax=1.)
bws_vlim["mG.atac"] = dict(vmin=0.15,vmax=1.)
bws_vlim["pG.atac"] = dict(vmin=0.1,vmax=1.)
bws_vlim["H3K27me3"] = dict(vmin=0.15,vmax=0.85)
bws_vlim["H3K9me3"] = dict(vmin=0.15,vmax=1)
bws_vlim["H3K4me3"] = dict(vmin=0.5,vmax=14)
bws_vlim["H3K27ac"] = dict(vmin=0.9,vmax=18)
bws_vlim["ctcf"] = dict(vmin=1.8,vmax=6)
bws_vlim["evm"] = dict(vmin=-1.3,vmax=1.3)
bws_vlim["evp"] = dict(vmin=-1.3,vmax=1.3)
bws_vlim["nevm"] = dict(vmin=-1.3,vmax=1.3)
bws_vlim["nevp"] = dict(vmin=-1.3,vmax=1.3)
bws_vlim["dots"] = dict(vmin=0,vmax=0.5)
# # # some rna-seq - Async rangap control and depletion ...
bws_vlim["rna"] = dict(vmin=0.5,vmax=40)
bws_vlim["rg0r1fwd"] = dict(vmin=0.5,vmax=35)
bws_vlim["rg0r1rev"] = dict(vmin=0.5,vmax=35)
bws_vlim["rg8r1fwd"] = dict(vmin=0.5,vmax=35)
bws_vlim["rg8r1rev"] = dict(vmin=0.5,vmax=35)

bws_vlim["rna"] = dict(vmin=2,vmax=30)
bws_vlim["H3K4me3"] = dict(vmin=1,vmax=12)
bws_vlim["H3K27ac"] = dict(vmin=1.5,vmax=12)
bws_vlim["mG.atac"] = dict(vmin=0.2,vmax=0.7)
bws_vlim["H3K27me3"] = dict(vmin=0.2,vmax=0.9)

# num_extra_plots,
# hms_dict_dict,
# ...
scales = {k: "linear" for k in keys_to_display}
vlims = {k: (bws_vlim[k]["vmin"], bws_vlim[k]["vmax"]) for k in keys_to_display}
titles = {k: k for k in keys_to_display}
cmaps = {k: davos_r for k in keys_to_display}
# cmaps = {k: "Blues" for k in keys_to_display}
binsizes = {k: ((2*_flank) // _nbins) for k in keys_to_display}
# modify EV-related things ...
for _evk in ["evm","evp"]:
    scales[_evk] = "linear"
    cmaps[_evk] = "RdBu_r"

cmaps["dots"] = davos_r

_df = id_anchors_dict[_chosen_stackup_name]
_valency_groups = _df.groupby(pd.cut(_df["valency"],[0,np.inf]), observed=True)
hms_dict = {}
for name, g in _valency_groups:
    _idx = g["size"].argsort()
    hms_dict[name] = {k: stackups_anchor[k][g.index[_idx]] for k in keys_to_display}


In [None]:
axx = plot_stackups_sets_new(
    1,
    hms_dict,  # heatmaps dict, that controls the order, the amount etc etc ...
    scales,
    vlims,
    titles,
    cmaps,
    binsizes,
    fillmissing=True,
    len_per_thousand=0.73,
    width_per_stack=0.33,
    profile_height=0.33,
    cbar_height = 0.07,
    spacing_v=.03,  # fixed distance between axes (vertically)
    spacing_h=.03,  # fixed distance between axes (horizontally)
    fig_fontsize=6,
    # fig_fname="Fig2C_stackup.svg",
    # fig_dpi=300,
    # **imshow_kwargs,
)

ax, = axx
for ii, (name, g) in enumerate(_valency_groups):
    _size_id = 0
    _idx = g["size"].argsort()
    # _idx = stackups_anchor["H3K27me3"][g.index].mean(axis=1).argsort()
    ax.fill_betweenx(
        np.arange(len(g)),
        g["size"].iloc[_idx],
        0,
        # marker="",
        # alpha=0.8,
        linewidth=1,
        edgecolor="gray",
        facecolor="lightgrey",
        interpolate=True,
    )
    # ax.set_xlim([0,250_000])
    ax.set_xlim([250_000, 0])
    ax.set_ylim(len(g), 0)
    # ax.set_xscale("log")
    ax.set_yticks([])
    ax.set_yticklabels([])
    ax.set_xticks([])
    ax.set_xticklabels([])

# fig.savefig("Fig2C_stackup.svg", dpi=300)
plt.savefig("Fig2C_stackup.svg", format="svg", dpi=300)
# # fig.savefig("Fig2C_stackup.svg", dpi=300)
# fig.savefig("Fig2C_stackup.svg", format="svg", dpi=300)

# Extended figure EV stackup

In [None]:
keys_to_display = [ "nevm", "nevp"]

_df = id_anchors_dict[_chosen_stackup_name]
_valency_groups = _df.groupby(pd.cut(_df["valency"],[0,np.inf]), observed=True)
hms_dict = {}
for name, g in _valency_groups:
    _idx = g["size"].argsort()
    hms_dict[name] = {k: stackups_anchor[k][g.index[_idx]] for k in keys_to_display}


# num_extra_plots,
# hms_dict_dict,
# ...
scales = {k: "linear" for k in keys_to_display}
vlims = {k: (bws_vlim[k]["vmin"], bws_vlim[k]["vmax"]) for k in keys_to_display}
titles = {k: k for k in keys_to_display}
cmaps = {k: davos_r for k in keys_to_display}
# cmaps = {k: "Blues" for k in keys_to_display}
binsizes = {k: ((2*_flank) // _nbins) for k in keys_to_display}
# modify EV-related things ...
for _evk in ["nevm","nevp"]:
    scales[_evk] = "linear"
    cmaps[_evk] = "RdBu_r"



In [None]:
axx = plot_stackups_sets_new(
    1,
    hms_dict,  # heatmaps dict, that controls the order, the amount etc etc ...
    scales,
    vlims,
    titles,
    cmaps,
    binsizes,
    fillmissing=True,
    len_per_thousand=0.73,
    width_per_stack=0.33,
    profile_height=0.33,
    cbar_height = 0.07,
    spacing_v=.03,  # fixed distance between axes (vertically)
    spacing_h=.03,  # fixed distance between axes (horizontally)
    fig_fontsize=6,
    # fig_fname="Fig2C_stackup.svg",
    # fig_dpi=300,
    # **imshow_kwargs,
)

ax, = axx
for ii, (name, g) in enumerate(_valency_groups):
    _size_id = 0
    _idx = g["size"].argsort()
    # _idx = stackups_anchor["H3K27me3"][g.index].mean(axis=1).argsort()
    ax.fill_betweenx(
        np.arange(len(g)),
        g["size"].iloc[_idx],
        0,
        # marker="",
        # alpha=0.8,
        linewidth=1,
        edgecolor="gray",
        facecolor="lightgrey",
        interpolate=True,
    )
    # ax.set_xlim([0,250_000])
    ax.set_xlim([250_000, 0])
    ax.set_ylim(len(g), 0)
    # ax.set_xscale("log")
    ax.set_yticks([])
    ax.set_yticklabels([])
    ax.set_xticks([])
    ax.set_xticklabels([])

# fig.savefig("Fig2C_stackup.svg", dpi=300)
# plt.savefig("Fig2C_stackup.svg", format="svg", dpi=300)
# # fig.savefig("Fig2C_stackup.svg", dpi=300)
# fig.savefig("Fig2C_stackup.svg", format="svg", dpi=300)

# fig.savefig("Fig2C_stackup.svg", dpi=300)
plt.savefig("FigE2B_stackup.svg", format="svg", dpi=300)

# legacy plotting below ...

In [None]:
_idx_order = id_anchors_dict[_chosen_stackup_name]["size"].argsort()
# _idx_order = ov["size"].argsort()
# _idx_order = nov["size"].argsort()
# _idx_order = stackups_anchor["H3K27ac"].sum(axis=1).argsort()

# kkk = ['mG.atac', 'H3K27me3', 'H3K4me3', 'H3K27ac', 'H3K36me3', 'H3K4me1', 'MED1', 'ctcf']
kkk = ['mM.atac', 'mG.atac', 'pG.atac', 'H3K27me3', 'H3K4me3', 'H3K27ac', 'ctcf', "evm", "evp"]

# hms_dict = {k: stackups_anchor[k][order_by_len_idx] for k in kkk}
hms_dict = {k: stackups_anchor[k][_idx_order] for k in kkk}
scales = {k: "log" for k in kkk}
vlims = {k: (bws_vlim[k]["vmin"], bws_vlim[k]["vmax"]) for k in kkk}
titles = {k: k for k in kkk}
cmaps = {k: davos_r for k in kkk}
# cmaps["H3K4me3"] = "Reds"
binsizes = {k: ((2*_flank) // _nbins) for k in kkk}

# modify EV-related things ...
for _evk in ["evm","evp"]:
    scales[_evk] = "linear"
    cmaps[_evk] = "RdBu_r"

# kkk.append("idcov")
# _kkk = "idcov"
# hms_dict[_kkk] = stackups_anchor[_kkk][_idx_order]
# scales[_kkk] = "linear"
# vlims[_kkk] = (1,25)
# titles[_kkk] = _kkk
# cmaps[_kkk] = davos_r
# binsizes[_kkk] = ((2*_flank) // _nbins)

plot_stackups_lite(
    None,
    hms_dict,  # heatmaps dict, that controls the order, the amount etc etc ...
    scales,
    vlims,
    titles,
    cmaps,
    binsizes,
    fillmissing=False,
    len_per_thousand=4.9,
    width_per_stack=3.2,
    extra_height=3.,  # height that goes toward the profile and colorbar
    interpolation="gaussian",
);

### exploring - testing valency related stuff ...

In [None]:
keys_to_display = [ 'H3K27me3', 'mG.atac', 'H3K4me3', 'H3K27ac'] + ["evm","evp"]

_df = id_anchors_dict[_chosen_stackup_name]
_valency_groups = _df.groupby(pd.cut(_df["valency"],[0,1,np.inf]), observed=True)
hms_dict = {}
for name, g in _valency_groups:
    _idx = g["size"].argsort()
    # _idx = stackups_anchor["H3K27ac"][g.index].mean(axis=1).argsort()
    hms_dict[name] = {k: stackups_anchor[k][g.index[_idx]] for k in keys_to_display}

# ...
scales = {k: "log" for k in keys_to_display}
vlims = {k: (bws_vlim[k]["vmin"], bws_vlim[k]["vmax"]) for k in keys_to_display}
titles = {k: k for k in keys_to_display}
cmaps = {k: davos_r for k in keys_to_display}
binsizes = {k: ((2*_flank) // _nbins) for k in keys_to_display}

for _evk in ["evm","evp"]:
    scales[_evk] = "linear"
    cmaps[_evk] = "RdBu_r"


_xfsize = 12
axx = plot_stackups_sets(
    2,
    hms_dict,
    scales,
    vlims,
    titles,
    cmaps,
    binsizes,
    fillmissing=False,
    len_per_thousand=4.9,
    width_per_stack=3.,
    extra_height=3.,  # height that goes toward the profile and colorbar
    delta_h=.3,  # fixed distance between axes (vertically)
    delta_w=.5,  # fixed distance between axes (horizontal
    fig_fontsize=_xfsize,
    interpolation="gaussian",
)

##################################################################
for ii, (name, g) in enumerate(_valency_groups):

    _idx = g["size"].argsort()
    # _idx = stackups_anchor["H3K27ac"][g.index].mean(axis=1).argsort()

    axx[0][ii].scatter(
        x = g["valency"].iloc[_idx],
        y = np.arange(len(g)),
        marker="s",
        alpha=0.8,
    )
    axx[0][ii].set_ylim(len(g), 0)
    axx[0][ii].set_xlim(0,150)
    axx[1][ii].scatter(
        x = g["size"].iloc[_idx],
        y = np.arange(len(g)),
        marker="s",
        alpha=0.8,
    )
    axx[1][ii].set_xlim([0,250_000])
    axx[1][ii].set_ylim(len(g), 0)
    # axx[1][ii].set_xscale("log")

    axx[0][ii].set_yticks([])
    axx[0][ii].set_yticklabels([])
    axx[1][ii].set_yticks([])
    axx[1][ii].set_yticklabels([])

    if ii == 2:
        # axx[0][ii].set_xticks([0,1])
        # axx[0][ii].set_xticklabels([0,1],fontsize=_xfsize)
        for _tidx, tick in enumerate(axx[0][ii].xaxis.get_majorticklabels()):
            if _tidx == 0:
                tick.set_horizontalalignment("left")
            else:
                tick.set_horizontalalignment("right")
        axx[0][ii].tick_params(axis="x", length=6)
    else:
        axx[0][ii].set_xticks([])
        axx[0][ii].set_xticklabels([])

    if ii == 2:
        # axx[1][ii].set_xticks([_tpmin,_tpmax])
        # # axx[1][ii].set_xticklabels([f"{_tpmin:.1f}",f"{_tpmax:.1f}"],fontsize=_xfsize)
        for _tidx, tick in enumerate(axx[1][ii].xaxis.get_majorticklabels()):
            if _tidx == 0:
                tick.set_horizontalalignment("left")
            else:
                tick.set_horizontalalignment("right")
        axx[1][ii].tick_params(axis="x", length=6)
    else:
        axx[1][ii].set_xticks([])
        axx[1][ii].set_xticklabels([])


## Old Fig 2B ...

In [None]:
keys_to_display = [ 'mG.atac', 'H3K4me3', 'H3K27ac', 'H3K27me3', "evm", "evp"]

_df = id_anchors_dict[_chosen_stackup_name]
_valency_groups = _df.groupby(pd.cut(_df["valency"],[0,np.inf]), observed=True)
hms_dict = {}
for name, g in _valency_groups:
    _idx = g["size"].argsort()
    # _idx = stackups_anchor["H3K27me3"][g.index].mean(axis=1).argsort()
    hms_dict[name] = {k: stackups_anchor[k][g.index[_idx]] for k in keys_to_display}

# ...
scales = {k: "log" for k in keys_to_display}
vlims = {k: (bws_vlim[k]["vmin"], bws_vlim[k]["vmax"]) for k in keys_to_display}
titles = {k: k for k in keys_to_display}
cmaps = {k: davos_r for k in keys_to_display}
binsizes = {k: ((2*_flank) // _nbins) for k in keys_to_display}

# modify EV-related things ...
for _evk in ["evm","evp"]:
    scales[_evk] = "linear"
    cmaps[_evk] = "RdBu_r"


_xfsize = 12
axx = plot_stackups_sets(
    2,
    hms_dict,
    scales,
    vlims,
    titles,
    cmaps,
    binsizes,
    fillmissing=False,
    len_per_thousand=4.9,
    width_per_stack=3.,
    extra_height=3.,  # height that goes toward the profile and colorbar
    delta_h=.3,  # fixed distance between axes (vertically)
    delta_w=.5,  # fixed distance between axes (horizontal
    fig_fontsize=_xfsize,
    interpolation="gaussian",
)

##################################################################
for ii, (name, g) in enumerate(_valency_groups):

    _valency_id = 1
    _size_id = 0

    _idx = g["size"].argsort()
    # _idx = stackups_anchor["H3K27me3"][g.index].mean(axis=1).argsort()

    axx[_valency_id][ii].scatter(
        x = g["valency"].iloc[_idx],
        y = np.arange(len(g)),
        marker="s",
        alpha=0.8,
    )
    axx[_valency_id][ii].set_ylim(len(g), 0)
    axx[_valency_id][ii].set_xlim(0,150)
    axx[_size_id][ii].fill_betweenx(
        np.arange(len(g)),
        g["size"].iloc[_idx],
        0,
        # marker="",
        # alpha=0.8,
        linewidth=5,
        edgecolor="gray",
        facecolor="lightgrey",
        interpolate=True,
    )
    axx[_size_id][ii].set_xlim([0,250_000])
    axx[_size_id][ii].set_ylim(len(g), 0)
    # axx[_size_id][ii].set_xscale("log")

    axx[_valency_id][ii].set_yticks([])
    axx[_valency_id][ii].set_yticklabels([])
    axx[_size_id][ii].set_yticks([])
    axx[_size_id][ii].set_yticklabels([])

    if ii == 2:
        # axx[_valency_id][ii].set_xticks([0,1])
        # axx[_valency_id][ii].set_xticklabels([0,1],fontsize=_xfsize)
        for _tidx, tick in enumerate(axx[_valency_id][ii].xaxis.get_majorticklabels()):
            if _tidx == 0:
                tick.set_horizontalalignment("left")
            else:
                tick.set_horizontalalignment("right")
        axx[_valency_id][ii].tick_params(axis="x", length=6)
    else:
        axx[_valency_id][ii].set_xticks([])
        axx[_valency_id][ii].set_xticklabels([])

    if ii == 2:
        # axx[_size_id][ii].set_xticks([_tpmin,_tpmax])
        # # axx[_size_id][ii].set_xticklabels([f"{_tpmin:.1f}",f"{_tpmax:.1f}"],fontsize=_xfsize)
        for _tidx, tick in enumerate(axx[_size_id][ii].xaxis.get_majorticklabels()):
            if _tidx == 0:
                tick.set_horizontalalignment("left")
            else:
                tick.set_horizontalalignment("right")
        axx[_size_id][ii].tick_params(axis="x", length=6)
    else:
        axx[_size_id][ii].set_xticks([])
        axx[_size_id][ii].set_xticklabels([])

# newer Fig2B with genes ...

In [None]:
keys_to_display = [ 'mG.atac', 'H3K4me3', 'H3K27ac', 'H3K27me3', "rna", "evm", "evp"]

_df = id_anchors_dict[_chosen_stackup_name]
_valency_groups = _df.groupby(pd.cut(_df["valency"],[0,np.inf]), observed=True)
hms_dict = {}
for name, g in _valency_groups:
    _idx = g["size"].argsort()
    hms_dict[name] = {k: stackups_anchor[k][g.index[_idx]] for k in keys_to_display}

# ...
scales = {k: "log" for k in keys_to_display}
vlims = {k: (bws_vlim[k]["vmin"], bws_vlim[k]["vmax"]) for k in keys_to_display}
titles = {k: k for k in keys_to_display}
cmaps = {k: davos_r for k in keys_to_display}
binsizes = {k: ((2*_flank) // _nbins) for k in keys_to_display}

# modify EV-related things ...
for _evk in ["evm","evp"]:
    scales[_evk] = "linear"
    cmaps[_evk] = "RdBu_r"


_xfsize = 12
axx = plot_stackups_sets(
    1,
    hms_dict,
    scales,
    vlims,
    titles,
    cmaps,
    binsizes,
    fillmissing=False,
    len_per_thousand=4.9,
    width_per_stack=3.,
    extra_height=3.,  # height that goes toward the profile and colorbar
    delta_h=.3,  # fixed distance between axes (vertically)
    delta_w=.5,  # fixed distance between axes (horizontal
    fig_fontsize=_xfsize,
    interpolation="gaussian",
)

##################################################################
for ii, (name, g) in enumerate(_valency_groups):

    _size_id = 0

    _idx = g["size"].argsort()
    # _idx = stackups_anchor["H3K27me3"][g.index].mean(axis=1).argsort()

    axx[_size_id][ii].fill_betweenx(
        np.arange(len(g)),
        g["size"].iloc[_idx],
        0,
        # marker="",
        # alpha=0.8,
        linewidth=5,
        edgecolor="gray",
        facecolor="lightgrey",
        interpolate=True,
    )
    axx[_size_id][ii].set_xlim([0,250_000])
    axx[_size_id][ii].set_ylim(len(g), 0)
    # axx[_size_id][ii].set_xscale("log")

    axx[_size_id][ii].set_yticks([])
    axx[_size_id][ii].set_yticklabels([])

    if ii == 2:
        # axx[_size_id][ii].set_xticks([_tpmin,_tpmax])
        # # axx[_size_id][ii].set_xticklabels([f"{_tpmin:.1f}",f"{_tpmax:.1f}"],fontsize=_xfsize)
        for _tidx, tick in enumerate(axx[_size_id][ii].xaxis.get_majorticklabels()):
            if _tidx == 0:
                tick.set_horizontalalignment("left")
            else:
                tick.set_horizontalalignment("right")
        axx[_size_id][ii].tick_params(axis="x", length=6)
    else:
        axx[_size_id][ii].set_xticks([])
        axx[_size_id][ii].set_xticklabels([])

plt.savefig("FigE2B_stackup_Nup93.svg",dpi=300)