# Saddleplots - plotting only !

In [None]:
# import standard python libraries
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
import os, subprocess

In [None]:
# Import python package for working with cooler files and tools for analysis
import cooler
import cooltools.lib.plotting

In [None]:
%load_ext autoreload
%autoreload 2
# from saddle import saddleplot

In [None]:
# download test data
# this file is 145 Mb, and may take a few seconds to download
import bbi
import cooltools
import bioframe
from matplotlib.colors import LogNorm
from helper_func import saddleplot
from data_catalog import bws, bws_vlim, telo_dict

import saddle


In [None]:
from tqdm import tqdm
from tqdm.notebook import trange, tqdm
import warnings
import seaborn as sns

import warnings
import h5py

import matplotlib.lines as lines
from matplotlib.lines import Line2D
from matplotlib.patches import ConnectionPatch, Rectangle
from mpl_toolkits.axes_grid1 import Divider, Size
from mpl_toolkits.axes_grid1.inset_locator import BboxConnector
from matplotlib import cm
import matplotlib as mpl
# from mpl_toolkits.axes_grid1.Size import Fixed


# enable editable text ...
mpl.rcParams["pdf.fonttype"]=42
mpl.rcParams["svg.fonttype"]="none"
mpl.rcParams['axes.linewidth'] = 0.5

## Calculating per-chromosome compartmentalization

We first load the Hi-C data at 100 kbp resolution. 

Note that the current implementation of eigendecomposition in cooltools assumes that individual regions can be held in memory-- for hg38 at 100kb this is either a 2422x2422 matrix for chr2, or a 3255x3255 matrix for the full cooler here.

In [None]:
# define genomic view that will be used to call dots and pre-compute expected

# Use bioframe to fetch the genomic features from the UCSC.
hg38_chromsizes = bioframe.fetch_chromsizes('hg38')
hg38_cens = bioframe.fetch_centromeres('hg38')
hg38_arms_full = bioframe.make_chromarms(hg38_chromsizes, hg38_cens)
# # remove "bad" chromosomes and near-empty arms ...
included_arms = hg38_arms_full["name"].to_list()[:44] # all autosomal ones ...
hg38_arms = hg38_arms_full[hg38_arms_full["name"].isin(included_arms)].reset_index(drop=True)

In [None]:
# ! ls *.hdf5

In [None]:
! ls *.hdf5

In [None]:
def print_attrs(name, obj):
    # Create indent
    shift = name.count('/') * '    '
    item_name = name.split("/")[-1]
    print(shift + item_name)
    try:
        for key, val in obj.attrs.items():
            print(shift + '    ' + f"{key}: {val}")
    except:
        pass

# load EV saddles asis ...
print("loading EV saddles, along with some metadata ...")
print()


# create a track sample map to pass to the multiprocessing job ...
track_sample_map = {}
with h5py.File("saddles_EV_by_distance.hdf5", 'r') as fr:
    # fr.visititems(print_attrs)

    # check general metadata ...
    _saddle_meta = dict(fr.attrs)

    # # sort out the results ...
    # interaction_sums_trans_asis = {}
    # interaction_counts_trans_asis = {}
    # # sort out the results ...
    # _counts = fr.get("counts_trans")
    # for _sample in _counts.keys():
    #     interaction_counts_trans_asis[_sample] = _counts.get(_sample)[()]
    # _sums = fr.get("sums_trans")
    # for _sample in _sums.keys():
    #     interaction_sums_trans_asis[_sample] = _sums.get(_sample)[()]

    interaction_sums = {}
    interaction_counts = {}
    # sort out the results ...
    _counts = fr.get("counts")
    for _sample in _counts.keys():
        _cds = _counts.get(_sample)
        # extracting sample to track_sample mapping from HDF5 itself ...
        track_sample_map[_sample] = _cds.attrs["track"]
        interaction_counts[_sample] = _cds[()]
    _sums = fr.get("sums")
    for _sample in _sums.keys():
        _sds = _sums.get(_sample)
        interaction_sums[_sample] = _sds[()]


print(f"loaded {_saddle_meta=}")
print(f"loaded sample to track sample map {track_sample_map=}")

# unpack metadata ...
Q_LO = _saddle_meta["Q_LO"]
Q_HI = _saddle_meta["Q_HI"]
N_GROUPS = _saddle_meta["N_GROUPS"]
binsize = _saddle_meta["cis_binsize"]


#######################################
#   load control saddles here ...
########################################
# create a track sample map to pass to the multiprocessing job ...
track_ctrl_sample_map = {}
with h5py.File("saddles_ctrlEV_by_distance.hdf5", 'r') as fr:
    # fr.visititems(print_attrs)

    # check general metadata ...
    _saddle_ctrl_meta = dict(fr.attrs)

    # # sort out the results ...
    # interaction_sums_trans_asis = {}
    # interaction_counts_trans_asis = {}
    # # sort out the results ...
    # _counts = fr.get("counts_trans")
    # for _sample in _counts.keys():
    #     interaction_counts_trans_asis[_sample] = _counts.get(_sample)[()]
    # _sums = fr.get("sums_trans")
    # for _sample in _sums.keys():
    #     interaction_sums_trans_asis[_sample] = _sums.get(_sample)[()]

    interaction_ctrl_sums = {}
    interaction_ctrl_counts = {}
    # sort out the results ...
    _counts = fr.get("counts")
    for _sample in _counts.keys():
        _cds = _counts.get(_sample)
        # extracting sample to track_sample mapping from HDF5 itself ...
        track_ctrl_sample_map[_sample] = _cds.attrs["track"]
        interaction_ctrl_counts[_sample] = _cds[()]
    _sums = fr.get("sums")
    for _sample in _sums.keys():
        _sds = _sums.get(_sample)
        interaction_ctrl_sums[_sample] = _sds[()]


print(f"loaded {_saddle_ctrl_meta=}")
print(f"loaded sample to track sample map {track_ctrl_sample_map=}")

# Make control ones were generated using the same parameters ...
print(Q_LO == _saddle_ctrl_meta["Q_LO"])
print(Q_HI == _saddle_ctrl_meta["Q_HI"])
print(N_GROUPS == _saddle_ctrl_meta["N_GROUPS"])
print(binsize == _saddle_ctrl_meta["cis_binsize"])




## Pre-load pre-calculated EV1s and coolers - for whatever resolution ...

In [None]:
telo_cis_evs = {}
for k, _fname in telo_dict.items():
    # derive output name
    _fname = f"ev_bedraph/{k}.{binsize//1_000}kb.bed"
    print(f"reading {_fname} ...")
    telo_cis_evs[k] = bioframe.read_table(_fname, schema="bedGraph", header=0)

# cooler files that we'll work on :
telo_clrs = { _k: cooler.Cooler(f"{_path}::/resolutions/{binsize}") for _k, _path in telo_dict.items() }

In [None]:
# let's check how digitization goes here ...
from cooltools.api.saddle import align_track_with_cooler, digitize
from cooltools.api.saddle import saddle_stack

print(f"we use {Q_LO=} {Q_HI=} and split the rest in {N_GROUPS=} ...")

n_bins = int(N_GROUPS)
qrange=(Q_LO,Q_HI)

