In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import json
import csv
import numpy as np
import pandas as pd
from pathlib import Path

from futils import snapshot

from hscpy import burden, variant
from hscpy.figures import burden as burden_fig
from hscpy.figures import options

In [None]:
PATH2SAVE = Path("../hsc/test/")
NB_TIMEPOINTS = 10
YEARS = 70
SUBCLONES = 60
SAVE = True
BIGLABELS = False
FIGSIZE = [5, 3] if BIGLABELS else [6.4, 4.8]  # default matplotlib
PDF = False
EXTENSION = ".pdf" if PDF else ".png"

PLOT_OPTIONS = options.PlotOptions(figsize=FIGSIZE, extension=EXTENSION, save=SAVE)

RUNS = 1
# SAMPLE_STRONG = int(summary.cells.mean())# SAMPLE_WEAK = SAMPLE_STRONG * 10  # TODO * 100
NCELLS = 200_000

B0 = 1  # TODO: double check this, should be between 2 and 20?
NEUTRAL_RATE = 20  # [mut/(year * cell)]

## FIT CLONES
MU0 = 2
S = 0
u = MU0 / (B0 * NCELLS)

In [None]:
sim_options_population = options.SimulationOptions(
    runs=RUNS,
    cells=NCELLS,
    sample=NCELLS,
    path2save=PATH2SAVE,
    neutral_rate=NEUTRAL_RATE,
    nb_timepoints=NB_TIMEPOINTS,
    last_timepoint_years=YEARS,
    nb_subclones=SUBCLONES,
    s=S,
)

In [None]:
ages = [
    0.0,
    7.7777777,
    15.555555,
    23.333332,
    31.11111,
    38.88889,
    46.666668,
    54.444447,
    62.222225,
    70.0,
]

In [None]:
burden_fig.show_burden_plots(
    sim_options_population, PLOT_OPTIONS, ages=ages[1:], id2plot="0", verbosity=False
)

In [None]:
for f in (sim_options_population.path2save / "rates").iterdir():
    fig, ax = plt.subplots(1, 1)
    data = []
    with open(f, "r") as file:
        data.extend([float(ele) for r in csv.reader(file) for ele in r if ele])
    ax.hist(data, bins=20)
    plt.show()

In [None]:
for t, age in zip(range(1, 11), ages[::-1]):
    fig, ax = plt.subplots(1, 1)
    sfs = []
    for i, file in enumerate(
        (
            sim_options_population.path2save
            / f"{sim_options_population.sample}cells/sfs/{t}"
        ).iterdir(),
        1,
    ):
        try:
            with open(file, "r") as f:
                sfs.append(
                    snapshot.Histogram(
                        {int(x): int(y) for x, y in json.load(f).items()}
                    )
                )
        except json.JSONDecodeError as e:
            print(f"Error in opening {file} {e}")
            sys.exit(1)
    # avg = snapshot.Uniformise.pooled_distribution(sfs)
    avg = sfs[0]
    # jmuts = max(avg.values())
    # ax.scatter(list(avg.keys()), [ele / jmuts for ele in avg.values()])
    ax.scatter(list(avg.keys()), list(avg.values()))
    ax.set_title(f"SFS at age {age:.1f} avg over {sim_options_population.runs} runs")
    ax.set_xlabel("nb of cells")
    ax.set_ylabel("density of variants")
    ax.set_xscale("log")
    ax.set_yscale("log")
    if PLOT_OPTIONS.save:
        plt.savefig(
            sim_options_population.path2save / f"sfs_{age}{PLOT_OPTIONS.extension}"
        )
    plt.show()

In [None]:
variants = variant.load_variant_fractions(
    sim_options_population.path2save,
    sim_options_population.nb_timepoints,
    sim_options_population.sample,
    sim_options_population.runs,
    sim_options_population.nb_subclones,
)

In [None]:
variants_nate = pd.read_csv(
    "/mnt/c/Users/terenz01/Documents/SwitchDrive/PhD/hsc/totalVariantFracTime.csv"
)
variants_nate

In [None]:
fig, ax = plt.subplots(1, 1)
ax.plot(ages, variants.sum(axis=-1), label="ABM")
ax.plot(
    variants_nate.t, variants_nate["Expected total variant fraction"], label="theory"
)
ax.plot(variants_nate.t, variants_nate["Average total variant fraction"], label="SDE")
ax.legend()
plt.show()

In [None]:
# record-format
df = list()
for t in range(0, sim_options_population.nb_timepoints):
    for r in range(sim_options_population.runs):
        for c in range(0, variants.shape[-1]):
            df.append((int(round(ages[t], 0)), r, c, variants[t, r, c]))
df = pd.DataFrame(df, columns=["time [years]", "run", "clone_id", "variant fraction"])

In [None]:
import seaborn as sns

In [None]:
variants.sum(axis=-1)[3:].shape

In [None]:
fig, ax = plt.subplots(1, 1)
for run in range(sim_options_population.runs):
    detected = df.loc[
        (df["variant fraction"] > 0.001) & (df.run == run),
        :,
    ]
    detected_clones = set(detected.clone_id.tolist())
    sns.barplot(
        detected,
        x="time [years]",
        y="variant fraction",
        hue="clone_id",
        ax=ax,
        palette="Dark2",
    )
# ax.set_xlim([-0.5, 4.5])
# ax.set_yscale("log")
ax.plot(variants.sum(axis=-1)[3:], linestyle="", marker="x", mew=2)
sns.move_legend(ax, loc=2)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
ax.xaxis.label.set_size(22)
ax.yaxis.label.set_size(22)
plt.yticks(fontsize=22)
ax.minorticks_on()
ax.xaxis.set_tick_params(which="minor", bottom=False)
plt.tight_layout()
if PLOT_OPTIONS.save:
    plt.savefig(
        sim_options_population.path2save / f"variant_run0{PLOT_OPTIONS.extension}"
    )
plt.show()