In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from scipy import stats
from typing import Tuple

SAVE = False
BIGLABELS = False
FIGSIZE = [5, 3] if BIGLABELS else [6.4, 4.8]  # default matplotlib
PDF = True
EXTENSION = ".pdf" if PDF else ".png"

In [None]:
try:
    summary = pd.read_csv("../Summary_cut.csv", index_col=0)
except FileNotFoundError:
    summary = pd.read_csv(
        "/mnt/c/Users/fra_t/Documents/PhD/Summary_cut.csv", index_col=0
    )
summary.cell_type = summary.cell_type.astype("category")
summary.sample_type = summary.sample_type.astype("category")
summary.sort_values(by="age", inplace=True)
summary.reset_index(inplace=True)
# neglect some duplicated colonies e.g. summary.colony_ID == "11_E07"
summary = summary.merge(
    summary[["donor_id", "age"]]
    .groupby("donor_id")
    .count()
    .reset_index()
    .rename(columns={"age": "cells"}),
    on="donor_id",
    validate="many_to_one",
    how="left",
)
summary.dtypes

In [None]:
summary.describe()

In [None]:
summary.cell_type.value_counts()

In [None]:
summary.sample_type.value_counts()

In [None]:
summary.timepoint.value_counts()

In [None]:
ages = summary[["age", "donor_id"]].drop_duplicates()
ages

In [None]:
summary

In [None]:
summary[["donor_id", "cells", "age"]].drop_duplicates()

In [None]:
summary[["donor_id", "number_mutations"]].groupby(
    "donor_id"
).sum()  # mutations per donor

In [None]:
sns.regplot(data=summary, x="age", y="number_mutations", scatter_kws={"marker": "x"})

In [None]:
mean_mutations = (
    summary[["donor_id", "number_mutations"]]
    .groupby("donor_id")
    .mean()
    .reset_index()
    .merge(
        summary[["donor_id", "age"]].drop_duplicates(),
        on="donor_id",
        how="inner",
        validate="one_to_one",
    )
    .sort_values(by="age")
)

In [None]:
x = mean_mutations.age.to_numpy()
y = mean_mutations.number_mutations.to_numpy()
A = np.vstack([x, np.ones(len(x))]).T
m, c = np.linalg.lstsq(A, y, rcond=None)[0]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=FIGSIZE)
ax.plot(x, y, "o")
ax.plot(x, m * x + c, "b", linestyle="--")
# ax.plot(x, 16*x + c, 'r', linestyle="--")
ax.set_xlabel("age [years]")
ax.set_ylabel("avg number of SNVs")
ax.set_title(f"y=mx+c with m={m:.2f}, c={c:.2f}")
fig.show()

In [None]:
fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=FIGSIZE)
sns.histplot(
    data=summary,
    x="number_mutations",
    hue="donor_id",
    kde=True,
    binwidth=10,
    ax=ax,
    stat="count",
)
sns.move_legend(ax, bbox_to_anchor=(1.01, 1), loc="upper left", frameon=False)
if SAVE:
    plt.savefig(f"./mitchell_burden{EXTENSION}")
plt.show()

In [None]:
for i in summary.donor_id.unique():
    fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=(6, 4))
    sns.histplot(
        data=summary[summary.donor_id == i],
        x="number_mutations",
        hue="donor_id",
        kde=True,
        bins=50,
        ax=ax,
        stat="count",
    )
    if SAVE:
        plt.savefig(f"./{i}_burden{EXTENSION}")
    plt.show()

## Entropy
Based on the code they [developped](https://github.com/emily-mitchell/normal_haematopoiesis/blob/23d221e8d125d78c1e8bcbe05d41d0f3594b0cfb/4_phylogeny_analysis/scripts/shannon_diversity.Rmd#L147), I think they define entropy as in [here](http://math.bu.edu/people/mkon/J6A.pdf) using the phylogenetic tree.
We just compute the entropy from the number of cells: we consider a class being the cells with the same number of mutations and compute the abbundance of those classes, that is the abbundance of cells with the same number of mutations.

## SFS using the genotype matrix

In [None]:
N = 20_000
S = 100
_f = compute_frequencies(N) # / N
correction = SamplingCorrection(N, S)
sampled_f, y = compute_variants(correction, Correction.ONE_OVER_F, S)

sampled_f_squared, y_squared = compute_variants(correction, Correction.ONE_OVER_F_SQUARED, S)

fig, ax = plt.subplots(1, 1, figsize=(6, 5))
ax.plot(_f[:S], sampled_f, label=f"$1/f$ sampled", marker="x", alpha=0.4)
ax.plot(_f[:S], sampled_f_squared / sampled_f_squared.max(), label=f"$1/f^2$ sampled", marker="x", alpha=0.4)
ax.plot(_f, y, label=f"$1/f$", alpha=0.4)
ax.plot(_f, y_squared, label=f"$1/f^2$", alpha=0.4)

ax.set_yscale("log")
ax.set_xscale("log")
ax.legend()
plt.show()

In [None]:
%%time
# pop size
N = 200_000
_f = compute_frequencies(N)

for i, donor in enumerate(summary.donor_id.unique()):
    if donor == "KX003":
        continue
    print(f"donor {donor}")
    filtered_matrix = filter_mutations(*load_patient(donor))
    sfs = filtered_matrix.sum(axis=1).value_counts(normalize=True)
    sfs.drop(index=sfs[sfs.index==0].index, inplace=True)
    x_sfs = sfs.index.to_numpy(dtype=int)
    y_sfs = sfs.to_numpy()
    # sample size
    cells = summary.loc[summary.donor_id == donor, "cells"].unique()[0]
    correction = SamplingCorrection(N, cells)
    assert cells <= 1000
    
    sampled_f, y = compute_variants(correction, Correction.ONE_OVER_F, cells)
    sampled_f_squared, y_squared = compute_variants(correction, Correction.ONE_OVER_F_SQUARED, cells)
    
    fig, ax = plt.subplots(1, 1, figsize=FIGSIZE)
    ax.plot(_f[:cells], sampled_f, label=f"$1/f$ sampled", alpha=0.4, linestyle="--", c="black")
    ax.plot(_f[:cells], sampled_f_squared / sampled_f_squared.max(), label=f"$1/f^2$ sampled", alpha=0.4, c="black")
    ax.plot(x_sfs, sfs.to_numpy(), label=f"{donor}", linestyle="", marker="x", c="purple", alpha=0.7)
    ax.set_yscale("log")
    ax.set_xscale("log")
    ax.set_xlabel("j cells")
    ax.set_ylabel("normalised nb of muts in j cells")
    ax.set_ylim([y_sfs.min() - y_sfs.min() * 0.2, 2])
    ax.set_xlim([0.8, x_sfs.max() +  x_sfs.max() * 0.2])
    ax.legend()
    ax.set_title(f"age {ages[ages.donor_id == donor].age.tolist()[0]}")
    plt.show()