In [1]:
# | echo: false
# | output: false
# | warning: false


import logging
from os import cpu_count, listdir, mkdir, path, walk
from typing import Dict, List

import anndata as ad
import hvplot
import hvplot.pandas
import itables
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import panel as pn
import patchworklib as pw
import scanpy as sc
import seaborn as sns
import seaborn.objects as so
import session_info
import tomlkit
from IPython.display import Markdown, display
from matplotlib import rc_context
from utils.preprocessing_funcs import *

logging.basicConfig(level=logging.ERROR)
sc.set_figure_params(dpi=100, dpi_save=300, format="png")
sc.settings.n_jobs = int(cpu_count() * 0.8)
pn.extension()

<Figure size 100x100 with 0 Axes>

In [2]:
# | echo: true
# | output: false
# | warning: false


"in the following cell, override the default pipeline parameters if needed"

# QC Params

# check QC and filter cells for each sample separately, default: True.
QC_PER_SAMPLE: bool = True

# Correct ambient RNA, uses DecontX.
CORRECT_AMBIENT_RNA: bool = False

# Filter doublets using Scrublet
FILTER_DOUBLETS: bool = False

# Calculate cell cycle scores, based on scanpy implementation.
CELL_CYCLE_SCORE: bool = True


############################################################
####################     QC Dict      ######################
############################################################

"""
Can either be a flat-dict for global threshold or a dict of dicts for each sample 
Entries in the form of col_name: [low, high] Ex: 'pct_counts_mt':[0, 20] which will be used as (low, high) pair.
Also entries can be a single number Ex: 'pct_counts_mt': 3 which be used as number of MADS as follows media -+ nmad * MAD. 
You can mix and match as needed in global config or per sample.
"""

# by default, filtering on read counts, number of features, and mitochondrial content
qc_dict: Dict[str, List] | Dict[str, Number] = {
    "n_genes_by_counts": 5,
    "total_counts": 5,
    "pct_counts_mt": 3,
}

# MODIFY_ME, Specifiy only after running diagnotics on the samples
# global config
# qc_dict: Dict[str, List] | Dict[str, List] = {
#     "pct_counts_mt": [0, 10],
#     "pct_counts_rb": [0, 10],
#     "n_genes_by_counts": [1500, 9000],
#     "total_counts": [3_000, 40_000]
# }


# # per-sample config
# qc_dict: Dict[str, List] | Dict[str, List] = {
#     "sample1": {
#         "pct_counts_mt": [0, 10],
#         "pct_counts_rb": 5,
#         "total_counts": [3_000, 40_000]
#     },
#     "sample2": {
#         "pct_counts_rb": 3,
#         "n_genes_by_counts": [1500, 9000],
#         "total_counts": [3_000, 40_000]
#     },
#     "sample3": {
#         "pct_counts_mt": [0, 10],
#         "n_genes_by_counts": [1500, 9000],
#         "total_counts": [3_000, 40_000]
#     },
#     "sample4": {
#         "pct_counts_mt": 4
#         "pct_counts_rb": 5,
#         "n_genes_by_counts": 6,
#     }
# }

In [3]:
# | echo: false
# | output: false
# | warning: false

## Pipeline parameters
with open("../config.toml", "r") as f:
    config = tomlkit.parse(f.read())

In [4]:
# | echo: false
# | output: false
# | warning: false


# Directories
ROOT_DIR = config["basic"]["ANALYSIS_DIR"]
DIR_SAVE = path.join(ROOT_DIR, config["basic"]["DIR_SAVE"])
DIR_samples = config["basic"]["DIR_SAMPLES"]

# Basic information
TECHNOLOGY: str = config["basic"]["TECHNOLOGY"]
ORGANISM: str = config["basic"]["ORGANISM"]
AUTODISCOVER: bool = config["basic"]["auto_find"]
samples: Dict[str, str] = config["basic"]["samples"]

sample_components = inputs[TECHNOLOGY]["files"]
black_list = inputs[TECHNOLOGY]["black_list"]
read_function = inputs[TECHNOLOGY]["function"]
raw_name = inputs[TECHNOLOGY]["raw_name"]


