# Comparaing simulations against the data from Mitchell

**how to use this**:
1. generate the parameters: run the notebook `parameters.ipynb`
1. generate the data with `qsub -t 1:4500 hsc-draft/simulations.sh hsc-draft/parameters.txt`
1. run this notebook

Note: `qsub` is the command to submit jobs via the Univa Grid engine available at QMUL. Another comand might be used with other job schedulers (e.g. Slurm, Apache Hadoop...).

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib as mpl
from matplotlib.lines import Line2D
import numpy as np
import pandas as pd
import socket
import seaborn as sns
import json
from pathlib import Path
from typing import List
from scipy import stats

from hscpy import mitchell, realisation, parameters, sfs, variant, COLORS
from hscpy.figures import PlotOptions, simulations, ToCellFrequency
from hscpy.figures import burden as burden_fig
from hscpy.figures import sfs as sfs_fig

from futils import parse_version, snapshot

LATEST = True
SAVEFIG = True
BIGLABELS = False
FIGSIZE = [5, 3] if BIGLABELS else [6.4, 4.8]  # default matplotlib
EXTENSION = ".svg"
PATH2HSC = Path("~").expanduser() / "hsc"
PATH2DATA = Path("~").expanduser() / "hsc-draft/data"
PLOT_OPTIONS = PlotOptions(figsize=FIGSIZE, extension=EXTENSION, save=SAVEFIG)
NCELLS = 100_000

In [None]:
if socket.gethostname() == "5X9ZYD3":
    PATH2MITCHELL = Path("/mnt/c/Users/terenz01/Documents/SwitchDrive/PhD/hsc")
elif socket.gethostname() == "LAPTOP-CEKCHJ4C":
    PATH2MITCHELL = Path("/mnt/c/Users/fra_t/Documents/PhD/hsc")
else:
    PATH2MITCHELL = Path("~").expanduser()

In [None]:
%%bash -s "$PATH2HSC" --out version
$1/target/release/hsc  --version

In [None]:
if LATEST:
    VERSION = parse_version(version)
else:
    VERSION = "sfs/v3.0.6"
PATH2SAVE = Path(f"./{VERSION}")
PATH2SIMS = PATH2DATA / Path(f"./simulations/{VERSION}")
print("Running hsc with version:", VERSION)

In [None]:
donors = mitchell.donors()
donors

## Single-cell mutational burden

In [None]:
%%time
# mitchell's donors
burden_donors = list()
for donor in donors.itertuples():
    print("loading burden for donor", donor.name)
    burden_donors.append(
        mitchell.burden_donor_mitchell(donor.name, donor.age, PATH2MITCHELL, False)
    )

In [None]:
# DONORS only
fig, ax = plt.subplots(1, 1)
means, variances = list(), list()
for b, c in zip(burden_donors, mcolors.TABLEAU_COLORS.values()):
    tot_cells = sum(b[3].values())
    assert tot_cells == donors[donors.name == b[0]].cells.iloc[0]
    array = snapshot.array_from_hist(b[3])
    means.append((b[0], b[1], array.mean()))
    variances.append((b[0], b[1], array.var()))

    burden_fig.plot_burden(
        ax,
        b[3],
        normalise=True,
        options=PLOT_OPTIONS,
        ls="-",
        # marker=".",
        lw=1,
        # alpha=0.5,
        color=c,
        label=f"{donors.loc[donors.name == b[0], 'age'].iloc[0]} y.o.",
    )
ax.legend(fontsize="x-small", ncols=2)
if PLOT_OPTIONS.save:
    fig.savefig(f"./burden{PLOT_OPTIONS.extension}")
fig.show()

In [None]:
burden_sims = dict()
for donor in donors.itertuples():
    burden_sims[donor.name] = realisation.load_all_burden_by_age(
        PATH2SIMS / Path(f"{donor.cells}cells/burden")
    )[donor.age]

