In [None]:
from pathlib import Path
from collections import defaultdict
import warnings, sys, logging

from tqdm.notebook import tqdm
import pandas as pd
import numpy as np


import matplotlib.pyplot as plt
import matplotlib.colors as colors
import seaborn as sns
import pyranges as pr
import pickle as pkl
from pyslavseq.dataset import SLAVseqDataSet

tqdm.pandas()
logging.basicConfig(force=True, level=logging.INFO)
logger = logging.getLogger(__name__)
warnings.filterwarnings(
    "ignore", category=FutureWarning
)  # to handle pyranges FutureWarnings

HUE_ORDER = ["KNRGL", "OTHER", "KRGL"]

# THRESHOLDS
MIN_READS = 10
MAX_MAPQ = 60
MAX_GERMLINE_DISTANCE = 40000

## Load data

1. meta data
1. bulk peaks
1. single-cell peaks

In [None]:
data = defaultdict(dict)
algorithms = [Path(f).parent.name for f in snakemake.input.regions]
for a, f in zip(algorithms, snakemake.input.regions):
    logger.info(f"Loading {a} data")
    with open(f, "rb") as f:
        data[a] = pkl.load(f)
    data[a].filter(f"n_reads >= {MIN_READS}")
    data[a].filter(f"max_mapq >= {MAX_MAPQ}")

### Peaks pers cell

In [None]:
def plot_peaks_per_cell(dt: dict, show_krgl: bool = True):
    """
    Plot peaks per cell by algorithm
    dt: dict, {algorithm: pd.DataFrame}
    show_krgl: bool, whether to show KRGL peaks
    """
    n_algs = len(dt)

    # PLOT 1: peaks per cell by algorithm
    g, ax = plt.subplots(1, 1, figsize=(13, 5))
    for a, d in dt.items():
        df = d.groupby(["cell_id", "donor_id"]).size().reset_index(name="n_peaks")
        l = (
            a.upper()
            + ": "
            + str(df["n_peaks"].sum())
            + ", "
            + str(int(df["n_peaks"].mean()))
        )
        sns.histplot(data=df, x="n_peaks", binwidth=10, ax=ax, label=l)
    ax.set(title="Peaks per cell by algorithm", xlabel="# peak calls", ylabel="# cells")
    ax.legend(title="Algorithm: Total, Mean", loc="upper right")
    plt.show()

    # PLOT 2: peaks per cell by algorithm by donor
    g, ax = plt.subplots(1, 1, figsize=(13, 5))
    data = []
    for a, d in dt.items():
        df = d.groupby(["cell_id", "donor_id"]).size().reset_index(name="n_peaks")
        df["algorithm"] = a
        data.append(df)
    data = pd.concat(data)
    sns.boxplot(
        data=data, y="n_peaks", x="donor_id", hue="algorithm", ax=ax, fliersize=1
    )
    ax.set(title="Peaks per cell by donor", ylabel="# peak calls", xlabel="donor")
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment="right")
    plt.show()

    hue_order = HUE_ORDER if show_krgl else ["KNRGL", "OTHER"]
    # PLOT 3: peaks per cell by label by algorithm
    g, axs = plt.subplots(n_algs, 1, figsize=(13, n_algs * 5), sharex=True)
    axs = axs.flatten() if n_algs > 1 else [axs]
    for (a, d), ax in zip(dt.items(), axs):
        df = (
            d.groupby(["cell_id", "donor_id", "label"])
            .size()
            .reset_index(name="n_peaks")
        )
        # set binwidth based on max value
        df["label"] = pd.Categorical(df["label"], categories=HUE_ORDER)
        for h in hue_order:
            d = df[df["label"] == h]
            l = f"{h}: {d['n_peaks'].mean():.2f}"  # format decimal
            sns.histplot(data=d, x="n_peaks", binwidth=2, ax=ax, label=l)
        ax.legend(title="Label: Average", loc="upper right")
        ax.set(title=a.upper(), xlabel="# peak calls", ylabel="# cells")
    plt.show()

    # PLOT 4: peaks per cell by label by algorithm by donor
    g, axs = plt.subplots(n_algs, 1, figsize=(13, n_algs * 5), sharex=True)
    axs = axs if n_algs > 1 else [axs]
    plt.subplots_adjust(wspace=0.1)
    for (a, d), ax in zip(dt.items(), axs):
        df = (
            d.groupby(["cell_id", "donor_id", "label"])
            .size()
            .reset_index(name="n_peaks")
        )
        sns.boxplot(
            data=df,
            y="n_peaks",
            x="donor_id",
            hue="label",
            hue_order=hue_order,
            ax=ax,
            fliersize=1,
        )
        ax.set(title=a.upper(), xlabel="# peak calls")
    plt.show()


