In [2]:
%load_ext pretty_jupyter

In [3]:
# -.-|m { input: false, output: false, input_fold: show}

import tomlkit
import scanpy as sc
import anndata as ad
from scipy.stats import median_abs_deviation

import numpy as np
import pandas as pd
import requests

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rc_context
import patchworklib as pw
import seaborn.objects as so

from typing import List, Dict, Callable
from anndata import AnnData

from IPython.display import display
from os import walk, path, mkdir, listdir
import session_info
import logging

logging.basicConfig(level=logging.ERROR)


In [19]:
# -.-|m { input: false, output: false, input_fold: show}
"in the following cell, override the default pipeline parameters if needed"

#QC Params
CONCAT_SAMPLES: bool = True         # Concatenate all samples in one object, default: true
NMADS: int = 5                      # Number of median absolute deviations for read and gene counts.    
NMADS_MITO: int = 3                 # Number of median absolute deviations for mitochondrial genes percentage.
CORRECT_AMBIENT_RNA: bool = False   # Correct ambient RNA, uses DecontX, Currently causes multiple erros.
FILTER_DOUBLETS: bool = True        # Filter doublets using Scrublet
CELL_CYCLE_SCORE: bool = True       # Calculate cell cycle scores, based on scanpy implementation.
REGRESS: bool = False               # Regress out unwanted variables. Not recommended.
VARS_TO_REGRESS: List[str] = []     # list of regress (pct_counts_mt, pct_counts_ribo).

'in the following cell, override the default pipeline parameters if needed'

In [4]:
# Configs
## Utility Functions
def get_sample_name(file_path: str, black_list: list[str], n = 3):
    """"Function to return probable sample name from a path, it recurselvy goes through the path and returns the first element not in the black list."""
    if n == 0:
        return ""

    tmp = path.basename(file_path)
    _d = path.dirname(file_path)

    if all(entry not in tmp for entry in black_list):
        return tmp
    else:
        res = get_sample_name(_d, black_list, n-1)
    return res


def is_outlier(adata: AnnData, metric: str, nmads: int):
    
    M = adata.obs[metric]
    outlier = (M < np.median(M) - nmads * median_abs_deviation(M)) | (
        np.median(M) + nmads * median_abs_deviation(M) < M
    )
    return outlier


def read_parsebio(data_path: str):
    """Reads ParseBio

    Args:
        data_path (str): _description_

    Returns:
        _type_: _description_
    """    

    adata = sc.read_mtx(data_path + 'count_matrix.mtx')

    # reading in gene and cell data
    gene_data = pd.read_csv(data_path + 'all_genes.csv')
    cell_meta = pd.read_csv(data_path + 'cell_metadata.csv')

    # find genes with nan values and filter
    gene_data = gene_data[gene_data.gene_name.notnull()]
    notNa = gene_data.index
    notNa = notNa.to_list()

    # remove genes with nan values and assign gene names
    adata = adata[:,notNa]
    adata.var = gene_data
    adata.var.set_index('gene_name', inplace=True)
    adata.var.index.name = None
    adata.var_names_make_unique()

    # add cell meta data to anndata object
    adata.obs = cell_meta
    adata.obs.set_index('bc_wells', inplace=True)
    adata.obs.index.name = None
    adata.obs_names_make_unique()

    return adata



def human2mouse(genes: List[str]) -> List[str]:

    r = requests.post(
        url='https://biit.cs.ut.ee/gprofiler/api/orth/orth/',
        json={
            'organism':'hsapiens',
            'target':'mmusculus',
            'query':genes,
        }
        )
    df = pd.DataFrame(r.json()['result'], )
    return df.name.replace("N/A", pd.NA).dropna().to_list()




In [5]:
## Technology components

inputs: Dict[str, List|Callable] = {
          "10x":{
                 "files": ['features.tsv.gz', 'barcodes.tsv.gz', 'matrix.mtx.gz'],
                 "black_list": ["filtered_feature_bc", "raw_feature_bc", "count", "outs"],
                 "raw_name": "raw_feature_bc_matrix",
                 "function": sc.read_10x_mtx
                 },

          "ParseBio":{
                    "files": ["all_genes.csv", "cell_metadata.csv", "count_matrix.mtx"],
                    "black_list": ["DGE_filtered", "DGE_unfiltered"],
                    "function": read_parsebio
                    }
          }


