<a href="https://colab.research.google.com/github/bdtranter/fine-tune-efficientViT/blob/main/Tuned_EfficientViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# cityscapes_for_sam.py
"""
Cityscapes wrapper that produces:
    - image tensor
    - one binary mask
    - one positive point prompt
for training a SAM-style model.
"""

from __future__ import annotations

import random
from typing import Dict, Any, Optional, Tuple

import torch
from torch import Tensor
from torch.utils.data import Dataset
from torchvision.datasets import Cityscapes
import torchvision.transforms as T


class CityscapesForSAM(Dataset):
    """
    Wraps torchvision Cityscapes and returns:
        image: (3,H,W) float [0,1]
        gt_mask: (H,W) float {0,1}
        point_coords: (1,2) float [[x, y]] in pixel coords
        point_labels: (1,) long [1]  (positive)
    """

    def __init__(self, root: str, split: str = "train"):
        super().__init__()
        self.base = Cityscapes(
            root=root,
            split=split,
            mode="fine",
            target_type="semantic",
        )
        self.img_transform = T.ToTensor()

    def __len__(self) -> int:
        return len(self.base)

    def _sample_prompt_from_mask(
        self, sem_mask: Tensor
    ) -> Optional[Tuple[Tensor, Tensor, Tensor]]:
        """
        sem_mask: (H,W) long semantic labels.

        Returns:
            binary_mask: (H,W) float {0,1}
            point_coords: (1,2) float [[x, y]]
            point_labels: (1,) long [1]
        or None if we fail to sample.
        """
        # Remove background / ignore labels as needed.
        classes = torch.unique(sem_mask)
        classes = classes[(classes != 0) & (classes != 255)]
        if len(classes) == 0:
            return None

        cls = classes[torch.randint(len(classes), (1,))]
        region = (sem_mask == cls)
        ys, xs = region.nonzero(as_tuple=True)
        if len(xs) == 0:
            return None

        idx = torch.randint(len(xs), (1,))
        y, x = ys[idx], xs[idx]

        binary_mask = region.float()
        point_coords = torch.tensor([[float(x), float(y)]], dtype=torch.float32)
        point_labels = torch.tensor([1], dtype=torch.int64)
        return binary_mask, point_coords, point_labels

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        img, sem = self.base[idx]
        img_t = self.img_transform(img)  # (3,H,W)
        sem_t = torch.as_tensor(sem, dtype=torch.long)

        sample = self._sample_prompt_from_mask(sem_t)
        if sample is None:
            # If something goes wrong, resample a different image.
            return self[random.randrange(len(self))]

        binary_mask, point_coords, point_labels = sample

        return {
            "image": img_t,
            "gt_mask": binary_mask,        # (H,W)
            "point_coords": point_coords,  # (1,2)
            "point_labels": point_labels,  # (1,)
        }


def cityscapes_collate_fn(batch: list[Dict[str, Any]]) -> Dict[str, Tensor]:
    """
    Simple collate for SAM-style training.
    Assumes all images in batch have same size (true for Cityscapes).
    """
    images = torch.stack([b["image"] for b in batch], dim=0)             # (B,3,H,W)
    gt_masks = torch.stack([b["gt_mask"] for b in batch], dim=0)         # (B,H,W)
    point_coords = torch.stack([b["point_coords"] for b in batch], dim=0)  # (B,1,2)
    point_labels = torch.stack([b["point_labels"] for b in batch], dim=0)  # (B,1)

    return {
        "image": images,
        "gt_mask": gt_masks,
        "point_coords": point_coords,
        "point_labels": point_labels,
    }


In [None]:
# losses_sam.py
"""
Loss functions for SAM-style training:
- binary dice loss
- binary focal loss
- combined loss for multiple masks with best-of-N strategy.
"""

from __future__ import annotations

import torch
from torch import Tensor
import torch.nn.functional as F


def dice_loss(pred_probs: Tensor, target: Tensor, eps: float = 1e-6) -> Tensor:
    """
    pred_probs, target: (B,1,H,W) in [0,1]
    """
    intersection = (pred_probs * target).sum(dim=(2, 3))
    union = pred_probs.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
    dice = (2.0 * intersection + eps) / (union + eps)
    return 1.0 - dice.mean()


