In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import socket
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import random
from matplotlib import ticker
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from pathlib import Path
from scipy import stats
from typing import Union, List
from futils import parse_version, snapshot
from hscpy.figures import PlotOptions
from hscpy.figures import abc as abc_fig
from hscpy.figures import sfs as sfs_fig
from hscpy import realisation, mitchell, abc, variant

PATH2BIN = Path("~").expanduser() / "hsc/target/release"
assert PATH2BIN.is_dir()

In [None]:
%%bash -s "$PATH2BIN" --out version
$1/hsc --version

In [None]:
USE_SCRATCH = True
SAVEFIG = True
SHOW_PRIORS = True
LATEST = True
SEED = 26
# set to true if it's the first time running this nb: this
# will remove the runs that didn't finish running up to the
# last timepoint
FIRST_TIME = False

# the higher, the less precise and thus the more runs
# we set to these values because we aim to keep approx X runs
# PROPORTION_RUNS_DISCARDED, QUANTILE_SFS, QUANTILE_CLONES = 0.2, 0.25, 0.4
# PROPORTION_RUNS_DISCARDED, QUANTILE_SFS, QUANTILE_CLONES = 0.2, 0.35, 0.35
PROPORTION_RUNS_DISCARDED, QUANTILE_SFS, QUANTILE_CLONES = 0.2, 0.25, 0.25

options = PlotOptions(figsize=(7, 6), extension="svg", save=SAVEFIG)
random.seed(SEED)

In [None]:
if LATEST:
    VERSION = parse_version(version)
else:
    VERSION = "v2.2.12"
PATH2SAVE = Path(f"./{VERSION}")

print("Running hsc with version:", VERSION)

if USE_SCRATCH:
    PATH2SIMS = Path("/data/scratch")
else:
    PATH2SIMS = Path("/data/home")
PATH2SIMS /= f"hfx923/hsc-draft/{VERSION}"

if socket.gethostname() == "5X9ZYD3":
    PATH2MITCHELL = Path("/mnt/c/Users/terenz01/Documents/SwitchDrive/PhD/hsc")
elif socket.gethostname() == "LAPTOP-CEKCHJ4C":
    PATH2MITCHELL = Path("/mnt/c/Users/fra_t/Documents/PhD/hsc")
else:
    PATH2MITCHELL = Path("~").expanduser()

In [None]:
# 1. we dont use summary `Summary_cut.csv` because
#    it contains errors (i.e. duplicated and missing entries)
# 2. we dont use donor CB001 because we already have CB002 at age 0
# 3. we dont use donor KX007 because the genotype matrix is wrong,
#    they have uploded twice the same donor
donors = mitchell.donors()
donors.drop(index=donors[donors.name == "CB001"].index, inplace=True)
donors

In [None]:
# bins for abc
bins_mu = np.arange(0, 31, 2)
bins_tau = np.arange(0, 10.5, 0.5)
bins_s = np.arange(0, 0.41, 0.02)
bins_std = np.arange(0, 0.105, 0.005)

# Load simualted data both SFS and variant fractions

## Remove runs
Remove all the runs that didn't finish running. This is required because we load runs in a numpy array with fixed size.

In [None]:
def files2remove(years: List[int], cells: List[int]) -> List[str]:
    assert len(years) == len(cells)

    paths = [
        Path(f"{PATH2SIMS}/{cell}cells/sfs/{year}dot0years")
        for year, cell in zip(years, cells)
    ]

    files = [file.stem for path in paths for file in path.iterdir()]
    files2remove = {
        file for file in set(files) if files.count(file) < len(donors.age.tolist())
    }
    print(f"{len(files2remove)} files to remove")
    return files2remove

