# HSC
Markov process with fixed-size population with k-types such that the type 0 is the wild-type with growth rate of `B0`. 

A cells can get a mutation conferring a proliferative advantage upon cell division. We model this process with a Bernouilli trial with success probability of `u`, with units of 1 mutation/division. To compute `u` we can do `u =  MU0 / (2 * B0 * NCELLS)` for the symmetric division case.

For now, all k clones have the same proliferative advantage with k greater than 0.

In [None]:
import matplotlib.pyplot as plt
import json
import numpy as np
import pandas as pd
import socket
import seaborn as sns
import sys
from hscpy import sfs, variant
from typing import Dict
from scipy import stats
from pathlib import Path
from futils import parse_version, snapshot

PATH2BIN = Path("~").expanduser() / "hsc/target/release"
assert PATH2BIN.is_dir()
FIGSIZE = [7, 3]
PDF = True
EXTENSION = ".pdf" if PDF else ".png"
YEARS_FAST = 100  # TODO replace by 70
RUNS = 64
NB_TIMEPOINTS = 21
DETECTION_THRESH = 0.01
SUBCLONES = 60
SAVE = False
USE_SCRATCH = True

In [None]:
NCELLS = 200_000
# mean of the Bernouilli trial (prob of success) to get an asymmetric
# division upon cell division, units are [1 asymmetric division / division]
P_ASYMMETRIC = 0

## NEUTRAL RATES
# division rate for the wild-type in units of [division / (year * cell)]
# Welch, J.S. et al. (2012) ‘The Origin and Evolution of Mutations in Acute Myeloid Leukemia’,
# Cell, 150(2), pp. 264–278
B0 = 1  # TODO: double check this, should be between 2 and 20?
# Abascal, F. et al. (2021) ‘Somatic mutation landscapes at single-molecule resolution’,
# Nature, 593(7859), pp. 405–410. fig. 2b
# see also fig 1b of Mitchell, E. et al.
# (2022) ‘Clonal dynamics of haematopoiesis across the human lifespan’,
# Nature, 606(7913), pp. 343–350
NEUTRAL_RATE = 20 # [mut/(year * cell)]

## FIT CLONES
# avg fit mutations arising in 1 year, units are [mutations/year]
# from ABC's inference
MU0 = 2
# proliferative advantage conferred by fit mutations, all clones
# have the same proliferative advantage for now. Units are
# [mutation / division]
S = 0.11
# mean of the Bernouilli trial (prob of success) to get a fit variant upon
# cell division, units are [1 mutation/division]
if not P_ASYMMETRIC:
    u = MU0 / (2 * B0 * NCELLS)
else:
    u = MU0 / (B0 * NCELLS)
# should be 2.0 × 10−3 per HSC per year according to Mitchell, E. et al.
# (2022) ‘Clonal dynamics of haematopoiesis across the human lifespan’,
# Nature, 606(7913), pp. 343–350
# driver mutations enter the HSC compartment at 2.0 × 10−3 per HSC per year
print(f"average sucess rate of occurence of 1 fit mutation upon cell division u={u}")

In [None]:
if socket.gethostname() == "5X9ZYD3":
    PATH2SIMS = Path("/mnt/c/Users/terenz01/Documents/SwitchDrive/PhD/")
    YEARS = YEARS_FAST
elif socket.gethostname() == "LAPTOP-CEKCHJ4C":
    PATH2SIMS = Path("/mnt/c/Users/fra_t/Documents/PhD/")
    # need + 1 to save the last timepoint
    YEARS = YEARS_FAST
else:
    PATH2SIMS = Path("~").expanduser()
    YEARS = 100

PATH2SIMS /= Path("totalVariantFracTime.csv")
x = np.linspace(0, YEARS, NB_TIMEPOINTS)
assert PATH2SIMS.is_file()

In [None]:
%%bash -s "$PATH2BIN" --out version
$1/hsc --version

In [None]:
VERSION = parse_version(version)
if USE_SCRATCH:
    PATH2SAVE = Path(f"/data/scratch/hfx923/hsc-draft/{VERSION}")
else:
    PATH2SAVE = Path(f"./{VERSION}")
print("Running hsc with version:", VERSION)

In [None]:
%%bash -s "$PATH2BIN" "$PATH2SAVE" "$B0" "$MU0" "$NEUTRAL_RATE" "$S" "$P_ASYMMETRIC" "$RUNS" "$NCELLS" "$YEARS" "$NB_TIMEPOINTS"
rm -rf $2
$1/hsc -c $9 -y ${10} -r $8 --b0 $3 --mu0 $4 --neutral-rate $5 -s $6 --p-asymmetric $7 --snapshots ${11} $2

