# CTMAP Tutorial

This notebook provides a complete walkthrough of CTMAP on the MERFISH dataset, including full training logs with progress bar and loss values.


## 1. Setup and Imports

This cell:
- Adds the project root directory to Python path (necessary for importing local modules)
- Sets environment variables
- Suppresses common warnings
- Imports required libraries and CTMAP modules

In [3]:
import sys
import os
# Environment settings
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Suppress warnings for cleaner output
import warnings
warnings.filterwarnings("ignore", category=UserWarning, message="Some cells have zero counts")
warnings.filterwarnings("ignore", category=FutureWarning)

# Core libraries
import torch
import numpy as np
import scanpy as sc
from sklearn.metrics import accuracy_score, normalized_mutual_info_score, adjusted_rand_score

# CTMAP modules
from CTMAP.model import CTMAP
from CTMAP.dataprocess import cell_type_encoder, anndata_preprocess, generate_dataloaders

print("\nAll imports successful!")


All imports successful!


## 2. Set random seed and device

This cell:
- Sets a fixed random seed for reproducibility
- Detects and selects GPU (if available) or CPU

In [4]:
seed = 0
print(f"\n{'='*70}")
print(f"CTMAP - MERFISH | Single Run with Seed {seed}")
print(f"{'='*70}")

torch.manual_seed(seed)
np.random.seed(seed)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


CTMAP - MERFISH | Single Run with Seed 0
Using device: cuda:0


## 3. Load data

This cell loads the example MERFISH dataset:
- scRNA-seq reference (`adata_rna.h5ad`)
- Spatial transcriptomics data (`adata_merfish.h5ad`)

Both files are included in the repository under `/dataset/MERFISH/`.

In [5]:
rna_path = "dataset/MERFISH/adata_rna.h5ad"
spatial_path = "dataset/MERFISH/adata_merfish.h5ad"

adata_r = sc.read(rna_path)
adata_s = sc.read(spatial_path)

print(f"Loaded scRNA-seq data: {adata_r.n_obs} cells, {adata_r.n_vars} genes")
print(f"Loaded spatial data: {adata_s.n_obs} spots, {adata_s.n_vars} genes")

Loaded scRNA-seq data: 30370 cells, 21043 genes
Loaded spatial data: 64373 spots, 160 genes


## 4. Preprocessing

This cell performs:
- Gene intersection between RNA and spatial data
- Adding source labels
- Cell type encoding
- Full preprocessing (normalization, log1p, per-batch MaxAbs scaling)
- Creating PyTorch DataLoaders

In [6]:
# Gene intersection
common_genes = adata_r.var_names.intersection(adata_s.var_names)
adata_rna = adata_r[:, common_genes].copy()
adata_spa = adata_s[:, common_genes].copy()

adata_rna.obs['source'] = 'RNA'
adata_spa.obs['source'] = 'MERFISH'
adata_rna.X = adata_rna.X.astype(np.float32)
adata_spa.X = adata_spa.X.astype(np.float32)

# Cell type encoding
_, _, cell_types = cell_type_encoder(adata_rna, adata_spa)

# Full preprocessing
adata_spa, adata_rna = anndata_preprocess(adata_spa, adata_rna)

# Create DataLoaders
rna_train_loader, st_train_loader, rna_test_loader, st_test_loader = generate_dataloaders(adata_spa, adata_rna)

print("\nPreprocessing completed.")
print(f"Final RNA data: {adata_rna.n_obs} cells, {adata_rna.n_vars} genes")
print(f"Final spatial data: {adata_spa.n_obs} spots, {adata_spa.n_vars} genes")

Unique RNA Labels: [ 0  1  2  3  4  5  6  7  8  9 10 11]
Unique SeqFISH Labels: [ 0  1  2  3  5  7 10 11 12]
common Labels: [ 0  1  2  3  5  7 10 11]
Number of classes: 13