In [None]:
%%time
if FIRST_TIME:
    runs2remove = files2remove(
        years=donors.age.tolist(),
        cells=donors.cells.tolist(),
    )
    print(f"found {len(runs2remove)} runs to remove")

    for r in donors[["age", "cells"]].itertuples():
        print(r)
        for f in runs2remove:
            myf = f"{PATH2SIMS}/{r.cells}cells/sfs/{r.age}dot0years/{f}"
            try:
                Path(myf).with_suffix(".json").unlink()
            except FileNotFoundError:
                pass
            try:
                Path(myf.replace("sfs", "variant_fraction")).with_suffix(
                    ".csv"
                ).unlink()
            except FileNotFoundError:
                pass

## Load data

In [None]:
%%time
sfs_sims = dict()
counts_sims = dict()

for r in donors[["name", "age", "cells"]].itertuples():
    print(f"\tloading sims SFS for donor {r.name} with {r.cells} cells")
    path2sfs = Path(PATH2SIMS / f"{r.cells}cells/sfs/")
    sfs_sims.update(realisation.load_all_sfs_by_age(path2sfs))

    print(f"\tloading sims variant counts for donor {r.name} with {r.cells} cells")
    counts_sims.update(
        variant.load_all_detected_var_counts_by_age(
            PATH2SIMS / f"{r.cells}cells/variant_fraction", 0.01
        )
    )

In [None]:
counts = variant.variant_counts_detected_df(counts_sims)
fig, ax = plt.subplots(1, 1)
sns.lineplot(
    counts,
    x="age",
    y="variant counts detected",
    errorbar=lambda x: (np.min(x), np.max(x)),
    ax=ax,
    label="min-max",
)
sns.lineplot(
    counts,
    x="age",
    y="variant counts detected",
    errorbar="sd",
    ax=ax,
    color="orange",
    label="std",
)
ax.legend()
plt.show()
print(counts[["variant counts detected", "age"]].groupby("age").describe())

# Run ABC on the real data

## Load the data from Mitchell's paper
We have excluded two donors from the ABC:
1. exclude KX007 bc they have uploded twice the same donor
3. exclude CB001 bc it maps to to the same timepoint as CB002 (same age 0)

In [None]:
%%time
target_sfs = {
    r.age: mitchell.sfs_donor_mitchell(
        r.name, r.age, PATH2MITCHELL, remove_indels=False
    )
    for r in donors[["name", "age"]].itertuples()
}
assert [
    ele[2] for ele in target_sfs.values()
] == donors.cells.tolist(), (
    "cells found in genotype matrices do not match the extected nb of cells"
)

### Compute the summary statistics 
1. wasserstein metric
2. the number of clones
3. the KS stat (not implemented yet)

After having computed the statistics from the simulations and Mitchell's data, we won't use the data anymore but just the summary statistic dataframe `abc_mitchell`: we are going to filter (and keep) the runs by selecting from this dataframe.

In [None]:
%%time
abc_mitchell = abc.compute_abc_results(
    sims_sfs=sfs_sims,
    sims_clones=counts,
    target_sfs={
        k: ele[-1] for k, ele in target_sfs.items()
    },  # ele[-1] means the SFS, k is the age
    target_clones=donors,
    experiment="mitchell",
)
# get names from donors
abc_mitchell = abc_mitchell.merge(
    right=donors[["age", "name"]],
    how="left",
    right_on="age",
    left_on="timepoint",
    validate="many_to_one"
).rename(columns={"age_x": "age"}).drop(columns=["age_y"])

abc_mitchell.shape

### Show priors

In [None]:
# TODO priors are now s /tau and sigma/ tau not s and sigma
if SHOW_PRIORS:
    priors = abc_mitchell[["mu", "u", "eta", "sigma", "tau"]].drop_duplicates()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["eta"], ax=ax, bins=bins_s)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["sigma"], ax=ax, bins=bins_std)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["tau"], ax=ax, bins=bins_tau)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["mu"], ax=ax, discrete=False, bins=bins_mu)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["u"], ax=ax)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    sns.histplot(abc_mitchell["wasserstein"], binwidth=0.01, ax=ax)
    plt.show()
    
    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    sns.histplot(abc_mitchell["rel clones diff"], binwidth=0.01, ax=ax)
    plt.show()

    sns.pairplot(priors[["eta", "sigma", "tau", "mu"]], kind="hist")
    if options.save:
        plt.savefig(f"priors.{options.extension}")
    plt.show()
    print(priors.eta.min(), priors.eta.max())