def focal_loss(
    pred_logits: Tensor,
    target: Tensor,
    alpha: float = 0.25,
    gamma: float = 2.0,
    eps: float = 1e-6,
) -> Tensor:
    """
    pred_logits: (B,1,H,W) raw logits
    target: (B,1,H,W) in {0,1}
    """
    prob = torch.sigmoid(pred_logits)
    pt = prob * target + (1 - prob) * (1 - target)
    w = alpha * target + (1 - alpha) * (1 - target)
    loss = -w * (1 - pt) ** gamma * torch.log(pt + eps)
    return loss.mean()


def sam_mask_loss(multi_mask_logits: Tensor, gt_mask: Tensor) -> Tensor:
    """
    Compute SAM-style loss given multiple masks (e.g., 3) and a single GT mask.

    Args:
        multi_mask_logits: (B, M, H, W) predicted logits for M masks
        gt_mask:           (B, H, W)     float {0,1}

    Returns:
        Scalar loss = mean over batch of:
            min_m [ 20 * focal_loss + 1 * dice_loss ]
    """
    B, M, H, W = multi_mask_logits.shape

    # Expand gt to match masks.
    gt = gt_mask.unsqueeze(1).expand(-1, M, -1, -1)  # (B,M,H,W)

    # Compute mask-wise losses.
    logits_flat = multi_mask_logits.view(B * M, 1, H, W)
    gt_flat = gt.view(B * M, 1, H, W)

    fl = focal_loss(logits_flat, gt_flat)  # scalar over all (B*M) if we do it this way

    # If you want **strict** per-mask best-of-N, uncomment this more detailed version:

    # with torch.no_grad():
    #     prob_flat = torch.sigmoid(logits_flat)
    # dl_per = dice_loss(prob_flat, gt_flat)  # currently scalar; reimplement to get per-sample if desired

    # For a skeleton we keep it simple and just combine globally:
    prob_flat = torch.sigmoid(logits_flat)
    dl = dice_loss(prob_flat, gt_flat)

    total = 20.0 * fl + 1.0 * dl
    return total

    # NOTE: If you want exact SAM behavior, modify this function to:
    #   - compute focal+dice per sample & per mask (shape (B,M))
    #   - take min over M for each sample
    #   - average over B


In [None]:
# sam_forward_train.py (optional) OR inside train script
from __future__ import annotations

from typing import Dict

import torch
from torch import Tensor

# Adjust imports to match the EfficientViT-SAM repo
from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor


def sam_forward_train(
    predictor: "EfficientViTSamPredictor",
    batch: Dict[str, Tensor],
) -> Tensor:
    """
    Training wrapper around EfficientViTSamPredictor.

    Args:
        predictor: EfficientViTSamPredictor instance wrapping EfficientViTSam model.
        batch: dict with keys:
            - "image":        (B,3,H,W)
            - "point_coords": (B,1,2) in original pixel coords
            - "point_labels": (B,1)

    Returns:
        multi_mask_logits: (B, M, H, W) logits in original image resolution.

    NOTE:
        Here we assume you will call **existing internal methods** from the predictor
        to avoid re-implementing preprocessing, embedding, and postprocessing.

        Concretely, inside this function you should:
          1. call a predictor method that:
             - resizes & pads images
             - transforms point_coords to resized space
             - runs image_encoder + prompt_encoder + mask_decoder
             - upscales masks back to original size
          2. Convert probabilities -> logits via torch.logit.

        Because the exact method names differ by repo, this function is left as a TODO.
    """
    images = batch["image"]
    point_coords = batch["point_coords"]
    point_labels = batch["point_labels"]

    # TODO: Replace the following NotImplementedError with real calls.
    # Example (pseudo-code, not guaranteed to match your repo):
    #
    # low_res_masks, iou_preds = predictor._predict_low_res_batch(
    #     images, point_coords, point_labels
    # )
    # upscaled_probs = predictor.postprocess_masks_to_original_size(
    #     low_res_masks, original_sizes=images.shape[-2:]
    # )   # (B,M,H,W)
    # logits = torch.logit(
    #     upscaled_probs.clamp(1e-6, 1 - 1e-6),
    #     eps=1e-6,
    # )
    # return logits

    raise NotImplementedError("Hook this up to EfficientViTSamPredictor internals.")


