# Saddleplots - plotting only !

In [None]:
# import standard python libraries
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
import os, subprocess

In [None]:
# Import python package for working with cooler files and tools for analysis
import cooler
import cooltools.lib.plotting

In [None]:
%load_ext autoreload
%autoreload 2
# from saddle import saddleplot

In [None]:
# download test data
# this file is 145 Mb, and may take a few seconds to download
import bbi
import cooltools
import bioframe
from matplotlib.colors import LogNorm
from helper_func import saddleplot
from data_catalog import bws, bws_vlim, telo_dict

import saddle


In [None]:
from tqdm import tqdm
from tqdm.notebook import trange, tqdm
import warnings
import seaborn as sns

import warnings
import h5py

import matplotlib.lines as lines
from matplotlib.lines import Line2D
from matplotlib.patches import ConnectionPatch, Rectangle
from mpl_toolkits.axes_grid1 import Divider, Size
from mpl_toolkits.axes_grid1.inset_locator import BboxConnector
from matplotlib import cm
import matplotlib as mpl
# from mpl_toolkits.axes_grid1.Size import Fixed


# enable editable text ...
mpl.rcParams["pdf.fonttype"]=42
mpl.rcParams["svg.fonttype"]="none"
mpl.rcParams['axes.linewidth'] = 0.5

## Calculating per-chromosome compartmentalization

We first load the Hi-C data at 100 kbp resolution. 

Note that the current implementation of eigendecomposition in cooltools assumes that individual regions can be held in memory-- for hg38 at 100kb this is either a 2422x2422 matrix for chr2, or a 3255x3255 matrix for the full cooler here.

In [None]:
# define genomic view that will be used to call dots and pre-compute expected

# Use bioframe to fetch the genomic features from the UCSC.
hg38_chromsizes = bioframe.fetch_chromsizes('hg38')
hg38_cens = bioframe.fetch_centromeres('hg38')
hg38_arms_full = bioframe.make_chromarms(hg38_chromsizes, hg38_cens)
# # remove "bad" chromosomes and near-empty arms ...
included_arms = hg38_arms_full["name"].to_list()[:44] # all autosomal ones ...
hg38_arms = hg38_arms_full[hg38_arms_full["name"].isin(included_arms)].reset_index(drop=True)

In [None]:
cre_to_code = {
    'pls':6,
    'pels':5,
    'nels':5,
    'dels':4,
    'openk4':3,
    'ctcf':2,
    'justopen':1,
}

_cats = pd.CategoricalDtype(categories=list(range(6+1)), ordered=True)

cre_fnames = {
    "atac95_rest95" : "./bigbed/all_cres_encodestyle_pMunion_95atac_95rest.bed",  # control sorta, like before
    # "atac90_rest90" : "./bigbed/all_cres_encodestyle_pMunion_90atac_90rest.bed",
    # "atac70_rest70" : "./bigbed/all_cres_encodestyle_pMunion_70atac_70rest.bed",
    # "atac70_rest95" : "./bigbed/all_cres_encodestyle_pMunion_70atac_95rest.bed",
    # "atac70_rest90" : "./bigbed/all_cres_encodestyle_pMunion_70atac_90rest.bed",
    # "atac75_rest95" : "./bigbed/all_cres_encodestyle_pMunion_75atac_95rest.bed",
    # "atac75_rest90" : "./bigbed/all_cres_encodestyle_pMunion_75atac_90rest.bed",
    # "allana_latest" : "./bigbed/all_cres_encodestyle_pMunion_NEWatac_k2785_rest95.bed",
}

_status_query = "(status_ == 0)|(status_ == 5)|(status_ == 4)|(status_ == 6)|(status_ == 7)"

