<a href="https://colab.research.google.com/github/jenriver/bonsai/blob/aistack/bonsai/models/unetr/tests/unetr_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Image Segmentation with UNETR

This notebook demonstrates training a UNETR model for image segmentation on the Oxford-IIIT Pet Dataset.
It uses the UNETR implementation from `bonsai.models.unetr`.


In [None]:
import jax

jax.devices()

## 1. Installation and Imports

In [None]:
!pip install -q git+https://github.com/jax-ml/bonsai
!pip install -U -q opencv-python-headless grain albumentations Pillow
!pip install -U -q optax orbax-checkpoint
!pip install -e -q ./bonsai

In [None]:
import bonsai

print(f"Bonsai location: {bonsai.__file__}")

In [None]:
import os
import sys
from typing import Any, Callable

import albumentations as A
import cv2
import grain.python as grain
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import orbax.checkpoint as ocp
from flax import nnx
from PIL import Image

from bonsai.models.unetr.modeling import UNETR, UNETRConfig

print(f"JAX version: {jax.__version__}")

## 2. Data Preparation

In [None]:
# Download Data using standard libraries to avoid shell timeouts
import shutil
import tarfile
import urllib.request

data_root = "/tmp/data/oxford_pets"
if os.path.exists(data_root):
    shutil.rmtree(data_root)
os.makedirs(data_root, exist_ok=True)


def download_and_extract(url, extract_to):
    filename = url.split("/")[-1]
    filepath = os.path.join(extract_to, filename)
    print(f"Downloading {filename}...")
    urllib.request.urlretrieve(url, filepath)
    print(f"Extracting {filename}...")
    with tarfile.open(filepath, "r:gz") as tar:
        tar.extractall(path=extract_to)
    print(f"Done {filename}")


download_and_extract("https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz", data_root)
download_and_extract("https://thor.robots.ox.ac.uk/datasets/pets/annotations.tar.gz", data_root)

In [None]:
root_path = "/tmp/data/oxford_pets"
img_path = os.path.join(root_path, "images")
start_mask_path = os.path.join(root_path, "annotations/trimaps")

all_images = sorted([os.path.join(img_path, x) for x in os.listdir(img_path) if x.endswith(".jpg")])
print(f"Total images: {len(all_images)}")

In [None]:
class OxfordPetsDataset(grain.MapDataset):
    def __init__(self, root_path, img_path, mask_path):
        self.img_path = img_path
        self.mask_path = mask_path
        self.all_images = sorted([x for x in os.listdir(img_path) if x.endswith(".jpg")])

    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, idx):
        img_name = self.all_images[idx]
        img_uri = os.path.join(self.img_path, img_name)
        mask_uri = os.path.join(self.mask_path, img_name.replace(".jpg", ".png"))

        image = np.array(Image.open(img_uri).convert("RGB"))

        # Some masks might be missing or corrupted, handle gracefully or assume they exist as per tutorial
        try:
            mask = np.array(Image.open(mask_uri))
        except Exception:
            # Fallback for missing masks? Or just skip?
            # Unet tutorial assumes standard dataset structure.
            # Creating dummy mask if missing for robustness in demo
            mask = np.zeros(image.shape[:2], dtype=np.uint8)

        # Preprocess mask: 1 -> foreground, 2 -> background, 3 -> boundary.
        # We want 0, 1, 2 classes.
        mask = mask.astype(np.uint8)
        mask = mask - 1
        mask = np.clip(mask, 0, 2)

        return {"image": image, "mask": mask}


val_len = 200
train_dataset = OxfordPetsDataset(root_path, img_path, start_mask_path)
# Simple split
all_indices = np.arange(len(train_dataset))
train_indices = all_indices[:-val_len]
val_indices = all_indices[-val_len:]


class Subset(grain.MapDataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]


train_set = Subset(train_dataset, train_indices)
val_set = Subset(train_dataset, val_indices)
print(f"Train size: {len(train_set)}, Val size: {len(val_set)}")

## 3. Transformations and Data Loading

In [None]:
IMG_SIZE = 256

train_transforms = A.Compose(
    [
        A.Resize(IMG_SIZE, IMG_SIZE),
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=20, p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ]
)

