In [1]:
import sys
from pathlib import Path

# Get the root directory of the project
project_root = Path("/home/lxz/scmamba/KCellFM_tutorial/spatial_transcriptomics").parent.parent
# project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from tqdm import tqdm
import pandas as pd
import numpy as np
import os
import sys
import random
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from models.model import MambaModel 
from models.gene_tokenizer import GeneVocab

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed(42)

In [4]:
# configuration parameters
class Config:
    # Data parameters
    train_data_path = "/home/lxz/scmamba/空转/Hubmap_SB_cross/train.csv"
    model_save_dir = "/home/lxz/scmamba/model_state/"
    os.makedirs(model_save_dir, exist_ok=True) 

    # Model parameters
    embsize = 512
    nhead = 8
    d_hid = 512
    nlayers = 6
    dropout = 0.1
    pad_token = "<pad>"
    max_seq_len = 50
    input_emb_style = "continuous"
    cell_emb_style = "cls"
    mask_value = -1
    pad_value = -2
    vocab_path = "/home/lxz/scmamba/vocab.json"
    pretrained_model_path = "/home/lxz/scmamba/model_state/cell_cls_3loss_6layer_final.pth"

    # Training parameters
    epochs = 10
    batch_size = 96  
    lr = 2e-4  
    weight_decay = 1e-3 
    val_split = 0.3  
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Category Mapping
    celltype_to_id = {
        "B": 0,
        "CD4T": 1,
        "CD7_Immune": 2,
        "CD8T": 3,
        "DC": 4,
        "Endothelial": 5,
        "Enterocyte_ITLN1p": 6,
        "Goblet": 7,
        "ICC": 8,
        "Lymphatic": 9,
        "Macrophage": 10,
        "Nerve": 11,
        "Neuroendocrine": 12,
        "Neutrophil": 13,
        "Plasma": 14,
        "Stroma": 15,
        "TA": 16
    }
    class_num = len(celltype_to_id)

    # Gene Mapping
    gene_to_id = {
        "MUC2": 17183, "SOX9": 32052, "MUC1": 17175, "CD31": 19330, "Synapto": 32742,
        "CD49f": 12272, "CD15": 9687, "CHGA": 4894, "CDX2": 4568, "ITLN1": 12308,
        "CD4": 4380, "CD127": 12051, "Vimentin": 35192, "HLADR": 11044, "CD8": 4412,
        "CD11c": 12283, "CD44": 4383, "CD16": 9286, "BCL2": 3080, "CD123": 36627,
        "CD38": 4376, "CD90": 33320, "aSMA": 1391, "CD21": 5523, "NKG2D": 12911,
        "CD66": 4589, "CD57": 2956, "CD206": 16892, "CD68": 4397, "CD34": 4373,
        "aDef5": 7546, "CD7": 4399, "CD36": 4374, "CD138": 30796, "Cytokeratin": 41736,
        "CK7": 12989, "CD117": 12801, "CD19": 4335, "Podoplanin": 19298, "CD45": 20664,
        "CD56": 17477, "CD69": 4398, "Ki67": 16711, "CD49a": 12264, "CD163": 4329,
        "CD161": 12901
    }


config = Config()

