In [1]:
# -.-|m { input: false, output: false, input_fold: show}

import tomlkit
import scanpy as sc
from anndata import AnnData
import pandas as pd
import numpy as np
import seaborn as sns
from pandas import DataFrame

from os import path
import session_info
import logging
from tempfile import TemporaryDirectory
from os import system
import torch

logging.basicConfig(level=logging.ERROR)

sc.set_figure_params(figsize=(6, 6), frameon=False)

In [2]:
# | echo: false
# | output: false
# | warning: false

## Pipeline parameters
with open("../config.toml", "r") as f:
    config = tomlkit.parse(f.read())

In [3]:
ROOT_DIR = config["basic"]["ANALYSIS_DIR"]
DIR_SAVE = path.join(ROOT_DIR, config["basic"]["DIR_SAVE"])
COUNTS_LAYER = config["normalization"]["COUNTS_LAYER"]
CLUSTERING_COL = config["clustering"]["CLUSTERING_COL"]
TISSUE = config["basic"]["TISSUE"]
ANNOTATION_METHOD = config["annotation"]["ANNOTATION_METHOD"]
NORMAMALIZATION_LAYER = config["normalization"]["NORMALIZATION_METHOD"]

In [4]:
# TODO: Make the download process more robust to errors
def get_scTAB_resources():
    import os
    import pandas as pd
    from huggingface_hub import hf_hub_download

    resources_path = "../resources/scTAB/"

    weights = hf_hub_download(
        "MohamedMabrouk/scTab",
        "val_f1_macro_epoch=41_val_f1_macro=0.847.ckpt",
        local_dir=resources_path,
    )

    genes = hf_hub_download(
        "MohamedMabrouk/scTab",
        "var.parquet",
        subfolder="merlin_cxg_2023_05_15_sf-log1p_minimal",
        local_dir=resources_path,
    )

    hyperparams = hf_hub_download(
        "MohamedMabrouk/scTab", "hparams.yaml", local_dir=resources_path
    )

    cell_type = hf_hub_download(
        "MohamedMabrouk/scTab",
        "cell_type.parquet",
        subfolder="merlin_cxg_2023_05_15_sf-log1p_minimal/categorical_lookup",
        local_dir=resources_path,
    )

    return weights, hyperparams, genes, cell_type


def scTAB_data_loader(adata: AnnData, genes_path: str, batchsize: int = 2048):
    from scipy.sparse import csc_matrix
    from utils import streamline_count_matrix, dataloader_factory

    genes_from_model = pd.read_parquet(genes_path)

    # subset gene space only to genes used by the model
    adata = adata[
        :, adata.var.feature_name.isin(genes_from_model.feature_name).to_numpy()
    ]
    # pass the count matrix in csc_matrix to make column slicing efficient
    x_streamlined = streamline_count_matrix(
        csc_matrix(adata.X),
        adata.var.feature_name,  # change this if gene names are stored in different column
        genes_from_model.feature_name,
    )
    loader = dataloader_factory(x_streamlined, batch_size=batchsize)

    return loader


def get_scTAB_model(weights_path: str, hyperparams_path: str):
    from collections import OrderedDict
    import yaml
    from utils import TabNet

    # load checkpoint
    if torch.cuda.is_available():
        ckpt = torch.load(weights_path)
    else:
        # map to cpu if there is not gpu available
        ckpt = torch.load(
            weights_path,
            map_location=torch.device("cpu"),
        )

    # extract state_dict of tabnet model from checkpoint
    # I can do this as well and just send you the updated checkpoint file - I think this would be the best solution
    # I just put this here for completeness
    tabnet_weights = OrderedDict()
    for name, weight in ckpt["state_dict"].items():
        if "classifier." in name:
            tabnet_weights[name.replace("classifier.", "")] = weight

    with open(hyperparams_path) as f:
        model_params = yaml.full_load(f.read())

    # initialzie model with hparams from hparams.yaml file
    tabnet = TabNet(
        input_dim=model_params["gene_dim"],
        output_dim=model_params["type_dim"],
        n_d=model_params["n_d"],
        n_a=model_params["n_a"],
        n_steps=model_params["n_steps"],
        gamma=model_params["gamma"],
        n_independent=model_params["n_independent"],
        n_shared=model_params["n_shared"],
        epsilon=model_params["epsilon"],
        virtual_batch_size=model_params["virtual_batch_size"],
        momentum=model_params["momentum"],
        mask_type=model_params["mask_type"],
    )

    # load trained weights
    tabnet.load_state_dict(tabnet_weights)
    # set model to inference mode
    tabnet.eval()

    return tabnet


