In [1]:
import sys
from pathlib import Path

# Get the root directory of the project
project_root = Path("/home/lxz/scmamba/KCellFM_tutorial/T_cancer_cell").parent.parent
# project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, accuracy_score
import scanpy as sc
import numpy as np
from scipy import sparse
from tqdm import tqdm
import pickle
from models.model import MambaModel
from models.gene_tokenizer import GeneVocab

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Use the same configuration parameters as training
batch_size = 32
embsize = 512
nhead = 8
d_hid = 512
nlayers = 6
dropout = 0.1
pad_token = "<pad>"
max_seq_len = 4096
input_emb_style = "continuous"
cell_emb_style = "cls"
pad_value = -2

In [4]:
# Cell type mapping
celltype_to_id = {
    'T cell': 0,
    'CD8-positive, alpha-beta cytotoxic T cell': 1,
    'naive thymus-derived CD4-positive, alpha-beta T cell': 2,
    'effector CD8-positive, alpha-beta T cell': 3,
    'effector memory CD8-positive, alpha-beta T cell': 4,
    'central memory CD4-positive, alpha-beta T cell': 5,
    'gamma-delta T cell': 6
}
id_to_celltype = {v: k for k, v in celltype_to_id.items()}
class_num = len(celltype_to_id)

In [5]:
# load vocab
vocab = GeneVocab.from_file("/home/lxz/scmamba/vocab.json")
ntokens = len(vocab)

In [6]:
# Dataset loader
class SingleCellDataset:
    def __init__(self, adata):
        self.adata = adata
        self.cell_ids = adata.obs_names.tolist()
        self.gene_names = adata.var.feature_name.tolist()

        # Pre-calculate the index of non-zero expressed genes
        self.nonzero_indices = {}
        expr_matrix = adata.X.toarray() if sparse.issparse(adata.X) else adata.X
        for i, cell_id in enumerate(self.cell_ids):
            self.nonzero_indices[cell_id] = np.where(expr_matrix[i] != 0)[0]

    def __len__(self):
        return len(self.cell_ids)

    def __getitem__(self, idx):
        cell_id = self.cell_ids[idx]
        cell_type = self.adata.obs.loc[cell_id, 'cell_type']

        nonzero_idx = self.nonzero_indices[cell_id]
        expr_values = self.adata.X[idx, nonzero_idx].toarray().flatten() \
            if sparse.issparse(self.adata.X) \
            else self.adata.X[idx, nonzero_idx]
        gene_names = [self.gene_names[i] for i in nonzero_idx]

        gene_ids = []
        filtered_expr = []
        for gene, value in zip(gene_names, expr_values):
            if gene in vocab:
                gene_ids.append(vocab[gene])
                filtered_expr.append(value)

        if len(gene_ids) > max_seq_len - 1:
            selected = np.random.choice(len(gene_ids), max_seq_len - 1, replace=False)
            gene_ids = [gene_ids[i] for i in selected]
            filtered_expr = [filtered_expr[i] for i in selected]

        gene_ids = [vocab["<cls>"]] + gene_ids
        filtered_expr = [0.0] + filtered_expr

        padding_len = max_seq_len - len(gene_ids)
        if padding_len > 0:
            gene_ids += [vocab["<pad>"]] * padding_len
            filtered_expr += [pad_value] * padding_len

        padding_mask = [id_ == vocab["<pad>"] for id_ in gene_ids]

        return {
            'src': torch.LongTensor(gene_ids),
            'values': torch.FloatTensor(filtered_expr),
            'padding_mask': torch.BoolTensor(padding_mask),
            'celltype': torch.tensor(celltype_to_id[cell_type], dtype=torch.long)
        }