In [None]:
posterior_fabre = pd.read_csv("abcAcceptedParams_Intermediate_5D_normedFitness.csv")
posterior_fabre.rename(
    columns={"s": "eta", "σ": "sigma", "μ": "mu", "τ": "tau"}, inplace=True
)
posterior_fabre

### Run abc considering all timepoints at the same time

In [None]:
runs2keep = abc.run_abc_sfs_clones(
    abc_mitchell, QUANTILE_SFS, QUANTILE_CLONES, PROPORTION_RUNS_DISCARDED
)
posterior_mitchell = abc_mitchell.loc[
    abc_mitchell.idx.isin(runs2keep), :
].drop_duplicates(subset="idx")

assert not posterior_mitchell.empty, "empty posterior"
print(
    f"ABC combined kept {len(runs2keep)} runs over a total of {abc_mitchell.idx.unique().shape[0]} runs"
)

# plots
gs = []
gs.append(abc_fig.plot_results(posterior_mitchell, "eta", "mu", bins_s, bins_mu))
plt.savefig(
    f"posterior{len(gs)}.{options.extension}", bbox_inches="tight", pad_inches=0
)
plt.show()

gs.append(abc_fig.plot_results(posterior_mitchell, "eta", "sigma", bins_s, bins_std))
plt.savefig(
    f"posterior{len(gs)}.{options.extension}", bbox_inches="tight", pad_inches=0
)
plt.show()

gs.append(abc_fig.plot_results(posterior_mitchell, "eta", "tau", bins_s, bins_tau))
plt.savefig(
    f"posterior{len(gs)}.{options.extension}", bbox_inches="tight", pad_inches=0
)
plt.show()

gs.append(abc_fig.plot_results(posterior_mitchell, "mu", "tau", bins_mu, bins_tau))
plt.savefig(
    f"posterior{len(gs)}.{options.extension}", bbox_inches="tight", pad_inches=0
)
plt.show()

gs.append(abc_fig.plot_results(posterior_mitchell, "mu", "sigma", bins_mu, bins_std))
plt.savefig(
    f"posterior{len(gs)}.{options.extension}", bbox_inches="tight", pad_inches=0
)
plt.show()


gs.append(abc_fig.plot_results(posterior_mitchell, "tau", "sigma", bins_tau, bins_std))
plt.savefig(
    f"posterior{len(gs)}.{options.extension}", bbox_inches="tight", pad_inches=0
)
plt.show()

# TODO: maybe iterate over gs to remove the grid in the marginal plots?
verbose = False
idx2show = dict()

for t in sorted(abc_mitchell.timepoint.unique()):
    name, cells = (
        abc_mitchell.loc[abc_mitchell.timepoint == t, "name"].unique()[0],
        abc_mitchell.loc[abc_mitchell.timepoint == t, "sample"].unique()[0],
    )
    if verbose:
        title = f"age: {t} years, quantile threshold: {abc_mitchell.loc[abc_mitchell.timepoint == t, 'wasserstein'].quantile(quantile):.2f}"
    else:
        title = f"age: {round(t)} years"

    idx2show[t] = abc_fig.get_idx_smaller_distance_clones_idx(
        abc_mitchell[abc_mitchell.timepoint == t], runs2keep
    )

    fig = sfs_fig.plot_sfs_cdf(
        [idx2show[t]],
        target_sfs[t][3],
        sfs_sims[t],
        t,
        verbose=verbose,
        alpha=0.7,
        donor_name=name,
        donor_cells=cells,
        plot_options=options,
    )
    if options.save:
        fig.savefig(f"sfs_{t}years.{options.extension}", transparent=True)
    plt.show()


