In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import numpy as np
import socket
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
from matplotlib import colors
from datetime import datetime
from pathlib import Path
from random import choices

from futils import parse_version, snapshot
from hscpy.figures import PlotOptions
from hscpy.figures import abc as abc_fig
from hscpy.figures import sfs as sfs_fig
from hscpy import sfs, mitchell, abc, parse_path2folder_xdoty_years, variant

PATH2BIN = Path("~").expanduser() / "hsc/target/release"
assert PATH2BIN.is_dir()

In [None]:
%%bash -s "$PATH2BIN" --out version
$1/hsc --version

In [None]:
USE_SCRATCH = True
SAMPLE = 368
SAVEFIG = True
SHOW_PRIORS = False

In [None]:
LATEST = True
if LATEST:
    VERSION = parse_version(version)
else:
    VERSION = "v1.3.0"
PATH2SAVE = Path(f"./{VERSION}")

print("Running hsc with version:", VERSION)

if USE_SCRATCH:
    PATH2SIMS = Path("/data/scratch/")
else:
    PATH2SIMS = Path("/data/home/")
PATH2SIMS /= f"hfx923/hsc-draft/{VERSION}"

if socket.gethostname() == "5X9ZYD3":
    PATH2MITCHELL = Path("/mnt/c/Users/terenz01/Documents/SwitchDrive/PhD/hsc")
elif socket.gethostname() == "LAPTOP-CEKCHJ4C":
    PATH2MITCHELL = Path("/mnt/c/Users/fra_t/Documents/PhD/hsc")
else:
    PATH2MITCHELL = Path("~").expanduser()

In [None]:
options = PlotOptions(figsize=(7, 6), extension="pdf", save=SAVEFIG)
# exclude donors for different reasons:
# 1. exclude KX007 bc they have uploded twice the same donor
# 2. exclude CB001 bc it maps to to the same timepoint as CB002 (same age 0)
summary = mitchell.load_and_process_mitchell(
    PATH2MITCHELL / "Summary_cut.csv", drop_donor_KX007=True
)
print(summary.shape)
summary.drop(index=summary[summary.donor_id == "CB001"].index, inplace=True)
print(summary.shape)
ages = summary.age.unique().tolist()

## Remove runs
Remove all the runs that didn't finish running. This is required because we load runs in a numpy array with fixed size.

In [None]:
path2sfs = Path(f"{PATH2SIMS}/{SAMPLE}cells/sfs/")
timepoint1 = {file.stem for file in (path2sfs / "0dot0years").iterdir()}
timepoint2 = {file.stem for file in (path2sfs / "81dot0years").iterdir()}
files2remove = timepoint1.symmetric_difference(timepoint2)
print(f"{len(files2remove)} files to remove")
runs2remove = " ".join(files2remove)

In [None]:
%%bash -s "$runs2remove" "$PATH2SIMS"
echo "removing files"
for file in $1
do
    find $2 -name *$file* -exec rm {} \;
done

In [None]:
path2sfs = Path(f"{PATH2SIMS}/{SAMPLE}cells/sfs/")
timepoint1 = {file.stem for file in (path2sfs / "0dot0years").iterdir()}
timepoint2 = {file.stem for file in (path2sfs / "81dot0years").iterdir()}
files2remove = timepoint1.symmetric_difference(timepoint2)
print(f"{len(files2remove)} files to remove")
runs2remove = " ".join(files2remove)

## Load simualted data both SFS and variant fractions

In [None]:
%%time
# load the sfs from sims by age, considering the age of the donors
# in the Mitchell data `summary`
path2sfs = Path(PATH2SIMS / f"{SAMPLE}cells/sfs/")
ages_mitchell = sorted(summary.age.unique())
ages_sims = sorted([parse_path2folder_xdoty_years(path) for path in path2sfs.iterdir()])
assert ages_sims == ages_mitchell
# load data
sfs_sims = sfs.load_all_sfs_by_age(path2sfs)

