In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import socket
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Union, List
from futils import parse_version, snapshot
from hscpy.figures import PlotOptions
from hscpy.figures import abc as abc_fig
from hscpy.figures import sfs as sfs_fig
from hscpy import realisation, mitchell, abc, variant

PATH2BIN = Path("~").expanduser() / "hsc/target/release"
assert PATH2BIN.is_dir()

In [None]:
%%bash -s "$PATH2BIN" --out version
$1/hsc --version

In [None]:
USE_SCRATCH = True
SAVEFIG = True
SHOW_PRIORS = True
LATEST = True
# set to true if it's the first time running this nb: this
# will remove the runs that didn't finish running up to the
# last timepoint
FIRST_TIME = False

# the higher, the less precise and thus the more runs
# we set to these values because we aim to keep approx X runs
PROPORTION_RUNS_DISCARDED, QUANTILE_SFS, QUANTILE_CLONES = 0.2, 0.3, 0.2

options = PlotOptions(figsize=(7, 6), extension="svg", save=SAVEFIG)

In [None]:
if LATEST:
    VERSION = parse_version(version)
else:
    VERSION = "v2.0.0"
PATH2SAVE = Path(f"./{VERSION}")

print("Running hsc with version:", VERSION)

if USE_SCRATCH:
    PATH2SIMS = Path("/data/scratch")
else:
    PATH2SIMS = Path("/data/home")
PATH2SIMS /= f"hfx923/hsc-draft/{VERSION}"

if socket.gethostname() == "5X9ZYD3":
    PATH2MITCHELL = Path("/mnt/c/Users/terenz01/Documents/SwitchDrive/PhD/hsc")
elif socket.gethostname() == "LAPTOP-CEKCHJ4C":
    PATH2MITCHELL = Path("/mnt/c/Users/fra_t/Documents/PhD/hsc")
else:
    PATH2MITCHELL = Path("~").expanduser()

In [None]:
# 1. we dont use summary `Summary_cut.csv` because
#    it contains errors (i.e. duplicated and missing entries)
# 2. we dont use donor CB001 because we already have CB002 at age 0
# 3. we dont use donor KX007 because the genotype matrix is wrong,
#    they have uploded twice the same donor
donors = pd.DataFrame.from_records(
    [
        {"name": "CB002", "age": 0, "cells": 390, "clones": 0},
        {"name": "KX001", "age": 29, "cells": 407, "clones": 0},
        {"name": "KX002", "age": 38, "cells": 380, "clones": 1},
        {"name": "SX001", "age": 48, "cells": 362, "clones": 0},
        {"name": "AX001", "age": 63, "cells": 361, "clones": 1},
        {"name": "KX008", "age": 76, "cells": 367, "clones": 12},
        {"name": "KX004", "age": 77, "cells": 451, "clones": 15},
        {"name": "KX003", "age": 81, "cells": 328, "clones": 13},
    ]
)
donors

In [None]:
# bins for abc
bins_s = np.arange(0, 0.44, 0.02)
bins_mu = np.arange(0, 20, 1)
bins_tau = np.arange(0, 5.2, 0.2)
bins_std = np.arange(0, 0.12, 0.01)

# Load simualted data both SFS and variant fractions

## Remove runs
Remove all the runs that didn't finish running. This is required because we load runs in a numpy array with fixed size.

In [None]:
def files2remove(years: List[int], cells: List[int]) -> List[str]:
    assert len(years) == len(cells)

    paths = [
        Path(f"{PATH2SIMS}/{cell}cells/sfs/{year}dot0years")
        for year, cell in zip(years, cells)
    ]

    files = [file.stem for path in paths for file in path.iterdir()]
    files2remove = {
        file for file in set(files) if files.count(file) < len(donors.age.tolist())
    }
    print(f"{len(files2remove)} files to remove")
    return files2remove