_sample = "m5hR1R2"
_mean_ev_bins = {}
for _sample in ["m5hR1R2", "p5hR1R2"]+["N93m5", "N93p5"]:
    _track_sample = track_sample_map[_sample]
    # track and contact map ...
    _ev_track = telo_cis_evs[_track_sample]
    _clr = telo_clrs[_sample]
    # align track to cooler - whatever that means
    track = align_track_with_cooler(
        _ev_track,
        _clr,
        view_df=hg38_arms,
        clr_weight_name="weight",
        mask_clr_bad_bins=True,
        drop_track_na=False,  # this adds check for chromosomes that have all missing values
    )
    # digitize continous track into N_GROUPS ...
    digitized_track, binedges = digitize(
        track.iloc[:, :4],
        n_bins,
        qrange=qrange,
        digitized_suffix=".d",
    )
    # grouping "track" with "digitized_track" - since their index matches
    _mean_ev_bins[_sample] = track.groupby(digitized_track["value.d"])["value"].mean()


for _sample, _color, _label in zip(
    ["m5hR1R2", "p5hR1R2"]+["N93m5", "N93p5"],
    ["blue","red"]+["green","black"],
    ["ctrl@5h","depletion@5h"]+["ctrlN@5h","depletionN@5h"],
):
    _mean_ev_bins[_sample].plot(
        kind="bar",
        width=1,
        color=_color,
        label=_label,
        alpha=0.5,
    )
ax = plt.gca()
ax.axhline(0, color="grey")
plt.legend(frameon=False)

In [None]:




def saddle_strength(k, sums_stack, counts_stack, dist_range=None):
    """
    saddle strength aka contrast calculation ...
    """
    if dist_range is not None:
        S = np.nansum(sums_stack[dist_range], axis=0)
        C = np.nansum(counts_stack[dist_range], axis=0)
    else:
        S = np.nansum(sums_stack, axis=0)
        C = np.nansum(counts_stack, axis=0)

    # exclude extremes - the outliers
    S = S[1:-1,1:-1]
    C = C[1:-1,1:-1]

    m, n = S.shape
    if m != n:
        raise ValueError("`saddledata` should be square.")

    # _b corner indices ...
    _b = slice(0, k)
    # _a corner indices ...
    _a = slice(n-k, n)

    # make sure corners are equally sized ...
    assert (_b.stop - _b.start) == (_a.stop - _a.start)

    intra_BB = np.nansum(S[_b, _b]) / np.nansum(C[_b, _b])
    intra_AA = np.nansum(S[_a, _a]) / np.nansum(C[_a, _a])
    intra_AA_BB = (
        (np.nansum(S[_b, _b]) + np.nansum(S[_a, _a])) /
        (np.nansum(C[_b, _b]) + np.nansum(C[_a, _a]))
    )
    inter_AB_BA = (
        (np.nansum(S[_b, _a]) + np.nansum(S[_a, _b])) /
        (np.nansum(C[_b, _a]) + np.nansum(C[_a, _b]))
    )
    # ...
    # ...
    return {"AA" : intra_AA/inter_AB_BA, "BB" : intra_BB/inter_AB_BA, "AA_BB": intra_AA_BB/inter_AB_BA}

# # https://stackoverflow.com/questions/48625475/python-shifted-logarithmic-colorbar-white-color-offset-to-center
# class MidPointLogNorm(LogNorm):
#     def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
#         LogNorm.__init__(self,vmin=vmin, vmax=vmax, clip=clip)
#         self.midpoint=midpoint
#     def __call__(self, value, clip=None):
#         # I'm ignoring masked values and all kinds of edge cases to make a
#         # simple example...
#         x, y = [np.log(self.vmin), np.log(self.midpoint), np.log(self.vmax)], [0, 0.5, 1]
#         return np.ma.masked_array(np.interp(np.log(value), x, y))



# https://stackoverflow.com/questions/48625475/python-shifted-logarithmic-colorbar-white-color-offset-to-center
class MidPointLogNorm(LogNorm):
    # to do introduce clipping ...
    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        LogNorm.__init__(self,vmin=vmin, vmax=vmax, clip=clip)
        self.midpoint=midpoint

    def __call__(self, value, clip=None):
        # I'm ignoring masked values and all kinds of edge cases to make a
        # simple example...
        vmin, midpoint, vmax = self.vmin, self.midpoint, self.vmax
        x, y = [np.log(vmin), np.log(midpoint), np.log(vmax)], [0, 0.5, 1]
        return np.ma.masked_array(np.interp(np.log(value), x, y))

    def inverse(self, value):
        if not self.scaled():
            raise ValueError("Not invertible until scaled")
        # t_vmin, t_midpoint, t_vmax = np.log(self.vmin), np.log(self.midpoint), np.log(self.vmax)
        vmin, midpoint, vmax = self.vmin, self.midpoint, self.vmax

        x, y = [0, 0.5, 1], [np.log(vmin), np.log(midpoint), np.log(vmax)]
        # # return np.ma.masked_array(np.interp(np.log(value), x, y))
        # if np.iterable(value):
        #     val = np.ma.asarray(value)
        #     return np.ma.power(val, 1. / gamma) * (vmax - vmin) + vmin
        # else:
        # return pow(value, 1. / gamma) * (vmax - vmin) + vmin
        return np.exp(np.interp(value, x, y))






def get_saddle_data(sample, dist_name, dist_range=None):
    """
    little convenience func - to turn local interaction_sums and interaction_counts
    into saddle data ...
    """
    # if dist_name == "trans":
    #     _sum = np.nansum(interaction_sums_trans[sample], axis=0)
    #     _count = np.nansum(interaction_counts_trans[sample], axis=0)
    if dist_name == "trans":
        _sum = np.nansum(interaction_sums[sample], axis=0)
        _count = np.nansum(interaction_counts[sample], axis=0)
    else:
        if dist_range is not None:
            _sum = np.nansum(interaction_sums[sample][dist_range], axis=0)
            _count = np.nansum(interaction_counts[sample][dist_range], axis=0)
        else:
            _sum = np.nansum(interaction_sums[sample], axis=0)
            _count = np.nansum(interaction_counts[sample], axis=0)
    return _sum / _count


imshow_kwargs = dict(
        # norm=LogNorm(vmin=1/5, vmax=5),
        norm=MidPointLogNorm(vmin=1/5, vmax=3, midpoint=1),
        cmap="RdBu_r",
        interpolation="nearest",
)


# cbarw = 0.7*matw
margin = 0.2
matw = 0.75
cbarh = 0.1

# distances = {
#     "all-cis": slice(None),
#     "trans": slice(None),
# }
distances = {
    "all-cis": slice(None),
    # "trans": slice(None),
}


In [None]:

h = [
    Size.Fixed(margin),
    Size.Fixed(0.5*cbarh),
    Size.Fixed(0.2*margin),
    Size.Fixed(matw), #ctrl
    Size.Fixed(0.2*margin),
    Size.Fixed(matw),  # depletion
    Size.Fixed(0.2*margin),
    Size.Fixed(cbarh),
    Size.Fixed(margin),
]

# goes from bottom to the top ...
v = [
    # single color bar at the vewry bottom
    Size.Fixed(margin),
    Size.Fixed(0.5*cbarh),
    Size.Fixed(0.2*margin),
    Size.Fixed(matw),
    Size.Fixed(0.2*margin),
    Size.Fixed(2.5*cbarh),
    Size.Fixed(margin),
]
_stickingout_bit = sum(_h.fixed_size for _h in h[-3:])
print(f"{_stickingout_bit=}")

