# Fine-tune Pre-trained Nicheformer Model for Downstream Tasks

This notebook fine-tunes a pre-trained Nicheformer model for downstream tasks and stores predictions in an AnnData object.

In [1]:
import os
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
from torch.utils.data import DataLoader
import anndata as ad
from typing import Optional, Dict, Any

from nicheformer.models import Nicheformer
from nicheformer.models._nicheformer_fine_tune import NicheformerFineTune
from nicheformer.data.dataset import NicheformerDataset

## Configuration

Set up the configuration parameters for fine-tuning.

In [2]:
config = {
    'data_path': '/home/ubuntu/dev/tsv2_bladder.h5ad',  # Path to your AnnData file
    'technology_mean_path': '/home/ubuntu/dev/sandbox/nicheformer/data/model_means/dissociated_mean_script.npy',  # Path to technology mean file
    'checkpoint_path': '/home/ubuntu/dev/nicheformer.ckpt',  # Path to pre-trained model
    'output_path': 'output/predictions.h5ad',  # Where to save results
    'output_dir': 'output/checkpoints',  # Directory for checkpoints
    
    # Training parameters
    'batch_size': 32,
    'max_seq_len': 1500,
    'aux_tokens': 30,
    'chunk_size': 1000,
    'num_workers': 4,
    'precision': 32,
    'max_epochs': 100,
    'lr': 1e-4,
    'warmup': 10,
    'gradient_clip_val': 1.0,
    'accumulate_grad_batches': 10,
    
    # Model parameters
    'supervised_task': 'niche_classification',  # or whichever task
    'extract_layers': [11],  # Which layers to extract features from
    'function_layers': "mean",  # Architecture of prediction head
    'dim_prediction': 33, # dim of the output vector
    'n_classes': 1,  # only for classification tasks
    'freeze': False,  # Whether to freeze backbone
    'reinit_layers': False,
    'extractor': False,
    'regress_distribution': True,
    'pool': 'mean',
    'predict_density': False,
    'ignore_zeros': False,
    'organ': 'brain',
    'label': 'cell_type'  # The target variable to predict
}

## Load Data and Create Datasets

In [3]:
# Run this updated version to fix the rare class issue
adata = ad.read_h5ad(config['data_path'])
original_technology_mean = np.load(config['technology_mean_path'])
from nicheformer.data.dataset import compute_technology_mean, create_splits

# Compute new technology_mean
technology_mean = compute_technology_mean(adata)
print(f"Technology mean shape: {technology_mean.shape}")

# Check distribution
print(f"\nTotal cells: {len(adata):,}")
print("Cell type counts:")
counts = adata.obs['cell_type'].value_counts()
print(f"Rare classes (< 3 cells): {(counts < 3).sum()}")
print(counts.head(10))

# Create splits with filtering
adata = create_splits(
    adata, 
    train_frac=0.7, 
    val_frac=0.15, 
    test_frac=0.15, 
    random_state=42, 
    stratify_col='cell_type',
    min_cells_per_class=3
)

# Update config
n_unique_cell_types = len(adata.obs['cell_type'].unique())
config['n_classes'] = n_unique_cell_types
print(f"Final n_classes: {n_unique_cell_types}")


Technology mean shape: (33502,)

Total cells: 36,715
Cell type counts:
Rare classes (< 3 cells): 1
cell_type
bladder urothelial cell            22111
fibroblast                          7227
CD8-positive, alpha-beta T cell     2284
monocyte                            1125
myofibroblast cell                  1095
macrophage                           725
CD4-positive, alpha-beta T cell      421
T cell                               412
mast cell                            310
smooth muscle cell                   271
Name: count, dtype: int64
  - endothelial cell: 1 cells
Removing 1 cells from rare classes...
Remaining cells: 36,714
Remaining cell types: 20
Created splits:
  Train: 25,699 cells (70.0%)
  Val:   5,507 cells (15.0%)
  Test:  5,508 cells (15.0%)

Split distribution by 'cell_type':
cell_type          fibroblast  T cell  mast cell  myofibroblast cell  \
nicheformer_split                                                      
test                     1084      62         47      

In [4]:
# Data and technology_mean are already loaded in the previous cell

# Create datasets
train_dataset = NicheformerDataset(
    adata=adata,
    technology_mean=technology_mean,
    split='train',
    max_seq_len=1500,
    aux_tokens=config.get('aux_tokens', 30),
    chunk_size=config.get('chunk_size', 1000),
    metadata_fields = {
        'obs': ['cell_type'],
    }
)

val_dataset = NicheformerDataset(
    adata=adata,
    technology_mean=technology_mean,
    split='val',
    max_seq_len=1500,
    aux_tokens=config.get('aux_tokens', 30),
    chunk_size=config.get('chunk_size', 1000),
    metadata_fields = {
        'obs': ['cell_type'],
        #'obsm': ['X_niche_1'],
    }
)

