# Saddleplots

## we'll just generate the data here - and them we'll reuse them in a separate notebook for plotting !

Welcome to the compartments and saddleplot notebook! 

This notebook illustrates cooltools functions used for investigating chromosomal compartments, visible as plaid patterns in mammalian interphase contact frequency maps.

These plaid patterns reflect tendencies of chromosome regions to make more frequent contacts with regions of the same type: active regions have increased contact frequency with other active regions, and intactive regions tend to contact other inactive regions more frequently. The strength of compartmentalization has been show to vary through the cell cycle, across cell types, and after degredation of components of the cohesin complex. 

In this notebook we:

* obtain compartment profiles using eigendecomposition
* calculate and visualize strength of compartmentalization using saddleplots

In [None]:
# import standard python libraries
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
import os, subprocess

In [None]:
# Import python package for working with cooler files and tools for analysis
import cooler
import cooltools.lib.plotting
import bioframe
import multiprocess as mp

In [None]:
%load_ext autoreload
%autoreload 2
# from saddle import saddleplot

In [None]:
# download test data
# this file is 145 Mb, and may take a few seconds to download
import bbi
import cooltools
import bioframe
from matplotlib.colors import LogNorm
from helper_func import saddleplot
from data_catalog import bws, bws_vlim, telo_dict

import saddle


In [None]:
from tqdm import tqdm
from tqdm.notebook import trange, tqdm
import warnings
import seaborn as sns

from mpire import WorkerPool
import warnings


## Calculating per-chromosome compartmentalization

We first load the Hi-C data at 100 kbp resolution. 

Note that the current implementation of eigendecomposition in cooltools assumes that individual regions can be held in memory-- for hg38 at 100kb this is either a 2422x2422 matrix for chr2, or a 3255x3255 matrix for the full cooler here.

In [None]:
# define genomic view that will be used to call dots and pre-compute expected

# Use bioframe to fetch the genomic features from the UCSC.
hg38_chromsizes = bioframe.fetch_chromsizes('hg38')
hg38_cens = bioframe.fetch_centromeres('hg38')
hg38_arms_full = bioframe.make_chromarms(hg38_chromsizes, hg38_cens)
# # remove "bad" chromosomes and near-empty arms ...
included_arms = hg38_arms_full["name"].to_list()[:44] # all autosomal ones ...
hg38_arms = hg38_arms_full[hg38_arms_full["name"].isin(included_arms)].reset_index(drop=True)

### pre-load coolers and pre-calculate expected ...

In [None]:
# cooler files that we'll work on :
binsize = 50_000
telo_clrs = { _k: cooler.Cooler(f"{_path}::/resolutions/{binsize}") for _k, _path in telo_dict.items() }

In [None]:
def _job(packed_data, sample):
    # packed data -> exp_kwargs and a dict with coolers for each sample
    exp_kwargs, clr_dict = packed_data
    _clr = clr_dict[sample]
    # in order to use spawn/forkserver we have to import for worker
    from cooltools import expected_cis
    _exp = expected_cis( _clr, **exp_kwargs)
    return (sample, _exp)

# define expected parameters in the form of kwargs-dict:
exp_kwargs = dict(
    view_df=hg38_arms,
    intra_only=False,
    nproc=12
)

# have to use daemon=False, because _job is multiprocessing-based already ...
with WorkerPool(
    n_jobs=8,
    daemon=False,
    shared_objects=( exp_kwargs, telo_clrs ),
    start_method="forkserver",  # little faster than spawn, fork is the fastest
    use_dill=True,
) as wpool:
    results = wpool.map(_job, telo_clrs, progress_bar=True)

# sort out the results ...
telo_exps_cis = {sample: _exp for sample, _exp in results}

### trans-expected second

In [None]:
def _job(packed_data, sample):
    # unpack data
    clr_dict, = packed_data
    exp_kwargs = dict(chunksize=1000000, nproc=12)
    from cooltools import expected_trans
    _clr = clr_dict[sample]
    _exp = expected_trans( _clr, **exp_kwargs).set_index(["region1", "region2"]).sort_index()
    return (sample, _exp)

# have to use daemon=False, because _job is multiprocessing-based already ...
with WorkerPool(
    n_jobs=8,
    daemon=False,
    shared_objects=(telo_clrs, ),
    start_method="forkserver",
    use_dill=True,
) as wpool:
    results = wpool.map(_job, telo_clrs, progress_bar=True)