qc_features_fac: Dict[str, List[str]] = {"human": {
                         "mito": ["MT-"],
                         "ribo": ["RBS", "RPL"],
                         "hb": ["^HB[^(P)]"]
                         },
               "mouse": {
                        "mito": ["mt"],
                        "ribo": ["Rps", "Rpl"],
                        "hb": ["^Hb[^(p)]"] # Validate this later
               }
                         }



In [6]:
## Pipeline parameters

with open("../config.toml", "r") as f:
    config = tomlkit.parse(f.read())

In [7]:
#Directories 
ROOT_DIR = config["basic"]["ANALYSIS_DIR"]
DIR_SAVE = path.join(ROOT_DIR, config["basic"]["DIR_SAVE"])
DIR_samples = config["basic"]["DIR_SAMPLES"]

#Basic information
TECHNOLOGY: str = config["basic"]["TECHNOLOGY"]
ORGANISM: str = config["basic"]["ORGANISM"]
AUTODISCOVER: bool = config["basic"]["auto_find"]
samples: Dict[str, str] = config["basic"]["samples"]

sample_components = inputs[TECHNOLOGY]["files"]
black_list = inputs[TECHNOLOGY]["black_list"]
read_function = inputs[TECHNOLOGY]["function"]
raw_name = inputs[TECHNOLOGY]["raw_name"]

In [8]:
# Diagnosic pipeline
## Reading files

if AUTODISCOVER and len(samples) == 0:
    files = walk(DIR_samples)
    for root, dir, files in files:
        if len(set(sample_components).difference(set(files))) == 0 and path.basename(root) != raw_name:
            samples[get_sample_name(root, black_list, 5)] = root
else:
    samples = config["basic"]["samples"]
    if len(samples) > 0:
        pass
    else:
        raise RuntimeError("No samples paths were provided, provide sample paths as a dictionary in 'config.toml'")


# Samples (auto-discovered or manually added)

In [9]:
pd.DataFrame(samples, index = list(range(len(samples)))).T

Unnamed: 0,0,1,2,3
sample4,../test/sample4/filtered_feature_bc_matrix,../test/sample4/filtered_feature_bc_matrix,../test/sample4/filtered_feature_bc_matrix,../test/sample4/filtered_feature_bc_matrix
sample3,../test/sample3/filtered_feature_bc_matrix,../test/sample3/filtered_feature_bc_matrix,../test/sample3/filtered_feature_bc_matrix,../test/sample3/filtered_feature_bc_matrix
sample2,../test/sample2/filtered_feature_bc_matrix,../test/sample2/filtered_feature_bc_matrix,../test/sample2/filtered_feature_bc_matrix,../test/sample2/filtered_feature_bc_matrix
sample1,../test/sample1/filtered_feature_bc_matrix,../test/sample1/filtered_feature_bc_matrix,../test/sample1/filtered_feature_bc_matrix,../test/sample1/filtered_feature_bc_matrix


In [10]:
adatas = {}
raw_h5 = {}
for sample_id, filename in samples.items():
    sample_adata = read_function(filename)
    sample_adata.var_names_make_unique()
    adatas[sample_id] = sample_adata


#TODO: Improve the unfiltered matrix detection heuristic
if TECHNOLOGY == "10x" and CORRECT_AMBIENT_RNA:
    for sample_id, filename in samples.items():
        files = listdir(path.dirname(filename))
        raw_file = [file for file in files if "raw_feature_bc_matrix" in file and ".h5" in file]
        if len(raw_file) == 1:
            adata_raw = sc.read_10x_h5(path.join(path.dirname(filename), raw_file[0]))
        else:
            raise ValueError("No/Multiple raw files meeting condition were found")

        adata_raw.var_names_make_unique()
        raw_h5[sample_id] = adata_raw


if CONCAT_SAMPLES:
    adata = ad.concat(adatas, label="sample", join="outer", merge="same")
    adata.obs_names_make_unique()
    del samples