In [None]:
%%time
if FIRST_TIME:
    runs2remove = files2remove(
        years=donors.age.tolist(),
        cells=donors.cells.tolist(),
    )
    print(f"found {len(runs2remove)} runs to remove")

    for r in donors[["age", "cells"]].itertuples():
        print(r)
        for f in runs2remove:
            myf = f"{PATH2SIMS}/{r.cells}cells/sfs/{r.age}dot0years/{f}"
            try:
                Path(myf).with_suffix(".json").unlink()
            except FileNotFoundError:
                pass
            try:
                Path(myf.replace("sfs", "variant_fraction")).with_suffix(
                    ".csv"
                ).unlink()
            except FileNotFoundError:
                pass

## Load data

In [None]:
%%time
sfs_sims = dict()
counts_sims = dict()

for r in donors[["name", "age", "cells"]].itertuples():
    print(f"\tloading sims SFS for donor {r.name} with {r.cells} cells")
    path2sfs = Path(PATH2SIMS / f"{r.cells}cells/sfs/")
    sfs_sims.update(realisation.load_all_sfs_by_age(path2sfs))

    print(f"\tloading sims variant counts for donor {r.name} with {r.cells} cells")
    counts_sims.update(
        variant.load_all_detected_var_counts_by_age(
            PATH2SIMS / f"{r.cells}cells/variant_fraction", 0.01
        )
    )

In [None]:
counts = variant.variant_counts_detected_df(counts_sims)
fig, ax = plt.subplots(1, 1)
sns.lineplot(
    counts,
    x="age",
    y="variant counts detected",
    errorbar=lambda x: (np.min(x), np.max(x)),
    ax=ax,
    label="min-max",
)
sns.lineplot(
    counts,
    x="age",
    y="variant counts detected",
    errorbar="sd",
    ax=ax,
    color="orange",
    label="std",
)
ax.legend()
plt.show()
print(counts[["variant counts detected", "age"]].groupby("age").describe())

# Run ABC on the real data

## Load the data from Mitchell's paper
We have excluded two donors from the ABC:
1. exclude KX007 bc they have uploded twice the same donor
3. exclude CB001 bc it maps to to the same timepoint as CB002 (same age 0)

In [None]:
%%time
target_sfs = {
    r.age: mitchell.sfs_donor_mitchell(
        r.name, r.age, PATH2MITCHELL, remove_indels=False
    )
    for r in donors[["name", "age"]].itertuples()
}
assert [
    ele[2] for ele in target_sfs.values()
] == donors.cells.tolist(), (
    "cells found in genotype matrices do not match the extected nb of cells"
)

### Compute the summary statistics 
1. wasserstein metric
2. the number of clones
3. the KS stat (not implemented yet)

After having computed the statistics from the simulations and Mitchell's data, we won't use the data anymore but just the summary statistic dataframe `abc_mitchell`: we are going to filter (and keep) the runs by selecting from this dataframe.

In [None]:
print("wasserstein metric")
abc_mitchell = abc.sfs_summary_statistic_wasserstein(
    sfs_sims,
    {
        k: ele[-1] for k, ele in target_sfs.items()
    },  # ele[-1] means the SFS, k is the age
    "mitchell",
)

# add information about clones from Mitchell's fig 5a
print("clones metric")
abc_mitchell = abc_mitchell.merge(
    right=counts[["age", "idx", "variant counts detected"]],
    how="left",
    left_on=["idx", "timepoint"],
    right_on=["idx", "age"],
    validate="one_to_one",
)
abc_mitchell.dropna(inplace=True)
assert (
    not abc_mitchell.isna().any().any()
), "cannot match the nb of clones data to the abc results"
abc_mitchell = donors.merge(
    right=abc_mitchell, how="right", on="age", validate="one_to_many"
)
abc_mitchell.shape

In [None]:
# 2. number of clones
abc_mitchell["clones diff"] = (
    abc_mitchell["clones"] - abc_mitchell["variant counts detected"]
).abs()
abc_mitchell = abc.summary_statistic_relative_diff_clones(abc_mitchell)

### Show priors

In [None]:
if SHOW_PRIORS:
    priors = abc_mitchell[["mu", "u", "s", "std", "tau"]].drop_duplicates()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["s"], ax=ax, binwidth=0.01)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["std"], ax=ax, binwidth=0.001)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["tau"], ax=ax, binwidth=0.1)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["mu"], ax=ax, discrete=False)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["u"], ax=ax)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    sns.histplot(abc_mitchell["wasserstein"], binwidth=0.01, ax=ax)
    plt.show()

    sns.pairplot(priors[["s", "std", "tau", "mu"]], kind="hist")
    plt.show()

