In [None]:
import tomlkit
import scanpy as sc
import anndata as ad
from anndata import AnnData
from scipy.stats import median_abs_deviation
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import patchworklib as pw
from typing import List, Dict, Callable

**TODO**
- [x] Add function factory  
- [x] Add component factory  
- [ ] Add a away to generate reports


# Utility Functions

In [None]:
def get_sample_name(file_path: str, black_list: list[str], n = 3):
    from os import path
    """"Function to return probable sample name from a path, it recurselvy goes through the path and returns the first element not in the black list."""
    if n == 0:
        return ""

    tmp = path.basename(file_path)
    _d = path.dirname(file_path)

    if tmp not in black_list:
        return tmp
    else:
        res = get_sample_name(_d, black_list, n-1)
    return res


def is_outlier(adata: AnnData, metric: str, nmads: int):
    
    M = adata.obs[metric]
    outlier = (M < np.median(M) - nmads * median_abs_deviation(M)) | (
        np.median(M) + nmads * median_abs_deviation(M) < M
    )
    return outlier


def read_parsebio(data_path: str):
    """Reads ParseBio

    Args:
        data_path (str): _description_

    Returns:
        _type_: _description_
    """    
    import pandas as pd

    adata = sc.read_mtx(data_path + 'count_matrix.mtx')

    # reading in gene and cell data
    gene_data = pd.read_csv(data_path + 'all_genes.csv')
    cell_meta = pd.read_csv(data_path + 'cell_metadata.csv')

    # find genes with nan values and filter
    gene_data = gene_data[gene_data.gene_name.notnull()]
    notNa = gene_data.index
    notNa = notNa.to_list()

    # remove genes with nan values and assign gene names
    adata = adata[:,notNa]
    adata.var = gene_data
    adata.var.set_index('gene_name', inplace=True)
    adata.var.index.name = None
    adata.var_names_make_unique()

    # add cell meta data to anndata object
    adata.obs = cell_meta
    adata.obs.set_index('bc_wells', inplace=True)
    adata.obs.index.name = None
    adata.obs_names_make_unique()

    return adata



# Constants & data

In [None]:
DIR_base = "PROJECT_PROJECT_PATH/"
# Provide directory where to find the base directory for all samples to search for count matrix, feature, and cell metadata
DIR_samples = "/home/mohamed/Documents/Bioinformatics/scRNA_pipeline_testing/preprocessing/test" # CellRanger standard output
# ParseBio: "PROJECT_PROCESSING_PATH/results_combined


TECHNOLOGY: str = "10x"
AUTODISCOVER: bool = True
CONCAT_SAMPLES: bool = True
ORGANISM: bool = "human"
NMADS: int = 5
NMADS_MITO: int = 3
FILTER_DOUBLETS: bool = False
REGRESS: bool = False
CELL_CYCLE_SCORE: bool = True
VARS_TO_REGRESS: List[str]|None = None # Add variables to regress pct_counts_mt, pct_counts_ribo

inputs: Dict[str, List|Callable] = {
          "10x":{
                 "files": ['features.tsv.gz', 'barcodes.tsv.gz', 'matrix.mtx.gz'],
                 "black_list": ["filtered_feature_bc_matrix", "raw_feature_bc_matrix"],
                 "function": sc.read_10x_mtx
                 },

          "ParseBio":{
                    "files": ["all_genes.csv", "cell_metadata.csv", "count_matrix.mtx"],
                    "black_list": ["DGE_filtered", "DGE_unfiltered"],
                    "function": read_parsebio
                    }
          }


qc_features_fac: Dict[str, List[str]] = {"human": {
                         "mito": ["MT-"],
                         "ribo": ["RBS", "RPL"],
                         "hb": ["^HB[^(P)]"]
                         },
               "mouse": {
                        "mito": ["mt"],
                        "ribo": ["Rps", "Rps"],
                        "hb": ["^Hb[^(p)]"] # Validate this later
               }
                         }

raw_name = "raw_feature_bc_matrix"



# Pipeline

## Reading files

In [None]:
from os import walk, path

samples = {}
sample_components = inputs[TECHNOLOGY]["files"]
black_list = inputs[TECHNOLOGY]["black_list"]
read_function = inputs[TECHNOLOGY]["function"]

if AUTODISCOVER:
    files = walk(DIR_samples)
    for root, dir, files in files:
        if len(set(sample_components).difference(set(files))) == 0 and path.basename(root) != raw_name:
            samples[get_sample_name(root, black_list, 4)] = root

samples

In [None]:
adatas = {}
raw_h5 = {}
for sample_id, filename in samples.items():
    sample_adata = read_function(filename)
    sample_adata.var_names_make_unique()
    adatas[sample_id] = sample_adata

# if TECHNOLOGY == "10x":
#     for sample_id, filename in samples.items():
#         adata_raw = sc.read_10x_h5(path.join(path.dirname(filename), "raw_feature_bc_matrix.h5"))
#         adata_raw.var_names_make_unique()
#         raw_h5[sample_id] = adata_raw


if CONCAT_SAMPLES:
    adata = ad.concat(adatas, label="sample", join="outer", merge="same")
    adata.obs_names_make_unique()
    del samples

# if CONCAT_SAMPLES and TECHNOLOGY == "10x":
#     adata_raw = ad.concat(raw_h5, label="sample", join="outer", merge="same")
#     adata_raw.obs_names_make_unique()
#     del raw_h5


**TODO: Handle also multple samples at once**

## Adding quality metrics

In [None]:
mt_features = qc_features_fac[ORGANISM]["mito"]
rb_features = qc_features_fac[ORGANISM]["ribo"]
hb_features = qc_features_fac[ORGANISM]["hb"]