Preprocessing completed.
Final RNA data: 30370 cells, 153 genes
Final spatial data: 64373 spots, 153 genes


## 5. Initialize model

This cell creates the CTMAP model instance with the appropriate input dimensions and hyperparameters.

In [7]:
r_dim = len(adata_rna)
s_dim = len(adata_spa)
rna_dim = adata_rna.shape[1]
st_dim = adata_spa.shape[1]

model = CTMAP(
    rna_dim=rna_dim,
    st_dim=st_dim,
    latent_dim=64,
    hidden_dim=256,
    mha_heads_1=4,
    mha_dim_1=256,
    mha_dim_2=128,
    mha_heads_2=4,
    class_num=len(np.unique(adata_rna.obs['cell_type'])),
    device=device
)

print("CTMAP model initialized.")

CTMAP model initialized.


## 6. Train the model

This cell runs the full training process.

**You will see:**
- Progress bar (tqdm)
- Loss values every 100 iterations
- Stage transition messages
- Final completion message

Training takes approximately 2-3 minutes on a GPU.

In [9]:
print("Starting training...\n")

truth_label, pred_label, truth_rna, rna_embeddings, st_embeddings = model.train(
    rna_dim=r_dim,
    spa_dim=s_dim,
    rna_train_loader=rna_train_loader,
    st_train_loader=st_train_loader,
    spatial_coor=adata_spa.obs[["X", "Y"]],
    rna_test_loader=rna_test_loader,
    st_test_loader=st_test_loader,
    lr=5e-4,
    maxiter=4000,
    miditer1=3000,
    log_interval=100,
    stage1_recon_weight=3.0,
    stage1_cls_weight=0.01,
    stage2_recon_weight=4.0,
    stage2_domain_weight=0.1,
    stage2_cls_weight=0.01
)

print("\nTraining completed!")

Starting training...

=== Starting Stage 1: Pretraining (Reconstruction + Classification) ===


Training:   3%|█▌                                                          | 105/4000 [00:05<02:41, 24.18it/s]

#Iter 100: recon_loss: 0.359234, cls loss: 2.482532, domain loss: 0.000000


Training:   5%|███                                                         | 202/4000 [00:09<02:14, 28.14it/s]

#Iter 200: recon_loss: 0.336991, cls loss: 2.486644, domain loss: 0.000000


Training:   8%|████▌                                                       | 304/4000 [00:14<02:55, 21.04it/s]

#Iter 300: recon_loss: 0.316649, cls loss: 2.483740, domain loss: 0.000000


Training:  10%|██████                                                      | 404/4000 [00:18<01:52, 31.84it/s]

#Iter 400: recon_loss: 0.301407, cls loss: 2.485245, domain loss: 0.000000


Training:  12%|███████▌                                                    | 500/4000 [00:21<02:02, 28.49it/s]

#Iter 500: recon_loss: 0.286075, cls loss: 2.484340, domain loss: 0.000000


Training:  15%|█████████                                                   | 607/4000 [00:26<01:52, 30.27it/s]

#Iter 600: recon_loss: 0.270919, cls loss: 2.485596, domain loss: 0.000000


Training:  18%|██████████▋                                                 | 709/4000 [00:29<01:29, 36.77it/s]

#Iter 700: recon_loss: 0.259507, cls loss: 2.484404, domain loss: 0.000000


Training:  20%|████████████                                                | 804/4000 [00:33<02:10, 24.45it/s]

#Iter 800: recon_loss: 0.248520, cls loss: 2.484999, domain loss: 0.000000


Training:  23%|█████████████▌                                              | 903/4000 [00:37<01:25, 36.12it/s]

#Iter 900: recon_loss: 0.231599, cls loss: 2.484593, domain loss: 0.000000


Training:  25%|██████████████▊                                            | 1004/4000 [00:40<01:10, 42.59it/s]

