In [1]:
import sys
from pathlib import Path

# Get the root directory of the project
project_root = Path("/home/lxz/scmamba/KCellFM_tutorial/T_cancer_cell").parent.parent
# project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

In [2]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from tqdm import tqdm
import scanpy as sc
import numpy as np
from scipy import sparse
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from models.model import MambaModel
from models.gene_tokenizer import GeneVocab

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed(42)

In [4]:
epochs = 10
batch_size = 32
embsize = 512
nhead = 8
d_hid = 512
nlayers = 6
fine_tune_layers = 2
dropout = 0.1
lr =2e-4  # 1e-5
weight_decay = 1e-3  # 1e-4
pad_token = "<pad>"
max_seq_len = 4096
input_emb_style = "continuous"
cell_emb_style = "cls"
mask_value = -1
pad_value = -2
val_split = 0.3
model_save_dir = "/home/lxz/scmamba/model_state/"
os.makedirs(model_save_dir, exist_ok=True)

In [5]:
vocab = GeneVocab.from_file("/home/lxz/scmamba/vocab.json")
ntokens = len(vocab)

In [6]:
celltype_to_id = {
    'T cell': 0,
    'CD8-positive, alpha-beta cytotoxic T cell': 1,
    'naive thymus-derived CD4-positive, alpha-beta T cell': 2,
    'effector CD8-positive, alpha-beta T cell': 3,
    'effector memory CD8-positive, alpha-beta T cell': 4,
    'central memory CD4-positive, alpha-beta T cell': 5,
    'gamma-delta T cell': 6
}
class_num = len(celltype_to_id)

In [7]:
class SingleCellDataset(Dataset):
    def __init__(self, adata):
        self.adata = adata
        self.cell_ids = adata.obs_names.tolist()
        self.gene_names = adata.var.feature_name.tolist()

        self.nonzero_indices = {}
        expr_matrix = adata.X.toarray() if sparse.issparse(adata.X) else adata.X
        for i, cell_id in enumerate(self.cell_ids):
            self.nonzero_indices[cell_id] = np.where(expr_matrix[i] != 0)[0]

    def __len__(self):
        return len(self.cell_ids)

    def __getitem__(self, idx):
        cell_id = self.cell_ids[idx]
        cell_type = self.adata.obs.loc[cell_id, 'cell_type']

        nonzero_idx = self.nonzero_indices[cell_id]
        expr_values = self.adata.X[idx, nonzero_idx].toarray().flatten() \
            if sparse.issparse(self.adata.X) \
            else self.adata.X[idx, nonzero_idx]
        gene_names = [self.gene_names[i] for i in nonzero_idx]

        gene_ids = []
        filtered_expr = []
        for gene, value in zip(gene_names, expr_values):
            if gene in vocab:
                gene_ids.append(vocab[gene])
                filtered_expr.append(value)

        if len(gene_ids) > max_seq_len - 1:  # -1
            selected = np.random.choice(len(gene_ids), max_seq_len - 1, replace=False)
            gene_ids = [gene_ids[i] for i in selected]
            filtered_expr = [filtered_expr[i] for i in selected]

        gene_ids = [vocab["<cls>"]] + gene_ids
        filtered_expr = [0.0] + filtered_expr  # CLS to 0

        padding_len = max_seq_len - len(gene_ids)
        if padding_len > 0:
            gene_ids += [vocab["<pad>"]] * padding_len
            filtered_expr += [pad_value] * padding_len

        padding_mask = [id_ == vocab["<pad>"] for id_ in gene_ids]

        return {
            'src': torch.LongTensor(gene_ids),
            'values': torch.FloatTensor(filtered_expr),
            'padding_mask': torch.BoolTensor(padding_mask),
            'celltype': torch.tensor(celltype_to_id[cell_type], dtype=torch.long)
        }