# mitochondrial genes, "MT-" for human, "Mt-" for mouse
adata.var["mt"] = adata.var_names.str.startswith(tuple(mt_features))
# ribosomal genes
adata.var["ribo"] = adata.var_names.str.startswith(tuple(rb_features))
# hemoglobin genes
adata.var["hb"] = adata.var_names.str.contains(tuple(hb_features)[0]) #Only regex is accepted

sc.pp.calculate_qc_metrics(
    adata, qc_vars=["mt", "ribo", "hb"], percent_top=[20],  inplace=True, log1p=True
)



In [None]:
#TODO: Add additional Critera for filtering based on Absolute thresholds, can it be per sample?

adata.obs["outlier"] = (
    is_outlier(adata, "log1p_total_counts", NMADS)
    | is_outlier(adata, "log1p_n_genes_by_counts", NMADS)
    | is_outlier(adata, "pct_counts_in_top_20_genes", NMADS)
)

adata.obs["mt_outlier"] = is_outlier(adata, "pct_counts_mt", NMADS_MITO) 

## Ambient RNA correction

### Prepare object for SoupX

In [None]:
# adata_pp = adata.copy()
# sc.pp.normalize_per_cell(adata_pp)
# sc.pp.log1p(adata_pp)
# sc.pp.pca(adata_pp)
# sc.pp.neighbors(adata_pp)
# sc.tl.leiden(adata_pp, key_added="soupx_groups")

# # Preprocess variables for SoupX
# soupx_groups = adata_pp.obs["soupx_groups"]

# del adata_pp

# cells = adata.obs_names
# genes = adata.var_names
# data = adata.X.T

# data_tod = adata_raw.X.T


**Not reliable, do via interop later**

## Cell cycle Scoring

In [None]:
# Split into 2 lists
s_genes = [x.strip() for x in open('../Resources/s_genes.txt')]
g2m_genes = [x.strip() for x in open('../Resources/s_genes.txt')]

In [None]:
adata.layers["counts"] = adata.X.copy()
sc.pp.normalize_per_cell(adata)
sc.pp.log1p(adata)
sc.pp.scale(adata)

In [None]:
# Cell cycle scoring is not reliable and not similair to Seurat
sc.tl.score_genes_cell_cycle(adata, s_genes=s_genes, g2m_genes=g2m_genes)

## Quality Plots

### Violin plots

In [None]:
sc.pl.violin(
    adata,
    ["n_genes_by_counts", "total_counts", "pct_counts_mt", "pct_counts_ribo"],
    jitter=0.4,
    multi_panel=True,
    groupby = "sample",
    stripplot = True
)


### Histograms

In [None]:
ax1 = pw.Brick(figsize=(3,2))
ax2 = pw.Brick(figsize=(3,2))
ax3 = pw.Brick(figsize=(3,2))
ax4 = pw.Brick(figsize=(3,2))

sns.histplot(data=adata.obs, x="total_counts", hue="sample", bins=300, ax = ax1)
sns.histplot(data=adata.obs, x="n_genes_by_counts", hue="sample", bins=300, ax = ax2)
sns.histplot(data=adata.obs, x="pct_counts_ribo", hue="sample", bins=300, ax = ax3)
sns.histplot(data=adata.obs, x="pct_counts_mt", hue="sample", bins=300, ax = ax4)


for ax in [ax1, ax2, ax3, ax4]:
    ax.move_legend(new_loc='upper left', bbox_to_anchor=(1.05, 1.0))
    legend = ax.legend_
    for t in legend.get_texts():
        t.set_fontsize(8)

In [None]:
(ax1+ax2)/(ax3+ax4)

### Scatter plots of confounders

In [None]:
import seaborn.objects as so
f1 = (
    so.Plot(adata.obs, x="total_counts", y="pct_counts_mt", color = "mt_outlier")
    .add(so.Dot(pointsize=5, alpha=0.4))
)
f1

In [None]:

f2 = (
    so.Plot(adata.obs, x="total_counts", y="n_genes_by_counts", color = "outlier")
    .add(so.Dot(pointsize=5, alpha=0.4))
)
f2


## Clustering prior to cell filtering

In [None]:
sc.pp.pca(adata, n_comps=20)
sc.pp.neighbors(adata)
sc.tl.leiden(adata, key_added="groups", flavor="igraph")
sc.tl.umap(adata)

In [None]:
sc.pl.umap(adata, size= 2,color= ["sample", "total_counts", "n_genes_by_counts", "pct_counts_mt"], show=False, ncols = 2)

## Cell filtering based on outlier function

In [None]:
adata = adata[(~adata.obs.outlier) & (~adata.obs.mt_outlier)].copy()

In [None]:
p1 = sc.pl.scatter(adata, "total_counts", "n_genes_by_counts", color="pct_counts_mt")

## Regression of Variables

In [None]:
if REGRESS:
    sc.pp.regress_out(adata, keys= VARS_TO_REGRESS)

## Doublet Detection

In [None]:
#TODO: Check real-life performance
#TODO: Check Interop with R to convert object to R & vice-versa 
if FILTER_DOUBLETS:
    adata = sc.pp.scrublet(adata, batch_key="sample")

## Clustering After Cell filtering 

In [None]:
sc.pp.pca(adata, n_comps=20)
sc.pp.neighbors(adata)
sc.tl.leiden(adata, key_added="groups", flavor="igraph")
sc.tl.umap(adata)

In [None]:
sc.pl.umap(adata, size= 2,color= ["sample", "total_counts", "n_genes_by_counts", "pct_counts_mt"], show=False, ncols = 2)

## Save Result

In [None]:
adata.write_h5ad("adata.h5ad")