# sort out the results ...
telo_exps_trans = {sample: _exp for sample, _exp in results}

# Load partition of the genome into clusters ...
# ... and create an assignment track for the saddle ...



## Saddleplots

A common way to visualize preferences captured by the eigenvector is by using saddleplots.

To generate a saddleplot, we first use the eigenvector to stratify genomic regions into groups with similar values of the eigenvector. These groups are then averaged over to create the saddleplot.
This process is called "digitizing".

Cooltools will operate with `digitized` bedgraph-like track with four columns. The fourth, or value, column is a categorical, as shown above for the first three bins. Categories have the following encoding:

    - `1..n` <-> values assigned to bins defined by vrange or qrange
    - `0` <-> left outlier values
    - `n+1` <-> right outlier values
    - `-1` <-> missing data (NaNs)
    
Track values can either be digitized by numeric values, by passing `vrange`, or by quantiles, by passing `qrange`, as above.

To create saddles in cis with `saddle`, cooltools requires: a cooler, a table with expected as function of distance, and parameters for digitizing:

In [None]:
nezar_df = bioframe.read_table("GSM7990272_DLD1.360.NT.50000.E1-E128.comp_10.kmeans9_5.bed", schema="bed9")

# Now make saddle-compatible track - i.e. digitized into 0,1,2,3,...categories
_comp_dict = {
    'A1':(1 + 3),
    'A2':(2 + 3),
    'V+VI':(3 + 3),
    'B2/B3':(4 + 3),
    'B4':(4 + 3),
}

_num_cats = max(_comp_dict.values()) + 1
_cats = pd.CategoricalDtype(categories=list(range(_num_cats)), ordered=True)

k = "name"
_track = nezar_df[["chrom","start","end",k]].replace({k: _comp_dict})
_track[k].unique()


In [None]:
# generate some binsized bins
clr_bins = cooler.binnify(hg38_chromsizes, binsize)
clr_bins = clr_bins[~clr_bins["chrom"].isin(["chrX","chrY","chrM"])]
# now annotate those bins with useful info - e.g. ID anchors ...

# now annotate bins with the anchors @10kb  ...
_bin_assigned = bioframe.overlap(clr_bins, _track)
_bin_assigned = _bin_assigned.drop(columns=["chrom_","start_","end_"])
_bin_assigned["name_"] = _bin_assigned["name_"].fillna(0).astype(int)
_bin_assigned["name_"] = _bin_assigned["name_"].astype(_cats)
# ...
_track_bins = _bin_assigned.astype({ "chrom" : str })
_track_bins = _track_bins.rename(columns={"name_":"name"})
# show intermediate results ...
display(_track_bins.head(2))
display(_track_bins.tail(3))
_track_bins["name"].value_counts().sort_index()

# Load ID/MCD, decide which ones to use and check subcomp. enrichments in them ...

In [None]:
id_anchor_fnames = {
    "mega_2X_enrichment": "ID_anchors/mega_2X_enrichment.fourth_mega.max_size.bed",
    "5hr_2X_enrichment_old": "ID_anchors/5hr_2X_enrichment.second_bulk.max_size.bed",
    "5hr_2X_enrichment": "ID_anchors/5hr_2X_enrichment.pixel_derived.bed",
    "5hr_2X_enrichment_nosing": "ID_anchors/5hr_2X_enrichment.pixel_derived.no_singletons.bed",
    "5hr_notinCyto_2X_enrichment_signal": "ID_anchors/p5notin_pCyto_anchors_2X_enrichment.pixel_derived.signal_peaks.bed",
    "5hr_2X_enrichment_signal": "ID_anchors/5hr_2X_enrichment.pixel_derived.signal_peaks.bed",
    "10hr_2X_enrichment_signal": "ID_anchors/10hrs_2X_enrichment.pixel_derived.signal_peaks.bed",
    "N93p5_2X_enrichment_signal": "ID_anchors/N93p5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "pCyto_2X_enrichment_signal": "ID_anchors/pCyto_2X_enrichment.pixel_derived.signal_peaks.bed",
    "mCyto_2X_enrichment_signal": "ID_anchors/mCyto_2X_enrichment.pixel_derived.signal_peaks.bed",
    "mega_3X_enrichment": "ID_anchors/mega_3X_enrichment.fifth_mega3x.max_size.bed",
    "MEGA_2X_enrichment": "ID_anchors/MEGAp5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGA_weaker_2X_enrichment": "ID_anchors/MEGA_plus_weak_anchors_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGAN93_2X_enrichment": "ID_anchors/MEGAN93p5_2X_enrichment.pixel_derived.signal_peaks.bed",
    "MEGAminus_2X_enrichment": "ID_anchors/MEGA_minus_ctrl_2X_enrichment.pixel_derived.signal_peaks.bed",
    "cyto_2x_enrichment": "ID_anchors/cyto_2x_enrichment.third_mCyto.max_size.bed",
}