val_transforms = A.Compose(
    [
        A.Resize(IMG_SIZE, IMG_SIZE),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ]
)


class DataAugs(grain.MapTransform):
    def __init__(self, transforms: Callable):
        self.albu_transforms = transforms

    def map(self, data):
        output = self.albu_transforms(**data)
        return output


# Parameters
batch_size = 8
seed = 42

# Samplers
train_sampler = grain.IndexSampler(
    len(train_set), shuffle=True, seed=seed, shard_options=grain.NoSharding(), num_epochs=1
)

val_sampler = grain.IndexSampler(len(val_set), shuffle=False, seed=seed, shard_options=grain.NoSharding(), num_epochs=1)

# Loaders
train_loader = grain.DataLoader(
    data_source=train_set,
    sampler=train_sampler,
    worker_count=2,
    worker_buffer_size=2,
    operations=[
        DataAugs(train_transforms),
        grain.Batch(batch_size, drop_remainder=True),
    ],
)

val_loader = grain.DataLoader(
    data_source=val_set,
    sampler=val_sampler,
    worker_count=2,
    worker_buffer_size=2,
    operations=[
        DataAugs(val_transforms),
        grain.Batch(batch_size),
    ],
)

## 4. Model Setup

In [None]:
config = UNETRConfig(
    out_channels=3,
    in_channels=3,
    img_size=IMG_SIZE,
    # Using a smaller model for demo purpose
    hidden_size=252,
    mlp_dim=512,
    num_heads=6,
    num_layers=12,
    feature_size=16,
)

model = UNETR(config=config)
nnx.display(model)

## 5. Training Setup

In [None]:
num_epochs = 2
steps_per_epoch = len(train_set) // batch_size
total_steps = steps_per_epoch * num_epochs
learning_rate = 1e-4
momentum = 0.9

lr_schedule = optax.linear_schedule(init_value=learning_rate, end_value=0.0, transition_steps=total_steps)

# optimizer = nnx.Optimizer(model, optax.adam(learning_rate), wrt=nnx.Param)
optimizer = nnx.ModelAndOptimizer(model, optax.adam(learning_rate), wrt=nnx.Param)


def compute_softmax_jaccard_loss(logits, masks, reduction="mean"):
    y_pred = nnx.softmax(logits, axis=-1)
    b, c = y_pred.shape[0], y_pred.shape[-1]
    y = nnx.one_hot(masks, num_classes=c, axis=-1)

    y_pred = y_pred.reshape((b, -1, c))
    y = y.reshape((b, -1, c))

    intersection = y_pred * y
    union = y_pred + y - intersection + 1e-8

    intersection = jnp.sum(intersection, axis=1)
    union = jnp.sum(union, axis=1)

    if reduction == "mean":
        intersection = jnp.mean(intersection)
        union = jnp.mean(union)

    return 1.0 - intersection / union


def compute_loss(model, inputs, masks):
    logits = model(inputs)
    ce_loss = optax.softmax_cross_entropy_with_integer_labels(logits, masks).mean()
    jacc_loss = compute_softmax_jaccard_loss(logits, masks)
    return ce_loss + jacc_loss, (ce_loss, jacc_loss)


@nnx.jit
def train_step(model, optimizer, batch):
    inputs = batch["image"]
    masks = batch["mask"].astype(jnp.int32)

    grad_fn = nnx.value_and_grad(compute_loss, has_aux=True)
    (loss, (ce, jacc)), grads = grad_fn(model, inputs, masks)
    optimizer.update(grads)
    return loss, ce, jacc

## 6. Metrics

