In [None]:
from pathlib import Path
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

import numpy as np
from scipy import stats
import pandas as pd
import pyranges as pr
import seaborn as sns
import matplotlib.pyplot as plt
from pyslavseq.preprocessing import collate_labels

HUE_ORDER = ["KNRGL", "OTHER", "KRGL"]

## 1. Read data

In [None]:
def parse_flagstat(file):
    with open(file) as f:
        lines = f.readlines()
        lines = [line.strip() for line in lines]
    return {"total": int(lines[0].split()[0]), "duplicates": int(lines[3].split()[0])}

In [None]:
meta = pd.read_csv(snakemake.config["donors"], sep="\t", dtype={"donor_id": str})  # type: ignore

In [None]:
bulk = sorted(snakemake.input.bulk)  # type: ignore
flagstat = sorted(snakemake.input.flagstat)  # type: ignore
bdata, fdata = [], []
for b, f in zip(bulk, flagstat):  # type: ignore
    bdf = pd.read_parquet(b)
    bdf["donor_id"] = Path(b).parent.name
    bdata.append(bdf)
    fdf = parse_flagstat(f)
    fdf["donor_id"] = Path(f).parent.name
    fdata.append(fdf)

bdata = pd.concat(bdata).merge(meta, on="donor_id")
fdata = pd.DataFrame(fdata)
fdata["non-duplicates"] = fdata["total"] - fdata["duplicates"]
print(f"Loaded {len(bdata)} peaks from {bdata['donor_id'].nunique()} donors")
avg_peaks = bdata.groupby("donor_id").size().mean()
sd_peaks = bdata.groupby("donor_id").size().std()
print(f"{int(avg_peaks)} ± {int(sd_peaks)} peaks per donor")

## 2. Peak summary stats - unlabelled

In [None]:
data = (
    bdata.groupby(["libd_id", "race", "diagnosis", "donor_id", "age"])
    .size()
    .reset_index(name="n_peaks")
    .merge(fdata, on="donor_id")
)

In [None]:
features = ["diagnosis", "race", "age", "total", "duplicates", "non-duplicates"]
g, axs = plt.subplots(2, 3, figsize=(18, 12))

# Define a function to calculate and annotate correlation coefficient and p-value
def annotate_correlation(ax, data, x, y):
    # convert y to numeric if it is a string
    if (data[y].dtype == "object") or (data[y].dtype.name == "category"):
        data[y] = pd.Categorical(data[y])
        correlation_coefficient, p_value = stats.pearsonr(
            data[x], data[y].cat.codes.astype(np.float64)
        )
    else:
        correlation_coefficient, p_value = stats.pearsonr(data[x], data[y])
    ax.text(
        0.1,
        0.7,
        f"Pearson: {correlation_coefficient:.2f}\nP-value: {p_value:.2e}",
        transform=ax.transAxes,
    )


for f, ax in zip(features, axs.flatten()):
    sns.scatterplot(data=data, y=f, x="n_peaks", alpha=0.5, ax=ax)
    annotate_correlation(ax, data, "n_peaks", f)
    ax.set_ylabel("")
    ax.set_title(f)

In [None]:
features_scale = [
    ("width", False),
    ("n_reads", True),
    ("rpm", True),
    ("n_ref_reads", True),
    ("n_unique_5end", True),
    ("frac_unique_5end", False),
    ("n_duplicates", True),
    ("frac_duplicates", False),
    ("n_contigs", True),
    ("frac_contigs", False),
    ("min_mapq", False),
    ("max_mapq", False),
]

# subplots
g, axs = plt.subplots(3, 4, figsize=(24, 18))
axs = axs.flatten()
# setup colors
colors = sns.color_palette("tab10", n_colors=2)
opts = {
    "hue_order": ["CAUC", "AA"],
    "hue": "race",
    "palette": {"CAUC": colors[0], "AA": colors[1]},
    "alpha": 0.5,
}

