## Inference of the parameter using ABC

We take 650 samples of the posterior distributions by generating 65'000 realisations of the [HSC dynamics](https://github.com/fraterenz/hsc), i.e. we keep 1% of the runs.

1. run `priors.ipynb`
2. `qsub particles.sh parameters.txt`
3. run this notebook

Note: `qsub` is the command to submit jobs via the Univa Grid engine available at QMUL. Another comand might be used with other job schedulers (e.g. Slurm, Apache Hadoop...).

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import json
import socket
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import random
from scipy import stats
from matplotlib import ticker
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from pathlib import Path
from typing import Union, List, Tuple # rm Tuple, Union
from futils import parse_version, snapshot
from hscpy.figures import PlotOptions
from hscpy.figures import abc as abc_fig
from hscpy.figures import sfs as sfs_fig
from hscpy import realisation, mitchell, abc, variant, sfs, COLORS

PATH2BIN = Path("~").expanduser() / "hsc/target/release"
assert PATH2BIN.is_dir()

In [None]:
%%bash -s "$PATH2BIN" --out version
$1/hsc --version

In [None]:
USE_SCRATCH = True
SAVEFIG = True
SHOW_PRIORS = True
LATEST = False
SEED = 36

# the higher, the less precise and thus the more runs
# we set to these values because we aim to keep approx 1% of the runs
PROPORTION_RUNS_DISCARDED, QUANTILE = 0.1, 0.3173

FONTSIZE = 12
options = PlotOptions(figsize=(3.306, 2.639), extension="svg", save=SAVEFIG)
plt.rcParams["figure.figsize"] = options.figsize

random.seed(SEED)
DETECTION_THRESHOLD = 0.01

In [None]:
if LATEST:
    VERSION = parse_version(version)
else:
    VERSION = "v4.3.6"
PATH2SAVE = Path(f"./{VERSION}")

print("Running hsc with version:", VERSION)

if USE_SCRATCH:
    PATH2SIMS = Path("/data/scratch")
else:
    PATH2SIMS = Path("/data/home")
PATH2SIMS /= f"hfx923/hsc-draft/{VERSION}"

if socket.gethostname() == "5X9ZYD3":
    PATH2MITCHELL = Path("/mnt/c/Users/terenz01/Documents/SwitchDrive/PhD/hsc")
elif socket.gethostname() == "LAPTOP-CEKCHJ4C":
    PATH2MITCHELL = Path("/mnt/c/Users/fra_t/Documents/PhD/hsc")
else:
    PATH2MITCHELL = Path("~").expanduser()

In [None]:
# 1. we dont use summary `Summary_cut.csv` because
#    it contains errors (i.e. duplicated and missing entries)
# 2. we dont use donor CB001 because we already have CB002 at age 0
# 3. we dont use donor KX007 because the genotype matrix is wrong,
#    they have uploded twice the same donor
donors = mitchell.donors()
donors.drop(index=donors[donors.name == "CB001"].index, inplace=True)
donors

In [None]:
# bins for abc
bins_mu = np.arange(0, 31, 2)
bins_tau = np.arange(0, 11, 1)
bins_s = np.arange(0, 0.41, 0.02)
bins_std = np.arange(0, 0.105, 0.01)

bins = abc_fig.Bins(bins_s, bins_std, bins_mu, bins_tau)

# Load simualted data both SFS and variant fractions

## Load data

In [None]:
%%time
sfs_sims, counts_sims, variant_frac_sims = dict(), dict(), dict()

for r in donors[["name", "age", "cells"]].itertuples():
    print(f"loading data for donor {r.name}")
    print(f"\tloading sims SFS for donor {r.name} with {r.cells} cells")
    path2sfs = Path(PATH2SIMS / f"{r.cells}cells/sfs/")
    sfs_sims.update(realisation.load_all_sfs_by_age(path2sfs))

    print(f"\tloading sims detected variant counts for donor {r.name} with {r.cells} cells")
    counts_sims.update(
        variant.load_all_detected_var_counts_by_age(
            PATH2SIMS / f"{r.cells}cells/variant_fraction", DETECTION_THRESHOLD
        )
    )
    
    print(f"\tloading sims variant counts for donor {r.name} with {r.cells} cells")
    variant_frac_sims.update(
        variant.load_all_var_frac_by_age(
            PATH2SIMS / f"{r.cells}cells/variant_fraction"
        )
    )

In [None]:
# load neutral sims if present: can be generated with command
# `qsub -t 1:100 hsc-draft/simulations.sh hsc-draft/parameters_neutral.txt`
# with hsc-draft/parameters_neutral.txt being
# -c 100000 -y 82 -r 1 --sequential --neutral --subsamples=390,407,380,362,361,367,451,328 --snapshots=0,29,38,48,63,76,77,81 /data/home/hfx923/hsc-draft/v4.3.7/neutral exp-moran --tau-exp 0.03619 --mu-exp 21.728217763711537 --mu-division-exp 1.14 --mu-background-exp 35.64 --tau 1 --mu 21.728217763711537 --mu-division 1.14 --mu-background 14.203133756172337
# repeated 100 times
sfs_sims_neutral = dict()

for r in donors[["name", "age", "cells"]].itertuples():
    print(f"loading data for donor {r.name}")
    print(f"\tloading sims SFS for donor {r.name} with {r.cells} cells")
    path2sfs = Path('/data/home/hfx923/hsc-draft/v4.3.7/neutral') / f"{r.cells}cells/sfs/"
    try:
        sfs_sims_neutral.update(realisation.load_all_sfs_by_age(path2sfs))
    except AssertionError:
        print("Neutral sims not found")
        break