## The SFS

In [None]:
sfs_all

In [None]:
sfs_all = sfs.load_sfs(PATH2SAVE, runs=RUNS)
jcells, avg_sfs = sfs.average_sfs(sfs_all)

In [None]:
fig, ax = plt.subplots(1, 1)
for k in ("7", "1", "2"):
    sfs.plot_sfs(sfs_all[k], ax, k)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_ylabel("# of mutations with j cells")
ax.set_xlabel("j cells")
ax.legend(title="sim id")
ax.set_title(f"individual realisations of the SFS after {YEARS} years")
plt.show()

fig, ax = plt.subplots(1, 1)
ax.plot(
    jcells[0],
    avg_sfs,
    alpha=0.45,
    marker="x",
    linestyle="",
    c="grey",
)
ax.set_ylabel("# of mutations with j cells")
ax.set_xlabel("j cells")
x_ = np.arange(1, 2 * NCELLS)
y_ = 2 * NEUTRAL_RATE * NCELLS / x_
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_title(f"avg of the SFS over {RUNS} runs after {YEARS} years")
plt.show()

## Single-cell mutational burden

In [None]:
def get_idx_timepoint_from_age(age: int, years: int, nb_timepoints: int) -> int:
    """Find the idx of the timepoint associated to `age`.
    
    Rust saves timepoints in the reverse order, that is idx of 1 corresponds
    to the older timepoint (greater age).
    """
    try:
        age = round(age)
    except TypeError:
        print(f"arg `age` must be int found {type(age)} instead")
        sys.exit(1)
    timepoints = list(np.linspace(0, years, nb_timepoints))[::-1]
    try:
        found = timepoints.index(age) + 1
        closest_age = age
    except ValueError:
        closest_age = round(min(timepoints, key=lambda x:abs(x-age)))
        found = timepoints.index(closest_age) + 1
        print(f"age {age} cannot be mapped, found mapping of timepoint {found} for the closest age of {closest_age}")
    return found, closest_age


def array_of_single_cell_mutations(sfs: Dict[int, int]) -> np.ndarray:
    muts = []
    for jmuts, jcells in sfs.items():
        for cell in range(0, jcells):
            muts.append(jmuts)
    return np.array(muts, dtype=int)

In [None]:
# TODO: move this to other nb
mitchell_ages = (0, 29, 38, 48, 63, 75, 81)
simulated = list()
for age in mitchell_ages:
    idx_timepoint, closest_age = get_idx_timepoint_from_age(age, YEARS_FAST, NB_TIMEPOINTS)
    for idx_sim, simulation in sfs.load_sfs(PATH2SAVE, runs=RUNS, timepoint=idx_timepoint).items():
        for cell in array_of_single_cell_mutations(simulation):
            simulated.append((int(idx_sim), closest_age, age, cell))

In [None]:
simulated = pd.DataFrame(simulated, columns=["id", "age", "mitchell_age", "single nucleotide variant"])
simulated.dtypes

In [None]:
cells_per_variants_per_age = (simulated[["age", "single nucleotide variant"]].groupby(["age", "single nucleotide variant"]).value_counts() / (RUNS * NCELLS)).reset_index()
# add entry with no variants
cells_per_variants_per_age = pd.concat([
    cells_per_variants_per_age, 
    (1 - cells_per_variants_per_age[["age", 0]].groupby("age").sum()).reset_index(),
    ]
).fillna(0)
cells_per_variants_per_age["single nucleotide variant"] = cells_per_variants_per_age["single nucleotide variant"].astype(int)

fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=(7, 3))
for age in cells_per_variants_per_age.age.unique():
    ax.bar(x=cells_per_variants_per_age.loc[cells_per_variants_per_age.age == age, "single nucleotide variant"], height=cells_per_variants_per_age.loc[cells_per_variants_per_age.age == age, 0], label=age, alpha=0.2)
ax.legend(title="age", bbox_to_anchor=(1.15, 1), frameon=False)
ax.set_ylabel("frequency of variants")
ax.set_xlabel("single nucleotide variant")
ax.set_title(f"frequency of variants pooled from {RUNS} runs")
ax.set_yscale("log")
fig.show()

In [None]:
fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=(7, 3))
sns.histplot(
    data=simulated[simulated.id == 1],# .sample(n=4000, replace=False),
    x="single nucleotide variant",
    hue="age",
    kde=False,
    binwidth=1,
    ax=ax,
    stat="count",
    alpha=0.5,
    palette="Dark2"
)
ax.set_yscale("symlog")
sns.move_legend(ax, bbox_to_anchor=(1.01, 1), loc="upper left", frameon=False)
plt.show()