In [None]:
class ConfusionMatrix(nnx.Metric):
    def __init__(self, num_classes: int, average: str | None = None):
        self.num_classes = num_classes
        self.average = average
        self.conf_mat = nnx.metrics.MetricState(jnp.zeros((num_classes, num_classes), dtype=jnp.int32))

    def reset(self):
        self.conf_mat.value = jnp.zeros((self.num_classes, self.num_classes), dtype=jnp.int32)

    def update(self, *, logits: jax.Array, labels: jax.Array, **kwargs):
        preds = jnp.argmax(logits, axis=-1)
        y_true = labels.reshape(-1)
        y_pred = preds.reshape(-1)

        # Valid mask: filter out indices not in [0, num_classes)
        valid_mask = (y_true >= 0) & (y_true < self.num_classes)

        # Calculate linear indices for the confusion matrix.
        # We map invalid pixels to an extra bin at index `num_classes**2`
        # so they don't corrupt the main matrix.
        cm_indices = self.num_classes * y_true + y_pred
        cm_indices = jnp.where(valid_mask, cm_indices, self.num_classes**2)

        # Compute bincount with fixed size (JIT-friendly)
        counts = jnp.bincount(cm_indices, minlength=self.num_classes**2 + 1, length=self.num_classes**2 + 1)

        # Discard the invalid bin and reshape
        batch_conf_mat = counts[:-1].reshape(self.num_classes, self.num_classes).astype(jnp.int32)
        self.conf_mat.value += batch_conf_mat

    def compute(self) -> jax.Array:
        return self.conf_mat.value


def compute_iou(cm: jax.Array) -> jax.Array:
    sum_over_row = jnp.sum(cm, axis=0)
    sum_over_col = jnp.sum(cm, axis=1)
    diag = jnp.diag(cm)
    denominator = sum_over_row + sum_over_col - diag
    return jnp.where(denominator > 0, diag / denominator, 0.0)


def compute_mean_iou(cm: jax.Array) -> jax.Array:
    return jnp.mean(compute_iou(cm))


def compute_accuracy(cm: jax.Array) -> jax.Array:
    return jnp.sum(jnp.diag(cm)) / jnp.sum(cm)

## 7. Eval & Train Functions

In [None]:
eval_metrics = nnx.MultiMetric(
    total_loss=nnx.metrics.Average("total_loss"),
    confusion_matrix=ConfusionMatrix(num_classes=3),
)

eval_metrics_history = {
    "train_total_loss": [],
    "train_IoU": [],
    "train_mean_IoU": [],
    "train_accuracy": [],
    "val_total_loss": [],
    "val_IoU": [],
    "val_mean_IoU": [],
    "val_accuracy": [],
}


def compute_losses_and_logits(model, images, masks):
    logits = model(images)
    # Re-using compute_loss logic but returning logits for metrics
    ce_loss = optax.softmax_cross_entropy_with_integer_labels(logits, masks).mean()
    jacc_loss = compute_softmax_jaccard_loss(logits, masks)
    loss = ce_loss + jacc_loss
    return loss, (ce_loss, jacc_loss, logits)


@nnx.jit
def eval_step(model: nnx.Module, batch, eval_metrics: nnx.MultiMetric):
    inputs = batch["image"]
    masks = batch["mask"].astype(jnp.int32)

    # We iterate over the batch since the model is usually larger for training
    # For eval we might fit simpler, but let's stick to batch processing.
    loss, (_, _, logits) = compute_losses_and_logits(model, inputs, masks)
    eval_metrics.update(total_loss=loss, logits=logits, labels=masks)


def train_one_epoch(epoch, model, optimizer, train_loader):
    model.train()
    for step, batch in enumerate(train_loader):
        loss, ce, jacc = train_step(model, optimizer, batch)
        if step % 10 == 0:
            print(
                f"\r[train] epoch: {epoch + 1}/{num_epochs}, iteration: {step}/{steps_per_epoch}, "
                f"total loss: {loss.item():.4f} ",
                f"xentropy loss: {ce.item():.4f} ",
                f"jaccard loss: {jacc.item():.4f} ",
                end="",
            )
    print("\r", end="")


def evaluate_model(epoch, model, val_loader, eval_metrics, eval_metrics_history):
    model.eval()
    # Evaluating on validation set only for demo speed
    # for tag, eval_loader in [("train", train_loader), ("val", val_loader)]:
    for tag, loader in [("val", val_loader)]:
        eval_metrics.reset()
        for val_batch in loader:
            eval_step(model, val_batch, eval_metrics)

        for metric, value in eval_metrics.compute().items():
            if metric == "confusion_matrix":
                eval_metrics_history[f"{tag}_IoU"].append(compute_iou(value))
                eval_metrics_history[f"{tag}_mean_IoU"].append(compute_mean_iou(value))
                eval_metrics_history[f"{tag}_accuracy"].append(compute_accuracy(value))
            else:
                eval_metrics_history[f"{tag}_{metric}"].append(value)

        print(
            f"[{tag}] epoch: {epoch + 1}/{num_epochs} "
            f"\n - total loss: {eval_metrics_history[f'{tag}_total_loss'][-1]:0.4f} "
            f"\n - IoU per class: {eval_metrics_history[f'{tag}_IoU'][-1].tolist()} "
            f"\n - Mean IoU: {eval_metrics_history[f'{tag}_mean_IoU'][-1]:0.4f} "
            f"\n - Accuracy: {eval_metrics_history[f'{tag}_accuracy'][-1]:0.4f} "
            "\n"
        )
    return eval_metrics_history["val_mean_IoU"][-1]