In [47]:
def evaluate():
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    torch.cuda.set_device(device)

    # Load test data
    print("Load test set h5ad file...")
    adata_test = sc.read("/mnt/HHD16T/DATA/lxz/cancer/T_test.h5ad")
    print(f"Test set loading completed, cell count: {adata_test.n_obs}")

    # Create a test dataset
    test_dataset = SingleCellDataset(adata_test)
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4
    )

    # Initialize model
    model = MambaModel(
        ntokens, embsize, nhead, d_hid, nlayers,
        vocab=vocab, dropout=dropout, pad_token=pad_token,
        pad_value=pad_value, input_emb_style=input_emb_style,
        cell_emb_style=cell_emb_style, class_num=class_num
    ).to(device)

    # Load the weights of the trained model
#     model_path = '/home/lxz/scmamba/model_state/cancer_Tcell_2_layers_best_final_ipynb.pth'
    model_path = '/home/lxz/scmamba/model_state/cancer_Tcell_2_layers_best_final_ipynb.pth'
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    print(f"Model weights have been loaded: {model_path}")

    # Initialize the collection of variables
    all_preds = []
    all_labels = []
    all_cell_embs = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            src = batch['src'].to(device)
            values = batch['values'].to(device)
            padding_mask = batch['padding_mask'].to(device)
            cell_types = batch['celltype'].to(device)

            # forward propagation
            model_output = model(
                src=src,
                values=values,
                src_key_padding_mask=padding_mask
            )

            # Obtain predicted categories and cell embeddings
            preds = torch.argmax(model_output["cls_output"], dim=1)
            cell_embs = model_output["cell_emb"].cpu().numpy()

            # collect results
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(cell_types.cpu().numpy())
            all_cell_embs.append(cell_embs)

    # Splicing the embeddings of all batches together
    all_cell_embs = np.concatenate(all_cell_embs, axis=0)  # Shape [N, 512]

    # Save cell embedding
    embedding_path = "/home/lxz/scmamba/aicase1/cell_embeddings_ipynb.npy"
    np.save(embedding_path, all_cell_embs)
    print(f"Cell embedding has been saved to: {embedding_path}")

    # Save labels and preds to a .pkl file
    pkl_path = "/home/lxz/scmamba/aicase1/labels_preds_ipynb.pkl"
    with open(pkl_path, "wb") as f:
        pickle.dump({
            "labels": all_labels,  # True label
            "preds": all_preds  # prediction label
        }, f)
    print(f"The labels and predicted results have been saved to: {pkl_path}")

    # Compute accuracy
    if len(all_preds) > 0:
        acc = accuracy_score(all_labels, all_preds)
        print(f"\ntest accuracy: {acc:.4f}")

        # Print classification report
        print("\nDetailed classification report:")
        print(classification_report(
            all_labels,
            all_preds,
            target_names=list(celltype_to_id.keys()),
            digits=4
        ))

In [49]:
if __name__ == "__main__":
    evaluate()

Load test set h5ad file...
Test set loading completed, cell count: 6383
Model weights have been loaded: /home/lxz/scmamba/model_state/cancer_Tcell_2_layers_best_final_ipynb.pth


Testing: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [02:26<00:00,  1.36it/s]

Cell embedding has been saved to: /home/lxz/scmamba/aicase1/cell_embeddings_ipynb.npy
The labels and predicted results have been saved to: /home/lxz/scmamba/aicase1/labels_preds_ipynb.pkl

test accuracy: 0.8123

Detailed classification report:
                                                      precision    recall  f1-score   support

                                              T cell     0.9479    0.9639    0.9558      1662
           CD8-positive, alpha-beta cytotoxic T cell     0.7909    0.8476    0.8183      1575
naive thymus-derived CD4-positive, alpha-beta T cell     0.8557    0.8711    0.8633      1055
            effector CD8-positive, alpha-beta T cell     0.7434    0.5253    0.6156       910
     effector memory CD8-positive, alpha-beta T cell     0.6572    0.7623    0.7059       732
      central memory CD4-positive, alpha-beta T cell     0.6650    0.6769    0.6709       390
                                  gamma-delta T cell     0.6905    0.4915    0.5743        59

  


