# SAM 3 CoralScapes Fine-tuning Notebook

This notebook fine-tunes the SAM 3 image model for semantic segmentation on the CoralScapes dataset using the same dataloader and combined CE + Dice loss shown in the reference training snippet. Update the paths and hyper-parameters in the **Config** cell before running on your environment (e.g., Perlmutter `pscratch`).


In [None]:
# Optional: install extra dependencies for notebooks/training (uncomment if needed)
# !pip install -e "./[notebooks,train]"  # from the repo root
# !pip install segmentation-models-pytorch==0.3.3  # only if you want the reference model for comparison


In [None]:
import os
from pathlib import Path
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

from tqdm.auto import tqdm
from torchmetrics.classification import Accuracy, JaccardIndex

from sam3 import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor


In [None]:
# ----------------------------
# Config
# ----------------------------
# Root where your CoralScapes data lives
root = "/pscratch/sd/k/kevinval/coralscapes"  # change if different

NUM_CLASSES = 40
IGNORE_INDEX = 0
BATCH_SIZE = 2  # adjust for your GPU memory
NUM_EPOCHS = 10
LR = 1e-4

# Paths for SAM 3 (requires HF auth + checkpoint access)
# If you already set HF_TOKEN and have access, the defaults work.
bpe_path = None  # optional path to BPE file; keep None to use default shipped with the package
sam3_checkpoint = None  # optional local checkpoint path; keep None to use the default

seed = 42
device = "cuda" if torch.cuda.is_available() else "cpu"

torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)


In [None]:
# ----------------------------
# Dataset (matches the reference snippet)
# ----------------------------
class CoralScapesTiled(Dataset):
    def __init__(self, root, split="train", transform=None):
        self.root = root
        self.split = split
        self.transform = transform

        self.images_dir = os.path.join(root, split, "images")
        self.masks_dir = os.path.join(root, split, "labels")

        self.image_files = sorted([
            f for f in os.listdir(self.images_dir)
            if f.lower().endswith((".png", ".jpg", ".jpeg"))
        ])

        for f in self.image_files:
            mask_path = os.path.join(self.masks_dir, f)
            if not os.path.exists(mask_path):
                raise FileNotFoundError(f"Mask not found for {f}")

        print(f"[{split}] Loaded {len(self.image_files)} tiles")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]

        img_path = os.path.join(self.images_dir, img_name)
        mask_path = os.path.join(self.masks_dir, img_name)

        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path))

        if self.transform is not None:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]

        if isinstance(mask, np.ndarray):
            mask = torch.as_tensor(mask, dtype=torch.long)
        else:
            mask = mask.long()

        return image, mask


In [None]:
# ----------------------------
# Augmentations
# ----------------------------
# We keep the geometric/color augmentations from the reference script.
# The normalization matches the Sam3Processor defaults (mean=0.5, std=0.5).
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(
        shift_limit=0.05,
        scale_limit=0.1,
        rotate_limit=10,
        border_mode=0,
        p=0.5,
    ),
    A.ColorJitter(0.2, 0.2, 0.2, 0.1, p=0.5),
    A.GaussianBlur(blur_limit=3, p=0.2),
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2(),
])


In [None]:
# ----------------------------
# Datasets & DataLoaders
# ----------------------------
train_ds = CoralScapesTiled(root, split="train", transform=train_transform)
val_ds = CoralScapesTiled(root, split="validation", transform=val_transform)

def seed_worker(worker_id):
    np.random.seed(seed + worker_id)
    torch.manual_seed(seed + worker_id)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    worker_init_fn=seed_worker,
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    worker_init_fn=seed_worker,
)


In [None]:
# ----------------------------
# Build SAM 3 model + lightweight segmentation head
# ----------------------------
# We reuse SAM 3 as a strong visual backbone and add a per-pixel classifier on top
# of the highest-resolution FPN feature map. The entire module is trainable so you
# can choose to freeze the backbone if desired.

sam3_model = build_sam3_image_model(
    bpe_path=bpe_path,
    ckpt_path=sam3_checkpoint,
).to(device)
processor = Sam3Processor(sam3_model, device=device)