In [None]:
# check that the burden in the sims match the data
means_var_s = list()
for b in burden_donors:
    print(b[0])
    fig, ax = plt.subplots(1, 1)
    # sims
    pooled = snapshot.Uniformise.pooled_distribution(
        [bur.burden for bur in burden_sims[b[0]]]
    )
    m_, v_ = realisation.compute_mean_variance(pooled)
    burden_fig.plot_burden(
        ax,
        pooled,
        normalise=False,
        color=COLORS["grey_dark"],
        marker=".",
        alpha=0.5,
        label=f"{len(burden_sims[b[0]])} sims",
        options=PLOT_OPTIONS,
    )
    means_var_s.append((b[0], b[1], m_, v_))
    # data
    muts, counts = list(b[3].keys()), list(b[3].values())
    tot_cells = sum(counts)
    assert tot_cells == donors[donors.name == b[0]].cells.iloc[0]
    burden_fig.plot_burden(
        ax,
        b[3],
        normalise=True,
        color=COLORS["orange"],
        marker=".",
        bins=10,
        alpha=0.5,
        label=f"{b[1]} y.o.",
        options=PLOT_OPTIONS,
    )
    ax.legend(fontsize="small")
    if PLOT_OPTIONS.save:
        fig.savefig(f"./burden_{b[0]}{PLOT_OPTIONS.extension}")
    fig.show()

In [None]:
fig, ax = plt.subplots(1, 1, layout="tight")
means_var_df_s = pd.DataFrame(means_var_s, columns=["name", "age", "mean", "variance"])
means_df = pd.DataFrame(means, columns=["name", "age", "mean"])
variances_df = pd.DataFrame(variances, columns=["name", "age", "variance"])

means_df.plot(
    ax=ax, x="age", y="mean", marker=".", color=COLORS["orange"], label="mean data", lw=1
)
means_var_df_s.plot(ax=ax, x="age", y="mean", color=COLORS["grey_dark"], label="mean sims", lw=1)

variances_df.plot(
    ax=ax, x="age", y="variance", ls="-.", color=COLORS["orange"], label="variance data", lw=1
)
means_var_df_s.plot(ax=ax, x="age", y="variance", color=COLORS["grey_dark"], ls="--", label="variance sims", lw=1)
ax.set_yscale("log")
ax.set_xlabel("Age (years)", fontsize=11)
ax.set_ylabel("Single-cell burden", fontsize=11)
ax.tick_params(axis='both', which='major', labelsize=11)
ax.legend(
    fontsize="x-small",
    bbox_to_anchor=(-0.05, 1.0, 1.07, .10), 
    loc='lower left', 
    mode="expand", 
    ncols=2, 
    # frameon=False,
    bbox_transform=ax.transAxes
)

# ax.legend().set_visible(False)
if PLOT_OPTIONS.save:
    fig.savefig(f"./burden_mean_var{PLOT_OPTIONS.extension}")
fig.show()

In [None]:
m1, m2 = (
    snapshot.array_from_hist(burden_donors[0][3]).mean(),
    snapshot.array_from_hist(burden_donors[1][3]).mean(),
)
print(
    f"The mean single-cell mut burden of the two neoborns computed from the genotype matrix is: {m1:.2f}, {m2:.2f}",
)
print(
    "from the sims:", means_var_df_s.loc[means_var_df_s["age"] == 0, "mean"].to_numpy()
)
m1, m2 = (
    snapshot.array_from_hist(burden_donors[0][3]).var(),
    snapshot.array_from_hist(burden_donors[1][3]).var(),
)
print(
    f"The variance single-cell mut burden of the two neoborns computed from the genotype matrix is: {m1:.2f}, {m2:.2f}",
)
print(
    "from the sims:",
    means_var_df_s.loc[means_var_df_s["age"] == 0, "variance"].to_numpy(),
)

