In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import numpy as np
import socket
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
from matplotlib import colors
from datetime import datetime
from pathlib import Path
from random import choices
from typing import Union, List
from futils import parse_version, snapshot
from hscpy.figures import PlotOptions
from hscpy.figures import abc as abc_fig
from hscpy.figures import sfs as sfs_fig
from hscpy import realisation, mitchell, abc, parse_path2folder_xdoty_years, variant

PATH2BIN = Path("~").expanduser() / "hsc/target/release"
assert PATH2BIN.is_dir()

In [None]:
%%bash -s "$PATH2BIN" --out version
$1/hsc --version

In [None]:
USE_SCRATCH = True
SAMPLE = 368
SAVEFIG = True
SHOW_PRIORS = True
LATEST = True

# the higher, the less precise and thus the more runs
QUANTILE, PROPORTION_RUNS_DISCARDED, NB_CLONES_DIFF = 0.85, 0.2, 3

options = PlotOptions(figsize=(7, 6), extension="svg", save=SAVEFIG)

In [None]:
if LATEST:
    VERSION = parse_version(version)
else:
    VERSION = "v1.3.0"
PATH2SAVE = Path(f"./{VERSION}")

print("Running hsc with version:", VERSION)

if USE_SCRATCH:
    PATH2SIMS = Path("/data/scratch/")
else:
    PATH2SIMS = Path("/data/home/")
PATH2SIMS /= f"hfx923/hsc-draft/{VERSION}"

if socket.gethostname() == "5X9ZYD3":
    PATH2MITCHELL = Path("/mnt/c/Users/terenz01/Documents/SwitchDrive/PhD/hsc")
elif socket.gethostname() == "LAPTOP-CEKCHJ4C":
    PATH2MITCHELL = Path("/mnt/c/Users/fra_t/Documents/PhD/hsc")
else:
    PATH2MITCHELL = Path("~").expanduser()

In [None]:
# exclude donors for different reasons:
# 1. exclude KX007 bc they have uploded twice the same donor
# 2. exclude CB001 bc it maps to to the same timepoint as CB002 (same age 0)
summary = mitchell.load_and_process_mitchell(
    PATH2MITCHELL / "Summary_cut.csv", drop_donor_KX007=True
)
print(summary.shape)
summary.drop(index=summary[summary.donor_id == "CB001"].index, inplace=True)
print(summary.shape)
ages = summary.age.unique().tolist()

## Remove runs
Remove all the runs that didn't finish running. This is required because we load runs in a numpy array with fixed size.

In [None]:
path2sfs = Path(f"{PATH2SIMS}/{SAMPLE}cells/sfs/")
timepoint1 = {file.stem for file in (path2sfs / "0dot0years").iterdir()}
timepoint2 = {file.stem for file in (path2sfs / "81dot0years").iterdir()}
files2remove = timepoint1.symmetric_difference(timepoint2)
print(f"{len(files2remove)} files to remove")
runs2remove = " ".join(files2remove)

In [None]:
%%bash -s "$runs2remove" "$PATH2SIMS"
echo "removing files"
for file in $1
do
    find $2 -name *$file* -exec rm {} \;
done

In [None]:
path2sfs = Path(f"{PATH2SIMS}/{SAMPLE}cells/sfs/")
timepoint1 = {file.stem for file in (path2sfs / "0dot0years").iterdir()}
timepoint2 = {file.stem for file in (path2sfs / "81dot0years").iterdir()}
files2remove = timepoint1.symmetric_difference(timepoint2)
print(f"{len(files2remove)} files to remove")
runs2remove = " ".join(files2remove)

## Load simualted data both SFS and variant fractions

In [None]:
%%time
# load the sfs from sims by age, considering the age of the donors
# in the Mitchell data `summary`
path2sfs = Path(PATH2SIMS / f"{SAMPLE}cells/sfs/")
ages_mitchell = sorted(summary.age.unique())
ages_sims = sorted([parse_path2folder_xdoty_years(path) for path in path2sfs.iterdir()])
assert ages_sims == ages_mitchell
# load data
sfs_sims = realisation.load_all_sfs_by_age(path2sfs)