In [None]:
counts = variant.variant_counts_detected_df(counts_sims)
fig, ax = plt.subplots(1, 1)
sns.lineplot(
    counts,
    x="age",
    y="variant counts detected",
    errorbar=lambda x: (np.min(x), np.max(x)),
    ax=ax,
    label="min-max",
)
sns.lineplot(
    counts,
    x="age",
    y="variant counts detected",
    errorbar="sd",
    ax=ax,
    color="orange",
    label="std",
)
ax.legend()
plt.show()
print(counts[["variant counts detected", "age"]].groupby("age").describe())

# Run ABC on the real data

## Load the data from Mitchell's paper
We have excluded two donors from the ABC:
1. exclude KX007 bc they have uploded twice the same donor
3. exclude CB001 bc it maps to to the same timepoint as CB002 (same age 0)

In [None]:
%%time
target_sfs = {
    r.age: mitchell.sfs_donor_mitchell(
        r.name, r.age, PATH2MITCHELL, remove_indels=False
    )
    for r in donors[["name", "age"]].itertuples()
}

"""
largest_vaf = {
    r.age: largest_vaf_donor_mitchell(
        r.name, r.age, PATH2MITCHELL, remove_indels=False
    )
    for r in donors[["name", "age"]].itertuples()
}
"""
    
assert [
    ele[2] for ele in target_sfs.values()
] == donors.cells.tolist(), (
    "cells found in genotype matrices do not match the extected nb of cells"
)

In [None]:
# uncoment this
clones = mitchell.load_clones(Path(PATH2MITCHELL) / "./expanded_clades.csv")
clones

In [None]:
fig, ax = plt.subplots(1, 1)
sns.barplot(
    data=clones[["donor_id", "cf"]].groupby("donor_id").sum().reindex(clones[["age", "donor_id"]].drop_duplicates().sort_values(by="age").donor_id).reset_index(),
    x="donor_id", y="cf",
    ax=ax,
)
ax.legend().set_visible(False)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=12, rotation=45)
ax.set_ylabel("Clonal fraction [%]")
plt.show()

### Compute the summary statistics 
1. wasserstein metric
2. the number of clones
3. the KS stat (not implemented yet)

After having computed the statistics from the simulations and Mitchell's data, we won't use the data anymore but just the summary statistic dataframe `abc_mitchell`: we are going to filter (and keep) the runs by selecting from this dataframe.

In [None]:
donors

In [None]:
%%time
abc_mitchell = abc.compute_abc_results(
    sims_sfs=sfs_sims,
    sims_clones=counts,
    target_sfs={
        k: ele[-1] for k, ele in target_sfs.items()
    },  # ele[-1] means the SFS, k is the age
    target_clones=donors,
    experiment="mitchell",
)
# get names from donors
abc_mitchell = (
    abc_mitchell.merge(
        right=donors[["age", "name"]],
        how="left",
        right_on="age",
        left_on="timepoint",
        validate="many_to_one",
    )
    .rename(columns={"age_x": "age"})
    .drop(columns=["age_y"])
)

abc_mitchell.shape

### Show priors

In [None]:
# TODO priors are now s /tau and sigma/ tau not s and sigma
if SHOW_PRIORS:
    priors = abc_mitchell[["mu", "eta", "sigma", "tau"]].drop_duplicates()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["eta"], ax=ax, bins=bins_s)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["sigma"], ax=ax, bins=bins_std)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["tau"], ax=ax, bins=bins_tau)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["mu"], ax=ax, discrete=False, bins=bins_mu)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    sns.histplot(abc_mitchell["wasserstein"], binwidth=0.01, ax=ax)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    sns.histplot(abc_mitchell["rel clones diff"], binwidth=0.01, ax=ax)
    plt.show()

    sns.pairplot(priors[["eta", "sigma", "tau", "mu"]], kind="hist")
    if options.save:
        plt.savefig(f"priors.{options.extension}")
    plt.show()
    print(priors.eta.min(), priors.eta.max())

In [None]:
posterior_fabre = pd.read_csv("abcAcceptedParamsQuantile_5D_normedFitness_biasedSampling_3000particles_ranked.csv")
posterior_fabre.rename(
    columns={"s": "eta", "σ": "sigma", "μ": "mu", "τ": "tau"}, inplace=True
)
posterior_fabre

### Run abc considering all timepoints at the same time

In [None]:
def color_mapping(eta: float, sigma: float) -> str:
    if sigma > 0.016 and eta < 0.03:
        return "orange"
    if sigma > 0 and eta > 0.15:
        return "green"
    return "blue"

In [None]:
# posteriors distr for Mitchell
# fine tune the q to keep 1% of the runs, i.e. 638 runs
runs2keep = abc_fig.posterior_mitchell_quantile(abc_mitchell, QUANTILE, QUANTILE, PROPORTION_RUNS_DISCARDED, bins)
posterior_mitchell = abc_mitchell.loc[abc_mitchell.idx.isin(runs2keep), :]

axd = bins.plot_posterior(posterior_mitchell, False, Path(f"./posterior"))
plt.show()

In [None]:
# show the fancy posteriors Fabre and Mitchell
colors, names = (COLORS["blue"], COLORS["orange"]), ("Fabre", "Mitchell")

fig_eta, ax_eta = plt.subplots(1, 1, layout="constrained",  figsize=options.figsize)
fig_sigma, ax_sigma = plt.subplots(1, 1, layout="constrained",  figsize=options.figsize)
fig_tau, ax_tau = plt.subplots(1, 1, layout="constrained",  figsize=options.figsize)
fig_mu, ax_mu = plt.subplots(1, 1, layout="constrained",  figsize=options.figsize)
fig_gamma, ax_gamma = plt.subplots(1, 1, layout="constrained",  figsize=options.figsize)

