# Remove empty droplets, lowly expressed genes, and low quality + dead cells; Mark multiplets

## Load data

In [None]:
from pathlib import Path
from mudata import MuData
import scanpy as sc
import scanpy.external as sce

# load data from snakemake inputs
if "CellBender" in snakemake.input.keys():
    adata_gene = sc.read_10x_h5(snakemake.input["CellBender"], var_names="gene_symbols")
elif "STARsolo" in snakemake.input.keys():
    adata_gene = sc.read_10x_mtx(
        Path(snakemake.input["STARsolo"]).parent, var_names="gene_symbols"
    )

# load TE counts if available
if "IRescue" in snakemake.input.keys():
    adata_te = sc.read_10x_mtx(Path(snakemake.input["IRescue"][0]).parent)
    adata_te.var.columns = ["te_symbols", "feature_types"]
    mdata = MuData({"gene": adata_gene, "te": adata_te})
    mdata["te"].var_names_make_unique()
else:
    mdata = MuData({"gene": adata_gene})

mdata["gene"].var_names_make_unique()

## Define plotting function

In [2]:
def qc_plot(adata, min_genes, max_pct_mt, min_counts, kind="gene", **kwargs):
    """Plot QC metrics for filtering"""
    import seaborn as sns
    import matplotlib.pyplot as plt

    assert kind in ["gene", "te"], "kind must be gene or te"
    assert (
        "total_counts" in adata.obs.columns
    ), "total_counts not found in adata.obs.columns"

    # start subplot
    if kind == "gene":
        _, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(
            1, 5, figsize=(12, 3), width_ratios=[0.3, 0.3, 0.3, 1, 1]
        )

        sns.violinplot(y=adata.obs["n_genes_by_counts"], ax=ax1, inner=None)
        f = sns.stripplot(
            y=adata.obs["n_genes_by_counts"], ax=ax1, color="black", alpha=0.5, size=2
        )
        f.axhline(
            min_genes,
            linestyle="dashed",
            color="blue",
        )

        sns.violinplot(y=adata.obs["pct_counts_mt"], ax=ax3, inner=None)
        f = sns.stripplot(
            y=adata.obs["pct_counts_mt"], ax=ax3, color="black", alpha=0.5, size=2
        )
        f.axhline(
            max_pct_mt,
            linestyle="dashed",
            color="red",
        )

        f = sc.pl.scatter(
            adata, x="total_counts", y="pct_counts_mt", show=False, ax=ax4
        )
        f.axhline(
            max_pct_mt,
            linestyle="dashed",
            color="red",
        )
        f.axvline(
            min_counts,
            linestyle="dashed",
            color="blue",
        )

        f = sc.pl.scatter(
            adata, x="total_counts", y="n_genes_by_counts", show=False, ax=ax5
        )
        f.axhline(
            min_genes,
            linestyle="dashed",
            color="blue",
        )
        f.axvline(
            min_counts,
            linestyle="dashed",
            color="blue",
        )

    elif kind == "te":
        _, (ax1, ax2, ax3) = plt.subplots(
            1, 3, figsize=(8, 3), width_ratios=[0.5, 0.5, 1]
        )

        sns.violinplot(y=adata.obs["n_tes_by_counts"], ax=ax1, inner=None)
        sns.stripplot(
            y=adata.obs["n_tes_by_counts"], ax=ax1, color="black", alpha=0.5, size=2
        )

        f = sc.pl.scatter(
            adata, x="total_counts", y="n_tes_by_counts", show=False, ax=ax3
        )
        f.axvline(
            min_counts,
            linestyle="dashed",
            color="blue",
        )

    sns.violinplot(y=adata.obs["total_counts"], ax=ax2, inner=None)
    f = sns.stripplot(
        y=adata.obs["total_counts"], ax=ax2, color="black", alpha=0.5, size=2
    )
    f.axhline(
        min_counts,
        linestyle="dashed",
        color="blue",
    )

    plt.subplots_adjust(wspace=0.5)

## 1. Remove empty droplets and lowly expressed genes