binned_cre10_dfs = {}
binned_cre25_dfs = {}
for name, cre_fname in cre_fnames.items():
    df = pd.read_table(cre_fname, usecols=['chrom', 'start', 'end', 'name'])
    print(f"read {name=} with {len(df)} items ...")
    # # rename nels -> pels
    # df = df.replace({'name': {"nels": "pels"}})
    # sort by coordinate
    df = bioframe.sort_bedframe(df, view_df=hg38_arms)
    df["status"] = df["name"].map(cre_to_code)

    # # now annotate bins with the anchors @10kb  ...
    # _bin_cres = bioframe.overlap(clr_bins10, df)
    # _bin_cres = _bin_cres.drop(columns=["chrom_","start_","end_"])
    # _bin_cres["status_"] = _bin_cres["status_"].fillna(0).astype(int)
    # # assign cCREs per bin according to the hierarchy ...
    # binned_cres = _bin_cres.groupby(["chrom","start","end"], observed=True)["status_"].max().reset_index()
    # binned_cres["status_"] = binned_cres["status_"].astype(_cats)
    # # value_counts of annotated bins before gene assignment ...
    # print("@10kb\n")
    # display(binned_cres["status_"].value_counts().sort_index())
    # print("\n")
    # #
    # binned_cre10_dfs[name] = binned_cres.astype({ "chrom" : str })

    # # now annotate bins with the anchors @25kb  ...
    # _bin_cres = bioframe.overlap(clr_bins25, df)
    # _bin_cres = _bin_cres.drop(columns=["chrom_","start_","end_"])
    # _bin_cres["status_"] = _bin_cres["status_"].fillna(0).astype(int)
    # # assign cCREs per bin according to the hierarchy ...
    # binned_cres = _bin_cres.groupby(["chrom","start","end"], observed=True)["status_"].max().reset_index()
    # binned_cres["status_"] = binned_cres["status_"].astype(_cats)
    # # value_counts of annotated bins before gene assignment ...
    # print("@25kb\n")
    # display(binned_cres["status_"].value_counts().sort_index())
    # print("\n")
    # # ...
    # binned_cre25_dfs[name] = binned_cres.astype({ "chrom" : str })



ticklabels = [
    "PLS",
    "pELS",
    "dELS",
    "opK4",
    "CTCF",
    "open",
    "none",
]
# # # categroies are going to be 1,2,3,4,5,6
# # # 0 - everything else ...

# _working_cres10 = binned_cre10_dfs['atac95_rest95']
# _working_cres25 = binned_cre25_dfs['atac95_rest95']

In [None]:

# fr.items()
def print_attrs(name, obj):
    # Create indent
    shift = name.count('/') * '    '
    item_name = name.split("/")[-1]
    print(shift + item_name)
    try:
        for key, val in obj.attrs.items():
            print(shift + '    ' + f"{key}: {val}")
    except:
        pass


with h5py.File("saddles_cre_by_distance.hdf5", 'r') as fr:
# with h5py.File("saddles_cre_by_distance_Allana_latest_cre.hdf5", 'r') as fr:
    # f = h5py.File('foo.hdf5','r')
    fr.visititems(print_attrs)

    # check general metadata ...
    _saddle_meta = dict(fr.attrs)

    # sort out the results ...
    interaction_sums_trans = {}
    interaction_counts_trans = {}
    # sort out the results ...
    _counts = fr.get("counts_trans")
    for _sample in _counts.keys():
        interaction_counts_trans[_sample] = _counts.get(_sample)[()]
    _sums = fr.get("sums_trans")
    for _sample in _sums.keys():
        interaction_sums_trans[_sample] = _sums.get(_sample)[()]

    interaction_sums = {}
    interaction_counts = {}
    # sort out the results ...
    _counts = fr.get("counts")
    for _sample in _counts.keys():
        interaction_counts[_sample] = _counts.get(_sample)[()]
    _sums = fr.get("sums")
    for _sample in _sums.keys():
        interaction_sums[_sample] = _sums.get(_sample)[()]