# TODO plot log1p violin instead of ecdf
def plot_features(dt: dict, show_krgl: bool = True):
    n_algs = len(dt)
    hue_order = HUE_ORDER if show_krgl else ["KNRGL", "OTHER"]

    for f in [
        "width",
        "n_reads",
        "germline_distance",
        "n_unique_5end",
        "n_unique_3end",
        "n_unique_clipped_3end",
        "n_duplicates",
        "three_end_clippedA_q1",
    ]:
        if f not in ["three_end_clippedA_q1", "width"]:
            opts = {"log_scale": True}
        else:
            opts = {"log_scale": False}
        g, axs = plt.subplots(1, n_algs, figsize=(n_algs * 7, 5), sharey=True)
        axs = axs if n_algs > 1 else [axs]
        for ax, (a, d) in zip(axs, dt.items()):
            sns.ecdfplot(
                data=d,
                x=f,
                hue="label",
                hue_order=hue_order,
                complementary=True,
                ax=ax,
                **opts,
            )
            ax.set(title=a.upper())

        plt.show()


def plot_locus_features(dt: dict, show_krgl: bool = True):
    n_algs = len(dt)
    hue_order = HUE_ORDER if show_krgl else ["KNRGL", "OTHER"]

    for f in [
        "width",
        "n_reads_max",
        "germline_distance",
        "n_unique_5end_max",
        "n_cells",
        "n_donors",
        "en_score",
        "three_end_clippedA_max",
    ]:
        if f not in ["three_end_clippedA_max", "width", "n_donors", "en_score"]:
            opts = {"log_scale": True}
        else:
            opts = {"log_scale": False}
        g, axs = plt.subplots(1, n_algs, figsize=(n_algs * 7, 5), sharey=True)
        axs = axs if n_algs > 1 else [axs]
        for ax, (a, d) in zip(axs, dt.items()):
            sns.ecdfplot(
                data=d,
                x=f,
                hue="label",
                hue_order=hue_order,
                complementary=True,
                ax=ax,
                **opts,
            )
            ax.set(title=a.upper())

        plt.show()

In [None]:
dt = {a: slav.data for a, slav in data.items()}
plot_peaks_per_cell(dt, show_krgl=False)

In [None]:
data["greedy"].ad.var["label"].value_counts()

In [None]:
dt = {a: slav.data for a, slav in data.items()}
plot_features(dt, show_krgl=False)

In [None]:
dt = {a: slav.ad.var for a, slav in data.items()}
plot_locus_features(dt, show_krgl=False)

Visualize merged clusters between callsets

In [None]:
# merge and plot width
n_algs = len(algorithms)
g, axs = plt.subplots(n_algs, 2, figsize=(10, n_algs * 5), sharey=True, sharex="row")
plt.subplots_adjust(wspace=0.1, hspace=0.2)
axs = axs if n_algs > 1 else [axs]
for ax, a in zip(axs, data.keys()):
    sns.ecdfplot(
        data=data[a].data,
        x="width",
        hue="label",
        hue_order=HUE_ORDER,
        complementary=True,
        ax=ax[0],
    )
    ax[0].set_title(a.upper() + " (raw)")
    sns.ecdfplot(
        data=data[a].ad.var,
        x="width",
        hue="label",
        hue_order=HUE_ORDER,
        complementary=True,
        ax=ax[1],
    )
    ax[1].set_title(a.upper() + " (merged)")

In [None]:
# filter by germline distance and revisualize
for a in algorithms:
    logger.info(f"Processing {a.upper()} ...")
    data[a].filter(f"germline_distance >= {MAX_GERMLINE_DISTANCE}")

data["greedy"].ad.var["label"].value_counts()

In [None]:
dt = {a: slav.data for a, slav in data.items()}
plot_peaks_per_cell(dt, show_krgl=False)
plot_features(dt, show_krgl=False)
dt = {a: slav.ad.var for a, slav in data.items()}
plot_locus_features(dt, show_krgl=False)

In [None]:
data["greedy"].ad.var["label"].value_counts()

## Analyze peaks in cells/donors/tissues

