In [None]:
from IPython.display import Markdown as md
from pathlib import Path
import numpy as np
import pandas as pd
import pegasus as pg
import matplotlib.pyplot as plt
import seaborn as sns
import os, sys

# get access to src module
if snakemake.config["istest"]:
    sys.path.append((os.path.abspath("../workflow")))
else:
    sys.path.append((os.path.abspath("workflow")))

from src.plot_utils import pretty_table

In [None]:
md(f"# Filter report for {snakemake.wildcards.soloFeatures}")

## Read in data

In [None]:
# remake input file into nested dictionary
samples = {}
for i in snakemake.input.keys():
    samples[i] = {"run": [], "file": []}
    for f in snakemake.input[i]:
        p = Path(f)
        if i == "IRescue":
            samples[i]["run"].append(p.parent.parent.name)
        elif i in ["STARsolo", "CellBender"]:
            samples[i]["run"].append(p.parent.parent.parent.name)
        elif i == "Demuxlet":
            samples[i]["run"].append(p.parent.name)
        else:
            continue
        samples[i]["file"].append(f)
    samples[i] = pd.DataFrame(samples[i])

In [None]:
d = {"Sample": [], "Location": [], "Modality": []}
for m in snakemake.input.keys():
    for s in samples[m].itertuples():
        if m == "IRescue":
            mode = "te"
        elif m in ["STARsolo", "CellBender"]:
            mode = "rna"
        else:
            continue
        d["Sample"].append(s.run)
        d["Location"].append(s.file)
        d["Modality"].append(mode)

data = pg.aggregate_matrices(d, default_ref="GRCh38", mito_prefix="MT-")

In [None]:
if "Demuxlet" in snakemake.input.keys():
    demux = []
    for s in samples["Demuxlet"].itertuples():
        df = pd.read_csv(s.file, sep="\t")
        df["BARCODE"] = s.run + "-" + df["BARCODE"]
        df = df[
            ["BARCODE", "DROPLET.TYPE", "SNG.BEST.GUESS", "DBL.BEST.GUESS", "BEST.LLK"]
        ]
        df.set_index("BARCODE", inplace=True)
        df.rename(
            columns={
                "DROPLET.TYPE": "assignment",
                "SNG.BEST.GUESS": "Demuxlet Singlet Best Guess",
                "DBL.BEST.GUESS": "Demuxlet Doublet Best Guess",
                "BEST.LLK": "Demuxlet Best Log Likelihood",
            },
            inplace=True,
        )
        demux.append(df)
    demux = pd.concat(demux)
    data.obs = data.obs.join(demux)

## Calculate QC metrics

In [None]:
# doublets
pg.infer_doublets(data, channel_attr="Channel", plot_hist=None)
pg.mark_doublets(data)

if "assignment" in data.obs.columns:
    pg.qc_metrics(
        data,
        select_singlets=True,
        subset_string="SNG",
        min_genes=snakemake.config["preprocess"]["min_genes"],
        min_umis=snakemake.config["preprocess"]["min_counts"],
        percent_mito=snakemake.config["preprocess"]["max_pct_mt"],
    )
else:
    pg.qc_metrics(
        data,
        select_singlets=True,
        min_genes=snakemake.config["preprocess"]["min_genes"],
        min_umis=snakemake.config["preprocess"]["min_counts"],
        percent_mito=snakemake.config["preprocess"]["max_pct_mt"],
    )
data.obs["log1p_n_genes"] = np.log1p(data.obs["n_genes"])
data.obs["log1p_n_counts"] = np.log1p(data.obs["n_counts"])

## Visualize

In [None]:
for n, c in zip(
    ["log1p_n_counts", "log1p_n_genes", "percent_mito", "doublet_score"],
    ["min_counts", "min_genes", "max_pct_mt", ""],
):
    assert n in data.obs.columns, f"{n} not found in df.columns"

    plt.clf()
    fig = sns.violinplot(
        data.obs, x="Channel", y=n, hue="passed_qc", split=True, inner=None
    )
    if n in ["log1p_n_counts", "log1p_n_genes"]:
        fig.axhline(
            np.log1p(snakemake.config["preprocess"][c]),
            linestyle="dashed",
            color="red",
        )
    elif n == "percent_mito":
        fig.axhline(
            snakemake.config["preprocess"][c],
            linestyle="dashed",
            color="red",
        )
    sns.despine()
    plt.show()

vars = ["percent_mito", "doublet_score"]
if "assignment" in data.obs.columns:
    vars.append("assignment")
vars.append("passed_qc")
for n in vars:
    plt.clf()
    fig = sns.relplot(
        data.obs,
        x="n_counts",
        y="n_genes",
        hue=n,
        col="Channel",
        kind="scatter",
        alpha=0.5,
    )
    fig.set(xscale="log", yscale="log")
    for ax in fig.axes[0]:
        ax.axhline(
            snakemake.config["preprocess"]["min_genes"],
            linestyle="dashed",
            color="red",
        )
        ax.axvline(
            snakemake.config["preprocess"]["min_counts"],
            linestyle="dashed",
            color="red",
        )
    sns.despine()
    plt.show()

## Filter stats

In [None]:
df_qc = pg.get_filter_stats(data)

cm = sns.color_palette("flare", as_cmap=True)
pretty_table(df_qc, cmap=cm)

## Do the filtering

In [None]:
pg.filter_data(data)

## Save

In [None]:
pg.write_output(data, snakemake.output[0])