MAX_ONLY: list[str] = ["pct_counts_mt", "pct_counts_rb", "pct_counts_hb"]
LOG_TRANSFORMED: list[str] = ["total_counts", "n_genes_by_counts"]

In [5]:
# | echo: false
# | output: false
# | warning: false

# Keys

plotting_qc_keys = get_keys(qc_dict)

if len(get_keys(qc_dict)) == 0:
    raise ValueError("Couldn't get valid qc variables from QC dict.")

filter_qc_keys = plotting_qc_keys

plotting_qc_keys = plotting_qc_keys + ["pct_counts_rb"]

if FILTER_DOUBLETS and "doublet_score" not in plotting_qc_keys:
    plotting_qc_keys = plotting_qc_keys + ["doublet_score"]

if CORRECT_AMBIENT_RNA and "decontX_contamination" not in plotting_qc_keys:
    plotting_qc_keys = plotting_qc_keys + ["decontX_contamination"]

if CELL_CYCLE_SCORE and "S_score" not in plotting_qc_keys:
    plotting_qc_keys = plotting_qc_keys + ["S_score"]

if CELL_CYCLE_SCORE and "G2M_score" not in plotting_qc_keys:
    plotting_qc_keys += ["G2M_score"]

In [None]:
# | echo: false
# | warning: false


# Diagnosic pipeline
## Reading files

# TODO: Refactor to a better design, use a custom function for discovery.
if AUTODISCOVER and len(samples) == 0:
    files = walk(DIR_samples)
    for root, dir, files in files:
        if (
            len(set(sample_components).difference(set(files))) == 0
            and path.basename(root) != raw_name
        ):
            samples[get_sample_name(root, black_list, 5)] = root
else:
    samples = config["basic"]["samples"]
    if len(samples) == 0:
        raise RuntimeError(
            "No samples paths were provided, provide sample paths as a dictionary in 'config.toml'"
        )

# Samples

In [None]:
# | warning: false
itables.show(pd.DataFrame(samples, index=["sample path"]).T)

In [None]:
# | echo: false
# | output: false
# | warning: false

adatas = {}
raw_h5 = {}
for sample_id, filename in samples.items():
    sample_adata = read_function(filename)
    sample_adata.var_names_make_unique()
    adatas[sample_id] = sample_adata


if TECHNOLOGY == "10x" and CORRECT_AMBIENT_RNA:
    for sample_id, filename in samples.items():
        files = listdir(path.dirname(filename))
        raw_file = [
            file for file in files if "raw_feature_bc_matrix" in file and ".h5" in file
        ]
        if len(raw_file) == 1:
            adata_raw = sc.read_10x_h5(path.join(path.dirname(filename), raw_file[0]))
        else:
            raise ValueError("No/Multiple raw files meeting condition were found")

        adata_raw.var_names_make_unique()
        raw_h5[sample_id] = adata_raw


adata = ad.concat(adatas, label="sample", join="outer", merge="same")
adata.obs_names_make_unique()
del samples

if QC_PER_SAMPLE and TECHNOLOGY == "10x" and CORRECT_AMBIENT_RNA:
    adata_raw = ad.concat(raw_h5, label="sample", join="outer", merge="same")
    adata_raw.obs_names_make_unique()
    del raw_h5

In [34]:
## Remove

adata = sc.read_h5ad("../save/adata.h5ad")
adata

AnnData object with n_obs × n_vars = 31087 × 36601
    obs: 'sample', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_rb', 'log1p_total_counts_rb', 'pct_counts_rb', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'n_genes_by_counts_outlier', 'total_counts_outlier', 'pct_counts_mt_outlier', 'outlier', 'S_score', 'G2M_score', 'phase', 'groups', 'leiden_1.0', 'scTAB_annotation', 'scTAB_annotation_majority_voting'
    var: 'gene_ids', 'feature_types', 'mt', 'rb', 'hb', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'mean', 'std', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'CLUSTERING_COL', 'FINAL_CLUSTERED', 'groups', 'hvg', 'leiden_0.1', 'leiden_0.1_colors', 'leiden_0.2', 'leiden_0.3', 'leiden_0.4', 'leiden_0.5', 'le

In [35]:
# | echo: false
# | output: false
# | warning: false