In [None]:
sub_samples_m = [
    "mMito",
    "mTelo",
    "mCyto",
    "m5hR1R2",
    "m10hR1R2",
]
sub_samples_p = [
    "pMito",
    "pTelo",
    "pCyto",
    "p5hR1R2",
    "p10hR1R2",
]


# introduce distance ranges
distances = {
    "<0.25MB": slice(0,int(250_000/_saddle_meta["cis_binsize"])+1),
    "0.25-1Mb": slice(int(250_000/_saddle_meta["cis_binsize"]),int(1_000_000/_saddle_meta["cis_binsize"])+1),
    "1-10Mb": slice(int(1_000_000/_saddle_meta["cis_binsize"]),int(10_000_000/_saddle_meta["cis_binsize"])+1),
    ">10Mb": slice(int(10_000_000/_saddle_meta["cis_binsize"]),int(1_000_000_000/_saddle_meta["cis_binsize"])+1),
    "trans": slice(None),
}


imshow_kwargs = dict(
        norm=LogNorm(vmin=1/2, vmax=2),
        cmap="RdBu_r",
        interpolation="nearest",
)

legend_kwargs = dict(
    aspect="auto",
    interpolation="nearest",
)

cre_cmap = {
    6: "darkred",  #"PLS",
    5: "red",  #"pELS",
    4: "orangered",  #"dELS",
    3: "pink",  #"opK4",
    2: "blue",  #"CTCF",
    1: "#ffee99",  #"open",
    0: "#D9E2EF",  #"none"
}

# some common sizings for the saddles ...
# cbarw = 0.7*matw
margin = 0.2
matw = 0.6
cbarh = 0.08


## Draw actual figures with the semi-manual custom layout ...

In [None]:
# The first items are for padding and the second items are for the axes, sizes are in inch.
h = [ Size.Fixed(margin) ] + len(distances)*[Size.Fixed(matw), Size.Fixed(0.2*margin)] + [Size.Fixed(cbarh), Size.Fixed(margin)]
# goes from bottom to the top ...
v = [Size.Fixed(margin), Size.Fixed(cbarh)] + 2 * [Size.Fixed(0.2*margin), Size.Fixed(matw)] + [Size.Fixed(0.2*margin), Size.Fixed(cbarh)]

_stickingout_bit = sum(_h.fixed_size for _h in h[-3:])
print(f"{_stickingout_bit=}")