for i, (f, s) in enumerate(features_scale):
    for g, df in bdata.groupby("libd_id"):
        sns.ecdfplot(data=df, x=f, log_scale=s, ax=axs[i], **opts)

## 3. Peak summary stats - labelled by annotation

In [None]:
labels = [
    "primer_sites",
    "l1hs",
    "l1pa2",
    "l1pa3",
    "l1pa4",
    "l1pa5",
    "l1pa6",
    "megane_gaussian",
    "megane_breakpoints",
    "graffite",
    "xtea",
]
bdata["other"] = bdata[labels].sum(axis=1) == 0
labels.append("other")

# plot the number of peaks per donor per label
data = (
    bdata.groupby(["libd_id", "race", "diagnosis", "donor_id", "age"])[labels]
    .sum()
    .reset_index()
    .merge(fdata, on="donor_id")
)

In [None]:
g, axs = plt.subplots(
    len(features),
    len(labels),
    figsize=(len(labels) * 6, len(features) * 6),
    sharey="row",
    sharex="col",
)

for i, f in enumerate(features):
    for j, l in enumerate(labels):
        sns.scatterplot(data=data, y=f, x=l, alpha=0.5, ax=axs[i, j])
        annotate_correlation(axs[i, j], data, l, f)

In [None]:
g, axs = plt.subplots(
    len(features_scale),
    len(labels),
    figsize=(len(labels) * 6, len(features_scale) * 6),
)

for i, (f, s) in enumerate(features_scale):
    for j, l in enumerate(labels):
        for g, df in bdata.groupby("libd_id"):
            data = df[df[l]]
            sns.ecdfplot(data=data, x=f, log_scale=s, ax=axs[i, j], **opts)
            axs[i, j].set_title(l)

In [None]:
for (f1, s1) in features_scale:
    if f1 == "n_reads":
        continue
    print(f"Plotting {f1}")
    g, axs = plt.subplots(1, 12, figsize=(72, 6))
    for l, ax in zip(labels, axs.flatten()):
        data = bdata[bdata[l]]
        if s1 == True:
            data = data[data[f1] > 0]
        sns.histplot(data=data, x="n_reads", y=f1, log_scale=(True, s1), bins=50, ax=ax)
        ax.set_title(l)
    plt.show()

In [None]:
# filter
bdata = bdata.query("n_reads > 100 and n_duplicates > 100 and max_mapq == 60")

# label
bdata["label"] = bdata.apply(collate_labels, axis=1)

# save
bdata.to_parquet(snakemake.output[0])  # type: ignore

## 4. Annotation stats

Inspect coverage of germline calls

In [None]:
# mdata = pd.concat(mdata).merge(meta, on="donor_id")
# mdata["bulk peak"] = mdata["bulk peak"].astype(bool)
# print(f"Loaded {len(mdata)} WGS calls from {mdata['donor_id'].nunique()} donors")
# avg_wgs = mdata.groupby("donor_id").size().mean()
# sd_wgs = mdata.groupby("donor_id").size().std()
# print(f"{int(avg_wgs)} ± {int(sd_wgs)} WGS calls per donor")

# mdata = pr.PyRanges(mdata).cluster().df
# m_ndonors_call = (
#     mdata.groupby(["Cluster", "bulk peak"], observed=True)["donor_id"]
#     .nunique()
#     .reset_index(name="ndonors")
# )
# m_ncalls_donor = (
#     mdata.groupby(["donor_id", "bulk peak", "race"])
#     .size()
#     .reset_index(name="ncalls")
#     .sort_values("race")
# )

# bdata = pr.PyRanges(bdata).cluster().df
# b_ndonors_call = (
#     bdata.groupby(["Cluster", "label"], observed=True)["donor_id"]
#     .nunique()
#     .reset_index(name="ndonors")
# )
# b_ncalls_donor = (
#     bdata.groupby(["donor_id", "label", "race"])
#     .size()
#     .reset_index(name="ncalls")
#     .sort_values("race")
# )