In [None]:
%%time
counts = variant.load_all_detected_var_counts_by_age(
    PATH2SIMS / f"{SAMPLE}cells/variant_fraction", 0.01
)
counts = variant.variant_counts_detected_df(counts)

In [None]:
fig, ax = plt.subplots(1, 1)
sns.lineplot(
    counts,
    x="age",
    y="variant counts detected",
    errorbar=lambda x: (np.min(x), np.max(x)),
    ax=ax,
    label="min-max",
)
sns.lineplot(
    counts,
    x="age",
    y="variant counts detected",
    errorbar="sd",
    ax=ax,
    color="orange",
    label="std",
)
ax.legend()
plt.show()

In [None]:
counts[["variant counts detected", "age"]].groupby("age").describe()

## Run ABC on the real data

### Load the data
We have excluded two donors from the ABC:
1. exclude KX007 bc they have uploded twice the same donor
3. exclude CB001 bc it maps to to the same timepoint as CB002 (same age 0)

In [None]:
%%time
names_mitchell = [
    summary.loc[summary.age == age, ["donor_id", "age"]]
    .drop_duplicates()
    .donor_id.squeeze()
    for age in ages_mitchell
]
target_sfs = {
    age: mitchell.sfs_donor_mitchell(donor, PATH2MITCHELL, remove_indels=False)
    for age, donor in zip(ages_mitchell, names_mitchell)
}

### Compute the summary statistics (wasserstein metric) and add the number of clones

In [None]:
%%time
abc_mitchell = abc.sfs_summary_statistic_wasserstein(sfs_sims, target_sfs, "mitchell")
abc_mitchell

# add information about clones from Mitchell's fig 5a
abc_mitchell = abc_mitchell.merge(
    right=counts[["age", "idx", "variant counts detected"]],
    how="left",
    left_on=["idx", "timepoint"],
    right_on=["idx", "age"],
    validate="one_to_one",
)
assert (
    not abc_mitchell.isna().any().any()
), "cannot match the nb of clones data to the abc results"
abc_mitchell = pd.DataFrame.from_records(
    [
        {"age": 0.0, "clones": 0},
        {"age": 29.0, "clones": 0},
        {"age": 38.0, "clones": 1},
        {"age": 48.0, "clones": 0},
        {"age": 63.0, "clones": 1},
        {"age": 76.0, "clones": 12},
        {"age": 77.0, "clones": 15},
        {"age": 81.0, "clones": 13},
    ]
).merge(right=abc_mitchell, how="right", on="age", validate="one_to_many")
abc_mitchell["clones diff"] = (
    abc_mitchell["clones"] - abc_mitchell["variant counts detected"]
).abs()

### Show priors

In [None]:
if SHOW_PRIORS:
    priors = abc_mitchell[["mu", "u", "s", "std"]].drop_duplicates()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["s"], ax=ax, binwidth=0.01)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["std"], ax=ax, binwidth=0.001)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["mu"], ax=ax, discrete=False)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["u"], ax=ax)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    sns.histplot(abc_mitchell["wasserstein"], binwidth=0.01, ax=ax)
    plt.show()

### Run abc considering all timepoints at the same time

In [None]:
def find_map_by_cut(
    view: pd.Series, bins: List[Union[float, int]]
) -> Union[float, None]:
    max_a_posteriori = pd.cut(view, bins=bins).mode()
    if max_a_posteriori.shape[0] > 1:
        print(
            f"cannot compute MAP because more than one mode have been found {max_a_posteriori}"
        )
        return None
    max_a_posteriori = max_a_posteriori.iloc[0]
    return max_a_posteriori.mid

In [None]:
results_mitchell, g1, g2, g3 = abc_fig.run_abc_filtering_on_clones(
    abc_mitchell,
    abc.AbcThresholds(
        quantile=0.9,
        proportion_runs_to_discard=PROPORTION_RUNS_DISCARDED,
        nb_clones_diff=NB_CLONES_DIFF,
    ),
)
view = abc_mitchell[abc_mitchell.idx.isin(results_mitchell.get_idx())]
map_mu = find_map_by_cut(view.mu, range(0, 21))
map_s = find_map_by_cut(view.s, np.arange(0, 0.45, 0.01).tolist())
map_std = find_map_by_cut(view["std"], np.arange(0, 0.15, 0.002).tolist())