In [None]:
%%time
counts = variant.load_all_detected_var_counts_by_age(
    PATH2SIMS / f"{SAMPLE}cells/variant_fraction", 0.01
)
counts = variant.variant_counts_detected_df(counts)

In [None]:
fig, ax = plt.subplots(1, 1)
sns.lineplot(
    counts,
    x="age",
    y="variant counts detected",
    errorbar=lambda x: (np.min(x), np.max(x)),
    ax=ax,
)
sns.lineplot(
    counts, x="age", y="variant counts detected", errorbar="sd", ax=ax, color="orange"
)
plt.show()

In [None]:
counts[["variant counts detected", "age"]].groupby("age").describe()

## Run ABC on the real data

### Load the data
We have excluded two donors from the ABC:
1. exclude KX007 bc they have uploded twice the same donor
3. exclude CB001 bc it maps to to the same timepoint as CB002 (same age 0)

In [None]:
%%time
names_mitchell = [
    summary.loc[summary.age == age, ["donor_id", "age"]]
    .drop_duplicates()
    .donor_id.squeeze()
    for age in ages_mitchell
]
target_sfs = {
    age: mitchell.sfs_donor_mitchell(donor, PATH2MITCHELL, remove_indels=False)
    for age, donor in zip(ages_mitchell, names_mitchell)
}

### Compute the summary statistics (wasserstein metric) and add the number of clones

In [None]:
%%time
abc_mitchell = abc.sfs_summary_statistic_wasserstein(sfs_sims, target_sfs, "mitchell")
abc_mitchell

# add information about clones from Mitchell's fig 5a
abc_mitchell = abc_mitchell.merge(
    right=counts[["age", "idx", "variant counts detected"]],
    how="left",
    left_on=["idx", "timepoint"],
    right_on=["idx", "age"],
    validate="one_to_one",
)
assert (
    not abc_mitchell.isna().any().any()
), "cannot match the nb of clones data to the abc results"
abc_mitchell = pd.DataFrame.from_records(
    [
        {"age": 0.0, "clones": 0},
        {"age": 29.0, "clones": 0},
        {"age": 38.0, "clones": 1},
        {"age": 48.0, "clones": 0},
        {"age": 63.0, "clones": 1},
        {"age": 76.0, "clones": 12},
        {"age": 77.0, "clones": 15},
        {"age": 81.0, "clones": 13},
    ]
).merge(right=abc_mitchell, how="right", on="age", validate="one_to_many")
abc_mitchell["clones diff"] = (
    abc_mitchell["clones"] - abc_mitchell["variant counts detected"]
).abs()

### Show priors

In [None]:
if SHOW_PRIORS:
    priors = abc_mitchell[["mu", "u", "s", "std"]].drop_duplicates()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["s"], ax=ax, binwidth=0.01)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["std"], ax=ax, binwidth=0.001)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["mu"], ax=ax, discrete=True)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["u"], ax=ax)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    sns.histplot(abc_mitchell["wasserstein"], binwidth=0.01, ax=ax)
    plt.show()

### Run abc considering all timepoints at the same time

In [None]:
assert abc_mitchell.timepoint.unique().shape[0] == len(ages_mitchell)
quantile = 0.9
prop2discard = 0.1  # the higher, the less precise and thus the more runs
minimum_runs = len(ages_mitchell) - round(len(ages_mitchell) * prop2discard)
print(f"{minimum_runs} vs {len(ages_mitchell)}")
results_mitchell, g1, g2, g3 = abc_fig.run_abc_filtering_on_clones(
    abc_mitchell, quantile, nb_clones_diff=4, minimum_runs=minimum_runs
)
if options.save:
    g1.savefig(f"posterior_mu_s.{options.extension}")
    g2.savefig(f"posterior_mu_std.{options.extension}")
    g3.savefig(f"posterior_s_std.{options.extension}")