id_anchors_dict = {}
for id_name, fname in id_anchor_fnames.items():
    id_anchors_dict[id_name] = pd.read_csv(fname, sep="\t")
    # ...
    print(f"loaded {len(id_anchors_dict[id_name]):5d} ID anchors {id_name:>20} in BED format ...")


# indices of compartments back to their names ...
name_to_descr = dict((i, name) for name, i in _comp_dict.items())

_chosen_idname = "5hr_2X_enrichment_signal"
# # _chosen_idname = "MEGAminus_2X_enrichment"
# # # _chosen_idname = "MEGA_2X_enrichment"
# _chosen_idname = "pCyto_2X_enrichment_signal"
_anchors = id_anchors_dict[_chosen_idname]
_anchors = _anchors.drop(columns=["size.1", ])

# overlap
_IPG_ID = bioframe.overlap( _track_bins, _anchors).dropna()
# check if 1 IPG overlap multiple IDs ...
_d = _IPG_ID.duplicated(subset=["chrom","start","end"]).sum()
print(f"there are {_d} IPGs that overlap more than 1 ID anchor ...")
_d = _IPG_ID.duplicated(subset=["chrom_","start_","end_"]).sum()
print(f"there are {_d} IDs that overlap more than 1 IPG ...")
# rename
_IPG_ID["name"] = _IPG_ID["name"].map(name_to_descr)
_IPG_ID_counts = _IPG_ID["name"].value_counts()
# IPG counts themselves to normalize ...
_IPG_counts = _track_bins["name"].map(name_to_descr).value_counts()

# total ID nucleotides divided by the total _track_bins nucleotides ...
_factor = _anchors.eval("end - start").sum() / _track_bins.eval("end - start").sum()

# print enrichments ...
(_IPG_ID_counts / _IPG_counts) / _factor

# Create a track that splits active subcomps into ID, and whatever is left ...

In [None]:
_track_bins_id = _track_bins.copy()
#
a1_name = _comp_dict["A1"]
a2_name = _comp_dict["A2"]
vv_name = _comp_dict["V+VI"]
b_name = _comp_dict["B4"]
# ...
_id_bins_all = bioframe.overlap( _track_bins, _anchors, return_index=True).dropna()
_id_bins_A1 = _id_bins_all.query(f"name == {a1_name}")["index"]
_id_bins_A2 = _id_bins_all.query(f"name == {a2_name}")["index"]
_id_bins_VVI = _id_bins_all.query(f"name == {vv_name}")["index"]
_id_bins_B = _id_bins_all.query(f"name == {b_name}")["index"]
# ...
_track_bins_id.loc[_id_bins_A1, "name"] = 1
_track_bins_id.loc[_id_bins_A2, "name"] = 2
_track_bins_id.loc[_id_bins_VVI, "name"] = 3
_track_bins_id["name"].value_counts()

## IPG/ID counts before and after splitting ...

In [None]:
ticklabels = [
    'none',
    "ID-A1",
    "ID-A2",
    "ID-VVI",
    'A1',
    'A2',
    'V+VI',
    'B',
]

In [None]:
ipg_counts_before = _track_bins["name"].value_counts().sort_index()
ipg_counts_before.index = ticklabels
ipg_counts_before

In [None]:
ipg_counts_after = _track_bins_id["name"].value_counts().sort_index()
ipg_counts_after.index = ticklabels
ipg_counts_after

## stack-saddles for cis ...

In [None]:
def _job(packed_data, sample):
    clr_dict, exp_dict, _atrack, view_df = packed_data
    _clr = clr_dict[sample]
    _exp = exp_dict[sample]
    from cooltools.api.saddle import saddle_stack
    _sum, _count = saddle_stack(
        _clr,
        _exp,
        _atrack,
        'cis',
        n_bins=None,
        drop_track_na=True,
        view_df=view_df
    )
    return sample, _sum, _count

