# Model Integration Guide

How to add a new foundation model to TumorImagingBench.

## Learning Objectives

After this tutorial, you will be able to:
- Understand the BaseModel architecture
- Implement a new model extractor
- Register your model with the framework
- Test your model
- Use your model in feature extraction pipelines

## Part 1: BaseModel Architecture

### Abstract Base Class

All models in TumorImagingBench inherit from `BaseModel` and must implement three abstract methods:

```python
from abc import ABC, abstractmethod
import torch.nn as nn

class BaseModel(ABC, nn.Module):
    """Base class for all foundation model feature extractors."""

    @abstractmethod
    def load(self, weights_path: str):
        """Load model weights from file."""
        pass

    @abstractmethod
    def preprocess(self, x):
        """Preprocess input data before forward pass."""
        pass

    @abstractmethod
    def forward(self, x):
        """Forward pass of the model."""
        pass
```

### Key Requirements

1. **Inherit from BaseModel** - Your class must inherit from `BaseModel`
2. **Implement load()** - Load pre-trained weights
3. **Implement preprocess()** - Convert input dict to tensor
4. **Implement forward()** - Return features as numpy array
5. **PyTorch Module** - Inherit from `nn.Module` for GPU support

## Part 2: Input/Output Specifications

### Input to preprocess()

In [1]:
# Input format to preprocess() method
example_input = {
    'image_path': '/path/to/scan.nii.gz',  # Path to NIFTI file
    'coordX': 100.5,                        # X centroid in physical coordinates (mm)
    'coordY': 150.3,                        # Y centroid in physical coordinates (mm)
    'coordZ': 200.1,                        # Z centroid in physical coordinates (mm)
    # Additional fields are optional but preserved
    'label': 1,
    'patient_id': 'P001'
}

print("Input format to preprocess():")
for key, value in example_input.items():
    print(f"  {key}: {value}")

Input format to preprocess():
  image_path: /path/to/scan.nii.gz
  coordX: 100.5
  coordY: 150.3
  coordZ: 200.1
  label: 1
  patient_id: P001


### Output of preprocess() and Input to forward()

In [2]:
import torch
import numpy as np

# Output of preprocess() and input to forward()
preprocessed_tensor = torch.randn(1, 1, 48, 48, 48)
print(f"Tensor shape from preprocess(): {preprocessed_tensor.shape}")
print(f"  - Dimension 0: 1 (or C) = channels (usually 1 for CT)")
print(f"  - Dimension 1-4: (H, W, D) = spatial dimensions (usually 48×48×48)")

batch_tensor = torch.randn(4, 1, 48, 48, 48)
print(f"\nTensor shape to forward(): {batch_tensor.shape}")
print(f"  - Dimension 0: 4 = batch size")
print(f"  - Dimensions 1-4: (C, H, W, D) = channels and spatial")

Tensor shape from preprocess(): torch.Size([1, 1, 48, 48, 48])
  - Dimension 0: 1 (or C) = channels (usually 1 for CT)
  - Dimension 1-4: (H, W, D) = spatial dimensions (usually 48×48×48)

Tensor shape to forward(): torch.Size([4, 1, 48, 48, 48])
  - Dimension 0: 4 = batch size
  - Dimensions 1-4: (C, H, W, D) = channels and spatial


### Output of forward()

In [3]:
# Output of forward() method
batch_size = 4
feature_dim = 512

features = np.random.randn(batch_size, feature_dim).astype(np.float32)
print(f"Output of forward(): {features.shape}")
print(f"  - Type: {type(features)}")
print(f"  - dtype: {features.dtype}")
print(f"  - Dimension 0: {batch_size} = batch size")
print(f"  - Dimension 1: {feature_dim} = feature dimension")
print(f"\nIMPORTANT: Must be numpy array on CPU, not torch tensor on GPU")

Output of forward(): (4, 512)
  - Type: <class 'numpy.ndarray'>
  - dtype: float32
  - Dimension 0: 4 = batch size
  - Dimension 1: 512 = feature dimension

IMPORTANT: Must be numpy array on CPU, not torch tensor on GPU


## Part 3: Complete Example - DummyResNetExtractor

Let's examine the DummyResNetExtractor which serves as a reference implementation:

In [4]:
import sys
sys.path.insert(0, '/home/suraj/Repositories/TumorImagingBench/src')

from tumorimagingbench.models import get_extractor