In [None]:
# regress neutral donors
fig, ax = plt.subplots(1, 1)
# neutral donors have no detected exp clone
neutral_donors = {"CB002", "KX001", "SX001"}
x, y = (
    means_df.loc[means_df.name.isin(neutral_donors), "age"],
    means_df.loc[means_df.name.isin(neutral_donors), "mean"],
)
A = np.vstack([x, np.ones(len(x))]).T
m, c = np.linalg.lstsq(A, y, rcond=None)[0]
ax.plot(donors.age, m * donors.age + c, "black", linewidth=1.5, linestyle="--")
for donor in donors.itertuples():
    d_burden = [d for d in burden_donors if d[0] == donor.name][0]
    array = snapshot.array_from_hist(d_burden[3])
    ax.plot([d_burden[1]] * array.shape[0], array, ls="", marker=".", alpha=0.1)
    ax.plot([d_burden[1]], array.mean(), ls="", marker="x", mew=2, color="black")
ax.plot()
print(m, c)
ax.set_ylabel("Single-cell burden")
ax.set_xlabel("Age (years)")
ax.text(x=1, y=1500, s=f"m={m:.2f}")
if PLOT_OPTIONS.save:
    fig.savefig(f"./burden_regression{PLOT_OPTIONS.extension}")
fig.show()

## Expanded clones

In [None]:
%%time
counts_sims = dict()

for donor in donors[["name", "age", "cells"]].itertuples():
    print(
        f"\tloading sims variant counts for donor {donor.name} with {donor.cells} cells"
    )
    counts_sims.update(
        variant.load_all_detected_var_counts_by_age(
            PATH2SIMS / Path(f"{donor.cells}cells/variant_fraction"), 0.01
        )
    )

In [None]:
counts = variant.variant_counts_detected_df(counts_sims)
fig, ax = plt.subplots(1, 1)
sns.lineplot(
    counts,
    x="age",
    y="variant counts detected",
    errorbar=lambda x: (np.min(x), np.max(x)),
    color="grey",
    alpha=0.3,
    ax=ax,
    label="min-max",
)
ax.plot(donors.age, donors.clones, marker=".", color="#d95f0e", label="Mitchell")
ax.legend()
plt.show()
print(counts[["variant counts detected", "age"]].groupby("age").describe())

## SFS 
Combine different data for this plot:
1. 1/f^2 sampled prediction (computed here in python)
2. Mitchell's SFS (loaded and computed here in python)
2. 1/f sampled prediction from Nate's (loaded from external file)
3. SFS from simulations (need to generate them)

### Generate/load/compute the data
####  1. 1/f2 predictions


In [None]:
%%time
# compute the correction for the sims' SFS with sampled distributions from
# https://www.biorxiv.org/content/10.1101/2022.11.07.515470v2
corrected_variants_one_over_1_squared = dict()
for donor in donors.itertuples():
    print(
        f"apply sampling correction to SFS of donor {donor.name} with age {donor.age} with sample size {donor.cells}"
    )
    corrected_variants_one_over_1_squared[donor.name] = realisation.compute_variants(
        realisation.Correction.ONE_OVER_F_SQUARED,
        pop_size=NCELLS,
        sample_size=donor.cells,
    )

#### 2. Mitchell's SFS

In [None]:
%%time
# there are two donors with the same age 0
mitchell_sfs = {
    donor.name: mitchell.sfs_donor_mitchell(
        donor.name, donor.age, PATH2MITCHELL, remove_indels=False
    )
    for donor in donors.itertuples()
}

In [None]:
assert all(
    [
        m[2] == donors.loc[donors.name == name, "cells"].squeeze()
        for name, m in mitchell_sfs.items()
    ]
), "number of cells loaded for the SFSF do not match those in donors"

In [None]:
ages, entropies = list(), list()
for don in mitchell_sfs.values():
    ag, sfs_ = don[1], snapshot.array_from_hist(don[-1])
    entrop = stats.entropy(sfs_)
    print(f"{don[0]} age {ag} with entropy {entrop}")
    ages.append(ag)
    entropies.append(entrop)
    
fig, ax = plt.subplots(1, 1)
ax.plot(ages, entropies, marker=".")
ax.set_ylabel("Entropy")
ax.set_xlabel("Age (years)")
plt.show()

#### 3. 1/f sample predictions