In [5]:
class SpatialDataset(Dataset):
    def __init__(self, csv_file, gene_to_id, celltype_to_id):
        self.data = pd.read_csv(csv_file)
        self.gene_to_id = gene_to_id
        self.celltype_to_id = celltype_to_id
        
        self.genes = [col for col in self.data.columns if col in gene_to_id]
        
        valid_cell_types = set(celltype_to_id.keys())
        self.data = self.data[self.data['cell_type_A'].isin(valid_cell_types)].reset_index(drop=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Obtain gene ID and expression value
        vocab_ids = []
        expr_values = []
        for gene in self.genes:
            vocab_ids.append(self.gene_to_id[gene])
            expr_values.append(row[gene])

        # Add CLS token (assuming 60695 is the CLS token ID)
        vocab_ids = [60695] + vocab_ids
        expr_values = [0.0] + expr_values

        # Sequence filling/truncation
        if len(vocab_ids) > config.max_seq_len:
            vocab_ids = vocab_ids[:config.max_seq_len]
            expr_values = expr_values[:config.max_seq_len]
        else:
            padding_length = config.max_seq_len - len(vocab_ids)
            vocab_ids += [60694] * padding_length  # 60694 is pad token ID
            expr_values += [config.pad_value] * padding_length

        # Create padding mask (True means mask required)
        padding_mask = [False] * config.max_seq_len
        for i in range(len(vocab_ids)):
            if vocab_ids[i] == 60694:
                padding_mask[i] = True

        # Obtain cell type ID
        cell_type = row['cell_type_A']
        cell_type_id = self.celltype_to_id[cell_type]

        # spatial coordinates
        x, y = row['x'], row['y']

        return {
            'src': torch.tensor(vocab_ids, dtype=torch.long),
            'values': torch.tensor(expr_values, dtype=torch.float),
            'padding_mask': torch.tensor(padding_mask, dtype=torch.bool),
            'celltype': torch.tensor(cell_type_id, dtype=torch.long),
            'coordinates': torch.tensor([x, y], dtype=torch.float)
        }

In [6]:
# Evaluate function (correct key name errors, remove redundant parameters)
def evaluate(model, loader, device, criterion_cls):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            src = batch['src'].to(device)
            values = batch['values'].to(device)
            padding_mask = batch['padding_mask'].to(device)
            cell_types = batch['celltype'].to(device)

            with torch.cuda.amp.autocast(enabled=True):
                model_output = model(
                    src=src,
                    values=values,
                    src_key_padding_mask=padding_mask
                )
                loss = criterion_cls(model_output["cls_output"], cell_types)
                total_loss += loss.item() * src.size(0)

                preds = torch.argmax(model_output["cls_output"], dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(cell_types.cpu().numpy())

    avg_loss = total_loss / len(loader.dataset)
    accuracy = np.sum(np.array(all_preds) == np.array(all_labels)) / len(all_labels)
    return avg_loss, accuracy, all_preds, all_labels

In [7]:
def train():
    # Device setting
    device = config.device
    torch.cuda.set_device(device)
    print(f"Using device: {device}")

    # Load dataset
    full_dataset = SpatialDataset(
        csv_file=config.train_data_path,
        gene_to_id=config.gene_to_id,
        celltype_to_id=config.celltype_to_id
    )
    print(f"Total number of training data samples: {len(full_dataset)}")

    # Correctly obtain all labels for stratified sampling
    all_labels = [full_dataset[i]['celltype'].item() for i in range(len(full_dataset))]

    # Divide the training set and validation set
    train_idx, val_idx = train_test_split(
        np.arange(len(full_dataset)),
        test_size=config.val_split,
        stratify=all_labels,
        random_state=42
    )

    train_dataset = Subset(full_dataset, train_idx)
    val_dataset = Subset(full_dataset, val_idx)
    print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")

    # Create data loader
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        drop_last=False
    )

    # Initialize model
    vocab = GeneVocab.from_file(config.vocab_path)
    ntokens = len(vocab)

    model = MambaModel(
        ntokens, config.embsize, config.nhead, config.d_hid, config.nlayers,
        dropout=config.dropout, pad_token=config.pad_token,
        pad_value=config.pad_value, input_emb_style=config.input_emb_style,
        cell_emb_style=config.cell_emb_style, class_num=config.class_num
    ).to(device)

    # Load pre-training weights
    try:
        pretrained_dict = torch.load(config.pretrained_model_path, map_location=device)
        model_dict = model.state_dict()

        # Filter the loadable weights (excluding classification heads)
        pretrained_dict_filtered = {
            k: v for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape and 'cls_decoder' not in k
        }

        model_dict.update(pretrained_dict_filtered)
        model.load_state_dict(model_dict)

        print(f"Successfully loaded {len(pretrained_dict_filtered)} pre training layer weights")
        print("Initialize classification header weights...")
        # Reinitialize the classification header
        nn.init.kaiming_normal_(model.cls_decoder.out_layer.weight, mode='fan_in', nonlinearity='relu')
        if model.cls_decoder.out_layer.bias is not None:
            nn.init.zeros_(model.cls_decoder.out_layer.bias)
    except Exception as e:
        print(f"Failed to load pre training weights: {str(e)}，Random initialization weights will be used")

    # Optimizer and loss function
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=config.lr,
        weight_decay=config.weight_decay
    )
    criterion_cls = nn.CrossEntropyLoss()
    scaler = torch.cuda.amp.GradScaler(enabled=True)

    # Training loop
    best_val_acc = 0.0
    best_model_path = os.path.join(config.model_save_dir, "spatial_classifier_SB_cross_best.pth")

    for epoch in range(config.epochs):
        model.train()
        total_train_loss = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{config.epochs}")
        for batch in pbar:
            src = batch['src'].to(device)
            values = batch['values'].to(device)
            padding_mask = batch['padding_mask'].to(device)
            cell_types = batch['celltype'].to(device)

            optimizer.zero_grad()

            with torch.cuda.amp.autocast(enabled=True):
                model_output = model(
                    src=src,
                    values=values,
                    src_key_padding_mask=padding_mask
                )
                loss = criterion_cls(model_output["cls_output"], cell_types)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            total_train_loss += loss.item() * src.size(0)
            pbar.set_postfix(train_loss=loss.item())

        # Calculate the average training loss
        avg_train_loss = total_train_loss / len(train_dataset)

        # Verify
        val_loss, val_acc, val_preds, val_labels = evaluate(model, val_loader, device, criterion_cls)

        # Print epoch result
        print(f"\nEpoch {epoch + 1} result:")
        print(f"Training loss: {avg_train_loss:.4f} | Validation loss: {val_loss:.4f} | Validation accuracy: {val_acc:.4f}")

        # Save the best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), best_model_path)
            print(f"Save the best model (Validation accuracy: {best_val_acc:.4f}) to {best_model_path}")

    # Print the best results after the training is completed
    print(f"\nTraining completed! Best Verification Accuracy: {best_val_acc:.4f} (The model is saved in {best_model_path})")

    # Load the best model and print a detailed report of the validation set
    model.load_state_dict(torch.load(best_model_path))
    _, _, val_preds, val_labels = evaluate(model, val_loader, device, criterion_cls)
    print("\nDetailed report of the best model on the validation set:")
    print(classification_report(
        val_labels,
        val_preds,
        target_names=config.celltype_to_id.keys(),
        digits=4,
        zero_division=0
    ))

