In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import numpy as np
import socket
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
from matplotlib import colors
from datetime import datetime
from pathlib import Path
from random import choices

from futils import parse_version, snapshot
from hscpy.figures import PlotOptions
from hscpy.figures import abc as abc_fig
from hscpy import sfs, mitchell, abc

PATH2BIN = Path("~").expanduser() / "hsc/target/release"
assert PATH2BIN.is_dir()

In [None]:
%%bash -s "$PATH2BIN" --out version
$1/hsc --version

In [None]:
NB_TIMEPOINTS = 16
USE_SCRATCH = True

SAMPLE = 368

SHOW_PRIORS = True

In [None]:
LATEST = False
if LATEST:
    VERSION = parse_version(version)
else:
    VERSION = "v1.3.0"
PATH2SAVE = Path(f"./{VERSION}")

print("Running hsc with version:", VERSION)

if USE_SCRATCH:
    PATH2SIMS = Path("/data/scratch/")
else:
    PATH2SIMS = Path("/data/home/")
PATH2SIMS /= f"hfx923/hsc-draft/{VERSION}"

if socket.gethostname() == "5X9ZYD3":
    PATH2MITCHELL = Path("/mnt/c/Users/terenz01/Documents/SwitchDrive/PhD/hsc")
elif socket.gethostname() == "LAPTOP-CEKCHJ4C":
    PATH2MITCHELL = Path("/mnt/c/Users/fra_t/Documents/PhD/hsc")
else:
    PATH2MITCHELL = Path("~").expanduser()

In [None]:
options = PlotOptions(figsize=(7, 6), extension="pdf", save=False)
# exclude donors for different reasons:
# 1. exclude KX007 bc they have uploded twice the same donor
# 2. exclude CB001 bc it maps to to the same timepoint as CB002 (same age 0)
summary = mitchell.load_and_process_mitchell(
    PATH2MITCHELL / "Summary_cut.csv", drop_donor_KX007=True
)
print(summary.shape)
summary.drop(index=summary[summary.donor_id == "CB001"].index, inplace=True)
print(summary.shape)
ages = summary.age.unique().tolist()

## Remove runs
Remove all the runs that didn't finish running. This is required because we load runs in a numpy array with fixed size.

In [None]:
path2sfs = Path(f"{PATH2SIMS}/{SAMPLE}cells/sfs/")
timepoint1 = {file.stem for file in (path2sfs / "0dot0years").iterdir()}
timepoint2 = {file.stem for file in (path2sfs / "81dot0years").iterdir()}
files2remove = timepoint1.symmetric_difference(timepoint2)
print(f"{len(files2remove)} files to remove")
runs2remove = " ".join(files2remove)

In [None]:
%%bash -s "$runs2remove" "$PATH2SIMS"
echo "removing files"
for file in $1
do
    find $2 -name *$file* -exec rm {} \;
done

In [None]:
path2sfs = Path(f"{PATH2SIMS}/{SAMPLE}cells/sfs/")
timepoint1 = {file.stem for file in (path2sfs / "0dot0years").iterdir()}
timepoint2 = {file.stem for file in (path2sfs / "81dot0years").iterdir()}
files2remove = timepoint1.symmetric_difference(timepoint2)
print(f"{len(files2remove)} files to remove")
runs2remove = " ".join(files2remove)

## Load simualted SFS

In [None]:
%%time
# load the sfs from sims by age, considering the age of the donors
# in the Mitchell data `summary`
path2sfs = Path(PATH2SIMS / f"{SAMPLE}cells/sfs/")
ages_mitchell = sorted(summary.age.unique())
ages_sims = sorted(
    [mitchell.parse_path2folder_xdoty_years(path) for path in path2sfs.iterdir()]
)
assert ages_sims == ages_mitchell
# load data
sfs_sims = mitchell.load_all_sfs_by_age(path2sfs)

## Run ABC on subsampled simulated data

In [None]:
%%time
target_stem = "7dot2964754mu0_0dot000018241096768178977u_0dot39135662mean_0dot045755856std_1b0_200000cells_14230idx"
target_sfs_simulated = {
    t: sfs_.sfs
    for t, sfs_donor in sfs_sims.items()
    for sfs_ in sfs_donor
    if sfs_.parameters.path.stem == target_stem
}
abc_simulated = abc.sfs_summary_statistic_wasserstein(
    sfs_sims, target_sfs_simulated, target_stem
)
abc_simulated