## Adding quality metrics
mt_features = qc_features_rules[ORGANISM]["mito"]
rb_features = qc_features_rules[ORGANISM]["ribo"]
hb_features = qc_features_rules[ORGANISM]["hb"]


# mitochondrial genes, "MT-" for human, "Mt-" for mouse
adata.var["mt"] = adata.var_names.str.startswith(tuple(mt_features))
# ribosomal genes
adata.var["rb"] = adata.var_names.str.startswith(tuple(rb_features))
# hemoglobin genes
adata.var["hb"] = adata.var_names.str.contains(
    tuple(hb_features)[0]
)  # Only regex is accepted

sc.pp.calculate_qc_metrics(
    adata, qc_vars=["mt", "rb", "hb"], percent_top=[20], inplace=True, log1p=False
)

In [None]:
# | echo: false
# | output: false
# | warning: false

if all(map(lambda x: isinstance(x, (list, Number)), qc_dict.values())):
    compute_outliers_bulk(adata, qc_dict)
elif all(map(lambda x: isinstance(x, dict), qc_dict.values())):
    compute_outlier_sample(adata, qc_dict)
else:
    raise ValueError("Please provide a QC-Dict in a valid format.")

In [None]:
# | echo: false
# | output: false
# | warning: false


## Ambient RNA correction
## TODO: Check if the Ambient RNA can be improved by using Batch information?

if CORRECT_AMBIENT_RNA and TECHNOLOGY == "10x":
    import tempfile
    import urllib.request
    from os import path, remove, system

    with tempfile.TemporaryDirectory(dir=".") as tmpdirname:
        # Workaround failure to install GenomeInfoDbData using pixi
        dn_path = path.join(tmpdirname, "GenomeInfoDbData_1.2.11.tar.gz")
        dn_url = "https://bioconductor.org/packages/3.18/data/annotation/src/contrib/GenomeInfoDbData_1.2.11.tar.gz"
        urllib.request.urlretrieve(dn_url, filename=dn_path)
        system(f"R CMD INSTALL {dn_path}")

        # Define paths for temporary files
        sce_path = path.join(tmpdirname, "sce.h5ad")
        raw_path = path.join(tmpdirname, "raw.h5ad")
        decontx_path = path.join(tmpdirname, "decontX.h5ad")

        # Save adata and adata_raw to the temporary directory
        adata.write_h5ad(sce_path)
        adata_raw.write_h5ad(raw_path)

        # Execute R scripts with temporary file paths
        system(
            f"Rscript ./utils/deconx.R -s {sce_path} -r {raw_path} -o {decontx_path}"
        )

        # Read the result back from the temporary directory
        adata = sc.read_h5ad(decontx_path)

In [None]:
# | echo: false
# | output: false
# | warning: false

## Doublet Detection

# TODO: Check real-life performance
# TODO: Check Interop with R to convert object to R & vice-versa
if FILTER_DOUBLETS and TECHNOLOGY == "10x":
    sc.pp.scrublet(adata, batch_key="sample")

In [None]:
# | echo: false
# | output: false
# | warning: false

## Cell cycle Scoring
# **Not reliable, do via interop later**

adata.layers["counts"] = adata.X.copy()
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
sc.pp.scale(adata)


if CELL_CYCLE_SCORE:
    if ORGANISM in ["human", "mouse"]:
        s_genes = [x.strip() for x in open("../resources/s_genes.txt")]
        g2m_genes = [x.strip() for x in open("../resources/s_genes.txt")]

        if ORGANISM == "mouse":
            s_genes = human2mouse(s_genes)
            g2m_genes = human2mouse(g2m_genes)

        # Cell cycle scoring is not reliable and not similair to Seurat
        sc.tl.score_genes_cell_cycle(adata, s_genes=s_genes, g2m_genes=g2m_genes)
    else:
        logging.error("Organism must be either human or mouse.")

# Diagnostic Plots (prior to filtering)

## Basic QC plots & metrics

In [None]:
ncols = 2
nrows = len(plotting_qc_keys) // ncols + len(plotting_qc_keys) % ncols

figsize = 3
wspace = 0.5
hspace = 0.5
fig, axs = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    figsize=(
        ncols * figsize + figsize * wspace * (ncols - 1),
        nrows * figsize + hspace * (nrows - 1),
    ),
)
plt.subplots_adjust(wspace=wspace, hspace=hspace)