if map_mu:
    g1.ax_joint.vlines(
        map_mu,
        view.s.min(),
        view.s.max(),
        color="yellowgreen",
        label=r"MAP $\mu=$" + str(map_mu),
    )
    g2.ax_joint.vlines(
        map_mu,
        view["std"].min(),
        view["std"].max(),
        color="yellowgreen",
        label=r"MAP $\mu=$" + str(map_mu),
    )
if map_s:
    g1.ax_joint.hlines(
        map_s,
        view.mu.min(),
        view.mu.max(),
        color="yellowgreen",
        linestyle="--",
        label=r"MAP $s=$" + f" {map_s:.2f}",
    )
    g3.ax_joint.vlines(
        map_s,
        view["std"].min(),
        view["std"].max(),
        color="yellowgreen",
        label=r"MAP $s=$" + f" {map_s:.2f}",
    )

if map_std:
    g2.ax_joint.hlines(
        map_std,
        view.mu.min(),
        view.mu.max(),
        color="yellowgreen",
        linestyle="--",
        label=r"MAP $\sigma=$" + f"{map_std:.2f}",
    )
    g3.ax_joint.hlines(
        map_std,
        view.s.min(),
        view.s.max(),
        color="yellowgreen",
        linestyle="--",
        label=r"MAP $\sigma=$" + f"{map_std:.2f}",
    )

result = stats.pearsonr(view.mu, view.s)

sns.regplot(
    data=view[["mu", "s"]],
    x="mu",
    y="s",
    label=f"r={result.statistic:.2f}, p={result.pvalue:.2e}",
    scatter=False,
    line_kws={"color": "purple", "linewidth": 1},
    ax=g1.ax_joint,
)
g1.ax_joint.set_xlabel(r"$\mu$")

for g_ in {g1, g2, g3}:
    g_.ax_joint.tick_params(
        which="major",
        bottom=True,
        top=False,
        left=True,
        right=False,
        width=1.1,
        length=5,
        labelsize=14,
    )
    g_.ax_joint.tick_params(
        which="minor",
        bottom=True,
        top=False,
        left=True,
        right=False,
        width=1.1,
        length=3,
        labelsize=14,
    )
    g_.ax_marg_x.tick_params(
        which="major",
        bottom=True,
        top=False,
        left=True,
        right=False,
        width=1.1,
        length=3,
        labelsize=14,
    )
    g_.ax_marg_x.tick_params(
        which="minor",
        bottom=True,
        top=False,
        left=True,
        right=False,
        width=1.1,
        length=3,
        labelsize=14,
    )
    g_.ax_marg_y.tick_params(
        which="minor",
        bottom=True,
        top=False,
        left=True,
        right=False,
        width=1.1,
        length=3,
        labelsize=14,
    )
    g_.ax_marg_y.tick_params(
        which="major",
        bottom=True,
        top=False,
        left=True,
        right=False,
        width=1.1,
        length=3,
        labelsize=14,
    )
    g_.ax_joint.legend(fontsize=14)

if options.save:
    g1.savefig(f"posterior_mu_s.{options.extension}")
    g2.savefig(f"posterior_mu_std.{options.extension}")
    g3.savefig(f"posterior_s_std.{options.extension}")

verbose = False
idx2show = dict()

for t in sorted(abc_mitchell.timepoint.unique()):
    if verbose:
        title = f"age: {t} years, quantile threshold: {abc_mitchell.loc[abc_mitchell.timepoint == t, 'wasserstein'].quantile(quantile):.2f}"
    else:
        title = f"age: {round(t)} years"

    idx2show[t] = abc_fig.get_idx_smaller_distance_clones_from_results(
        abc_mitchell[abc_mitchell.timepoint == t], results_mitchell
    )

    fig = sfs_fig.plot_sfs_cdf(
        [idx2show[t]], target_sfs[t], sfs_sims[t], t, verbose=verbose, alpha=0.7
    )
    """fig.suptitle(
        title,
        x=0.4,
    )"""
    if options.save:
        fig.savefig(f"sfs_{t}years.{options.extension}")
    plt.show()