fig = abc_fig.create_posteriors_grid_eta_sigma_tau_mu()
colors, names = ("#0570B0", "#d95f0e"), ("Fabre", "Mitchell")
custom_lines = [
    Patch(facecolor=colors[0], alpha=0.7, edgecolor="black", label=names[0]),
    Patch(facecolor=colors[1], alpha=0.7, edgecolor="black", label=names[1]),
]

for i, (name, posterior, color) in enumerate(
    zip(names, (posterior_fabre, posterior_mitchell), colors)
):
    fig, _, estimates = abc_fig.plot_posteriors_grid_eta_sigma_tau_mu(
        posterior,
        name,
        fig,
        color,
        bins_eta=bins_s,
        bins_sigma=bins_std,
        bins_tau=bins_tau,
        bins_mu=bins_mu,
    )
    fig.axes[0].text(
        0.95,
        0.85 - i * 0.15,
        f"${estimates[0].name.replace('$', '')}={estimates[0].point_estimate:.2f}\;\;{estimates[1].name.replace('$', '')}={estimates[1].point_estimate:.2f}$",
        fontsize=13,
        color=color,
        # bbox=bbox,
        transform=fig.axes[0].transAxes,
        horizontalalignment="right",
    )
    fig.axes[0].set_xlabel(r"Innate clonal fitness $s$", fontsize="small")
    for row in posterior.loc[posterior.eta < 0.2, ["eta", "sigma"]].sample(40, random_state=SEED).itertuples():
        abc_fig.Gamma(row.eta, row.sigma).plot(fig.axes[0], color=color, alpha=0.1)
    fig.axes[0].set_ylim([0, 36])
    """
    for eta_interval, sigma_interval in zip(estimates[0].credible_interval_90, estimates[1].credible_interval_90):
        abc_fig.Gamma(eta_interval, sigma_interval).plot(fig.axes[0], color=color, alpha=0.1)
    fig.axes[0].set_ylim([0, 21])
    """
    
    # add text box for the statistics
    # bbox = dict(boxstyle='round', fc="white", ec=None, alpha=0.5)
    for j, (ax_, estimate) in enumerate(zip(fig.axes[2:], estimates)):
        xlabel = ax_.get_xlabel()
        if j < 2:
            frmt = estimate.to_string("two")
            if j:
                ax_.set_xlabel(
                    f"Variance fitness per year ${ax_.get_xlabel().replace('$', '')}$",
                    fontsize="small",
                )
            else:
                ax_.set_xlabel(
                    f"Mean fitness per year ${ax_.get_xlabel().replace('$', '')}$",
                    fontsize="small",
                )
        elif j == 2:
            frmt = estimate.to_string("one")
            ax_.set_xlabel(
                f"Wild-type interdivision time in years ${ax_.get_xlabel().replace('$', '')}$",
                fontsize="small",
            )
        else:
            frmt = estimate.to_string("zero")
            ax_.set_xlabel(
                f"Fit mutants per year ${ax_.get_xlabel().replace('$', '')}$",
                fontsize="small",
            )

        ax_.text(
            0.95,
            0.85 - i * 0.15,
            f"${xlabel.replace('$', '')}={frmt}$",
            fontsize=13,
            color=color,
            # bbox=bbox,
            transform=ax_.transAxes,
            horizontalalignment="right",
        )
fig.axes[1].legend(handles=custom_lines, frameon=False, fontsize=16)

if options.save:
    fig.savefig(f"figure4.{options.extension}", transparent=True)
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, layout="tight")
unselected = abc_mitchell.loc[
    ~abc_mitchell.idx.isin(runs2keep), ["age", "sims clones", "idx"]
].drop_duplicates()

