# CTMAP Tutorial (Detailed Training Log)

**C**ell **T**ype **M**apping with **A**dversarial **P**rofile alignment

This notebook provides a complete, step-by-step walkthrough of CTMAP on the MERFISH dataset, with **full training logs** displayed (progress bar, loss values, stage transitions, and final metrics).

Run each cell sequentially to reproduce the results.

## 1. Import libraries

In [None]:
import os
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import torch
import numpy as np
import scanpy as sc
from sklearn.metrics import accuracy_score, normalized_mutual_info_score, adjusted_rand_score

from dataprocess import cell_type_encoder, anndata_preprocess, generate_dataloaders
from model import CTMAP

## 2. Load and preprocess data

In [None]:
rna_path = "dataset/MERFISH/adata_rna.h5ad"
spatial_path = "dataset/MERFISH/adata_merfish.h5ad"

adata_rna = sc.read(rna_path)
adata_spatial = sc.read(spatial_path)

common_genes = adata_rna.var_names.intersection(adata_spatial.var_names)
adata_rna = adata_rna[:, common_genes].copy()
adata_spatial = adata_spatial[:, common_genes].copy()

adata_rna.obs['source'] = 'RNA'
adata_spatial.obs['source'] = 'MERFISH'

_, _, cell_types = cell_type_encoder(adata_rna, adata_spatial)
adata_spatial, adata_rna = anndata_preprocess(adata_spatial, adata_rna)

rna_train_loader, st_train_loader, rna_test_loader, st_test_loader = generate_dataloaders(adata_spatial, adata_rna)

print(f"Preprocessing completed.")
print(f"RNA cells: {len(adata_rna)}, Spatial spots: {len(adata_spatial)}, Genes: {adata_rna.shape[1]}")

## 3. Initialize model

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = CTMAP(
    rna_dim=adata_rna.shape[1],
    st_dim=adata_spatial.shape[1],
    latent_dim=64,
    hidden_dim=256,
    mha_heads_1=4,
    mha_dim_1=256,
    mha_dim_2=128,
    mha_heads_2=4,
    class_num=len(np.unique(adata_rna.obs['cell_type'])),
    device=device
)

## 4. Train model (full detailed logs will be displayed)

In [None]:
# This cell will show the complete training process with progress bar and loss logs

truth_label, pred_label, truth_rna, rna_embeddings, st_embeddings = model.train(
    rna_n_cells=len(adata_rna),
    st_n_cells=len(adata_spatial),
    rna_train_loader=rna_train_loader,
    st_train_loader=st_train_loader,
    spatial_coor=adata_spatial.obs[["X", "Y"]],
    rna_test_loader=rna_test_loader,
    st_test_loader=st_test_loader,
    lr=5e-4,
    maxiter=4000,
    miditer1=3000,
    log_interval=100
)

print("\nTraining completed successfully!")

## 5. Evaluation

In [None]:
truth_ct = [cell_types[i] for i in truth_label]
pred_ct = [cell_types[i] for i in pred_label]

acc = accuracy_score(truth_ct, pred_ct)
nmi = normalized_mutual_info_score(truth_ct, pred_ct)
ari = adjusted_rand_score(truth_ct, pred_ct)

print(f"Final Results:")
print(f"Accuracy: {acc:.4f}")
print(f"NMI:      {nmi:.4f}")
print(f"ARI:      {ari:.4f}")

## Expected Results

| Metric    | Value  |
|-----------|--------|
| Accuracy  | 0.8907 |
| NMI       | 0.7833 |
| ARI       | 0.7888 |