#Iter 1000: recon_loss: 0.218735, cls loss: 2.484868, domain loss: 0.000000


Training:  28%|████████████████▎                                          | 1105/4000 [00:44<01:33, 31.10it/s]

#Iter 1100: recon_loss: 0.208533, cls loss: 2.484675, domain loss: 0.000000


Training:  30%|█████████████████▋                                         | 1203/4000 [00:47<01:21, 34.38it/s]

#Iter 1200: recon_loss: 0.196570, cls loss: 2.484322, domain loss: 0.000000


Training:  33%|███████████████████▏                                       | 1305/4000 [00:52<01:40, 26.84it/s]

#Iter 1300: recon_loss: 0.189481, cls loss: 2.484109, domain loss: 0.000000


Training:  35%|████████████████████▋                                      | 1401/4000 [00:55<01:33, 27.89it/s]

#Iter 1400: recon_loss: 0.178702, cls loss: 2.485225, domain loss: 0.000000


Training:  38%|██████████████████████▏                                    | 1503/4000 [00:59<01:24, 29.67it/s]

#Iter 1500: recon_loss: 0.171051, cls loss: 2.485186, domain loss: 0.000000


Training:  40%|███████████████████████▋                                   | 1606/4000 [01:03<01:12, 32.90it/s]

#Iter 1600: recon_loss: 0.161323, cls loss: 2.484956, domain loss: 0.000000


Training:  43%|█████████████████████████▏                                 | 1705/4000 [01:06<01:17, 29.66it/s]

#Iter 1700: recon_loss: 0.153176, cls loss: 2.484468, domain loss: 0.000000


Training:  45%|██████████████████████████▌                                | 1803/4000 [01:10<01:00, 36.07it/s]

#Iter 1800: recon_loss: 0.145135, cls loss: 2.484408, domain loss: 0.000000


Training:  48%|████████████████████████████▏                              | 1908/4000 [01:13<00:42, 49.33it/s]

#Iter 1900: recon_loss: 0.138371, cls loss: 2.485741, domain loss: 0.000000


Training:  50%|█████████████████████████████▌                             | 2005/4000 [01:15<00:47, 42.31it/s]

#Iter 2000: recon_loss: 0.132464, cls loss: 2.484193, domain loss: 0.000000


Training:  53%|██████████████████████████████▉                            | 2101/4000 [01:19<01:00, 31.25it/s]

#Iter 2100: recon_loss: 0.125699, cls loss: 2.485362, domain loss: 0.000000


Training:  55%|████████████████████████████████▌                          | 2211/4000 [01:22<00:47, 38.00it/s]

#Iter 2200: recon_loss: 0.120433, cls loss: 2.485276, domain loss: 0.000000


Training:  58%|█████████████████████████████████▉                         | 2303/4000 [01:26<01:10, 24.20it/s]

#Iter 2300: recon_loss: 0.113747, cls loss: 2.484219, domain loss: 0.000000


Training:  60%|███████████████████████████████████▍                       | 2404/4000 [01:29<00:38, 40.97it/s]

#Iter 2400: recon_loss: 0.109736, cls loss: 2.485038, domain loss: 0.000000


Training:  63%|████████████████████████████████████▉                      | 2504/4000 [01:33<00:56, 26.59it/s]

#Iter 2500: recon_loss: 0.106162, cls loss: 2.484561, domain loss: 0.000000


Training:  65%|██████████████████████████████████████▍                    | 2607/4000 [01:37<00:36, 38.66it/s]

#Iter 2600: recon_loss: 0.101090, cls loss: 2.484394, domain loss: 0.000000


Training:  68%|███████████████████████████████████████▉                   | 2706/4000 [01:40<00:47, 27.12it/s]

#Iter 2700: recon_loss: 0.096467, cls loss: 2.485038, domain loss: 0.000000


Training:  70%|█████████████████████████████████████████▎                 | 2802/4000 [01:44<00:47, 25.36it/s]