### Run abc considering all timepoints at the same time

In [None]:
def find_map_by_cut(
    view: pd.Series, bins: List[Union[float, int]]
) -> Union[float, None]:
    max_a_posteriori = pd.cut(view, bins=bins).mode()
    if max_a_posteriori.shape[0] > 1:
        print(
            f"cannot compute MAP because more than one mode have been found {max_a_posteriori}"
        )
        return None
    max_a_posteriori = max_a_posteriori.iloc[0]
    return max_a_posteriori.mid

In [None]:
# plot all posteriors for all timepoints separately, for all metrics
for t in sorted(target_sfs.keys()):
    # for metric in ("rel clones diff", ):
    # for metric in ("clones diff", ):
    # for metric in ("wasserstein", "rel clones diff"):
    for metric in ("wasserstein",):
        fig, axes = plt.subplots(2, 2, layout="tight", sharey=True)
        print(f"Running ABC on {metric} metric at age {t}")
        idx = abc.run_abc_per_single_timepoint(
            abc_mitchell,
            t,
            quantile=0.01,
            metric=metric,
        ).get_idx()
        print(len(idx))
        view = abc_mitchell[abc_mitchell.idx.isin(idx)]

        for i, (theta, lim) in enumerate(
            zip(
                ["s", "std", "mu", "tau"],
                [
                    (bins_s[0], bins_s[-1]),
                    (bins_std[0], bins_std[-1]),
                    (bins_mu[0], bins_mu[-1]),
                    (bins_tau[0], bins_tau[-1]),
                ],
            )
        ):
            ax = axes[np.unravel_index(i, (2, 2))]
            sns.histplot(view[theta], ax=ax)  # stat="percent")
            ax.set_xlim(lim)
        fig.suptitle(f"{t} years with {metric}", y=0.9, fontsize="small")
        plt.show()

In [None]:
nb_timepoints = abc_mitchell.name.unique().shape[0]
minimum_timepoints = int(
    round(nb_timepoints - nb_timepoints * PROPORTION_RUNS_DISCARDED)
)

# run abc with different metrics per each timepoint and keep only
# the runs that are accepted at least for minimum timepoints
print(f"Running ABC with {minimum_timepoints} minimum timepoints over {nb_timepoints}")
wasserstein_idx = set(
    abc.run_abc(
        abc_mitchell,
        quantile=QUANTILE_SFS,
        metric="wasserstein",
    )
    .abc_filter_on_minimum_timepoints(minimum_timepoints)
    .idx.tolist()
)
print(f"ABC wasserstein kept {len(wasserstein_idx)} runs")

clones_idx = set(
    abc.run_abc(
        abc_mitchell,
        quantile=QUANTILE_CLONES,
        metric="rel clones diff",
    )
    .abc_filter_on_minimum_timepoints(minimum_timepoints)
    .idx.tolist()
)
print(f"ABC clones kept {len(clones_idx)} runs")

runs2keep = clones_idx.intersection(wasserstein_idx)
view = abc_mitchell[abc_mitchell.idx.isin(runs2keep)].drop_duplicates(subset="idx")
assert not view.empty, "empty posterior"
print(f"ABC combined kept {len(runs2keep)} runs")

# plots
map_mu = find_map_by_cut(view.mu, range(0, 21))
# map_s = find_map_by_cut(view.s, np.arange(0, 0.45, kwargs_s["binwidth"]).tolist())
# map_std = find_map_by_cut(view["std"], np.arange(0, 0.15, kwargs_std["binwidth"]).tolist())

gs = []
gs.append(abc_fig.plot_results(view, "s", "mu", bins_s, bins_mu))
plt.savefig(
    f"posterior{len(gs)}.{options.extension}", bbox_inches="tight", pad_inches=0
)
plt.show()