# set figsize based on the tiling provided ...
fig_width = sum(_h.fixed_size for _h in h)
fig_height = sum(_v.fixed_size for _v in v)
fig = plt.figure(
    figsize=(fig_width, fig_height),
    # facecolor='lightblue'
)
print(f"figure size {fig_width=} {fig_height=}")
# ...
divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)


samples = ["m5hR1R2", "p5hR1R2"]
mev_colors = ["tab:blue", "tab:red"]

axs = {}
axq_hor = {}
axm = {}
for i, _sample in enumerate(samples):
    _nx = 2*(i+1)+1
    _ny = 3
    axs[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx, ny=_ny))
    axq_hor[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx, ny=_ny-2))
    axm[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx, ny=_ny+2))
axlow = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=3, nx1=6, ny=0))
axq_ver = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=1, ny=_ny))
# colorbar ...
cbar_ax = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx+2, ny=_ny))

# traverse nested dict to access the axes ...
for _ax in (
    list(axs.values()) + list(axq_hor.values()) + list(axm.values()) + [axq_ver, cbar_ax]
):
    _ax.set_xticks([])
    _ax.set_yticks([])

for _sample, _mcolor in zip(samples, mev_colors):
    for _dist_key, _dist in distances.items():
        if _dist_key != "trans":
            C = np.nanmean(interaction_sums[_sample][_dist], axis=0) / np.nanmean(interaction_counts[_sample][_dist], axis=0)
            _strength_dict = saddle_strength(11, interaction_sums[_sample], interaction_counts[_sample], dist_range=_dist)
        elif _dist_key == "trans":
            C = np.nanmean(interaction_sums_trans[_sample][_dist], axis=0) / np.nanmean(interaction_counts_trans[_sample][_dist], axis=0)
            _strength_dict = saddle_strength(11, interaction_sums_trans[_sample], interaction_counts_trans[_sample], dist_range=_dist)
        else:
            pass
        _h = axs[_sample].imshow(C[1:-1,1:-1], **imshow_kwargs)
        _h.cmap.set_over("#300000")
        _h.cmap.set_under("black")
        _sa = _strength_dict["AA"]
        _sb = _strength_dict["BB"]
        _bx = 1
        _ax = C.shape[0]-2-1
        axs[_sample].text(_bx, _bx, f"{_sb:.2f}", fontsize=8, ha="left", va="top")
        axs[_sample].text(_ax, _ax, f"{_sa:.2f}", fontsize=8, ha="right", va="bottom")
        # ...
    data = _mean_ev_bins[_sample].loc[1:n_bins]
    assert len(data) == n_bins
    axm[_sample].fill_between(data.index, data, 0, color=_mcolor, ec="grey", step="mid", linewidth=0.25)
    axm[_sample].set_ylim(-1.2,1.2)
    axm[_sample].set_xlim(1-0.1,n_bins+0.1)
    axm[_sample].spines[:].set_visible(False)
    axm[_sample].axhline(0,color="grey",lw=0.5)
    axm[_sample].spines["left"].set_visible(True)
    _track_sample = track_sample_map[_sample].rstrip("R1R2")
    axm[_sample].set_title(f"{_sample}@{_track_sample}", fontsize=6, pad=0)
    # ev quantiles ...
    axq_hor[_sample].fill_between(data.index, np.asarray(data.index), 0, color="grey", ec="grey", step="mid", linewidth=0.5)
    axq_hor[_sample].set_ylim(0, n_bins+1)
    axq_hor[_sample].set_xlim(1-0.1,n_bins+0.1)
    axq_hor[_sample].spines[:].set_visible(False)
    axq_hor[_sample].invert_yaxis()
    # axm[_sample].axhline(0,color="grey",lw=0.5)
    # # axm[_sample].spines["right"].set_visible(True)
    # axm[_sample].spines["left"].set_visible(True)
axq_ver.fill_betweenx(data.index, np.asarray(data.index), 0, color="grey", ec="grey", step="mid", linewidth=0.5)
axq_ver.set_xlim(0, n_bins+1)
axq_ver.set_ylim(1-0.1,n_bins+0.1)
axq_ver.spines[:].set_visible(False)
axq_ver.invert_yaxis()
axq_ver.invert_xaxis()
axq_ver.set_ylabel(
    "EV1 quantiles",
    fontsize=6,
    labelpad=0,
)

axlow.text(
    0.5, 0.9,
    "EV1 quantiles",
    ha='center', va='top',
    transform = axlow.transAxes,
    fontsize=6,
)
axlow.axis("off")


# add a single colorbar ...
fig.colorbar(
    cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
    cax=cbar_ax,
    orientation="vertical",
)
_vmin = imshow_kwargs["norm"].vmin
_midpoint = imshow_kwargs["norm"].midpoint
_vmax = imshow_kwargs["norm"].vmax
cbar_ax.set_yticks(
    [_vmin, _midpoint, _vmax],
    labels=[f"{v:.1f}" for v in [_vmin, _midpoint, _vmax]],
    fontsize=6,
)
cbar_ax.minorticks_off()
cbar_ax.tick_params(length=1.0, pad=1)
for _tidx, tick in enumerate(cbar_ax.yaxis.get_majorticklabels()):
    if _tidx == 0:
        tick.set_verticalalignment("bottom")
    elif _tidx == 2:
        tick.set_verticalalignment("top")
    else:
        tick.set_verticalalignment("center")

fig.savefig("Fig6B.svg", dpi=300)

In [None]:

h = [
    Size.Fixed(margin),
    Size.Fixed(0.5*cbarh),
    Size.Fixed(0.2*margin),
    Size.Fixed(matw), #ctrl
    Size.Fixed(0.2*margin),
    Size.Fixed(matw),  # depletion
    Size.Fixed(0.2*margin),
    Size.Fixed(cbarh),
    Size.Fixed(margin),
]

# goes from bottom to the top ...
v = [
    # single color bar at the vewry bottom
    Size.Fixed(margin),
    Size.Fixed(0.5*cbarh),
    Size.Fixed(0.2*margin),
    Size.Fixed(matw),
    Size.Fixed(0.2*margin),
    Size.Fixed(2.5*cbarh),
    Size.Fixed(margin),
]
_stickingout_bit = sum(_h.fixed_size for _h in h[-3:])
print(f"{_stickingout_bit=}")

# set figsize based on the tiling provided ...
fig_width = sum(_h.fixed_size for _h in h)
fig_height = sum(_v.fixed_size for _v in v)
fig = plt.figure(
    figsize=(fig_width, fig_height),
    # facecolor='lightblue'
)
print(f"figure size {fig_width=} {fig_height=}")
# ...
divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)


samples = ["N93m5", "N93p5"]
mev_colors = ["tab:blue", "tab:red"]

axs = {}
axq_hor = {}
axm = {}
for i, _sample in enumerate(samples):
    _nx = 2*(i+1)+1
    _ny = 3
    axs[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx, ny=_ny))
    axq_hor[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx, ny=_ny-2))
    axm[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx, ny=_ny+2))