axs = axs.flatten()

# Prevent the subplots from showing
plt.close(fig)

for i, key in enumerate(plotting_qc_keys):
    sc.pl.violin(
        adata, keys=[key], groupby="sample", stripplot=False, inner="box", ax=axs[i]
    )
    axs[i].set_xticklabels(axs[i].get_xticklabels(), rotation=30)

plt.tight_layout()

if len(plotting_qc_keys) < nrows * ncols:
    for i in range(len(plotting_qc_keys), nrows * ncols):
        fig.delaxes(axs[i])

display(fig)

## Table of basic QC metrics

In [None]:
df1 = adata.obs.groupby("sample")[plotting_qc_keys].agg(["mean", "median"]).round(3)
df2 = adata.obs.groupby("sample")[["sample"]].agg(["size"])
itables.show(pd.concat([df1, df2], axis=1))

## Histograms

### Interactive

In [None]:
# | echo: false
# | warning: false


from operator import add

sample_fig_list = []
hvplot.extension("bokeh")
ls = []
df = adata.obs[plotting_qc_keys]
for key in df.columns:
    fig = (
        adata.obs[[key, "sample"]]
        .hvplot(
            kind="hist",
            bins=300,
            width=400,
            height=350,
            line_color="#A3D5FF",
            bgcolor="white",
        )
        .opts(axiswise=True)
    )
    ls.append(fig)

layout = reduce(add, ls)
layout = layout.cols(2)
sample_fig_list.append(("All samples", layout))

samples = adata.obs["sample"].unique()
samples = np.sort(samples)
for sample in samples:
    ls = []
    df = adata.obs.loc[adata.obs["sample"] == sample, plotting_qc_keys]
    for key in df.columns:
        fig = (
            df[[key]]
            .hvplot(
                kind="hist",
                bins=300,
                width=400,
                height=350,
                line_color="#A3D5FF",
                bgcolor="white",
            )
            .opts(axiswise=True)
        )
        ls.append(fig)
    layout = reduce(add, ls)
    layout = layout.cols(2)
    sample_fig_list.append((sample, layout))


pn.Tabs(*sample_fig_list)

### Static

In [None]:
import base64
import io

import holoviews as hv
import ipywidgets as widgets
from IPython.display import Markdown, display

n = len(plotting_qc_keys)
ncols = 2
nrows = n // ncols + (n % ncols > 0)


quarto_markdown = "::: {.panel-tabset}\n\n"


# Create the subplots
wspace = 0.5
fig, axs = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    figsize=(
        ncols * figsize + figsize * wspace * (ncols - 1),
        nrows * figsize + wspace * (nrows - 1),
    ),
)

plt.subplots_adjust(wspace=wspace, hspace=wspace)
# Flatten the axes
axs = axs.flatten()
outlier_dict = {}

# Plot a histogram on each subplot
for i, col in enumerate(plotting_qc_keys):
    sns.histplot(data=df, x=col, ax=axs[i], bins=300)

# Remove any unused subplots
if len(df.columns) < nrows * ncols:
    for i in range(len(df.columns), nrows * ncols):
        fig.delaxes(axs[i])

plt.tight_layout()

buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight")
plt.close()
buf.seek(0)
# Encode the image to base64
img_str = base64.b64encode(buf.read()).decode("utf-8")
# Add the image to the markdown
quarto_markdown += (
    f"### All samples \n\n![All samples](data:image/png;base64,{img_str})\n\n"
)