sns.lineplot(
    data=abc_mitchell[["age", "clones"]].drop_duplicates(),
    x="age",
    y="clones",
    mew=0.2,
    marker="o",
    # linewidths=2,
    color="#d95f0e",
    label="Mitchell data",
)
grouped = abc_mitchell.loc[
    abc_mitchell.idx.isin(runs2keep), ["age", "sims clones"]
].groupby("age")
ax.fill_between(
    x=abc_mitchell.loc[abc_mitchell.idx.isin(runs2keep), "age"].unique(),
    y1=grouped.max().squeeze(),
    y2=grouped.min().squeeze(),
    color="#fdbf6f",
    alpha=0.5,
    label="accepted runs",
)

# find max/min per timepoint
grouped = unselected[["age", "sims clones"]].groupby("age")
ax.fill_between(
    x=unselected.age.unique(),
    y1=grouped.max().squeeze(),
    y2=grouped.min().squeeze(),
    color="grey",
    alpha=0.2,
    label="rejected runs",
)

ax.legend(
    ncols=3,
    mode="expand",
    bbox_to_anchor=(-0.2, 1, 1.2, 1),
    loc="lower left",
    fontsize="x-small",
    handletextpad=0.5,
)
ax.set_ylabel("Expanded clones", fontsize="medium")
ax.set_xlabel("Age [years]")

if options.save:
    fig.savefig(f"variants_abc.{options.extension}", transparent=True)
plt.show()

In [None]:
%%time
for name in abc_mitchell.name.unique():
    age, cells = abc_mitchell.loc[abc_mitchell.name == name, "age"].unique()[0],  abc_mitchell.loc[abc_mitchell.name == name, "sample"].unique()[0]
    fig, ax = plt.subplots(1, 1, layout="tight")

    for sfs_s in [s for s in sfs_sims[age][:40_000]]:
        cdf_x_sim, cdf_y_sim = realisation.cdf_from_dict(sfs_s.sfs)
        if sfs_s.parameters.idx in runs2keep:
            continue
        else:
            ax.plot(cdf_x_sim / cells, cdf_y_sim, color="#bdbdbd", alpha=0.01)

    for sfs_s in [s for s in sfs_sims[age] if s.parameters.idx in runs2keep]:
        cdf_x_sim, cdf_y_sim = realisation.cdf_from_dict(sfs_s.sfs)
        ax.plot(cdf_x_sim / cells, cdf_y_sim, alpha=0.05, color="#fdbf6f")

    cdf_x_target, cdf_y_target = realisation.cdf_from_dict(target_sfs[age][3])
    ax.plot(cdf_x_target / cells, cdf_y_target, marker="o", color="#d95f0e")
    ax.set_xscale("log")
    ax.set_ylabel("Cumulative distribution")
    ax.set_xlabel(r"Variant frequency $f$")
    ax.text(
        x=0.65,
        y=0.1,
        s=f"donor {age} y.o.",
        transform=ax.transAxes,
    )
    ax.set_ylim([0.8, 1])
    if options.save:
        fig.savefig(f"cdf_fits_{name}.{options.extension}", transparent=True)
    plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 1))
legend_elements = [
    Line2D(
        [0],
        [0],
        marker=".",
        mew=4,
        color="#d95f0e",
        label="data",
        markersize=13,
    ),
    Line2D(
        [0],
        [0],
        mew=1,
        color="#fdbf6f",
        label="accepted runs",
        markersize=12,
    ),
    Line2D([0], [0], color="grey", alpha=0.6, lw=4, label="rejected runs"),
]

ax.legend(
    handles=legend_elements,
    mode="extend",
    ncols=5,
    handletextpad=0.4,
)
ax.axis("off")
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, layout="tight")
# colors = ["grey", "#a6cee3", "#1f78b4"]
colors = ["grey", "#d95f0e", "#d95f0e"]

custom_lines = [
    Line2D([0], [0], color=colors[0], label="rejected particles", lw=2),
    Line2D([0], [0], color=colors[1], alpha=0.3, label="accepted particles", lw=2),
    Line2D([0], [0], color=colors[2], label="Mitchell data", marker=".", mew=3, lw=2),
]