# 400 for fabre is 1%
for i, (name, posterior, color) in enumerate(
    zip(names, (posterior_fabre.iloc[:400], posterior_mitchell), colors)
):
    estimate_eta = bins.bins["eta"].compute_estimate(posterior.loc[:, "eta"])
    _ = abc_fig.plot_posteriors_with_estimate(
        ax_eta,
        posterior.eta, 
        r"$\eta$", 
        bins.bins["eta"], 
        color, 
        (0.95, 0.85 - i * 0.15),
    )
    ax_eta.set_xlabel(r"Mean fitness per year $\eta$", fontsize=FONTSIZE)
    ax_eta.set_ylabel("Pdf", fontsize=FONTSIZE)
    ax_eta.set_xlim(0, 0.4)
    ax_eta.set_ylim(0, 20)

    estimate_sigma = bins.bins["sigma"].compute_estimate(posterior.loc[:, "sigma"])
    _ = abc_fig.plot_posteriors_with_estimate(
        ax_sigma,
        posterior.sigma, 
        r"$\sigma$", 
        bins.bins["sigma"], 
        color, 
        (0.95, 0.85 - i * 0.15),
    )
    ax_sigma.set_xlabel(r"Standard deviation fitness per year $\sigma$", fontsize=FONTSIZE)
    ax_sigma.set_ylabel("Pdf", fontsize=FONTSIZE)
    ax_sigma.set_xlim(0, 0.1)
    ax_sigma.set_ylim(0, 80)

    estimate_tau = bins.bins["tau"].compute_estimate(posterior.loc[:, "tau"])
    _ = abc_fig.plot_posteriors_with_estimate(
        ax_tau,
        posterior.tau, 
        r"$\tau$", 
        bins.bins["tau"], 
        color, 
        (0.95, 0.85 - i * 0.15),
    )
    ax_tau.set_xlabel(r"Wild-type inter-division time per year $\tau$", fontsize=FONTSIZE)
    ax_tau.set_ylabel("Pdf", fontsize=FONTSIZE)
    ax_tau.set_xlim(0, ax_tau.get_xlim()[-1])
    ax_tau.set_ylim(0, ax_tau.get_ylim()[-1])

    estimate_mu = bins.bins["mu"].compute_estimate(posterior.loc[:, "mu"])
    _ = abc_fig.plot_posteriors_with_estimate(
        ax_mu,
        posterior.mu, 
        r"$\mu$", 
        bins.bins["mu"], 
        color, 
        (0.95, 0.85 - i * 0.15),
    )
    ax_mu.set_xlabel(r"Fit mutants arrival rate per year $\mu$", fontsize=FONTSIZE)
    ax_mu.set_ylabel(r"Pdf", fontsize=FONTSIZE)
    ax_mu.set_xlim(0, 30)
    ax_mu.set_ylim(0, 0.08)

    abc_fig.plot_gamma_inferred(
        ax_gamma, 
        posterior, 
        name, 
        color, 
        bins.bins["eta"], 
        bins.bins["sigma"],
        (0.98, 0.85 - i * 0.1),
        SEED,
        30,
    )
    ax_gamma.set_xlim(0, ax_gamma.get_xlim()[-1])
    ax_gamma.set_ylim([0, 30])
    ax_gamma.set_xlabel(ax_gamma.get_xlabel(), fontsize=FONTSIZE)
    ax_gamma.set_ylabel("Pdf", fontsize=FONTSIZE)
if options.save:
    for i, figg in enumerate([fig_eta, fig_sigma, fig_tau, fig_mu, fig_gamma]):
        figg.savefig(f"figure4_{i}.{options.extension}", transparent=True)
plt.show()

In [None]:
# show the SFS for Mitchell
verbose = False
idx2show = dict()
for t in sorted(abc_mitchell.timepoint.unique()):
    name, cells, age = (
        abc_mitchell.loc[abc_mitchell.timepoint == t, "name"].unique()[0],
        abc_mitchell.loc[abc_mitchell.timepoint == t, "sample"].unique()[0],
        abc_mitchell.loc[abc_mitchell.timepoint == t, "age"].unique()[0],
    )
    if verbose:
        title = f"age: {t} years, quantile threshold: {abc_mitchell.loc[abc_mitchell.timepoint == t, 'wasserstein'].quantile(quantile):.2f}"
    else:
        title = f"age: {round(t)} years"

    idx2show[t] = abc_fig.get_idx_smaller_distance_clones_idx(
        posterior_mitchell[posterior_mitchell.timepoint == t], runs2keep
    )
    
    best_fit_all = [ele.sfs for ele in sfs_sims[t] if ele.parameters.idx  == idx2show[t]]
    best_fit = best_fit_all.pop()
    assert len(best_fit_all) == 0

    fig, [ax_sfs, ax_cdf] = plt.subplots(
        2, 
        1, 
        figsize=(options.figsize[0] * 4 / 3, options.figsize[1] * 1.5), 
        layout="constrained",
        sharex=True,
    )
    
    ax_sfs, ax_cdf = sfs_fig.plot_sfs_cdf(
        ax_sfs,
        ax_cdf,
        [idx2show[t]],
        target_sfs[t][3],
        sfs_sims[t],
        t,
        verbose=verbose,
        alpha=0.7,
        donor_name=name,
        donor_cells=cells,
        plot_options=options,
    )
    f_obs, f_exp, idx_lower_bound, idx_upper_bound, mean_squared_log_error, rmsre, mape = sfs.prepare_sfs_with_uniformisation_for_test(
        target_sfs[t][3], best_fit
    )
    res = stats.ks_2samp(f_obs, f_exp)
    ax_sfs.text(
        x=0.62,
        y=0.75,
        s=f"donor {age} y.o.\n{donors.loc[donors.name == name, 'cells'].squeeze()} cells",
        transform=ax_sfs.transAxes,
        fontsize=12,
    )   
    ax_sfs.text(
        x=0.62,
        y=0.64,
        s=r"$\mathregular{{p_{{KS}}={{{:.2f}}}}}$".format(res.pvalue),
        transform=ax_sfs.transAxes,
        fontsize=12,
    )
    # TODO: not sure about the +1 here
    # ax_sfs.axvline((idx_upper_bound + 1) / cells)
    
    
    if options.save:
        fig.savefig(f"sfs_{t}years.{options.extension}", transparent=True)
    plt.show()