samples = adata.obs["sample"].unique()
samples = np.sort(samples)
for sample in samples:
    ls = []
    df = adata.obs.loc[adata.obs["sample"] == sample, plotting_qc_keys]
    fig, axs = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        figsize=(
            ncols * figsize + figsize * wspace * (ncols - 1),
            nrows * figsize + wspace * (nrows - 1),
        ),
    )

    plt.subplots_adjust(wspace=wspace, hspace=wspace)

    # Flatten the axes
    axs = axs.flatten()
    outlier_dict = {}

    # Plot a histogram on each subplot
    for i, col in enumerate(plotting_qc_keys):
        sns.histplot(data=df, x=col, ax=axs[i], bins=300)

    # Remove any unused subplots
    if len(df.columns) < nrows * ncols:
        for i in range(len(df.columns), nrows * ncols):
            fig.delaxes(axs[i])

    plt.tight_layout()

    buf = io.BytesIO()
    plt.savefig(buf, format="png", bbox_inches="tight")
    plt.close()
    buf.seek(0)
    # Encode the image to base64
    img_str = base64.b64encode(buf.read()).decode("utf-8")
    # Add the image to the markdown
    quarto_markdown += (
        f"### {sample} \n\n![{sample}](data:image/png;base64,{img_str})\n\n"
    )

quarto_markdown += ":::\n"
# Display the generated markdown
display(Markdown(quarto_markdown))

## Scatter plots of confounders

In [None]:
# | echo: false
# | warning: false

# Assuming 'df' is your DataFrame
df = adata.obs[plotting_qc_keys + ["sample"]]

# Determine the number of rows and columns for your subplot grid
n = len(df.columns)
ncols = 2
nrows = n // ncols
figwidth = 5
figheight = 4

# Create the subplots
wspace = 0.5
fig, axs = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    figsize=(
        ncols * figwidth + figwidth * wspace * (ncols - 1),
        nrows * figheight + wspace * (nrows - 1),
    ),
)

plt.subplots_adjust(wspace=wspace, hspace=wspace)

# Flatten the axes
axs = axs.flatten()

sns.scatterplot(
    df, x="total_counts", y="n_genes_by_counts", hue="sample", alpha=0.4, s=6, ax=axs[0]
)
axs[0].legend(fancybox=True, framealpha=0.8)
sns.scatterplot(
    df, x="total_counts", y="pct_counts_mt", hue="sample", alpha=0.4, s=6, ax=axs[1]
)
axs[1].legend(fancybox=True, framealpha=0.8)
sns.scatterplot(
    df, x="total_counts", y="pct_counts_rb", hue="sample", alpha=0.4, s=6, ax=axs[2]
)
axs[2].legend(fancybox=True, framealpha=0.8)
sns.move_legend(axs[0], "upper left", bbox_to_anchor=(1, 1))
sns.move_legend(axs[1], "upper left", bbox_to_anchor=(1, 1))
sns.move_legend(axs[2], "upper left", bbox_to_anchor=(1, 1))

if FILTER_DOUBLETS:
    sns.scatterplot(
        df, x="total_counts", y="doublet_score", hue="sample", alpha=0.4, s=6, ax=axs[3]
    )
    sns.move_legend(axs[3], "upper left", bbox_to_anchor=(1, 1))

if CORRECT_AMBIENT_RNA:
    sns.scatterplot(
        df,
        x="total_counts",
        y="decontX_contamination",
        hue="sample",
        alpha=0.4,
        s=6,
        ax=axs[4],
    )
    sns.move_legend(axs[4], "upper left", bbox_to_anchor=(1, 1))


[fig.delaxes(ax) for ax in axs.flatten() if not ax.has_data()]


plt.tight_layout()

## DeconX contamination

In [None]:
# | echo: false
# | warning: false

if CORRECT_AMBIENT_RNA and TECHNOLOGY == "10x":
    ax1 = pw.Brick(figsize=(6, 6))
    ax2 = pw.Brick(figsize=(6, 6))

    scatter = sns.scatterplot(
        pd.concat([adata.obsm["decontX_UMAP"], adata.obs], axis=1),
        x="DecontX_UMAP_1",
        y="DecontX_UMAP_2",
        hue="decontX_contamination",
        s=0.8,
        ax=ax1,
        palette="inferno",
    )
    sns.move_legend(
        scatter, "center right", bbox_to_anchor=(1.1, 0.5), title=None, frameon=False
    )

    scatter = sns.scatterplot(
        pd.concat([adata.obsm["decontX_UMAP"], adata.obs], axis=1),
        x="DecontX_UMAP_1",
        y="DecontX_UMAP_2",
        hue="decontX_clusters",
        s=0.8,
        ax=ax2,
    )
    sns.move_legend(
        scatter, "center right", bbox_to_anchor=(1.15, 0.5), title=None, frameon=False
    )

    ax12 = ax1 + ax2
    display(ax12)