def sf_log1p_norm(x):
    """Normalize each cell to have 10000 counts and apply log(x+1) transform."""

    counts = torch.sum(x, dim=1, keepdim=True)
    # avoid zero division error
    counts += counts == 0.0
    scaling_factor = 10000.0 / counts

    return torch.log1p(scaling_factor * x)

In [5]:
adata = sc.read_h5ad(path.join(DIR_SAVE, "adata.h5ad"))
adata

AnnData object with n_obs × n_vars = 33131 × 36601
    obs: 'sample', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_rb', 'log1p_total_counts_rb', 'pct_counts_rb', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'n_genes_by_counts_outlier', 'total_counts_outlier', 'pct_counts_mt_outlier', 'outlier', 'decontX_contamination', 'decontX_clusters', 'n_genes', 'doublet_score', 'predicted_doublet', 'S_score', 'G2M_score', 'phase', 'groups', 'leiden_0.2', 'leiden_0.4', 'leiden_0.6', 'leiden_0.8', 'leiden_1.0', 'leiden_1.2', 'leiden_1.4', 'leiden_1.6', 'leiden_1.8', 'leiden_2.0', 'leiden_2.2', 'leiden_2.4', 'leiden_2.6', 'leiden_2.8', 'cluster'
    var: 'gene_ids', 'feature_types', 'mt', 'rb', 'hb', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'mean', 'std', 'highly_

In [6]:
# Getting a stable counts layer to be used later, setting X to be raw count values.
if COUNTS_LAYER == "X":
    adata.layers["counts"] = adata.X.copy()
    COUNTS_LAYER = "counts"
elif COUNTS_LAYER in adata.layers.keys():
    adata.X = adata.layers[COUNTS_LAYER].copy()
else:
    raise ValueError("{COUNTS_LAYER} layer can't be found in the object")


if ANNOTATION_METHOD == "celltypist":
    exit(code=0)


if ANNOTATION_METHOD == "scGPT":
    print(
        "please use the accelerated_annotation notebook with a GPU, TPU, or HPU present."
    )
    exit(code=0)

if ANNOTATION_METHOD == "scTAB":
    import tqdm

    # BUG in cellnet
    weights, hyperparams, genes, cell_type = get_scTAB_resources()
    tabnet = get_scTAB_model(weights, hyperparams)
    loader = scTAB_data_loader(adata, genes, batchsize=2048)

    preds = []

    with torch.no_grad():
        for batch in tqdm(loader):
            # normalize data
            x_input = sf_log1p_norm(batch[0]["X"])
            logits, _ = tabnet(x_input)
            preds.append(torch.argmax(logits, dim=1).numpy())

    preds = np.hstack(preds)

    cell_type_mapping = pd.read_parquet(cell_type)
    preds = cell_type_mapping.loc[preds]["label"].to_numpy()
    adata.obs["scTAB_label"] = pd.Categorical(preds)

  ckpt = torch.load(weights_path)


AttributeError: 'DataFrame' object has no attribute 'feature_name'

In [14]:
adata.var_keys()

['gene_ids',
 'feature_types',
 'mt',
 'rb',
 'hb',
 'n_cells_by_counts',
 'mean_counts',
 'log1p_mean_counts',
 'pct_dropout_by_counts',
 'total_counts',
 'log1p_total_counts',
 'mean',
 'std',
 'highly_variable',
 'highly_variable_rank',
 'means',
 'variances',
 'variances_norm']