In [None]:
diamonds = sns.load_dataset("diamonds")

In [None]:
sns.catplot(
    data=simulated,
    x="age", y="single nucleotide variant", kind="box",
)

In [None]:
diamonds

In [None]:
for age in simulated.age.unique():
    fig, ax = plt.subplots(1, 1)
    sns.histplot(
        data=simulated.loc[simulated.age == age, ["single nucleotide variant", "age"]], 
        x="single nucleotide variant",
        ax=ax, 
        discrete=True
    )
    ax.set_title(age)
    plt.show()

In [None]:
yo = simulated[["age", "id", "single nucleotide variant"]].groupby(["age", "id"]).mean().reset_index()
yo

In [None]:
for age in yo.age.unique():
    fig, ax = plt.subplots(1, 1)
    for i in yo.id.unique():
        tmp = yo[(yo.age == age) & (yo.id == i)]
        ax.boxplot(x=i, y=tmp["single nucleotide variant"])
    ax.set_title(age)
    fig.show()

## Entropy
Compute the entropy of the sfs for all patients.

In [None]:
# load all sfs over time
entropies_avg, entropies_std = [], []
entropies = []
for t in range(1, NB_TIMEPOINTS + 1):
    sfs_t = sfs.load_sfs(PATH2SAVE, neutral=False, timepoint=t, runs=RUNS)

    entropy = list()
    # compute the entropy for the run
    for sfs_patient in sfs_t.values():
        jcells, jmuts = np.fromiter(sfs_patient.keys(), dtype=float), np.fromiter(sfs_patient.values(), dtype=float)
        jcells /= jcells.sum()
        pi = jcells * jmuts 
        entropy.append(stats.entropy(pi))
    for i, e in enumerate(entropy):
        entropies.append((x[t - 1], e, i))
    # average and std of all runs for this timepoint
    entropies_avg.append(np.mean(entropy, axis=-1))
    entropies_std.append(np.std(entropy, axis=-1))

fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=FIGSIZE)
ax.errorbar(
    x,
    entropies_avg,
    yerr=entropies_std,
    fmt="o",
    alpha=0.8,
    label=f"ABM, avg of {RUNS} runs",
)
ax.set_xlabel("time [years]")
ax.set_ylabel("avg entropy")
ax.set_title(f"avg of the entropy over {RUNS} runs")
plt.show()

In [None]:
# same plot as above but using sns api
entropies_df = pd.DataFrame(entropies, columns=["time [years]", "entropy", "run"])
sns.relplot(
    data=entropies_df[["time [years]", "entropy"]],
    x="time [years]",
    y="entropy",
    kind="line",
    errorbar="sd",
    aspect=2,
    height=3,
);

In [None]:
fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=FIGSIZE)

prob_jcells, idx = list(), list()
for patient, sfs_patient in sfs_all.items():
    jcells = np.fromiter(sfs_patient.keys(), dtype=float)
    jcells /= jcells.sum()
    prob_jcells.append(jcells)
    idx.append(patient)

for xx, pk in zip(idx, prob_jcells):
    ax.bar(xx, stats.entropy(pk), color="grey", alpha=0.8)
ax.set_ylabel("entropy")
ax.set_xlabel("run idx")
ax.set_xticks(range(0, RUNS))
ax.set_title(f"entropy for {RUNS} runs after {YEARS} years")
plt.show()

## Total variant
The total variant fraction is the fraction of all selected clones averaged over all patients, that is anything except the wild type.

In [None]:
fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=FIGSIZE)
other_sims = pd.read_csv(PATH2SIMS).loc[:YEARS, :]

ax.set_xlabel("time [years]")
ax.set_ylabel("avg total variant fraction")
variant_fraction = variant.load_variant_fractions(PATH2SAVE, NB_TIMEPOINTS, RUNS, SUBCLONES)
total_fraction = variant_fraction.sum(axis=-1)
ax.errorbar(
    x,
    total_fraction.mean(axis=-1),
    yerr=total_fraction.std(axis=-1),
    fmt="o",
    alpha=0.8,
    label=f"ABM, avg of {RUNS} runs",
)
ax.plot(other_sims.t, other_sims["Expected total variant fraction"], label="theory")
ax.plot(
    other_sims.t,
    other_sims["Average total variant fraction"],
    linestyle="--",
    label="sims, avg of ?? runs",
)
ax.legend(loc="upper left")
if SAVE:
    path2figure = PATH2SAVE / "figures"
    try:
        plt.savefig(path2figure / f"total_variant{EXTENSION}")
    except FileNotFoundError:
        path2figure.mkdir()
        plt.savefig(path2figure / f"total_variant{EXTENSION}")
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=FIGSIZE)
clones_abm = (
    np.sum(np.sum(np.where(variant_fraction > 0.0, 1, 0), axis=-1), axis=-1) / RUNS
)
ax.scatter(x, clones_abm, label=f"ABM, avg {RUNS} runs")
ax.plot(other_sims.t, other_sims[f"Average number of existing clones"], label="sims")
# ax.set_yscale("log")
ax.set_xlabel("time [years]")
ax.set_ylabel("clones")
ax.legend()
ax.set_title("avg # of clones")
fig.show()