In [None]:
stats = open(snakemake.output["stats"], "w")

# filter empty droplets and lowly expressed genes
before = mdata["gene"].n_obs
sc.pp.filter_cells(
    mdata["gene"], min_genes=snakemake.config["preprocess"]["empty"]["min_genes"]
)
after = mdata["gene"].n_obs
stats.writelines(f"Empty droplets removed\t{before - after}")

before = mdata["gene"].n_vars
sc.pp.filter_genes(
    mdata["gene"], min_cells=snakemake.config["preprocess"]["lowexp"]["min_cells"]
)
after = mdata["gene"].n_vars
stats.writelines(f"Lowly expressed genes removed \t{before - after}")

# calculate QC metrics
mdata["gene"].var["mt"] = mdata["gene"].var_names.str.startswith(
    "MT-"
)  # annotate the group of mitochondrial genes as 'mt'
sc.pp.calculate_qc_metrics(
    mdata["gene"],
    qc_vars=["mt"],
    percent_top=None,
    log1p=False,
    inplace=True,
    var_type="genes",
)

# visualize with thresholds for low quality cells
qc_plot(mdata["gene"], **snakemake.config["preprocess"]["low_quality"], kind="gene")

### Explore TE expression

In [None]:
if "IRescue" in snakemake.input.keys():
    before = mdata["te"].n_vars
    # filter lowly expressed TEs
    sc.pp.filter_genes(
        mdata["te"], min_cells=snakemake.config["preprocess"]["lowexp"]["min_cells"]
    )
    after = mdata["te"].n_vars
    stats.writelines(f"Lowly expressed TEs removed \t{before - after}")

    # calculate QC metrics
    sc.pp.calculate_qc_metrics(
        mdata["te"], percent_top=None, log1p=False, inplace=True, var_type="tes"
    )

    # visualize with thresholds
    qc_plot(mdata["te"], **snakemake.config["preprocess"]["low_quality"], kind="te")

# 2. Remove low quality and dead cells

In [None]:
# # only run if specified in config
mdata.update()
if snakemake.config["preprocess"]["low_quality"]["activate"]:
    before = mdata["gene"].n_obs
    mdata = mdata[
        mdata.obs["gene:n_genes_by_counts"]
        > snakemake.config["preprocess"]["low_quality"]["min_genes"],
        :,
    ].copy()
    after = mdata["gene"].n_obs
    stats.writelines(f"Cells below min_genes removed \t{before - after}")

    before = mdata["gene"].n_obs
    mdata = mdata[
        mdata.obs["gene:total_counts"]
        > snakemake.config["preprocess"]["low_quality"]["min_counts"],
        :,
    ].copy()
    after = mdata["gene"].n_obs
    stats.writelines(f"Cells below min_counts removed \t{before - after}")

    before = mdata["gene"].n_obs
    mdata = mdata[
        mdata.obs["gene:pct_counts_mt"]
        < snakemake.config["preprocess"]["low_quality"]["max_pct_mt"],
        :,
    ].copy()
    after = mdata["gene"].n_obs
    stats.writelines(f"Cells above max_pct_mt removed \t{before - after}")

    # visualize again
    qc_plot(mdata["gene"], **snakemake.config["preprocess"]["low_quality"], kind="gene")

    if "IRescue" in snakemake.input.keys():
        qc_plot(mdata["te"], **snakemake.config["preprocess"]["low_quality"], kind="te")

## 3. Mark multiplets

In [None]:
# handle test data
nPCs = 3 if min(mdata["gene"].shape) < 100 else 30

# run scrublet
sce.pp.scrublet(
    mdata["gene"],
    expected_doublet_rate=snakemake.params.expected_multiplet_rate,
    n_prin_comps=nPCs,
)

# visualize results
sce.pl.scrublet_score_distribution(mdata["gene"], save=False)

## Save output

In [None]:
mdata.update()
assert "h5mu" in snakemake.output[0], "Output file must be an h5mu file"

mdata.write(snakemake.output[0])
stats.close()