for _fig_fname, (sample_m, sample_p) in {
    "Fig2F.svg": ("m5hR1R2", "p5hR1R2"),
    "FigE2D.svg": ("N93m5", "N93p5"),
}.items():
    # set figsize based on the tiling provided ...
    fig_width = sum(_h.fixed_size for _h in h)
    fig_height = sum(_v.fixed_size for _v in v)
    fig = plt.figure(
        figsize=(fig_width, fig_height),
        # facecolor='lightblue'
    )
    print(f"figure size {fig_width=} {fig_height=}")

    # ...
    divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)

    axs_m = {}
    axs_p = {}
    cax_h = {}

    for i, _dist_key in enumerate(distances):
        # mind the gaps/marging between actual plots ...
        axs_p[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=3))
        axs_m[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=5))
        # ...
        cax_h[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=7))

    caxp_v = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*(i+1)+1, ny=3))
    caxm_v = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*(i+1)+1, ny=5))

    cbar_ax = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=1))

    for _ax in list(cax_h.values()) + [caxm_v, caxp_v, cbar_ax]:
        _ax.set_xticks([])
        _ax.set_yticks([])

    for jj, (_dist_key, _dist) in enumerate(distances.items()):
        axm, axp, cax = axs_m[_dist_key], axs_p[_dist_key], cax_h[_dist_key]
        if _dist_key != "trans":
            Cm = np.nanmean(interaction_sums[sample_m][_dist], axis=0) / np.nanmean(interaction_counts[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums[sample_p][_dist], axis=0) / np.nanmean(interaction_counts[sample_p][_dist], axis=0)
        elif _dist_key == "trans":
            # pass
            Cm = np.nanmean(interaction_sums_trans[sample_m][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums_trans[sample_p][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_p][_dist], axis=0)
        else:
            pass
        _hm = axm.imshow(Cm, **imshow_kwargs)
        _hm.cmap.set_over("#300000")
        _hm = axp.imshow(Cp, **imshow_kwargs)
        _hm.cmap.set_over("#300000")

        for _ax in [axp, axm]:
            _ax.set_xticks([])
            _ax.set_yticks([])
        # axm.set_title(f"{_dist_key}", fontsize=9)
        cax.set_title(f"{_dist_key}", fontsize=8)
        # axp.set_xticks(np.arange(len(ticklabels)))
        # axp.set_xticklabels(np.asarray(ticklabels[::-1]), rotation="vertical")
        if jj == 0:
            axm.set_ylabel(sample_m, fontsize=8)
            axp.set_ylabel(sample_p, fontsize=8)

    # create a "fake" legend for CREs - for now ...
    _size, _ = Cm.shape
    # ...
    _fdata = np.reshape(np.arange(_size), (-1,1))
    _fcmap = plt.cm.gray
    _fnorm = plt.Normalize()
    _frgba = _fcmap(_fnorm(_fdata))
    _frgbaT = _fcmap(_fnorm(_fdata.T))
    # ...
    # Set the diagonal to red
    for i in range(_size):
        _frgba[i,0] = list(mpl.colors.to_rgb(cre_cmap[i]))+[1]
        _frgbaT[0,i] = list(mpl.colors.to_rgb(cre_cmap[i]))+[1]

    caxp_v.imshow(_frgba, **legend_kwargs)
    caxm_v.imshow(_frgba, **legend_kwargs)

    for _dist_key in distances:
        cax_h[_dist_key].imshow(_frgbaT, **legend_kwargs)

    # add a single colorbar ...

    fig.colorbar(
        cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
        cax=cbar_ax,
        orientation="horizontal",
    )
    cbar_ax.set_xticks([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
    cbar_ax.set_xticklabels([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax], fontsize=6)
    cbar_ax.minorticks_off()
    cbar_ax.tick_params(length=1.5, pad=1)#,direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)
    for _tidx, tick in enumerate(cbar_ax.xaxis.get_majorticklabels()):
        if _tidx == 0:
            tick.set_horizontalalignment("left")
        elif _tidx == 2:
            tick.set_horizontalalignment("right")
        else:
            tick.set_horizontalalignment("center")

    # # # # plt.show()
    fig.savefig(_fig_fname, dpi=300)


## Do a similar thing for the extended figure 3 ...

In [None]:
samples_m = [
    "mMito",
    "mTelo",
    "mCyto",
    "m5hR1R2",
]
samples_p = [
    "pMito",
    "pTelo",
    "pCyto",
    "p5hR1R2",
]
assert len(samples_m) == len(samples_p)
_num_timepoints = len(samples_m)

# The first items are for padding and the second items are for the axes, sizes are in inch.
h = [ Size.Fixed(margin) ] + \
    (len(distances) - 1)*[Size.Fixed(matw), Size.Fixed(0.2*margin)] + [Size.Fixed(matw), Size.Fixed(margin)] + \
    len(distances)*[Size.Fixed(matw), Size.Fixed(0.2*margin)] + \
    [Size.Fixed(cbarh), Size.Fixed(margin)]
# goes from bottom to the top ...
v = [Size.Fixed(margin), Size.Fixed(cbarh)] + \
    _num_timepoints*[Size.Fixed(0.2*margin), Size.Fixed(matw)] + \
    [Size.Fixed(0.2*margin), Size.Fixed(cbarh)] + \
    [Size.Fixed(margin)]

_stickingout_bit = sum(_h.fixed_size for _h in h[-3:])
print(f"{_stickingout_bit=}")

# set figsize based on the tiling provided ...
fig_width = sum(_h.fixed_size for _h in h)
fig_height = sum(_v.fixed_size for _v in v)
fig = plt.figure(
    figsize=(fig_width, fig_height),
    # facecolor='lightblue'
)
print(f"figure size {fig_width=} {fig_height=}")

# ...
divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)

axs_p = {}
axs_m = {}
axl_m = {}
axl_p = {}
axlv_m = {}

for i, _dist_key in enumerate(distances):
    _nxm = 2*i+1
    _nxp = 2*(i+len(distances))+1
    axs_p[_dist_key] = {}
    axs_m[_dist_key] = {}
    for j, (_m, _p) in enumerate(zip(reversed(samples_m), reversed(samples_p))):
        # mind the gaps/marging between actual plots ...
        _ny = 2*j+3
        axs_m[_dist_key][_m] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nxm, ny=_ny))
        axs_p[_dist_key][_p] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nxp, ny=_ny))
    # the last j+1 for the label on top ...
    _ytop = 2*(j+1)+3
    axl_m[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nxm, ny=_ytop))
    axl_p[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nxp, ny=_ytop))