gs.append(abc_fig.plot_results(view, "s", "std", bins_s, bins_std))
plt.savefig(
    f"posterior{len(gs)}.{options.extension}", bbox_inches="tight", pad_inches=0
)
plt.show()

gs.append(abc_fig.plot_results(view, "s", "tau", bins_s, bins_tau))
plt.savefig(
    f"posterior{len(gs)}.{options.extension}", bbox_inches="tight", pad_inches=0
)
plt.show()

gs.append(abc_fig.plot_results(view, "mu", "tau", bins_mu, bins_tau))
plt.savefig(
    f"posterior{len(gs)}.{options.extension}", bbox_inches="tight", pad_inches=0
)
plt.show()

# TODO: maybe iterate over gs to remove the grid in the marginal plots?
verbose = False
idx2show = dict()

for t in sorted(abc_mitchell.timepoint.unique()):
    if verbose:
        title = f"age: {t} years, quantile threshold: {abc_mitchell.loc[abc_mitchell.timepoint == t, 'wasserstein'].quantile(quantile):.2f}"
    else:
        title = f"age: {round(t)} years"

    idx2show[t] = abc_fig.get_idx_smaller_distance_clones_idx(
        abc_mitchell[abc_mitchell.timepoint == t], runs2keep
    )

    fig = sfs_fig.plot_sfs_cdf(
        [idx2show[t]], target_sfs[t][3], sfs_sims[t], t, verbose=verbose, alpha=0.7
    )
    """fig.suptitle(
        title,
        x=0.4,
    )"""
    if options.save:
        fig.savefig(f"sfs_{t}years.{options.extension}")
    plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, layout="tight")
sns.lineplot(
    abc_mitchell[abc_mitchell.idx.isin(runs2keep)],
    x="age",
    y="variant counts detected",
    errorbar=lambda x: (np.min(x), np.max(x)),
    ax=ax,
    label="min-max",
)
sns.lineplot(
    abc_mitchell[abc_mitchell.idx.isin(runs2keep)],
    x="age",
    y="variant counts detected",
    errorbar="sd",
    ax=ax,
    color="orange",
    label="std",
)
sns.scatterplot(
    data=abc_mitchell[["age", "clones"]].drop_duplicates(),
    x="age",
    y="clones",
    marker="x",
    linewidths=2,
    color="purple",
    label="Mitchell",
)
ax.legend()
if options.save:
    fig.savefig(f"variants_abc.{options.extension}")
plt.show()

In [None]:
ts = list(idx2show.keys())
mus, ss = (
    [
        abc_mitchell[(abc_mitchell.timepoint == t) & (abc_mitchell.idx == idx)].mu.iloc[
            0
        ]
        for t, idx in idx2show.items()
    ],
    [
        abc_mitchell[(abc_mitchell.timepoint == t) & (abc_mitchell.idx == idx)].s.iloc[
            0
        ]
        for t, idx in idx2show.items()
    ],
)
sns.histplot(mus, binwidth=0.5)
plt.show()
sns.histplot(ss, binwidth=0.005)
plt.show()

plt.plot(ts, mus, marker="x", mew=2)
plt.show()

plt.plot(ts, ss, marker="x", mew=2)
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, layout="constrained")
(view.s / view.tau).hist(ax=ax, bins=30)
ax.set_xlabel(r"$s/\tau$")
ax.set_ylabel("counts")
mean = (view.s / view.tau).mean()
ax.axvline(mean, color="black")
ax.annotate(
    f"mean = {mean:.3f}", xy=(0.5, 1.01), xycoords=("axes fraction", "axes fraction")
)
plt.show()

In [None]:
view[["s", "tau"]].corr()

In [None]:
# TODO remove?
s = np.asarray(view.s.tolist())
tau = np.asarray(view.tau.tolist())
A = np.vstack([s, np.ones(len(s))]).T
m, c = np.linalg.lstsq(A, tau, rcond=None)[0]

fig, ax = plt.subplots(1, 1)
ax.plot(s, tau, "o", label="Original data", markersize=10)
ax.plot(s, m * s + c, "r", label=f"Fit {m:.2f}+{c:.2f}")
ax.set_xlabel(r"$s$")
ax.set_ylabel(r"$\tau$")
ax.legend()
plt.show()