unselected = abc_mitchell.loc[
    (abc_mitchell.idx.isin(abc_mitchell.sample(1000).idx))
    & (~abc_mitchell.idx.isin(runs2keep)),
    ["age", "sims clones", "idx"],
].drop_duplicates()

for id_ in unselected.idx.unique():
    tmp = unselected[unselected.idx == id_]
    ax.plot(tmp["age"], tmp["sims clones"], c=colors[0], alpha=0.01)
"""for id_ in abc_mitchell[abc_mitchell.idx.isin(runs2keep)].idx.unique():
    tmp = abc_mitchell[abc_mitchell.idx == id_]
    ax.plot(tmp["age"], tmp["variant counts detected"], c=colors[1], alpha=0.2)
"""
grouped = abc_mitchell.loc[
    abc_mitchell.idx.isin(runs2keep), ["age", "sims clones"]
].groupby("age")
ax.fill_between(
    x=abc_mitchell.loc[abc_mitchell.idx.isin(runs2keep), "age"].unique(),
    y1=grouped.max().squeeze(),
    y2=grouped.min().squeeze(),
    color="#d95f0e",
    alpha=0.3,
    label="accepted runs",
)

ax.plot(
    abc_mitchell[["age", "clones"]].drop_duplicates().age,
    abc_mitchell[["age", "clones"]].drop_duplicates().clones,
    marker=".",
    mew=1,
    c=colors[2],
)
ax.legend(handles=custom_lines, loc="upper left", frameon=False, fontsize=16)
ax.set_xlabel("age")
ax.set_ylabel("detected clones")
if options.save:
    fig.savefig(f"variants_abc.{options.extension}")
plt.show()

In [None]:
ts = list(idx2show.keys())
mus, ss = (
    [
        abc_mitchell[(abc_mitchell.timepoint == t) & (abc_mitchell.idx == idx)].mu.iloc[
            0
        ]
        for t, idx in idx2show.items()
    ],
    [
        abc_mitchell[(abc_mitchell.timepoint == t) & (abc_mitchell.idx == idx)].s.iloc[
            0
        ]
        for t, idx in idx2show.items()
    ],
)
sns.histplot(mus, binwidth=0.5)
plt.show()
sns.histplot(ss, binwidth=0.005)
plt.show()

plt.plot(ts, mus, marker="x", mew=2)
plt.show()

plt.plot(ts, ss, marker="x", mew=2)
plt.show()

In [None]:
posterior_mitchell[["s", "tau", "std"]].corr()

In [None]:
fig, ax = plt.subplots(1, 1, layout="constrained")
mean, median = posterior_mitchell.tau.mean(), posterior_mitchell.tau.median()
posterior_mitchell.tau.hist(ax=ax, bins=15)
ax.axvline(mean, color="black", label=f"mean={mean:.2f}")
ax.axvline(mean, color="red", label=f"median={median:.2f}")
ax.legend()
ax.set_xlim([bins_tau[0], bins_tau[-1]])
plt.show()

In [None]:
def color_mapping(eta: float, sigma: float) -> str:
    if sigma > 0.016 and eta < 0.03:
        return "orange"
    if sigma > 0.01 and eta > 0.15:
        return "green"
    return "blue"

In [None]:
fig, ax = plt.subplots(1, 1)
posterior_mitchell["hue"] = posterior_mitchell[["eta", "sigma"]].apply(
    lambda row: color_mapping(row.eta, row.sigma), axis=1
)
sns.scatterplot(
    data=posterior_mitchell[["eta", "sigma", "hue"]].drop_duplicates(),
    x="eta",
    y="sigma",
    hue="hue",
    legend=None,
)
ax.set_xlabel(r"$\eta$")
ax.set_ylabel(r"$\sigma$")
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1)
for row in posterior_mitchell[["eta", "sigma"]].drop_duplicates().sample(8, random_state=SEED).itertuples():
    gamma = abc_fig.Gamma(row.eta, row.sigma)
    # c = color_mapping(row.eta, row.sigma)
    gamma.plot(ax, lw=2, alpha=0.8, label=f"${{\eta}}={row.eta:.2f}\;{{\sigma}}={row.sigma:.2f}$")