In [8]:
def evaluate(model, loader, device, criterion):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            src = batch['src'].to(device)
            values = batch['values'].to(device)
            padding_mask = batch['padding_mask'].to(device)
            cell_types = batch['celltype'].to(device)

            with torch.cuda.amp.autocast(enabled=True):
                model_output = model(
                    src=src,
                    values=values,
                    src_key_padding_mask=padding_mask
                )
                loss = criterion(model_output["cls_output"], cell_types)
                total_loss += loss.item() * src.size(0)

                preds = torch.argmax(model_output["cls_output"], dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(cell_types.cpu().numpy())

    avg_loss = total_loss / len(loader.dataset)
    accuracy = np.sum(np.array(all_preds) == np.array(all_labels)) / len(all_labels)
    return avg_loss, accuracy, all_preds, all_labels

In [9]:
def train():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    torch.cuda.set_device(device)

    # Load dataset
    print("Start loading h5ad file...")
    adata_train = sc.read("/mnt/HHD16T/DATA/lxz/cancer/T_train.h5ad")
    print("H5ad file loading completed, cell count:", adata_train.n_obs)

    print("Start building the dataset...")
    full_dataset = SingleCellDataset(adata_train)
    print("Dataset construction completed, total sample size:", len(full_dataset))

    # Retrieve all labels for stratified sampling
    all_labels = [full_dataset[i]['celltype'].item() for i in range(len(full_dataset))]

    # Divide the training set and validation set
    train_idx, val_idx = train_test_split(
        np.arange(len(full_dataset)),
        test_size=val_split,
        stratify=all_labels,
        random_state=42
    )

    train_dataset = Subset(full_dataset, train_idx)
    val_dataset = Subset(full_dataset, val_idx)
    print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")

    # Create dataset loader
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        # pin_memory=True,
        # drop_last=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        # pin_memory=True,
        # drop_last=False
    )

    # Initialize model
    model = MambaModel(
        ntokens, embsize, nhead, d_hid, nlayers,
        vocab=vocab, dropout=dropout, pad_token=pad_token,
        pad_value=pad_value, input_emb_style=input_emb_style,
        cell_emb_style=cell_emb_style, class_num=class_num
    ).to(device)

    # Load pre trained weights (skip classification header)
    try:
        pretrained_dict = torch.load("/home/lxz/scmamba/model_state/cell_cls_3loss_6layer_final.pth",
                                     map_location=device)
        model_dict = model.state_dict()

        pretrained_dict = {k: v for k, v in pretrained_dict.items()
                           if k in model_dict and v.shape == model_dict[k].shape
                           and not k.startswith('cls_decoder.out_layer')}

        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        print("Successfully loaded pre training weights (excluding classification head weights)")

        # Reinitialize the classification header
        print("Initialize classification header weights...")
        nn.init.kaiming_normal_(model.cls_decoder.out_layer.weight, mode='fan_in', nonlinearity='relu')
        if model.cls_decoder.out_layer.bias is not None:
            nn.init.zeros_(model.cls_decoder.out_layer.bias)
    except Exception as e:
        print(f"Failed to load pre training weights: {str(e)}")

    # Freeze and Unfreeze Logic
    print(f"Freeze the {nlayers - fine_tune_layers} layer before freezing, and only fine tune the final {fine_tune_layers} layer and classifier")

    # Check if the mamba_decoder structure meets expectations
    if hasattr(model, 'mamba_encoder') and isinstance(model.mamba_encoder, nn.ModuleList):
        # Verify if the number of layers matches the configuration
        actual_layers = len(model.mamba_encoder)
        if actual_layers != nlayers:
            raise ValueError(f"The number of layers in the mamba_decoder model does not match the configuration. The actual number is {practical_1ayers}, while the expected number is {nlayers}")

        # Freeze the previous layer
        for i in range(nlayers - fine_tune_layers):
            layer = model.mamba_encoder[i]
            for param in layer.parameters():
                param.requires_grad = False
            print(f"Freeze layer {i}: Parameters frozen")

        # Unfreeze the last two layers
        for i in range(nlayers - fine_tune_layers, nlayers):
            layer = model.mamba_encoder[i]
            for param in layer.parameters():
                param.requires_grad = True
            print(f"Unfreezing layer {i}: parameters trainable")
    else:
        raise AttributeError("Mamba_decoder not found in the model or it is not of type nn.ModuleList")

    # Ensure that the classifier is trainable
    if hasattr(model, 'cls_decoder'):
        for param in model.cls_decoder.parameters():
            param.requires_grad = True
        print("The classifier parameters have been set to trainable")
    else:
        raise AttributeError("The cls_decoder classifier was not found in the model")

    # Only optimize the parameters that need to be trained
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    print(f"Number of trainable parameters: {sum(p.numel() for p in trainable_params)}")

    # Optimizer and loss function
    optimizer = torch.optim.AdamW(trainable_params, lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    scaler = torch.cuda.amp.GradScaler(enabled=True)

    # training loop
    best_val_acc = 0.0
    best_model_path = os.path.join(model_save_dir, "cancer_Tcell_2_layers_best_ipynb.pth")

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")
        for batch in pbar:
            src = batch['src'].to(device)
            values = batch['values'].to(device)
            padding_mask = batch['padding_mask'].to(device)
            cell_types = batch['celltype'].to(device)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=True):
                output = model(src=src, values=values, src_key_padding_mask=padding_mask)
                loss = criterion(output["cls_output"], cell_types)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            total_train_loss += loss.item() * src.size(0)
            pbar.set_postfix({'loss': loss.item()})

        # Calculate the average training loss
        avg_train_loss = total_train_loss / len(train_dataset)

        # Verify
        val_loss, val_acc, val_preds, val_labels = evaluate(model, val_loader, device, criterion)

        # Print epoch result
        print(f"\nEpoch {epoch + 1} result:")
        print(f"Training loss: {avg_train_loss:.4f} | Validation loss: {val_loss:.4f} | Validation accuracy: {val_acc:.4f}")

        # Save the best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), best_model_path)
            print(f"Save the best mdoel (Validation accuracy: {best_val_acc:.4f}) to {best_model_path}")

    # Print the best results after the training is completed
    print(f"\nTraining completed! Best Verification Accuracy: {best_val_acc:.4f} (The model is saved in {best_model_path})")

    # Load the best model and print a detailed report of the validation set
    model.load_state_dict(torch.load(best_model_path))
    _, _, val_preds, val_labels = evaluate(model, val_loader, device, criterion)
    print("\nDetailed report of the best model on the validation set:")
    print(classification_report(
        val_labels,
        val_preds,
        target_names=celltype_to_id.keys(),
        digits=4,
        zero_division=0
    ))

    # Save the final model
    final_model_path = os.path.join(model_save_dir, "cancer_Tcell_2_layers_best_final_ipynb.pth")
    torch.save(model.state_dict(), final_model_path)
    print(f"The final model has been saved as {final_model_path}")