In [None]:
# theoretical homeostatic neutral SFS data, from Nate's paper in Elife: for each patient (skipping the neonates)
# I evolved until their specific age, and then sampled to the same size as in the data
mapping = {
    age: f"predictions_1_over_f/homeostasisSFS_pid{i}.csv"
    for i, age in enumerate(donors.age.unique().tolist()[1:], 3)
}
mapping

#### 4. SFS from simulations
The data have been generated with the cmd `sfs.sh parameters.txt`.

In [None]:
%%time
sfs_sims = dict()
for donor in donors.itertuples():
    sfs_sims[donor.name] = realisation.load_all_sfs_by_age(
        PATH2SIMS / Path(f"{donor.cells}cells/sfs")
    )[donor.age]

#### Plots

In [None]:
scaling = 1
for name in donors.name.unique().tolist()[::-1]:
    age = donors.loc[donors.name == name, "age"].squeeze()
    fig, ax = plt.subplots(1, 1, layout="tight", figsize=[ele * scaling for ele in mpl.rcParams['figure.figsize']])
    sfs_fig.plot_ax_sfs_predictions_data_sims(
        ax,
        donor=donors[donors.name == name].squeeze(),
        corrected_one_over_1_squared=corrected_variants_one_over_1_squared[name],
        sfs_sims_donor=None,
        mitchell_sfs=mitchell_sfs[name][3],
        one_over_f_csv=mapping.get(age),
        idx_sim2plot=None,
        plot_options=PLOT_OPTIONS,
        mew=1.3 * scaling,
        lw=1.5 * scaling,
        markersize=4 * scaling,
    )

    ax.text(
        x=0.55,
        y=0.8,
        s=f"donor {age} y.o.",
        fontsize=(mpl.rcParams['font.size'] - 6) * scaling,
        transform=ax.transAxes,
    )
    ax.set_xlabel(ax.get_xlabel(), fontsize=(mpl.rcParams['font.size'] - 5) * scaling)
    ax.set_ylabel(ax.get_ylabel(), fontsize=(mpl.rcParams['font.size'] - 5) * scaling)
    ax.tick_params(axis='both', which='major', labelsize=8 * scaling, width=0.6, length=4)    
    ax.tick_params(axis='both', which='minor', labelsize=8 * scaling, width=0.5, length=2)    
    handles, labels = ax.get_legend_handles_labels()
    """
    ax.legend(
        fontsize="small",
        loc="upper right",
        frameon=False,
    )
    if name not in selected:
        ax.legend(
            fontsize="small",
            loc="upper right",
            frameon=False,
        )
    else:
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(
            handles=[handles[-1]],
            labels=[labels[-1]],
            fontsize="medium",
            loc="upper right",
            frameon=False,
            handletextpad=0,
        )
    """
    if PLOT_OPTIONS.save:
        fig.savefig(f"./sfs_age{age}_{name}{PLOT_OPTIONS.extension}")
    fig.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 1))
legend_elements = [
    Line2D([1], [1], color="black", alpha=0.8, lw=4, label="growth scaling law"),
    Line2D(
        [0],
        [0],
        color="black",
        ls="--",
        alpha=0.8,
        lw=4,
        label="homeostatsis scaling law",
    ),
    Line2D(
        [0],
        [0],
        marker="x",
        ls="",
        mew=4,
        color="#d95f0e",
        label="data",
        markerfacecolor="g",
        markersize=13,
    ),
]

ax.legend(
    handles=legend_elements,
    mode="extend",
    ncols=5,
    handletextpad=0.4,
)
ax.axis("off")
plt.show()

In [None]:
def f_test(f_obs, f_exp):
    """From scikit learn source code
    https://github.com/scikit-learn/scikit-learn/blob/f07e0138bfee41cd2c0a5d0251dc3fe03e6e1084/sklearn/feature_selection/_univariate_selection.py#L501C1-L507C62
    """
    # f_test assumes that the errors are normally distributed and uncorrelated, 
    # which doesn't make sense in our case
    
    # not centered
    deg_of_freedom = f_obs.size - 1

    corr_coef_squared = f_exp**2

    f_statistic = corr_coef_squared / (1 - corr_coef_squared) * deg_of_freedom
    p_values = stats.f.sf(f_statistic, 1, deg_of_freedom)
    return p_values