# Get the DummyResNetExtractor class
DummyResNet = get_extractor('DummyResNetExtractor')

# Instantiate
model = DummyResNet()
print(f"Model class: {model.__class__.__name__}")
print(f"Is PyTorch module: {isinstance(model, torch.nn.Module)}")



In the future `np.bool` will be defined as the corresponding NumPy scalar.
In the future `np.bool` will be defined as the corresponding NumPy scalar.
In the future `np.bool` will be defined as the corresponding NumPy scalar.


✓ Registered extractor: CTFMExtractor
✓ Registered extractor: FMCIBExtractor
✓ Registered extractor: MerlinExtractor
✓ Registered extractor: ModelsGenExtractor
✓ Registered extractor: PASTAExtractor
✓ Registered extractor: SUPREMExtractor
✓ Registered extractor: VISTA3DExtractor
✓ Registered extractor: VocoExtractor
✓ Registered extractor: DummyResNetExtractor
Model class: DummyResNetExtractor
Is PyTorch module: True


In the future `np.bool` will be defined as the corresponding NumPy scalar.
In the future `np.bool` will be defined as the corresponding NumPy scalar.
required package for reader ITKReader is not installed, or the version doesn't match requirement.


## Part 4: Step-by-Step Tutorial - Creating Your Own Model

Let's create a simplified custom model step by step.

### Step 1: Import Required Libraries

In [5]:
import torch
import torch.nn as nn
from tumorimagingbench.models import BaseModel
from tumorimagingbench.models.utils import get_transforms

print("✓ Imports successful")

✓ Imports successful


### Step 2: Create the Model Class