if CONCAT_SAMPLES and TECHNOLOGY == "10x" and CORRECT_AMBIENT_RNA:
    adata_raw = ad.concat(raw_h5, label="sample", join="outer", merge="same")
    adata_raw.obs_names_make_unique()
    del raw_h5


[]


ValueError: No/Multiple raw files meeting condition were found

In [None]:
## Adding quality metrics


"""**TODO:**
- [ ] Handle also multple samples at once
- [ ] Make the Oulier function parameterized on sample
"""


mt_features = qc_features_fac[ORGANISM]["mito"]
rb_features = qc_features_fac[ORGANISM]["ribo"]
hb_features = qc_features_fac[ORGANISM]["hb"]


# mitochondrial genes, "MT-" for human, "Mt-" for mouse
adata.var["mt"] = adata.var_names.str.startswith(tuple(mt_features))
# ribosomal genes
adata.var["ribo"] = adata.var_names.str.startswith(tuple(rb_features))
# hemoglobin genes
adata.var["hb"] = adata.var_names.str.contains(tuple(hb_features)[0]) #Only regex is accepted

sc.pp.calculate_qc_metrics(
    adata, qc_vars=["mt", "ribo", "hb"], percent_top=[20],  inplace=True, log1p=True
)


In [None]:
#TODO: Add additional Critera for filtering based on Absolute thresholds, can it be per sample?

adata.obs["outlier"] = (
    is_outlier(adata, "log1p_total_counts", NMADS)
    | is_outlier(adata, "log1p_n_genes_by_counts", NMADS)
    #| is_outlier(adata, "pct_counts_in_top_20_genes", NMADS)
)

adata.obs["mt_outlier"] = is_outlier(adata, "pct_counts_mt", NMADS_MITO) 

In [None]:
## Ambient RNA correction
## TODO: Check if the Ambient RNA can be improved by using Batch information?

if CORRECT_AMBIENT_RNA and TECHNOLOGY == "10x":

    from os import system, remove, path
    import tempfile

    with tempfile.TemporaryDirectory(dir=".") as tmpdirname:
        # Define paths for temporary files
        sce_path = path.join(tmpdirname, "sce.h5ad")
        raw_path = path.join(tmpdirname, "raw.h5ad")
        decontx_path = path.join(tmpdirname, "decontX.h5ad")
        
        # Save adata and adata_raw to the temporary directory
        adata.write_h5ad(sce_path)
        adata_raw.write_h5ad(raw_path)
        
        # Execute R scripts with temporary file paths
        system(f"Rscript ./utils/deconx.R -s {sce_path} -r {raw_path} -o {decontx_path}")

        # Read the result back from the temporary directory
        adata = sc.read_h5ad(decontx_path)

In [None]:
## Doublet Detection

#TODO: Check real-life performance
#TODO: Check Interop with R to convert object to R & vice-versa 
if FILTER_DOUBLETS:
    sc.pp.scrublet(adata, batch_key="sample")

In [None]:
## Cell cycle Scoring
# **Not reliable, do via interop later**

adata.layers["counts"] = adata.X.copy()
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
sc.pp.scale(adata)


if CELL_CYCLE_SCORE:
    if ORGANISM in ["human", "mouse"]:

        s_genes = [x.strip() for x in open('../resources/s_genes.txt')]
        g2m_genes = [x.strip() for x in open('../resources/s_genes.txt')]


        if ORGANISM == "mouse":
            s_genes = human2mouse(s_genes)
            g2m_genes = human2mouse(g2m_genes)


        # Cell cycle scoring is not reliable and not similair to Seurat
        sc.tl.score_genes_cell_cycle(adata, s_genes=s_genes, g2m_genes=g2m_genes)
    else:
        logging.error('Organism must be either human or mouse.')

# Quality Plots

In [None]:
# sc.set_figure_params(dpi=300, color_map="viridis_r")
# sc.settings.verbosity = 0

## Basic QC plots & metrics

In [None]:
keys = ["n_genes_by_counts", "total_counts", "pct_counts_mt", "pct_counts_ribo"]

if FILTER_DOUBLETS:
    keys = keys + ["doublet_score"]