In [None]:
fig, ax = plt.subplots(1, 1)
view = counts[counts.idx.isin(idx2show.values())]
sns.lineplot(
    view,
    x="age",
    y="variant counts detected",
    errorbar=lambda x: (np.min(x), np.max(x)),
    ax=ax,
    label="min-max",
)
sns.lineplot(
    view,
    x="age",
    y="variant counts detected",
    errorbar="sd",
    ax=ax,
    color="orange",
    label="std",
)
sns.scatterplot(
    data=abc_mitchell[["age", "clones"]].drop_duplicates(),
    x="age",
    y="clones",
    marker="x",
    linewidths=2,
    color="purple",
    label="Mitchell",
)
ax.legend()
plt.show()

In [None]:
ts = list(idx2show.keys())
mus, ss = (
    [
        abc_mitchell[(abc_mitchell.timepoint == t) & (abc_mitchell.idx == idx)].mu.iloc[
            0
        ]
        for t, idx in idx2show.items()
    ],
    [
        abc_mitchell[(abc_mitchell.timepoint == t) & (abc_mitchell.idx == idx)].s.iloc[
            0
        ]
        for t, idx in idx2show.items()
    ],
)
sns.histplot(mus, binwidth=0.5)
plt.show()
sns.histplot(ss, binwidth=0.005)
plt.show()

plt.plot(ts, mus, marker="x", mew=2)
plt.show()

plt.plot(ts, ss, marker="x", mew=2)
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1)
sns.histplot(
    data=results_mitchell.accepted.timepoint + 2,
    discrete=True,
    ax=ax,
    stat="density",
    label="accepted filtered",
)
sns.histplot(
    data=results_mitchell.accepted_quantile.timepoint,
    discrete=True,
    stat="density",
    ax=ax,
    label="accepted",
)
ax.legend()
plt.show()

fig, ax = plt.subplots(1, 1)
sns.lineplot(
    x="age",
    y="wasserstein",
    data=abc_mitchell.loc[
        abc_mitchell.idx.isin(results_mitchell.get_idx()),
        ["timepoint", "wasserstein"],
    ].rename({"timepoint": "age"}, axis=1),
    errorbar="sd",
    label="accepted filtered",
    ax=ax,
)

sns.lineplot(
    x="age",
    y="wasserstein",
    data=abc_mitchell.loc[
        abc_mitchell.idx.isin(results_mitchell.accepted_quantile.idx.unique()),
        ["timepoint", "wasserstein"],
    ].rename({"timepoint": "age"}, axis=1),
    errorbar="sd",
    label="accepted",
    ax=ax,
)
ax.set_ylim([abc_mitchell.wasserstein.min(), abc_mitchell.wasserstein.max()])
plt.show()

## Run ABC on subsampled simulated data

### Low mu $\mu = 1.097$

In [None]:
low_mu = (
    abc_mitchell[(abc_mitchell.mu - 1.1).abs() < 0.01]
    .drop_duplicates(subset={"mu", "s"})
    .sort_values(by=["s", "mu"])
)
idx_low_low = low_mu.iloc[1].idx
idx_low_high = low_mu.iloc[-2].idx
idx_low_medium = low_mu.iloc[int(low_mu.shape[0] / 2)].idx

#### Low mu and low s, $\mu=1.097$ and $s=0.017$

In [None]:
view = abc_mitchell[abc_mitchell.idx == idx_low_low]
target_stem = view.path.iloc[0].stem
view

In [None]:
%%time
_, g1, g2, g3 = abc_fig.abc_simulated_validation(
    target_stem,
    sfs_sims,
    counts,
    abc.AbcThresholds(quantile=0.4, proportion_runs_to_discard=0, nb_clones_diff=2),
    show_priors=False,
)
plt.show()

#### Low mu and medium s, $\mu=1.096$ and $s=0.189$

In [None]:
view = abc_mitchell[abc_mitchell.idx == idx_low_medium]
target_stem = view.path.iloc[0].stem
view

In [None]:
%%time
_, g1, g2, g3 = abc_fig.abc_simulated_validation(
    target_stem,
    sfs_sims,
    counts,
    abc.AbcThresholds(quantile=0.4, proportion_runs_to_discard=0, nb_clones_diff=2),
    show_priors=False,
)
plt.show()