## Run ABC on subsampled simulated data

### High mu

In [None]:
test_sfs = {k: sfs_sims[k][1] for k in sfs_sims.keys()}
print(f"mu: {test_sfs[0.0].parameters.mu}")
test_idx = test_sfs[0.0].parameters.idx
test_counts = counts.loc[
    counts.idx == test_idx, ["age", "variant counts detected", "mu"]
].rename(columns={"variant counts detected": "target clones detected"})
assert (test_sfs[0.0].parameters.mu == test_counts.mu).all()
test_counts.drop(columns=["mu"], inplace=True)

print("clones metric")
test_counts = counts.merge(
    right=test_counts, on="age", how="left", validate="many_to_one"
)

print("wasserstein metric")
abc_mitchell = abc.sfs_summary_statistic_wasserstein(
    sfs_sims,
    {k: ele.sfs for k, ele in test_sfs.items()},  # ele[-1] means the SFS, k is the age
    "mitchell",
)

abc_mitchell["clones"] = test_counts["target clones detected"]
abc_mitchell["clones diff"] = (
    test_counts["variant counts detected"] - abc_mitchell["clones"]
).abs()

abc_mitchell = abc.summary_statistic_relative_diff_clones(abc_mitchell)

In [None]:
nb_timepoints = abc_mitchell["sample"].unique().shape[0]
minimum_timepoints = int(
    round(nb_timepoints - nb_timepoints * PROPORTION_RUNS_DISCARDED)
)

# run abc with different metrics per each timepoint and keep only
# the runs that are accepted at least for minimum timepoints
print(f"Running ABC with {minimum_timepoints} minimum timepoints over {nb_timepoints}")
wasserstein_idx = set(
    abc.run_abc(
        abc_mitchell,
        quantile=QUANTILE_SFS,
        metric="wasserstein",
    )
    .abc_filter_on_minimum_timepoints(minimum_timepoints)
    .idx.tolist()
)
print(f"ABC wasserstein kept {len(wasserstein_idx)} runs")

clones_idx = set(
    abc.run_abc(
        abc_mitchell,
        quantile=QUANTILE_CLONES,
        metric="rel clones diff",
    )
    .abc_filter_on_minimum_timepoints(minimum_timepoints)
    .idx.tolist()
)
print(f"ABC clones kept {len(clones_idx)} runs")

runs2keep = clones_idx.intersection(wasserstein_idx)
view = abc_mitchell[abc_mitchell.idx.isin(runs2keep)].drop_duplicates(subset="idx")
assert not view.empty, "empty posterior"
print(f"ABC combined kept {len(runs2keep)} runs")


axd = abc_fig.plot_results(view, "s", "mu", bins_s, bins_mu)
axd["C"].vlines(test_sfs[0.0].parameters.s, ymin=bins_mu[0], ymax=bins_mu[-1])
axd["C"].hlines(test_sfs[0.0].parameters.mu, xmin=bins_s[0], xmax=bins_s[-1])
plt.show()

axd = abc_fig.plot_results(view, "s", "std", bins_s, bins_std)
axd["C"].vlines(test_sfs[0.0].parameters.s, ymin=bins_std[0], ymax=bins_std[-1])
axd["C"].hlines(test_sfs[0.0].parameters.std, xmin=bins_s[0], xmax=bins_s[-1])
plt.show()


axd = abc_fig.plot_results(view, "s", "tau", bins_s, bins_tau)
axd["C"].vlines(test_sfs[0.0].parameters.s, ymin=bins_tau[0], ymax=bins_tau[-1])
axd["C"].hlines(test_sfs[0.0].parameters.tau, xmin=bins_s[0], xmax=bins_s[-1])
plt.show()

axd = abc_fig.plot_results(view, "mu", "tau", bins_mu, bins_tau)
axd["C"].vlines(test_sfs[0.0].parameters.mu, ymin=bins_tau[0], ymax=bins_tau[-1])
axd["C"].hlines(test_sfs[0.0].parameters.tau, xmin=bins_mu[0], xmax=bins_mu[-1])
plt.show()