In [None]:
from IPython.display import Markdown as md
from pathlib import Path
import numpy as np
import pandas as pd
import pegasus as pg
import matplotlib.pyplot as plt
import seaborn as sns
import os, sys

sys.path.append((os.path.abspath("../workflow")))
from src.plot_utils import pretty_table

# set custom tempdir to handle parallel writing of tempfiles
tmp = os.environ["TMPDIR"]
os.environ["TMPDIR"] = tmp + "/" + snakemake.wildcards.soloFeatures

In [None]:
md(f"# Filter report for {snakemake.wildcards.soloFeatures}")

## Read in data

In [None]:
# read in data
runsheet = pd.read_csv(snakemake.config["runsheet"], sep="\t")

d = {"Sample": [], "Location": [], "Modality": []}
for r in runsheet["run_id"].unique():
    for m in snakemake.input.keys():
        for f in snakemake.input[m]:
            if m == "IRescue":
                n = Path(f).parent.parent.name
                mode = "te"
            else:
                n = Path(f).parent.parent.parent.name
                mode = "rna"

            if r == n:
                d["Sample"].append(r)
                d["Location"].append(f)
                d["Modality"].append(mode)

data = pg.aggregate_matrices(d, default_ref="GRCh38", mito_prefix="MT-")

## Calculate QC metrics

In [None]:
# dead cells
pg.qc_metrics(
    data,
    select_singlets=True,
    min_genes=snakemake.config["preprocess"]["min_genes"],
    min_umis=snakemake.config["preprocess"]["min_counts"],
    percent_mito=snakemake.config["preprocess"]["max_pct_mt"],
)
data.obs["log1p_n_genes"] = np.log1p(data.obs["n_genes"])
data.obs["log1p_n_counts"] = np.log1p(data.obs["n_counts"])

# doublets
pg.infer_doublets(data, channel_attr="Channel", plot_hist=None)
pg.mark_doublets(data)

## Visualize

In [None]:
for n, c in zip(
    ["log1p_n_counts", "log1p_n_genes", "percent_mito", "doublet_score"],
    ["min_counts", "min_genes", "max_pct_mt", ""],
):
    assert n in data.obs.columns, f"{n} not found in df.columns"

    plt.clf()
    fig = sns.violinplot(
        data.obs, x="Channel", y=n, hue="passed_qc", split=True, inner=None
    )
    if n in ["log1p_n_counts", "log1p_n_genes"]:
        fig.axhline(
            np.log1p(snakemake.config["preprocess"][c]),
            linestyle="dashed",
            color="red",
        )
    elif n == "percent_mito":
        fig.axhline(
            snakemake.config["preprocess"][c],
            linestyle="dashed",
            color="red",
        )
    sns.despine()
    plt.show()

for n in ["percent_mito", "doublet_score", "passed_qc"]:
    plt.clf()
    fig = sns.relplot(
        data.obs,
        x="n_counts",
        y="n_genes",
        hue=n,
        col="Channel",
        kind="scatter",
        alpha=0.5,
    )
    fig.set(xscale="log", yscale="log")
    for ax in fig.axes[0]:
        ax.axhline(
            snakemake.config["preprocess"]["min_genes"],
            linestyle="dashed",
            color="red",
        )
        ax.axvline(
            snakemake.config["preprocess"]["min_counts"],
            linestyle="dashed",
            color="red",
        )
    sns.despine()
    plt.show()

## Filter stats

In [None]:
df_qc = pg.get_filter_stats(data)

cm = sns.color_palette("flare", as_cmap=True)
pretty_table(df_qc, cmap=cm)

## Do the filtering

In [None]:
pg.filter_data(data)

## Save

In [None]:
pg.write_output(data, snakemake.output[0])