fig, ax = plt.subplots(1, 1)
posterior_mitchell["hue"] = posterior_mitchell.loc[:, ["eta", "sigma"]].apply(
    lambda row: color_mapping(row.eta, row.sigma), axis=1
)
sns.scatterplot(
    data=posterior_mitchell[["eta", "sigma", "hue"]].drop_duplicates(),
    x="eta",
    y="sigma",
    hue="hue",
    legend=None,
)
ax.set_xlabel(r"$\eta$")
ax.set_ylabel(r"$\sigma$")
plt.show()

fig, ax = plt.subplots(1, 1)
for row in posterior_mitchell[["eta", "sigma"]].drop_duplicates().itertuples():
    gamma = abc_fig.Gamma(row.eta, row.sigma)
    c = color_mapping(row.eta, row.sigma)
    gamma.plot(ax, c=c, lw=1, alpha=0.2)


ax.set_ylim([0, 150])
ax.set_xlim([0, 0.5])
ax.set_ylabel("pdf")
ax.set_xlabel(r"$s$")
plt.show()

In [None]:
unselected = abc_mitchell.loc[
    ~abc_mitchell.idx.isin(runs2keep), ["age", "sims clones", "idx"]
].drop_duplicates()

# find max/min per timepoint
grouped = unselected[["age", "sims clones"]].groupby("age")
fig, ax = plt.subplots(1, 1, layout="constrained",  figsize=options.figsize)
ax.fill_between(
    x=unselected.age.unique(),
    y1=grouped.max().squeeze(),
    y2=grouped.min().squeeze(),
    color=COLORS["grey_dark"],
    alpha=0.9,
    # label="rejected runs",
)
sns.lineplot(
    data=abc_mitchell[["age", "clones"]].drop_duplicates(),
    x="age",
    y="clones",
    mew=0,
    markersize=8,
    marker=".",
    linewidth=1.5,
    color=COLORS["orange"],
    # label="Mitchell data",
    ax=ax,
)

grouped = abc_mitchell.loc[
    abc_mitchell.idx.isin(runs2keep), ["age", "sims clones"]
].groupby("age")
ax.fill_between(
    x=abc_mitchell.loc[abc_mitchell.idx.isin(runs2keep), "age"].unique(),
    y1=grouped.max().squeeze(),
    y2=grouped.min().squeeze(),
    color=COLORS["yellow"],
    alpha=0.8,
    # label="accepted runs",
    edgecolor=None,
)
"""
ax.plot(
    abc_mitchell.loc[abc_mitchell.idx.isin(runs2keep), "age"].unique(),
    grouped.mean().squeeze(),
    color=COLORS["yellow"],
)
ax.legend(
    ncols=3,
    mode="expand",
    bbox_to_anchor=(-0.2, 1, 1.2, 1),
    loc="lower left",
    fontsize="x-small",
    handletextpad=0.5,
).set_visible(False)
"""
ax.set_ylim([-1, 25])
ax.set_yticks([0, 5, 10, 15, 20, 25])
ax.set_xlim([-1, donors.age.max() + 1])
ax.set_ylabel("Expanded clones", fontsize=12)
ax.set_xlabel("Age (years)", fontsize=12)

if options.save:
    fig.savefig(f"variants_abc.{options.extension}", transparent=True)
plt.show()

In [None]:
from matplotlib.patches import ConnectionPatch
yticks = [0, 5, 10, 15, 20]

fig, axes = plt.subplots(1, 7, figsize=(4.5, 2), layout="constrained", sharey=True, gridspec_kw={"wspace":0})
fig.patch.set_visible(False)
for age, ax in zip(abc_mitchell.age.unique()[1:], axes):
    ax.set_ylim([0, 22])
    data = abc_mitchell.loc[
        (abc_mitchell.idx.isin(runs2keep)) & (abc_mitchell.age == age), "sims clones"
    ]
    hist, edges = np.histogram(data.tolist(), bins=range(0, 22), density=True)
    ax.barh(
        y=edges[:-1],
        height=1,
        width=hist,
        color=COLORS["yellow"],
        zorder=20,
    )
    if age > 30:
        ax.spines[['right', 'left', 'top', 'bottom']].set_visible(False)
        ax.yaxis.set_tick_params(width=0)
    else:
        ax.spines[['right', 'top', 'bottom']].set_visible(False)
        ax.set_ylabel("Expanded clones")
    ax.set_yticks(yticks)
    ax.set_xticks([0], labels=[age])
    ax.grid(axis="y", visible=False)
    ax.plot(0, data.mean(), marker="D", color="grey", zorder=24)
    ax.plot(
        donors.loc[donors.age == age, "clones"].squeeze(), marker="s", color=COLORS["orange"], 
        zorder=24,
    )