In [None]:
# show priors
priors = abc_simulated[["mu", "u", "s", "std"]].drop_duplicates()

if SHOW_PRIORS:
    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["s"], ax=ax, binwidth=0.01)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["std"], ax=ax, binwidth=0.001)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["mu"], ax=ax, discrete=True)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["u"], ax=ax)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    sns.histplot(abc_simulated["wasserstein"], binwidth=0.01, ax=ax)
    plt.show()

In [None]:
quantile = 0.06
prop2discard = 0.55  # the higher, the less precise and thus the more runs
minimum_runs = NB_TIMEPOINTS - round(NB_TIMEPOINTS * prop2discard)
results = abc.run_abc(abc_simulated, quantile, minimum_runs, verbose=True)

In [None]:
if len(results.get_idx()) < 20:
    idx_target = abc_simulated.loc[
        abc_results.donor_name == target_stem, "idx"
    ].unique()[0]

    for t in abc_simulated.timepoint.unique():
        fig, ax = plt.subplots(1, 1)
        target = sfs.process_sfs(
            target_sfs_simulated[t], normalise=False, log_transform=True
        )
        ax.plot(
            list(target.keys()),
            list(target.values()),
            marker="x",
            linestyle="",
            color="b",
            label=f"target {idx_target}",
            mew=2,
        )
        for s_id in runs2keep:
            if s_id == idx_run:
                continue
            run = [ele for ele in sfs_sims[t] if ele.parameters.idx == s_id][0]
            sim = sfs.process_sfs(run.sfs, normalise=False, log_transform=True)
            ax.plot(
                list(sim.keys()),
                list(sim.values()),
                marker="o",
                linestyle="",
                alpha=0.4,
                label=f"sim {run.parameters.idx}",
            )
        ax.legend()
        ax.set_title(f"age: {t} years")
        plt.show()

sns.heatmap(
    abc_simulated.loc[
        (abc_simulated.idx.isin(results.get_idx()))
        & (abc_simulated.timepoint == abc_simulated.timepoint.max()),
        ["s", "mu", "wasserstein"],
    ]
    .drop_duplicates()
    .pivot(index="s", columns="mu", values="wasserstein")
)
plt.show()
sns.heatmap(
    abc_simulated.loc[
        (abc_simulated.idx.isin(results.get_idx()))
        & (abc_simulated.timepoint == abc_simulated.timepoint.max()),
        ["s", "std", "wasserstein"],
    ]
    .drop_duplicates()
    .pivot(index="s", columns="std", values="wasserstein")
)
plt.show()

In [None]:
selected = abc_simulated.loc[
    abc_simulated.idx.isin(results.get_idx()), ["mu", "s", "std"]
].drop_duplicates()

abc_fig.plot_results(
    selected,
    ["mu", "s"],
    [0, abc_fig.lims(priors, "mu")[1]],
    [0, abc_fig.lims(priors, "s")[1]],
    {"discrete": True},
    {"binwidth": 0.01},
)
abc_fig.plot_results(
    selected,
    ["mu", "std"],
    [0, abc_fig.lims(priors, "mu")[1]],
    [0, abc_fig.lims(priors, "std")[1]],
    {"discrete": True},
    {"binwidth": 0.005},
)
abc_fig.plot_results(
    selected,
    ["s", "std"],
    [0, abc_fig.lims(priors, "s")[1]],
    [0, abc_fig.lims(priors, "std")[1]],
    {"binwidth": 0.01},
    {"binwidth": 0.005},
)

In [None]:
fig, ax = plt.subplots(1, 1)
sns.histplot(
    data=results.accepted.timepoint, discrete=True, ax=ax, label="accepted filtered"
)
ax.legend()
plt.show()
fig, ax = plt.subplots(1, 1)
sns.histplot(
    data=results.accepted_quantile.timepoint, discrete=True, ax=ax, label="accepted"
)
ax.legend()
plt.show()
print(
    f"{len(results.get_idx())} runs accepted with minimum sims {minimum_runs} over total of {NB_TIMEPOINTS}"
)


fig, ax = plt.subplots(1, 1)
sns.lineplot(
    x="age",
    y="wasserstein",
    data=abc_simulated.loc[
        abc_simulated.idx.isin(results.get_idx()),
        ["timepoint", "wasserstein"],
    ].rename({"timepoint": "age"}, axis=1),
    errorbar="sd",
    label="accepted filtered",
    ax=ax,
)