In [None]:
def prepare_sfs_with_uniformisation_for_chisquare(sfs_target: snapshot.Histogram, sfs_sim: snapshot.Histogram):
    f_obs, f_exp = snapshot.Uniformise.uniformise_histograms([sfs_target, sfs_sim]).make_histograms()
    f_obs, f_exp = np.fromiter(f_obs.values(), dtype=float), np.fromiter(f_exp.values(), dtype=float)
    # rm state 0
    f_obs, f_exp = f_obs[1:], f_exp[1:]
    idx_lower_bound = 1
    # find the first ele that is 0
    idx_obs, idx_exp = np.argmin(f_obs), np.argmin(f_exp)
    idx_upper_bound = min([idx_obs, idx_exp])
    f_obs, f_exp = f_obs[idx_lower_bound:idx_upper_bound], f_exp[idx_lower_bound:idx_upper_bound]
    f_obs, f_exp = np.log10(f_obs), np.log10(f_exp)
    assert len(f_obs) == len(f_exp)
    #f_obs /= f_obs.sum()
    #f_exp /= f_exp.sum()
    # assert f_obs.sum() == 1 == f_exp.sum(), f"{f_obs.sum()}, {f_exp.sum()}"
    
    mean_squared_log_error = np.mean(np.power(np.log(f_obs + 1)  - np.log(f_exp + 1), 2))
    rmsre = np.mean(np.power((f_obs - f_exp) / f_obs, 2))
    mape = np.mean(np.abs(f_obs - f_exp) / f_obs)
    # f_test assumes that the errors are normally distributed and uncorrelated, 
    # which doesn't make sense in our case
    return f_obs, f_exp, idx_lower_bound, idx_upper_bound, mean_squared_log_error, rmsre, mape, f_test(f_obs, f_exp)

In [None]:
selected = ["CB002", "KX002", "KX008"]
for name in donors.name.unique().tolist():
    idx_available = [sfs_.parameters.idx for sfs_ in sfs_sims[name]]
    # TODO change with best fit from ABC!
    best_idx = idx_available[0]
    
    print(f"there are {len(idx_available)} runs for {name}")
    age = donors.loc[donors.name == name, "age"].squeeze()
    fig, ax = plt.subplots(1, 1, layout="tight")
    sfs_fig.plot_ax_sfs_predictions_data_sims(
        ax,
        donor=donors[donors.name == name].squeeze(),
        corrected_one_over_1_squared=corrected_variants_one_over_1_squared[name],
        sfs_sims_donor=sfs_sims[name],
        mitchell_sfs=mitchell_sfs[name][3],
        one_over_f_csv=mapping.get(age),
        idx_sim2plot=best_idx,
        plot_options=PLOT_OPTIONS,
        mew=1.3 * scaling,
        lw=1.5 * scaling,
        markersize=4 * scaling,
    )
    best_fit = [ele for ele in sfs_sims[name] if ele.parameters.idx == best_idx][0].sfs
    f_obs, f_exp, idx_lower_bound, idx_upper_bound, mean_squared_log_error, rmsre, mape = sfs.prepare_sfs_with_uniformisation_for_test(mitchell_sfs[name][3], best_fit)
    res = stats.ks_2samp(f_obs, f_exp)
    ax.text(
        x=0.6,
        y=0.8,
        s=f"donor {age} y.o.\n{donors.loc[donors.name == name, 'cells'].squeeze()} cells",
        transform=ax.transAxes,
        fontsize=(mpl.rcParams['font.size'] - 5) * scaling,
    )   
    ax.text(
        x=0.6,
        y=0.725,
        s=r"$\mathregular{{p_{{KS}}={{{:.2f}}}}}$".format(res.pvalue),
        transform=ax.transAxes,
        fontsize=(mpl.rcParams['font.size'] - 5) * scaling,
    )
    ax.set_xlabel(ax.get_xlabel(), fontsize=(mpl.rcParams['font.size'] - 5) * scaling)
    ax.set_ylabel(ax.get_ylabel(), fontsize=(mpl.rcParams['font.size'] - 5) * scaling)
    ax.tick_params(axis='both', which='major', labelsize=8 * scaling, width=0.6, length=4)    
    ax.tick_params(axis='both', which='minor', labelsize=8 * scaling, width=0.5, length=2) 
    print(res.pvalue, res.statistic, idx_lower_bound, idx_upper_bound, mean_squared_log_error, rmsre)
    #ax.axvline((idx_lower_bound + 1) / donors.loc[donors.name == name, "cells"].squeeze())
    #ax.axvline(idx_upper_bound / donors.loc[donors.name == name, "cells"].squeeze())
    """
    ax.axvline(idx / donors.loc[donors.name == name, "cells"].squeeze())
    ax.text(
        x=0.4,
        y=0.7,
        s=f"$\mathrm{{p_{{\chi^2}}}}$ = {res.pvalue:.1e}",
        transform=ax.transAxes,
    )
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(
        fontsize="small",
        loc="upper right",
        frameon=False,
    )
    if name not in selected:
        ax.legend(
            fontsize="small",
            loc="upper right",
            frameon=False,
        )
    else:
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(
            handles=[handles[-1]],
            labels=[labels[-1]],
            fontsize="medium",
            loc="upper right",
            frameon=False,
            handletextpad=0,
        )
    """
    if PLOT_OPTIONS.save:
        fig.savefig(f"./sfs_age{age}_{name}{PLOT_OPTIONS.extension}")
    fig.show()