cons = list()
for ytick in yticks:
    if ytick == 0:
        con = ConnectionPatch(
            xyA=(0, ytick), coordsA=axes[0].transData,
            xyB=(axes[-1].get_xlim()[-1], ytick) , coordsB=axes[-1].transData,
        )
    else:
        con = ConnectionPatch(
            xyA=(0, ytick), coordsA=axes[0].transData,
            xyB=(axes[-1].get_xlim()[-1], ytick) , coordsB=axes[-1].transData,
            color="#b0b0b0", alpha=0.5, lw=0.45, zorder=0
        )
    fig.add_artist(con)

if options.save:
    fig.savefig(f"variants_over_time.{options.extension}", transparent=True)
plt.show()

In [None]:
def compute_entropy(sfs_s: dict, subset: set, ages: list) -> np.ndarray:
    entropies_sims = list()
    for age in ages:
        for sfs_ss in [s for s in sfs_s[age] if s.parameters.idx in subset]:
            entropies_sims.append(stats.entropy(snapshot.array_from_hist(sfs_ss.sfs)))
    return np.array(entropies_sims, dtype=float).reshape(abc_mitchell.age.unique().shape[0], len(subset))

In [None]:
%%time
ages = abc_mitchell.age.unique()
if sfs_sims_neutral:
    entropies_sims_neutral = compute_entropy(
        sfs_sims_neutral, 
        {ele.parameters.idx for ele in sfs_sims_neutral[0.0]}, # all of them
        ages,
    )
    fig, ax = plt.subplots(1, 1, layout="constrained",  figsize=options.figsize)
    for i in range(entropies_sims_neutral.shape[-1]):
        ax.plot(ages, entropies_sims_neutral[..., i], color="grey", alpha=0.05)
    ax.plot(ages, entropies_sims_neutral.mean(axis=1))

    ax.set_ylabel("Entropy")
    ax.set_xlabel("Age (years)")
    ax.set_ylim([6, 14])
    plt.show()

print("computing the entropy for the accepted runs")
entropies_sims = compute_entropy(
    sfs_sims, 
    {ele.parameters.idx for ele in sfs_sims[0.0] if ele.parameters.idx in runs2keep},
    ages,
)
print("computing the entropy for the rejected runs")
entropies_sims_rejected = compute_entropy(
    sfs_sims, 
    {ele.parameters.idx for ele in sfs_sims[0.0][:2000] if ele.parameters.idx not in runs2keep},
    ages,
)

print("nb of rejected runs plotted:", entropies_sims_rejected.shape[-1])

fig, ax = plt.subplots(1, 1, layout="constrained",  figsize=options.figsize)
"""
ax.fill_between(
    x=ages,
    y1=rejected_sims.max(axis=1).squeeze(),
    y2=rejected_sims.min(axis=1).squeeze(),
    color=COLORS["grey_dark"],
    alpha=0.9,
    label="rejected runs",
    edgecolor=None,
)"""
# ax.plot(ages, entropies_sims.mean(axis=1), lw=1.2, marker=".", color=COLORS["yellow"])
for i in range(entropies_sims_rejected.shape[-1]):
    ax.plot(ages, entropies_sims_rejected[..., i], color=COLORS["grey_dark"], alpha=0.05)
"""ax.fill_between(
    x=ages,
    y1=entropies_sims.max(axis=1).squeeze(),
    y2=entropies_sims.min(axis=1).squeeze(),
    color=COLORS["yellow"],
    alpha=0.8,
    label="accepted runs",
    edgecolor=None,
)"""
for i in range(entropies_sims.shape[-1]):
    ax.plot(ages, entropies_sims[..., i], COLORS["yellow"], alpha=0.5)
"""for i in range(entropies_sims_neutral.shape[-1]):
    ax.plot(ages, entropies_sims_neutral[..., i], color="black", alpha=0.5)"""
if sfs_sims_neutral:
    ax.plot(ages, entropies_sims_neutral.mean(axis=1), lw=1, ls=(0, (4, 3)), color="black")
ax.set_ylabel("Entropy")
ax.set_xlabel("Age (years)")
ax.set_ylim([5.5, 14])
ax.set_xlim([-1, 82])

ages, entropies = list(), list()
for don in target_sfs.values():
    ag, sfs_ = don[1], snapshot.array_from_hist(don[-1])
    entrop = stats.entropy(sfs_)
    print(f"{don[0]} age {ag} with entropy {entrop}")
    ages.append(ag)
    entropies.append(entrop)
    
ax.plot(ages, entropies, lw=1.2, marker=".", color=COLORS["orange"], label="Mitchell")
ax.set_ylabel("Entropy")
ax.set_xlabel("Age (years)")
#ax.legend(fontsize=10, loc=4)
if options.save:
    fig.savefig(f"sfs_entropy.png", transparent=True)
plt.show()

In [None]:
ages, entropies = list(), list()
for don in target_sfs.values():
    ag, sfs_ = don[1], snapshot.array_from_hist(don[-1])
    entrop = stats.entropy(sfs_)
    print(f"{don[0]} age {ag} with entropy {entrop}")
    ages.append(ag)
    entropies.append(entrop)
    
fig, ax = plt.subplots(1, 1)
ax.plot(ages, entropies, marker=".")
ax.set_ylabel("Entropy")
ax.set_xlabel("Age (years)")
plt.show()