if CELL_CYCLE_SCORE:
    keys = keys + ["S_score", "G2M_score"]

ncols = 2
nrows = len(keys) // ncols + len(keys) % ncols


figsize = 4
wspace = 0.5
fig, axs = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    figsize=(ncols * figsize + figsize * wspace * (ncols - 1), nrows * figsize),
)

plt.subplots_adjust(wspace=wspace)
# Prevent the subplots from showing 
plt.close(fig)


for i, key in enumerate(keys):
    row = i // ncols
    col = i % ncols
    sc.pl.violin(adata, keys=[key], groupby="sample", stripplot=False, inner="box", ax=axs[row, col])

display(fig)


### Table of basic QC metrics

In [None]:
df1 = adata.obs.groupby("sample")[keys].agg(["mean", "median"]).round(3)
df2 = adata.obs.groupby("sample")[["sample"]].agg(["size"])
pd.concat([df1, df2], axis =1)

## Histograms




In [None]:
# Assuming 'df' is your DataFrame
df = adata.obs[keys]

# Determine the number of rows and columns for your subplot grid
n = len(df.columns)
ncols = 2
nrows = n // ncols + (n % ncols > 0)

# Create the subplots
wspace = 0.5
fig, axs = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    figsize=(ncols * figsize + figsize * wspace * (ncols - 1), nrows * figsize + wspace * (nrows - 1) ),
)

plt.subplots_adjust(wspace=wspace, hspace=wspace)

# Flatten the axes
axs = axs.flatten()

outlier_dict = {}


# Plot a histogram on each subplot
for i, col in enumerate(df.columns):
    sns.histplot(data=df, x=col, ax=axs[i], bins=300)

    if col in ["n_genes_by_counts", "total_counts"]:
        max_line = min(df[col].agg("mean")+ median_abs_deviation(df[col]) * NMADS,  df[col].agg("max")) 
        min_line = max(df[col].agg("mean") - median_abs_deviation(df[col]) * NMADS, df[col].agg("min"))
        axs[i].axvline(x=max_line, color='red', linestyle='--', linewidth = 0.7)
        axs[i].axvline(x=min_line, color='red', linestyle='--', linewidth = 0.7)

        outlier_dict[col] = (min_line, max_line)

    if col is "pct_counts_mt":
        max_line = min(df[col].agg("mean")+ median_abs_deviation(df[col]) * NMADS_MITO,  df[col].agg("max")) 
        min_line = max(df[col].agg("mean") - median_abs_deviation(df[col]) * NMADS_MITO, df[col].agg("min"))
        axs[i].axvline(x=max_line, color='red', linestyle='--', linewidth = 0.7)
        axs[i].axvline(x=min_line, color='red', linestyle='--', linewidth = 0.7)

        outlier_dict[col] = (min_line, max_line)

# Remove any unused subplots
if len(df.columns) < nrows*ncols:
    for i in range(len(df.columns), nrows*ncols):
        fig.delaxes(axs[i])

plt.tight_layout()
plt.show()

In [None]:
manual_filters = {}

## Filtering Thresholds

In [None]:
pd.DataFrame(outlier_dict, index= ("Min", "Max")).T

## Number of outliers based on provided criteria

In [None]:
df1 = pd.DataFrame(adata.obs[["outlier", "mt_outlier", "sample"]].value_counts(subset=["outlier",  "sample"]))
df2 = pd.DataFrame(adata.obs[["outlier", "mt_outlier", "sample"]].value_counts(subset=["mt_outlier",  "sample"]))

df1.columns = ["Outlier"]
df2.columns = ["MT-Outlier"]

pd.concat([df1,df2], axis=1)

## Cell filtering based on outlier function

In [None]:
# Saving The object at the last step before subsseting
if path.exists(DIR_SAVE):
    adata.write_h5ad(path.join(DIR_SAVE, "raw_adata.h5ad"))
else:
    mkdir(DIR_SAVE)
    adata.write_h5ad(path.join(DIR_SAVE, "raw_adata.h5ad"))


# Cell Filtering based on threshold
adata = adata[(~adata.obs.outlier) & (~adata.obs.mt_outlier)].copy()