In [None]:
for n in donors.name.unique():
    stuff = sfs.prepare_sfs_with_uniformisation_for_test(mitchell_sfs[name][3], mitchell_sfs[n][3])
    a, b = stuff[0], stuff[1]
    res = stats.ks_2samp(a, b)
    print(f"p={res.pvalue:.2f}")
    
    fig, ax = plt.subplots(1, 1, figsize=(5, 4))
    sfssss = snapshot.Uniformise.uniformise_histograms([mitchell_sfs[name][3], mitchell_sfs[n][3]]).make_histograms()
    sfs_fig.plot_sfs(ax, sfssss[0], normalise=True, options=PLOT_OPTIONS, marker=".", label=name)
    sfs_fig.plot_sfs(ax, sfssss[1], normalise=True, options=PLOT_OPTIONS, marker=".", label=n)
    ax.legend()
    ax.axvline(stuff[3])
    ax.axvline(stuff[2] + 1)
    plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 1))
legend_elements = [
    Line2D([1], [1], color="black", alpha=0.8, lw=4, label="growth scaling law"),
    Line2D(
        [0],
        [0],
        color="black",
        ls="--",
        alpha=0.8,
        lw=4,
        label="homeostatic scaling law",
    ),
    Line2D(
        [0],
        [0],
        marker="x",
        ls="",
        mew=4,
        color="#d95f0e",
        label="data",
        markerfacecolor="g",
        markersize=13,
    ),
    Line2D(
        [0],
        [0],
        marker="o",
        ls="",
        mew=1,
        color="grey",
        label="single simulation",
        markersize=12,
    ),
    Line2D([0], [0], color="grey", alpha=0.6, lw=4, label="simulation average"),
]

ax.legend(
    handles=legend_elements,
    mode="extend",
    ncols=5,
    handletextpad=0.4,
)
ax.axis("off")
plt.show()

In [None]:
for donor in donors.itertuples():
    fig, ax = plt.subplots(1, 1, layout="tight")
    print(donor.name)

    for sfs_s in sfs_sims[donor.name]:
        cdf_x_sim, cdf_y_sim = realisation.cdf_from_dict(sfs_s.sfs)
        ax.plot(cdf_x_sim / donor.cells, cdf_y_sim, color="#bdbdbd", alpha=0.1)

    cdf_x_target, cdf_y_target = realisation.cdf_from_dict(mitchell_sfs[donor.name][3])
    ax.plot(cdf_x_target / donor.cells, cdf_y_target, color="#d95f0e")
    ax.set_xscale("log")
    ax.set_ylabel("Cdf")
    ax.set_xlabel(r"Variant frequency $f$")
    ax.text(
        x=0.5,
        y=0.2,
        s=f"donor {donor.age} y.o.",
        fontsize=12,
        transform=ax.transAxes,
    )
    ax.set_ylim([0.9, 1])
    plt.show()