test_dataset = NicheformerDataset(
    adata=adata,
    technology_mean=technology_mean,
    split='test',
    max_seq_len=1500,
    aux_tokens=config.get('aux_tokens', 30),
    chunk_size=config.get('chunk_size', 1000),
    metadata_fields = {
        'obs': ['cell_type'],
        #'obsm': ['X_niche_1'],
    }
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=config.get('num_workers', 4),
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=config.get('num_workers', 4),
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=config.get('num_workers', 4),
    pin_memory=True
)

Using train split with 25699 cells


100%|██████████| 26/26 [00:40<00:00,  1.55s/it]


Using val split with 5507 cells


100%|██████████| 6/6 [00:07<00:00,  1.28s/it]


Using test split with 5508 cells


100%|██████████| 6/6 [00:07<00:00,  1.29s/it]


## Load Model and Set Up Fine-tuning

In [5]:
# Load pre-trained model
model = Nicheformer.load_from_checkpoint(checkpoint_path=config['checkpoint_path'], strict=False)

print("Model loaded")


Model loaded


In [6]:
# Create fine-tuning model
fine_tune_model = NicheformerFineTune(
    backbone=model,
    supervised_task=config['supervised_task'],
    extract_layers=config['extract_layers'],
    function_layers=config['function_layers'],
    lr=config['lr'],
    warmup=config['warmup'],
    max_epochs=config['max_epochs'],
    dim_prediction=config['dim_prediction'],
    n_classes=config['n_classes'],
    # baseline=config['baseline'],
    freeze=False,
    reinit_layers=config['reinit_layers'],
    extractor=config['extractor'],
    regress_distribution=config['regress_distribution'],
    pool=config['pool'],
    predict_density=config['predict_density'],
    ignore_zeros=config['ignore_zeros'],
    organ=config.get('organ', 'unknown'),
    label=config['label'],
    without_context=True
)

print("Fine-tuning model created")

Fine-tuning model created


In [7]:
# Check backbone model hyperparameters to understand the issue
print("Backbone model hyperparameters:")
print(f"  modality: {model.hparams.modality}")
print(f"  assay: {model.hparams.assay}")  
print(f"  specie: {model.hparams.specie}")

# Configure trainer
trainer = pl.Trainer(
    max_epochs=config['max_epochs'],
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    default_root_dir=config['output_dir'],
    precision=config.get('precision', 32),
    gradient_clip_val=config.get('gradient_clip_val', 1.0),
    accumulate_grad_batches=config.get('accumulate_grad_batches', 10),
)

print("Trainer created")

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Backbone model hyperparameters:
  modality: True
  assay: True
  specie: True
Trainer created


/home/ubuntu/miniforge3/envs/nicheformer_env/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [None]:
! nvidia-smi

Tue Aug 19 22:55:52 2025       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.256.02   Driver Version: 470.256.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:1E.0 Off |                    0 |
| N/A   48C    P0    42W / 300W |   1609MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

: 

## Train and Evaluate Model

In [None]:
# Train the model  
print("Training the model...")
try:
    trainer.fit(
        model=fine_tune_model,
        train_dataloaders=train_loader,
        val_dataloaders=val_loader
    )
    
    # Test the model
    print("Testing the model...")
    test_results = trainer.test(
        model=fine_tune_model,
        dataloaders=test_loader
    )
    
    # Get predictions
    print("Getting predictions...")
    predictions = trainer.predict(fine_tune_model, dataloaders=test_loader)
    predictions = [torch.cat([p[0] for p in predictions]).cpu().numpy(),
                  torch.cat([p[1] for p in predictions]).cpu().numpy()]
    if 'regression' in config['supervised_task']:
        predictions = predictions[0]  # For regression both values are the same
        
except Exception as e:
    print(f"Error during training: {e}")
    print(f"Error type: {type(e).__name__}")
    import traceback
    traceback.print_exc()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Training the model...



  | Name        | Type             | Params | Mode 
---------------------------------------------------------
0 | backbone    | Nicheformer      | 49.3 M | train
1 | linear_head | Linear           | 10.2 K | train
2 | cls_loss    | CrossEntropyLoss | 0      | train
---------------------------------------------------------
49.3 M    Trainable params
0         Non-trainable params
49.3 M    Total params
197.249   Total estimated model params size (MB)
144       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

## Save Results

In [None]:
# Store predictions in AnnData object
prediction_key = f"predictions_{config.get('label', 'X_niche_1')}"
test_mask = adata.obs.nicheformer_split == 'test'

if 'classification' in config['supervised_task']:
    # For classification tasks
    adata.obs.loc[test_mask, f"{prediction_key}_class"] = predictions[0]
    adata.obs.loc[test_mask, f"{prediction_key}_class_probs"] = predictions[1]
else:
    # For regression tasks
    adata.obs.loc[test_mask, prediction_key] = predictions

# Store test metrics
for metric_name, value in test_results[0].items():
    adata.uns[f"{prediction_key}_metrics_{metric_name}"] = value

# Save updated AnnData
adata.write_h5ad(config['output_path'])

print(f"Results saved to {config['output_path']}")