In [None]:
df1 = adata.obs.groupby("sample")[keys].agg(["mean", "median"]).round(3)
df2 = adata.obs.groupby("sample")[["sample"]].agg(["size"])
pd.concat([df1, df2], axis =1)

## Scatter plots of confounders

In [None]:
# Assuming 'df' is your DataFrame
df = adata.obs[keys+["sample"]]

# Determine the number of rows and columns for your subplot grid
n = len(df.columns)
ncols = 2
nrows = 2
figsize= 4

# Create the subplots
wspace = 0.5
fig, axs = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    figsize=(ncols * figsize + figsize * wspace * (ncols - 1), nrows * figsize + wspace * (nrows - 1) ),
)

plt.subplots_adjust(wspace=wspace, hspace=wspace)

# Flatten the axes
axs = axs.flatten()

sns.scatterplot(df, x="total_counts", y="n_genes_by_counts", hue = "sample", alpha = 0.4, s = 6,  ax= axs[0])
sns.scatterplot(df, x="total_counts", y="pct_counts_mt", hue = "sample", alpha = 0.4, s = 6,  ax= axs[1])
sns.scatterplot(df, x="total_counts", y="pct_counts_ribo", hue = "sample", alpha = 0.4, s = 6,  ax= axs[2])

if FILTER_DOUBLETS:
    sns.scatterplot(df, x="total_counts", y="doublet_score", hue = "sample", alpha = 0.4, s = 6,  ax= axs[3])

plt.tight_layout()
plt.show()



## DecontX contamination

In [None]:

if CORRECT_AMBIENT_RNA and TECHNOLOGY == "10x":
    ax1 = pw.Brick(figsize=(6,6))
    ax2 = pw.Brick(figsize=(6,6))

    scatter = sns.scatterplot(pd.concat([adata.obsm["decontX_UMAP"], adata.obs], axis = 1), x = "DecontX_UMAP_1", y = "DecontX_UMAP_2", hue= "decontX_contamination" , s= 0.8, ax = ax1, palette="inferno")
    sns.move_legend(scatter, "center right", bbox_to_anchor=(1.1, 0.5), title=None, frameon=False)

    scatter = sns.scatterplot(pd.concat([adata.obsm["decontX_UMAP"], adata.obs], axis = 1), x = "DecontX_UMAP_1", y = "DecontX_UMAP_2", hue= "decontX_clusters" , s= 0.8, ax=ax2)
    sns.move_legend(scatter, "center right", bbox_to_anchor=(1.15, 0.5), title=None, frameon=False)

    ax12 = ax1+ax2
    display(ax12)

## Clustering after to cell filtering

In [None]:
sc.pp.pca(adata, n_comps=20)
sc.pp.neighbors(adata)
sc.tl.leiden(adata, key_added="groups", flavor="igraph")
sc.tl.umap(adata)

In [None]:
# -.-|m { input: false, output: true, input_fold: show}
keys = ["sample", "total_counts", "n_genes_by_counts", "pct_counts_mt"]

if FILTER_DOUBLETS:
    keys = keys + ["predicted_doublet"]

if CELL_CYCLE_SCORE:
    keys = keys + ["phase"]

figs = sc.pl.umap(adata, size= 7, color= keys, show=False, ncols = 2, color_map="inferno", sort_order=False, alpha = 0.8)

# Regression of Variables

In [None]:
# - [ ] Add error handling if the vars to regress is empty or contain non-keys
if REGRESS:
    sc.pp.regress_out(adata, keys= VARS_TO_REGRESS)

# Ambient RNA 

In [None]:
if CORRECT_AMBIENT_RNA:
    sns.violinplot(adata.obs, x = "decontX_clusters", y = "decontX_contamination")
    sns.stripplot(adata.obs, x = "decontX_clusters", y = "decontX_contamination", s = 1, c = "black")

In [None]:
## Save Result
if path.exists(DIR_SAVE):
    adata.write_h5ad(path.join(DIR_SAVE, "adata.h5ad"))
else:
    mkdir(DIR_SAVE)
    adata.write_h5ad(path.join(DIR_SAVE, "adata.h5ad"))

# Session Information

In [None]:
session_info.show()