In [None]:
import random
import numpy as np
import socket
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt 
from scipy import stats
from matplotlib import colors
from datetime import datetime
from pathlib import Path
from futils import parse_version, snapshot
from hscpy.figures import simulations, mitchell, PlotOptions
from hscpy.figures import abc as abc_fig
from hscpy import abc, sfs

PATH2BIN = Path("~").expanduser() / "hsc/target/release"
assert PATH2BIN.is_dir()

In [None]:
%%bash -s "$PATH2BIN" --out version
$1/hsc --version

In [None]:
YEARS = 100
YEARS_ENTROPY = 1
RUNS = 2737
NB_TIMEPOINTS = 20
DETECTION_THRESH = 0.01
SUBCLONES = 60
USE_SCRATCH = True

NCELLS = 200_000
SAMPLE = 368

MU_BACKGROUND = 15.5
S = 0.11

SHOW_PRIORS = True

In [None]:
VERSION = parse_version(version)
PATH2SAVE = Path(f"./{VERSION}")

print("Running hsc with version:", VERSION)

if USE_SCRATCH:
    PATH2SIMS = Path("/data/scratch/")
else:
    PATH2SIMS = Path("/data/home/")
PATH2SIMS /= f"hfx923/hsc-draft/{VERSION}"

if socket.gethostname() == "5X9ZYD3":
    PATH2MITCHELL = Path("/mnt/c/Users/terenz01/Documents/SwitchDrive/PhD/hsc")
elif socket.gethostname() == "LAPTOP-CEKCHJ4C":
    PATH2MITCHELL = Path("/mnt/c/Users/fra_t/Documents/PhD/hsc")
else:
    PATH2MITCHELL = Path("~").expanduser()

In [None]:
sim_options_abc = simulations.SimulationOptions(
    runs=RUNS,
    cells=NCELLS,
    sample=SAMPLE,
    path2save=PATH2SIMS,
    neutral_rate=MU_BACKGROUND,
    nb_timepoints=NB_TIMEPOINTS,
    last_timepoint_years=YEARS,
    nb_subclones=SUBCLONES,
    s=S,
)
options = PlotOptions(figsize=(7, 6), extension="pdf", save=False)
summary = mitchell.load_and_process_mitchell(PATH2MITCHELL / "Summary_cut.csv")
donors = simulations.donors_from_mitchell(summary, sim_options_abc)
donors = [donor for donor in donors if donor.name != "KX007"]  # they have uploded twice the same donor

## Remove runs
Remove all the runs that didn't finish running. This is required because we load runs in a numpy array with fixed size.

In [None]:
path2burden = Path(f"{sim_options_abc.path2save}/{sim_options_abc.sample}cells/sfs/")
timepoint1 = {file.stem for file in (path2burden / "1").iterdir()}
timepoint2 =  {file.stem for file in (path2burden / "20").iterdir()}
files2remove = timepoint1.symmetric_difference(timepoint2)
print(f"{len(files2remove)} files to remove")
runs2remove = " ".join(files2remove)

In [None]:
%%bash -s "$runs2remove" "$sim_options_abc.path2save"
echo "removing files"
for file in $1
do
    find $2 -name *$file* -exec rm {} \;
done

In [None]:
path2burden = Path(f"{sim_options_abc.path2save}/{sim_options_abc.sample}cells/burden/")
timepoint1 = {file.stem for file in (path2burden / "1").iterdir()}
timepoint2 =  {file.stem for file in (path2burden / "20").iterdir()}
files2remove = timepoint1.symmetric_difference(timepoint2)
print(f"{len(files2remove)} files to remove")

## Load simualted SFS

In [None]:
%%time
sfs_sims = sfs.load_sfs_timepoints(
    sim_options_abc.path2save, 
    sim_options_abc.nb_timepoints, 
    sim_options_abc.sample, 
    sim_options_abc.runs
)

## Run ABC on subsampled simulated data

In [None]:
# summarise
donor_idx = "7mu0_0dot000017499913155916147u_0dot11840355mean_0dot05043838std_1b0_200000cells_21430idx"
target_sfs = {t: sfs_donor[donor_idx] for t, sfs_donor in sfs_sims.items()}
abc_results = abc.sfs_summary_statistic_wasserstein(sfs_sims, target_sfs, donor_idx)
abc_results

In [None]:
# show priors
priors = abc_results[["mu", "u", "s", "std"]].drop_duplicates()

if SHOW_PRIORS:
    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["s"], ax=ax, binwidth=0.01)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["std"], ax=ax, binwidth=0.001)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["mu"], ax=ax, discrete=True)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["u"], ax=ax)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    sns.histplot(abc_results["wasserstein"], binwidth=0.01, ax=ax)
    plt.show()

In [None]:
# run abc to filter runs TODO
quantile = 0.24
minimum_runs = 16
runs2keep = abc.run_abc(abc_results, quantile, minimum_runs)

In [None]:
for t in abc_results.timepoint.unique():
    print(t)
    # plot SFS
    fig, ax = plt.subplots(1, 1)
    idx_target = abc_results.loc[abc_results.filename == donor_idx, "idx"].unique()[0]
    target = sfs.process_sfs(sfs_sims[t][donor_idx], normalise=True, log_transform=True)
    ax.plot(list(target.keys()), list(target.values()), marker="x", linestyle="", color="b", label=f"target {idx_target}", mew=2)
    # ax = sfs_fig.plot_sfs_sim_with_id(ax, target, normalise=True, options=options, linestyle="", marker="x", mew=2, label=f"target {idx_target}", color="b")
    for s_id in runs2keep:
        filename_sim = abc_results.loc[abc_results.idx == s_id, "filename"].unique()[0]
        sim = sfs.process_sfs(sfs_sims[t][filename_sim], normalise=True, log_transform=True)
        ax.plot(list(sim.keys()), list(sim.values()), marker="o", linestyle="", alpha=0.4, label=f"sim {s_id}")
        # ax.set_xscale("log")
        # ax.set_yscale("log")
    ax.legend()
    plt.show()
