In [None]:
from pathlib import Path
import pyarrow.parquet as pq
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

from tqdm.notebook import tqdm
import pandas as pd

tqdm.pandas()

import pyranges as pr
import seaborn as sns
from pyslavseq.preprocessing import collate_labels
from pyslavseq.plotting import joint_ecdfplot

HUE_ORDER = ["KNRGL", "OTHER", "KRGL"]

In [None]:
def npeaks_per_cell(df: pd.DataFrame, meta: pd.DataFrame, hue=None):

    plot_df = (
        df.groupby(["donor_id", "cell_id", "label"]).size().reset_index(name="n_peaks")
    )
    plot_df = plot_df.merge(meta, on="donor_id")

    # sns.boxplot(data=plot_df, y="donor_id", x="n_peaks", hue="KNRGL",ax=ax)
    g = sns.catplot(
        data=plot_df,
        y="n_peaks",
        x="donor_id",
        hue=hue,
        kind="box",
        col="diagnosis",
        hue_order=HUE_ORDER if hue else None,
        fliersize=3,
        sharex=False,
    )

    # set xlabel
    g.set_axis_labels("Donor ID", "Number of peaks")

## Read meta data

In [None]:
meta = pd.read_csv(snakemake.config["samples"], sep="\t", dtype={"sample_id": str, "tissue_id": str, "donor_id": str})  # type: ignore
donors = pd.read_csv(snakemake.config["donors"], sep="\t", dtype={"donor_id": str})  # type: ignore
meta = meta.merge(donors, on=["donor_id"]).rename(columns={"sample_id": "cell_id"})

## Read bulk SLAVseq peaks

In [None]:
bulk = pd.read_parquet(snakemake.input.bulk)  # type: ignore

## Read single-cell SLAVseq peaks

In [None]:
data = []
for f in tqdm(snakemake.input.cells):  # type: ignore
    df = pq.read_table(f).to_pandas()
    df["cell_id"] = Path(f).name.rstrip("_labelled.pqt")
    data.append(df)

data = pd.concat(data).reset_index(drop=True)
data["KNRGL"] = data[snakemake.params.pos_label]  # type: ignore
data = data.merge(meta, on="cell_id")
data["tissue"] = data["cell_id"].apply(
    lambda x: "HIP" if "ush" in x.lower() else "DLPFC"
)

ndonors = data["donor_id"].nunique()
ncells = data["cell_id"].nunique()
print(f"Loaded {len(data)} peaks from {ncells} cells from {ndonors} donors.")

## Label

In [None]:
print("Collating Labels")
data["label"] = data.progress_apply(collate_labels, axis=1)

In [None]:
# label ref peaks
# initialize these columns as boolean
def add_KRGL(d, df):
    rdf = pr.PyRanges(df.query("label == 'KRGL'"))
    df = pr.PyRanges(df).count_overlaps(rdf, overlap_col="ref").df
    df["ref"] = df["ref"] > 0

    brdf = pr.PyRanges(bulk.query("donor_id == @d and label == 'KRGL'"))
    df = pr.PyRanges(df).count_overlaps(brdf, overlap_col="bulk_ref").df
    df["bulk_ref"] = df["bulk_ref"] > 0
    share = df[df["bulk_ref"] | df["ref"]].index.isin(df[df["label"] == "KRGL"].index)
    print(f"Added {sum(~share)} KRGL labels to donor {d}")
    return df


data = (
    data.groupby("donor_id")
    .progress_apply(lambda x: add_KRGL(x.name, x))
    .reset_index(drop=True)
)

In [None]:
print("Collating Labels again")
data["label"] = data.progress_apply(collate_labels, axis=1)

## First filter

In [None]:
n_peaks = len(data)
data = data.query("max_mapq >= 30").reset_index(drop=True)
print(f"Removed {n_peaks - len(data)}/{n_peaks} peaks with MAX MAPQ < 30")

n_cells = data["cell_id"].nunique()
data = data.groupby("cell_id").filter(lambda d: (d["label"] == "KNRGL").sum() >= 20)
print(
    f"Removed {n_cells - data['cell_id'].nunique()}/{n_cells} cells with less than 20 KNRGL peaks"
)

## Compute distance nearest germline peak

In [None]:
print("Computing distance to nearest germline peak")


def germline_distance(cell_df: pd.DataFrame) -> pd.DataFrame:
    gdf = pr.PyRanges(cell_df[cell_df["label"] != "OTHER"])
    cell_df = pr.PyRanges(cell_df).df
    cell_df["germline_dist"] = (
        pr.PyRanges(cell_df).nearest(gdf, overlap=False).df["Distance"].abs()
    )
    return cell_df