axlow = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=3, nx1=6, ny=0))
axq_ver = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=1, ny=_ny))
# colorbar ...
cbar_ax = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx+2, ny=_ny))

# traverse nested dict to access the axes ...
for _ax in (
    list(axs.values()) + list(axq_hor.values()) + list(axm.values()) + [axq_ver, cbar_ax]
):
    _ax.set_xticks([])
    _ax.set_yticks([])


for _sample, _mcolor in zip(samples, mev_colors):
    for _dist_key, _dist in distances.items():
        if _dist_key != "trans":
            C = np.nanmean(interaction_sums[_sample][_dist], axis=0) / np.nanmean(interaction_counts[_sample][_dist], axis=0)
            _strength_dict = saddle_strength(11, interaction_sums[_sample], interaction_counts[_sample], dist_range=_dist)
        elif _dist_key == "trans":
            C = np.nanmean(interaction_sums_trans[_sample][_dist], axis=0) / np.nanmean(interaction_counts_trans[_sample][_dist], axis=0)
            _strength_dict = saddle_strength(11, interaction_sums_trans[_sample], interaction_counts_trans[_sample], dist_range=_dist)
        else:
            pass
        _h = axs[_sample].imshow(C[1:-1,1:-1], **imshow_kwargs)
        _h.cmap.set_over("#300000")
        _h.cmap.set_under("black")
        _sa = _strength_dict["AA"]
        _sb = _strength_dict["BB"]
        _bx = 1
        _ax = C.shape[0]-2-1
        axs[_sample].text(_bx, _bx, f"{_sb:.2f}", fontsize=8, ha="left", va="top")
        axs[_sample].text(_ax, _ax, f"{_sa:.2f}", fontsize=8, ha="right", va="bottom")
        # ...
    data = _mean_ev_bins[_sample].loc[1:n_bins]
    assert len(data) == n_bins
    axm[_sample].fill_between(data.index, data, 0, color=_mcolor, ec="grey", step="mid", linewidth=0.25)
    axm[_sample].set_ylim(-1.2,1.2)
    axm[_sample].set_xlim(1-0.1,n_bins+0.1)
    axm[_sample].spines[:].set_visible(False)
    axm[_sample].axhline(0,color="grey",lw=0.5)
    axm[_sample].spines["left"].set_visible(True)
    _track_sample = track_sample_map[_sample].rstrip("R1R2")
    axm[_sample].set_title(f"{_sample}@{_track_sample}", fontsize=6, pad=0)
    # ev quantiles ...
    axq_hor[_sample].fill_between(data.index, np.asarray(data.index), 0, color="grey", ec="grey", step="mid", linewidth=0.5)
    axq_hor[_sample].set_ylim(0, n_bins+1)
    axq_hor[_sample].set_xlim(1-0.1,n_bins+0.1)
    axq_hor[_sample].spines[:].set_visible(False)
    axq_hor[_sample].invert_yaxis()
    # axm[_sample].axhline(0,color="grey",lw=0.5)
    # # axm[_sample].spines["right"].set_visible(True)
    # axm[_sample].spines["left"].set_visible(True)
axq_ver.fill_betweenx(data.index, np.asarray(data.index), 0, color="grey", ec="grey", step="mid", linewidth=0.5)
axq_ver.set_xlim(0, n_bins+1)
axq_ver.set_ylim(1-0.1,n_bins+0.1)
axq_ver.spines[:].set_visible(False)
axq_ver.invert_yaxis()
axq_ver.invert_xaxis()
axq_ver.set_ylabel(
    "EV1 quantiles",
    fontsize=6,
    labelpad=0,
)

axlow.text(
    0.5, 0.9,
    "EV1 quantiles",
    ha='center', va='top',
    transform = axlow.transAxes,
    fontsize=6,
)
axlow.axis("off")


# add a single colorbar ...
fig.colorbar(
    cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
    cax=cbar_ax,
    orientation="vertical",
)
_vmin = imshow_kwargs["norm"].vmin
_midpoint = imshow_kwargs["norm"].midpoint
_vmax = imshow_kwargs["norm"].vmax
cbar_ax.set_yticks(
    [_vmin, _midpoint, _vmax],
    labels=[f"{v:.1f}" for v in [_vmin, _midpoint, _vmax]],
    fontsize=6,
)
cbar_ax.minorticks_off()
cbar_ax.tick_params(length=1.0, pad=1)
for _tidx, tick in enumerate(cbar_ax.yaxis.get_majorticklabels()):
    if _tidx == 0:
        tick.set_verticalalignment("bottom")
    elif _tidx == 2:
        tick.set_verticalalignment("top")
    else:
        tick.set_verticalalignment("center")

fig.savefig("FigExt76C_nup.svg", dpi=300)


In [None]:

h = [
    Size.Fixed(margin),
    Size.Fixed(0.5*cbarh),
    # Size.Fixed(0.2*margin),
    # Size.Fixed(matw), #ctrl
    Size.Fixed(0.2*margin),
    Size.Fixed(matw),  # depletion
    Size.Fixed(0.2*margin),
    Size.Fixed(cbarh),
    Size.Fixed(margin),
]

# goes from bottom to the top ...
v = [
    # single color bar at the vewry bottom
    Size.Fixed(margin),
    Size.Fixed(0.5*cbarh),
    Size.Fixed(0.2*margin),
    Size.Fixed(matw),
    Size.Fixed(0.2*margin),
    Size.Fixed(2.5*cbarh),
    Size.Fixed(margin),
]
_stickingout_bit = sum(_h.fixed_size for _h in h[-3:])
print(f"{_stickingout_bit=}")

# set figsize based on the tiling provided ...
fig_width = sum(_h.fixed_size for _h in h)
fig_height = sum(_v.fixed_size for _v in v)
fig = plt.figure(
    figsize=(fig_width, fig_height),
    # facecolor='lightblue'
)
print(f"figure size {fig_width=} {fig_height=}")
# ...
divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)


# samples = ["N93m5", "N93p5"]
# mev_colors = ["tab:blue", "tab:red"]
samples = ["N93p5",]
mev_colors = ["tab:blue",]

axs = {}
axq_hor = {}
axm = {}
for i, _sample in enumerate(samples):
    _nx = 2*(i+1)+1
    _ny = 3
    axs[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx, ny=_ny))
    axq_hor[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx, ny=_ny-2))
    axm[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx, ny=_ny+2))
axlow = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=3, nx1=6, ny=0))
axq_ver = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=1, ny=_ny))
# colorbar ...
cbar_ax = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx+2, ny=_ny))

# traverse nested dict to access the axes ...
for _ax in (
    list(axs.values()) + list(axq_hor.values()) + list(axm.values()) + [axq_ver, cbar_ax]
):
    _ax.set_xticks([])
    _ax.set_yticks([])