In [None]:
# k = 2
# idx2show = choices(results_mitchell.get_idx(), k=k)
verbose = False

for t in sorted(abc_mitchell.timepoint.unique()):
    if verbose:
        title = f"age: {t} years, quantile threshold: {abc_mitchell.loc[abc_mitchell.timepoint == t, 'wasserstein'].quantile(quantile):.2f}"
    else:
        title = f"age: {round(t)} years"
    idx2show = []

    idx2show.append(
        abc_fig.get_idx_smaller_distance_clones_from_results(
            abc_mitchell[abc_mitchell.timepoint == t], results_mitchell
        )
    )

    fig = sfs_fig.plot_sfs_cdf(
        idx2show, target_sfs[t], sfs_sims[t], verbose=verbose, alpha=0.7
    )
    fig.suptitle(
        title,
        x=0.4,
    )
    if options.save:
        fig.savefig(f"sfs_{t}years.pdf")
    plt.show()

In [None]:
fig, ax = plt.subplots(1, 1)
sns.histplot(
    data=results_mitchell.accepted.timepoint + 2,
    discrete=True,
    ax=ax,
    stat="density",
    label="accepted filtered",
)
sns.histplot(
    data=results_mitchell.accepted_quantile.timepoint,
    discrete=True,
    stat="density",
    ax=ax,
    label="accepted",
)
ax.legend()
plt.show()

fig, ax = plt.subplots(1, 1)
sns.lineplot(
    x="age",
    y="wasserstein",
    data=abc_mitchell.loc[
        abc_mitchell.idx.isin(results_mitchell.get_idx()),
        ["timepoint", "wasserstein"],
    ].rename({"timepoint": "age"}, axis=1),
    errorbar="sd",
    label="accepted filtered",
    ax=ax,
)

sns.lineplot(
    x="age",
    y="wasserstein",
    data=abc_mitchell.loc[
        abc_mitchell.idx.isin(results_mitchell.accepted_quantile.idx.unique()),
        ["timepoint", "wasserstein"],
    ].rename({"timepoint": "age"}, axis=1),
    errorbar="sd",
    label="accepted",
    ax=ax,
)
ax.set_ylim([abc_mitchell.wasserstein.min(), abc_mitchell.wasserstein.max()])
plt.show()

## Run ABC on subsampled simulated data

### High fitness

In [None]:
%%time
target_stem = "7dot2964754mu0_0dot000018241096768178977u_0dot39135662mean_0dot045755856std_1b0_200000cells_14230idx"
target_sfs_simulated = {
    t: sfs_.sfs
    for t, sfs_donor in sfs_sims.items()
    for sfs_ in sfs_donor
    if sfs_.parameters.path.stem == target_stem
}
abc_simulated = abc.sfs_summary_statistic_wasserstein(
    sfs_sims, target_sfs_simulated, target_stem
)

abc_simulated["target"] = abc_simulated.path.map(lambda x: Path(x).stem) == target_stem

abc_simulated = abc_simulated.merge(
    right=counts[["age", "idx", "variant counts detected"]],
    how="left",
    left_on=["idx", "timepoint"],
    right_on=["idx", "age"],
    validate="one_to_one",
)

abc_simulated = abc_simulated.merge(
    right=abc_simulated.loc[
        abc_simulated.target, ["variant counts detected", "timepoint"]
    ].rename({"variant counts detected": "clones"}, axis=1),
    how="left",
    on="timepoint",
    validate="many_to_one",
)

abc_simulated["clones diff"] = (
    abc_simulated["clones"] - abc_simulated["variant counts detected"]
).abs()

abc_simulated

In [None]:
if SHOW_PRIORS:
    priors = abc_simulated[["mu", "u", "s", "std"]].drop_duplicates()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["s"], ax=ax, binwidth=0.01)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["std"], ax=ax, binwidth=0.001)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["mu"], ax=ax, discrete=True)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["u"], ax=ax)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    sns.histplot(abc_simulated["wasserstein"], binwidth=0.01, ax=ax)
    plt.show()

