# Filter dataset

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import pegasus as pg
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px

In [None]:
# read in data
runsheet = pd.read_csv(snakemake.config["runsheet"], sep="\t")

d = {"Sample": [], "Location": [], "Modality": []}
for r in runsheet["run_id"].unique():
    for m in snakemake.input.keys():
        for f in snakemake.input[m]:
            if m == "IRescue":
                n = Path(f).parent.parent.name
                mode = "te"
            else:
                n = Path(f).parent.parent.parent.name
                mode = "rna"

            if r == n:
                d["Sample"].append(r)
                d["Location"].append(f)
                d["Modality"].append(mode)

data = pg.aggregate_matrices(d, default_ref="GRCh38", mito_prefix="MT-")

In [None]:
# dead cells
pg.qc_metrics(
    data,
    min_genes=snakemake.config["preprocess"]["min_genes"],
    min_umis=snakemake.config["preprocess"]["min_counts"],
    percent_mito=snakemake.config["preprocess"]["max_pct_mt"],
)
data.obs["log1p_n_genes"] = np.log1p(data.obs["n_genes"])
data.obs["log1p_n_counts"] = np.log1p(data.obs["n_counts"])

# doublets
pg.infer_doublets(data, channel_attr="Channel", plot_hist=None)
pg.mark_doublets(data)

In [None]:
def qc_plot(df):
    """Plot QC metrics for filtering"""

    for n, c in zip(
        ["log1p_n_counts", "log1p_n_genes", "percent_mito", "doublet_score"],
        ["min_counts", "min_genes", "max_pct_mt", ""],
    ):
        assert n in df.columns, f"{n} not found in df.columns"

        plt.clf()
        fig = sns.violinplot(
            df, x="Channel", y=n, hue="passed_qc", dodge=True, inner="point"
        )
        if n in ["log1p_n_counts", "log1p_n_genes"]:
            fig.axhline(
                np.log1p(snakemake.config["preprocess"][c]),
                linestyle="dashed",
                color="red",
            )
        elif n == "percent_mito":
            fig.axhline(
                snakemake.config["preprocess"][c],
                linestyle="dashed",
                color="red",
            )
        sns.despine()
        plt.show()

    for n in ["percent_mito", "doublet_score", "passed_qc"]:
        plt.clf()
        fig = px.scatter(
            df,
            x="n_counts",
            y="n_genes",
            color=n,
            facet_col="Channel",
            log_x=True,
            log_y=True,
            alpha=0.5,
        )
        fig.add_hline(
            snakemake.config["preprocess"]["min_genes"],
            line_dash="dash",
            line_color="red",
        )
        fig.add_vline(
            snakemake.config["preprocess"]["min_counts"],
            line_dash="dash",
            line_color="red",
        )
        fig.show()


qc_plot(data.obs)

In [None]:
pg.qc_metrics(
    data,
    select_singlets=True,
    min_genes=snakemake.config["preprocess"]["min_genes"],
    min_umis=snakemake.config["preprocess"]["min_counts"],
    percent_mito=snakemake.config["preprocess"]["max_pct_mt"],
)
pg.filter_data(data)

In [None]:
pg.write_output(data, snakemake.output[0])