In [10]:
if __name__ == "__main__":
    train()

Start loading h5ad file...
H5ad file loading completed, cell count: 30717
Start building the dataset...
Dataset construction completed, total sample size: 30717
Training set size: 21501, Validation set size: 9216
Successfully loaded pre training weights (excluding classification head weights)
Initialize classification header weights...
Freeze the 4 layer before freezing, and only fine tune the final 2 layer and classifier
Freeze layer 0: Parameters frozen
Freeze layer 1: Parameters frozen
Freeze layer 2: Parameters frozen
Freeze layer 3: Parameters frozen
Unfreezing layer 4: parameters trainable
Unfreezing layer 5: parameters trainable
The classifier parameters have been set to trainable
Number of trainable parameters: 35853512


Epoch 1/10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 672/672 [07:31<00:00,  1.49it/s, loss=0.425]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 288/288 [01:00<00:00,  4.80it/s]



Epoch 1 result:
Training loss: 0.5457 | Validation loss: 0.4610 | Validation accuracy: 0.8296
Save the best mdoel (Validation accuracy: 0.8296) to /home/lxz/scmamba/model_state/cancer_Tcell_2_layers_best_ipynb.pth


Epoch 2/10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 672/672 [07:13<00:00,  1.55it/s, loss=0.389]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 288/288 [01:00<00:00,  4.80it/s]