## Effects of confounders

In [None]:
# | echo: false
# | output: false
# | warning: false

sc.pp.pca(adata, n_comps=30)
sc.pp.neighbors(adata)
sc.tl.leiden(adata, key_added="groups", flavor="leidenalg")
sc.tl.umap(adata)

In [None]:
# | echo: false
# | warning: false
figs = sc.pl.umap(
    adata,
    size=7,
    color=["sample"] + plotting_qc_keys,
    show=False,
    ncols=2,
    color_map="inferno",
    sort_order=False,
    alpha=0.8,
    hspace=0.2,
    wspace=0.3,
)


# Filtering 

## Filtering thresholds

In [None]:
# | echo: false
# | warning: false

ls = {}
for key in filter_qc_keys:
    ls[key] = (
        adata.obs.query(f"{key}_outlier == False").groupby("sample")[key].agg("min")
    )
min_df = pd.concat(ls, axis=1)

for key in filter_qc_keys:
    ls[key] = (
        adata.obs.query(f"{key}_outlier == False").groupby("sample")[key].agg("max")
    )
max_df = pd.concat(ls, axis=1)

pd.concat([min_df, max_df], keys=["min", "max"], axis=0)

## Number of outliers based on provided criteria

In [None]:
# | echo: false
# | warning: false

# TODO: Fix
df_l = []
col_l = []
for key in filter_qc_keys:
    try:
        df_l.append(
            pd.DataFrame(
                adata.obs[[f"{key}_outlier", "sample"]].value_counts().loc[(True), :]
            )
        )
        col_l.append(f"{key}_outlier")
    except:
        pass

col_l.append(f"aggregate outliers")
df_l.append(
    pd.DataFrame(adata.obs[[f"outlier", "sample"]].value_counts().loc[(True), :])
)

df_l.append(pd.DataFrame(adata.obs[["sample"]].value_counts()))
df_l[-1].index = df_l[-1].index.get_level_values(0)
col_l.append(f"Total Cells")

df = pd.concat(df_l, axis=1)
df.columns = [col_l]
df = df.loc[np.sort(df.index), :]
itables.show(df.fillna(0))

## Cell filtering based on outlier function

In [None]:
# | echo: false
# | output: false
# | warning: false

adata.obs = adata.obs.infer_objects()

# Saving The object at the last step before subsseting
if path.exists(DIR_SAVE):
    adata.write_h5ad(path.join(DIR_SAVE, "raw_adata.h5ad"))
else:
    mkdir(DIR_SAVE)
    adata.write_h5ad(path.join(DIR_SAVE, "raw_adata.h5ad"))

# Cell Filtering based on threshold


adata = adata[(~adata.obs.outlier)].copy()

# QC plots (post filtering)

## Table of basic QC metrics

In [None]:
# | echo: false
# | warning: false

df1 = adata.obs.groupby("sample")[plotting_qc_keys].agg(["mean", "median"]).round(3)
df2 = adata.obs.groupby("sample")[["sample"]].agg(["size"])
df = pd.concat([df1, df2], axis=1)
df

## Histograms

### Interactive

In [None]:
# | echo: false
# | warning: false


from operator import add

sample_fig_list = []
hvplot.extension("bokeh")
ls = []
df = adata.obs[plotting_qc_keys]
for key in df.columns:
    fig = (
        adata.obs[[key, "sample"]]
        .hvplot(
            kind="hist",
            bins=300,
            width=400,
            height=350,
            line_color="#A3D5FF",
            bgcolor="white",
        )
        .opts(axiswise=True)
    )
    ls.append(fig)

layout = reduce(add, ls)
layout = layout.cols(2)
sample_fig_list.append(("All samples", layout))

samples = adata.obs["sample"].unique()
samples = np.sort(samples)
for sample in samples:
    ls = []
    df = adata.obs.loc[adata.obs["sample"] == sample, plotting_qc_keys]
    for key in df.columns:
        fig = (
            df[[key]]
            .hvplot(
                kind="hist",
                bins=300,
                width=400,
                height=350,
                line_color="#A3D5FF",
                bgcolor="white",
            )
            .opts(axiswise=True)
        )
        ls.append(fig)
    layout = reduce(add, ls)
    layout = layout.cols(2)
    sample_fig_list.append((sample, layout))