data = (
    data.groupby("cell_id")
    .progress_apply(germline_distance)
    .reset_index(drop=True)
    .sort_values(["Chromosome", "Start"])
)

## Find peaks shared across cells/donors

In [None]:
data = pr.PyRanges(data).cluster().df
data["Cluster"] = data["Cluster"].astype("category")
cells_per_peak = data.groupby("Cluster", observed=True)["cell_id"].nunique()
donors_per_peak = data.groupby("Cluster", observed=True)["donor_id"].nunique()
data["n_cells"] = data["Cluster"].map(cells_per_peak)
data["n_donors"] = data["Cluster"].map(donors_per_peak)
data["cells_per_donor"] = data["n_cells"] / data["n_donors"]

data = (
    pr.PyRanges(data).count_overlaps(pr.PyRanges(bulk), overlap_col="n_bulk_donors").df
)

## Visualize

In [None]:
npeaks_per_cell(data, meta)

In [None]:
npeaks_per_cell(data, meta, hue="label")
mean_knrgls = data.groupby("cell_id", observed=True)["KNRGL"].sum().mean()
print(f"Mean KRNGL peaks per cell: {mean_knrgls:.2f}")

In [None]:
g = sns.ecdfplot(data=data, x="germline_dist", hue="label", hue_order=HUE_ORDER)
g.set(xscale="log", xlabel="Distance to nearest germline peak (bp)")

In [None]:
g = joint_ecdfplot(data, x="n_cells", y="n_donors", hue="label", hue_order=HUE_ORDER)

## Label additional KNRGLs using the bulk SLAVseq data

In [None]:
g = sns.ecdfplot(data=data, x="n_bulk_donors", hue="label", hue_order=HUE_ORDER)
g.set_xlabel("Number of overlaps with donor bulk peaks")

n_knrgl = sum(data["label"] == "KNRGL")
data.loc[(data["n_bulk_donors"] > 1) & (data["label"] == "OTHER"), "KNRGL"] = True
data.loc[(data["n_bulk_donors"] > 1) & (data["label"] == "OTHER"), "label"] = "KNRGL"
n_knrgl_filtered = sum(data["label"] == "KNRGL")

print(
    f"Labeled {n_knrgl_filtered - n_knrgl} additional KNRGL peaks from bulk SLAVseq data"
)

## Recompute germline distance

In [None]:
print("Computing distance to nearest germline peak")
data = (
    data.groupby("cell_id")
    .progress_apply(germline_distance)
    .reset_index(drop=True)
    .sort_values(["Chromosome", "Start"])
)

## Remove peaks at reference L1 insertions

In [None]:
# remove reference insertions and peaks with low mapping quality
print("Filtering peaks")
print("Found {} peaks with reference reads.".format(len(data[data["n_ref_reads"] > 0])))
print("Found {} peaks with reference clusters.".format(data["ref"].sum()))
print(
    "Found {} peaks with bulk SLAVseq reference clusters.".format(
        data["bulk_ref"].sum()
    )
)
print("Found {} peaks at primer sites.".format(data["primer_sites"].sum()))
for l in ["l1hs", "l1pa2", "l1pa3", "l1pa4", "l1pa5", "l1pa6"]:
    print("Found {} peaks at {} sites.".format(data[l].sum(), l))

data = data.query("label != 'KRGL'").reset_index(drop=True)
print("{} peaks remain after filtering".format(len(data)))

In [None]:
npeaks_per_cell(data, meta, hue="label")
mean_knrgls = data.groupby("cell_id", observed=True)["KNRGL"].sum().mean()
print(f"Mean KRNGL peaks per cell: {mean_knrgls:.2f}")

In [None]:
g = joint_ecdfplot(data, x="n_cells", y="n_donors", hue="label", hue_order=HUE_ORDER)

## Visualize

In [None]:
npeaks_per_cell(data, meta)

In [None]:
npeaks_per_cell(data, meta, hue="label")
mean_knrgls = data.groupby("cell_id", observed=True)["KNRGL"].sum().mean()
print(f"Mean KRNGL peaks per cell: {mean_knrgls:.2f}")

In [None]:
g = sns.ecdfplot(data=data, x="germline_dist", hue="label", hue_order=HUE_ORDER)
g.set_xscale("log")
g.set_xlabel("Distance to nearest germline peak (bp)")

In [None]:
g = joint_ecdfplot(data, x="n_cells", y="n_donors", hue="label", hue_order=HUE_ORDER)
g = joint_ecdfplot(
    data, x="n_cells", y="n_bulk_donors", hue="label", hue_order=HUE_ORDER
)

## Save

In [None]:
data.to_parquet(snakemake.output[0], index=False)  # type: ignore