In [None]:
fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=FIGSIZE)
clones_abm = (
    np.sum(
        np.sum(np.where(variant_fraction > DETECTION_THRESH, 1, 0), axis=-1), axis=-1
    )
    / RUNS
)
ax.scatter(x, clones_abm, label=f"ABM, avg {RUNS} runs")
ax.plot(
    other_sims.t,
    other_sims[f"Average number of clones above threshold 0.01"],
    label="sims",
)
# ax.set_yscale("log")
ax.set_xlabel("time [years]")
ax.set_ylabel("clones")
ax.legend()
ax.set_title(f"avg # of clones above frequency threshold of {DETECTION_THRESH}")
fig.show()

In [None]:
# record-format
df = list()
for t in range(0, NB_TIMEPOINTS):
    for r in range(RUNS):
        for c in range(0, variant_fraction.shape[-1]):
            df.append((x[t], r, c, variant_fraction[t, r, c]))
df = pd.DataFrame(
    df, columns=["time [years]", "run", "clone_id", "avg tot variant fraction"]
)
df

In [None]:
rl = sns.relplot(
    data=df.loc[
        df["avg tot variant fraction"] > DETECTION_THRESH,
        ["time [years]", "avg tot variant fraction"],
    ],
    x="time [years]",
    y="avg tot variant fraction",
    kind="line",
    errorbar="sd",
    aspect=2,
    height=3,
)
rl.fig.suptitle("tot avg fraction for detectable clones")
rl.fig.show()

In [None]:
grouped = (
    df[["run", "time [years]", "avg tot variant fraction"]]
    .groupby(["run", "time [years]"])
    .sum()
    .reset_index()
)
grouped.rename(
    columns={"avg tot variant fraction": "tot variant fraction"}, inplace=True
)
grouped

In [None]:
df = df.merge(grouped, on=["run", "time [years]"], how="left", validate="many_to_one")
df["effective fitness"] = (
    S * df["avg tot variant fraction"] / df["tot variant fraction"]
).fillna(0)
df

In [None]:
int_x = [round(x_) for x_ in x]
clones = []
for run in range(RUNS):
    fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=FIGSIZE)
    try:
        detected = df.loc[
            (df["avg tot variant fraction"] > DETECTION_THRESH) & (df.run == run), :
        ]
        detected_clones = set(detected.clone_id.tolist())
        sns.barplot(
            detected,
            x="time [years]",
            y="avg tot variant fraction",
            hue="clone_id",
            ax=ax,
            palette="Dark2",
        )
    except ValueError:  # no detectable clone for this run
        continue
    ax.set_ylabel("variant fraction")
    ax.legend(loc="center left", title="clone id")
    ax.set_title(
        f"variant fraction of clones above frequency threshold of {DETECTION_THRESH}"
    )
    plt.show()

    pivoted = detected.pivot(
        columns="clone_id", index="time [years]", values="avg tot variant fraction"
    ).fillna(0)

    fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=FIGSIZE)
    pivoted.plot(ax=ax, color=sns.color_palette("Dark2"))
    ax.set_ylabel("variant fraction")
    ax.legend(loc="center left", title="clone id")
    ax.set_title(
        f"variant fraction of clones above frequency threshold of {DETECTION_THRESH}"
    )
    ax.set_xlim([0, YEARS])

    fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=FIGSIZE)
    pivoted = (
        df.loc[
            (df.run == run) & (df.clone_id.isin(detected_clones)),
            ["clone_id", "time [years]", "effective fitness"],
        ]
        .pivot(columns="clone_id", index="time [years]", values="effective fitness")
        .fillna(0)
    )
    pivoted.plot(ax=ax, color=sns.color_palette("Dark2"))
    ax.set_ylabel("effective fitness")
    ax.legend(loc="center left", title="clone id")
    ax.set_title(
        f"effective fitness of clones above frequency threshold of {DETECTION_THRESH}"
    )

    plt.show()