class Sam3SemanticSegmentation(nn.Module):
    def __init__(self, sam3_model, num_classes):
        super().__init__()
        self.sam3 = sam3_model
        self.classifier = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.GroupNorm(8, 256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1),
        )

    def forward(self, images):
        # images are already normalized tensors (B, 3, H, W)
        backbone_out = self.sam3.backbone.forward_image(images)
        fpn_feats = backbone_out["backbone_fpn"]  # list of multi-scale features
        # use the highest-resolution feature map (last in the list)
        feat = fpn_feats[-1]
        logits = self.classifier(feat)
        # upsample to input resolution
        logits = F.interpolate(logits, size=images.shape[-2:], mode="bilinear", align_corners=False)
        return logits

model = Sam3SemanticSegmentation(sam3_model, NUM_CLASSES).to(device)

# Uncomment the next line if you want to freeze the SAM 3 backbone and only train the head
# for p in model.sam3.parameters():
#     p.requires_grad = False


In [None]:
# ----------------------------
# Loss: CE + Dice (same structure as the reference)
# ----------------------------
class CombinedCELossDiceLoss(nn.Module):
    def __init__(self, weight_ce=0.5, weight_dice=0.5, ignore_index=0, eps=1e-6):
        super().__init__()
        self.weight_ce = weight_ce
        self.weight_dice = weight_dice
        self.ignore_index = ignore_index
        self.eps = eps
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)

    def forward(self, logits, targets):
        ce = self.ce_loss(logits, targets)

        num_classes = logits.shape[1]
        targets_one_hot = F.one_hot(targets.clamp(min=0), num_classes=num_classes)
        targets_one_hot = targets_one_hot.permute(0, 3, 1, 2).float()

        probs = F.softmax(logits, dim=1)
        valid_mask = (targets != self.ignore_index).unsqueeze(1).float()
        probs = probs * valid_mask
        targets_one_hot = targets_one_hot * valid_mask

        intersection = (probs * targets_one_hot).sum(dim=(0, 2, 3))
        union = probs.sum(dim=(0, 2, 3)) + targets_one_hot.sum(dim=(0, 2, 3))
        dice_loss = 1.0 - ((2.0 * intersection + self.eps) / (union + self.eps)).mean()

        return self.weight_ce * ce + self.weight_dice * dice_loss

loss_fn = CombinedCELossDiceLoss(weight_ce=0.3, weight_dice=0.7, ignore_index=IGNORE_INDEX)


In [None]:
# ----------------------------
# Optimizer, Scheduler, Metrics
# ----------------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

scaler = torch.cuda.amp.GradScaler(enabled=(device == "cuda"))

accuracy_metric = Accuracy(
    task="multiclass",
    num_classes=NUM_CLASSES,
    ignore_index=IGNORE_INDEX,
).to(device)

miou_metric = JaccardIndex(
    task="multiclass",
    num_classes=NUM_CLASSES,
    ignore_index=IGNORE_INDEX,
).to(device)

best_miou = 0.0


In [None]:
# ----------------------------
# Training Loop
# ----------------------------
for epoch in range(NUM_EPOCHS):
    model.train()
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")

    running_loss = 0.0

    for images, masks in train_bar:
        images = images.to(device)
        masks = masks.to(device).long()

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=(device == "cuda")):
            logits = model(images)
            loss = loss_fn(logits, masks)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * images.size(0)
        avg_loss = running_loss / ((train_bar.n + 1e-9) * BATCH_SIZE)
        train_bar.set_postfix(loss=avg_loss)

    scheduler.step()

    # ------------------------
    # Validation
    # ------------------------
    model.eval()
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device).long()

            with torch.cuda.amp.autocast(enabled=(device == "cuda")):
                logits = model(images)

            preds = logits.argmax(dim=1)

            accuracy_metric.update(preds.view(-1), masks.view(-1))
            miou_metric.update(preds, masks)

        val_accuracy = accuracy_metric.compute().item()
        val_miou = miou_metric.compute().item()

        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} | Val Accuracy: {val_accuracy:.4f} | Val mIoU: {val_miou:.4f}")

        accuracy_metric.reset()
        miou_metric.reset()

        if val_miou > best_miou:
            best_miou = val_miou
            torch.save(model.state_dict(), "best_sam3_coralscapes_semantic.pth")
            print(f"  -> Saved best checkpoint with mIoU: {best_miou:.4f}\n")