sns.lineplot(
    x="age",
    y="wasserstein",
    data=abc_simulated.loc[
        abc_simulated.idx.isin(results.accepted_quantile.idx.unique()),
        ["timepoint", "wasserstein"],
    ].rename({"timepoint": "age"}, axis=1),
    errorbar="sd",
    label="accepted",
    ax=ax,
)
ax.set_ylim([abc_simulated.wasserstein.min(), abc_simulated.wasserstein.max()])
plt.show()

In [None]:
abc_simulated[
    abc_simulated.path.map(lambda x: Path(x).stem == abc_simulated.donor_name.iloc[0])
]

## Run ABC on the real data
We have excluded two donors from the ABC:
1. exclude KX007 bc they have uploded twice the same donor
3. exclude CB001 bc it maps to to the same timepoint as CB002 (same age 0)

In [None]:
%%time
names_mitchell = [
    summary.loc[summary.age == age, ["donor_id", "age"]]
    .drop_duplicates()
    .donor_id.squeeze()
    for age in ages_mitchell
]
target_sfs = {
    age: mitchell.sfs_donor_mitchell(donor, PATH2MITCHELL, remove_indels=False)
    for age, donor in zip(ages_mitchell, names_mitchell)
}

In [None]:
%%time
abc_mitchell = abc.sfs_summary_statistic_wasserstein(sfs_sims, target_sfs, "mitchell")
abc_mitchell

In [None]:
priors = abc_mitchell[["mu", "u", "s", "std"]].drop_duplicates()

if SHOW_PRIORS:
    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["s"], ax=ax, binwidth=0.01)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["std"], ax=ax, binwidth=0.001)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["mu"], ax=ax, discrete=True)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["u"], ax=ax)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    sns.histplot(abc_mitchell["wasserstein"], binwidth=0.01, ax=ax)
    plt.show()

In [None]:
assert abc_mitchell.timepoint.unique().shape[0] == len(ages_mitchell)
quantile = 0.5
prop2discard = 0  # the higher, the less precise and thus the more runs
minimum_runs = len(ages_mitchell) - round(len(ages_mitchell) * prop2discard)
results_mitchell = abc.run_abc(abc_mitchell, quantile, minimum_runs, verbose=True)
print(
    f"{len(results_mitchell.get_idx())} runs accepted with minimum timepoints of {minimum_runs} out of {len(ages_mitchell)} timepoints"
)

In [None]:
selected = abc_mitchell.loc[
    abc_mitchell.idx.isin(results_mitchell.get_idx()), ["mu", "s", "std"]
].drop_duplicates()

abc_fig.plot_results(
    selected,
    ["mu", "s"],
    [0, abc_fig.lims(priors, "mu")[1]],
    [0, abc_fig.lims(priors, "s")[1]],
    {"discrete": True},
    {"binwidth": 0.01},
)

abc_fig.plot_results(
    selected,
    ["mu", "std"],
    [0, abc_fig.lims(priors, "mu")[1]],
    [0, abc_fig.lims(priors, "std")[1]],
    {"discrete": True},
    {"binwidth": 0.005},
)
abc_fig.plot_results(
    selected,
    ["s", "std"],
    [0, abc_fig.lims(priors, "s")[1]],
    [0, abc_fig.lims(priors, "std")[1]],
    {"binwidth": 0.01},
    {"binwidth": 0.005},
)

In [None]:
selected = abc_mitchell.loc[
    abc_mitchell.idx.isin(results_mitchell.get_idx()), ["mu", "s", "std"]
].drop_duplicates()

abc_fig.plot_results(
    selected,
    ["mu", "s"],
    [0, abc_fig.lims(priors, "mu")[1]],
    [0, abc_fig.lims(priors, "s")[1]],
    {"discrete": True},
    {"binwidth": 0.01},
)

abc_fig.plot_results(
    selected,
    ["mu", "std"],
    [0, abc_fig.lims(priors, "mu")[1]],
    [0, abc_fig.lims(priors, "std")[1]],
    {"discrete": True},
    {"binwidth": 0.005},
)
abc_fig.plot_results(
    selected,
    ["s", "std"],
    [0, abc_fig.lims(priors, "s")[1]],
    [0, abc_fig.lims(priors, "std")[1]],
    {"binwidth": 0.01},
    {"binwidth": 0.005},
)

