In [9]:
# -.-|m { input: false, output: false, input_fold: show}

import tomlkit
import scanpy as sc
from anndata import AnnData
import pandas as pd
import numpy as np
import seaborn as sns
from pandas import DataFrame

from os import path
import session_info
import logging
from tempfile import TemporaryDirectory
from os import system
import torch

logging.basicConfig(level=logging.ERROR)

sc.set_figure_params(figsize=(6, 6), frameon=False)

In [10]:
# Add CELL_TYPIST model(s) to use
CELL_TYPIST_MODELS: list[str] = []

In [11]:
# | echo: false
# | output: false
# | warning: false

## Pipeline parameters
with open("../config.toml", "r") as f:
    config = tomlkit.parse(f.read())

In [12]:
ROOT_DIR = config["basic"]["ANALYSIS_DIR"]
DIR_SAVE = path.join(ROOT_DIR, config["basic"]["DIR_SAVE"])
COUNTS_LAYER = config["normalization"]["COUNTS_LAYER"]
CLUSTERING_COL = config["clustering"]["CLUSTERING_COL"]
TISSUE = config["basic"]["TISSUE"]
ANNOTATION_METHOD = config["annotation"]["ANNOTATION_METHOD"]
NORMAMALIZATION_LAYER = config["normalization"]["NORMALIZATION_METHOD"]

In [13]:
def cell_typist_annotate(adata: AnnData, models: list[str], inplace=True):
    import celltypist
    from celltypist import models as ctypist_models

    if len(models) == 0:
        raise ValueError("The models list are empty, enter valid model names.")

    all_models = ctypist_models.models_description().model.to_list()

    for model in models:
        if model not in all_models:
            raise ValueError("{model} not found in supported cell typist models.")

    ctypist_models.download_models(force_update=True, model=models)

    adata_celltypist = adata.copy()
    adata_celltypist.X = adata.layers[COUNTS_LAYER]
    sc.pp.normalize_per_cell(adata_celltypist, counts_per_cell_after=10**4)
    sc.pp.log1p(adata_celltypist)
    adata_celltypist.X = adata_celltypist.X.toarray()

    for model in models:
        loaded_model = ctypist_models.Model.load(model=model)
        predictions = celltypist.annotate(
            adata_celltypist, model=loaded_model, majority_voting=True
        )
        predictions_adata = predictions.to_adata()
        adata.obs["celltypist_" + model + "_label"] = predictions_adata.obs.loc[
            adata.obs.index, "majority_voting"
        ]
        adata.obs["celltypist_" + model + "_conf_score"] = predictions_adata.obs.loc[
            adata.obs.index, "conf_score"
        ]

    if not inplace:
        return adata

In [14]:
def dn_scTAB_resources():
    import os
    import pandas as pd

    from os.path import exists
    from os import system, makedirs

    resources_path = "../resources/scTAB/"
    check_point_file = "scTab-checkpoints.tar.gz"
    check_point_directory = "scTab-checkpoints"
    check_point_url = (
        "https://pklab.med.harvard.edu/felix/data/scTab-checkpoints.tar.gz"
    )
    cxg_minimal_file = "merlin_cxg_2023_05_15_sf-log1p_minimal.tar.gz"
    cxg_minimal_directory = "merlin_cxg_2023_05_15_sf-log1p_minimal"
    cxg_minimual_url = "https://pklab.med.harvard.edu/felix/data/merlin_cxg_2023_05_15_sf-log1p_minimal.tar.gz"

    if not exists(resources_path):
        makedirs(resources_path)
    if not exists(resources_path + check_point_file) and not exists(
        resources_path + check_point_directory
    ):
        system(f"wget {check_point_url} -o {check_point_file} -P {resources_path}")
    if not exists(resources_path + cxg_minimal_file) and not exists(
        resources_path + cxg_minimal_directory
    ):
        system(f"wget {cxg_minimual_url} -o {cxg_minimal_file} -P {resources_path}")

    if exists({resources_path + cxg_minimal_file}):
        system(f"tar -xzvf {resources_path+cxg_minimal_file}  -c {resources_path}")
        os.system(f"rm {resources_path+cxg_minimal_file}")

    if exists({resources_path + check_point_file}):
        system(f"tar -xzvf {resources_path+check_point_file}  -c {resources_path}")
        system(f"rm {resources_path+check_point_file}")

    genes_from_model = pd.read_parquet(
        f"{resources_path+cxg_minimal_directory}/var.parquet"
    )

    return genes_from_model


def scTAB_data_loader(
    adata: AnnData, genes_from_model: DataFrame, batchsize: int = 2048
):
    from scipy.sparse import csc_matrix
    from cellnet.utils.data_loading import streamline_count_matrix
    from cellnet.utils.data_loading import dataloader_factory

    # subset gene space only to genes used by the model
    adata = adata[
        :, adata.var.feature_name.isin(genes_from_model.feature_name).to_numpy()
    ]
    # pass the count matrix in csc_matrix to make column slicing efficient
    x_streamlined = streamline_count_matrix(
        csc_matrix(adata.X),
        adata.var.feature_name,  # change this if gene names are stored in different column
        genes_from_model.feature_name,
    )
    loader = dataloader_factory(x_streamlined, batch_size=batchsize)

    return loader


