In [1]:
import sys
from pathlib import Path

# Get the root directory of the project
project_root = Path("/home/lxz/scmamba/KCellFM_tutorial/spatial_transcriptomics").parent.parent
# project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from tqdm import tqdm
import pandas as pd
import numpy as np
import os
import sys
import random
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from models.model import MambaModel  
from models.gene_tokenizer import GeneVocab

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed(42)

In [4]:
# configuration parameters
class Config:
    # Data parameters
    train_data_path = "/home/lxz/scmamba/空转/Hubmap_SB_intra/train.csv"
    model_save_dir = "/home/lxz/scmamba/model_state/"
    os.makedirs(model_save_dir, exist_ok=True)  

    # Model parameters
    embsize = 512
    nhead = 8
    d_hid = 512
    nlayers = 6
    dropout = 0.1
    pad_token = "<pad>"
    max_seq_len = 50
    input_emb_style = "continuous"
    cell_emb_style = "cls"
    mask_value = -1
    pad_value = -2
    vocab_path = "/home/lxz/scmamba/vocab.json"
    pretrained_model_path = "/home/lxz/scmamba/model_state/cell_cls_3loss_6layer_final.pth"

    # Training parameters
    epochs = 10
    batch_size = 96  
    lr = 5e-5  
    weight_decay = 1e-4
    val_split = 0.3  
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    # Category Mapping
    celltype_to_id = {
        "CD4T": 0,
        "CD7_Immune": 1,
        "CD8T": 2,
        "DC": 3,
        "Endothelial": 4,
        "Goblet": 5,
        "ICC": 6,
        "Lymphatic": 7,
        "Macrophage": 8,
        "Nerve": 9,
        "Neuroendocrine": 10,
        "Neutrophil": 11,
        "Plasma": 12,
        "Stroma": 13,
        "TA": 14
    }
    class_num = len(celltype_to_id)

    # Gene Mapping
    gene_to_id = {
        "MUC2": 17183, "SOX9": 32052, "MUC1": 17175, "CD31": 19330, "Synapto": 32742,
        "CD49f": 12272, "CD15": 9687, "CHGA": 4894, "CDX2": 4568, "ITLN1": 12308,
        "CD4": 4380, "CD127": 12051, "Vimentin": 35192, "HLADR": 11044, "CD8": 4412,
        "CD11c": 12283, "CD44": 4383, "CD16": 9286, "BCL2": 3080, "CD123": 36627,
        "CD38": 4376, "CD90": 33320, "aSMA": 1391, "CD21": 5523, "NKG2D": 12911,
        "CD66": 4589, "CD57": 2956, "CD206": 16892, "CD68": 4397, "CD34": 4373,
        "aDef5": 7546, "CD7": 4399, "CD36": 4374, "CD138": 30796, "Cytokeratin": 41736,
        "CK7": 12989, "CD117": 12801, "CD19": 4335, "Podoplanin": 19298, "CD45": 20664,
        "CD56": 17477, "CD69": 4398, "Ki67": 16711, "CD49a": 12264, "CD163": 4329,
        "CD161": 12901
    }


config = Config()

