# Descriptive analysis of Mitchell's paper
We plot the burden and the SFS for the data published in [Mitchell's et al. Nature 2022](https://www.nature.com/articles/s41586-022-04786-y).
There are no simulations here.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import socket
import seaborn as sns
from pathlib import Path

from hscpy import mitchell, realisation
from hscpy.figures import sfs as sfs_figures
from hscpy.figures import PlotOptions, simulations

from futils import parse_version, snapshot

NCELLS = 100_000

SAVEFIG = True
BIGLABELS = False
FIGSIZE = [5, 3] if BIGLABELS else [6.4, 4.8]  # default matplotlib
EXTENSION = ".svg"
PLOT_OPTIONS = PlotOptions(figsize=FIGSIZE, extension=EXTENSION, save=SAVEFIG)

In [None]:
if socket.gethostname() == "5X9ZYD3":
    PATH2MITCHELL = Path("/mnt/c/Users/terenz01/Documents/SwitchDrive/PhD/hsc")
elif socket.gethostname() == "LAPTOP-CEKCHJ4C":
    PATH2MITCHELL = Path("/mnt/c/Users/fra_t/Documents/PhD/hsc")
else:
    PATH2MITCHELL = Path("~").expanduser()

## Mitchell's data

In [None]:
summary = mitchell.load_and_process_mitchell(
    PATH2MITCHELL / "Summary_cut.csv", drop_donor_KX007=True
)
summary.dtypes

In [None]:
print(summary.describe())
print(f"\n\ncell types: \n{summary.cell_type.value_counts()}")
print(f"\n\nsample types: \n{summary.sample_type.value_counts()}")
print(f"\n\ntimepoints: \n{summary.timepoint.value_counts()}")
print(
    f'\n\nages and cells: \n{summary[["donor_id", "cells", "age"]].drop_duplicates()}'
)
print(
    f'\n\nmutations per donor: \n{summary[["donor_id", "number_mutations"]].groupby("donor_id").sum()}'
)

In [None]:
for i in summary.donor_id.unique():
    fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=(6, 4))
    sns.histplot(
        data=summary[summary.donor_id == i],
        x="number_mutations",
        hue="donor_id",
        kde=True,
        bins=50,
        ax=ax,
        stat="count",
    )
    if PLOT_OPTIONS.save:
        plt.savefig(f"./{i}_burden{EXTENSION}")
    plt.show()

In [None]:
descr = (
    summary.loc[summary.age == 0, ["donor_id", "number_mutations"]]
    .groupby("donor_id")
    .describe()
)
descr

In [None]:
fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=FIGSIZE)
sns.histplot(
    data=summary,
    x="number_mutations",
    hue="donor_id",
    kde=True,
    binwidth=10,
    ax=ax,
    stat="percent",
)
sns.move_legend(
    ax,
    ncol=2,
    # bbox_to_anchor=(1.01, 1),
    loc="upper right",
    frameon=False,
    fontsize="small",
)
if PLOT_OPTIONS.save:
    plt.savefig(f"./mitchell_burden{EXTENSION}")
plt.show()

In [None]:
descr[("number_mutations", "mean")].mean() / (2 * np.log(200_000 - 2))

In [None]:
descr[("number_mutations", "std")] ** 2

## Plot the SFS with the correction for sampling

In [None]:
%%time
# compute the correction for the sims' SFS with sampled distributions from
# https://www.biorxiv.org/content/10.1101/2022.11.07.515470v2
corrected_variants_one_over_1_squared = dict()
for donor in (
    summary[["donor_id", "age", "cells"]]
    .drop_duplicates()
    .sort_values(by="age", ascending=False)
    .itertuples()
):
    print(
        f"apply sampling correction to SFS of donor {donor.donor_id} with age {donor.age}"
    )
    corrected_variants_one_over_1_squared[
        donor.donor_id
    ] = realisation.compute_variants(
        realisation.Correction.ONE_OVER_F_SQUARED,
        pop_size=NCELLS,
        sample_size=donor.cells,
    )

In [None]:
%%time
# compute SFS from Mitchell's data
names_mitchell = summary.donor_id.unique()
ages_mitchell = summary.age.unique().tolist()
# there are two donors with the same age 0
assert len(ages_mitchell) + 1 == len(names_mitchell)
mitchell_sfs = {
    donor: mitchell.sfs_donor_mitchell(donor, age, PATH2MITCHELL, remove_indels=False)
    for age, donor in zip([0] + ages_mitchell, names_mitchell)
}

In [None]:
for age, name in zip([0] + summary.age.unique().tolist(), summary.donor_id.unique()):
    normalisation_x = sfs_figures.ToCellFrequency(
        sample_size=summary.loc[summary["donor_id"] == name, "cells"]
        .drop_duplicates()
        .squeeze(),
        to_one=True,
    )
    fig, ax = plt.subplots(1, 1, layout="constrained", figsize=PLOT_OPTIONS.figsize)
    sfs_figures.plot_sfs_correction(
        ax,
        corrected_variants_one_over_1_squared[name],
        normalise=True,
        options=PLOT_OPTIONS,
        normalise_x=normalisation_x,
        linestyle="-",
        color="grey",
        label=r"$1/f^2$ sampled",
        linewidth=2,
    )

    sfs_figures.plot_sfs(
        ax,
        mitchell_sfs[name][3],
        normalise=True,
        normalise_x=normalisation_x,
        options=PLOT_OPTIONS,
        color="purple",
        mew=2,
        linestyle="",
        marker="x",
        label=f"{name}, {age} years",
    )

    ax.legend(
        fontsize="medium",
        loc="upper right",
        frameon=False,
    )

    if PLOT_OPTIONS.save:
        fig.savefig(f"./sfs_age{age}_{name}{PLOT_OPTIONS.extension}")
    fig.show()

In [None]:
def plot_sfs_multiple_data(names):
    fig, ax = plt.subplots(1, 1, layout="constrained", figsize=PLOT_OPTIONS.figsize)
    for name, ls, mk in zip(names, ("-", "--", "-."), ('x', 'o', '.')):
        normalisation_x = sfs_figures.ToCellFrequency(
            sample_size=summary.loc[summary["donor_id"] == name, "cells"]
            .drop_duplicates()
            .squeeze(),
            to_one=True,
        )

        sfs_figures.plot_sfs(
            ax,
            mitchell_sfs[name][3],
            normalise=True,
            normalise_x=normalisation_x,
            options=PLOT_OPTIONS,
            color="purple",
            mew=2,
            linestyle="",
            marker=mk,
            label="CB001",
        )

        sfs_figures.plot_sfs_correction(
            ax,
            corrected_variants_one_over_1_squared[name],
            normalise=True,
            options=PLOT_OPTIONS,
            normalise_x=normalisation_x,
            linestyle=ls,
            color="grey",
            label=r"$1/f^2$ sampled",
            linewidth=2,
        )
    return fig, ax

names = ['KX008', 'KX004', 'KX003']
fig, ax = plot_sfs_multiple_data(names)
# ax.set_ylim([10**(-4), 1.3])
# ax.legend()
fig.savefig("old_sfs.png")
plt.show()