_nxright = 4*len(distances)+1
for j, _p in enumerate(reversed(samples_p)):
    # mind the gaps/marging between actual plots ...
    _ny = 2*j+3
    axlv_m[_p] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nxright, ny=_ny))

# separate colorbar axes
cbar_ax = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=_nxright-2, ny=1))

# turn of ticks everywhere ... for saddles first
# for label axes and cbar afterwards ...
for ax in (
    sum([list(v.values()) for k,v in axs_p.items()], start=[]) + sum([list(v.values()) for k,v in axs_m.items()], start=[]) +
    list(axl_m.values()) + list(axl_p.values()) + list(axlv_m.values()) + [cbar_ax]
):
    ax.set_xticks([])
    ax.set_yticks([])


for _dist_key, _dist in distances.items():
    for _m, _p in zip(samples_m, samples_p):
        axm = axs_m[_dist_key][_m]
        axp = axs_p[_dist_key][_p]
        if _dist_key != "trans":
            Cm = np.nanmean(interaction_sums[_m][_dist], axis=0) / np.nanmean(interaction_counts[_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums[_p][_dist], axis=0) / np.nanmean(interaction_counts[_p][_dist], axis=0)
        elif _dist_key == "trans":
            Cm = np.nanmean(interaction_sums_trans[_m][_dist], axis=0) / np.nanmean(interaction_counts_trans[_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums_trans[_p][_dist], axis=0) / np.nanmean(interaction_counts_trans[_p][_dist], axis=0)
        else:
            pass
        # plot corresponding saddle ...
        _hm = axm.imshow(Cm, **imshow_kwargs)
        _hm.cmap.set_over("#300000")
        _hm = axp.imshow(Cp, **imshow_kwargs)
        _hm.cmap.set_over("#300000")
        if _dist_key.startswith("<"):
            axm.set_ylabel(_m.lstrip("m"), fontsize=8, labelpad=1)
            # axp.set_ylabel(_p, fontsize=8, labelpad=1)
    axl_m[_dist_key].set_title(f"{_dist_key}", fontsize=8, pad=2)
    axl_p[_dist_key].set_title(f"{_dist_key}", fontsize=8, pad=2)

# create a "fake" legend for CREs - for now ...
_size, _ = Cm.shape
# ...
_fdata = np.reshape(np.arange(_size), (-1,1))
_fcmap = plt.cm.gray
_fnorm = plt.Normalize()
_frgba = _fcmap(_fnorm(_fdata))
_frgbaT = _fcmap(_fnorm(_fdata.T))
# ...
# Set the diagonal to red
for i in range(_size):
    _frgba[i,0] = list(mpl.colors.to_rgb(cre_cmap[i]))+[1]
    _frgbaT[0,i] = list(mpl.colors.to_rgb(cre_cmap[i]))+[1]

for _dist_key in distances:
    axl_m[_dist_key].imshow(_frgbaT, **legend_kwargs)
    axl_p[_dist_key].imshow(_frgbaT, **legend_kwargs)
for _p in samples_p:
    axlv_m[_p].imshow(_frgba, **legend_kwargs)

# try adding labels to the right ...
axlv_m[samples_p[-1]].yaxis.tick_right()
axlv_m[samples_p[-1]].set_yticks(
    list(cre_cmap.keys()),
    labels=ticklabels,
    fontsize=6,
)
axlv_m[samples_p[-1]].tick_params(length=1.5, pad=1)

# add a single colorbar ...
fig.colorbar(
    cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
    cax=cbar_ax,
    orientation="horizontal",
)
cbar_ax.set_xticks([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
cbar_ax.set_xticklabels([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax], fontsize=6)
cbar_ax.minorticks_off()
cbar_ax.tick_params(length=1.5, pad=1)
for _tidx, tick in enumerate(cbar_ax.xaxis.get_majorticklabels()):
    if _tidx == 0:
        tick.set_horizontalalignment("left")
    elif _tidx == 2:
        tick.set_horizontalalignment("right")
    else:
        tick.set_horizontalalignment("center")


_fig_fname = "FigExt3B.svg"
fig.savefig(_fig_fname, dpi=300)

! cairosvg --format pdf -o FigExt3B.pdf FigExt3B.svg

## Draw Extended fig 7 or so - the 10hour mp experiment ...

In [None]:
# The first items are for padding and the second items are for the axes, sizes are in inch.
h = [ Size.Fixed(margin) ] + len(distances)*[Size.Fixed(matw), Size.Fixed(0.2*margin)] + [Size.Fixed(cbarh), Size.Fixed(margin)]
# goes from bottom to the top ...
v = [Size.Fixed(margin), Size.Fixed(cbarh)] + 3 * [Size.Fixed(0.2*margin), Size.Fixed(matw)] + [Size.Fixed(0.2*margin), Size.Fixed(cbarh), Size.Fixed(margin)]

_stickingout_bit = sum(_h.fixed_size for _h in h[-3:])
print(f"{_stickingout_bit=}")

sample_m, sample_mp, sample_p = "m10hR1R2", "mp10hR1R2", "p10hR1R2"
# set figsize based on the tiling provided ...
fig_width = sum(_h.fixed_size for _h in h)
fig_height = sum(_v.fixed_size for _v in v)
fig = plt.figure(
    figsize=(fig_width, fig_height),
    # facecolor='lightblue'
)
print(f"figure size {fig_width=} {fig_height=}")

# ...
divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False)

axs_m = {}
axs_mp = {}
axs_p = {}
cax_h = {}

for i, _dist_key in enumerate(distances):
    # mind the gaps/marging between actual plots ...
    axs_m[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=7))
    axs_mp[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=5))
    axs_p[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=3))
    # ...
    cax_h[_dist_key] = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=9))