Epoch 2 result:
Training loss: 0.3658 | Validation loss: 0.4778 | Validation accuracy: 0.8242


Epoch 3/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 672/672 [07:12<00:00,  1.55it/s, loss=0.0444]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 288/288 [01:00<00:00,  4.80it/s]



Epoch 3 result:
Training loss: 0.2393 | Validation loss: 0.5103 | Validation accuracy: 0.8379
Save the best mdoel (Validation accuracy: 0.8379) to /home/lxz/scmamba/model_state/cancer_Tcell_2_layers_best_ipynb.pth


Epoch 4/10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 672/672 [07:11<00:00,  1.56it/s, loss=0.119]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 288/288 [01:00<00:00,  4.80it/s]



Epoch 4 result:
Training loss: 0.1250 | Validation loss: 0.6819 | Validation accuracy: 0.8356


Epoch 5/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 672/672 [07:12<00:00,  1.55it/s, loss=0.0144]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 288/288 [01:00<00:00,  4.79it/s]



Epoch 5 result:
Training loss: 0.0663 | Validation loss: 0.8277 | Validation accuracy: 0.8346


Epoch 6/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 672/672 [07:13<00:00,  1.55it/s, loss=0.0146]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 288/288 [01:00<00:00,  4.80it/s]



Epoch 6 result:
Training loss: 0.0345 | Validation loss: 1.1379 | Validation accuracy: 0.8279


Epoch 7/10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 672/672 [07:25<00:00,  1.51it/s, loss=0.000978]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 288/288 [01:00<00:00,  4.79it/s]



Epoch 7 result:
Training loss: 0.0255 | Validation loss: 1.1546 | Validation accuracy: 0.8338


Epoch 8/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 672/672 [07:09<00:00,  1.56it/s, loss=0.0492]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 288/288 [01:00<00:00,  4.80it/s]



Epoch 8 result:
Training loss: 0.0211 | Validation loss: 1.2757 | Validation accuracy: 0.8315


Epoch 9/10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 672/672 [07:18<00:00,  1.53it/s, loss=0.000494]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 288/288 [01:07<00:00,  4.28it/s]



Epoch 9 result:
Training loss: 0.0252 | Validation loss: 1.2131 | Validation accuracy: 0.8371


Epoch 10/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 672/672 [07:14<00:00,  1.54it/s, loss=0.202]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 288/288 [01:06<00:00,  4.36it/s]



Epoch 10 result:
Training loss: 0.0242 | Validation loss: 1.2684 | Validation accuracy: 0.8256

Training completed! Best Verification Accuracy: 0.8379 (The model is saved in /home/lxz/scmamba/model_state/cancer_Tcell_2_layers_best_ipynb.pth)


Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 288/288 [01:01<00:00,  4.69it/s]



Detailed report of the best model on the validation set:
                                                      precision    recall  f1-score   support

                                              T cell     0.9457    0.9812    0.9631      3831
           CD8-positive, alpha-beta cytotoxic T cell     0.8194    0.7510    0.7837      1450
naive thymus-derived CD4-positive, alpha-beta T cell     0.7758    0.9114    0.8382      1219
            effector CD8-positive, alpha-beta T cell     0.7594    0.6542    0.7029      1206
     effector memory CD8-positive, alpha-beta T cell     0.6317    0.6359    0.6338       758
      central memory CD4-positive, alpha-beta T cell     0.6794    0.4884    0.5683       473
                                  gamma-delta T cell     0.7722    0.9355    0.8460       279

                                            accuracy                         0.8379      9216
                                           macro avg     0.7691    0.7654    0.7623      9216