import os

checkpoint_path = "/tmp/output-oxford-model/"
if os.path.exists(checkpoint_path):
    import shutil

    shutil.rmtree(checkpoint_path)
os.makedirs(checkpoint_path, exist_ok=True)

options = ocp.CheckpointManagerOptions(max_to_keep=2, create=True)
mngr = ocp.CheckpointManager(os.path.abspath(checkpoint_path), options=options)


def save_model(epoch, model, mngr):
    state = nnx.state(model)
    # We should convert PRNGKeyArray to the old format for Dropout layers if needed
    # But for demo simple save is mostly sufficient unless using dropout heavily
    mngr.save(epoch, args=ocp.args.StandardSave(state))
    mngr.wait_until_finished()

## 8. Main Training Loop

In [None]:
print("Starting training...")
best_val_mean_iou = 0.0

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    train_one_epoch(epoch, model, optimizer, train_loader)

    # Evaluate every epoch for demo
    val_mean_iou = evaluate_model(epoch, model, val_loader, eval_metrics, eval_metrics_history)

    if val_mean_iou > best_val_mean_iou:
        print(f"New best Mean IoU: {val_mean_iou:.4f} (was {best_val_mean_iou:.4f})")
        save_model(epoch, model, mngr)
        best_val_mean_iou = val_mean_iou

print("Training finished!")




  variable[...]

For other Variable types use:

  variable.get_value()

  self.conf_mat.value += batch_conf_mat


## 9. Visualization

In [None]:
epochs_list = list(range(len(eval_metrics_history["val_total_loss"])))

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs_list, eval_metrics_history["val_total_loss"], label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Loss over Epochs")

plt.subplot(1, 2, 2)
plt.plot(epochs_list, eval_metrics_history["val_mean_IoU"], label="Validation Mean IoU")
plt.xlabel("Epoch")
plt.ylabel("Mean IoU")
plt.legend()
plt.title("Mean IoU over Epochs")
plt.show()

In [None]:
def display_image_mask_pred(img, mask, pred, label=""):
    if img.dtype in (np.float32,):
        # Denormalize if simplified, or just display as is if range 0-1
        # Assuming standard normalization, we might clipped data.
        # Just rescaling for display
        img = ((img - img.min()) / (img.max() - img.min()) * 255.0).astype(np.uint8)

    _, axs = plt.subplots(1, 5, figsize=(15, 10))
    axs[0].set_title(f"Image{label}")
    axs[0].imshow(img)
    axs[0].axis("off")

    axs[1].set_title(f"Mask{label}")
    axs[1].imshow(mask, vmin=0, vmax=2)
    axs[1].axis("off")

    axs[2].set_title("Image + Mask")
    axs[2].imshow(img)
    axs[2].imshow(mask, alpha=0.5, vmin=0, vmax=2)
    axs[2].axis("off")

    axs[3].set_title(f"Pred{label}")
    axs[3].imshow(pred, vmin=0, vmax=2)
    axs[3].axis("off")

    axs[4].set_title("Image + Pred")
    axs[4].imshow(img)
    axs[4].imshow(pred, alpha=0.5, vmin=0, vmax=2)
    axs[4].axis("off")
    plt.show()


model.eval()
val_batch = next(iter(val_loader))

images, masks = val_batch["image"], val_batch["mask"]
preds_logits = model(images)
preds = jnp.argmax(preds_logits, axis=-1)

for i in range(min(4, images.shape[0])):
    display_image_mask_pred(images[i], masks[i], preds[i], label=" (val)")