In [None]:
%%time
for name in abc_mitchell.name.unique():
    age, cells = (
        abc_mitchell.loc[abc_mitchell.name == name, "age"].unique()[0],
        abc_mitchell.loc[abc_mitchell.name == name, "sample"].unique()[0],
    )
    fig, ax = plt.subplots(1, 1, layout="constrained", figsize=options.figsize)

    for sfs_s in [s for s in sfs_sims[age][:20_000]]:
        cdf_x_sim, cdf_y_sim = realisation.cdf_from_dict(sfs_s.sfs)
        if sfs_s.parameters.idx in runs2keep:
            continue
        else:
            ax.plot(cdf_x_sim / cells, cdf_y_sim, color=COLORS["grey_dark"], alpha=0.01)

    for sfs_s in [s for s in sfs_sims[age] if s.parameters.idx in runs2keep]:
        cdf_x_sim, cdf_y_sim = realisation.cdf_from_dict(sfs_s.sfs)
        ax.plot(cdf_x_sim / cells, cdf_y_sim, alpha=0.1, color=COLORS["yellow"])

    cdf_x_target, cdf_y_target = realisation.cdf_from_dict(target_sfs[age][3])
    ax.plot(cdf_x_target / cells, cdf_y_target, color=COLORS["orange"])
    ax.set_xscale("log")
    ax.set_ylabel("Cumulative distribution", fontsize=12)
    ax.set_xlabel(r"Variant frequency $f$", fontsize=12)
    ax.set_xlim(ax.get_xlim()[0], 1)
    ax.text(
        x=0.55,
        y=0.1,
        s=f"donor {age} y.o.",
        fontsize=12,
        transform=ax.transAxes,
    )
    
    if name in {'KX001', 'KX002', 'SX001', 'AX001'}:
        ax.set_ylim([0.95, 1])
    else:
        ax.set_ylim([0.85, 1])
    if options.save:
        # too many datapoints to save this in svg
        fig.savefig(f"cdf_fits_{name}.png", transparent=True)
    plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 1))
legend_elements = [
    Line2D(
        [0],
        [0],
        marker=".",
        mew=4,
        color="#d95f0e",
        label="data",
        markersize=13,
    ),
    Line2D(
        [0],
        [0],
        mew=1,
        color="#fdbf6f",
        label="accepted runs",
        markersize=12,
    ),
    Line2D([0], [0], color="grey", alpha=0.6, lw=4, label="rejected runs"),
]

ax.legend(
    handles=legend_elements,
    mode="extend",
    ncols=5,
    handletextpad=0.4,
)
ax.axis("off")
plt.show()

In [None]:
posterior_mitchell[["s", "tau", "std"]].corr()

In [None]:
fig, ax = plt.subplots(1, 1, layout="constrained")
mean, median = posterior_mitchell.tau.mean(), posterior_mitchell.tau.median()
posterior_mitchell.tau.hist(ax=ax, bins=15)
ax.axvline(mean, color="black", label=f"mean={mean:.2f}")
ax.axvline(mean, color="red", label=f"median={median:.2f}")
ax.legend()
ax.set_xlim([bins_tau[0], bins_tau[-1]])
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1)
for row in (
    posterior_mitchell[["eta", "sigma"]]
    .drop_duplicates()
    .sample(8, random_state=SEED)
    .itertuples()
):
    gamma = abc_fig.Gamma(row.eta, row.sigma)
    # c = color_mapping(row.eta, row.sigma)
    gamma.plot(
        ax,
        lw=2,
        alpha=0.8,
        label=f"${{\eta}}={row.eta:.2f}\;{{\sigma}}={row.sigma:.2f}$",
    )


ax.legend(fontsize="x-small")
ax.set_ylim([0, 40])
ax.set_xlim([0, 0.5])
ax.set_ylabel("Probability density")
ax.set_xlabel(r"Distribution of fitness effects $s$")
if options.save:
    fig.savefig(f"distribution_fitness_effects_per_donor.{options.extension}")
plt.show()

In [None]:
max_ = list()
ages =  abc_mitchell.age.unique()
# data from the paper, by eye from fig 5a
mitchell_tot_variant_frac = [0, 0, 0.015, 0, 0.04, 0.43, 0.59, 0.55]

for age in ages:
    arr = np.array([(s.variant_fractions) for s in variant_frac_sims[age] if s.parameters.idx in runs2keep])
    max_.append(np.where(arr >= DETECTION_THRESHOLD, arr, 0).sum(axis=1))

max_ = np.asarray(max_)
print(max_.shape)
fig, ax = plt.subplots(1, 1, layout="constrained", figsize=options.figsize)
ax.plot(ages, max_.mean(axis=1), color=COLORS["grey_dark"])
ax.fill_between(ages, max_.min(axis=1), max_.max(axis=1), alpha=0.45, color=COLORS["grey_dark"])
ax.plot(ages, mitchell_tot_variant_frac, marker=".", color=COLORS["orange"])
ax.set_xlabel("Age (years)")
ax.set_ylabel("Total variant fraction")
ax.set_xlim([0, 81])
if options.save:
    fig.savefig(f"total_variant_fraction.{options.extension}")
plt.show()

In [None]:
max_ = list()
ages =  abc_mitchell.age.unique()

for age in ages:
    # for all the runs2keep, select the largest clone, for all timepoints
    max_.append([sum(s.variant_fractions) for s in variant_frac_sims[age] if s.parameters.idx in runs2keep])

max_ = np.asarray(max_)
print(max_.shape)
fig, ax = plt.subplots(1, 1, layout="constrained", figsize=options.figsize)
ylim = 1
#ax.plot(ages, max_.mean(axis=1), color=COLORS["yellow"], marker=".")
ax.fill_between(ages, max_.min(axis=1), max_.max(axis=1), alpha=1, color=COLORS["yellow"], edgecolor=None)
ax.plot(ages[1:], clones[["age", "cf"]].drop(index=clones[clones.donor_id == "KX007"].index).groupby("age").sum().to_numpy().ravel() / 100, color=COLORS["orange"], marker=".")
ax.set_xlabel("Age (years)")
ax.set_ylabel("Total clone frequency")
ax.set_xlim([-1, ages[-1] + 1])
ax.set_ylim([-0.05, ylim])
plt.show()

In [None]:
max_, max_rejected = list(), list()
ages =  abc_mitchell.age.unique()