In [5]:
# Dataset
class SpatialDataset(Dataset):
    def __init__(self, csv_file, gene_to_id, celltype_to_id):
        self.data = pd.read_csv(csv_file)
        self.gene_to_id = gene_to_id
        self.celltype_to_id = celltype_to_id
        
        self.genes = [col for col in self.data.columns if col in gene_to_id]
        
        valid_cell_types = set(celltype_to_id.keys())
        self.data = self.data[self.data['cell_type_A'].isin(valid_cell_types)].reset_index(drop=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Gene ID and expression
        vocab_ids = []
        expr_values = []
        for gene in self.genes:
            vocab_ids.append(self.gene_to_id[gene])
            expr_values.append(row[gene])

        # Add CLS token (assuming 60695 is the CLS token ID)
        vocab_ids = [60695] + vocab_ids
        expr_values = [0.0] + expr_values

        # Sequence filling/truncation
        if len(vocab_ids) > config.max_seq_len:
            vocab_ids = vocab_ids[:config.max_seq_len]
            expr_values = expr_values[:config.max_seq_len]
        else:
            padding_length = config.max_seq_len - len(vocab_ids)
            vocab_ids += [60694] * padding_length  # 60694 is pad token ID
            expr_values += [config.pad_value] * padding_length

        # Create padding mask (True means mask required)
        padding_mask = [False] * config.max_seq_len
        for i in range(len(vocab_ids)):
            if vocab_ids[i] == 60694:
                padding_mask[i] = True

        # Obtain cell type ID
        cell_type = row['cell_type_A']
        cell_type_id = self.celltype_to_id[cell_type]

        # spatial coordinates
        x, y = row['x'], row['y']

        return {
            'src': torch.tensor(vocab_ids, dtype=torch.long),
            'values': torch.tensor(expr_values, dtype=torch.float),
            'padding_mask': torch.tensor(padding_mask, dtype=torch.bool),
            'celltype': torch.tensor(cell_type_id, dtype=torch.long),
            'coordinates': torch.tensor([x, y], dtype=torch.float)
        }

In [6]:
# Evaluation function
def evaluate(model, loader, device, criterion_cls):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            src = batch['src'].to(device)
            values = batch['values'].to(device)
            padding_mask = batch['padding_mask'].to(device)
            cell_types = batch['celltype'].to(device)  

            with torch.cuda.amp.autocast(enabled=True):
                model_output = model(
                    src=src,
                    values=values,
                    src_key_padding_mask=padding_mask
                )
                loss = criterion_cls(model_output["cls_output"], cell_types)
                total_loss += loss.item() * src.size(0)  

                preds = torch.argmax(model_output["cls_output"], dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(cell_types.cpu().numpy())

    avg_loss = total_loss / len(loader.dataset)
    accuracy = np.sum(np.array(all_preds) == np.array(all_labels)) / len(all_labels)
    return avg_loss, accuracy, all_preds, all_labels

In [7]:
def train():
    # Device setting
    device = config.device
    torch.cuda.set_device(device)
    print(f"Using device: {device}")

    # load dataset
    full_dataset = SpatialDataset(
        csv_file=config.train_data_path,
        gene_to_id=config.gene_to_id,
        celltype_to_id=config.celltype_to_id
    )
    print(f"Total number of training data samples: {len(full_dataset)}")

    
    all_labels = [full_dataset[i]['celltype'].item() for i in range(len(full_dataset))]

    # Divide the training set and validation set
    train_idx, val_idx = train_test_split(
        np.arange(len(full_dataset)),
        test_size=config.val_split,
        stratify=all_labels,
        random_state=42
    )

    train_dataset = Subset(full_dataset, train_idx)
    val_dataset = Subset(full_dataset, val_idx)
    print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")

    # Create data loader
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        drop_last=False
    )

    # Initialize model
    vocab = GeneVocab.from_file(config.vocab_path)
    ntokens = len(vocab)

    model = MambaModel(
        ntokens, config.embsize, config.nhead, config.d_hid, config.nlayers,
        dropout=config.dropout, pad_token=config.pad_token,
        pad_value=config.pad_value, input_emb_style=config.input_emb_style,
        cell_emb_style=config.cell_emb_style, class_num=config.class_num
    ).to(device)

    # load pre-training weights
    try:
        pretrained_dict = torch.load(config.pretrained_model_path, map_location=device)
        model_dict = model.state_dict()

        # Filter the loadable weights (excluding classification heads)
        pretrained_dict_filtered = {
            k: v for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape and 'cls_decoder' not in k
        }

        model_dict.update(pretrained_dict_filtered)
        model.load_state_dict(model_dict)

        print(f"Successfully load {len(pretrained_dict_filtered)} pre training layer weights")
        print("Initialize classification header weights...")
        # Reinitialize the classification header
        nn.init.kaiming_normal_(model.cls_decoder.out_layer.weight, mode='fan_in', nonlinearity='relu')
        if model.cls_decoder.out_layer.bias is not None:
            nn.init.zeros_(model.cls_decoder.out_layer.bias)
    except Exception as e:
        print(f"Failed to load pre training weights: {str(e)}，Random initialization weights will be used")

    # Optimizer and loss function
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=config.lr,
        weight_decay=config.weight_decay
    )
    criterion_cls = nn.CrossEntropyLoss()
    scaler = torch.cuda.amp.GradScaler(enabled=True)

    # Training loop
    best_val_acc = 0.0
    best_model_path = os.path.join(config.model_save_dir, "spatial_classifier_SB_intra_best.pth")

    for epoch in range(config.epochs):
        model.train()
        total_train_loss = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{config.epochs}")
        for batch in pbar:
            src = batch['src'].to(device)
            values = batch['values'].to(device)
            padding_mask = batch['padding_mask'].to(device)
            cell_types = batch['celltype'].to(device)

            optimizer.zero_grad()

            with torch.cuda.amp.autocast(enabled=True):
                model_output = model(
                    src=src,
                    values=values,
                    src_key_padding_mask=padding_mask
                )
                loss = criterion_cls(model_output["cls_output"], cell_types)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            total_train_loss += loss.item() * src.size(0)
            pbar.set_postfix(train_loss=loss.item())

        # Calculate the average training loss
        avg_train_loss = total_train_loss / len(train_dataset)

        # Verify
        val_loss, val_acc, val_preds, val_labels = evaluate(model, val_loader, device, criterion_cls)

        # Print epoch result
        print(f"\nEpoch {epoch + 1} result:")
        print(f"Training loss: {avg_train_loss:.4f} | Validation loss: {val_loss:.4f} | Validation accuracy: {val_acc:.4f}")

        # Save the best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), best_model_path)
            print(f"Save the best model (Validaiton accuracy: {best_val_acc:.4f}) to {best_model_path}")

    # Print the best results after the training is completed
    print(f"\nTraining completed! Best Verification Accuracy: {best_val_acc:.4f} (The model is saved in {best_model_path})")

    # Load the best model and print a detailed report of the validation set
    model.load_state_dict(torch.load(best_model_path))
    _, _, val_preds, val_labels = evaluate(model, val_loader, device, criterion_cls)
    print("\nDetailed report of the best model on the validation set:")
    print(classification_report(
        val_labels,
        val_preds,
        target_names=config.celltype_to_id.keys(),
        digits=4,
        zero_division=0
    ))

