# TerraTorch: 4 Levels of Abstraction to Use Terratorch's Model Registry in a Notebook

This notebook teaches you how to use TerraTorch models at different levels:

| Level | What You Get | Method to Call | When to Use |
|-------|--------------|----------------|-------------|
| 1. Backbone | Feature extractor only | `BACKBONE_REGISTRY.build("model_name")` | Research, custom pipelines |
| 2. Full Model | Backbone + Decoder + Head | `EncoderDecoderFactory().build_model(...)` | Inference |
| 3. Task | Model + training logic | `SemanticSegmentationTask(...)` | Custom training loops |
| 4. Task + DataModule + Trainer | Complete pipeline | `Trainer.fit(task, datamodule)` | Full training runs |

**Key files to study:**
- Level 1: `terratorch/registry/registry.py`
- Level 2: `terratorch/models/encoder_decoder_factory.py`
- Level 3: `terratorch/tasks/base_task.py`
- Level 4: `terratorch/datamodules/generic_pixel_wise_data_module.py`

## Level 1: Backbone Only

A **backbone** is a neural network that extracts features from images. It doesn't make predictions ‚Äî it just "sees" patterns.



In [None]:
# Level 1: Load a backbone
from terratorch.registry import BACKBONE_REGISTRY
import torch

# Build a backbone
backbone = BACKBONE_REGISTRY.build("prithvi_eo_v2_100_tl", pretrained=True)
print(f"Type: {type(backbone).__name__}")

In [None]:
# See how the backbone processes an image
fake_image = torch.randn(1, 6, 224, 224)  # [batch, channels, height, width]

backbone.eval()
with torch.no_grad():
    features = backbone(fake_image)

# Print output shapes
print("Features shape:")
for i, f in enumerate(features):
    print(f"  tensor {i}: {f.shape}")

In [None]:
# Backbone Families Summary
print("\nüèóÔ∏è TERRATORCH BACKBONE FAMILIES\n")
print(f"{'Family':<35} {'Count':<10} {'Example Models'}")
print("-" * 80)

for source_name, source in BACKBONE_REGISTRY._sources.items():
    try:
        models = list(source)
        examples = ", ".join(models[:3])
        if len(models) > 3:
            examples += "..."
        print(f"{source_name:<35} {len(models):<10} {examples}")
    except:
        print(f"{source_name:<35} {'dynamic':<10} (loaded on demand)")

In [None]:
# TerraTorch native models (geospatial-focused)
print("üõ∞Ô∏è TERRATORCH MODELS (Geospatial)")
print("-" * 40)
terratorch_models = list(BACKBONE_REGISTRY._sources['terratorch'])
for m in sorted(terratorch_models)[:15]:
    print(f"  ‚Ä¢ {m}")
print(f"  ... Total: {len(terratorch_models)}")

# Timm models (general vision)
print("\nüì∑ TIMM MODELS (General Vision)")
print("-" * 40)
timm_models = list(BACKBONE_REGISTRY._sources['timm'])
for m in sorted(timm_models)[:15]:
    print(f"  ‚Ä¢ {m}")
print(f"  ... Total: {len(timm_models)}")

## Level 2: Full Model (Backbone + Decoder + Head)

A **full model** combines:
- **Backbone**: Extracts features
- **Decoder**: Transforms features (e.g., upsamples)
- **Head**: Makes predictions

Now we can get actual outputs (e.g., segmentation masks).

In [None]:
# Level 2: Build a full model
from terratorch.models import EncoderDecoderFactory
from terratorch.datasets import HLSBands
import torch

# Create fake image for testing
fake_image = torch.randn(1, 6, 224, 224)  # [batch, channels, height, width]

factory = EncoderDecoderFactory()
model = factory.build_model(
    task="segmentation",
    backbone="prithvi_eo_v2_100_tl",
    decoder="UperNetDecoder",
    backbone_bands=[HLSBands.BLUE, HLSBands.GREEN, HLSBands.RED,
                    HLSBands.NIR_NARROW, HLSBands.SWIR_1, HLSBands.SWIR_2],
    num_classes=5,
    backbone_pretrained=True,
)

In [None]:
# Use the full model
model.eval()
with torch.no_grad():
    output = model(fake_image)

print(f"Input: {fake_image.shape}")
print(f"Output: {output.output.shape}")  # [batch, num_classes, H, W]

## Level 3: Task Class

A **Task** wraps the model with training logic:
- Loss function
- Metrics (accuracy, mAP, etc.)
- `training_step()`, `validation_step()`

It's a PyTorch Lightning module ‚Äî ready for training but needs data.

In [None]:
# Level 3: Create a Task
from terratorch.tasks import SemanticSegmentationTask