for age in ages:
    # for all the runs2keep, select the largest clone, for all timepoints
    max_.append([max(s.variant_fractions) for s in variant_frac_sims[age] if s.parameters.idx in runs2keep])

for age in ages:
    # for all the runs2keep, select the largest clone, for all timepoints
    for va in [s for s in variant_frac_sims[age]][:100]:
        if va.parameters.idx in runs2keep:
            continue
        else:
            max_rejected.append(max(va.variant_fractions))
    
max_ = np.asarray(max_)
max_rejected = np.asarray(max_rejected).reshape(8, -1)
print(max_.shape)

fig, ax = plt.subplots(1, 1, layout="constrained", figsize=options.figsize)
ylim = 1
#ax.fill_between(ages, 0.04, ylim, alpha=0.2, color="grey")
ax.fill_between(ages, max_rejected.min(axis=1), max_rejected.max(axis=1), alpha=0.9, color=COLORS["grey_dark"])
# ax.plot(ages, max_.mean(axis=1), color=COLORS["yellow"], marker=".")
ax.fill_between(ages, max_.min(axis=1), max_.max(axis=1), alpha=0.8, color=COLORS["yellow"], edgecolor=None)
ax.set_xlabel("Age (years)")
ax.set_ylabel("Largest clone frequency")
ax.set_xlim([-1, ages[-1] + 1])
ax.set_ylim([-0.05, ylim])

ax.plot(ages[1:], clones[["age", "cf"]].drop(index=clones[clones.donor_id == "KX007"].index).groupby("age").max().to_numpy().ravel() / 100, color=COLORS["orange"], marker=".")
if options.save:
    fig.savefig(f"largest_clone_frequency.{options.extension}")
plt.show()

#### Sensitivity analysis Fabre

In [None]:
TOT_FABRE_RUNS = 40_000
percentages = np.array([0.25, 0.5, 5]) / 100
posterior_fabre.reset_index(names="idx", inplace=True)

for percent in percentages:
    print(f"sensitivity with {percent:.2%} runs")
    nb_runs = int(percent * TOT_FABRE_RUNS)
    print(f"{percent:.2%}", nb_runs)
    assert posterior_fabre[:nb_runs].shape[0] == nb_runs
    
    axd = bins.plot_posterior(posterior_fabre[:nb_runs], False, Path(f"./posterior_fabre_{percent * 100}perc_"))
    plt.show()

#### Sensitivity analysis Mitchell

In [None]:
# with git tag rust of hsc at biorxiv this corresponds to 0.25%, 0.5%, 5%
quantiles = {k: ele for k, ele in zip(percentages, [0.2225, 0.272, 0.543])}

for perc, quant in quantiles.items():
    print(f"sensitivity with {perc:.2%} runs and {quant} quantile")
    runs2keep = abc_fig.posterior_mitchell_quantile(abc_mitchell, quant, quant, 0.1, bins)
    posterior_mitchell = abc_mitchell.loc[abc_mitchell.idx.isin(runs2keep), :]

    axd = bins.plot_posterior(posterior_mitchell, False, Path(f"./posterior_mitchell_{perc * 100}perc_"))
    plt.show()

### ABC for each timepoint separately

In [None]:
# plot all posteriors for all timepoints separately, for all metrics
for t in sorted(target_sfs.keys()):
    # for metric in ("rel clones diff", ):
    # for metric in ("clones diff", ):
    # for metric in ("wasserstein", "rel clones diff"):
    for metric in ("wasserstein",):
        fig, axes = plt.subplots(2, 2, layout="tight", sharey=True)
        print(f"Running ABC on {metric} metric at age {t}")
        idx = abc.run_abc_per_single_timepoint(
            abc_mitchell,
            t,
            quantile=0.01,
            metric=metric,
        ).get_idx()
        print(len(idx))
        view_tt = abc_mitchell[abc_mitchell.idx.isin(idx)]

        for i, (theta, lim) in enumerate(
            zip(
                ["s", "std", "mu", "tau"],
                [
                    (bins_s[0], bins_s[-1]),
                    (bins_std[0], bins_std[-1]),
                    (bins_mu[0], bins_mu[-1]),
                    (bins_tau[0], bins_tau[-1]),
                ],
            )
        ):
            ax = axes[np.unravel_index(i, (2, 2))]
            sns.histplot(view_tt[theta], ax=ax)  # stat="percent")
            ax.set_xlim(lim)
        fig.suptitle(f"{t} years with {metric}", y=0.9, fontsize="small")
        plt.show()

In [None]:
color = "#d95f02"
gammas = dict()
# between 0.8% and 1.6%
quantiles = {0: (0.015, 0.015), 29: (0.03, 0.03), 38: (0.05, 0.05), 48: (0.08, 0.08), 63: (0.08, 0.08), 76: (0.06, 0.06), 77: (0.07, 0.07), 81: (0.09, 0.09)}