pn.Tabs(*sample_fig_list)

### Static

In [None]:
import base64
import io

import holoviews as hv
import ipywidgets as widgets
from IPython.display import Markdown, display

df = adata.obs[plotting_qc_keys]


n = len(plotting_qc_keys)
ncols = 2
nrows = n // ncols + (n % ncols > 0)


quarto_markdown = "::: {.panel-tabset}\n\n"


# Create the subplots
wspace = 0.5
fig, axs = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    figsize=(
        ncols * figsize + figsize * wspace * (ncols - 1),
        nrows * figsize + wspace * (nrows - 1),
    ),
)

plt.subplots_adjust(wspace=wspace, hspace=wspace)
# Flatten the axes
axs = axs.flatten()
outlier_dict = {}

# Plot a histogram on each subplot
for i, col in enumerate(plotting_qc_keys):
    sns.histplot(data=df, x=col, ax=axs[i], bins=300)

# Remove any unused subplots
if len(df.columns) < nrows * ncols:
    for i in range(len(df.columns), nrows * ncols):
        fig.delaxes(axs[i])

plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight")
plt.close()
buf.seek(0)
# Encode the image to base64
img_str = base64.b64encode(buf.read()).decode("utf-8")
# Add the image to the markdown
quarto_markdown += (
    f"### All samples \n\n![All samples](data:image/png;base64,{img_str})\n\n"
)


samples = adata.obs["sample"].unique()
samples = np.sort(samples)
for sample in samples:
    ls = []
    df = adata.obs.loc[adata.obs["sample"] == sample, plotting_qc_keys]
    fig, axs = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        figsize=(
            ncols * figsize + figsize * wspace * (ncols - 1),
            nrows * figsize + wspace * (nrows - 1),
        ),
    )

    plt.subplots_adjust(wspace=wspace, hspace=wspace)

    # Flatten the axes
    axs = axs.flatten()
    outlier_dict = {}

    # Plot a histogram on each subplot
    for i, col in enumerate(plotting_qc_keys):
        sns.histplot(data=df, x=col, ax=axs[i], bins=300)

    # Remove any unused subplots
    if len(df.columns) < nrows * ncols:
        for i in range(len(df.columns), nrows * ncols):
            fig.delaxes(axs[i])

    plt.tight_layout()

    buf = io.BytesIO()
    plt.savefig(buf, format="png", bbox_inches="tight")
    plt.close()
    buf.seek(0)
    # Encode the image to base64
    img_str = base64.b64encode(buf.read()).decode("utf-8")
    # Add the image to the markdown
    quarto_markdown += (
        f"### {sample} \n\n![{sample}](data:image/png;base64,{img_str})\n\n"
    )

quarto_markdown += ":::\n"
# Display the generated markdown
display(Markdown(quarto_markdown))

## Scatter plots of confounders after filtering

In [None]:
# | echo: false
# | warning: false

# Assuming 'df' is your DataFrame
df = adata.obs[plotting_qc_keys + ["sample"]]

# Determine the number of rows and columns for your subplot grid
n = len(df.columns)
ncols = 2
nrows = n // ncols
figwidth = 5
figheight = 4

# Create the subplots
wspace = 0.5
fig, axs = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    figsize=(
        ncols * figwidth + figwidth * wspace * (ncols - 1),
        nrows * figheight + wspace * (nrows - 1),
    ),
)

plt.subplots_adjust(wspace=wspace, hspace=wspace)

# Flatten the axes
axs = axs.flatten()

sns.scatterplot(
    df, x="total_counts", y="n_genes_by_counts", hue="sample", alpha=0.4, s=6, ax=axs[0]
)
axs[0].legend(fancybox=True, framealpha=0.8)
sns.scatterplot(
    df, x="total_counts", y="pct_counts_mt", hue="sample", alpha=0.4, s=6, ax=axs[1]
)
axs[1].legend(fancybox=True, framealpha=0.8)
sns.scatterplot(
    df, x="total_counts", y="pct_counts_rb", hue="sample", alpha=0.4, s=6, ax=axs[2]
)
axs[2].legend(fancybox=True, framealpha=0.8)
sns.move_legend(axs[0], "upper left", bbox_to_anchor=(1, 1))
sns.move_legend(axs[1], "upper left", bbox_to_anchor=(1, 1))
sns.move_legend(axs[2], "upper left", bbox_to_anchor=(1, 1))