# have to use daemon=False, because _job is multiprocessing-based already ...
with WorkerPool(
    n_jobs=16,
    daemon=True,
    shared_objects=( telo_clrs, telo_exps_cis, _track_bins_id, hg38_arms ),
    start_method="fork",  # little faster than spawn, fork is the fastest
    use_dill=True,
) as wpool:
    results = wpool.map(_job, telo_clrs, progress_bar=True)

# sort out the results ...
interaction_sums = {}
interaction_counts = {}
for sample, _sum, _counts in results:
    interaction_sums[sample] = _sum
    interaction_counts[sample] = _counts

In [None]:
# deal with sex chroms where needed ... for trans expected and view ...
telo_trans_filt_exps = {}
for _k, _clr in tqdm(telo_clrs.items()):
    _df = telo_exps_trans[_k].reset_index()
    m2 = _df["region2"].isin(["chrX","chrY","chrM"])
    m1 = _df["region1"].isin(["chrX","chrY","chrM"])
    telo_trans_filt_exps[_k] = _df[~(m1 | m2)]

# a view without M,X and Y chromosomes ...
sub_chrom_view = bioframe.make_viewframe(hg38_chromsizes)
bad_chroms = ["chrX","chrY","chrM"]
sub_chrom_view = sub_chrom_view[~sub_chrom_view["name"].isin(bad_chroms)]
##########################################################################

# trans saddles here yo !
def _job_trans(packed_data, sample):
    clr_dict, exp_dict, _atrack, view_df = packed_data
    _clr = clr_dict[sample]
    _exp = exp_dict[sample]
    from cooltools.api.saddle import saddle_stack
    _sum, _count = saddle_stack(
        _clr,
        _exp,
        _atrack,
        'trans',
        n_bins=None,
        drop_track_na=True,
        view_df=view_df,
    )
    return sample, _sum, _count

# have to use daemon=False, because _job_trans is multiprocessing-based already ...
with WorkerPool(
    n_jobs=16,
    daemon=True,
    shared_objects=( telo_clrs, telo_trans_filt_exps, _track_bins_id, sub_chrom_view ),
    start_method="fork",  # little faster than spawn, fork is the fastest
    use_dill=True,
) as wpool:
    results = wpool.map(_job_trans, telo_clrs, progress_bar=True)

# sort out the results ...
interaction_sums_trans = {}
interaction_counts_trans = {}
for sample, _sum, _counts in results:
    interaction_sums_trans[sample] = _sum
    interaction_counts_trans[sample] = _counts

# Let's save this results using HDF5 for conveniece and to practice ...

In [None]:
import h5py

In [None]:
# # we've got 4 dictionaries to store along with some metadata ...
with h5py.File("saddles_IPG_by_distance_wIDs.hdf5", 'x') as f:
    # add metadata just in case
    f.attrs["cis_binsize"] = binsize
    f.attrs["trans_binsize"] = binsize
    f.attrs["cre_fname"] = "GSM7990272_DLD1.360.NT.50000.E1-E128.comp_10.kmeans9_5.bed"
    f.attrs["mcd_fname"] = id_anchor_fnames[_chosen_idname]
    # CIS ...
    # interaction_sums ...
    _sums_grp = f.create_group("sums")
    # create subgroups per sample
    for _sample, _arr in interaction_sums.items():
        _sums_grp.create_dataset(_sample, data=_arr)
    # interaction_counts ...
    _sums_grp = f.create_group("counts")
    # create subgroups per sample
    for _sample, _arr in interaction_counts.items():
        _sums_grp.create_dataset(_sample, data=_arr)
    # TRANS ...
    # interaction_sums ...
    _sums_grp = f.create_group("sums_trans")
    # create subgroups per sample
    for _sample, _arr in interaction_sums_trans.items():
        _sums_grp.create_dataset(_sample, data=_arr)
    # interaction_counts ...
    _sums_grp = f.create_group("counts_trans")
    # create subgroups per sample
    for _sample, _arr in interaction_counts_trans.items():
        _sums_grp.create_dataset(_sample, data=_arr)


In [None]:
! ls -lah *.hdf5
# ! rm saddles_cre_by_distance.hdf5

# legacy plotting infrastructure ...