# run ABC for one donor (indep.)
for age in sorted(abc_mitchell.age.unique()):
    name = f"{age} year"
    timepoint = abc_mitchell[abc_mitchell.age == age]
    runs2keep = abc.run_abc_sfs_clones(timepoint, quantiles[age][0], quantiles[age][1], 0)
    view_t = timepoint[timepoint.idx.isin(runs2keep)].drop_duplicates(subset="idx")

    assert not view_t.empty, "empty posterior"
    print(
        f"ABC combined kept {len(runs2keep)/timepoint.idx.unique().shape[0]:.2%} of runs: {len(runs2keep)} runs over a total of {timepoint.idx.unique().shape[0]}"
    )

    custom_lines = [
        Patch(facecolor=color, alpha=0.7, edgecolor="black", label=name),
    ]

    fig = abc_fig.create_posteriors_grid_eta_sigma_tau_mu()
    fig, gamma, estimates = abc_fig.plot_posteriors_grid_eta_sigma_tau_mu(
        view_t,
        name,
        fig,
        color,
        bins_eta=bins.bins["eta"],
        bins_sigma=bins.bins["sigma"],
        bins_tau=bins.bins["tau"],
        bins_mu=bins.bins["mu"],
        fancy=True,
    )
    fig.axes[1].legend(handles=custom_lines, frameon=False, fontsize=16)
    fig.axes[5].yaxis.set_major_formatter(ticker.FuncFormatter(abc_fig.fmt_two_digits))

    gammas[age] = gamma
    # add text box for the statistics
    # bbox = dict(boxstyle='round', fc="white", ec=None, alpha=0.5)
    for ax_, estimate in zip(fig.axes[2:], estimates):
        ax_.text(
            0.95,
            0.85,
            f"${ax_.get_xlabel().replace('$', '')}={estimate.point_estimate:.2f}^{{{estimate.credible_interval_90[1]:.2f}}}_{{{estimate.credible_interval_90[0]:.2f}}}$",
            fontsize=11,
            color=color,
            # bbox=bbox,
            transform=ax_.transAxes,
            horizontalalignment="right",
        )
    if options.save:
        fig.savefig(f"posteriors_per_patient_{age}.{options.extension}")
    plt.show()

In [None]:
markers = None, ".", "o", "v", "s", "*", "x", 4

fig, ax = plt.subplots(1, 1)

for (age, gamma), marker in zip(gammas.items(), markers):
    gamma.plot(ax, marker=marker, label=age, lw=2, mew=1.5, markevery=5, alpha=0.7)
ax.set_ylim([1, 36])
ax.set_xlim([0, 0.25])
ax.set_ylabel("pdf")
ax.set_xlabel("s")
ax.legend(title="years", fontsize="small")
plt.show()

colors = ["#7570b3", "#e7298a"]
fig, ax = plt.subplots(1, 1)
ax.plot(
    list(gammas.keys()),
    [gamma.mean for gamma in gammas.values()],
    marker=".",
    label=r"$\eta$",
    c=colors[0],
)
secax = ax.secondary_yaxis("right")
ax.plot(
    list(gammas.keys()),
    [gamma.std for gamma in gammas.values()],
    marker="o",
    label=r"$\sigma$",
    c=colors[1],
)
ax.set_ylabel(r"$\eta$", color=colors[0])
ax.set_ylim([0, 0.4])
secax.set_ylabel(r"$\sigma$", color=colors[1])
ax.set_xlabel("age")
ax.legend()
plt.show()

## Run ABC on subsampled simulated data

In [None]:
all_idx_from_sims = set([s.parameters.idx for s in sfs_sims[0]])
SHOW_VAL_POST = False

In [None]:
sorted(
    list(
        {(ele.parameters.s / ele.parameters.tau, ele.parameters.mu, ele.parameters.std / ele.parameters.tau, ele.parameters.idx) for ele in sfs_sims[0][:300]}
    ), 
    key=lambda x: (x[1], x[0])
)

In [None]:
validation_idx3 = abc_fig.SyntheticValidation(511050, sfs_sims, counts)

In [None]:
plt.hist(validation_idx3.params["mu"], bins=[3, 4, 5])

In [None]:
axd, estimates = validation_idx3.compute_posteriors(
    QUANTILE  - 0.065, QUANTILE  - 0.065, PROPORTION_RUNS_DISCARDED, bins,
)
for ax in axd:
    x, y = ax["C"].get_xlabel().replace("$", "").replace("\\", ""), ax["C"].get_ylabel().replace("$", "").replace("\\", "")
    ax["C"].axvline(validation_idx3.params[x], c="red")
    ax["C"].axhline(validation_idx3.params[y], c="red")
plt.show()

In [None]:
validation = list()
for target_idx in random.sample(sorted(all_idx_from_sims), 2):
    print(target_idx)
    validation_idx3 = abc_fig.SyntheticValidation(target_idx, sfs_sims, counts)
    axd, estimates = validation_idx3.compute_posteriors(
        QUANTILE - 0.1, QUANTILE - 0.1, PROPORTION_RUNS_DISCARDED, bins,
    )
    errors = list(map(
        abs,
        [
            validation_idx3.params["eta"] - estimates["eta"].point_estimate,
            validation_idx3.params["sigma"] - estimates["sigma"].point_estimate,
            validation_idx3.params["mu"] - estimates["mu"].point_estimate,
            validation_idx3.params["tau"] - estimates["tau"].point_estimate,
        ], 
    ))
    errors.append(target_idx)
    validation.append(errors)
    plt.show()
    
    if SHOW_VAL_POST:
        for ax in axd:
            x, y = ax["C"].get_xlabel().replace("$", "").replace("\\", ""), ax["C"].get_ylabel().replace("$", "").replace("\\", "")
            ax["C"].axvline(validation_idx3.params[x], c="red")
            ax["C"].axhline(validation_idx3.params[y], c="red")

        if options.save:
            for ax in axd:
                plt.savefig(f"posterior_synthetic_{ax['C'].get_xlabel()}_{ax['C'].get_ylabel()}.{options.extension}")

In [None]:
val_pd = pd.DataFrame(validation, columns=["eta", "sigma", "mu", "tau", "idx"])
val_pd

In [None]:
for col in val_pd.columns[:-1]:
    fig, ax = plt.subplots(1, 1)
    val_pd[col].plot(ax=ax, kind="hist")
    cc = "\\" + col
    ax.set_xlabel(f"${{{cc}}}$")
    plt.show()