caxm_v = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*(i+1)+1, ny=7))
caxmp_v = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*(i+1)+1, ny=5))
caxp_v = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*(i+1)+1, ny=3))
cbar_ax = fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*i+1, ny=1))

for ax in list(cax_h.values()) + [caxm_v, caxmp_v, caxp_v, cbar_ax]:
    ax.set_xticks([])
    ax.set_yticks([])

for jj, (_dist_key, _dist) in enumerate(distances.items()):
    axm, axmp, axp, cax = axs_m[_dist_key], axs_mp[_dist_key], axs_p[_dist_key], cax_h[_dist_key]
    if _dist_key != "trans":
        Cm = np.nanmean(interaction_sums[sample_m][_dist], axis=0) / np.nanmean(interaction_counts[sample_m][_dist], axis=0)
        Cmp = np.nanmean(interaction_sums[sample_mp][_dist], axis=0) / np.nanmean(interaction_counts[sample_mp][_dist], axis=0)
        Cp = np.nanmean(interaction_sums[sample_p][_dist], axis=0) / np.nanmean(interaction_counts[sample_p][_dist], axis=0)
    elif _dist_key == "trans":
        Cm = np.nanmean(interaction_sums_trans[sample_m][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_m][_dist], axis=0)
        Cmp = np.nanmean(interaction_sums_trans[sample_mp][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_mp][_dist], axis=0)
        Cp = np.nanmean(interaction_sums_trans[sample_p][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_p][_dist], axis=0)
    else:
        pass
    _hm = axm.imshow(Cm, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    _hm = axmp.imshow(Cmp, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    _hm = axp.imshow(Cp, **imshow_kwargs)
    _hm.cmap.set_over("#300000")
    for _ax in [axp, axmp, axm]:
        _ax.set_xticks([])
        _ax.set_yticks([])
    cax.set_title(f"{_dist_key}", fontsize=8)
    if jj == 0:
        axm.set_ylabel(sample_m.rstrip("R1R2"), fontsize=8, labelpad=1)
        axmp.set_ylabel(sample_mp.rstrip("R1R2"), fontsize=8, labelpad=1)
        axp.set_ylabel(sample_p.rstrip("R1R2"), fontsize=8, labelpad=1)

# create a "fake" legend for CREs - for now ...
_size, _ = Cm.shape
# ...
_fdata = np.reshape(np.arange(_size), (-1,1))
_fcmap = plt.cm.gray
_fnorm = plt.Normalize()
_frgba = _fcmap(_fnorm(_fdata))
_frgbaT = _fcmap(_fnorm(_fdata.T))
# ...
# Set the diagonal to red
for i in range(_size):
    _frgba[i,0] = list(mpl.colors.to_rgb(cre_cmap[i]))+[1]
    _frgbaT[0,i] = list(mpl.colors.to_rgb(cre_cmap[i]))+[1]

caxp_v.imshow(_frgba, **legend_kwargs)
caxm_v.imshow(_frgba, **legend_kwargs)
caxmp_v.imshow(_frgba, **legend_kwargs)

for _dist_key in distances:
    cax_h[_dist_key].imshow(_frgbaT, **legend_kwargs)


# try adding labels to the right ...
caxp_v.yaxis.tick_right()
caxp_v.set_yticks(
    list(cre_cmap.keys()),
    labels=ticklabels,
    fontsize=6,
)
caxp_v.tick_params(length=1.5, pad=1)


# add a single colorbar ...
fig.colorbar(
    cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
    cax=cbar_ax,
    orientation="horizontal",
)
cbar_ax.set_xticks([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
cbar_ax.set_xticklabels([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax], fontsize=6)
cbar_ax.minorticks_off()
cbar_ax.tick_params(length=1.5, pad=1)#,direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)
for _tidx, tick in enumerate(cbar_ax.xaxis.get_majorticklabels()):
    if _tidx == 0:
        tick.set_horizontalalignment("left")
    elif _tidx == 2:
        tick.set_horizontalalignment("right")
    else:
        tick.set_horizontalalignment("center")

_fig_fname = "FigExt7I_mp.svg"
fig.savefig(_fig_fname, dpi=300)


# Older stuff that isn't publication ready ...

In [None]:
# # # # the mix one - mp
sub_samples_m = [
    "N93m5",
    "N93m10",
]
# p ...
sub_samples_p = [
    "N93p5",
    "N93p10",
]

fig, axs = plt.subplots(
    nrows=len(sub_samples_m),
    ncols=2*len(distances),
    figsize=(4*len(distances),2*len(sub_samples_m)),
    sharex=True,
    sharey=True,
)

imshow_kwargs = dict(
        norm=LogNorm(vmin=1/2.5, vmax=2.5),
        cmap="RdBu_r",
        interpolation="none",
)

for sample_m, sample_p, (i, axs) in zip(sub_samples_m, sub_samples_p, enumerate(axs)):
    for jj, (_dist_name, _dist) in enumerate(distances.items()):
        axm, axp = axs[jj], axs[len(distances) + jj]
        if _dist_name != "trans":
            Cm = np.nanmean(interaction_sums[sample_m][_dist], axis=0) / np.nanmean(interaction_counts[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums[sample_p][_dist], axis=0) / np.nanmean(interaction_counts[sample_p][_dist], axis=0)
        elif _dist_name == "trans":
            Cm = np.nanmean(interaction_sums_trans[sample_m][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums_trans[sample_p][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_p][_dist], axis=0)
        else:
            pass
        axm.imshow(Cm, **imshow_kwargs)
        axp.imshow(Cp, **imshow_kwargs)
        for _ax in [axp, axm]:
            _ax.set_xticks([])
            _ax.set_yticks([])
        if i == 0:
            axm.set_title(f"m-{_dist_name}")
            axp.set_title(f"p-{_dist_name}")
        if i == len(sub_samples_m)-1:
            for _ax in [axm, axp]:
                _ax.set_xticks(np.arange(len(ticklabels)))
                _ax.set_xticklabels(np.asarray(ticklabels[::-1]), rotation="vertical")
        if jj == 0:
            axm.set_ylabel(sample_m.lstrip("m"))

In [None]:
sub_samples_m =[
        "m10hR1R2",
        "p10hR1R2",
        "mp10hR1R2",
    ]
sub_samples_p = [
        "N93m10",
        "N93p10",
        "N93mp10",
    ]

fig, axs = plt.subplots(
    nrows=len(sub_samples_m),
    ncols=2*len(distances),
    figsize=(4*len(distances),2*len(sub_samples_m)),
    sharex=True,
    sharey=True,
)

imshow_kwargs = dict(
        norm=LogNorm(vmin=1/2.5, vmax=2.5),
        cmap="RdBu_r",
        interpolation="none",
)

for sample_m, sample_p, (i, axs) in zip(sub_samples_m, sub_samples_p, enumerate(axs)):
    for jj, (_dist_name, _dist) in enumerate(distances.items()):
        axm, axp = axs[jj], axs[len(distances) + jj]
        if _dist_name != "trans":
            Cm = np.nanmean(interaction_sums[sample_m][_dist], axis=0) / np.nanmean(interaction_counts[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums[sample_p][_dist], axis=0) / np.nanmean(interaction_counts[sample_p][_dist], axis=0)
        elif _dist_name == "trans":
            Cm = np.nanmean(interaction_sums_trans[sample_m][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums_trans[sample_p][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_p][_dist], axis=0)
        else:
            pass
        axm.imshow(Cm, **imshow_kwargs)
        axp.imshow(Cp, **imshow_kwargs)
        for _ax in [axp, axm]:
            _ax.set_xticks([])
            _ax.set_yticks([])
        if i == 0:
            axm.set_title(f"m-{_dist_name}")
            axp.set_title(f"p-{_dist_name}")
        if i == len(sub_samples_m)-1:
            for _ax in [axm, axp]:
                _ax.set_xticks(np.arange(len(ticklabels)))
                _ax.set_xticklabels(np.asarray(ticklabels[::-1]), rotation="vertical")
        if jj == 0:
            axm.set_ylabel(sample_m)



# try adding an axes manually ...
cax = fig.add_axes([0.88,0.001,0.1,0.02])
fig.colorbar(
    cm.ScalarMappable(norm=imshow_kwargs["norm"], cmap=imshow_kwargs["cmap"]),
    cax=cax,
    orientation="horizontal",
)
cax.set_xticks([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
cax.set_xticklabels([imshow_kwargs["norm"].vmin, 1, imshow_kwargs["norm"].vmax])
cax.minorticks_off()