sns.heatmap(abc_results.loc[(abc_results.idx.isin(runs2keep)) & (abc_results.timepoint == t), ["s", "mu", "wasserstein"]].drop_duplicates().pivot(index='s', columns='mu', values='wasserstein'))
plt.show()
sns.heatmap(abc_results.loc[(abc_results.idx.isin(runs2keep)) & (abc_results.timepoint == t), ["s", "std", "wasserstein"]].drop_duplicates().pivot(index='s', columns='std', values='wasserstein'))
plt.show()

In [None]:
selected = abc_results.loc[abc_results.idx.isin(runs2keep), ["mu", "s", "std"]].drop_duplicates()

abc_fig.plot_results(
    selected, ["mu", "s"], [0, abc_fig.lims(priors, "mu")[1]], 
    [0, abc_fig.lims(priors, "s")[1]], 
    {"discrete": True}, {"binwidth": 0.01}
)
abc_fig.plot_results(
    selected, ["mu", "std"],
    [0, abc_fig.lims(priors, "mu")[1]], 
    [0, abc_fig.lims(priors, "std")[1]], 
    {"discrete": True}, {"binwidth": 0.005}
)
abc_fig.plot_results(
    selected, ["s", "std"],  
    [0, abc_fig.lims(priors, "s")[1]],
    [0, abc_fig.lims(priors, "std")[1]], 
    {"binwidth": 0.01}, {"binwidth": 0.005}
)

## Run ABC on the real data

In [None]:
%%time
# summarise
# skip the first donor for now since we have another donor with age 0
target_sfs = {donor.id_timepoint: mitchell.sfs_donor_mitchell(donor.name, PATH2SIMS, remove_indels=False) for donor in donors[1:]}
abc_results = abc.sfs_summary_statistic_wasserstein({t: sfs for t, sfs in sfs_sims.items() if t in set(target_sfs.keys())}, target_sfs, "mitchell")
abc_results

In [None]:
priors = abc_results[["mu", "u", "s", "std"]].drop_duplicates()

if SHOW_PRIORS:
    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["s"], ax=ax, binwidth=0.01)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["std"], ax=ax, binwidth=0.001)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["mu"], ax=ax, discrete=True)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    ax = abc_fig.plot_prior(priors["u"], ax=ax)
    plt.show()

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    sns.histplot(abc_results["wasserstein"], binwidth=0.01, ax=ax)
    plt.show()

In [None]:
# run abc
quantile = 0.2
minimum_runs = 6
runs2keep = abc.run_abc(abc_results, quantile, minimum_runs)

In [None]:
selected = abc_results.loc[abc_results.idx.isin(runs2keep), ["mu", "s", "std"]].drop_duplicates()

abc_fig.plot_results(
    selected, ["mu", "s"], [0, abc_fig.lims(priors, "mu")[1]], 
    [0, abc_fig.lims(priors, "s")[1]], 
    {"discrete": True}, {"binwidth": 0.01}
)
abc_fig.plot_results(
    selected, ["mu", "std"],
    [0, abc_fig.lims(priors, "mu")[1]], 
    [0, abc_fig.lims(priors, "std")[1]], 
    {"discrete": True}, {"binwidth": 0.005}
)
abc_fig.plot_results(
    selected, ["s", "std"],  
    [0, abc_fig.lims(priors, "s")[1]],
    [0, abc_fig.lims(priors, "std")[1]], 
    {"binwidth": 0.01}, {"binwidth": 0.005}
)

## Show metric over time

In [None]:
raise NotImplementedError

In [None]:
show_fitness_distributions = False
                 
for _ in range(9):
    id1, id2 = random.choices(list(sfs_sims[20].keys()), k=2)
    my_stats = abc.heatmap_wasserstein(
        sfs_sims,
        id1,
        id2,
        sim_options_abc.nb_timepoints,
        sim_options_abc.last_timepoint_years,
        normalise=True,
        log_transform=True,
    )
    my_stats.index.name = my_stats.index.name.split("_")[-1].replace("id", " id")
    my_stats.columns.name = my_stats.columns.name.split("_")[-1].replace("id", " id")

    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    sns.heatmap(my_stats, norm=colors.LogNorm(), ax=ax)
    plt.show()

    if show_fitness_distributions:
        for idx in (id1, id2):
            fig, ax = plt.subplots(1, 1)
            pd.read_csv(sim_options_abc.path2save / f"rates/{idx}.csv", header=None).squeeze().plot(
                kind="hist", ax=ax, bins=35
            )
            ax.set_xlim(0.95, 1.4)  # TODO?
            ax.set_title(f"simulation id: {idx}")
            plt.show()
            if idx == id2:
                break
else:
    my_stats = abc.heatmap_wasserstein(
        sfs_sims,
        id1,
        id1,
        sim_options_abc.nb_timepoints,
        sim_options_abc.last_timepoint_years,
        normalise=True,
        log_transform=True,
    )
    my_stats.index.name = my_stats.index.name.split("_")[-1].replace("id", " id")
    my_stats.columns.name = my_stats.columns.name.split("_")[-1].replace("id", " id")
    
    fig, ax = plt.subplots(1, 1, figsize=[7, 6])
    sns.heatmap(my_stats, norm=colors.LogNorm(), ax=ax)
    plt.show()