# MinimalSAM Training

This notebook contains the training pipeline.

## Setup Environment

⚠️ Make sure to set `using_colab=True` if running on Google Colab!

In [None]:
import torch
import torchvision
import os

using_colab=False

# check if minimal_sam and sam2 are installed
try:
    from sam2.build_sam import build_sam2
    print("SAM2 is already installed.")
except:   
    !pip install 'git+https://github.com/facebookresearch/sam2.git'
try:
    from minimal_sam.models import MinimalSAM
    print("MinimalSAM is already installed.")
except:
    !pip install 'git+https://github.com/laggezhd/giga-sam.git'

# setup environment
if using_colab:
    !git clone https://github.com/facebookresearch/sam2.git
    !git clone https://github.com/laggezhd/giga-sam.git
    !cd giga-sam/checkpoints && ./download_ckpts.sh
    !cd giga-sam/dataset && ./download_dataset.sh

    os.environ["BASE_PATH"] = os.getcwd()
    os.environ["ANNS_PATH"] = "/content/drive/MyDrive/dataset/"
else:
    os.environ["BASE_PATH"] = os.path.dirname(os.getcwd())
    os.environ["ANNS_PATH"] = os.path.join(os.environ["BASE_PATH"], "configs/annotations")

print(f"BASE_PATH: {os.environ['BASE_PATH']}")
print(f"ANNS_PATH: {os.environ['ANNS_PATH']}")
print(f"PyTorch version:     {torch.__version__}")
print(f"Torchvision version: {torchvision.__version__}")
print(f"CUDA is available:   {torch.cuda.is_available()}")

SAM2 is already installed.
BASE_PATH: /home/sschwartz/Dokumente/ml_mcu_hs_25/giga-sam
PyTorch version:     2.9.1+cu128
Torchvision version: 0.24.1+cu128
CUDA is available:   False


In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt

import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F

from pathlib import Path
from pycocotools.coco import COCO
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

from minimal_sam.models import MicroSAM, MinimalSAM, PicoSAM
from minimal_sam.utils.data import MinimalSamDataset
from minimal_sam.utils.loss import bce_dice_loss
from minimal_sam.utils.metrics import compute_iou

requires manual authentication for both Google Drive and WandB

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!wandb login

## Training Loop

In [None]:
MODEL="MicroSAM"  # choose between "MicroSAM", "PicoSAM" and "MinimalSAM"
IMG_SIZE = 96

# === Files and Directories ===
COCO_ANN_FILE  = Path(os.environ["BASE_PATH"]).joinpath("giga-sam/dataset/annotations/instances_train2017.json")
COCO_FILTERED  = Path(os.environ["ANNS_PATH"]).joinpath(f"filtered_anns_{IMG_SIZE}x{IMG_SIZE}.json")

OUTPUT_DIR  = Path(os.environ["BASE_PATH"]).joinpath("outputs")
DATASET_DIR = Path(os.environ["BASE_PATH"]).joinpath("giga-sam/dataset")

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
assert COCO_ANN_FILE.exists(), f"COCO annotation file not found at {COCO_ANN_FILE}"
assert COCO_FILTERED.exists(), f"Filtered COCO annotations not found at {COCO_FILTERED}"

# === Training Hyperparameters ===
NUM_WORKERS = 12  # speeds up training!
NUM_EPOCHS = 30
BATCH_SIZE = 64
LEARNING_RATE = 3e-4
MAX_LR = 1e-3
WEIGHT_DECAY = 1e-4

def train():

    # Set training device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training on: {device}")

    # Initialize wandb
    wandb.init(
        project="MLonMCU",
        name=MODEL,
        config={
            "img_size": IMG_SIZE,
            "epochs": NUM_EPOCHS,
            "batch_size": BATCH_SIZE,
            "learning_rate": LEARNING_RATE,
            "max_lr": MAX_LR,
            "weight_decay": WEIGHT_DECAY
        }
    )

    # Initialize model
    match MODEL:
        case "MicroSAM":
            model = MicroSAM().to(device)
        case "PicoSAM":
            model = PicoSAM().to(device)
        case "MinimalSAM":
            model = MinimalSAM().to(device)
        case _:
            raise ValueError(f"Unknown model type: {MODEL}! Choose between 'MicroSAM', 'PicoSAM' and 'MinimalSAM'.")

    # Initialize dataset and dataloaders
    dataset = MinimalSamDataset(IMG_SIZE, DATASET_DIR, COCO_ANN_FILE, COCO_FILTERED)

    train_len = int(len(dataset) * 0.95)
    train_ds, val_ds = random_split(dataset, [train_len, len(dataset) - train_len])

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True)

    # Initialize optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=MAX_LR, steps_per_epoch=len(train_loader), epochs=NUM_EPOCHS)

    scaler = torch.amp.GradScaler("cuda")

    best_val_iou = 0.0

    # === Training Loop ===
    for epoch in range(NUM_EPOCHS):
        model.train()

        total_loss, total_iou, samples = 0, 0, 0  # logging params

        for batch_idx, (images, masks) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1} - Train")):
            images, masks = images.to(device), masks.to(device)

            # forward pass
            with torch.amp.autocast("cuda"):
                logits = model(images)
                loss = bce_dice_loss(logits, masks)

            # backpropagation
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            
            scaler.unscale_(optimizer)  # unscale gradients before clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)  # gradient clipping

            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            # logging
            batch_iou = compute_iou(logits, masks)

            total_loss += loss.item() * images.size(0)
            total_iou += batch_iou * images.size(0)
            samples += images.size(0)

            wandb.log({
                "batch_loss": loss.item(),
                "batch_mIoU": batch_iou,
                "epoch": epoch + 1
            })

        wandb.log({
            "train_loss": total_loss / samples,
            "train_mIoU": total_iou / samples,
            "lr": scheduler.get_last_lr()[0],
            "epoch": epoch + 1
        })

        # === Validation Loop ===
        model.eval()
        val_loss, val_iou, val_samples = 0, 0, 0
        with torch.no_grad():
            for images, masks in tqdm(val_loader, desc=f"Epoch {epoch + 1} - Val"):
                images, masks = images.to(device), masks.to(device)

                logits = model(images)
                loss = bce_dice_loss(logits, masks)

                val_loss += loss.item() * images.size(0)
                val_iou += compute_iou(logits, masks) * images.size(0)
                val_samples += images.size(0)

        wandb.log({
            "val_loss": val_loss / val_samples,
            "val_mIoU": val_iou / val_samples,
            "epoch": epoch + 1
        })

        torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, f"{MODEL}_epoch{epoch + 1}.pt"))

        # Save best model
        avg_val_iou = val_iou / val_samples

        if avg_val_iou > best_val_iou:
            best_val_iou = avg_val_iou
            best_model_path = os.path.join(OUTPUT_DIR, f"{MODEL}_best.pt")
            torch.save(model.state_dict(), best_model_path)
            wandb.save(best_model_path)
            print(f"New best model saved with mIoU: {best_val_iou:.4f}")

    print("Training completed.")

In [None]:
train()