for _sample, _mcolor in zip(samples, mev_colors):
    for _dist_key, _dist in distances.items():
        if _dist_key != "trans":
            C = np.nanmean(interaction_ctrl_sums[_sample][_dist], axis=0) / np.nanmean(interaction_ctrl_counts[_sample][_dist], axis=0)
            _strength_dict = saddle_strength(11, interaction_ctrl_sums[_sample], interaction_ctrl_counts[_sample], dist_range=_dist)
        elif _dist_key == "trans":
            C = np.nanmean(interaction_ctrl_sums_trans[_sample][_dist], axis=0) / np.nanmean(interaction_ctrl_counts_trans[_sample][_dist], axis=0)
            _strength_dict = saddle_strength(11, interaction_ctrl_sums_trans[_sample], interaction_ctrl_counts_trans[_sample], dist_range=_dist)
        else:
            pass
        _h = axs[_sample].imshow(C[1:-1,1:-1], **imshow_kwargs)
        _h.cmap.set_over("#300000")
        _h.cmap.set_under("black")
        _sa = _strength_dict["AA"]
        _sb = _strength_dict["BB"]
        _bx = 1
        _ax = C.shape[0]-2-1
        axs[_sample].text(_bx, _bx, f"{_sb:.2f}", fontsize=8, ha="left", va="top")
        axs[_sample].text(_ax, _ax, f"{_sa:.2f}", fontsize=8, ha="right", va="bottom")
        # ...
    _track_sample = track_ctrl_sample_map[_sample]
    _track_sample_text = _track_sample.rstrip("R1R2")
    data = _mean_ev_bins[_track_sample].loc[1:n_bins]
    assert len(data) == n_bins
    axm[_sample].fill_between(data.index, data, 0, color=_mcolor, ec="grey", step="mid", linewidth=0.25)
    axm[_sample].set_ylim(-1.2,1.2)
    axm[_sample].set_xlim(1-0.1,n_bins+0.1)
    axm[_sample].spines[:].set_visible(False)
    axm[_sample].axhline(0,color="grey",lw=0.5)
    axm[_sample].spines["left"].set_visible(True)
    _track_sample = track_ctrl_sample_map[_sample].rstrip("R1R2")
    axm[_sample].set_title(f"{_sample}@{_track_sample_text}", fontsize=6, pad=0)
    # ev quantiles ...
    axq_hor[_sample].fill_between(data.index, np.asarray(data.index), 0, color="grey", ec="grey", step="mid", linewidth=0.5)
    axq_hor[_sample].set_ylim(0, n_bins+1)
    axq_hor[_sample].set_xlim(1-0.1,n_bins+0.1)
    axq_hor[_sample].spines[:].set_visible(False)
    axq_hor[_sample].invert_yaxis()
    # axm[_sample].axhline(0,color="grey",lw=0.5)
    # # axm[_sample].spines["right"].set_visible(True)
    # axm[_sample].spines["left"].set_visible(True)
axq_ver.fill_betweenx(data.index, np.asarray(data.index), 0, color="grey", ec="grey", step="mid", linewidth=0.5)
axq_ver.set_xlim(0, n_bins+1)
axq_ver.set_ylim(1-0.1,n_bins+0.1)
axq_ver.spines[:].set_visible(False)
axq_ver.invert_yaxis()
axq_ver.invert_xaxis()
axq_ver.set_ylabel(
    "EV1 quantiles",
    fontsize=6,
    labelpad=0,
)

axlow.text(
    0.5, 0.9,
    "EV1 quantiles",
    ha='center', va='top',
    transform = axlow.transAxes,
    fontsize=6,
)
axlow.axis("off")


# add a single colorbar ...
fig.colorbar(
    cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
    cax=cbar_ax,
    orientation="vertical",
)
_vmin = imshow_kwargs["norm"].vmin
_midpoint = imshow_kwargs["norm"].midpoint
_vmax = imshow_kwargs["norm"].vmax
cbar_ax.set_yticks(
    [_vmin, _midpoint, _vmax],
    labels=[f"{v:.1f}" for v in [_vmin, _midpoint, _vmax]],
    fontsize=6,
)
cbar_ax.minorticks_off()
cbar_ax.tick_params(length=1.0, pad=1)
for _tidx, tick in enumerate(cbar_ax.yaxis.get_majorticklabels()):
    if _tidx == 0:
        tick.set_verticalalignment("bottom")
    elif _tidx == 2:
        tick.set_verticalalignment("top")
    else:
        tick.set_verticalalignment("center")

fig.savefig("FigExt76D_nup_ctrl.svg", dpi=300)


# Legacy plotting ...

In [None]:
sub_samples_m = [
    "mMito",
    "mTelo",
    "mCyto",
    "m5hR1R2",
    "m10hR1R2",
]
sub_samples_p = [
    "pMito",
    "pTelo",
    "pCyto",
    "p5hR1R2",
    "p10hR1R2",
]

fig, axs = plt.subplots(
    nrows=len(sub_samples_m),
    ncols=2*len(distances),
    figsize=(4*len(distances),2*len(sub_samples_m)),
    sharex=True,
    sharey=True,
)

for sample_m, sample_p, (i, _axs) in zip(sub_samples_m, sub_samples_p, enumerate(axs)):
    for jj, (_dist_name, _dist) in enumerate(distances.items()):
        axm, axp = _axs[jj], _axs[len(distances) + jj]
        Cm = get_saddle_data(sample_m, _dist_name, _dist)
        Cp = get_saddle_data(sample_p, _dist_name, _dist)
        axm.imshow(Cm[1:-1,1:-1], **imshow_kwargs)
        axp.imshow(Cp[1:-1,1:-1], **imshow_kwargs)

# annotate labels and titles ...
for jj, _dist_name in enumerate(distances):
    # m ...
    axs[0, jj].set_title(f"m-{_dist_name}")
    # p ...
    axs[0, len(distances) + jj].set_title(f"p-{_dist_name}")
for ii, _sample in enumerate(sub_samples_m):
    axs[ii,0].set_ylabel(_sample.lstrip("m"))
    axs[ii,0].set_yticks([])
    axs[ii,0].set_xticks([])


In [None]:
sub_samples_m = [
    "mCyto",
    "m5hR1R2",
]
sub_samples_p = [
    "pCyto",
    "p5hR1R2",
]

# introduce distance ranges
distances = {
    "cis": slice(None),
    "trans": slice(None),
}


# cbarw = 0.7*matw
margin = 0.2
matw = 0.75
cbarh = 0.1

imshow_kwargs = dict(
        norm=LogNorm(vmin=1/3, vmax=3),
        cmap="RdBu_r",
        interpolation="nearest",
)




# create legends ......
IPG_cmap = {
    3: "cornflowerblue",  # B
    2: "#ffee99",  # V+VI
    1: "orangered",  # A2
    0: "maroon",  # A1
    # 0: "#D9E2EF",  #"none"
}
ticklabels_ipg=["B","VVI","A2","A1"]


IPGwID_cmap = {
    6: "cornflowerblue",  # B
    5: "#ffee99",  # V+VI
    4: "#ffee99",  # V+VI-ID
    3: "orangered",  # A2
    2: "orangered",  # A2-ID
    1: "maroon",  # A1
    0: "maroon",  # A1-ID
    # 0: "#D9E2EF",  #"none"
}
ticklabels_ipgid=["B","VVI","VVI-ID","A2","A2-ID","A1","A1-ID"]


## Draw actual figures with the semi-manual custom layout ...

In [None]:

h = [
    Size.Fixed(margin),
    # Cyto
    Size.Fixed(matw),
    Size.Fixed(0.2*margin),
    Size.Fixed(matw),
    Size.Fixed(margin),
    # 5hr
    Size.Fixed(matw),
    Size.Fixed(0.2*margin),
    Size.Fixed(matw),
    # bar ...
    Size.Fixed(0.2*margin),
    Size.Fixed(cbarh),
    Size.Fixed(margin),
]