### REDO THE SAME PLOT with sims from ABC

In [None]:
idx2load = [320490, 274050, 619150, 73200, 295810, 434910, 309270, 94630]
# load the data from abc
PATH2SIMS = Path("/data/scratch") / f"hfx923/hsc-draft/{VERSION}"


for r in donors[["name", "age", "cells"]].itertuples():
    path2sfs_abc = Path(PATH2SIMS / f"{r.cells}cells/sfs/")
    print(f"\tloading sims SFS for donor {r.name} with {r.cells} cells")
    sfs_sims_abc = Path(PATH2SIMS / f"{r.cells}cells/sfs/")
    sfs_sims[r.name].extend(
        realisation.load_all_sfs_by_age(path2sfs_abc, idx2load)[r.age]
    )

In [None]:
def prepare_sfs_for_ks(sfs_: snapshot.Histogram):
    # normalised
    # max_ = max(sfs_.values())
    samp = snapshot.array_from_hist(sfs_)
    #samp = samp[samp > 1]
    #return samp / max_
    return samp

In [None]:
selected = ["CB002", "KX002", "KX008"]
for i, (name, idx2plot) in enumerate(
    zip(donors.name.unique().tolist(), [idx2load[0]] + idx2load)
):  # trick for using twice the donor 0
    print(name, idx2plot)
    age = donors.loc[donors.name == name, "age"].squeeze()
    fig, ax = plt.subplots(1, 1, layout="constrained", figsize=PLOT_OPTIONS.figsize)
    sfs_fig.plot_ax_sfs_predictions_data_sims(
        ax,
        donor=donors[donors.name == name].squeeze(),
        corrected_one_over_1_squared=corrected_variants_one_over_1_squared[name],
        sfs_sims_donor=sfs_sims[name],
        mitchell_sfs=mitchell_sfs[name][3],
        one_over_f_csv=mapping.get(age),
        idx_sim2plot=idx2plot,
        plot_options=PLOT_OPTIONS,
    )
    
    best_fit = [ele for ele in sfs_sims[name] if ele.parameters.idx == idx2plot][0].sfs
    # res = stats.ks_2samp(prepare_sfs_for_ks(best_fit), prepare_sfs_for_ks(mitchell_sfs[name][3]))
    # res = stats.ks_2samp(prepare_sfs_for_ks(mitchell_sfs[name][3]), prepare_sfs_for_ks(mitchell_sfs[name][3]))
    # res = stats.ks_2samp(prepare_sfs_for_ks(best_fit), prepare_sfs_for_ks(best_fit))
    best_fit2remove = sfs_sims[name][0].sfs
    res = stats.ks_2samp(prepare_sfs_for_ks(best_fit2remove), prepare_sfs_for_ks(best_fit))

    ax.text(
        x=0.5,
        y=0.9,
        s=f"donor {age} y.o., {donors.loc[donors.name == name, 'cells'].squeeze()} cells",
        transform=ax.transAxes,
        fontsize="small"
    )
    """
    ax.text(
        x=0.4,
        y=0.7,
        s=f"$\mathrm{{p_{{\chi}}}}$ = {res.pvalue:.1e}",
        transform=ax.transAxes,
    )
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(
        fontsize="small",
        loc="upper right",
        frameon=False,
    )
    if name not in selected:
        ax.legend(
            fontsize="small",
            loc="upper right",
            frameon=False,
        )
    else:
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(
            handles=[handles[-1]],
            labels=[labels[-1]],
            fontsize="medium",
            loc="upper right",
            frameon=False,
            handletextpad=0,
        )
    """
    if PLOT_OPTIONS.save:
        fig.savefig(f"./sfs_age{age}_{name}{PLOT_OPTIONS.extension}")
    fig.show()