In [None]:
quantile = 0.5
prop2discard = 0.1  # the higher, the less precise and thus the more runs
minimum_runs = len(ages_mitchell) - round(len(ages_mitchell) * prop2discard)
print(f"{minimum_runs} vs {len(ages_mitchell)}")
results, g1, g2, g3 = abc_fig.run_abc_filtering_on_clones(
    abc_simulated, quantile, nb_clones_diff=2, minimum_runs=minimum_runs
)
mu_target, s_target, std_target = (
    abc_simulated.loc[abc_simulated.target, "mu"].squeeze(),
    abc_simulated.loc[abc_simulated.target, "s"].squeeze(),
    abc_simulated.loc[abc_simulated.target, "std"].squeeze(),
)
g1.ax_joint.plot(mu_target, s_target, marker="x", color="black", mew=2)
g2.ax_joint.plot(mu_target, std_target, marker="x", color="black", mew=2)
g3.ax_joint.plot(s_target, std_target, marker="x", color="black", mew=2)

plt.show()

### Low fitness

In [None]:
%%time
target_stem = "3dot2131548mu0_0dot000008032846380956471u_0dot097829mean_0dot057576984std_1b0_200000cells_66310idx"
target_sfs_simulated = {
    t: sfs_.sfs
    for t, sfs_donor in sfs_sims.items()
    for sfs_ in sfs_donor
    if sfs_.parameters.path.stem == target_stem
}
abc_simulated = abc.sfs_summary_statistic_wasserstein(
    sfs_sims, target_sfs_simulated, target_stem
)

abc_simulated["target"] = abc_simulated.path.map(lambda x: Path(x).stem) == target_stem

abc_simulated = abc_simulated.merge(
    right=counts[["age", "idx", "variant counts detected"]],
    how="left",
    left_on=["idx", "timepoint"],
    right_on=["idx", "age"],
    validate="one_to_one",
)

abc_simulated = abc_simulated.merge(
    right=abc_simulated.loc[
        abc_simulated.target, ["variant counts detected", "timepoint"]
    ].rename({"variant counts detected": "clones"}, axis=1),
    how="left",
    on="timepoint",
    validate="many_to_one",
)

abc_simulated["clones diff"] = (
    abc_simulated["clones"] - abc_simulated["variant counts detected"]
).abs()

abc_simulated

In [None]:
if SHOW_PRIORS:
    priors = abc_simulated[["mu", "u", "s", "std"]].drop_duplicates()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["s"], ax=ax, binwidth=0.01)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["std"], ax=ax, binwidth=0.001)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["mu"], ax=ax, discrete=True)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["u"], ax=ax)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    sns.histplot(abc_simulated["wasserstein"], binwidth=0.01, ax=ax)
    plt.show()

In [None]:
quantile = 0.9
prop2discard = 0.1  # the higher, the less precise and thus the more runs
minimum_runs = len(ages_mitchell) - round(len(ages_mitchell) * prop2discard)
print(f"{minimum_runs} vs {len(ages_mitchell)}")
results, g1, g2, g3 = abc_fig.run_abc_filtering_on_clones(
    abc_simulated, quantile, nb_clones_diff=0, minimum_runs=minimum_runs
)
mu_target, s_target, std_target = (
    abc_simulated.loc[abc_simulated.target, "mu"].squeeze(),
    abc_simulated.loc[abc_simulated.target, "s"].squeeze(),
    abc_simulated.loc[abc_simulated.target, "std"].squeeze(),
)
g1.ax_joint.plot(mu_target, s_target, marker="x", color="black", mew=2)
g2.ax_joint.plot(mu_target, std_target, marker="x", color="black", mew=2)
g3.ax_joint.plot(s_target, std_target, marker="x", color="black", mew=2)

plt.show()