In [None]:
from IPython.display import Markdown as md
from pathlib import Path
import tempfile
import pandas as pd
import numpy as np
import seaborn as sns
import pegasus as pg
import plotly.express as px
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import colors
import logging

logger = logging.getLogger("pegasus")
logger.setLevel(logging.WARNING)

In [None]:
md(f"# STARsolo report for {snakemake.wildcards.soloFeatures}")

In [None]:
runsheet = pd.read_csv(snakemake.config["runsheet"], sep="\t")

## STARsolo metrics

In [None]:
def read_metrics(fn: str):
    """Read and cleam STARsolo metrics file, return a pandas dataframe."""

    metrics = pd.read_csv(fn, header=None).set_index(0).transpose()
    for c in metrics.columns:
        if "GeneFull" in c:
            metrics.rename({c: c.replace("GeneFull", "Genes")}, axis=1, inplace=True)
        else:
            metrics.rename({c: c.replace("Gene", "Genes")}, axis=1, inplace=True)
    for c in [
        "Estimated Number of Cells",
        "Unique Reads in Cells Mapped to Genes",
        "Number of Reads",
        "UMIs in Cells",
        "Total Genes Detected",
        "Median UMI per Cell",
        "Median Genes per Cell",
        "Median Reads per Cell",
    ]:
        metrics[c] = metrics[c].astype(int)

    return metrics

    # sequencing = metrics.iloc[:,0:5]
    # mapping = pd.concat([metrics.iloc[:,0:1], metrics.iloc[:,5:9], metrics.iloc[:,10:12]], axis=1)
    # cells = pd.concat(metrics.iloc[:,10], metrics.iloc[:,12:], axis=1)
    # return sequencing, mapping, cells

In [None]:
summ = []
for r in runsheet["run_id"].unique():
    # read in summary metrics
    for f in snakemake.input["summary"]:
        p = Path(f)
        if r != p.parent.parent.name:
            continue
        d = read_metrics(f)
        d["run_id"] = r
    summ.append(d)

summ = pd.concat(summ).set_index(["run_id"])

In [None]:
# generate a normalized color map for each column of the dataframe
cm = sns.color_palette("flare", as_cmap=True)


def background_gradient(s, cmap="PuBu"):
    if s.min() > 0 and s.max() < 1:
        norm = colors.PowerNorm(2, vmin=0, vmax=1)
    else:
        norm = colors.PowerNorm(2, vmin=0, vmax=s.max())
    normed = norm(s.values)
    c = [colors.rgb2hex(x) for x in plt.colormaps.get_cmap(cmap)(normed)]
    return [f"background-color: {color}" for color in c]


summ.style.format(precision=2).apply(background_gradient, cmap=cm)

In [None]:
# make CSVs for raw and filtered 10x runs
for i in ["raw", "filtered"]:
    with tempfile.NamedTemporaryFile(mode="w") as file:
        print("Sample,Location", file=file)
        for r in runsheet["run_id"].unique():
            for f in snakemake.input[i]:
                # raw count matrix
                if r == Path(f).parent.parent.parent.name:
                    print(f"{r},{f}", file=file)
        file.seek(0)
        if i == "raw":
            raw = pg.aggregate_matrices(file.name)
        else:
            filtered = pg.aggregate_matrices(file.name)

df = raw.obs
del raw
df = df.loc[df["n_counts"] > 0, :]  # remove zeros
df = df.sort_values(["Channel", "n_counts"], ascending=False)
df["isEmpty"] = ~df.index.isin(filtered.obs.index)
df["rank"] = df.groupby("Channel")["n_counts"].rank("first", ascending=False)

In [None]:
# plot barcode rank
# df.reset_index(inplace=True)
px.line(
    df,
    x="rank",
    y="n_counts",
    line_group="Channel",
    color="isEmpty",
    log_x=True,
    log_y=True,
    width=800,
    height=600,
    color_discrete_sequence=["purple", "gray"],
    title="Barcode Rank Plot",
)

## Cell QC: Total Counts, Total Genes, Percent MT, and doublet scores

In [None]:
# calculate QC metrics
pg.qc_metrics(filtered, mito_prefix="MT-")

filtered.obs["log1p_n_counts"] = np.log1p(filtered.obs["n_counts"])
filtered.obs["log1p_n_genes"] = np.log1p(filtered.obs["n_genes"])
df = filtered.obs

In [None]:
plot_df = df.melt(
    id_vars=["Channel"],
    value_vars=["log1p_n_counts", "log1p_n_genes", "percent_mito"],
    var_name="metric",
)

sns.catplot(
    plot_df,
    x="Channel",
    y="value",
    row="metric",
    kind="violin",
    sharey=False,
)

In [None]:
px.scatter(
    df,
    x="n_counts",
    y="n_genes",
    facet_col="Channel",
    facet_col_wrap=df["Channel"].nunique() if df["Channel"].nunique() <= 4 else 4,
)

In [None]:
px.scatter(
    df,
    x="log1p_n_counts",
    y="log1p_n_genes",
    facet_col="Channel",
    facet_col_wrap=df["Channel"].nunique() if df["Channel"].nunique() <= 4 else 4,
)

In [None]:
# handle test data
nPCs = 3 if filtered.shape[0] < 100 else 30

# infer doublets using scrublet
pg.infer_doublets(filtered, channel_attr="Channel", n_prin_comps=nPCs)
for c in df["Channel"].unique():
    plt.close()

for c in df["Channel"].unique():
    image = mpimg.imread(f"sample.{c}.dbl.png")
    plt.imshow(image)
    plt.axis("off")
    plt.title(f"Scrublet for channel {c}")
    plt.figure(figsize=(10, 10))
    plt.show()