In [6]:
class SimpleModelExtractor(BaseModel):
    """
    A simple model for educational purposes.
    
    This model demonstrates:
    - Inheriting from BaseModel
    - Implementing required abstract methods
    - Using MONAI preprocessing
    - Extracting fixed-size features
    """

    def __init__(self):
        """Initialize the model."""
        super().__init__()
        
        # Simple feature extractor
        # Input: (1, 48, 48, 48)
        # Output: (256,)
        self.backbone = nn.Sequential(
            nn.Conv3d(1, 8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(2),
            nn.Conv3d(8, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool3d(1),
        )
        
        # Linear projection to feature space
        self.head = nn.Linear(16, 256)
        
        # Preprocessing pipeline
        self.transforms = get_transforms(
            orient="RAS",
            scale_range=(-1024, 2048),
            spatial_size=(48, 48, 48),
            spacing=(1, 1, 1),
        )
        
        self.feature_dim = 256

    def load(self, weights_path=None):
        """Load model weights."""
        # In this simple example, we skip weight loading
        # In practice, load from weights_path if provided
        self.eval()

    def preprocess(self, x):
        """Preprocess input."""
        return self.transforms(x)

    def forward(self, x):
        """Extract features."""
        with torch.no_grad():
            # Backbone
            features = self.backbone(x)  # (batch, 16, 1, 1, 1)
            features = features.view(features.shape[0], -1)  # (batch, 16)
            
            # Head
            features = self.head(features)  # (batch, 256)
            
            # Move to CPU and convert to numpy
            return features.cpu().numpy()

print("✓ Model class defined")

✓ Model class defined


### Step 3: Register the Model

In [7]:
from tumorimagingbench.models import register_extractor, get_available_extractors

# Register the model
register_extractor('SimpleModelExtractor', SimpleModelExtractor)

# Verify it's registered
available = get_available_extractors()
print(f"Registered models: {available}")
assert 'SimpleModelExtractor' in available, "Model not registered!"
print("✓ Model successfully registered")

✓ Registered extractor: SimpleModelExtractor
Registered models: ['CTFMExtractor', 'FMCIBExtractor', 'MerlinExtractor', 'ModelsGenExtractor', 'PASTAExtractor', 'SUPREMExtractor', 'VISTA3DExtractor', 'VocoExtractor', 'DummyResNetExtractor', 'SimpleModelExtractor']
✓ Model successfully registered


### Step 4: Test Your Model

In [8]:
# Instantiate your model
model = SimpleModelExtractor()
print(f"✓ Model instantiated: {model.__class__.__name__}")

# Load weights
model.load()
print("✓ Weights loaded")

# Move to GPU if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
print(f"✓ Model moved to {device}")

In the future `np.bool` will be defined as the corresponding NumPy scalar.
In the future `np.bool` will be defined as the corresponding NumPy scalar.


✓ Model instantiated: SimpleModelExtractor
✓ Weights loaded
✓ Model moved to cuda


### Step 5: Extract Features

In [9]:
# Create dummy input
batch_size = 4
dummy_input = torch.randn(batch_size, 1, 48, 48, 48, device=device)
print(f"Input shape: {dummy_input.shape}")
print(f"Input device: {dummy_input.device}")

# Extract features
features = model.forward(dummy_input)

print(f"\nOutput shape: {features.shape}")
print(f"Output type: {type(features)}")
print(f"Output dtype: {features.dtype}")
print(f"\nFeature statistics:")
print(f"  Mean: {features.mean():.6f}")
print(f"  Std: {features.std():.6f}")
print(f"  Min: {features.min():.6f}")
print(f"  Max: {features.max():.6f}")

Input shape: torch.Size([4, 1, 48, 48, 48])
Input device: cuda:0

Output shape: (4, 256)
Output type: <class 'numpy.ndarray'>
Output dtype: float32

Feature statistics:
  Mean: 0.010754
  Std: 0.229787
  Min: -0.633219
  Max: 0.545278


### Step 6: Verify Output Format

In [10]:
import numpy as np

# Verify output format
checks = [
    (isinstance(features, np.ndarray), "Output is numpy array"),
    (features.shape[0] == batch_size, f"Batch size preserved ({batch_size})"),
    (features.shape[1] == model.feature_dim, f"Feature dimension correct ({model.feature_dim})"),
    (features.dtype in [np.float32, np.float64], "Output dtype is float"),
]

print("Output format verification:")
for passed, description in checks:
    status = "✓" if passed else "✗"
    print(f"  {status} {description}")

if all(check[0] for check in checks):
    print("\n✓ All checks passed!")
else:
    print("\n✗ Some checks failed!")

Output format verification:
  ✓ Output is numpy array
  ✓ Batch size preserved (4)
  ✓ Feature dimension correct (256)
  ✓ Output dtype is float

✓ All checks passed!


## Part 5: Best Practices

### Best Practice 1: Comprehensive Documentation

In [11]:
# Example of well-documented model
class WellDocumentedModel(BaseModel):
    """
    Well-documented model extractor.
    
    Architecture:
        - Input: 3D volumetric image (1, 48, 48, 48)
        - Backbone: ResNet-50 with 3D convolutions
        - Output: 2048-dimensional feature vector
        - Feature aggregation: Global average pooling
    
    Pre-training:
        - Dataset: ImageNet-1k (with 2D-to-3D adaptation)
        - Task: Image classification
        - Weights: Publicly available from pytorch/vision
    
    Input Requirements:
        - NIFTI format (.nii.gz)
        - CT intensities: -1024 to 2048 HU
        - Physical coordinates in mm
    
    Output:
        - Feature dimension: 2048
        - Output range: Unbounded (ReLU features)
        - Suitable for: Classification, clustering, similarity
    
    Examples:
        >>> model = WellDocumentedModel()
        >>> model.load()
        >>> dummy = torch.randn(1, 1, 48, 48, 48)
        >>> features = model.forward(dummy)
        >>> print(features.shape)
        (1, 2048)
    """
    pass

print(WellDocumentedModel.__doc__)


    Well-documented model extractor.

    Architecture:
        - Input: 3D volumetric image (1, 48, 48, 48)
        - Backbone: ResNet-50 with 3D convolutions
        - Output: 2048-dimensional feature vector
        - Feature aggregation: Global average pooling

    Pre-training:
        - Dataset: ImageNet-1k (with 2D-to-3D adaptation)
        - Task: Image classification
        - Weights: Publicly available from pytorch/vision

    Input Requirements:
        - NIFTI format (.nii.gz)
        - CT intensities: -1024 to 2048 HU
        - Physical coordinates in mm

    Output:
        - Feature dimension: 2048
        - Output range: Unbounded (ReLU features)
        - Suitable for: Classification, clustering, similarity

    Examples:
        >>> model = WellDocumentedModel()
        >>> model.load()
        >>> dummy = torch.randn(1, 1, 48, 48, 48)
        >>> features = model.forward(dummy)
        >>> print(features.shape)
        (1, 2048)
    


### Best Practice 2: Use Standard Preprocessing

In [12]:
# Use get_transforms() from utils for consistency
from tumorimagingbench.models.utils import get_transforms

# This ensures consistency across all models
transforms = get_transforms(
    orient="RAS",              # Standard orientation
    scale_range=(-1024, 2048), # CT intensity range
    spatial_size=(48, 48, 48), # Standard patch size
    spacing=(1, 1, 1),         # Standard spacing
)

print("Standard preprocessing pipeline created")
print("Benefits:")
print("  - Consistency across models")
print("  - MONAI-based (standard in medical imaging)")
print("  - Handles NIFTI loading, orientation, resampling, cropping")

Standard preprocessing pipeline created
Benefits:
  - Consistency across models
  - MONAI-based (standard in medical imaging)
  - Handles NIFTI loading, orientation, resampling, cropping


In the future `np.bool` will be defined as the corresponding NumPy scalar.
In the future `np.bool` will be defined as the corresponding NumPy scalar.


### Best Practice 3: GPU Efficiency

In [13]:
# Good forward pass implementation
def efficient_forward(model, x, device='cuda'):
    """
    Efficient forward pass for GPU inference.
    """
    with torch.no_grad():  # Disable gradients
        x = x.to(device)  # Move to GPU
        features = model.backbone(x)  # Process on GPU
        return features.cpu().numpy()  # Move to CPU only at end

print("Efficient GPU usage:")
print("  - Disable gradients: with torch.no_grad()")
print("  - Keep data on GPU during computation")
print("  - Move to CPU only at the end")
print("  - Convert to numpy for downstream tasks")

Efficient GPU usage:
  - Disable gradients: with torch.no_grad()
  - Keep data on GPU during computation
  - Move to CPU only at the end
  - Convert to numpy for downstream tasks


## Part 6: Testing Your Model

### Unit Tests

In [14]:
# Test suite for your model
def test_model():
    """Comprehensive test suite for model."""
    
    # Test 1: Instantiation
    model = SimpleModelExtractor()
    assert model is not None
    print("✓ Test 1: Instantiation")
    
    # Test 2: Loading
    model.load()
    print("✓ Test 2: Weight loading")
    
    # Test 3: Forward shape
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    dummy = torch.randn(2, 1, 48, 48, 48, device=device)
    output = model.forward(dummy)
    assert output.shape == (2, model.feature_dim)
    print(f"✓ Test 3: Output shape {output.shape}")
    
    # Test 4: Output type
    assert isinstance(output, np.ndarray)
    print("✓ Test 4: Output is numpy array")
    
    # Test 5: Batch processing
    for batch_size in [1, 2, 4, 8]:
        dummy = torch.randn(batch_size, 1, 48, 48, 48, device=device)
        output = model.forward(dummy)
        assert output.shape[0] == batch_size
    print("✓ Test 5: Batch processing")
    
    print("\n✓ All tests passed!")

test_model()

✓ Test 1: Instantiation
✓ Test 2: Weight loading
✓ Test 3: Output shape (2, 256)
✓ Test 4: Output is numpy array
✓ Test 5: Batch processing

✓ All tests passed!


In the future `np.bool` will be defined as the corresponding NumPy scalar.
In the future `np.bool` will be defined as the corresponding NumPy scalar.


## Part 7: Use in Feature Extraction

In [15]:
# Now your registered model can be used in feature extraction
from tumorimagingbench.models import get_extractor

# Retrieve your registered model
ModelClass = get_extractor('SimpleModelExtractor')
print(f"Retrieved model: {ModelClass.__name__}")

# Use in feature extraction pipeline
# python nsclc_radiomics_feature_extractor.py \
#   --output features/nsclc.pkl \
#   --models SimpleModelExtractor DummyResNetExtractor

Retrieved model: SimpleModelExtractor


## Summary

You have learned how to:

1. **Understand BaseModel** - The abstract base class all models inherit from
2. **Implement a Model** - Create your own model extractor with required methods
3. **Register a Model** - Make it available in the framework
4. **Test a Model** - Verify correct input/output formats
5. **Use in Pipelines** - Integrate with feature extraction

## Next Steps

- Check [02_feature_extractor_guide.ipynb](./02_feature_extractor_guide.ipynb) to learn how to add datasets
- Review the DummyResNetExtractor: `src/tumorimagingbench/models/dummy_resnet.py`
- Explore other models in `src/tumorimagingbench/models/`