In [8]:
if __name__ == "__main__":
    train()

Using device: cuda:1
Total number of training data samples: 31036
Training set size: 21725, Validation set size: 9311
Successfully load 111 pre training layer weights
Initialize classification header weights...


Epoch 1/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 227/227 [00:45<00:00,  4.96it/s, train_loss=0.674]
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97/97 [00:03<00:00, 32.29it/s]



Epoch 1 result:
Training loss: 0.7145 | Validation loss: 0.3910 | Validation accuracy: 0.8694
Save the best model (Validaiton accuracy: 0.8694) to /home/lxz/scmamba/model_state/spatial_classifier_SB_intra_best.pth


Epoch 2/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 227/227 [00:27<00:00,  8.37it/s, train_loss=0.692]
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97/97 [00:03<00:00, 32.29it/s]



Epoch 2 result:
Training loss: 0.3626 | Validation loss: 0.3576 | Validation accuracy: 0.8674


Epoch 3/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 227/227 [00:27<00:00,  8.20it/s, train_loss=0.388]
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97/97 [00:02<00:00, 32.49it/s]



Epoch 3 result:
Training loss: 0.2907 | Validation loss: 0.3207 | Validation accuracy: 0.8852
Save the best model (Validaiton accuracy: 0.8852) to /home/lxz/scmamba/model_state/spatial_classifier_SB_intra_best.pth