def get_scTAB_model():
    from collections import OrderedDict
    import yaml
    from cellnet.tabnet.tab_network import TabNet

    # load checkpoint
    if torch.cuda.is_available():
        ckpt = torch.load(
            "../resources/scTAB/scTab-checkpoints/scTab/run5/val_f1_macro_epoch=41_val_f1_macro=0.847.ckpt",
        )
    else:
        # map to cpu if there is not gpu available
        ckpt = torch.load(
            "../resources/scTABscTab-checkpoints/scTab/run5/val_f1_macro_epoch=41_val_f1_macro=0.847.ckpt",
            map_location=torch.device("cpu"),
        )

        # extract state_dict of tabnet model from checkpoint
        # I can do this as well and just send you the updated checkpoint file - I think this would be the best solution
        # I just put this here for completeness
        tabnet_weights = OrderedDict()
        for name, weight in ckpt["state_dict"].items():
            if "classifier." in name:
                tabnet_weights[name.replace("classifier.", "")] = weight

    with open("../resources/scTab-checkpoints/scTab/run5/hparams.yaml") as f:
        model_params = yaml.full_load(f.read())

    # initialzie model with hparams from hparams.yaml file
    tabnet = TabNet(
        input_dim=model_params["gene_dim"],
        output_dim=model_params["type_dim"],
        n_d=model_params["n_d"],
        n_a=model_params["n_a"],
        n_steps=model_params["n_steps"],
        gamma=model_params["gamma"],
        n_independent=model_params["n_independent"],
        n_shared=model_params["n_shared"],
        epsilon=model_params["epsilon"],
        virtual_batch_size=model_params["virtual_batch_size"],
        momentum=model_params["momentum"],
        mask_type=model_params["mask_type"],
    )

    # load trained weights
    tabnet.load_state_dict(tabnet_weights)
    # set model to inference mode
    tabnet.eval()

    return tabnet


def sf_log1p_norm(x):
    """Normalize each cell to have 10000 counts and apply log(x+1) transform."""

    counts = torch.sum(x, dim=1, keepdim=True)
    # avoid zero division error
    counts += counts == 0.0
    scaling_factor = 10000.0 / counts

    return torch.log1p(scaling_factor * x)

In [7]:
adata = sc.read_h5ad(path.join(DIR_SAVE, "adata.h5ad"))
# adata = sc.read_h5ad("../save/marcelo_ref.h5ad")

In [8]:
# Getting a stable counts layer to be used later, setting X to be raw count values.
if COUNTS_LAYER == "X":
    adata.layers["counts"] = adata.X.copy()
    COUNTS_LAYER = "counts"
elif COUNTS_LAYER in adata.layers.keys():
    adata.X = adata.layers[COUNTS_LAYER].copy()
else:
    raise ValueError("{COUNTS_LAYER} layer can't be found in the object")


if ANNOTATION_METHOD == "celltypist":
    cell_typist_annotate(adata, CELL_TYPIST_MODELS)


if ANNOTATION_METHOD == "scGPT":
    print(
        "please use the accelerated_annotation notebook with a GPU, TPU, or HPU present."
    )
    exit(code=0)

if ANNOTATION_METHOD == "scTAB":
    import tqdm
    from scipy.sparse import csc_matrix

    genes_from_model = dn_scTAB_resources()
    tabnet = get_scTAB_model()
    loader = scTAB_data_loader(adata, genes_from_model, batchsize=2048)

    preds = []

    with torch.no_grad():
        for batch in tqdm(loader):
            # normalize data
            x_input = sf_log1p_norm(batch[0]["X"])
            logits, _ = tabnet(x_input)
            preds.append(torch.argmax(logits, dim=1).numpy())

    preds = np.hstack(preds)
    cell_type_mapping = pd.read_parquet(
        "../resources/scTAB/merlin_cxg_2023_05_15_sf-log1p_minimal/categorical_lookup/cell_type.parquet"
    )
    preds = cell_type_mapping.loc[preds]["label"].to_numpy()
    adata.obs["scTAB_label"] = pd.Categorical(preds)

In [7]:
import os

os.path.exists("../resources/scTAB/scTab-checkpoints.tar.gz") and not os.path.exists(
    "../resources/scTAB/scTab-checkpoints.tar.gz"
)

False

In [None]:
from os.path import exists
from os import system, makedirs

resources_path = "../resources/scTAB/"
check_point_file = "scTab-checkpoints.tar.gz"
check_point_directory = "scTab-checkpoints"
check_point_url = "https://pklab.med.harvard.edu/felix/data/scTab-checkpoints.tar.gz"
cxg_minimal_file = "merlin_cxg_2023_05_15_sf-log1p_minimal.tar.gz"
cxg_minimal_directory = "merlin_cxg_2023_05_15_sf-log1p_minimal"
cxg_minimual_url = "https://pklab.med.harvard.edu/felix/data/merlin_cxg_2023_05_15_sf-log1p_minimal.tar.gz"

if not exists(resources_path):
    makedirs(resources_path)
if not exists(resources_path + check_point_file) and not exists(
    resources_path + check_point_directory
):
    system(f"wget {check_point_url} -o {check_point_file} -P {resources_path}")
if not exists(resources_path + cxg_minimal_file) and not exists(
    resources_path + cxg_minimal_directory
):
    system(f"wget {cxg_minimual_url} -o {cxg_minimal_file} -P {resources_path}")

if exists({resources_path + cxg_minimal_file}):
    system(f"tar -xzvf {resources_path+cxg_minimal_file}  -c {resources_path}")
    os.system(f"rm {resources_path+cxg_minimal_file}")

if exists({resources_path + check_point_file}):
    system(f"tar -xzvf {resources_path+check_point_file}  -c {resources_path}")
    system(f"rm {resources_path+check_point_file}")