# Fine-tuning on Pre-trained Model for Cell-type Annotation
In this tutorial, we demonstrate how to fine-tune a pre-trained model on a new dataset for the cell type annotation task. We use the Multiple Sclerosis dataset as an example and fine-tune on the pre-trained whole-body model. Please download the dataset folder from https://drive.google.com/drive/folders/1Qd42YNabzyr2pWt9xoY4cVMTAxsNBt4v?usp=sharing

We summarize the fine-tuning pipeline in the following steps, which can be used as a general recipe for finetuning on cell-type annotation tasks and beyond: 

     1. Specify hyper-parameter setup for integration task
     
     2. Load and pre-process data
     
     3. Load the pre-trained scGPT model
     
     4. Finetune scGPT with task-specific objectives
     
     5. Evaluate fine-tuned scGPT

In [None]:
!pip install scgpt scanpy

In [2]:
import torch
import logging

# # Set up a logger for this module
# logger = logging.getLogger(__name__)

#main code logger different - yash
# Create a dedicated logger for API usage tracking
api_usage_logger = logging.getLogger("torch_api_usage_logger")
api_usage_logger.setLevel(logging.INFO)  # Set the desired log level

# Optionally, add a handler if you want to output to a file or change format
handler = logging.StreamHandler()  # or logging.FileHandler('api_usage.log')
formatter = logging.Formatter('%(asctime)s - %(message)s')
handler.setFormatter(formatter)
api_usage_logger.addHandler(handler)

def _log_class_usage(klass):
    """Log class usage for API tracking.
    
    Args:
        klass: The class whose usage is being logged.
    """
    identifier = "custom_class"
    if klass and hasattr(klass, "__name__"):
        identifier += f".{klass.__name__}"

    # Log the usage of the class with Torch's internal logging mechanism
    torch._C._log_api_usage_once(identifier)

    # Optionally log the usage using standard Python logging as well
    logger.info(f"API usage logged for class: {identifier}")

In [3]:
from typing import Dict, List, Optional

import torch
import torch.nn as nn


class Vocab(nn.Module):
    __jit_unused_properties__ = ["is_jitable"]
    r"""Creates a vocab object which maps tokens to indices.

    Args:
        vocab (torch.classes.torchtext.Vocab or torchtext._torchtext.Vocab): a cpp vocab object.
    """
    # Removed 'torchtext' specific reference in docstring
    r"""Creates a vocab object which maps tokens to indices.

    Args:
        vocab (Vocab): A Python-based vocab object that implements necessary methods like `lookup_indices` etc.
    """

    def __init__(self, vocab) -> None:
        super(Vocab, self).__init__()
        self.vocab = vocab
        _log_class_usage(__class__)


    @property
    def is_jitable(self):
        return isinstance(self.vocab, torch._C.ScriptObject)

    @torch.jit.export
    def forward(self, tokens: List[str]) -> List[int]:
        r"""Calls the `lookup_indices` method

        Args:
            tokens: a list of tokens used to lookup their corresponding `indices`.

        Returns:
            The indices associated with a list of `tokens`.
        """
        return self.vocab.lookup_indices(tokens)


    @torch.jit.export
    def __len__(self) -> int:
        r"""
        Returns:
            The length of the vocab.
        """
        return len(self.vocab)


    @torch.jit.export
    def __contains__(self, token: str) -> bool:
        r"""
        Args:
            token: The token for which to check the membership.

        Returns:
            Whether the token is member of vocab or not.
        """
        return self.vocab.__contains__(token)


    def __getitem__(self, token: str) -> int:
        r"""
        Args:
            token: The token used to lookup the corresponding index.

        Returns:
            The index corresponding to the associated token.
        """
        return self.vocab[token]


    @torch.jit.export
    def set_default_index(self, index: Optional[int]) -> None:
        r"""
        Args:
            index: Value of default index. This index will be returned when OOV token is queried.
        """
        self.vocab.set_default_index(index)


    @torch.jit.export
    def get_default_index(self) -> Optional[int]:
        r"""
        Returns:
            Value of default index if it is set.
        """
        return self.vocab.get_default_index()


    @torch.jit.export
    def insert_token(self, token: str, index: int) -> None:
        r"""
        Args:
            token: The token used to lookup the corresponding index.
            index: The index corresponding to the associated token.
        Raises:
            RuntimeError: If `index` is not in range [0, Vocab.size()] or if `token` already exists in the vocab.
        """
        self.vocab.insert_token(token, index)


    @torch.jit.export
    def append_token(self, token: str) -> None:
        r"""
        Args:
            token: The token used to lookup the corresponding index.

        Raises:
            RuntimeError: If `token` already exists in the vocab
        """
        self.vocab.append_token(token)


    @torch.jit.export
    def lookup_token(self, index: int) -> str:
        r"""
        Args:
            index: The index corresponding to the associated token.

        Returns:
            token: The token used to lookup the corresponding index.

        Raises:
            RuntimeError: If `index` not in range [0, itos.size()).
        """
        return self.vocab.lookup_token(index)


    @torch.jit.export
    def lookup_tokens(self, indices: List[int]) -> List[str]:
        r"""
        Args:
            indices: The `indices` used to lookup their corresponding`tokens`.

        Returns:
            The `tokens` associated with `indices`.

        Raises:
            RuntimeError: If an index within `indices` is not int range [0, itos.size()).
        """
        return self.vocab.lookup_tokens(indices)


    @torch.jit.export
    def lookup_indices(self, tokens: List[str]) -> List[int]:
        r"""
        Args:
            tokens: the tokens used to lookup their corresponding `indices`.

        Returns:
            The 'indices` associated with `tokens`.
        """
        return self.vocab.lookup_indices(tokens)


    @torch.jit.export
    def get_stoi(self) -> Dict[str, int]:
        r"""
        Returns:
            Dictionary mapping tokens to indices.
        """
        return self.vocab.get_stoi()


    @torch.jit.export
    def get_itos(self) -> List[str]:
        r"""
        Returns:
            List mapping indices to tokens.
        """
        return self.vocab.get_itos()


    # def __prepare_scriptable__(self):
    #     r"""Return a JITable Vocab."""
    #     if not self.is_jitable:
    #         cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_, self.vocab.default_index_)
    #         return Vocab(cpp_vocab)
    #     return self

    def __prepare_scriptable__(self):
        r"""Return a JITable Vocab."""
        if not self.is_jitable:
            # Assuming the `vocab` is a Python-based vocab object with necessary methods
            cpp_vocab = Vocab(self.vocab.itos_, self.vocab.default_index_)
            return Vocab(cpp_vocab)
        return self

VocabPybind is not required if you're using your Python-based Vocab implementation.
VocabPybind would work correctly only if you are utilizing torchtext's C++ Vocab, which would involve Pybind11 integration, not applicable to your current code.

In [4]:
# %%
import copy
import gc
import json
import os
from pathlib import Path
import shutil
import sys
import time
import traceback
from typing import List, Tuple, Dict, Union, Optional
import warnings
import pandas as pd
# from . import asyn
import pickle
import torch
from anndata import AnnData
import scanpy as sc
# import scvi
import seaborn as sns
import numpy as np
import wandb
from scipy.sparse import issparse
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
# from torchtext.vocab import Vocab
# from torchtext._torchtext import (
#     Vocab as VocabPybind,
# )
from sklearn.metrics import confusion_matrix

sys.path.insert(0, "../")
import scgpt as scg
from scgpt.model import TransformerModel, AdversarialDiscriminator
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.preprocess import Preprocessor
from scgpt import SubsetsBatchSampler
from scgpt.utils import set_seed, category_str2int, eval_scib_metrics

sc.set_figure_params(figsize=(6, 6))
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')

  IPython.display.set_matplotlib_formats(*ipython_format)


## Step1: Specify hyper-parameter setup for cell-type annotation task
Listed below are some hyper-parameter recommendations for the cell-type task. Note that the CLS objective is on to facilitate cell-type classification.

In [5]:
hyperparameter_defaults = dict(
    seed=0,
    dataset_name="cite_seq",
    do_train=True,
    load_model="/kaggle/input/scgpt_human/keras/default/1/scGPT_human",
    mask_ratio=0.0,
    epochs=10, #10
    n_bins=51,
    MVC=False, # Masked value prediction for cell embedding
    ecs_thres=0.0, # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
    dab_weight=0.0,
    lr=1e-4, #1e-4
    batch_size=32, #32
    layer_size=128,
    nlayers=4,  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder #4
    nhead=4,  # number of heads in nn.MultiheadAttention #4
    dropout=0.2,  # dropout probability
    schedule_ratio=0.9,  # ratio of epochs for learning rate schedule
    save_eval_interval=1, #5
    fast_transformer=True,
    pre_norm=False,
    amp=True,  # Automatic Mixed Precision
    include_zero_gene = False,
    freeze = False, #freeze
    DSBN = False,  # Domain-spec batchnorm
)

In [6]:
wandb.finish()
!wandb login cc1eea9dcc7b7e94199b8a914fdd5edce065d074

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [7]:
run = wandb.init(
    config=hyperparameter_defaults,
    project="scGPT",
    reinit=True,
    resume = False, #yash
    settings=wandb.Settings(start_method="fork"),
)
config = wandb.config
print(config)