Epoch 4/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 227/227 [00:27<00:00,  8.14it/s, train_loss=0.185]
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97/97 [00:03<00:00, 31.33it/s]



Epoch 4 result:
Training loss: 0.2515 | Validation loss: 0.2914 | Validation accuracy: 0.8951
Save the best model (Validaiton accuracy: 0.8951) to /home/lxz/scmamba/model_state/spatial_classifier_SB_intra_best.pth


Epoch 5/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 227/227 [00:25<00:00,  8.85it/s, train_loss=0.193]
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97/97 [00:03<00:00, 30.32it/s]



Epoch 5 result:
Training loss: 0.2290 | Validation loss: 0.3219 | Validation accuracy: 0.8831


Epoch 6/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 227/227 [00:25<00:00,  9.01it/s, train_loss=0.235]
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97/97 [00:02<00:00, 32.63it/s]



Epoch 6 result:
Training loss: 0.2091 | Validation loss: 0.2942 | Validation accuracy: 0.8968
Save the best model (Validaiton accuracy: 0.8968) to /home/lxz/scmamba/model_state/spatial_classifier_SB_intra_best.pth


Epoch 7/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 227/227 [00:25<00:00,  8.98it/s, train_loss=0.022]
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97/97 [00:03<00:00, 32.29it/s]



Epoch 7 result:
Training loss: 0.1865 | Validation loss: 0.3063 | Validation accuracy: 0.8964


Epoch 8/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 227/227 [00:25<00:00,  8.97it/s, train_loss=0.211]
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97/97 [00:03<00:00, 31.93it/s]



Epoch 8 result:
Training loss: 0.1705 | Validation loss: 0.3102 | Validation accuracy: 0.8956


Epoch 9/10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 227/227 [00:26<00:00,  8.63it/s, train_loss=0.0968]
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97/97 [00:03<00:00, 32.12it/s]



Epoch 9 result:
Training loss: 0.1534 | Validation loss: 0.3164 | Validation accuracy: 0.8983
Save the best model (Validaiton accuracy: 0.8983) to /home/lxz/scmamba/model_state/spatial_classifier_SB_intra_best.pth


Epoch 10/10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 227/227 [00:26<00:00,  8.45it/s, train_loss=0.0739]
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97/97 [00:03<00:00, 31.84it/s]



Epoch 10 result:
Training loss: 0.1441 | Validation loss: 0.3094 | Validation accuracy: 0.9010
Save the best model (Validaiton accuracy: 0.9010) to /home/lxz/scmamba/model_state/spatial_classifier_SB_intra_best.pth

Training completed! Best Verification Accuracy: 0.9010 (The model is saved in /home/lxz/scmamba/model_state/spatial_classifier_SB_intra_best.pth)


Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97/97 [00:03<00:00, 32.17it/s]


Detailed report of the best model on the validation set:
                precision    recall  f1-score   support

          CD4T     0.8454    0.9163    0.8794       800
    CD7_Immune     0.8590    0.8864    0.8725       220
          CD8T     0.8953    0.8844    0.8898      1073
            DC     0.8000    0.8511    0.8247        94
   Endothelial     0.9718    0.9277    0.9492       816
        Goblet     0.9064    0.9397    0.9228      1742
           ICC     0.9508    0.7733    0.8529        75
     Lymphatic     0.8475    0.8475    0.8475       295
    Macrophage     0.9312    0.8808    0.9053       906
         Nerve     0.8712    0.8689    0.8700       366
Neuroendocrine     0.9122    0.9639    0.9373       194
    Neutrophil     0.9189    0.8500    0.8831        40
        Plasma     0.9169    0.9458    0.9311      1108
        Stroma     0.8406    0.7301    0.7815       289
            TA     0.9014    0.8770    0.8891      1293

      accuracy                         0.901