#### Low mu and high s, $\mu=1.102$ and $s=0.367$

In [None]:
view = abc_mitchell[abc_mitchell.idx == idx_low_high]
target_stem = view.path.iloc[0].stem
view

In [None]:
%%time
_, g1, g2, g3 = abc_fig.abc_simulated_validation(
    target_stem,
    sfs_sims,
    counts,
    abc.AbcThresholds(quantile=0.4, proportion_runs_to_discard=0, nb_clones_diff=2),
    show_priors=False,
)
plt.show()

### Medium mu $\mu = 10$

In [None]:
medium_mu = (
    abc_mitchell[(abc_mitchell.mu - 10).abs() < 0.01]
    .drop_duplicates(subset={"mu", "s"})
    .sort_values(by=["s", "mu"])
)
idx_medium_low = medium_mu.iloc[1].idx
idx_medium_high = medium_mu.iloc[-2].idx
idx_medium_medium = medium_mu.iloc[int(medium_mu.shape[0] / 2)].idx

#### Medium mu and low s, $\mu=9.991$ and $s=0.027$

In [None]:
view = abc_mitchell[abc_mitchell.idx == idx_medium_low]
target_stem = view.path.iloc[0].stem
view

In [None]:
%%time
_, g1, g2, g3 = abc_fig.abc_simulated_validation(
    target_stem,
    sfs_sims,
    counts,
    abc.AbcThresholds(quantile=0.4, proportion_runs_to_discard=0, nb_clones_diff=2),
    show_priors=False,
)
plt.show()

#### Medium mu and medium s, $\mu=10.004$ and $s=0.193$

In [None]:
view = abc_mitchell[abc_mitchell.idx == idx_medium_medium]
target_stem = view.path.iloc[0].stem
view

In [None]:
%%time
_, g1, g2, g3 = abc_fig.abc_simulated_validation(
    target_stem,
    sfs_sims,
    counts,
    abc.AbcThresholds(quantile=0.4, proportion_runs_to_discard=0, nb_clones_diff=2),
    show_priors=False,
)
plt.show()

### High mu $\mu=19.091$

In [None]:
high_mu = (
    abc_mitchell[(abc_mitchell.mu - 19.1).abs() < 0.01]
    .drop_duplicates(subset={"mu", "s"})
    .sort_values(by=["s", "mu"])
)
idx_high_low = high_mu.iloc[1].idx
idx_high_high = high_mu.iloc[-2].idx
idx_high_medium = high_mu.iloc[int(high_mu.shape[0] / 2)].idx

#### High mu and low s, $\mu=19.091$ and $s=0.022$

In [None]:
view = abc_mitchell[abc_mitchell.idx == idx_high_low]
target_stem = view.path.iloc[0].stem
view

In [None]:
%%time
_, g1, g2, g3 = abc_fig.abc_simulated_validation(
    target_stem,
    sfs_sims,
    counts,
    abc.AbcThresholds(quantile=0.4, proportion_runs_to_discard=0, nb_clones_diff=2),
    show_priors=False,
)
plt.show()

#### High mu and medium s, $\mu=19.105$ and $s=0.234$

In [None]:
view = abc_mitchell[abc_mitchell.idx == idx_high_medium]
target_stem = view.path.iloc[0].stem
view

In [None]:
%%time
_, g1, g2, g3 = abc_fig.abc_simulated_validation(
    target_stem,
    sfs_sims,
    counts,
    abc.AbcThresholds(quantile=0.4, proportion_runs_to_discard=0, nb_clones_diff=2),
    show_priors=False,
)
plt.show()

#### HIgh mu and high s, $\mu=19.095$ and $s=0.390$

In [None]:
view = abc_mitchell[abc_mitchell.idx == idx_high_high]
target_stem = view.path.iloc[0].stem
view

In [None]:
%%time
_, g1, g2, g3 = abc_fig.abc_simulated_validation(
    target_stem,
    sfs_sims,
    counts,
    abc.AbcThresholds(quantile=0.4, proportion_runs_to_discard=0, nb_clones_diff=2),
    show_priors=False,
)
plt.show()