In [None]:
k = 2
colors = ["cyan", "black", "yellowgreen"]
markers = {"o", "<", "*"}
alpha = 0.45

for t in sorted(abc_mitchell.timepoint.unique()):
    fig = plt.figure(layout="constrained", figsize=(7, 4))
    fig.suptitle(
        f"age: {t} years, quantile threshold: {abc_mitchell.loc[abc_mitchell.timepoint == t, 'wasserstein'].quantile(quantile):.2f}",
        x=0.4,
    )

    subfigs = fig.subfigures(1, 2, wspace=-0.1, width_ratios=[2.4, 1])
    axes = subfigs[0].subplots(2, 1, height_ratios=[1.4, 1])
    ax3 = subfigs[1].subplots(1, 1)

    target = sfs.process_sfs(target_sfs[t], normalise=False, log_transform=True)
    u_values, u_weights = list(target.keys()), list(target.values())

    axes[0].plot(
        list(target.keys()),
        list(target.values()),
        marker="x",
        linestyle="",
        color="purple",
        label=f"Mitchell",
        mew=2,
    )
    axes[1].plot(*snapshot.cdf_from_histogram(target), color="purple", label="Mitchell")
    if len(results_mitchell.get_idx()) < 20:
        idx2show = results_mitchell.get_idx()
    else:
        idx2show = choices(results_mitchell.get_idx(), k=k)
    for s_id, marker, color in zip(idx2show, markers, colors):
        run = [ele for ele in sfs_sims[t] if ele.parameters.idx == s_id][0]
        sim = sfs.process_sfs(run.sfs, normalise=False, log_transform=True)
        wasserstein = abc_mitchell.loc[
            (abc_mitchell.idx == s_id) & (abc_mitchell.timepoint == t), "wasserstein"
        ].squeeze()
        v_values, v_weights = list(sim.keys()), list(sim.values())
        wasserstein_scipy = stats.wasserstein_distance(
            u_values, v_values, u_weights, v_weights
        )
        assert wasserstein == wasserstein_scipy
        axes[0].plot(
            list(sim.keys()),
            list(sim.values()),
            marker=marker,
            linestyle="",
            mew=1,
            alpha=alpha,
            color=color,
            label=f"id: {s_id}, dist: {wasserstein:.2f}",
        )
        axes[1].plot(
            *snapshot.cdf_from_histogram(sim),
            alpha=alpha,
            color=color,
            linestyle="--",
            label=f"{run.parameters.idx}, metric: {wasserstein:.2f}",
        )
    axes[0].set_ylabel("log10 nb of mutants")
    axes[1].set_ylabel("cdf")
    axes[1].set_xlabel("log10 nb of cells")

    ax3.legend(*axes[0].get_legend_handles_labels(), loc=6, frameon=False)
    ax3.set_xticks([])
    ax3.set_yticks([])
    ax3.spines.right.set_visible(False)
    ax3.spines.left.set_visible(False)
    ax3.spines.top.set_visible(False)
    ax3.spines.bottom.set_visible(False)
    fig.savefig(f"sfs_{t}years.pdf")
    plt.show()

In [None]:
fig, ax = plt.subplots(1, 1)
sns.histplot(
    data=results_mitchell.accepted.timepoint,
    discrete=True,
    ax=ax,
    label="accepted filtered",
)
ax.legend()
plt.show()

fig, ax = plt.subplots(1, 1)
sns.lineplot(
    x="age",
    y="wasserstein",
    data=abc_mitchell.loc[
        abc_mitchell.idx.isin(results_mitchell.get_idx()),
        ["timepoint", "wasserstein"],
    ].rename({"timepoint": "age"}, axis=1),
    errorbar="sd",
    label="accepted filtered",
    ax=ax,
)

sns.lineplot(
    x="age",
    y="wasserstein",
    data=abc_mitchell.loc[
        abc_mitchell.idx.isin(results_mitchell.accepted_quantile.idx.unique()),
        ["timepoint", "wasserstein"],
    ].rename({"timepoint": "age"}, axis=1),
    errorbar="sd",
    label="accepted",
    ax=ax,
)
ax.set_ylim([abc_simulated.wasserstein.min(), abc_simulated.wasserstein.max()])
plt.show()

fig, ax = plt.subplots(1, 1)
sns.histplot(
    data=results_mitchell.accepted_quantile.timepoint,
    discrete=True,
    ax=ax,
    label="accepted",
)
ax.legend()
plt.show()