task = SemanticSegmentationTask(
    model_args={
        "backbone": "prithvi_eo_v2_100_tl",
        "decoder": "UperNetDecoder",
        "num_classes": 5,
        "backbone_pretrained": True,
    },
    loss="ce",
    model_factory="EncoderDecoderFactory",
)

print(f"Task type: {type(task).__name__}")
print(f"Model inside: {type(task.model).__name__}")

## Level 4: Task + DataModule + Trainer

A **DataModule** handles data loading (train/val/test splits, transforms, batching).

A **Trainer** runs the training loop.

Together: complete pipeline.

In [None]:
# Level 4: Working example with fake data
from torch.utils.data import Dataset, DataLoader
from lightning import Trainer
import torch

# Create a fake dataset
class FakeSegmentationDataset(Dataset):
    def __init__(self, num_samples=4):
        self.num_samples = num_samples
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # Fake image: 6 channels, 224x224
        image = torch.randn(6, 224, 224)
        # Fake mask: class labels 0-4 for each pixel
        mask = torch.randint(0, 5, (224, 224))
        return {"image": image, "mask": mask}

# Create dataloaders
train_dataset = FakeSegmentationDataset(num_samples=4)
train_loader = DataLoader(train_dataset, batch_size=2)

val_dataset = FakeSegmentationDataset(num_samples=2)
val_loader = DataLoader(val_dataset, batch_size=2)

# Create a fresh task
from terratorch.tasks import SemanticSegmentationTask

task = SemanticSegmentationTask(
    model_args={
        "backbone": "prithvi_eo_v2_100_tl",
        "decoder": "UperNetDecoder",
        "num_classes": 5,
        "backbone_pretrained": True,
    },
    loss="ce",
    ignore_index=-1,  # Required: index to ignore in loss (-1 = none ignored)
    model_factory="EncoderDecoderFactory",
)

# Create trainer and run 1 epoch
# Note: Using CPU because MPS has issues with adaptive pooling on certain sizes
trainer = Trainer(
    max_epochs=1,
    accelerator="cpu",  # Force CPU to avoid MPS pooling issues
    enable_progress_bar=True,
    enable_model_summary=False,
    logger=False,  # Disable logging for this demo
)

print("üöÄ Starting training for 1 epoch...")
trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)
print("‚úÖ Training complete!")

## Troubleshooting: What if a model URL changes?

Sometimes HuggingFace or timm model URLs change (e.g., repos get renamed or moved).

**Example:** Clay v1 models moved from `made-with-clay/Clay` to `made-with-clay/Clay-legacy`.

Here's how to debug and find where URLs are configured:

In [None]:
# Example: Loading a Clay model
from terratorch.registry import BACKBONE_REGISTRY

# Try loading Clay v1 backbone
try:
    clay_backbone = BACKBONE_REGISTRY.build("clay_v1_base", pretrained=True)
    print("‚úÖ Clay v1 loaded successfully!")
    print(f"Type: {type(clay_backbone).__name__}")
except Exception as e:
    print(f"‚ùå Error loading Clay: {e}")
    print("\nüîß If you see a 404 or connection error, the HuggingFace URL may have changed.")

In [None]:
# How to find and fix broken model URLs
# 
# Step 1: Find where the model is defined
#   - Search the codebase for the model name
#   - For Clay: terratorch/models/backbones/clay_v1/embedder.py
#
# Step 2: Look for `default_cfgs` or `hf_hub_id`
#   - This is where HuggingFace URLs are configured

import inspect
from terratorch.models.backbones.clay_v1 import embedder

# Find the file location
print("üìÅ FILE TO EDIT:")
print(f"   {inspect.getfile(embedder)}")
print()

# Show the current HuggingFace configuration
print("üîó CURRENT HUGGINGFACE CONFIG:")
if hasattr(embedder, 'default_cfgs'):
    for model_name, cfg in embedder.default_cfgs.items():
        hf_url = getattr(cfg.default, 'hf_hub_id', 'Not set')
        print(f"   {model_name}: {hf_url}")
else:
    print("   (Check the file manually for hf_hub_id or url settings)")

### Quick Reference: Where to fix model URLs

| Model Family | Config File | Look For |
|--------------|-------------|----------|
| **Clay v1** | `terratorch/models/backbones/clay_v1/embedder.py` | `default_cfgs`, `hf_hub_id` |
| **Prithvi** | `terratorch/models/backbones/prithvi_vit.py` | `default_cfgs`, `hf_hub_id` |
| **ScaleMAE** | `terratorch/models/backbones/scalemae/scalemae.py` | `default_cfgs` |
| **timm models** | timm library (external) | Update timm version |