In [None]:
for a in algorithms:
    print(f"\n{a.upper()}")
    other_loci = data[a].ad.var.query("label == 'OTHER' and not bulk")
    plot = sns.jointplot(
        data=other_loci,
        x="n_donors",
        y="n_cells",
        # discrete=True,
        marginal_kws=dict(discrete=True, fill=True),
        marginal_ticks=True,
        alpha=0.5,
    )
    plot.set_axis_labels("# donors", "# cells")
    plot.ax_marg_x.set_yscale("log")
    plot.ax_marg_y.set_xscale("log")
    plot.fig.set_figwidth(8)
    plot.fig.set_figheight(8)
    plot.fig.suptitle(f"{a.upper()}: # donors vs # cells")
    plt.show()

    single_donor_loci = other_loci.query("n_donors == 1")
    g, ax = plt.subplots(1, 1, figsize=(9, 5))
    sns.histplot(
        data=single_donor_loci, x="n_cells", discrete=True, hue="n_tissues", ax=ax
    )
    ax.set_yscale("log")
    ax.set_title(
        f"Cells per peak: {len(single_donor_loci)} peaks present in only one donor"
    )
    ax.set_ylabel("# peaks")
    ax.set_xlabel("# cells")
    plt.show()

Analyze single donor greedy peaks

In [None]:
N_CELLS = 10
cand_missed_germline = single_donor_loci.query(
    f"n_cells > {N_CELLS} and n_tissues == 2"
)
print(
    f"Found {len(cand_missed_germline)} peaks present in only 1 donor, 2 tissues, and with than {N_CELLS} cells"
)

In [None]:
reads_2d_data = data["greedy"].ad[:, other_loci.index]
reads_long_data = reads_2d_data.to_df().stack().reset_index(name="n_reads")
reads_long_data = reads_long_data.set_index("sample_id").join(
    data["greedy"].ad.obs[["region", "donor_id"]]
)
reads_long_data = reads_long_data.query("Cluster in @cand_missed_germline.index")
reads_long_data = (
    reads_long_data.set_index("Cluster")
    .join(data["greedy"].ad.var[["label", "n_donors", "n_cells", "locus"]])
    .reset_index()
)

In [None]:
cand_missed_germline.sort_values("n_cells", ascending=False).head(10)[
    ["n_cells", "n_tissues", "locus"]
]

In [None]:
reads_long_data[reads_long_data["Cluster"] == "10245"].query("n_reads > 0")

In [None]:
# make plots jointly together
sns.catplot(
    reads_long_data,
    x="n_reads",
    y="donor_id",
    hue="region",
    col="Cluster",
    col_wrap=6,
    jitter=0.2,
    alpha=0.5,
    size=4,
    log_scale=True,
)

In [None]:
single_donor_loci.loc[:, "clonality"] = "subclonal"
single_donor_loci.loc[single_donor_loci["n_cells"] > 10, "clonality"] = "clonal"
single_donor_loci.loc[
    single_donor_loci["n_tissues"] == 2, "clonality"
] = "missed germline"
single_donor_loci["clonality"].value_counts()

In [None]:
for f in []

In [None]:
plots = [
    ("germline_distance", "n_bulk_donors", (True, False)),
    ("germline_distance", "n_donors", (True, False)),
    ("germline_distance", "n_cells", (True, False)),
    ("germline_distance", "cells_per_donor", (True, False)),
    ("germline_distance", "n_donor_cells", (True, False)),
    ("germline_distance", "n_HIP_cells", (True, False)),
    ("germline_distance", "n_DLPFC_cells", (True, False)),
    ("germline_distance", "n_reads", (True, True)),
    ("germline_distance", "three_end_clippedA_mean", (True, False)),
    ("germline_distance", "three_end_clippedA_q0", (True, False)),
    ("germline_distance", "three_end_clippedA_q1", (True, False)),
    ("germline_distance", "alignment_score_mean", (True, False)),
    ("germline_distance", "alignment_score_normed_mean", (True, False)),
    ("germline_distance", "n_proper_pairs", (True, False)),
    ("germline_distance", "n_unique_5end", (True, False)),
    ("germline_distance", "n_unique_clipped_3end", (True, False)),
    ("germline_distance", "5end_gini", (True, False)),
    ("n_cells", "n_donors", (False, False)),
    ("n_cells", "n_bulk_donors", (False, False)),
    ("n_cells", "n_donor_cells", (False, False)),
    ("n_donors", "n_bulk_donors", (False, False)),
]

# for i, (x, y, s) in enumerate(plots):
#     datashader_plot(data, x, y, s, plot_width=100, plot_height=100)