In [None]:
# -.-|m { input: false, output: false, input_fold: show}

from os import path, system
from pathlib import Path
from typing import List, Union

import GPUtil
import matplotlib.pyplot as plt
import numpy as np
import scanpy as sc
import seaborn as sns
import session_info
import tomlkit
from anndata import AnnData
from pandas import DataFrame
from tomlkit.items import Array, String
from tomlkit.toml_document import TOMLDocument
from utils.util_funcs import cell_typist_annotate, get_marker_genes, score_markers

In [None]:
# Add CELL_TYPIST model(s) to use
CELL_TYPIST_MODELS: List[str] = []

# scGPT batch size
BATCH_SIZE: int = 128

# known marker parameters
MARKER_GENES_PATH: Path = ""
ANNOTATION_THERESHOLD: float = 0.5

In [None]:
# | echo: false
# | output: false
# | warning: false

## Pipeline parameters
with open("../config.toml", "r") as f:
    config: TOMLDocument = tomlkit.parse(f.read())

In [None]:
ROOT_DIR: String = config["basic"]["ANALYSIS_DIR"]
DIR_SAVE: String = path.join(ROOT_DIR, config["basic"]["DIR_SAVE"])
COUNTS_LAYER: String = config["normalization"]["COUNTS_LAYER"]
CLUSTERING_COL: String = config["clustering"]["CLUSTERING_COL"]
ANNOTATION_METHODS: Union[String, Array] = config["annotation"]["ANNOTATION_METHOD"]
NORMAMALIZATION_LAYER: String = config["normalization"]["NORMALIZATION_METHOD"]

In [None]:
adata: AnnData = sc.read_h5ad(path.join(DIR_SAVE, "adata.h5ad"))

In [None]:
# TODO: Keep track of anotation columns, add UMAP for each method
# TODO: save the columns where the annotation is being done.
# TODO: add Majority Voting option to marker annotation.
# TODO: add more visualization for the annotation.


def annotation_dispatcher(method: str) -> None:
    global adata
    if method == "celltypist":
        cell_typist_annotate(adata, CELL_TYPIST_MODELS)
        adata.write(path.join(DIR_SAVE, "adata.h5ad"))
    if method == "scGPT":
        deviceIDs: List = GPUtil.getAvailable()
        if len(deviceIDs) > 0:
            system(
                f"pixi run -e scgpt scgpt_annotate -i {path.join(DIR_SAVE, 'adata.h5ad')} --config {path.join(ROOT_DIR, 'config.toml')} -b {BATCH_SIZE}"
            )
            adata: AnnData = sc.read_h5ad(path.join(DIR_SAVE, "adata.h5ad"))
        else:
            print("CUDA is not available, scGPT will not be run efficiently on CPU")
            return

    if method == "scTAB":
        system(
            f"pixi run -e sctab sctab_annotate --input {path.join(DIR_SAVE, 'adata.h5ad')} --config {path.join(ROOT_DIR, 'config.toml')}"
        )
        adata: AnnData = sc.read_h5ad(path.join(DIR_SAVE, "adata.h5ad"))

    if method == "known_markers":
        df: DataFrame = get_marker_genes(MARKER_GENES_PATH, adata)
        score_markers(df, adata, 0.5)
        adata.write_h5ad(path.join(DIR_SAVE, "adata.h5ad"))


# Getting a stable counts layer to be used later, setting X to be raw count values.
if COUNTS_LAYER == "X":
    adata.layers["counts"] = adata.X.copy()
    COUNTS_LAYER = "counts"
elif COUNTS_LAYER in adata.layers.keys():
    adata.X = adata.layers[COUNTS_LAYER].copy()
else:
    raise ValueError("{COUNTS_LAYER} layer can't be found in the object")

if isinstance(ANNOTATION_METHODS, list):
    for METHOD in ANNOTATION_METHODS:
        annotation_dispatcher(METHOD, adata)
elif isinstance(ANNOTATION_METHODS, str):
    annotation_dispatcher(ANNOTATION_METHODS, adata)

# Annotated cell identities

In [None]:
annotation_keys = (
    r"scTAB_annotation|scGPT_annotation|celltypist_annotation|marker_annotation"
)

selected_annotation_columns = adata.obs.filter(
    regex=annotation_keys, axis=1
).columns.to_list()

ncols = 2
nrows = len(selected_annotation_columns) // 2 + 1
fig_width = ncols * 9
fig_height = nrows * 6
fig, axs = plt.subplots(nrows, ncols, figsize=(fig_width, fig_height))
axs = axs.flatten()
fig.subplots_adjust(hspace=0.8, wspace=0.1)

for i, key in enumerate(selected_annotation_columns):
    adata.obs[key].value_counts().head(20).plot(kind="bar", ax=axs[i])
    axs[i].set_title(key)


if len(selected_annotation_columns) < nrows * ncols:
    for i in range(len(selected_annotation_columns), nrows * ncols):
        fig.delaxes(axs[i])

## UMAP after annotation

In [None]:
ncols = 2
nrows = len(selected_annotation_columns) // 2 + 1
fig_width = ncols * 7
fig_height = nrows * 6
fig, axs = plt.subplots(nrows, ncols, figsize=(fig_width, fig_height))
axs = axs.flatten()
fig.subplots_adjust(hspace=0.5, wspace=0.8)

MAX_CLASSES = 10
for i, col in enumerate(selected_annotation_columns):
    if len(adata.obs[col].value_counts()) > MAX_CLASSES:
        top_legends = adata.obs[col].value_counts().index[0:MAX_CLASSES].to_list()
        sc.pl.umap(
            adata,
            color=col,
            ncols=1,
            show=False,
            # return_fig=True,
            groups=top_legends,
            title=f"{col}, Top {MAX_CLASSES} classes",
            ax=axs[i],
        )
    else:
        sc.pl.umap(
            adata,
            color=col,
            ncols=1,
            # return_fig=True,
            show=False,
            ax=axs[i],
        )

if len(selected_annotation_columns) < nrows * ncols:
    for i in range(len(selected_annotation_columns), nrows * ncols):
        fig.delaxes(axs[i])

# Session Information

In [None]:
session_info.show()