In [8]:
if __name__ == "__main__":
    train()

Using device: cuda:0
Total number of training data samples: 71884
Training set size: 50318, Validation set size: 21566
Successfully loaded 111 pre training layer weights
Initialize classification header weights...


Epoch 1/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 525/525 [01:20<00:00,  6.55it/s, train_loss=0.82]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 225/225 [00:06<00:00, 32.58it/s]



Epoch 1 result:
Training loss: 0.5215 | Validation loss: 0.3693 | Validation accuracy: 0.8735
Save the best model (Validation accuracy: 0.8735) to /home/lxz/scmamba/model_state/spatial_classifier_SB_cross_best.pth


Epoch 2/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 525/525 [00:58<00:00,  9.03it/s, train_loss=0.562]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 225/225 [00:06<00:00, 33.32it/s]



Epoch 2 result:
Training loss: 0.3079 | Validation loss: 0.3390 | Validation accuracy: 0.8751
Save the best model (Validation accuracy: 0.8751) to /home/lxz/scmamba/model_state/spatial_classifier_SB_cross_best.pth


Epoch 4/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 525/525 [01:06<00:00,  7.95it/s, train_loss=0.238]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 225/225 [00:12<00:00, 18.04it/s]



Epoch 4 result:
Training loss: 0.2371 | Validation loss: 0.2736 | Validation accuracy: 0.9047
Save the best model (Validation accuracy: 0.9047) to /home/lxz/scmamba/model_state/spatial_classifier_SB_cross_best.pth


Epoch 5/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 525/525 [01:21<00:00,  6.41it/s, train_loss=0.142]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 225/225 [00:06<00:00, 33.40it/s]



Epoch 5 result:
Training loss: 0.2161 | Validation loss: 0.2706 | Validation accuracy: 0.9079
Save the best model (Validation accuracy: 0.9079) to /home/lxz/scmamba/model_state/spatial_classifier_SB_cross_best.pth


Epoch 6/10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 525/525 [01:00<00:00,  8.71it/s, train_loss=0.0307]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 225/225 [00:06<00:00, 32.65it/s]



Epoch 6 result:
Training loss: 0.1969 | Validation loss: 0.2675 | Validation accuracy: 0.9094
Save the best model (Validation accuracy: 0.9094) to /home/lxz/scmamba/model_state/spatial_classifier_SB_cross_best.pth


Epoch 7/10:  68%|█████████████████████████████████████████████████████████████████████████████████▌                                      | 357/525 [00:46<00:20,  8.23it/s, train_loss=0.184]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Epoch 10/10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 525/525 [01:02<00:00,  8.41it/s, train_loss=0.386]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 225/225 [00:06<00:00, 32.47it/s]



Epoch 10 result:
Training loss: 0.1381 | Validation loss: 0.2987 | Validation accuracy: 0.9105

Training completed! Best Verification Accuracy: 0.9139 (The model is saved in /home/lxz/scmamba/model_state/spatial_classifier_SB_cross_best.pth)


Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 225/225 [00:06<00:00, 32.22it/s]



Detailed report of the best model on the validation set:
                   precision    recall  f1-score   support

                B     0.8869    0.8896    0.8882       661
             CD4T     0.8801    0.8745    0.8773      1578
       CD7_Immune     0.7609    0.6731    0.7143        52
             CD8T     0.9166    0.9233    0.9199      1238
               DC     0.8905    0.7488    0.8135       402
      Endothelial     0.9821    0.9831    0.9826      1951
Enterocyte_ITLN1p     1.0000    0.4516    0.6222        62
           Goblet     0.8827    0.8795    0.8811      2207
              ICC     0.8709    0.9385    0.9034       309
        Lymphatic     0.8693    0.8159    0.8418       766
       Macrophage     0.9459    0.9132    0.9293      2812
            Nerve     0.9114    0.9169    0.9142       999
   Neuroendocrine     0.8750    0.9686    0.9194       159
       Neutrophil     0.9106    0.9333    0.9218       120
           Plasma     0.9328    0.9396    0.9362      34