In [None]:
# introduce distance ranges
# 0-1mb: 0:21 bins
# 1-7Mb: 21:141 bins
# 7-50Mb: 141:1001 bins
distances = {
    "short:<1MB": slice(0,int(1_000_000/binsize)+1),
    "mid:1MB-7Mb": slice(int(1_000_000/binsize),int(7_000_000/binsize)+1),
    "long7Mb-50Mb": slice(int(7_000_000/binsize),int(50_000_000/binsize)+1),
    "all-cis": slice(None),
    "trans": slice(None),
}


imshow_kwargs = dict(
        norm=LogNorm(vmin=1/3, vmax=3),
        cmap="RdBu_r",
        interpolation="nearest",
)

distances = {
    "all-cis": slice(None),
    "trans": slice(None),
}

ticklabels = [
    'none',
    "ID-A1",
    "ID-A2",
    "ID-VVI",
    'A1',
    'A2',
    'V+VI',
    'B',
    # 'B2/B3',
    # 'B4',
]

def get_saddle_data(sample, dist_name, dist_range=None):
    """
    little convenience func - to turn local interaction_sums and interaction_counts
    into saddle data ...
    """
    if dist_name == "trans":
        _sum = np.nansum(interaction_sums_trans[sample], axis=0)
        _count = np.nansum(interaction_counts_trans[sample], axis=0)
    else:
        if dist_range is not None:
            _sum = np.nansum(interaction_sums[sample][dist_range], axis=0)
            _count = np.nansum(interaction_counts[sample][dist_range], axis=0)
        else:
            _sum = np.nansum(interaction_sums[sample], axis=0)
            _count = np.nansum(interaction_counts[sample], axis=0)
    return _sum / _count


In [None]:
sub_samples_m = [
    "mMito",
    "mTelo",
    "mCyto",
    "m5hR1R2",
    "m10hR1R2",
]
sub_samples_p = [
    "pMito",
    "pTelo",
    "pCyto",
    "p5hR1R2",
    "p10hR1R2",
]


_dfs_m = {}
_dfs_p = {}
for _dist in distances:
    _dfs_m[_dist] = {}
    _dfs_p[_dist] = {}