In [None]:
# # TODO sort donors by race
# fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 14))

# sns.barplot(data=m_ncalls_donor, x="ncalls", y="donor_id", hue="bulk peak", ax=ax1).set(
#     xlabel="LINE1 insertions detected from WGS"
# )
# sns.ecdfplot(data=m_ndonors_call, x="ndonors", hue="bulk peak", ax=ax2).set(
#     ylabel="LINE1 insertions detected from WGS", xlabel="# donors"
# )
# # retitle legends
# ax1.legend_.set_title("Covered by Bulk SLAVseq peak")
# ax2.legend_.set_title("Covered by Bulk SLAVseq peak")


# sns.barplot(
#     data=b_ncalls_donor,
#     x="ncalls",
#     y="donor_id",
#     hue="label",
#     hue_order=HUE_ORDER,
#     ax=ax3,
# ).set(xlabel="Bulk SLAVseq peaks", xscale="log", xlim=(1, None))
# sns.ecdfplot(
#     data=b_ndonors_call, x="ndonors", hue="label", hue_order=HUE_ORDER, ax=ax4
# ).set(ylabel="Bulk SLAVseq peaks", xlabel="# donors")

# plt.tight_layout()

In [None]:
# # look for clonal insertions labelled "OTHER"
# other = bdata.query("label == 'OTHER'").copy()
# other["Cluster"] = other["Cluster"].astype("category")
# ndonors = other.groupby("Cluster", observed=True)["donor_id"].nunique()
# avg_reads = other.groupby("Cluster", observed=True)["n_reads"].mean()
# avg_n_unique_5end = other.groupby("Cluster", observed=True)["n_unique_5end"].mean()
# plot_df = pd.concat([ndonors, avg_reads, avg_n_unique_5end], axis=1).rename(
#     columns={
#         "donor_id": "n_donors",
#         "n_reads": "avg_reads",
#         "n_unique_5end": "avg_n_unique_5end",
#     }
# )

# fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 5))
# sns.ecdfplot(data=plot_df, x="n_donors", stat="count", ax=ax1)
# sns.scatterplot(data=plot_df, x="n_donors", y="avg_reads", alpha=0.5, ax=ax2)
# ax2.set_yscale("log")
# ax2.axhline(30, color="red", linestyle="--")
# sns.scatterplot(data=plot_df, x="n_donors", y="avg_n_unique_5end", alpha=0.5, ax=ax3)
# ax3.set_yscale("log")

## PCA

In [None]:
# from sklearn.preprocessing import StandardScaler
# from sklearn.decomposition import PCA


# features = snakemake.config["features"]  # type: ignore


# scaler = StandardScaler()
# pca = PCA(n_components=2)

# X = bdata[features].values
# X = scaler.fit_transform(X)
# X = pca.fit_transform(X)

# bdata["PC1"] = X[:, 0]
# bdata["PC2"] = X[:, 1]

# bdata["log_n_reads"] = np.log10(bdata["n_reads"])
# bdata["log_rpm"] = np.log10(bdata["rpm"])

# hues = [
#     "label",
#     "log_n_reads",
#     "max_mapq",
#     "min_mapq",
#     "frac_unique_3end",
#     "frac_unique_5end",
# ]

# fig, axes = plt.subplots(
#     1, len(hues), figsize=(5 * len(hues), 5), sharey=True, sharex=True
# )
# plt.subplots_adjust(wspace=0)

# for ax, hue in zip(axes, hues):
#     sns.scatterplot(data=bdata, x="PC1", y="PC2", hue=hue, s=3, alpha=0.7, ax=ax)

In [None]:
# # get pca loadinsg
# loadings = pca.components_.T * np.sqrt(pca.explained_variance_)
# pd.DataFrame(loadings, columns=["PC1", "PC2"], index=features)