In [None]:
# train_efficientvit_sam_cityscapes.py
"""
Head-only fine-tuning of EfficientViT-SAM on Cityscapes.

Reuses:
    - EfficientViTSam model
    - EfficientViTSamPredictor for preprocessing & inference
and only adds:
    - Cityscapes dataset wrapper
    - training loop
    - SAM loss
"""

from __future__ import annotations

import os
from typing import Dict, Any

import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm

from cityscapes_for_sam import CityscapesForSAM, cityscapes_collate_fn
from losses_sam import sam_mask_loss
from sam_forward_train import sam_forward_train  # or local function

# Adjust to your repo entry point
from efficientvit.sam_model_zoo import create_sam_model
from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor


def build_model_and_predictor(
    device: str,
    model_name: str = "l0",
    weights_path: str | None = None,
) -> tuple[nn.Module, EfficientViTSamPredictor]:
    """
    Build EfficientViT-SAM model + predictor and set requires_grad flags.
    """
    # Example; adjust args to match repo
    model = create_sam_model(name=model_name, weight_url=weights_path)
    model.to(device)

    # Freeze image encoder
    for p in model.image_encoder.parameters():
        p.requires_grad = False

    # Unfreeze mask decoder (head)
    for p in model.mask_decoder.parameters():
        p.requires_grad = True

    # Optionally also train prompt encoder
    for p in model.prompt_encoder.parameters():
        p.requires_grad = True

    predictor = EfficientViTSamPredictor(model)
    return model, predictor


def train(
    cityscapes_root: str,
    output_path: str = "efficientvit_sam_head_finetuned_cityscapes.pt",
    model_name: str = "l0",
    batch_size: int = 2,
    num_epochs: int = 10,
    lr: float = 1e-4,
    num_workers: int = 4,
) -> None:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Datasets & loaders
    train_ds = CityscapesForSAM(cityscapes_root, split="train")
    val_ds = CityscapesForSAM(cityscapes_root, split="val")

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=cityscapes_collate_fn,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=cityscapes_collate_fn,
        pin_memory=True,
    )

    # Model + predictor
    model, predictor = build_model_and_predictor(
        device=device,
        model_name=model_name,
        weights_path=None,  # or a checkpoint path if needed
    )

    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=lr,
        weight_decay=1e-2,
    )

    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        # ----------------- Train -----------------
        model.train()
        train_loss_accum = 0.0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [train]"):
            # Move GT to device
            batch = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in batch.items()}

            # Forward through SAM predictor (you must implement sam_forward_train)
            multi_mask_logits = sam_forward_train(predictor, batch)  # (B,M,H,W)

            loss = sam_mask_loss(multi_mask_logits, batch["gt_mask"])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss_accum += loss.item()

        train_loss = train_loss_accum / max(1, len(train_loader))
        print(f"[Epoch {epoch+1}] train loss: {train_loss:.4f}")

        # ----------------- Validation -----------------
        model.eval()
        val_loss_accum = 0.0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [val]"):
                batch = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in batch.items()}
                multi_mask_logits = sam_forward_train(predictor, batch)
                loss = sam_mask_loss(multi_mask_logits, batch["gt_mask"])
                val_loss_accum += loss.item()

        val_loss = val_loss_accum / max(1, len(val_loader))
        print(f"[Epoch {epoch+1}] val loss: {val_loss:.4f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), output_path)
            print(f"  -> Saved new best model to {output_path}")

    print(f"Training complete. Best val loss = {best_val_loss:.4f}")


if __name__ == "__main__":
    # CHANGE THIS to your Cityscapes root directory.
    cityscapes_root = "/path/to/cityscapes"
    train(cityscapes_root)