for sample_m, sample_p in zip(sub_samples_m, sub_samples_p):
    for _dist_name, _dist in distances.items():
        if _dist_name != "trans":
            Cm = np.nanmean(interaction_sums[sample_m][_dist], axis=0) / np.nanmean(interaction_counts[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums[sample_p][_dist], axis=0) / np.nanmean(interaction_counts[sample_p][_dist], axis=0)
        elif _dist_name == "trans":
            Cm = np.nanmean(interaction_sums_trans[sample_m][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_m][_dist], axis=0)
            Cp = np.nanmean(interaction_sums_trans[sample_p][_dist], axis=0) / np.nanmean(interaction_counts_trans[sample_p][_dist], axis=0)
        else:
            pass
        #
        _dfs_m[_dist_name][sample_m] = pd.DataFrame(Cm, index=ticklabels, columns=ticklabels).stack()
        _dfs_p[_dist_name][sample_p] = pd.DataFrame(Cp, index=ticklabels, columns=ticklabels).stack()

for _dist in distances:
    print(_dist)
    # control one ...
    _dfs_m[_dist] = pd.DataFrame(_dfs_m[_dist]) \
    .reset_index() \
    .query("(level_0 == level_1) & (level_0 != 'none')") \
    .drop(columns=["level_1"]) \
    .rename(columns={"level_0":"type"}) \
    .set_index("type").T
    # the depletion one ...
    _dfs_p[_dist] = pd.DataFrame(_dfs_p[_dist]) \
    .reset_index() \
    .query("(level_0 == level_1) & (level_0 != 'none')") \
    .drop(columns=["level_1"]) \
    .rename(columns={"level_0":"type"}) \
    .set_index("type").T
    #
    # ...
    #
    _dfs_m[_dist] = _dfs_m[_dist].add_prefix("ctrl:")
    _dfs_m[_dist].index = [s.lstrip("m") for s in _dfs_m[_dist].index]
    _dfs_m[_dist]["cell cycle"] = [0,1,2,5, 6]
    # ...
    _dfs_p[_dist] = _dfs_p[_dist].add_prefix("delta:")
    _dfs_p[_dist].index = [s.lstrip("p") for s in _dfs_p[_dist].index]
    _dfs_p[_dist]["cell cycle"] = [0,1,2,5, 6]

# pd.concat([_dfs_m["trans"], _dfs_p["trans"]],axis=1).to_csv("trans_IPG_asis.tsv",sep="\t")
# pd.concat([_dfs_m["all-cis"], _dfs_p["all-cis"]],axis=1).to_csv("allcis_IPG_asis.tsv",sep="\t")

_dfs_p[_dist]

In [None]:
_dist_name = 'trans'
_dist_name = 'all-cis'

ylims = {
    'trans': (0.93, 4.9),
    'all-cis': (0.9999999, 3.1),
}

yticks = {
    'trans': [1,2,3,4],
    'all-cis': [1,2,3],
}

ctrl_kwargs = dict(
    marker="o",
    lw=1,
    linestyle="-",
)
delta_kwargs = dict(
    marker="x",
    lw=1.25,
    linestyle=":",
)
_id_colors = ["darkred","orangered","darkgoldenrod"]
_sub_colors = _id_colors + ["cornflowerblue"]


# # plotting ...
# f, ax = plt.subplots(nrows=1, ncols=1, figsize=(6,5))

# _dfs_m[_dist_name].plot(
#     x="cell cycle",
#     y=["ctrl:A1", "ctrl:A2", "ctrl:V+VI"],
#     color=_id_colors,
#     **ctrl_kwargs,
#     ax=ax,
# )
# _dfs_p[_dist_name].plot(
#     x="cell cycle",
#     y=["delta:A1", "delta:A2", "delta:V+VI"],
#     color=_id_colors,
#     **delta_kwargs,
#     ax=ax,
# )
# ax.legend(frameon=False)
# ax.set_xticks([0,1,2,5, 6])
# ax.set_xticklabels(["Mito","Telo","Cyto","G1", "G1@10h"])
# ax.set_yscale("log")
# ax.set_title(f"ID portion of subcompatments: {_dist_name}")
# ax.set_ylim(*ylims[_dist_name])

# ax.set_yticks(yticks[_dist_name])
# # ax.set_yticklabels(yticks[_dist_name])
# # ax.yaxis.set_major_formatter(ScalarFormatter())


f, ax = plt.subplots(nrows=1, ncols=1, figsize=(6,5))

_dfs_m[_dist_name].plot(
    x="cell cycle",
    y=["ctrl:ID-A1", "ctrl:ID-A2", "ctrl:ID-VVI"],
    color=_id_colors,
    **ctrl_kwargs,
    ax=ax,
)
_dfs_p[_dist_name].plot(
    x="cell cycle",
    y=["delta:ID-A1", "delta:ID-A2", "delta:ID-VVI"],
    color=_id_colors,
    **delta_kwargs,
    ax=ax,
)
ax.legend(frameon=False)
ax.set_xticks([0,1,2,5])
ax.set_xticklabels(["Mito","Telo","Cyto","G1"])
ax.set_yscale("log")
ax.set_title(f"ID portion of subcompatments: {_dist_name}")
ax.set_ylim(*ylims[_dist_name])

ax.set_yticks(yticks[_dist_name])
# ax.set_yticklabels(yticks[_dist_name])
# ax.yaxis.set_major_formatter(ScalarFormatter())

In [None]:
f, ax = plt.subplots(nrows=1, ncols=1, figsize=(6,5))

_dfs_m[_dist_name].plot(
    x="cell cycle",
    y=["ctrl:A1", "ctrl:A2", "ctrl:V+VI"],
    color=_id_colors,
    **ctrl_kwargs,
    ax=ax,
)
_dfs_p[_dist_name].plot(
    x="cell cycle",
    y=["delta:A1", "delta:A2", "delta:V+VI"],
    color=_id_colors,
    **delta_kwargs,
    ax=ax,
)
ax.legend(frameon=False)
ax.set_xticks([0,1,2,5])
ax.set_xticklabels(["Mito","Telo","Cyto","G1"])
ax.set_yscale("log")
ax.set_title(f"Subcompartments w/o IDs: {_dist_name}")
ax.set_ylim(*ylims[_dist_name])

ax.set_yticks(yticks[_dist_name])
# ax.set_yticklabels(yticks[_dist_name])
# ax.yaxis.set_major_formatter(ScalarFormatter())

In [None]:
f, ax = plt.subplots(nrows=1, ncols=1, figsize=(6,5))

_dfs_m[_dist_name].plot(
    x="cell cycle",
    y=["ctrl:A1", "ctrl:A2", "ctrl:V+VI", "ctrl:B"],
    color=_sub_colors,
    **ctrl_kwargs,
    ax=ax,
)
_dfs_p[_dist_name].plot(
    x="cell cycle",
    y=["delta:A1", "delta:A2", "delta:V+VI", "delta:B"],
    color=_sub_colors,
    **delta_kwargs,
    ax=ax,
)
ax.legend(frameon=False)
# introduce a fake time to convey the delay between Cyto and G1 ...
ax.set_xticks([0,1,2,5])
ax.set_xticklabels(["Mito","Telo","Cyto","G1"])
ax.set_yscale("log")
ax.set_title(_dist_name)
ax.set_ylim(*ylims[_dist_name])

ax.set_yticks(yticks[_dist_name])
# # ax.set_yticklabels(yticks[_dist_name])
# ax.yaxis.set_major_formatter(ScalarFormatter())

In [None]:
sub_samples_m = [
    "mMito",
    "mTelo",
    "mCyto",
    "m5hR1R2",
    "m10hR1R2",
]
sub_samples_p = [
    "pMito",
    "pTelo",
    "pCyto",
    "p5hR1R2",
    "p10hR1R2",
]

fig, axs = plt.subplots(
    nrows=len(sub_samples_m),
    ncols=2*len(distances),
    figsize=(4*len(distances),2*len(sub_samples_m)),
    sharex=True,
    sharey=True,
)

for sample_m, sample_p, (i, _axs) in zip(sub_samples_m, sub_samples_p, enumerate(axs)):
    for jj, (_dist_name, _dist) in enumerate(distances.items()):
        axm, axp = _axs[jj], _axs[len(distances) + jj]
        Cm = get_saddle_data(sample_m, _dist_name, _dist)
        Cp = get_saddle_data(sample_p, _dist_name, _dist)
        axm.imshow(Cm[1:,1:], **imshow_kwargs)
        axp.imshow(Cp[1:,1:], **imshow_kwargs)

# annotate labels and titles ...
for jj, _dist_name in enumerate(distances):
    # m ...
    axs[0, jj].set_title(f"m-{_dist_name}")
    axs[-1,jj].set_xticks(np.arange(len(ticklabels)-1))
    axs[-1,jj].set_xticklabels(np.asarray(ticklabels)[1:], rotation="vertical")
    # p ...
    axs[0, len(distances) + jj].set_title(f"p-{_dist_name}")
    axs[-1,len(distances) + jj].set_xticks(np.arange(len(ticklabels)-1))
    axs[-1,len(distances) + jj].set_xticklabels(np.asarray(ticklabels)[1:], rotation="vertical")
for ii, _sample in enumerate(sub_samples_m):
    axs[ii,0].set_ylabel(_sample.lstrip("m"))
    axs[ii,0].set_yticks([])


In [None]:
fig, axs = plt.subplots(
    nrows=len(sub_samples_m),
    ncols=2*len(distances),
    figsize=(4*len(distances),2*len(sub_samples_m)),
    sharex=True,
    sharey=True,
)

_reidxs = [0,3,1,4,2,5,6]

for sample_m, sample_p, (i, _axs) in zip(sub_samples_m, sub_samples_p, enumerate(axs)):
    for jj, (_dist_name, _dist) in enumerate(distances.items()):
        axm, axp = _axs[jj], _axs[len(distances) + jj]
        Cm = get_saddle_data(sample_m, _dist_name, _dist)
        Cp = get_saddle_data(sample_p, _dist_name, _dist)
        axm.imshow(Cm[1:,1:][_reidxs][:,_reidxs], **imshow_kwargs)
        axp.imshow(Cp[1:,1:][_reidxs][:,_reidxs], **imshow_kwargs)

# annotate labels and titles ...
for jj, _dist_name in enumerate(distances):
    # m ...
    axs[0, jj].set_title(f"m-{_dist_name}")
    axs[-1,jj].set_xticks(np.arange(len(ticklabels)-1))
    axs[-1,jj].set_xticklabels(np.asarray(ticklabels)[1:][_reidxs], rotation="vertical")
    # p ...
    axs[0, len(distances) + jj].set_title(f"p-{_dist_name}")
    axs[-1,len(distances) + jj].set_xticks(np.arange(len(ticklabels)-1))
    axs[-1,len(distances) + jj].set_xticklabels(np.asarray(ticklabels)[1:][_reidxs], rotation="vertical")
for ii, _sample in enumerate(sub_samples_m):
    axs[ii,0].set_ylabel(_sample.lstrip("m"))
    axs[ii,0].set_yticks([])

plt.savefig("saddles_Nezar_wID.pdf", dpi=300)

In [None]:
# # # # the mix one - mp
sub_samples_m = [
    "N93m5",
    "N93m10",
]
# p ...
sub_samples_p = [
    "N93p5",
    "N93p10",
]

fig, axs = plt.subplots(
    nrows=len(sub_samples_m),
    ncols=2*len(distances),
    figsize=(4*len(distances),2*len(sub_samples_m)),
    sharex=True,
    sharey=True,
)

for sample_m, sample_p, (i, _axs) in zip(sub_samples_m, sub_samples_p, enumerate(axs)):
    for jj, (_dist_name, _dist) in enumerate(distances.items()):
        axm, axp = _axs[jj], _axs[len(distances) + jj]
        Cm = get_saddle_data(sample_m, _dist_name, _dist)
        Cp = get_saddle_data(sample_p, _dist_name, _dist)
        axm.imshow(Cm[1:,1:], **imshow_kwargs)
        axp.imshow(Cp[1:,1:], **imshow_kwargs)


# annotate labels and titles ...
for jj, _dist_name in enumerate(distances):
    # m ...
    axs[0, jj].set_title(f"m-{_dist_name}")
    axs[-1,jj].set_xticks(np.arange(len(ticklabels)-1))
    axs[-1,jj].set_xticklabels(np.asarray(ticklabels)[1:], rotation="vertical")
    # p ...
    axs[0, len(distances) + jj].set_title(f"p-{_dist_name}")
    axs[-1,len(distances) + jj].set_xticks(np.arange(len(ticklabels)-1))
    axs[-1,len(distances) + jj].set_xticklabels(np.asarray(ticklabels)[1:], rotation="vertical")
for ii, _sample in enumerate(sub_samples_m):
    axs[ii,0].set_ylabel(_sample.replace("m","-"))
    axs[ii,0].set_yticks([])


In [None]:
sub_samples =[
        "m10hR1R2",
        "p10hR1R2",
        "mp10hR1R2",
    ]

fig, axs = plt.subplots(
    nrows=len(sub_samples),
    ncols=len(distances),
    figsize=(2*len(distances),2*len(sub_samples)),
    sharex=True,
    sharey=True,
)

for ii, _sample in enumerate(sub_samples):
    for jj, (_dist_name, _dist) in enumerate(distances.items()):
        ax = axs[ii, jj]
        saddle_data = get_saddle_data(_sample, _dist_name, _dist)
        ax.imshow(saddle_data[1:,1:], **imshow_kwargs)
        ax.set_xticks([])
        ax.set_yticks([])

# annotate labels and titles ...
for jj, _dist_name in enumerate(distances):
    axs[0,jj].set_title(f"{_dist_name}")
    axs[-1,jj].set_xticks(np.arange(len(ticklabels)-1))
    axs[-1,jj].set_xticklabels(np.asarray(ticklabels)[1:], rotation="vertical")
for ii, _sample in enumerate(sub_samples):
    axs[ii,0].set_ylabel(_sample)


In [None]:
sub_samples = [
        "N93m10",
        "N93p10",
        "N93mp10",
    ]

fig, axs = plt.subplots(
    nrows=len(sub_samples),
    ncols=len(distances),
    figsize=(2*len(distances),2*len(sub_samples)),
    sharex=True,
    sharey=True,
)

for ii, _sample in enumerate(sub_samples):
    for jj, (_dist_name, _dist) in enumerate(distances.items()):
        ax = axs[ii, jj]
        saddle_data = get_saddle_data(_sample, _dist_name, _dist)
        ax.imshow(saddle_data[1:,1:], **imshow_kwargs)
        ax.set_xticks([])
        ax.set_yticks([])

# annotate labels and titles ...
for jj, _dist_name in enumerate(distances):
    axs[0,jj].set_title(f"{_dist_name}")
    axs[-1,jj].set_xticks(np.arange(len(ticklabels)-1))
    axs[-1,jj].set_xticklabels(np.asarray(ticklabels)[1:], rotation="vertical")
for ii, _sample in enumerate(sub_samples):
    axs[ii,0].set_ylabel(_sample)