# goes from bottom to the top ...
v = [
    # single color bar at the vewry bottom
    Size.Fixed(margin),
    Size.Fixed(0.5*cbarh),
    # bottom - with IDs
    Size.Fixed(0.5*margin),
    Size.Fixed(cbarh),
    Size.Fixed(0.2*margin),
    Size.Fixed(matw),
    Size.Fixed(0.2*margin),
    Size.Fixed(matw),
    # upper - as is ...
    Size.Fixed(margin),
    Size.Fixed(cbarh),
    Size.Fixed(0.2*margin),
    Size.Fixed(matw),
    Size.Fixed(0.2*margin),
    Size.Fixed(matw),
    Size.Fixed(margin),
]
_stickingout_bit = sum(_h.fixed_size for _h in h[-3:])
print(f"{_stickingout_bit=}")

# set figsize based on the tiling provided ...
fig_width = sum(_h.fixed_size for _h in h)
fig_height = sum(_v.fixed_size for _v in v)
fig = plt.figure(
    figsize=(fig_width, fig_height),
    # facecolor='lightblue'
)
print(f"figure size {fig_width=} {fig_height=}")
# ...
divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)



axs = {}
for i, _sample in enumerate(sub_samples_m + sub_samples_p):
    axs[_sample] = {}
    _nx_i = 2*i+1
    axs[_sample]["bottom_legend"] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx_i, ny=1+2))
    axs[_sample]["trans_ID_saddle"] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx_i, ny=3+2))
    axs[_sample]["cis_ID_saddle"] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx_i, ny=5+2))
    axs[_sample]["upper_legend"] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx_i, ny=7+2))
    axs[_sample]["trans_IPG_saddle"] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx_i, ny=9+2))
    axs[_sample]["cis_IPG_saddle"] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx_i, ny=11+2))
# colorbar ...
cbar_ax = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx_i, ny=1))
# very last column with legends ....
_nx_i = 2*(i+1)+1
axs["bottom_right_legend"] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx_i, ny=3+2))
axs["bottom_right_legend2"] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx_i, ny=5+2))
axs["upper_right_legend"] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx_i, ny=9+2))
axs["upper_right_legend2"] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx_i, ny=11+2))

# traverse nested dict to access the axes ...
for _ax in sum([ [v for v in axs[s].values()] for s in (sub_samples_m+sub_samples_p) ], start=[]):
    _ax.set_xticks([])
    _ax.set_yticks([])
# do the remaining ones (not nested one as well) ...
for _ax in [ v for k,v in axs.items() if k not in (sub_samples_m+sub_samples_p) ]:
    _ax.set_xticks([])
    _ax.set_yticks([])
cbar_ax.set_xticks([])
cbar_ax.set_yticks([])

# reorder for the ones with IDs ...
_reidxs = [0,3,1,4,2,5,6]


for _sample in (sub_samples_m+sub_samples_p):
    for _dist_key, _dist in distances.items():
        # IPG saddles As is ...
        if _dist_key != "trans":
            C = np.nanmean(interaction_sums_asis[_sample][_dist], axis=0) / np.nanmean(interaction_counts_asis[_sample][_dist], axis=0)
        elif _dist_key == "trans":
            C = np.nanmean(interaction_sums_trans_asis[_sample][_dist], axis=0) / np.nanmean(interaction_counts_trans_asis[_sample][_dist], axis=0)
        else:
            pass
        _h = axs[_sample][f"{_dist_key}_IPG_saddle"].imshow(C[1:,1:], **imshow_kwargs)
        _h.cmap.set_over("#300000")
        # extract size ...
        _num_ipg_colors, _ = C[1:,1:].shape
        # IPG saddles with IDs ...
        if _dist_key != "trans":
            C = np.nanmean(interaction_sums_wids[_sample][_dist], axis=0) / np.nanmean(interaction_counts_wids[_sample][_dist], axis=0)
        elif _dist_key == "trans":
            C = np.nanmean(interaction_sums_trans_wids[_sample][_dist], axis=0) / np.nanmean(interaction_counts_trans_wids[_sample][_dist], axis=0)
        else:
            pass
        _h = axs[_sample][f"{_dist_key}_ID_saddle"].imshow(C[1:,1:][_reidxs][:,_reidxs], **imshow_kwargs)
        _h.cmap.set_over("#300000")
        # extract size ...
        _num_ipg_wIDs_colors, _ = C[1:,1:].shape

# create a "fake" legend for IPGs - for now ...
_fdata = np.reshape(np.arange(_num_ipg_colors), (-1,1))
_fcmap = plt.cm.gray
_fnorm = plt.Normalize()
_frgba = _fcmap(_fnorm(_fdata))
_frgbaT = _fcmap(_fnorm(_fdata.T))
# ...
for i in range(_num_ipg_colors):
    _frgba[i,0] = list(mpl.colors.to_rgb(IPG_cmap[i]))+[1]
    _frgbaT[0,i] = list(mpl.colors.to_rgb(IPG_cmap[i]))+[1]

for _sample in (sub_samples_m+sub_samples_p):
    axs[_sample]["upper_legend"].imshow(_frgbaT, aspect="auto")
    for _ in range(max(IPG_cmap.keys())):
        axs[_sample]["upper_legend"].axvline(_+.5,color="black",lw=0.5)
axs["upper_right_legend"].imshow(_frgba, aspect="auto")
axs["upper_right_legend2"].imshow(_frgba, aspect="auto")
for _ in range(max(IPG_cmap.keys())):
    axs["upper_right_legend"].axhline(_+.5,color="black",lw=0.5)
    axs["upper_right_legend2"].axhline(_+.5,color="black",lw=0.5)
# try adding labels to the right ...
axs["upper_right_legend"].yaxis.tick_right()
axs["upper_right_legend"].set_yticks(
    list(IPG_cmap.keys()),
    labels=ticklabels_ipg,
    fontsize=6,
)
axs["upper_right_legend"].tick_params(length=1.5, pad=1)

# create a "fake" legend for IPGs - for now ...
_fdata = np.reshape(np.arange(_num_ipg_wIDs_colors), (-1,1))
_frgba = _fcmap(_fnorm(_fdata))
_frgbaT = _fcmap(_fnorm(_fdata.T))
# ...
for i in range(_num_ipg_wIDs_colors):
    _frgba[i,0] = list(mpl.colors.to_rgb(IPGwID_cmap[i]))+[1]
    _frgbaT[0,i] = list(mpl.colors.to_rgb(IPGwID_cmap[i]))+[1]

for _sample in (sub_samples_m+sub_samples_p):
    axs[_sample]["bottom_legend"].imshow(_frgbaT, aspect="auto")
    for _ in range(max(IPGwID_cmap.keys())):
        axs[_sample]["bottom_legend"].axvline(_+.5,color="black",lw=0.5)
axs["bottom_right_legend"].imshow(_frgba, aspect="auto")
axs["bottom_right_legend2"].imshow(_frgba, aspect="auto")
for _ in range(max(IPGwID_cmap.keys())):
    axs["bottom_right_legend"].axhline(_+.5,color="black",lw=0.5)
    axs["bottom_right_legend2"].axhline(_+.5,color="black",lw=0.5)