#Iter 2800: recon_loss: 0.094118, cls loss: 2.484840, domain loss: 0.000000


Training:  73%|██████████████████████████████████████████▊                | 2904/4000 [01:48<00:34, 31.40it/s]

#Iter 2900: recon_loss: 0.089586, cls loss: 2.484987, domain loss: 0.000000


Training:  75%|████████████████████████████████████████████▎              | 3003/4000 [01:50<00:31, 31.53it/s]

#Iter 3000: recon_loss: 0.086352, cls loss: 2.484656, domain loss: 0.000000

=== Stage 1 Completed ===
=== Starting Stage 2: Domain Adversarial Training ===



Training:  78%|█████████████████████████████████████████████▊             | 3103/4000 [01:54<00:33, 26.61it/s]

#Iter 3100: recon_loss: 0.081760, cls loss: 2.485259, domain loss: 1.383129


Training:  80%|███████████████████████████████████████████████▎           | 3204/4000 [01:59<00:33, 23.42it/s]

#Iter 3200: recon_loss: 0.077953, cls loss: 2.484393, domain loss: 1.385284


Training:  83%|████████████████████████████████████████████████▋          | 3301/4000 [02:03<00:30, 22.60it/s]

#Iter 3300: recon_loss: 0.073374, cls loss: 2.484798, domain loss: 1.385005


Training:  85%|██████████████████████████████████████████████████▏        | 3401/4000 [02:08<00:28, 20.84it/s]

#Iter 3400: recon_loss: 0.070070, cls loss: 2.485437, domain loss: 1.385204


Training:  88%|███████████████████████████████████████████████████▋       | 3502/4000 [02:12<00:20, 24.35it/s]

#Iter 3500: recon_loss: 0.068012, cls loss: 2.484529, domain loss: 1.385411


Training:  90%|█████████████████████████████████████████████████████▏     | 3603/4000 [02:17<00:12, 32.19it/s]

#Iter 3600: recon_loss: 0.065351, cls loss: 2.484610, domain loss: 1.385383


Training:  92%|██████████████████████████████████████████████████████▌    | 3700/4000 [02:21<00:08, 35.32it/s]

#Iter 3700: recon_loss: 0.061653, cls loss: 2.485136, domain loss: 1.385414


Training:  95%|████████████████████████████████████████████████████████   | 3801/4000 [02:26<00:07, 26.50it/s]

#Iter 3800: recon_loss: 0.060417, cls loss: 2.484929, domain loss: 1.385444


Training:  98%|█████████████████████████████████████████████████████████▌ | 3900/4000 [02:30<00:03, 31.58it/s]

#Iter 3900: recon_loss: 0.057458, cls loss: 2.485142, domain loss: 1.385459


Training: 100%|██████████████████████████████████████████████████████████▉| 3999/4000 [02:33<00:00, 44.62it/s]

#Iter 4000: recon_loss: 0.055800, cls loss: 2.484606, domain loss: 1.385740


Training: 100%|███████████████████████████████████████████████████████████| 4000/4000 [02:40<00:00, 24.91it/s]


Training completed!





## 7. Evaluation

This cell computes Accuracy, NMI, and ARI on the predicted spatial cell types.

In [10]:
truth_ct = [cell_types[i] for i in truth_label]
pred_ct = [cell_types[i] for i in pred_label]

accuracy = accuracy_score(truth_ct, pred_ct)
nmi = normalized_mutual_info_score(truth_ct, pred_ct)
ari = adjusted_rand_score(truth_ct, pred_ct)

print(f"\n{'='*70}")
print("FINAL RESULTS")
print(f"Accuracy: {accuracy:.4f}, NMI: {nmi:.4f}, ARI: {ari:.4f}")
print(f"{'='*70}")


FINAL RESULTS
Accuracy: 0.8907, NMI: 0.7833, ARI: 0.7888