ax.legend(fontsize="x-small")
ax.set_ylim([0, 40])
ax.set_xlim([0, 0.2])
ax.set_ylabel("Probability density")
ax.set_xlabel(r"Distribution of fitness effects $s$")
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1)
for row in posterior_mitchell[["eta", "sigma"]].drop_duplicates().itertuples():
    gamma = abc_fig.Gamma(row.eta, row.sigma)
    c = color_mapping(row.eta, row.sigma)
    gamma.plot(ax, c=c, lw=1, alpha=0.2)


ax.set_ylim([0, 150])
ax.set_xlim([0, 0.5])
ax.set_ylabel("pdf")
ax.set_xlabel(r"$s$")
plt.show()

### ABC for each timepoint separately

In [None]:
# plot all posteriors for all timepoints separately, for all metrics
for t in sorted(target_sfs.keys()):
    # for metric in ("rel clones diff", ):
    # for metric in ("clones diff", ):
    # for metric in ("wasserstein", "rel clones diff"):
    for metric in ("wasserstein",):
        fig, axes = plt.subplots(2, 2, layout="tight", sharey=True)
        print(f"Running ABC on {metric} metric at age {t}")
        idx = abc.run_abc_per_single_timepoint(
            abc_mitchell,
            t,
            quantile=0.01,
            metric=metric,
        ).get_idx()
        print(len(idx))
        view_tt = abc_mitchell[abc_mitchell.idx.isin(idx)]

        for i, (theta, lim) in enumerate(
            zip(
                ["s", "std", "mu", "tau"],
                [
                    (bins_s[0], bins_s[-1]),
                    (bins_std[0], bins_std[-1]),
                    (bins_mu[0], bins_mu[-1]),
                    (bins_tau[0], bins_tau[-1]),
                ],
            )
        ):
            ax = axes[np.unravel_index(i, (2, 2))]
            sns.histplot(view_tt[theta], ax=ax)  # stat="percent")
            ax.set_xlim(lim)
        fig.suptitle(f"{t} years with {metric}", y=0.9, fontsize="small")
        plt.show()

In [None]:
color = "#d95f02"
gammas = dict()

# run ABC for one donor (indep.)
for age in sorted(abc_mitchell.age.unique()):
    name = f"{age} year"
    timepoint = abc_mitchell[abc_mitchell.age == age]
    runs2keep = abc.run_abc_sfs_clones(timepoint, 0.05, 0.1, 0)
    view_t = timepoint[timepoint.idx.isin(runs2keep)].drop_duplicates(subset="idx")

    assert not view_t.empty, "empty posterior"
    print(
        f"ABC combined kept {len(runs2keep)} runs over a total of {timepoint.idx.unique().shape[0]} runs"
    )

    custom_lines = [
        Patch(facecolor=color, alpha=0.7, edgecolor="black", label=name),
    ]

    fig = abc_fig.create_posteriors_grid_eta_sigma_tau_mu()
    fig, gamma, estimates = abc_fig.plot_posteriors_grid_eta_sigma_tau_mu(
        view_t,
        name,
        fig,
        color,
        bins_eta=bins_s,
        bins_sigma=bins_std,
        bins_tau=bins_tau,
        bins_mu=bins_mu,
        fancy=False,
    )
    fig.axes[1].legend(handles=custom_lines, frameon=False, fontsize=16)
    fig.axes[5].yaxis.set_major_formatter(ticker.FuncFormatter(abc_fig.fmt_two_digits))

    gammas[age] = gamma
    # add text box for the statistics
    # bbox = dict(boxstyle='round', fc="white", ec=None, alpha=0.5)
    for ax_, estimate in zip(fig.axes[2:], estimates):
        ax_.text(
            0.95,
            0.85,
            f"${ax_.get_xlabel().replace('$', '')}={estimate.point_estimate:.2f}^{{{estimate.credible_interval_90[1]:.2f}}}_{{{estimate.credible_interval_90[0]:.2f}}}$",
            fontsize=11,
            color=color,
            # bbox=bbox,
            transform=ax_.transAxes,
            horizontalalignment="right",
        )
    if options.save:
        fig.savefig(f"posteriors_per_patient_{age}.{options.extension}")
    plt.show()