# try adding labels to the right ...
axs["bottom_right_legend"].yaxis.tick_right()
axs["bottom_right_legend"].set_yticks(
    list(IPGwID_cmap.keys()),
    labels=ticklabels_ipgid,
    fontsize=6,
)
axs["bottom_right_legend"].tick_params(length=1.5, pad=1)


# add a single colorbar ...
fig.colorbar(
    cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
    cax=cbar_ax,
    orientation="horizontal",
)
_vmin = imshow_kwargs["norm"].vmin
_vmax = imshow_kwargs["norm"].vmax
cbar_ax.set_xticks([_vmin, 1, _vmax])
cbar_ax.set_xticklabels([f"{_vmin:.2f}", 1, _vmax], fontsize=6)
cbar_ax.minorticks_off()
cbar_ax.tick_params(length=1.5, pad=1)#,direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)
for _tidx, tick in enumerate(cbar_ax.xaxis.get_majorticklabels()):
    if _tidx == 0:
        tick.set_horizontalalignment("left")
    elif _tidx == 2:
        tick.set_horizontalalignment("right")
    else:
        tick.set_horizontalalignment("center")

fig.savefig("Fig6D.pdf", dpi=300)


In [None]:

h = [
    Size.Fixed(margin),
    Size.Fixed(matw), #ctrl
    Size.Fixed(0.2*margin),
    Size.Fixed(matw),  # nup93
    Size.Fixed(0.2*margin),
    Size.Fixed(cbarh),
    Size.Fixed(margin),
]

# goes from bottom to the top ...
v = [
    # single color bar at the vewry bottom
    Size.Fixed(margin),
    Size.Fixed(0.5*cbarh),
    Size.Fixed(0.33*margin),
    Size.Fixed(cbarh),
    Size.Fixed(0.2*margin),
    Size.Fixed(matw),
    Size.Fixed(0.2*margin),
    Size.Fixed(matw),
    Size.Fixed(margin),
]
_stickingout_bit = sum(_h.fixed_size for _h in h[-3:])
print(f"{_stickingout_bit=}")

# set figsize based on the tiling provided ...
fig_width = sum(_h.fixed_size for _h in h)
fig_height = sum(_v.fixed_size for _v in v)
fig = plt.figure(
    figsize=(fig_width, fig_height),
    # facecolor='lightblue'
)
print(f"figure size {fig_width=} {fig_height=}")
# ...
divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)


nup_samples = ["N93m5","N93p5"]

axs = {}
axl = {}
axv = {}
for i, _sample in enumerate(nup_samples):
    _nx = 2*i+1
    axs[_sample] = {}
    axl[_sample] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx, ny=3))
    for j, _dist_key in enumerate(reversed(distances)):
        _ny = 2*j+5
        axs[_sample][_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx, ny=_ny))
        if _sample == nup_samples[-1]:
            axv[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx+2, ny=_ny))
# colorbar ...
cbar_ax = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nx, ny=1))

# traverse nested dict to access the axes ...
for _ax in (
    sum([ [v for v in axs[s].values()] for s in nup_samples ], start=[]) +
    list(axl.values()) + list(axv.values()) + [cbar_ax]
):
    _ax.set_xticks([])
    _ax.set_yticks([])

# reorder for the ones with IDs ...
_reidxs = [0,3,1,4,2,5,6]
for _sample in nup_samples:
    for _dist_key, _dist in distances.items():
        # IPG saddles w IDs ....
        if _dist_key != "trans":
            C = np.nanmean(interaction_sums_wids[_sample][_dist], axis=0) / np.nanmean(interaction_counts_wids[_sample][_dist], axis=0)
        elif _dist_key == "trans":
            C = np.nanmean(interaction_sums_trans_wids[_sample][_dist], axis=0) / np.nanmean(interaction_counts_trans_wids[_sample][_dist], axis=0)
        else:
            pass
        _h = axs[_sample][_dist_key].imshow(C[1:,1:][_reidxs][:,_reidxs], **imshow_kwargs)
        _h.cmap.set_over("#300000")
        # ...
        if _sample == nup_samples[0]:
            axs[_sample][_dist_key].set_ylabel(_dist_key, fontsize=8, labelpad=1)
        if _dist_key == "cis":
            axs[_sample][_dist_key].set_title(_sample, fontsize=8, pad=1)
        # extract size ...
        _num_ipg_wIDs_colors, _ = C[1:,1:].shape

# create a "fake" legend for IPGs - for now ...
_fdata = np.reshape(np.arange(_num_ipg_wIDs_colors), (-1,1))
_fcmap = plt.cm.gray
_fnorm = plt.Normalize()
_frgba = _fcmap(_fnorm(_fdata))
_frgbaT = _fcmap(_fnorm(_fdata.T))
# ...
for i in range(_num_ipg_wIDs_colors):
    _frgba[i,0] = list(mpl.colors.to_rgb(IPGwID_cmap[i]))+[1]
    _frgbaT[0,i] = list(mpl.colors.to_rgb(IPGwID_cmap[i]))+[1]


for _sample in nup_samples:
    axl[_sample].imshow(_frgbaT, aspect="auto")
    for _ in range(max(IPGwID_cmap.keys())):
        axl[_sample].axvline(_+.5,color="black",lw=0.5)

for _dist_key in distances:
    axv[_dist_key].imshow(_frgba, aspect="auto")
    for _ in range(max(IPGwID_cmap.keys())):
        axv[_dist_key].axhline(_+.5,color="black",lw=0.5)

# try adding labels to the right ...
axv["trans"].yaxis.tick_right()
axv["trans"].set_yticks(
    list(IPGwID_cmap.keys()),
    labels=ticklabels_ipgid,
    fontsize=6,
)
axv["trans"].tick_params(length=1.5, pad=1)


# add a single colorbar ...
fig.colorbar(
    cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
    cax=cbar_ax,
    orientation="horizontal",
)
_vmin = imshow_kwargs["norm"].vmin
_vmax = imshow_kwargs["norm"].vmax
cbar_ax.set_xticks([_vmin, 1, _vmax])
cbar_ax.set_xticklabels([f"{_vmin:.2f}", 1, _vmax], fontsize=6)
cbar_ax.minorticks_off()
cbar_ax.tick_params(length=1.5, pad=1)#,direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)
for _tidx, tick in enumerate(cbar_ax.xaxis.get_majorticklabels()):
    if _tidx == 0:
        tick.set_horizontalalignment("left")
    elif _tidx == 2:
        tick.set_horizontalalignment("right")
    else:
        tick.set_horizontalalignment("center")

fig.savefig("FigExt76E.pdf", dpi=300)


# Older stuff that isn't publication ready ...

In [None]:
fig, axs = plt.subplots(
    nrows=len(sub_samples_m),
    ncols=2*len(distances),
    figsize=(4*len(distances),2*len(sub_samples_m)),
    sharex=True,
    sharey=True,
)

imshow_kwargs = dict(
        norm=LogNorm(vmin=1/2, vmax=2),
        cmap="RdBu_r",
        interpolation="none",
)