# Assuming `set_seed` is defined elsewhere
set_seed(config.seed)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33myashshri148[0m ([33myashshri148-indian-institute-of-technology-mandi[0m). Use [1m`wandb login --relogin`[0m to force relogin


{'seed': 0, 'dataset_name': 'cite_seq', 'do_train': True, 'load_model': '/kaggle/input/scgpt_human/keras/default/1/scGPT_human', 'mask_ratio': 0.0, 'epochs': 2, 'n_bins': 51, 'MVC': False, 'ecs_thres': 0.0, 'dab_weight': 0.0, 'lr': 0.0002, 'batch_size': 8, 'layer_size': 128, 'nlayers': 3, 'nhead': 3, 'dropout': 0.2, 'schedule_ratio': 0.9, 'save_eval_interval': 1, 'fast_transformer': True, 'pre_norm': False, 'amp': True, 'include_zero_gene': False, 'freeze': False, 'DSBN': False}


In [8]:
# settings for input and preprocessing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
mask_ratio = config.mask_ratio
mask_value = "auto"  # for masked values, now it should always be auto

include_zero_gene = config.include_zero_gene  # if True, include zero genes among hvgs in the training
max_seq_len = 3001
n_bins = config.n_bins

# input/output representation
input_style = "binned"  # "normed_raw", "log1p", or "binned"
output_style = "binned"  # "normed_raw", "log1p", or "binned"

# settings for training
MLM = False  # whether to use masked language modeling, currently it is always on.
CLS = True  # celltype classification objective
ADV = False  # Adversarial training for batch correction
CCE = False  # Contrastive cell embedding objective
MVC = config.MVC  # Masked value prediction for cell embedding
ECS = config.ecs_thres > 0  # Elastic cell similarity objective
DAB = False  # Domain adaptation by reverse backpropagation, set to 2 for separate optimizer
INPUT_BATCH_LABELS = False  # TODO: have these help MLM and MVC, while not to classifier
input_emb_style = "continuous"  # "category" or "continuous" or "scaling"
cell_emb_style = "cls"  # "avg-pool" or "w-pool" or "cls"
adv_E_delay_epochs = 0  # delay adversarial training on encoder for a few epochs
adv_D_delay_epochs = 0
mvc_decoder_style = "inner product"
ecs_threshold = config.ecs_thres
dab_weight = config.dab_weight

explicit_zero_prob = MLM and include_zero_gene  # whether explicit bernoulli for zeros
do_sample_in_train = False and explicit_zero_prob  # sample the bernoulli in training

per_seq_batch_sample = False

# settings for optimizer
lr = config.lr  # TODO: test learning rate ratio between two tasks
lr_ADV = 1e-3  # learning rate for discriminator, used when ADV is True
batch_size = config.batch_size
eval_batch_size = config.batch_size
epochs = config.epochs
schedule_interval = 1

# settings for the model
fast_transformer = config.fast_transformer
fast_transformer_backend = "flash"  # "linear" or "flash"
embsize = config.layer_size  # embedding dimension
d_hid = config.layer_size  # dimension of the feedforward network in TransformerEncoder
nlayers = config.nlayers  # number of TransformerEncoderLayer in TransformerEncoder
nhead = config.nhead  # number of heads in nn.MultiheadAttention
dropout = config.dropout  # dropout probability

# logging
log_interval = 100  # iterations
save_eval_interval = config.save_eval_interval  # epochs
do_eval_scib_metrics = True

In [9]:
# %% validate settings
assert input_style in ["normed_raw", "log1p", "binned"]
assert output_style in ["normed_raw", "log1p", "binned"]
assert input_emb_style in ["category", "continuous", "scaling"]
if input_style == "binned":
    if input_emb_style == "scaling":
        raise ValueError("input_emb_style `scaling` is not supported for binned input.")
elif input_style == "log1p" or input_style == "normed_raw":
    if input_emb_style == "category":
        raise ValueError(
            "input_emb_style `category` is not supported for log1p or normed_raw input."
        )

if input_emb_style == "category":
    mask_value = n_bins + 1
    pad_value = n_bins  # for padding gene expr values
    n_input_bins = n_bins + 2
else:
    mask_value = -1
    pad_value = -2
    n_input_bins = n_bins

if ADV and DAB:
    raise ValueError("ADV and DAB cannot be both True.")
DAB_separate_optim = True if DAB > 1 else False

In [10]:
dataset_name = config.dataset_name
save_dir = Path(f"./save/dev_{dataset_name}-{time.strftime('%b%d-%H-%M')}_logfile/")
save_dir.mkdir(parents=True, exist_ok=True)
print(f"save to {save_dir}")
logger = scg.logger
scg.utils.add_file_handler(logger, save_dir / "run.log")

save to save/dev_cite_seq-Mar08-09-43_logfile


## Step 2: Load and pre-process data
We follow the standard scGPT data pre-processing pipelines for the cell-type annotation task. Note that since now we have two datasets at hand (i.e., reference and query data), the same pre-prpocessing steps need to be applied to both of them.

In [11]:
import scanpy as sc

# Load the dataset
adata = sc.read("/kaggle/input/citeseqscrnaseqproteins-challenge-neurips2021/GSE194122_openproblems_neurips2021_cite_BMMC_processed.h5ad")

# Print basic summary of the dataset
print(adata)

# Check the available columns in the observations (cell-level metadata)
print(adata.obs.columns)

# Check the available columns in the variables (gene or protein-level metadata)
print(adata.var.columns)

# If you want to inspect some of the data
print(adata.obs.head())  # Preview the cell-level metadata
print(adata.var.head())  # Preview the gene/protein-level metadata


AnnData object with n_obs × n_vars = 90261 × 14087
    obs: 'GEX_n_genes_by_counts', 'GEX_pct_counts_mt', 'GEX_size_factors', 'GEX_phase', 'ADT_n_antibodies_by_counts', 'ADT_total_counts', 'ADT_iso_count', 'cell_type', 'batch', 'ADT_pseudotime_order', 'GEX_pseudotime_order', 'Samplename', 'Site', 'DonorNumber', 'Modality', 'VendorLot', 'DonorID', 'DonorAge', 'DonorBMI', 'DonorBloodType', 'DonorRace', 'Ethnicity', 'DonorGender', 'QCMeds', 'DonorSmoker', 'is_train'
    var: 'feature_types', 'gene_id'
    uns: 'dataset_id', 'genome', 'organism'
    obsm: 'ADT_X_pca', 'ADT_X_umap', 'ADT_isotype_controls', 'GEX_X_pca', 'GEX_X_umap'
    layers: 'counts'
Index(['GEX_n_genes_by_counts', 'GEX_pct_counts_mt', 'GEX_size_factors',
       'GEX_phase', 'ADT_n_antibodies_by_counts', 'ADT_total_counts',
       'ADT_iso_count', 'cell_type', 'batch', 'ADT_pseudotime_order',
       'GEX_pseudotime_order', 'Samplename', 'Site', 'DonorNumber', 'Modality',
       'VendorLot', 'DonorID', 'DonorAge', 'DonorBM

In [12]:
import scanpy as sc
import numpy as np

def balanced_downsample(adata, percentile=50, rare_threshold_ratio=0.5, min_cells=200, random_state=42):
    """
    Downsample dataset to balance cell types without extreme reduction.

    Parameters:
    - adata: AnnData object containing single-cell data.
    - percentile: Percentile to determine downsampling target.
    - rare_threshold_ratio: Threshold for rare classes (relative to mean count).
    - min_cells: Minimum number of cells per class.
    - random_state: Seed for reproducibility.

    Returns:
    - Balanced AnnData object.
    """
    np.random.seed(random_state)

    # Compute stats
    cell_counts = adata.obs["celltype"].value_counts()
    avg_count = cell_counts.mean()
    median_count = cell_counts.median()
    q1_count = int(np.percentile(cell_counts, percentile))

    # Determine adaptive target
    target_count = max(q1_count, median_count, int(rare_threshold_ratio * avg_count), min_cells)

    # Ensure target_count is an integer
    target_count = int(target_count)

    print(f"Avg: {avg_count:.2f}, Median: {median_count}, Q1 ({percentile}%): {q1_count}, Target: {target_count}")

    balanced_indices = []
    for cell_type, count in cell_counts.items():
        cell_indices = adata.obs[adata.obs["celltype"] == cell_type].index

        # Keep rare classes untouched, only downsample large ones
        if count <= target_count:
            sampled_indices = cell_indices
        else:
            sampled_indices = np.random.choice(cell_indices, target_count, replace=False)

        balanced_indices.extend(sampled_indices)

    # Create balanced dataset
    balanced_adata = adata[balanced_indices].copy()
    
    print(f"New balanced dataset shape: {balanced_adata.shape}")
    return balanced_adata


In [13]:
# Define dataset
dataset_name = "cite_seq"
if dataset_name == "cite_seq":
    # data_dir = Path("../data/cite_seq")
    
    # Load the dataset
    adata = sc.read("/kaggle/input/citeseqscrnaseqproteins-challenge-neurips2021/GSE194122_openproblems_neurips2021_cite_BMMC_processed.h5ad")
    # Splitting into train and test based on 'is_train' column
    adata_train = adata[adata.obs["is_train"] == "train"].copy()
    adata_test = adata[adata.obs["is_train"] == "test"].copy()

    print(adata_train.var.index.duplicated().sum())  # Number of duplicated genes in train
    print(adata_test.var.index.duplicated().sum())  # Number of duplicated genes in test
    adata_train = adata_train[:, ~adata_train.var.index.duplicated()].copy()
    adata_test = adata_test[:, ~adata_test.var.index.duplicated()].copy()

    # Assign cell types and batch IDs
    adata_train.obs["celltype"] = adata_train.obs["cell_type"].astype("category")
    adata_test.obs["celltype"] = adata_test.obs["cell_type"].astype("category")
    # Encode batch IDs properly
    adata_train.obs["batch_id"] = adata_train.obs["is_train"].astype("category")
    adata_test.obs["batch_id"] = adata_test.obs["is_train"].astype("category")
    # Convert batch to numerical encoding
    adata_train.obs["str_batch"] = adata_train.obs["batch_id"].cat.codes.astype(str)
    adata_test.obs["str_batch"] = adata_test.obs["batch_id"].cat.codes.astype(str)

    # Align genes in train & test before merging
    common_genes = adata_train.var.index.intersection(adata_test.var.index)
    adata_train = adata_train[:, common_genes].copy()
    adata_test = adata_test[:, common_genes].copy()

    # Apply balancing function to train dataset
    adata_train = balanced_downsample(adata_train)    
    # Apply balancing function to test dataset
    adata_test = balanced_downsample(adata_test)
    
    # Flags for downstream use
    data_is_raw = False
    filter_gene_by_counts = False
    # Keep raw test dataset
    adata_test_raw = adata_test.copy()
    # Merge datasets
    adata = adata_train.concatenate(adata_test, batch_key="str_batch")

# Assign batch and cell type IDs for further processing
batch_id_labels = adata.obs["str_batch"].astype("category").cat.codes.values
adata.obs["batch_id"] = batch_id_labels
# Store unique cell types and mapping
celltype_id_labels = adata.obs["celltype"].astype("category").cat.codes.values
celltypes = adata.obs["celltype"].unique()
num_types = len(np.unique(celltype_id_labels))
id2type = dict(enumerate(adata.obs["celltype"].astype("category").cat.categories))
adata.obs["celltype_id"] = celltype_id_labels
# Store gene names
adata.var["gene_id"] = adata.var.index.tolist()

36
36
📊 Avg: 1470.56, Median: 725.0, Q1 (50%): 725, Target: 735
✅ New balanced dataset shape: (22813, 14051)
📊 Avg: 376.65, Median: 235.5, Q1 (50%): 235, Target: 235
✅ New balanced dataset shape: (7291, 14051)


In [14]:
if config.load_model is not None:
    model_dir = "/kaggle/input/scgpt_human/keras/default/1/scGPT_human"
    model_config_file = model_dir +"/"+ "args.json"
    model_file = model_dir +"/"+ "best_model.pt"
    vocab_file = model_dir +"/"+ "vocab.json"

    vocab = GeneVocab.from_file(vocab_file)
    shutil.copy(vocab_file, save_dir / "vocab.json")
    for s in special_tokens:
        if s not in vocab:
            vocab.append_token(s)

    adata.var["id_in_vocab"] = [
        1 if gene in vocab else -1 for gene in adata.var["gene_id"]
    ]
    gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
    logger.info(
        f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
        f"in vocabulary of size {len(vocab)}."
    )
    adata = adata[:, adata.var["id_in_vocab"] >= 0]

    # model
    with open(model_config_file, "r") as f:
        model_configs = json.load(f)
    logger.info(
        f"Resume model from {model_file}, the model args will override the "
        f"config {model_config_file}."
    )
    embsize = model_configs["embsize"]
    nhead = model_configs["nheads"]
    d_hid = model_configs["d_hid"]
    nlayers = model_configs["nlayers"]
    n_layers_cls = model_configs["n_layers_cls"]

scGPT - INFO - match 12588/14051 genes in vocabulary of size 60697.
scGPT - INFO - Resume model from /kaggle/input/scgpt_human/keras/default/1/scGPT_human/best_model.pt, the model args will override the config /kaggle/input/scgpt_human/keras/default/1/scGPT_human/args.json.


In [None]:
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
import scanpy as sc

# set up the preprocessor
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=filter_gene_by_counts,  # step 1
    filter_cell_by_counts=False,  # step 2
    normalize_total=1e4,  # 3. normalize the raw data to this sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=data_is_raw,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=False,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    binning=n_bins,  # 6. bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)

# Subsetting adata
adata_test = adata[adata.obs["str_batch"] == "1"]
adata = adata[adata.obs["str_batch"] == "0"]

# Preprocessing the data
preprocessor(adata, batch_key=None)
preprocessor(adata_test, batch_key=None)

# Feature Selection: Selecting Highly Variable Genes (HVGs)
sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=2000, subset=True)

# Handling Missing Data: Imputation using SimpleImputer
imputer = SimpleImputer(strategy='mean')  # Or use 'median' or 'most_frequent' based on the dataset characteristics
adata.X = imputer.fit_transform(adata.X)

# Normalize features (standardize genes' expression)
scaler = StandardScaler()
adata.X = scaler.fit_transform(adata.X)

# Run preprocessing steps for adata_test (the test batch)
preprocessor(adata_test, batch_key=None)

In [15]:
# # set up the preprocessor, use the args to config the workflow
# preprocessor = Preprocessor(
#     use_key="X",  # the key in adata.layers to use as raw data
#     filter_gene_by_counts=filter_gene_by_counts,  # step 1
#     filter_cell_by_counts=False,  # step 2
#     normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
#     result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
#     log1p=data_is_raw,  # 4. whether to log1p the normalized data
#     result_log1p_key="X_log1p",
#     subset_hvg=False,  # 5. whether to subset the raw data to highly variable genes
#     hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
#     binning=n_bins,  # 6. whether to bin the raw data and to what number of bins
#     result_binned_key="X_binned",  # the key in adata.layers to store the binned data
# )


# adata_test = adata[adata.obs["str_batch"] == "1"]
# adata = adata[adata.obs["str_batch"] == "0"]

# preprocessor(adata, batch_key=None)
# preprocessor(adata_test, batch_key=None)

scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Binning data ...
scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Binning data ...


In [16]:
input_layer_key = {  # the values of this map coorespond to the keys in preprocessing
    "normed_raw": "X_normed",
    "log1p": "X_normed",
    "binned": "X_binned",
}[input_style]
all_counts = (
    adata.layers[input_layer_key].A
    if issparse(adata.layers[input_layer_key])
    else adata.layers[input_layer_key]
)
genes = adata.var["gene_id"].tolist()

celltypes_labels = adata.obs["celltype_id"].tolist()  # make sure count from 0
celltypes_labels = np.array(celltypes_labels)

batch_ids = adata.obs["batch_id"].tolist()
num_batch_types = len(set(batch_ids))
batch_ids = np.array(batch_ids)

(
    train_data,
    valid_data,
    train_celltype_labels,
    valid_celltype_labels,
    train_batch_labels,
    valid_batch_labels,
) = train_test_split(
    all_counts, celltypes_labels, batch_ids, test_size=0.1, shuffle=True
)

In [17]:
if config.load_model is None:
    vocab = Vocab(
        VocabPybind(genes + special_tokens, None)
    )  # bidirectional lookup [gene <-> int]
vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(vocab(genes), dtype=int)

In [18]:
tokenized_train = tokenize_and_pad_batch(
    train_data,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,  # append <cls> token at the beginning
    include_zero_gene=include_zero_gene,
)
tokenized_valid = tokenize_and_pad_batch(
    valid_data,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,
    include_zero_gene=include_zero_gene,
)
logger.info(
    f"train set number of samples: {tokenized_train['genes'].shape[0]}, "
    f"\n\t feature length: {tokenized_train['genes'].shape[1]}"
)
logger.info(
    f"valid set number of samples: {tokenized_valid['genes'].shape[0]}, "
    f"\n\t feature length: {tokenized_valid['genes'].shape[1]}"
)

scGPT - INFO - train set number of samples: 20531, 
	 feature length: 3001
scGPT - INFO - valid set number of samples: 2282, 
	 feature length: 3001


In [19]:
def prepare_data(sort_seq_batch=False) -> Tuple[Dict[str, torch.Tensor]]:
    masked_values_train = random_mask_value(
        tokenized_train["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )
    masked_values_valid = random_mask_value(
        tokenized_valid["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )
    print(
        f"random masking at epoch {epoch:3d}, ratio of masked values in train: ",
        f"{(masked_values_train == mask_value).sum() / (masked_values_train - pad_value).count_nonzero():.4f}",
    )

    input_gene_ids_train, input_gene_ids_valid = (
        tokenized_train["genes"],
        tokenized_valid["genes"],
    )
    input_values_train, input_values_valid = masked_values_train, masked_values_valid
    target_values_train, target_values_valid = (
        tokenized_train["values"],
        tokenized_valid["values"],
    )

    tensor_batch_labels_train = torch.from_numpy(train_batch_labels).long()
    tensor_batch_labels_valid = torch.from_numpy(valid_batch_labels).long()

    tensor_celltype_labels_train = torch.from_numpy(train_celltype_labels).long()
    tensor_celltype_labels_valid = torch.from_numpy(valid_celltype_labels).long()

    if sort_seq_batch:  # TODO: update to random pick seq source in each traning batch
        train_sort_ids = np.argsort(train_batch_labels)
        input_gene_ids_train = input_gene_ids_train[train_sort_ids]
        input_values_train = input_values_train[train_sort_ids]
        target_values_train = target_values_train[train_sort_ids]
        tensor_batch_labels_train = tensor_batch_labels_train[train_sort_ids]
        tensor_celltype_labels_train = tensor_celltype_labels_train[train_sort_ids]

        valid_sort_ids = np.argsort(valid_batch_labels)
        input_gene_ids_valid = input_gene_ids_valid[valid_sort_ids]
        input_values_valid = input_values_valid[valid_sort_ids]
        target_values_valid = target_values_valid[valid_sort_ids]
        tensor_batch_labels_valid = tensor_batch_labels_valid[valid_sort_ids]
        tensor_celltype_labels_valid = tensor_celltype_labels_valid[valid_sort_ids]

    train_data_pt = {
        "gene_ids": input_gene_ids_train,
        "values": input_values_train,
        "target_values": target_values_train,
        "batch_labels": tensor_batch_labels_train,
        "celltype_labels": tensor_celltype_labels_train,
    }
    valid_data_pt = {
        "gene_ids": input_gene_ids_valid,
        "values": input_values_valid,
        "target_values": target_values_valid,
        "batch_labels": tensor_batch_labels_valid,
        "celltype_labels": tensor_celltype_labels_valid,
    }

    return train_data_pt, valid_data_pt


# dataset
class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        self.data = data

    def __len__(self):
        return self.data["gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}


# data_loader
def prepare_dataloader(
    data_pt: Dict[str, torch.Tensor],
    batch_size: int,
    shuffle: bool = False,
    intra_domain_shuffle: bool = False,
    drop_last: bool = False,
    num_workers: int = 0,
) -> DataLoader:
    if num_workers == 0:
        num_workers = min(len(os.sched_getaffinity(0)), batch_size // 2)

    dataset = SeqDataset(data_pt)

    if per_seq_batch_sample:
        # find the indices of samples in each seq batch
        subsets = []
        batch_labels_array = data_pt["batch_labels"].numpy()
        for batch_label in np.unique(batch_labels_array):
            batch_indices = np.where(batch_labels_array == batch_label)[0].tolist()
            subsets.append(batch_indices)
        data_loader = DataLoader(
            dataset=dataset,
            batch_sampler=SubsetsBatchSampler(
                subsets,
                batch_size,
                intra_subset_shuffle=intra_domain_shuffle,
                inter_subset_shuffle=shuffle,
                drop_last=drop_last,
            ),
            num_workers=num_workers,
            pin_memory=True,
        )
        return data_loader

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        pin_memory=True,
    )
    return data_loader


## Step 3: Load the pre-trained scGPT model

 # For ScBERT

In [None]:
# import torch
# import torch.nn as nn
# from transformers import BertModel, BertTokenizer

# # Detect available GPUs
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# num_gpus = torch.cuda.device_count()

# # Load pre-trained scBERT (replace 'scBERT' with the actual model name, if different)
# scbert_model_name = 'allenai/scibert_scivocab_uncased'  # You can change it based on your preference or the model you're using
# scbert = BertModel.from_pretrained(scbert_model_name).to(device)
# tokenizer = BertTokenizer.from_pretrained(scbert_model_name)

# # Define your Transformer model (which uses the scBERT embedding as the encoder)
# ntokens = len(vocab)  # size of vocabulary
# model = TransformerModel(
#     ntokens,
#     embsize,
#     nhead,
#     d_hid,
#     nlayers,
#     nlayers_cls=3,
#     n_cls=num_types if CLS else 1,
#     vocab=vocab,
#     dropout=dropout,
#     pad_token=pad_token,
#     pad_value=pad_value,
#     do_mvc=MVC,
#     do_dab=DAB,
#     use_batch_labels=INPUT_BATCH_LABELS,
#     num_batch_labels=num_batch_types,
#     domain_spec_batchnorm=config.DSBN,
#     input_emb_style=input_emb_style,
#     n_input_bins=n_input_bins,
#     cell_emb_style=cell_emb_style,
#     mvc_decoder_style=mvc_decoder_style,
#     ecs_threshold=ecs_threshold,
#     explicit_zero_prob=explicit_zero_prob,
#     use_fast_transformer=fast_transformer,
#     fast_transformer_backend=fast_transformer_backend,
#     pre_norm=config.pre_norm,
# )

# # Integrate scBERT into your model
# class scBERTTransformerModel(nn.Module):
#     def __init__(self, scbert, transformer_model):
#         super(scBERTTransformerModel, self).__init__()
#         self.scbert = scbert  # Pre-trained scBERT
#         self.transformer = transformer_model  # Your custom transformer model

#     def forward(self, input_ids, attention_mask):
#         # Get embeddings from scBERT
#         scbert_output = self.scbert(input_ids=input_ids, attention_mask=attention_mask)
#         embeddings = scbert_output.last_hidden_state  # You can use the [CLS] token embeddings or the full sequence

#         # Pass scBERT embeddings to the transformer model
#         output = self.transformer(embeddings)
#         return output

# # Create an integrated model
# scbert_transformer_model = scBERTTransformerModel(scbert, model).to(device)

# # Move model to device first
# scbert_transformer_model.to(device)

# # Load model weights if specified
# if config.load_model is not None:
#     try:
#         scbert_transformer_model.load_state_dict(torch.load(model_file, map_location=device))
#         logger.info(f"Loading all model params from {model_file}")
#     except:
#         # Load matching parameters only
#         model_dict = scbert_transformer_model.state_dict()
#         pretrained_dict = torch.load(model_file, map_location=device)
#         pretrained_dict = {
#             k: v
#             for k, v in pretrained_dict.items()
#             if k in model_dict and v.shape == model_dict[k].shape
#         }
#         for k, v in pretrained_dict.items():
#             logger.info(f"Loading params {k} with shape {v.shape}")
#         model_dict.update(pretrained_dict)
#         scbert_transformer_model.load_state_dict(model_dict)

# # Enable Multi-GPU if more than 1 GPU is available
# if num_gpus > 1:
#     print(f"Using {num_gpus} GPUs with DataParallel!")
#     scbert_transformer_model = nn.DataParallel(scbert_transformer_model)

# # Freeze encoder parameters if required
# pre_freeze_param_count = sum(
#     dict((p.data_ptr(), p.numel()) for p in scbert_transformer_model.parameters() if p.requires_grad).values()
# )

# for name, para in scbert_transformer_model.named_parameters():
#     print("-"*20)
#     print(f"name: {name}")
#     if config.freeze and "encoder" in name and "transformer_encoder" not in name:
#         print(f"Freezing weights for: {name}")
#         para.requires_grad = False

# post_freeze_param_count = sum(
#     dict((p.data_ptr(), p.numel()) for p in scbert_transformer_model.parameters() if p.requires_grad).values()
# )

# logger.info(f"Total Pre freeze Params {pre_freeze_param_count}")
# logger.info(f"Total Post freeze Params {post_freeze_param_count}")
# wandb.log(
#     {
#         "info/pre_freeze_param_count": pre_freeze_param_count,
#         "info/post_freeze_param_count": post_freeze_param_count,
#     },
# )

# # Move model again after freezing parameters
# scbert_transformer_model.to(device)

# # Enable Weights & Biases logging
# if isinstance(scbert_transformer_model, torch.nn.DataParallel):
#     wandb.watch(scbert_transformer_model.module)  # Access the original model inside DataParallel
# else:
#     wandb.watch(scbert_transformer_model)

# # If adversarial training (ADV) is enabled, use multiple GPUs for the discriminator too
# if ADV:
#     discriminator = AdversarialDiscriminator(
#         d_model=embsize,
#         n_cls=num_batch_types,
#     ).to(device)

#     if num_gpus > 1:
#         discriminator = nn.DataParallel(discriminator)


# For scGPT

In [20]:
import torch
import torch.nn as nn

# Detect available GPUs
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_gpus = torch.cuda.device_count()

# Define model
ntokens = len(vocab)  # size of vocabulary
model = TransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=3,
    n_cls=num_types if CLS else 1,
    vocab=vocab,
    dropout=dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    do_mvc=MVC,
    do_dab=DAB,
    use_batch_labels=INPUT_BATCH_LABELS,
    num_batch_labels=num_batch_types,
    domain_spec_batchnorm=config.DSBN,
    input_emb_style=input_emb_style,
    n_input_bins=n_input_bins,
    cell_emb_style=cell_emb_style,
    mvc_decoder_style=mvc_decoder_style,
    ecs_threshold=ecs_threshold,
    explicit_zero_prob=explicit_zero_prob,
    use_fast_transformer=fast_transformer,
    fast_transformer_backend=fast_transformer_backend,
    pre_norm=config.pre_norm,
)

# Move model to device first
model.to(device)

# Load model weights if specified
if config.load_model is not None:
    try:
        model.load_state_dict(torch.load(model_file, map_location=device))
        logger.info(f"Loading all model params from {model_file}")
    except:
        # Load matching parameters only
        model_dict = model.state_dict()
        pretrained_dict = torch.load(model_file, map_location=device)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
        for k, v in pretrained_dict.items():
            logger.info(f"Loading params {k} with shape {v.shape}")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

# Enable Multi-GPU if more than 1 GPU is available
if num_gpus > 1:
    print(f"Using {num_gpus} GPUs with DataParallel!")
    model = nn.DataParallel(model)

# Freeze encoder parameters if required
pre_freeze_param_count = sum(
    dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values()
)

for name, para in model.named_parameters():
    print("-"*20)
    print(f"name: {name}")
    if config.freeze and "encoder" in name and "transformer_encoder" not in name:
        print(f"Freezing weights for: {name}")
        para.requires_grad = False

post_freeze_param_count = sum(
    dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values()
)

logger.info(f"Total Pre freeze Params {pre_freeze_param_count}")
logger.info(f"Total Post freeze Params {post_freeze_param_count}")
wandb.log(
    {
        "info/pre_freeze_param_count": pre_freeze_param_count,
        "info/post_freeze_param_count": post_freeze_param_count,
    },
)

# Move model again after freezing parameters
model.to(device)

# # Enable Weights & Biases logging
# wandb.watch(model)

#Modify wandb.watch(model) for Multi-GPU
if isinstance(model, torch.nn.DataParallel):
    wandb.watch(model.module)  # Access the original model inside DataParallel
else:
    wandb.watch(model)


# If adversarial training (ADV) is enabled, use multiple GPUs for the discriminator too
if ADV:
    discriminator = AdversarialDiscriminator(
        d_model=embsize,
        n_cls=num_batch_types,
    ).to(device)

    if num_gpus > 1:
        discriminator = nn.DataParallel(discriminator)


scGPT - INFO - Loading params encoder.embedding.weight with shape torch.Size([60697, 512])
scGPT - INFO - Loading params encoder.enc_norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params encoder.enc_norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear1.weight with shape torch.Size([512, 1])
scGPT - INFO - Loading params value_encoder.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params value_encoder.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.out_proj.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.out_proj.bias with shape torch.Si

# For SingleCellNet

In [1]:
# import torch
# import torch.nn as nn
# from transformers import BertModel, BertTokenizer

# # Detect available GPUs
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# num_gpus = torch.cuda.device_count()

# # Load pre-trained scBERT (as before)
# scbert_model_name = 'allenai/scibert_scivocab_uncased'  # Change to your desired scBERT model
# scbert = BertModel.from_pretrained(scbert_model_name).to(device)
# tokenizer = BertTokenizer.from_pretrained(scbert_model_name)

# # Load SingleCellNet model (assuming it's a PyTorch model)
# # You need to modify this part based on how SingleCellNet is saved or loaded.
# # If it's available as a PyTorch model, load it similarly to scBERT

# # Assuming SingleCellNet is a PyTorch model (this could be different if it's TensorFlow or another framework)
# singlecellnet_model_path = "path/to/singlecellnet_model.pth"  # Path to the pre-trained SingleCellNet model file
# singlecellnet_model = torch.load(singlecellnet_model_path, map_location=device)

# # Example SingleCellNet Model Architecture (this will vary based on the model's actual definition)
# class SingleCellNet(nn.Module):
#     def __init__(self):
#         super(SingleCellNet, self).__init__()
#         # Define the architecture (replace with the actual model details)
#         self.layer1 = nn.Linear(512, 256)
#         self.layer2 = nn.Linear(256, 128)
#         self.output_layer = nn.Linear(128, 10)  # Example for 10 classes

#     def forward(self, x):
#         x = torch.relu(self.layer1(x))
#         x = torch.relu(self.layer2(x))
#         x = self.output_layer(x)
#         return x

# # Assuming the above class is the correct architecture, replace with actual loading if different
# singlecellnet = SingleCellNet().to(device)

# # Load pre-trained weights if available
# singlecellnet.load_state_dict(torch.load(singlecellnet_model_path, map_location=device))

# # Define your Transformer model as before
# ntokens = len(vocab)  # size of vocabulary
# model = TransformerModel(
#     ntokens,
#     embsize,
#     nhead,
#     d_hid,
#     nlayers,
#     nlayers_cls=3,
#     n_cls=num_types if CLS else 1,
#     vocab=vocab,
#     dropout=dropout,
#     pad_token=pad_token,
#     pad_value=pad_value,
#     do_mvc=MVC,
#     do_dab=DAB,
#     use_batch_labels=INPUT_BATCH_LABELS,
#     num_batch_labels=num_batch_types,
#     domain_spec_batchnorm=config.DSBN,
#     input_emb_style=input_emb_style,
#     n_input_bins=n_input_bins,
#     cell_emb_style=cell_emb_style,
#     mvc_decoder_style=mvc_decoder_style,
#     ecs_threshold=ecs_threshold,
#     explicit_zero_prob=explicit_zero_prob,
#     use_fast_transformer=fast_transformer,
#     fast_transformer_backend=fast_transformer_backend,
#     pre_norm=config.pre_norm,
# )

# # Combine scBERT and SingleCellNet into a unified model, if necessary
# class UnifiedModel(nn.Module):
#     def __init__(self, scbert, singlecellnet, transformer_model):
#         super(UnifiedModel, self).__init__()
#         self.scbert = scbert  # Pre-trained scBERT
#         self.singlecellnet = singlecellnet  # Pre-trained SingleCellNet
#         self.transformer = transformer_model  # Custom transformer model

#     def forward(self, input_ids, attention_mask, cell_features):
#         # Get embeddings from scBERT
#         scbert_output = self.scbert(input_ids=input_ids, attention_mask=attention_mask)
#         embeddings = scbert_output.last_hidden_state  # Use the [CLS] token embeddings or the full sequence

#         # Pass scBERT embeddings to the transformer model
#         transformer_output = self.transformer(embeddings)

#         # Optionally, use SingleCellNet for further classification
#         cell_classification = self.singlecellnet(cell_features)

#         return transformer_output, cell_classification

# # Create an integrated model
# unified_model = UnifiedModel(scbert, singlecellnet, model).to(device)

# # Move model to device first
# unified_model.to(device)

# # Load model weights if specified
# if config.load_model is not None:
#     try:
#         unified_model.load_state_dict(torch.load(model_file, map_location=device))
#         logger.info(f"Loading all model params from {model_file}")
#     except:
#         # Load matching parameters only
#         model_dict = unified_model.state_dict()
#         pretrained_dict = torch.load(model_file, map_location=device)
#         pretrained_dict = {
#             k: v
#             for k, v in pretrained_dict.items()
#             if k in model_dict and v.shape == model_dict[k].shape
#         }
#         for k, v in pretrained_dict.items():
#             logger.info(f"Loading params {k} with shape {v.shape}")
#         model_dict.update(pretrained_dict)
#         unified_model.load_state_dict(model_dict)

# # Enable Multi-GPU if more than 1 GPU is available
# if num_gpus > 1:
#     print(f"Using {num_gpus} GPUs with DataParallel!")
#     unified_model = nn.DataParallel(unified_model)

# # Freeze encoder parameters if required
# pre_freeze_param_count = sum(
#     dict((p.data_ptr(), p.numel()) for p in unified_model.parameters() if p.requires_grad).values()
# )

# for name, para in unified_model.named_parameters():
#     print("-"*20)
#     print(f"name: {name}")
#     if config.freeze and "encoder" in name and "transformer_encoder" not in name:
#         print(f"Freezing weights for: {name}")
#         para.requires_grad = False

# post_freeze_param_count = sum(
#     dict((p.data_ptr(), p.numel()) for p in unified_model.parameters() if p.requires_grad).values()
# )

# logger.info(f"Total Pre freeze Params {pre_freeze_param_count}")
# logger.info(f"Total Post freeze Params {post_freeze_param_count}")
# wandb.log(
#     {
#         "info/pre_freeze_param_count": pre_freeze_param_count,
#         "info/post_freeze_param_count": post_freeze_param_count,
#     },
# )

# # Move model again after freezing parameters
# unified_model.to(device)

# # Enable Weights & Biases logging
# if isinstance(unified_model, torch.nn.DataParallel):
#     wandb.watch(unified_model.module)  # Access the original model inside DataParallel
# else:
#     wandb.watch(unified_model)

# # If adversarial training (ADV) is enabled, use multiple GPUs for the discriminator too
# if ADV:
#     discriminator = AdversarialDiscriminator(
#         d_model=embsize,
#         n_cls=num_batch_types,
#     ).to(device)

#     if num_gpus > 1:
#         discriminator = nn.DataParallel(discriminator)


In [21]:
criterion = masked_mse_loss
criterion_cls = nn.CrossEntropyLoss()
criterion_dab = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=lr, eps=1e-4 if config.amp else 1e-8
)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, schedule_interval, gamma=config.schedule_ratio
)
if DAB_separate_optim:
    optimizer_dab = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler_dab = torch.optim.lr_scheduler.StepLR(
        optimizer_dab, schedule_interval, gamma=config.schedule_ratio
    )
if ADV:
    criterion_adv = nn.CrossEntropyLoss()  # consider using label smoothing
    optimizer_E = torch.optim.Adam(model.parameters(), lr=lr_ADV)
    scheduler_E = torch.optim.lr_scheduler.StepLR(
        optimizer_E, schedule_interval, gamma=config.schedule_ratio
    )
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr_ADV)
    scheduler_D = torch.optim.lr_scheduler.StepLR(
        optimizer_D, schedule_interval, gamma=config.schedule_ratio
    )

scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

In [22]:
def train(model: nn.Module, loader: DataLoader) -> None:
    """
    Train the model for one epoch.
    """
    model.train()
    (
        total_loss,
        total_mse,
        total_cls,
        total_cce,
        total_mvc,
        total_ecs,
        total_dab,
        total_adv_E,
        total_adv_D,
        total_zero_log_prob,
        total_mvc_zero_log_prob,
    ) = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    total_error = 0.0
    start_time = time.time()

    num_batches = len(loader)

    # Define the number of gradient accumulation steps
    gradient_accumulation_steps = 2  # Change this as needed
    total_loss=0.0
    # # Initialize the accumulator
    # optimizer.zero_grad()

    accumulation_step = 0  # Initialize for each epoch
    for batch, batch_data in enumerate(loader):

        # Initialize accumulation_step counter at the start of training or batch loop
       
        input_gene_ids = batch_data["gene_ids"].to(device)
        input_values = batch_data["values"].to(device)
        target_values = batch_data["target_values"].to(device)
        batch_labels = batch_data["batch_labels"].to(device)
        celltype_labels = batch_data["celltype_labels"].to(device)

        
        # Check If GPU Memory Is Utilized Properly
        print("Check If GPU Memory Is Utilized Properly")
        #before forward pass
        for i in range(torch.cuda.device_count()):
            print(f"GPU {i}: Memory Allocated: {torch.cuda.memory_allocated(i) / 1e9:.2f} GB")


        src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
        with torch.cuda.amp.autocast(enabled=config.amp):
            output_dict = model(
                input_gene_ids,
                input_values,
                src_key_padding_mask=src_key_padding_mask,
                batch_labels=batch_labels if INPUT_BATCH_LABELS or config.DSBN else None,
                CLS=CLS,
                CCE=CCE,
                MVC=MVC,
                ECS=ECS,
                do_sample=do_sample_in_train,
                #generative_training=False
            )

            masked_positions = input_values.eq(mask_value)  # the postions to predict
            loss = 0.0
            metrics_to_log = {}
            if MLM:
                loss_mse = criterion(
                    output_dict["mlm_output"], target_values, masked_positions
                )
                loss = loss + loss_mse
                metrics_to_log = {"train/mse": loss_mse.item()}
            if explicit_zero_prob:
                loss_zero_log_prob = criterion_neg_log_bernoulli(
                    output_dict["mlm_zero_probs"], target_values, masked_positions
                )
                loss = loss + loss_zero_log_prob
                metrics_to_log.update({"train/nzlp": loss_zero_log_prob.item()})
            if CLS:
                loss_cls = criterion_cls(output_dict["cls_output"], celltype_labels)
                loss = loss + loss_cls
                metrics_to_log.update({"train/cls": loss_cls.item()})

                error_rate = 1 - (
                    (output_dict["cls_output"].argmax(1) == celltype_labels)
                    .sum()
                    .item()
                ) / celltype_labels.size(0)
            if CCE:
                loss_cce = 10 * output_dict["loss_cce"]
                loss = loss + loss_cce
                metrics_to_log.update({"train/cce": loss_cce.item()})
            if MVC:
                loss_mvc = criterion(
                    output_dict["mvc_output"], target_values, masked_positions
                )
                loss = loss + loss_mvc
                metrics_to_log.update({"train/mvc": loss_mvc.item()})
            if MVC and explicit_zero_prob:
                loss_mvc_zero_log_prob = criterion_neg_log_bernoulli(
                    output_dict["mvc_zero_probs"], target_values, masked_positions
                )
                loss = loss + loss_mvc_zero_log_prob
                metrics_to_log.update({"train/mvc_nzlp": loss_mvc_zero_log_prob.item()})
            if ECS:
                loss_ecs = 10 * output_dict["loss_ecs"]
                loss = loss + loss_ecs
                metrics_to_log.update({"train/ecs": loss_ecs.item()})
            if DAB:
                # try weighting and separate optimizer
                loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)
                loss = loss + dab_weight * loss_dab
                metrics_to_log.update({"train/dab": loss_dab.item()})

        # # Gradient accumulation
        # loss = loss / gradient_accumulation_steps  # Scale loss to accumulate gradients
        # scaler.scale(loss).backward()

        # model.zero_grad()

        # Accumulate gradients (without updating the weights)
        loss = loss / gradient_accumulation_steps  # Divide by accumulation steps to average loss
        scaler.scale(loss).backward()
        # scaler.unscale_(optimizer)

        # Increment the accumulation step counter
        accumulation_step += 1

        # Perform optimization step after the desired number of steps
        if (batch + 1) % gradient_accumulation_steps == 0:
            accumulation_step = 0
            scaler.unscale_(optimizer)
            with warnings.catch_warnings(record=True) as w:
                warnings.filterwarnings("always")
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(),
                    1.0,
                    error_if_nonfinite=False if scaler.is_enabled() else True,
                )
                if len(w) > 0:
                    logger.warning(
                        f"Found infinite gradient. This may be caused by the gradient "
                        f"scaler. The current scale is {scaler.get_scale()}. This warning "
                        "can be ignored if no longer occurs after autoscaling of the scaler."
                    )
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()  # Reset gradients after optimization step

            # Log progress after accumulation step
            print(f"Accumulation Step: {accumulation_step}/{len(loader) // gradient_accumulation_steps}, Loss: {loss.item():.4f}")
            logger.info(f"Accumulation Step: {accumulation_step}/{len(loader) // gradient_accumulation_steps}, Loss: {loss.item():.4f}")

            #model.xero_grad() not used here but works 
        
        # with warnings.catch_warnings(record=True) as w:
        #     warnings.filterwarnings("always")
        #     torch.nn.utils.clip_grad_norm_(
        #         model.parameters(),
        #         1.0,
        #         error_if_nonfinite=False if scaler.is_enabled() else True,
        #     )
        #     if len(w) > 0:
        #         logger.warning(
        #             f"Found infinite gradient. This may be caused by the gradient "
        #             f"scaler. The current scale is {scaler.get_scale()}. This warning "
        #             "can be ignored if no longer occurs after autoscaling of the scaler."
        #         )
        # scaler.step(optimizer)
        # scaler.update()

        if ADV:
            # rerun the model for adversarial training
            output_dict = model(
                input_gene_ids,
                input_values,
                src_key_padding_mask=src_key_padding_mask,
                batch_labels=batch_labels if INPUT_BATCH_LABELS or config.DSBN else None,
                CLS=CLS,
                CCE=CCE,
                MVC=MVC,
                ECS=ECS,
                do_sample=do_sample_in_train,
                #generative_training=False
            )

            # TRAINING DISCRIMINATOR
            loss_adv_D = criterion_adv(
                discriminator(output_dict["cell_emb"].detach()), batch_labels
            )
            if epoch > adv_D_delay_epochs:
                discriminator.zero_grad()
                loss_adv_D.backward()
                optimizer_D.step()

            # TRAINING ENCODER
            loss_adv_E = -criterion_adv(
                discriminator(output_dict["cell_emb"]), batch_labels
            )
            # NOTE: the loss is negative here because we want to maximize
            # the cross_entropy_loss, in other words, disguise against the discriminator
            if epoch > adv_E_delay_epochs:
                model.zero_grad()
                discriminator.zero_grad()
                loss_adv_E.backward()
                optimizer_E.step()

        wandb.log(metrics_to_log)

        total_loss += loss.item()
        total_mse += loss_mse.item() if MLM else 0.0
        total_cls += loss_cls.item() if CLS else 0.0
        total_cce += loss_cce.item() if CCE else 0.0
        total_mvc += loss_mvc.item() if MVC else 0.0
        total_ecs += loss_ecs.item() if ECS else 0.0
        total_dab += loss_dab.item() if DAB else 0.0
        total_adv_E += loss_adv_E.item() if ADV else 0.0
        total_adv_D += loss_adv_D.item() if ADV else 0.0
        total_zero_log_prob += loss_zero_log_prob.item() if explicit_zero_prob else 0.0
        total_mvc_zero_log_prob += (
            loss_mvc_zero_log_prob.item() if MVC and explicit_zero_prob else 0.0
        )
        total_error += error_rate
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            cur_mse = total_mse / log_interval
            cur_cls = total_cls / log_interval if CLS else 0.0
            cur_cce = total_cce / log_interval if CCE else 0.0
            cur_mvc = total_mvc / log_interval if MVC else 0.0
            cur_ecs = total_ecs / log_interval if ECS else 0.0
            cur_dab = total_dab / log_interval if DAB else 0.0
            cur_adv_E = total_adv_E / log_interval if ADV else 0.0
            cur_adv_D = total_adv_D / log_interval if ADV else 0.0
            cur_zero_log_prob = (
                total_zero_log_prob / log_interval if explicit_zero_prob else 0.0
            )
            cur_mvc_zero_log_prob = (
                total_mvc_zero_log_prob / log_interval
                if MVC and explicit_zero_prob
                else 0.0
            )
            cur_error = total_error / log_interval
            # ppl = math.exp(cur_loss)
            logger.info(
                f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | "
                f"loss {cur_loss:5.2f} | "
                + (f"mse {cur_mse:5.2f} | mre {cur_error:5.2f} |" if MLM else "")
                + (f"cls {cur_cls:5.2f} | " if CLS else "")
                + (f"err {cur_error:5.2f} | " if CLS else "")
                + (f"cce {cur_cce:5.2f} |" if CCE else "")
                + (f"mvc {cur_mvc:5.2f} |" if MVC else "")
                + (f"ecs {cur_ecs:5.2f} |" if ECS else "")
                + (f"dab {cur_dab:5.2f} |" if DAB else "")
                + (f"adv_E {cur_adv_E:5.2f} |" if ADV else "")
                + (f"adv_D {cur_adv_D:5.2f} |" if ADV else "")
                + (f"nzlp {cur_zero_log_prob:5.2f} |" if explicit_zero_prob else "")
                + (
                    f"mvc_nzlp {cur_mvc_zero_log_prob:5.2f} |"
                    if MVC and explicit_zero_prob
                    else ""
                )
            )
            total_loss = 0
            total_mse = 0
            total_cls = 0
            total_cce = 0
            total_mvc = 0
            total_ecs = 0
            total_dab = 0
            total_adv_E = 0
            total_adv_D = 0
            total_zero_log_prob = 0
            total_mvc_zero_log_prob = 0
            total_error = 0
            start_time = time.time()


def define_wandb_metrcis():
    wandb.define_metric("valid/mse", summary="min", step_metric="epoch")
    wandb.define_metric("valid/mre", summary="min", step_metric="epoch")
    wandb.define_metric("valid/dab", summary="min", step_metric="epoch")
    wandb.define_metric("valid/sum_mse_dab", summary="min", step_metric="epoch")
    wandb.define_metric("test/avg_bio", summary="max")


def evaluate(model: nn.Module, loader: DataLoader, return_raw: bool = False) -> float:
    """
    Evaluate the model on the evaluation data.
    """
    model.eval()
    total_loss = 0.0
    total_error = 0.0
    total_dab = 0.0
    total_num = 0
    predictions = []
    with torch.no_grad():
        for batch_data in loader:
            input_gene_ids = batch_data["gene_ids"].to(device)
            input_values = batch_data["values"].to(device)
            target_values = batch_data["target_values"].to(device)
            batch_labels = batch_data["batch_labels"].to(device)
            celltype_labels = batch_data["celltype_labels"].to(device)

            src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
            with torch.cuda.amp.autocast(enabled=config.amp):
                output_dict = model(
                    input_gene_ids,
                    input_values,
                    src_key_padding_mask=src_key_padding_mask,
                    batch_labels=batch_labels if INPUT_BATCH_LABELS or config.DSBN else None,
                    CLS=CLS,  # evaluation does not need CLS or CCE
                    CCE=False,
                    MVC=False,
                    ECS=False,
                    do_sample=do_sample_in_train,
                    #generative_training = False,
                )
                output_values = output_dict["cls_output"]
                loss = criterion_cls(output_values, celltype_labels)

                if DAB:
                    loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)

            total_loss += loss.item() * len(input_gene_ids)
            accuracy = (output_values.argmax(1) == celltype_labels).sum().item()
            total_error += (1 - accuracy / len(input_gene_ids)) * len(input_gene_ids)
            total_dab += loss_dab.item() * len(input_gene_ids) if DAB else 0.0
            total_num += len(input_gene_ids)
            preds = output_values.argmax(1).cpu().numpy()
            predictions.append(preds)

    wandb.log(
        {
            "valid/mse": total_loss / total_num,
            "valid/err": total_error / total_num,
            "valid/dab": total_dab / total_num,
            "valid/sum_mse_dab": (total_loss + dab_weight * total_dab) / total_num,
            "epoch": epoch,
        },
    )

    if return_raw:
        return np.concatenate(predictions, axis=0)

    return total_loss / total_num, total_error / total_num


## Step 4: Finetune scGPT with task-specific objectives

In [23]:
from tqdm import tqdm
# Initialize list to track epoch times for progress estimation
epoch_times = []

best_val_loss = float("inf")
best_avg_bio = 0.0
best_model = None
define_wandb_metrcis()

# for epoch in range(1, epochs + 1):
for epoch in tqdm(range(1, epochs + 1), desc="Training Epochs", unit="epoch"):
    epoch_start_time = time.time()
    train_data_pt, valid_data_pt = prepare_data(sort_seq_batch=per_seq_batch_sample)
    train_loader = prepare_dataloader(
        train_data_pt,
        batch_size=batch_size,
        shuffle=False,
        intra_domain_shuffle=True,
        drop_last=False,
    )
    valid_loader = prepare_dataloader(
        valid_data_pt,
        batch_size=eval_batch_size,
        shuffle=False,
        intra_domain_shuffle=False,
        drop_last=False,
    )

    if config.do_train:
        train(
            model,
            loader=train_loader,
        )
    val_loss, val_err = evaluate(
        model,
        loader=valid_loader,
    )
    
    elapsed = time.time() - epoch_start_time
    epoch_times.append(elapsed)
    
    logger.info("-" * 89)
    logger.info(
        f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
        f"valid loss/mse {val_loss:5.4f} | err {val_err:5.4f}"
    )
    
        # Estimate remaining time
    avg_epoch_time = sum(epoch_times) / len(epoch_times)
    remaining_epochs = epochs - epoch
    total_remaining_time = avg_epoch_time * remaining_epochs
    total_estimated_time = sum(epoch_times) + total_remaining_time

    logger.info(
        f"Estimated time remaining: {remaining_epochs * avg_epoch_time / 60:5.2f} min | "
        f"Total estimated time: {total_estimated_time / 3600:5.2f} hours"
    )

    logger.info("-" * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        best_model_epoch = epoch
        logger.info(f"Best model with score {best_val_loss:5.4f}")

    scheduler.step()
    if DAB_separate_optim:
        scheduler_dab.step()
    if ADV:
        scheduler_D.step()
        scheduler_E.step()

Check If GPU Memory Is Utilized Properly
GPU 0: Memory Allocated: 1.27 GB
GPU 1: Memory Allocated: 0.08 GB
Accumulation Step: 0/1283, Loss: 1.9509
scGPT - INFO - Accumulation Step: 0/1283, Loss: 1.9509
Check If GPU Memory Is Utilized Properly
GPU 0: Memory Allocated: 1.07 GB
GPU 1: Memory Allocated: 0.08 GB
Check If GPU Memory Is Utilized Properly
GPU 0: Memory Allocated: 1.27 GB
GPU 1: Memory Allocated: 0.08 GB
Accumulation Step: 0/1283, Loss: 1.9108
scGPT - INFO - Accumulation Step: 0/1283, Loss: 1.9108
Check If GPU Memory Is Utilized Properly
GPU 0: Memory Allocated: 1.07 GB
GPU 1: Memory Allocated: 0.08 GB
Check If GPU Memory Is Utilized Properly
GPU 0: Memory Allocated: 1.27 GB
GPU 1: Memory Allocated: 0.08 GB
Accumulation Step: 0/1283, Loss: 1.9058
scGPT - INFO - Accumulation Step: 0/1283, Loss: 1.9058
Check If GPU Memory Is Utilized Properly
GPU 0: Memory Allocated: 1.07 GB
GPU 1: Memory Allocated: 0.08 GB
Check If GPU Memory Is Utilized Properly
GPU 0: Memory Allocated: 1.27 GB

Training Epochs:  50%|█████     | 1/2 [45:31<45:31, 2731.47s/epoch]


KeyboardInterrupt: 

In [1]:
# torch.save(best_model.state_dict(), "/kaggle/working/best_model_1100.pth")

NameError: name 'torch' is not defined

In [None]:
# model.load_state_dict(torch.load("best_model_0700.pth"))
# model.eval()  # Set to evaluation mode

In [None]:
# #Debug GPU Memory Usage
# import torch
# print(torch.cuda.memory_summary(device=None, abbreviated=False))

In [None]:
# %% inference
def test(model: nn.Module, adata: DataLoader) -> float:
    all_counts = (
        adata.layers[input_layer_key].A
        if issparse(adata.layers[input_layer_key])
        else adata.layers[input_layer_key]
    )

    celltypes_labels = adata.obs["celltype_id"].tolist()  # make sure count from 0
    celltypes_labels = np.array(celltypes_labels)

    batch_ids = adata.obs["batch_id"].tolist()
    batch_ids = np.array(batch_ids)

    tokenized_test = tokenize_and_pad_batch(
        all_counts,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=include_zero_gene,
    )

    input_values_test = random_mask_value(
        tokenized_test["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )

    test_data_pt = {
        "gene_ids": tokenized_test["genes"],
        "values": input_values_test,
        "target_values": tokenized_test["values"],
        "batch_labels": torch.from_numpy(batch_ids).long(),
        "celltype_labels": torch.from_numpy(celltypes_labels).long(),
    }

    test_loader = DataLoader(
        dataset=SeqDataset(test_data_pt),
        batch_size=eval_batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=min(len(os.sched_getaffinity(0)), eval_batch_size // 2),
        pin_memory=True,
    )

    model.eval()
    predictions = evaluate(
        model,
        loader=test_loader,
        return_raw=True,
    )

    # compute accuracy, precision, recall, f1
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

    accuracy = accuracy_score(celltypes_labels, predictions)
    precision = precision_score(celltypes_labels, predictions, average="macro")
    recall = recall_score(celltypes_labels, predictions, average="macro")
    macro_f1 = f1_score(celltypes_labels, predictions, average="macro")

    logger.info(
        f"Accuracy: {accuracy:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}, "
        f"Macro F1: {macro_f1:.3f}"
    )

    results = {
        "test/accuracy": accuracy,
        "test/precision": precision,
        "test/recall": recall,
        "test/macro_f1": macro_f1,
    }

    return predictions, celltypes_labels, results

## Step 5: Inference with fine-tuned scGPT model
In the cell-type annotation task, the fine-tuned scGPT predicts cell-type labels for query set as inference. The model performance is evaluated on standard classificaton metrics. Here we visualize the predicted labels over the scGPT cell embeddings, and present the confusion matrix for detailed classification performance on the cell-group level.

In [None]:
# import time
# import tqdm
# import pickle
# import wandb
# import scanpy as sc
# import matplotlib.pyplot as plt

# # Start time for estimation
# test_start_time = time.time()

# # Reduce test size by a factor of 10
# total_test_samples = len(adata_test) // 10
# adata_test_subset = adata_test[:total_test_samples]  # Select only 10% of data

# # Run inference with tqdm progress bar
# predictions, labels, results = [], [], {}
# for i in tqdm.tqdm(range(total_test_samples), desc="Testing", unit="sample"):
#     pred, label, res = test(best_model, adata_test_subset[i])  # Assuming test() handles single sample
#     predictions.append(pred)
#     labels.append(label)
#     results = res  # Assuming results structure is updated per sample

# # End time and compute time per sample
# elapsed_time = time.time() - test_start_time
# avg_time_per_sample = elapsed_time / total_test_samples
# estimated_total_time = avg_time_per_sample * total_test_samples

# # Store predictions in AnnData object
# adata_test_raw.obs["predictions"] = [id2type[p] for p in predictions]

# # Define color palette
# palette_ = plt.rcParams["axes.prop_cycle"].by_key()["color"]
# palette_ = palette_ * 3  # Extend palette
# palette_ = {c: palette_[i] for i, c in enumerate(celltypes)}

# # Plot results
# with plt.rc_context({"figure.figsize": (6, 4), "figure.dpi": (300)}):
#     sc.pl.umap(
#         adata_test_raw,
#         color=["celltype", "predictions"],
#         palette=palette_,
#         show=False,
#     )
#     plt.savefig(save_dir / "results.png", dpi=300)

# # Save results
# save_dict = {
#     "predictions": predictions,
#     "labels": labels,
#     "results": results,
#     "id_maps": id2type
# }
# with open(save_dir / "results.pkl", "wb") as f:
#     pickle.dump(save_dict, f)

# # Log results to wandb
# results["test/cell_umap"] = wandb.Image(
#     str(save_dir / "results.png"),
#     caption=f"Predictions Macro F1: {results['test/macro_f1']:.3f}",
# )
# results["test/estimated_time_min"] = estimated_total_time / 60  # Convert to minutes

# wandb.log(results)

# # Print estimated time
# print(f"Testing completed in {elapsed_time:.2f} seconds")
# print(f"Estimated total testing time: {estimated_total_time / 60:.2f} minutes")
# print(f"Reduced test set size: {total_test_samples} samples")

In [None]:
print(adata_test_subset.shape)  # Check the number of features
print(len(gene_ids))  # Check the number of gene IDs

In [None]:
# # Print first few gene names from adata_test_subset
# print(adata_test_subset.var_names[:10])

# # Print first few gene_ids
# print(gene_ids[:10])


In [None]:
# # Check how many genes in adata_test_subset match gene_ids
# matching_genes = [gene for gene in adata_test_subset.var_names if gene in gene_ids]
# print(f"Number of matching genes: {len(matching_genes)}")

# # Check a few unmatched genes
# unmatched_genes = [gene for gene in adata_test_subset.var_names if gene not in gene_ids]
# print(f"Some unmatched genes: {unmatched_genes[:10]}")


In [None]:
# # Filter adata_test_subset to only include the genes that are in gene_ids
# filtered_genes = [gene for gene in adata_test_subset.var_names if gene in gene_ids]

# # Subset the data to only include the valid genes
# adata_test_subset_filtered = adata_test_subset[:, filtered_genes].copy()

# # Check the new shape
# print(adata_test_subset_filtered.shape)  # This should have 12,588 features now


In [None]:
print(type(gene_ids))

In [None]:
# # Reduce test size by a factor of 10
# total_test_samples = len(adata_test) // 10
# adata_test_subset = adata_test[:total_test_samples]  # Select only 10% of data

predictions, labels, results = test(best_model, adata_test_subset)
adata_test_raw_subset = adata_test_raw[:total_test_samples]
adata_test_raw_subset.obs["predictions"] = [id2type[int(p)] for p in predictions]

# plot
palette_ = plt.rcParams["axes.prop_cycle"].by_key()["color"] 
palette_ = plt.rcParams["axes.prop_cycle"].by_key()["color"] + plt.rcParams["axes.prop_cycle"].by_key()["color"] + plt.rcParams["axes.prop_cycle"].by_key()["color"]
palette_ = {c: palette_[i] for i, c in enumerate(celltypes)}

with plt.rc_context({"figure.figsize": (6, 4), "figure.dpi": (300)}):
    sc.pl.umap(
        adata_test_raw,
        color=["celltype", "predictions"],
        palette=palette_,
        show=False,
    )
    plt.savefig(save_dir / "results.png", dpi=300)

save_dict = {
    "predictions": predictions,
    "labels": labels,
    "results": results,
    "id_maps": id2type
}
with open(save_dir / "results.pkl", "wb") as f:
    pickle.dump(save_dict, f)

results["test/cell_umap"] = wandb.Image(
    str(save_dir / "results.png"),
    caption=f"predictions macro f1 {results['test/macro_f1']:.3f}",
)
wandb.log(results)

In [None]:
from sklearn.metrics import confusion_matrix
celltypes = list(celltypes)
for i in set([id2type[p] for p in predictions]):
    if i not in celltypes:
        celltypes.remove(i)
cm = confusion_matrix(labels, predictions)
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
cm = pd.DataFrame(cm, index=celltypes[:cm.shape[0]], columns=celltypes[:cm.shape[1]])
plt.figure(figsize=(10, 10))
sns.heatmap(cm, annot=True, fmt=".1f", cmap="Blues")
plt.savefig(save_dir / "confusion_matrix.png", dpi=300)

results["test/confusion_matrix"] = wandb.Image(
    str(save_dir / "confusion_matrix.png"),
    caption=f"confusion matrix",
)

In [None]:
# save the model into the save_dir
torch.save(best_model.state_dict(),"/kaggle/working/model-0700.pt")

In [None]:
from sklearn.metrics import accuracy_score, f1_score, adjusted_rand_score, normalized_mutual_info_score

# Load datasets based on the name and source type
def load_dataset(dataset_name, data_dir):
    """
    Load scRNA-seq dataset based on the name and data directory.
    
    Parameters:
    - dataset_name: str, name of the dataset to load ('covid19', 'healthy_lung', 'lung_cancer').
    - data_dir: str, the directory where the dataset files are stored.
    
    Returns:
    - Anndata object containing the dataset.
    """
    # Ensure the directory exists
    if not os.path.exists(data_dir):
        raise ValueError(f"Data directory {data_dir} does not exist.")
    if dataset_name == "covid19":
        covid19_file = os.path.join(data_dir, "covid19_data.h5")
        if os.path.exists(covid19_file):
            data = sc.read_10x_h5(covid19_file)
        else:
            raise ValueError("COVID-19 dataset not found.")
    elif dataset_name == "healthy_lung":
        healthy_lung_file = os.path.join(data_dir, "healthy_lung_data.h5")
        if os.path.exists(healthy_lung_file):
            data = sc.read_10x_h5(healthy_lung_file)
        else:
            raise ValueError("Healthy Lung Tissue dataset not found.")
    elif dataset_name == "lung_cancer":
        lung_cancer_file = os.path.join(data_dir, "lung_cancer_data.h5")
        if os.path.exists(lung_cancer_file):
            data = sc.read_10x_h5(lung_cancer_file)
        else:
            raise ValueError("Lung Cancer dataset not found.")
    else:
        raise ValueError(f"Dataset {dataset_name} not recognized.")
    # Return the loaded dataset (AnnData object)
    return data

# Preprocessing function
def preprocess_data(adata):
    """
    Preprocess the scRNA-seq data: filtering, normalization, log transformation.
    
    Parameters:
    - adata: AnnData object containing the raw scRNA-seq data.
    
    Returns:
    - adata: Processed AnnData object.
    """
    # Filter genes and cells based on counts (e.g., minimum number of counts per gene and cell)
    sc.pp.filter_genes(adata, min_counts=1)  # Keep genes with at least 1 count in any cell
    sc.pp.filter_cells(adata, min_genes=1)   # Keep cells with at least 1 gene expression

    # Normalize the data (total count normalization)
    sc.pp.normalize_total(adata, target_sum=1e4)  # Normalize each cell to 10,000 counts

    # Log transformation of the normalized data
    sc.pp.log1p(adata)

    # Highly Variable Genes (optional but recommended for downstream analysis)
    sc.pp.highly_variable_genes(adata, min_mean=0.1, max_mean=10, min_disp=0.5)
    adata = adata[:, adata.var.highly_variable]

    return adata

# Evaluate model performance using various metrics
def evaluate_model(true_labels, pred_labels):
    """
    Evaluate the performance of a model using accuracy, F1 score, ARI, and NMI.
    
    Parameters:
    - true_labels: array, ground truth labels.
    - pred_labels: array, predicted labels.
    
    Returns:
    - metrics: dict, containing the evaluation metrics.
    """
    metrics = {}
    metrics['accuracy'] = accuracy_score(true_labels, pred_labels)
    metrics['f1_score'] = f1_score(true_labels, pred_labels, average='weighted')
    metrics['ari'] = adjusted_rand_score(true_labels, pred_labels)
    metrics['nmi'] = normalized_mutual_info_score(true_labels, pred_labels, average_method='arithmetic')
    return metrics

data_dir = "/kaggle/input/test_data/"
model_path= "/kaggle/working/model-0700.pt"
try:
    covid19_data = load_dataset("covid19", data_dir)
    print(f"Loaded COVID-19 dataset with shape {covid19_data.shape}")
except ValueError as e:
    print(e)
try:
    healthy_lung_data = load_dataset("healthy_lung", data_dir)
    print(f"Loaded Healthy Lung Tissue dataset with shape {healthy_lung_data.shape}")
except ValueError as e:
    print(e)
try:
    lung_cancer_data = load_dataset("lung_cancer", data_dir)
    print(f"Loaded Lung Cancer dataset with shape {lung_cancer_data.shape}")
except ValueError as e:
    print(e)

covid19_data = preprocess_data(covid19_data)
healthy_lung_data = preprocess_data(healthy_lung_data)
lung_cancer_data = preprocess_data(lung_cancer_data)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model(model_path, device)
inputs = torch.tensor(covid19_data.X, dtype=torch.float32).to(device)
with torch.no_grad():  # No gradient calculation needed during inference
    pred_labels = model(inputs)
pred_labels = torch.argmax(pred_labels, dim=1).cpu().numpy() 

metrics = evaluate_model(true_labels, pred_labels)
print("Model evaluation metrics:", metrics)