In [None]:
markers = None, ".", "o", "v", "s", "*", "x", 4

fig, ax = plt.subplots(1, 1)

for (age, gamma), marker in zip(gammas.items(), markers):
    gamma.plot(ax, marker=marker, label=age, lw=2, mew=1.5, markevery=5, alpha=0.7)
ax.set_ylim([1, 60])
ax.set_xlim([0, 0.21])
ax.set_ylabel("pdf")
ax.set_xlabel("s")
ax.legend(title="years", fontsize="small")
plt.show()

colors = ["#7570b3", "#e7298a"]
fig, ax = plt.subplots(1, 1)
ax.plot(
    list(gammas.keys()),
    [gamma.mean for gamma in gammas.values()],
    marker=".",
    label=r"$\eta$",
    c=colors[0],
)
secax = ax.secondary_yaxis("right")
ax.plot(
    list(gammas.keys()),
    [gamma.std for gamma in gammas.values()],
    marker="o",
    label=r"$\sigma$",
    c=colors[1],
)
ax.set_ylabel(r"$\eta$", color=colors[0])
# ax.set_ylim([0, 0.2])
secax.set_ylabel(r"$\sigma$", color=colors[1])
ax.set_xlabel("age")
ax.legend()
plt.show()

## Run ABC on subsampled simulated data

In [None]:
all_idx_from_sims = set([s.parameters.idx for s in sfs_sims[0]])

In [None]:
%%time
target_idx = random.sample(sorted(all_idx_from_sims), 1)[0]
print(target_idx)
validation_idx3 = abc_fig.SyntheticValidation(target_idx, sfs_sims, counts)

In [None]:
gs02_02 = validation_idx3.compute_posteriors(
    0.3, 0.3, PROPORTION_RUNS_DISCARDED, bins_s, bins_mu, bins_tau, bins_std
)

In [None]:
%%time
target_idx = random.sample(sorted(all_idx_from_sims), 1)[0]
print(target_idx)
validation_idx3 = abc_fig.SyntheticValidation(target_idx, sfs_sims, counts)

In [None]:
gs02_02 = validation_idx3.compute_posteriors(
    0.15, 0.15, PROPORTION_RUNS_DISCARDED, bins_s, bins_mu, bins_tau, bins_std
)

In [None]:
# find some values that are close to the values we find on Mitchell's data
# eta=0.05, sigma=0.02, mu<4, tau=2
for idx in random.sample(sorted(all_idx_from_sims), 100):
    tmp = counts.loc[counts.idx == idx, ["mu", "s", "tau", "std"]].iloc[0]
    eta, sigma = tmp.s / tmp.tau, tmp["std"] / tmp.tau
    if tmp.tau < 5 and tmp.mu < 5 and eta<0.07 and sigma < 0.04:
        print(idx, eta, sigma, tmp.tau, tmp.mu)
        print("\n")

In [None]:
%%time
target_idx = 443840
validation_idx3 = abc_fig.SyntheticValidation(target_idx, sfs_sims, counts)

In [None]:
gs02_02 = validation_idx3.compute_posteriors(
    0.14, 0.14, PROPORTION_RUNS_DISCARDED, bins_s, bins_mu, bins_tau, bins_std
)