if FILTER_DOUBLETS:
    sns.scatterplot(
        df, x="total_counts", y="doublet_score", hue="sample", alpha=0.4, s=6, ax=axs[3]
    )
    sns.move_legend(axs[3], "upper left", bbox_to_anchor=(1, 1))

if CORRECT_AMBIENT_RNA:
    sns.scatterplot(
        df,
        x="total_counts",
        y="decontX_contamination",
        hue="sample",
        alpha=0.4,
        s=6,
        ax=axs[4],
    )
    sns.move_legend(axs[4], "upper left", bbox_to_anchor=(1, 1))


[fig.delaxes(ax) for ax in axs.flatten() if not ax.has_data()]


plt.tight_layout()

## DecontX contamination

In [None]:
# | echo: false
# | warning: false

if CORRECT_AMBIENT_RNA and TECHNOLOGY == "10x":
    ax1 = pw.Brick(figsize=(6, 6))
    ax2 = pw.Brick(figsize=(6, 6))

    scatter = sns.scatterplot(
        pd.concat([adata.obsm["decontX_UMAP"], adata.obs], axis=1),
        x="DecontX_UMAP_1",
        y="DecontX_UMAP_2",
        hue="decontX_contamination",
        s=0.8,
        ax=ax1,
        palette="inferno",
    )
    sns.move_legend(
        scatter, "center right", bbox_to_anchor=(1.1, 0.5), title=None, frameon=False
    )

    scatter = sns.scatterplot(
        pd.concat([adata.obsm["decontX_UMAP"], adata.obs], axis=1),
        x="DecontX_UMAP_1",
        y="DecontX_UMAP_2",
        hue="decontX_clusters",
        s=0.8,
        ax=ax2,
    )
    sns.move_legend(
        scatter, "center right", bbox_to_anchor=(1.15, 0.5), title=None, frameon=False
    )

    ax12 = ax1 + ax2
    display(ax12)

    # | echo: false
# | warning: false

In [None]:
if CORRECT_AMBIENT_RNA and TECHNOLOGY == "10x":
    with rc_context(
        {
            "figure.figsize": (
                adata.obs["decontX_clusters"].astype("int32").max() * 0.7,
                5,
            )
        }
    ):
        sns.violinplot(adata.obs, x="decontX_clusters", y="decontX_contamination")
        sns.stripplot(
            adata.obs, x="decontX_clusters", y="decontX_contamination", s=1, c="black"
        )

## Clustering after to cell filtering

In [None]:
# | echo: false
# | output: false
# | warning: false

sc.pp.pca(adata, n_comps=30)
sc.pp.neighbors(adata)
sc.tl.leiden(adata, key_added="groups", flavor="leidenalg")
sc.tl.umap(adata)

In [None]:
# | echo: false
# | warning: false
with rc_context({"figure.figsize": (5, (len(plotting_qc_keys) + 1) * 0.6)}):
    figs = sc.pl.umap(
        adata,
        size=8,
        color=["sample"] + plotting_qc_keys,
        show=False,
        ncols=2,
        color_map="inferno",
        sort_order=False,
        alpha=0.8,
        hspace=0.2,
        wspace=0.3,
    )

In [None]:
# #| echo: false
# #| output: false
# #| warning: false

# ## Regression of Variables

# # - [ ] Add error handling if the vars to regress is empty or contain non-keys
# if REGRESS:
#     sc.pp.regress_out(adata, keys= VARS_TO_REGRESS)

In [None]:
# | echo: false
# | output: false
# | warning: false

adata.obs = adata.obs.infer_objects()
## Save Result
if path.exists(DIR_SAVE):
    adata.write_h5ad(path.join(DIR_SAVE, "adata.h5ad"))
else:
    mkdir(DIR_SAVE)
    adata.write_h5ad(path.join(DIR_SAVE, "adata.h5ad"))

# Session Information

In [None]:
session_info.show()