for sample_m, sample_p, (i, axs) in zip(sub_samples_m, sub_samples_p, enumerate(axs)):
    for jj, (_dist_name, _dist) in enumerate(distances.items()):
        axm, axp = axs[jj], axs[len(distances) + jj]
        if _dist_name != "trans":
            Cm = np.nanmean(interaction_sums[sample_m][_dist], axis=0) / np.nanmean(interaction_counts[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums[sample_p][_dist], axis=0) / np.nanmean(interaction_counts[sample_p][_dist], axis=0)
        elif _dist_name == "trans":
            # pass
            Cm = np.nanmean(interaction_sums_trans[sample_m][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums_trans[sample_p][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_p][_dist], axis=0)
        else:
            pass
        axm.imshow(Cm, **imshow_kwargs)
        axp.imshow(Cp, **imshow_kwargs)
        for _ax in [axp, axm]:
            _ax.set_xticks([])
            _ax.set_yticks([])
        if i == 0:
            axm.set_title(f"m-{_dist_name}")
            axp.set_title(f"p-{_dist_name}")
        if i == len(sub_samples_m)-1:
            for _ax in [axm, axp]:
                _ax.set_xticks(np.arange(len(ticklabels)))
                _ax.set_xticklabels(np.asarray(ticklabels[::-1]), rotation="vertical")
        if jj == 0:
            axm.set_ylabel(sample_m.lstrip("m"))



In [None]:
fig, axs = plt.subplots(
    nrows=len(sub_samples_m),
    ncols=2*len(distances),
    figsize=(4*len(distances),2*len(sub_samples_m)),
    sharex=True,
    sharey=True,
)

imshow_kwargs = dict(
        norm=LogNorm(vmin=1/2.25, vmax=2.25),
        cmap="RdBu_r",
        interpolation="none",
)

for sample_m, sample_p, (i, axs) in zip(sub_samples_m, sub_samples_p, enumerate(axs)):
    for jj, (_dist_name, _dist) in enumerate(distances.items()):
        axm, axp = axs[jj], axs[len(distances) + jj]
        if _dist_name != "trans":
            Cm = np.nanmean(interaction_sums[sample_m][_dist], axis=0) / np.nanmean(interaction_counts[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums[sample_p][_dist], axis=0) / np.nanmean(interaction_counts[sample_p][_dist], axis=0)
        elif _dist_name == "trans":
            # pass
            Cm = np.nanmean(interaction_sums_trans[sample_m][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums_trans[sample_p][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_p][_dist], axis=0)
        else:
            pass
        axm.imshow(Cm, **imshow_kwargs)
        axp.imshow(Cp, **imshow_kwargs)
        for _ax in [axp, axm]:
            _ax.set_xticks([])
            _ax.set_yticks([])
        if i == 0:
            axm.set_title(f"m-{_dist_name}")
            axp.set_title(f"p-{_dist_name}")
        if i == len(sub_samples_m)-1:
            for _ax in [axm, axp]:
                _ax.set_xticks(np.arange(len(ticklabels)))
                _ax.set_xticklabels(np.asarray(ticklabels[::-1]), rotation="vertical")
        if jj == 0:
            axm.set_ylabel(sample_m.lstrip("m"))

In [None]:
# # # # the mix one - mp
sub_samples_m = [
    "N93m5",
    "N93m10",
]
# p ...
sub_samples_p = [
    "N93p5",
    "N93p10",
]

fig, axs = plt.subplots(
    nrows=len(sub_samples_m),
    ncols=2*len(distances),
    figsize=(4*len(distances),2*len(sub_samples_m)),
    sharex=True,
    sharey=True,
)

imshow_kwargs = dict(
        norm=LogNorm(vmin=1/2.5, vmax=2.5),
        cmap="RdBu_r",
        interpolation="none",
)

for sample_m, sample_p, (i, axs) in zip(sub_samples_m, sub_samples_p, enumerate(axs)):
    for jj, (_dist_name, _dist) in enumerate(distances.items()):
        axm, axp = axs[jj], axs[len(distances) + jj]
        if _dist_name != "trans":
            Cm = np.nanmean(interaction_sums[sample_m][_dist], axis=0) / np.nanmean(interaction_counts[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums[sample_p][_dist], axis=0) / np.nanmean(interaction_counts[sample_p][_dist], axis=0)
        elif _dist_name == "trans":
            Cm = np.nanmean(interaction_sums_trans[sample_m][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums_trans[sample_p][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_p][_dist], axis=0)
        else:
            pass
        axm.imshow(Cm, **imshow_kwargs)
        axp.imshow(Cp, **imshow_kwargs)
        for _ax in [axp, axm]:
            _ax.set_xticks([])
            _ax.set_yticks([])
        if i == 0:
            axm.set_title(f"m-{_dist_name}")
            axp.set_title(f"p-{_dist_name}")
        if i == len(sub_samples_m)-1:
            for _ax in [axm, axp]:
                _ax.set_xticks(np.arange(len(ticklabels)))
                _ax.set_xticklabels(np.asarray(ticklabels[::-1]), rotation="vertical")
        if jj == 0:
            axm.set_ylabel(sample_m.lstrip("m"))

In [None]:
sub_samples_m =[
        "m10hR1R2",
        "p10hR1R2",
        "mp10hR1R2",
    ]
sub_samples_p = [
        "N93m10",
        "N93p10",
        "N93mp10",
    ]

fig, axs = plt.subplots(
    nrows=len(sub_samples_m),
    ncols=2*len(distances),
    figsize=(4*len(distances),2*len(sub_samples_m)),
    sharex=True,
    sharey=True,
)

imshow_kwargs = dict(
        norm=LogNorm(vmin=1/2.5, vmax=2.5),
        cmap="RdBu_r",
        interpolation="none",
)

for sample_m, sample_p, (i, axs) in zip(sub_samples_m, sub_samples_p, enumerate(axs)):
    for jj, (_dist_name, _dist) in enumerate(distances.items()):
        axm, axp = axs[jj], axs[len(distances) + jj]
        if _dist_name != "trans":
            Cm = np.nanmean(interaction_sums[sample_m][_dist], axis=0) / np.nanmean(interaction_counts[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums[sample_p][_dist], axis=0) / np.nanmean(interaction_counts[sample_p][_dist], axis=0)
        elif _dist_name == "trans":
            Cm = np.nanmean(interaction_sums_trans[sample_m][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums_trans[sample_p][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_p][_dist], axis=0)
        else:
            pass
        axm.imshow(Cm, **imshow_kwargs)
        axp.imshow(Cp, **imshow_kwargs)
        for _ax in [axp, axm]:
            _ax.set_xticks([])
            _ax.set_yticks([])
        if i == 0:
            axm.set_title(f"m-{_dist_name}")
            axp.set_title(f"p-{_dist_name}")
        if i == len(sub_samples_m)-1:
            for _ax in [axm, axp]:
                _ax.set_xticks(np.arange(len(ticklabels)))
                _ax.set_xticklabels(np.asarray(ticklabels[::-1]), rotation="vertical")
        if jj == 0:
            axm.set_ylabel(sample_m)



# try adding an axes manually ...
cax = fig.add_axes([0.88,0.001,0.1,0.02])
fig.colorbar(
    cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
    cax=cax,
    orientation="horizontal",
)
cax.set_xticks([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
cax.set_xticklabels([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
cax.minorticks_off()