[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/harpreetsahota204/finetune_sam3_with_fo_and_transformers/blob/main/how_to_finetune_sam3.ipynb)

In [None]:
!pip install fiftyone umap-learn
!pip install git+https://github.com/huggingface/transformers.git#egg=transformers
!pip install shapely
!pip install trackio

In this tutorial we'll make use of the [RIS-LAD](https://huggingface.co/datasets/Voxel51/RIS-LAD) dataset. [RIS-LAD is the first fine-grained benchmark](https://arxiv.org/abs/2507.20920) designed specifically for low-altitude drone image segmentation.

The dataset features 13,871 annotations with image-text-mask triplets captured from real drone footage at 30-100 meter altitudes with oblique viewing angles. Unlike existing remote sensing datasets that rely on high-altitude satellite imagery, RIS-LAD focuses on the visual complexities of low-altitude drone perception. These challenges include perspective changes, densely packed tiny objects, variable lighting conditions, and the notorious problems of **category drift** (tiny targets causing confusion with larger, semantically similar objects) and **object drift** (difficulty distinguishing among crowded same-class instances) that plague crowded aerial scenes.

# Download the Dataset

This benchmark addresses the gap in understanding how Visual AI systems see the world from a drone's perspective.

You can download the dataset from the Hugging Face Hub as follows:

In [None]:
import fiftyone as fo
from fiftyone.utils.huggingface import load_from_hub

dataset = load_from_hub(
    "Voxel51/RIS-LAD",
    overwrite=True,
    persistent=True
)

# Explore the Dataset

This dataset is in [FiftyOne format](https://docs.voxel51.com/user_guide/using_datasets.html). 

FiftyOne provides powerful functionality to inspect, search, and modify it from a [Dataset](https://docs.voxel51.com/api/fiftyone.core.dataset.html#fiftyone.core.dataset.Dataset)-wide down to a [Sample](https://docs.voxel51.com/api/fiftyone.utils.data.html#fiftyone.utils.data.Sample) level.

To see the schema of this dataset, you can simply call the Dataset as follows:

In [None]:
dataset

A FiftyOne dataset is comprised of [Samples](https://docs.voxel51.com/api/fiftyone.utils.data.html#fiftyone.utils.data.Sample).  

Samples store all information associated with a particular piece of data in a dataset, including basic metadata about the data, one or more sets of labels, and additional features associated with subsets of the data and/or label sets.

The attributes of a Sample are called [Fields](https://docs.voxel51.com/api/fiftyone.core.fields.html#fiftyone.core.fields.Field), which stores information about the Sample. When a new Field is assigned to a Sample in a Dataset, it is automatically added to the dataset’s schema and thus accessible on all other samples in the dataset.

To see the schema of a single Sample and the contents of its Fields, you can call the [`first()` method](https://docs.voxel51.com/api/fiftyone.core.dataset.html#fiftyone.core.dataset.Dataset.first):

In [None]:
dataset.first()

You can use the FiftyOne SDK to quickly compute some high-level statistics about your dataset with it's [built-in Aggregration methods](https://docs.voxel51.com/user_guide/using_aggregations.html).

For example, you can use the [`count()` aggregation](https://docs.voxel51.com/api/fiftyone.core.collections.html#fiftyone.core.collections.SampleCollection.count) to compute the number of non-None field values in a collection:

In [None]:
dataset.count("ground_truth.detections.label")

In [None]:
dataset.count("ground_truth.detections.referring_expression")

You can use the [`count_values()` aggregation](https://docs.voxel51.com/api/fiftyone.core.collections.html#fiftyone.core.collections.SampleCollection.count_values) to compute the occurrences of field values in a collection:

In [None]:
dataset.count_values("ground_truth.detections.label")

You can use the [`distinct()` aggregation](https://docs.voxel51.com/api/fiftyone.core.collections.html#fiftyone.core.collections.SampleCollection.distinct) to compute the distinct values of a field in a collection:

In [None]:
len(dataset.distinct("ground_truth.detections.referring_expression"))

### Adding a new Field to the Dataset

A useful piece of information to have about a sample is the number of detection labels in that sample.  You can easily add this to each sample in your Dataset using a `ViewField` expression.  

[`ViewField`](https://docs.voxel51.com/api/fiftyone.core.expressions.html#fiftyone.core.expressions.ViewField) and [`ViewExpression`](https://docs.voxel51.com/api/fiftyone.core.expressions.html#fiftyone.core.expressions.ViewExpression) classes allow you to use native Python operators to define expression. Simply wrap the target field of your sample in a `ViewField` and then apply comparison, logic, arithmetic or array operations to it to create a `ViewExpression`

The idiomatic FiftyOne way to count the number of instance labels in a sample is to use a `ViewField` expression to access the list of labels and then use `.length()` to count them.

To add the number of instances per image as a field on each sample in your dataset, you can use FiftyOne's [`set_values()`](https://docs.voxel51.com/api/fiftyone.core.dataset.html#fiftyone.core.dataset.Dataset.set_values) method. This will efficiently compute and store the count for each sample.

You can learn more about creating Dataset Views [in these docs](https://docs.voxel51.com/user_guide/using_views.html).

In [None]:
import fiftyone as fo
from fiftyone import ViewField as F

num_instances = dataset.values(F("ground_truth.detections").length())

dataset.set_values("num_instances", num_instances)

dataset.save()

In a similar manner, you can count the number of unique instance types for each sample in your Dataset:

In [None]:
from fiftyone import ViewField as F

labels_per_sample = dataset.values("ground_truth.detections.label")

num_distinct_labels_per_sample = [len(set(labels)) if labels else 0 for labels in labels_per_sample]

dataset.set_values("num_unique_instances", num_distinct_labels_per_sample)

dataset.save()

You can then combine these values together to create a complexity score for each Sample in your Dataset. As a simple example you can define the complexity score as number of instances + number of unique instance types. Note that the [`.values()` method](https://docs.voxel51.com/api/fiftyone.core.dataset.html#fiftyone.core.dataset.Dataset.values) is used for efficiently extracting a slice of field across all Samples in a Dataset.

In [None]:
unique_instance_counts = dataset.values("num_unique_instances")

num_instances_values = dataset.values("num_instances")

# Compute complexity scores for all samples
complexity_scores = [nd + nul for nd, nul in zip(num_instances_values, unique_instance_counts)]

# Set the values
dataset.set_values("complexity_score", complexity_scores)

dataset.save()

There's a lot of interesting and non-trival things, like those shown above, that you can do with Fiftyone. Here are some additional resources for you to check out later:

- For those familar with `pandas` you may want to check out this [pandas v FiftyOne cheat sheet](https://docs.voxel51.com/cheat_sheets/pandas_vs_fiftyone.html) to learn how to you can translate common pandas operations into FiftyOne syntax. 

- How to [create Views of your Dataset](https://docs.voxel51.com/cheat_sheets/views_cheat_sheet.html) 

- [Filtering cheat sheet docs](https://docs.voxel51.com/cheat_sheets/filtering_cheat_sheet.html)

Of course, the most interesting part of FiftyOne is [the FiftyOne App](https://docs.voxel51.com/user_guide/app.html#using-the-fiftyone-app) (which runs locally on your machine). Something that can help us in exploring our Dataset in the App is [the Dashboard plugin](https://docs.voxel51.com/plugins/plugins_ecosystem/dashboard.html). You can install the Plugin as follows:

In [None]:
!fiftyone plugins download https://github.com/voxel51/fiftyone-plugins --plugin-names @voxel51/dashboard

FiftyOne is open-source and hackable, and it has a robust framework for [building Plugins](https://docs.voxel51.com/plugins/developing_plugins.html), which allow you to extend and customize the functionality of the core tool to suit your specific needs.  FiftyOne has integrations with various computer vision models and other popular AI tools, [browse this curated collection of plugins](https://docs.voxel51.com/plugins/) to see how you can transform FiftyOne into a bespoke visual AI development workbench.

To launch the FiftyOne App, all you need to do is run the following:

In [None]:
session = fo.launch_app(dataset, auto=False)
session.url

<img src="ris_lad_in_fo_1.gif">


Of course, you can go deeper in the analysis of your dataset by [visualizing image embeddings](https://docs.voxel51.com/brain.html#visualizing-embeddings) in the App. You can use one of the the models from the [FiftyOne Model Zoo](https://docs.voxel51.com/model_zoo/overview.html), or a custom model which you can integrate as a [Remote Zoo Model](https://docs.voxel51.com/model_zoo/remote.html#remotely-sourced-zoo-models).

One example of a Remote Zoo Model is the integration of [SigLIP2](https://docs.voxel51.com/plugins/plugins_ecosystem/siglip2.html), which you can use to visualize image embeddings, perform zero shot classification, and perform image retrieval by [searching via natural language](https://docs.voxel51.com/brain.html#text-similarity) in the App.

Let's start by registering the Remote Zoo Model source:

In [None]:
import fiftyone.zoo as foz

# Register this custom model source
foz.register_zoo_model_source(
    "https://github.com/harpreetsahota204/siglip2", 
    overwrite=True
    )

Then instantiate the model:

In [None]:
import fiftyone.zoo as foz

siglip_model = foz.load_zoo_model(
    "google/siglip2-giant-opt-patch16-256"
)

You can than use the [`compute_embeddings()` method](https://docs.voxel51.com/api/fiftyone.core.models.html#fiftyone.core.models.compute_embeddings) of the Dataset:

In [None]:
dataset.compute_embeddings(
    model=siglip_model,
    embeddings_field="siglip2_embeddings",
)

Then use the [`compute_visualization()` method](https://docs.voxel51.com/api/fiftyone.brain.html#fiftyone.brain.compute_visualization) to generate low-dimensional representations of the samples (and/or individual objects) in your Dataset.

In [None]:
import fiftyone.brain as fob

results = fob.compute_visualization(
    dataset,
    embeddings="siglip2_embeddings",
    method="umap",
    brain_key="siglip2_viz",
    num_dims=2,
)


You can then use the [`compute_similarity()` method](https://docs.voxel51.com/api/fiftyone.brain.html#fiftyone.brain.compute_similarity) to build a similarity index over the images in your dataset, which allows you to sort by similarity or search with natural language.

In [None]:
# Build a similarity index
text_img_index = fob.compute_similarity(
    dataset,
    model="google/siglip2-giant-opt-patch16-256",
    embeddings="siglip2_embeddings",
    brain_key="siglip2_similarity",
)

With the embeddings computed you can perform a lot of non-trival math, such as computing scores for [uniqueness](https://docs.voxel51.com/brain.html#image-uniqueness), [representativeness](https://docs.voxel51.com/brain.html#image-representativeness), and [identifying near duplicates](https://docs.voxel51.com/brain.html#near-duplicates) with simple function calls. 


We can use the same SigLIP2 model to perform zero-shot classification and further enrich our Dataset with information it didn't have before:

In [None]:
siglip_model.text_prompt = "Low altitude drone footage taken at "
siglip_model.classes = ["day", "night", "dusk"]

dataset.apply_model(
    siglip_model,
    label_field="time_of_day"
)

In [None]:
siglip_model.text_prompt = "The scene in this low altitude drone footage is in a "
siglip_model.classes = ["urban area", "near water", "highway", "pedestrian area"]

dataset.apply_model(
    siglip_model,
    label_field="location"
)



Let's launch the App again and see what we can uncover by inspecting [the Embeddings panel](https://docs.voxel51.com/user_guide/app.html#embeddings-panel).

In [None]:
session = fo.launch_app(dataset, auto=False)
session.url

<img src="ris_lad_in_fo_2.gif">

# SAM 3 Initial Results (Before Fine-Tuning)

We can use [SAM 3 in FiftyOne](https://docs.voxel51.com/plugins/plugins_ecosystem/sam3_images.html) as a Remote Zoo Model. The pattern is exactly as we have seen before:

In [None]:
import fiftyone.zoo as foz

# Register the remote model source
foz.register_zoo_model_source(
    "https://github.com/harpreetsahota204/sam3_images",
    overwrite=True
)

# Load the model
sam3_model = foz.load_zoo_model("facebook/sam3")

The implementation in Fiftyone also allows us to compute embeddings for images using SAM 3 as well:

In [None]:
import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.brain as fob

sam3_model.pooling_strategy = "max"  # or "mean", "cls"

dataset.compute_embeddings(
    sam3_model,
    embeddings_field="sam_embeddings",
    batch_size=32
)

# Visualize with UMAP
fob.compute_visualization(
    dataset,
    method="umap",
    brain_key="sam_viz",
    embeddings="sam_embeddings",
    num_dims=2
)

To run the SAM 3 model on the dataset, all we have to do is set some values for the model, and use the [`apply_model()` of the Dataset](docs.voxel51.com/api/fiftyone.core.dataset.html#fiftyone.core.dataset.dataset.apply_model):

In [None]:
sam3_model.operation = "concept_segmentation"
sam3_model.threshold = 0.5
sam3_model.mask_threshold = 0.5

sam3_model.prompt = dataset.distinct("ground_truth.detections.label")

dataset.apply_model(
    sam3_model,
    label_field="sam3_not_finetuned",
    batch_size=32,
    num_workers=8,
    skip_failures=False
)

We can view the embeddings and the predictions in the App as well:

<img src="ris_lad_in_fo_3.gif">

We can then use [FiftyOne's evaluation API](https://docs.voxel51.com/user_guide/evaluation.html) to see how well the initial results. You can [`evaluate_detections()` method](https://docs.voxel51.com/user_guide/evaluation.html#detections) to evaluate the predictions of an object detection model stored in a [`Detections`](https://docs.voxel51.com/api/fiftyone.core.labels.html#fiftyone.core.labels.Detections), [`Polylines`](https://docs.voxel51.com/api/fiftyone.core.labels.html#fiftyone.core.labels.Polylines), or [`Keypoints`](https://docs.voxel51.com/api/fiftyone.core.labels.html#fiftyone.core.labels.Keypoints) field of your dataset or of a temporal detection model stored in a [`TemporalDetections`](https://docs.voxel51.com/api/fiftyone.core.labels.html#fiftyone.core.labels.TemporalDetection) field of your dataset.

In [None]:
results = dataset.evaluate_detections(
    "sam3_not_finetuned",          # Detections with masks
    gt_field="ground_truth",   # Detections with masks
    eval_key="initial_sam3_eval",
    use_masks=True,            # use instance masks for IoU
    compute_mAP=True,
    tolerance=2
)


The `evaluate_detections()` method returns a [`DetectionResults` instance](https://docs.voxel51.com/api/fiftyone.utils.eval.detection.html#fiftyone.utils.eval.detection.DetectionResults) that provides a variety of methods for generating various aggregate evaluation reports about your model.

In addition, when you specify an `eval_key` parameter, a number of helpful fields will be populated on each sample and its predicted/ground truth objects that you can leverage via the FiftyOne App to interactively explore the strengths and weaknesses of your model on individual samples.

You can print the report to get a high-level picture of the model performance:

In [None]:
results.print_report()
print(results.mAP())

You can also open the [Model Evaluation Panel](https://docs.voxel51.com/api/fiftyone.utils.eval.detection.html#fiftyone.utils.eval.detection.DetectionResults) to visualize and interactively explore the evaluation results in the App:

<img src="ris_lad_in_fo_4.gif">


You can use [Scenario Analysis](https://docs.voxel51.com/user_guide/app.html#scenario-analysis-sub-new) for a deep dive into model behavior across different scenarios.

This evaluation technique helps uncover edge cases, identify annotation errors, and understand performance variations in different contexts. It gives you a better insight into your model's strengths and weaknesses while enabling meaningful comparisons of performance under varying input conditions. 

Ultimately, this detailed analysis helps improve training data quality and builds intuition about when and why your model succeeds or fails.

<img src="ris_lad_in_fo_5.gif">

#### Check for Data Leakage between Train and Validation Splits

We're almost ready to fine-tune the model, but before we do we should check if there is any data leakage between the train and validation sets of the dataset.

Our dataset has [Sample level tags](https://docs.voxel51.com/user_guide/basics.html#tags) which indicate which split each sample belongs to:

In [None]:
dataset.distinct("tags")

Despite our best efforts, duplicates and other forms of non-IID samples show up in our data. 

When these samples end up in different splits, [this can have consequences when evaluating a model](https://voxel51.com/blog/on-leaky-datasets-and-a-clever-horse). It can often be easy to overestimate model capability due to this issue. The FiftyOne Brain offers a way to identify such cases in dataset splits.

The leaks of a dataset can be computed directly without the need for the predictions of a pre-trained model via the [`compute_leaky_splits()`](https://docs.voxel51.com/brain.html#leaky-splits) method:



In [None]:
import fiftyone.brain as fob

split_tags = ["train", "val"]

index = fob.compute_leaky_splits(
    dataset, 
    splits=split_tags,
    embeddings="sam_embeddings",
    )

The [`leaks_view()` method](https://docs.voxel51.com/api/fiftyone.brain.internal.core.leaky_splits.html#fiftyone.brain.internal.core.leaky_splits.LeakySplitsIndex.leaks_view) returns a view that contains only the leaks in the input splits. Once you have these leaks, it is wise to look through them. You may gain some insight into the source of the leaks:

In [None]:
leaks = index.leaks_view()

You can launch the app on this view like so:

```python
session = fo.launch_app(leaks)
```

Fortunately for us, there are no leaks between our splits. But, it's always a good idea to check.

We're now ready to fine-tune SAM 3

# SAM 3 Fine-tuning, Part 1: PyTorch Dataset from FiftyOne Dataset

First, we need to convert the FiftyOne Dataset to a torch Dataset.

In this section we will convert a FiftyOne dataset with Detection masks into a PyTorch 
dataset compatible with SAM fine-tuning.

Key insight: FiftyOne's [`to_patches()`](https://docs.voxel51.com/api/fiftyone.core.collections.html#fiftyone.core.collections.SampleCollection.to_patches) method creates a view where each 
detection becomes its own sample. This eliminates the need to manually 
flatten detections - FiftyOne handles it for us.

The pipeline:
1. Define a [`GetItem`](https://docs.voxel51.com/api/fiftyone.utils.torch.html#fiftyone.utils.torch.GetItem) to extract and transform each patch

2. Define a Collate Function

3. Split and flatten the dataset by converting dataset to patches view (one sample per detection)

4. Use [`to_torch()`](https://docs.voxel51.com/api/fiftyone.core.collections.html#fiftyone.core.collections.SampleCollection.to_torch) to create the DataLoader PyTorch dataset


### Step 1: Define a GetItem subclass

FiftyOne's `GetItem` class is the bridge between FiftyOne and PyTorch. It tells 
FiftyOne:

 1. What fields to extract from each sample (via `required_keys`)
 
 2. How to transform them into your desired format (via `__call__`)

The `field_mapping` parameter is important when working with patches. In a 
patches view, the detection data lives in the original field name (e.g., 
"ground_truth"), but we want to access it with a generic name in our code.

`field_mapping={"detection": "ground_truth"}` means:
 - In our code, we write `d["detection"]` 
 - FiftyOne knows to pull from the "ground_truth" field

This makes our `GetItem` reusable across datasets with different field names.

In [None]:
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F

from fiftyone.utils.torch import GetItem

class SAMPatchGetItem(GetItem):
    """
    Extracts and transforms patch data for SAM training.
    
    Each patch sample contains:
    - filepath: path to the full image
    - detection: the Detection object (bbox, mask, label, etc.)
    - metadata: image dimensions
    
    We transform this into SAM's expected format:
    - pixel_values: processed image tensor
    - mask_labels: ground truth mask resized to match model input
    """
    
    def __init__(self, processor, field_mapping=None):
        self.processor = processor
        # Must call super().__init__() with field_mapping - this sets up
        # the internal mapping that FiftyOne uses to pull the right fields
        super().__init__(field_mapping=field_mapping)

    @property
    def required_keys(self):
        # These are the keys we'll access in __call__.
        # 'detection' is a virtual key that gets mapped to the real field
        # via field_mapping. 'filepath' and 'metadata' are standard fields
        # that exist on all FiftyOne samples.
        return ["filepath", "detection", "metadata"]

    def __call__(self, d):
        """
        Transform a FiftyOne sample dict into SAM training format.
        
        This is where the FiftyOne → SAM conversion happens:
        - Cropped mask → full-image mask → resized to model input size
        - Raw image + text label → processed tensors
        """
        # Load full image (patches still reference the original image file)
        image = Image.open(d["filepath"]).convert("RGB")
        detection = d["detection"]
        metadata = d["metadata"]

        # Get image dimensions (cast to int for safety)
        w = int(metadata.width)
        h = int(metadata.height)

        # --- Bounding Box Extraction ---
        # FiftyOne stores bboxes as [x, y, width, height] with values in [0, 1].
        # We only need the top-left corner (x0, y0) to position the mask.
        rx, ry, rw, rh = detection.bounding_box
        x0 = int(rx * w)  # Top-left x in pixels
        y0 = int(ry * h)  # Top-left y in pixels

        # --- Mask Conversion ---
        # FiftyOne stores masks cropped to the bounding box to save space.
        # We expand the cropped mask back to full image size by placing it
        # at the correct position (x0, y0).
        full_mask = np.zeros((h, w), dtype=np.uint8)
        m = detection.mask
        mh, mw = m.shape
        full_mask[y0 : y0 + mh, x0 : x0 + mw] = m.astype(np.uint8)

        # --- Text Prompt ---
        # Use the detection's class label as the text prompt for SAM
        text = detection["label"]

        # --- SAM Processor ---
        # The processor handles image preprocessing and text tokenization.
        # For text-prompted SAM variants, we pass both image and text.
        inputs = self.processor(images=image, text=text, return_tensors="pt")
        
        # Remove batch dimension added by processor (we batch later in DataLoader)
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}

        # --- Resize Mask to Match Model Input ---
        # The processor resizes the image to the model's expected input size.
        # We need to resize our ground truth mask to match, so the loss
        # computation compares tensors of the same shape.
        hm = int(inputs["pixel_values"].shape[-2])  # Model input height
        wm = int(inputs["pixel_values"].shape[-1])  # Model input width
        
        # Convert mask to tensor and add batch+channel dims for interpolate
        mask_t = torch.from_numpy(full_mask).float()[None, None, ...]  # (1, 1, H, W)
        
        # Resize using nearest neighbor to preserve binary mask values
        mask_rs = F.interpolate(mask_t, size=(hm, wm), mode="nearest").squeeze(0)  # (1, H, W)

        # Add resized mask as the training label
        inputs["mask_labels"] = mask_rs
        
        return inputs

### Step 2: Collate function for DataLoader

When PyTorch's DataLoader batches samples together, it needs to know how to combine them. 

The default collate works for simple tensors, but we have:

 - `mask_labels`: must stay as a list because SAM3 expects variable-length 
   targets per image (even though we have one mask per sample here, the model
   interface expects a list)

A custom collate function tells DataLoader exactly how to handle each field.

In [None]:
def collate_fn(batch):
    """
    Collate function for SAM3 training.

    Handles batching of samples from SAMPatchGetItem:
    - Stacks tensor fields (pixel_values, input_ids, attention_mask, etc.)
    - Keeps mask_labels as a list of tensors (SAM3's expected format)
    
    Args:
        batch: List of dicts from SAMPatchGetItem.__call__
        
    Returns:
        Dict with batched tensors and mask_labels as list
    """
    result = {}
    
    # --- Fields that need special handling ---
    # mask_labels stays as a list because SAM3 expects targets in list format,
    # allowing for variable numbers of masks per image during training
    list_keys = {"mask_labels"}

    # --- Stack all standard tensor fields ---
    # These include pixel_values, input_ids, attention_mask, etc.
    # All samples have the same shape for these, so we can stack them
    keys = [k for k in batch[0].keys() if k not in list_keys]
    for key in keys:
        values = [item[key] for item in batch]
        if isinstance(values[0], torch.Tensor):
            # Stack tensors along new batch dimension
            result[key] = torch.stack(values)
        else:
            # Non-tensor fields (e.g., strings) stay as lists
            result[key] = values

    # --- Keep mask_labels as list of tensors ---
    # Each element is shape (1, H, W) - one mask per sample
    # SAM3's loss function iterates over this list
    result["mask_labels"] = [item["mask_labels"] for item in batch]
    
    return result

### Step 3: Split and "Flatten" the Dataset


#### Filter by split tags

[`match_tags()`](https://docs.voxel51.com/api/fiftyone.core.collections.html#fiftyone.core.collections.SampleCollection.match_tags) returns a view containing only samples with that tag. Your dataset should already have "train"/"val" tags on each sample.

In [None]:
train_view = dataset.match_tags("train")
val_view = dataset.match_tags("val")

print(f"Samples - train: {len(train_view)}, val: {len(val_view)}")

`to_patches(field)` creates a view where each detection in that
    field becomes its own sample. If you have 100 images with 5 detections
    each, `to_patches` gives you 500 patch samples. This is perfect for 
    instance-level training like SAM.

In [None]:
train_patches = train_view.to_patches("ground_truth")
val_patches = val_view.to_patches("ground_truth")

print(f"Patches - train: {len(train_patches)}, val: {len(val_patches)}")

In the patches view, each sample's detection data lives in the original field (e.g., "ground_truth"). field_mapping lets us access it with a generic name in our GetItem code.

This makes `SAMPatchGetItem` reusable - it always uses `d.get("detection")`, and `field_mapping` tells FiftyOne which actual field that refers to.

In [None]:
field_mapping = {"detection": "ground_truth"}

# we need to instantiate the SAM 3 processor
from transformers.models.sam3 import Sam3Processor

processor = Sam3Processor.from_pretrained("facebook/sam3")

train_getter = SAMPatchGetItem(processor, field_mapping=field_mapping)
val_getter = SAMPatchGetItem(processor, field_mapping=field_mapping)

### Step 4: Create DataLoaders

`to_torch()` converts a FiftyOne view to a PyTorch Dataset using your `GetItem` class to define how each sample is loaded and transformed.


In [None]:
train_dataset = train_patches.to_torch(train_getter)

val_dataset = val_patches.to_torch(val_getter)


Now we can instantiate standard PyTorch DataLoaders with our custom collate function.

In [None]:
# adjust based on your resources
batch_size = 8
num_workers = 0

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
)


Let's take a look at what the data loader yields.

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)
  elif isinstance(v, list):
    print(k, f"list of {len(v)} items")
    if len(v) > 0 and isinstance(v[0], torch.Tensor):
      print(v[0].shape)
  else:
    print(k, type(v))

# SAM 3 Fine-tuning, Part 2: Optimizer, Loss Function, and Training Loop

Let's define the learning rate, number of training epochs, and instantiate the model. Note that we have already instantiated the processor before, but for completeness it is instantiated again:

In [None]:
lr = 1e-5
num_epochs = 1
log_every = 200

from transformers.models.sam3 import Sam3Model, Sam3Processor

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = Sam3Processor.from_pretrained("facebook/sam3")
model = Sam3Model.from_pretrained("facebook/sam3")
model.to(device)

We can now define the optimizer and the loss function.


In [12]:
from torch.optim import Adam

optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=0)

### We'll implement a simple Hungarian matcher and loss for SAM 3. 

The [**Hungarian algorithm** (also called the Kuhn-Munkres algorithm)](https://cp-algorithms.com/graph/hungarian-algorithm.html) solves the **optimal assignment problem**: given N workers and M jobs with different costs for each worker-job pairing, find the assignment that minimizes total cost.

##### In the Context of Segmentation

When SAM 3 predicts **Q masks** but you have **T ground truth masks**, you need to figure out *which prediction corresponds to which target* before you can compute loss. This is the matching problem.

**What the matcher does:**

1. **Computes a cost matrix** — For every (prediction, target) pair, it calculates `1 - IoU` (so lower cost = better overlap)

2. **Finds optimal 1-to-1 assignment** — `linear_sum_assignment` finds the matching that minimizes total cost across all pairs

3. **Returns indices** — `(row_ind, col_ind)` tells you prediction `i` matches target `j`

##### Why We Do This

Without matching, you'd have an ambiguity problem. 

If SAM 3 outputs 5 masks and you have 3 ground truth objects, which predictions do you penalize?

Random assignment would create noisy gradients. The Hungarian matcher ensures each prediction is compared against its *best-fitting* target, giving meaningful supervision.

##### The Loss Pipeline

```
Predictions (Q masks)  ──┐
                         ├──► Hungarian Match ──► Paired masks 
                         |──► Dice + BCE Loss
Targets (T masks)      ──┘
```

This pattern (Hungarian matching + set-based loss) was popularized by **DETR** and is standard in transformer-based detection/segmentation where you have unordered set predictions rather than anchor-based outputs.

In [13]:
from scipy.optimize import linear_sum_assignment
import torch.nn as nn
import torch


@torch.no_grad()
def hungarian_matcher(pred_masks, target_masks):
    """Match predictions to targets using Hungarian assignment on mask IoU.

    pred_masks: [Q, Hp, Wp] logits
    target_masks: [T, Ht, Wt] binary/float

    SAM3 often outputs masks at a lower internal resolution than `pixel_values`.
    We resize targets to (Hp, Wp) before matching.
    """
    # Edge case: no ground truth masks to match against
    if target_masks.shape[0] == 0:
        return (
            torch.tensor([], dtype=torch.int64, device=pred_masks.device),
            torch.tensor([], dtype=torch.int64, device=pred_masks.device),
        )

    # Resize targets to match prediction resolution (SAM3 often outputs at lower res)
    if pred_masks.shape[-2:] != target_masks.shape[-2:]:
        tgt = target_masks[:, None, ...].float()  # [T,1,Ht,Wt] - add channel dim for interpolate
        tgt = nn.functional.interpolate(tgt, size=pred_masks.shape[-2:], mode="nearest")
        target_masks = tgt[:, 0, ...]  # [T,Hp,Wp] - remove channel dim

    # Convert to probabilities and flatten spatial dims for efficient matrix ops
    pred_sigmoid = pred_masks.sigmoid().flatten(1)  # [Q, Hp*Wp] - Q predictions flattened
    target_flat = target_masks.float().flatten(1)   # [T, Hp*Wp] - T targets flattened

    # ---- Compute IoU between all (prediction, target) pairs ----
    # Matrix multiply gives us sum(pred * target) for each pair = intersection
    intersection = torch.matmul(pred_sigmoid, target_flat.T)  # [Q, T]
    pred_area = pred_sigmoid.sum(1, keepdim=True)             # [Q, 1] - area of each pred
    target_area = target_flat.sum(1, keepdim=True).T          # [1, T] - area of each target
    # Union = A + B - intersection (broadcasting gives [Q, T] matrix)
    union = (pred_area + target_area - intersection).clamp_min(1e-6)

    # IoU matrix: entry [i,j] = IoU between prediction i and target j
    iou = intersection / union
    # Cost = 1 - IoU (lower cost = better match, Hungarian minimizes cost)
    cost = (1 - iou).detach().cpu().numpy()

    # ---- Hungarian algorithm finds optimal 1-to-1 assignment ----
    # Returns indices: prediction row_ind[k] matches target col_ind[k]
    row_ind, col_ind = linear_sum_assignment(cost)
    return (
        torch.as_tensor(row_ind, dtype=torch.int64, device=pred_masks.device),
        torch.as_tensor(col_ind, dtype=torch.int64, device=pred_masks.device),
    )


def compute_loss(pred_masks, target_masks):
    """Dice + BCE on matched masks.

    pred_masks: [N, Hp, Wp] logits (N = number of matched pairs)
    target_masks: [N, Ht, Wt] binary {0,1}
    """
    # No predictions to supervise - return zero loss with grad enabled
    if pred_masks.numel() == 0:
        return torch.tensor(0.0, device=pred_masks.device, requires_grad=True)

    # Align target resolution to prediction resolution if needed
    if pred_masks.shape[-2:] != target_masks.shape[-2:]:
        tgt = target_masks[:, None, ...].float()  # [N,1,Ht,Wt]
        tgt = nn.functional.interpolate(tgt, size=pred_masks.shape[-2:], mode="nearest")
        target_masks = tgt[:, 0, ...]

    # Flatten spatial dimensions: each mask becomes a 1D vector
    pred_flat = pred_masks.flatten(1)              # [N, Hp*Wp] - raw logits
    target_flat = target_masks.float().flatten(1)  # [N, Hp*Wp] - binary targets

    # ---- Dice Loss ----
    # Measures overlap; works well for imbalanced masks (few fg pixels)
    pred_sigmoid = pred_flat.sigmoid()  # Convert logits to probabilities
    # Dice = 2 * |A ∩ B| / (|A| + |B|), so loss = 1 - Dice
    numerator = 2 * (pred_sigmoid * target_flat).sum(1)   # 2 * intersection per mask
    denominator = pred_sigmoid.sum(1) + target_flat.sum(1)  # sum of areas
    dice_loss = 1 - (numerator + 1) / (denominator + 1)  # +1 for numerical stability
    dice_loss = dice_loss.mean()  # Average over all matched pairs

    # ---- BCE Loss ----
    # Per-pixel classification loss; complements Dice with pixel-level gradients
    bce_loss = nn.functional.binary_cross_entropy_with_logits(
        pred_flat, target_flat, reduction="mean"
    )

    # Combined loss: Dice handles global overlap, BCE handles pixel accuracy
    return dice_loss + bce_loss

### Evaluation Loop

This function measures how well the model is performing on held-out data **without updating weights**. For each batch:

1. **Forward pass** — Run images + text prompts through SAM3 to get predicted masks

2. **Match predictions to ground truth** — Use Hungarian matching to pair each prediction with its best-matching target

3. **Compute loss on matched pairs** — Calculate Dice + BCE loss to quantify prediction quality

4. **Accumulate** — Average the loss across all samples to get a single validation metric

The `@torch.no_grad()` decorator disables gradient computation since we're only measuring performance, not training. This saves memory and speeds things up.

Evaluation mirrors training logic (forward pass → match → loss) but skips the backward pass. This gives you a comparable metric to your training loss, letting you detect overfitting (training loss drops but val loss plateaus/rises).

In [14]:
@torch.no_grad()
def evaluate(dataloader, max_batches=20):
    model.eval()

    total = 0.0
    n = 0

    for b_idx, batch in enumerate(dataloader):
        if max_batches is not None and b_idx >= max_batches:
            break

        outputs = model(
            pixel_values=batch["pixel_values"].to(device),
            input_ids=batch["input_ids"].to(device),
            attention_mask=batch.get("attention_mask", None).to(device) if batch.get("attention_mask", None) is not None else None,
        )

        predicted_masks = outputs.pred_masks  # [B,Q,H,W]
        mask_labels = [m.to(device) for m in batch["mask_labels"]]  # list([T,H,W])

        for i in range(len(mask_labels)):
            src_idx, tgt_idx = hungarian_matcher(predicted_masks[i], mask_labels[i])
            if len(src_idx) > 0:
                loss = compute_loss(predicted_masks[i][src_idx], mask_labels[i][tgt_idx])
                total += float(loss.item())
                n += 1

    model.train()
    return (total / max(n, 1))


We need to visualize our predictions during training so we can monitor the model improvement. We'll create a function to show evaluation masks using matplotlib.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F

def show_mask(mask, ax, color=None, alpha=0.6):
    """Overlay a binary mask on a matplotlib axis."""
    if color is None:
        color = np.array([0.12, 0.56, 1.0])
    mask = mask.detach().float().cpu().numpy() if isinstance(mask, torch.Tensor) else np.asarray(mask)
    if mask.ndim == 3:
        mask = mask[0]
    h, w = mask.shape[-2:]
    rgba = np.concatenate([np.asarray(color).reshape(3), [alpha]])
    ax.imshow(mask.reshape(h, w, 1) * rgba.reshape(1, 1, 4))


def denorm_pixel_values(pixel_values, processor):
    """Denormalize for visualization."""
    x = pixel_values.detach().float().cpu()
    if x.ndim == 4:
        x = x[0]

    mean = None
    std = None
    ip = getattr(processor, "image_processor", None)
    if ip is not None:
        print("ip is not none")
        mean = getattr(ip, "image_mean", None)
        std = getattr(ip, "image_std", None)
        print("mean and std", ip)

    if mean is not None and std is not None:
        mean = torch.tensor(mean).view(3, 1, 1)
        std = torch.tensor(std).view(3, 1, 1)
        x = x * std + mean

    x = x.clamp(0, 1)
    return x.permute(1, 2, 0).numpy()


def _mask_iou(pred_bin: torch.Tensor, tgt_bin: torch.Tensor) -> torch.Tensor:
    pred = pred_bin.bool()
    tgt = tgt_bin.bool()
    inter = (pred & tgt).sum().float()
    union = (pred | tgt).sum().float().clamp_min(1.0)
    return inter / union


@torch.no_grad()
def visualize_batch(batch, step, processor, model, max_items=2, thresh=0.5, title_prefix=""):
    """Visualize input + GT + best predicted mask.

    We use raw `outputs.pred_masks` (mask queries) to keep a clean 1-to-1 matching
    story (query -> target). `processor.post_process_instance_segmentation(...)`
    is great for inference-style results, but it can filter/merge/re-rank masks,
    which breaks straightforward query-wise matching.

    SAM3 `pred_masks` are typically lower-res than `pixel_values`, so we:
    - downsample GT for IoU selection
    - upsample the chosen prediction for overlay
    """
    model.eval()

    outputs = model(
        pixel_values=batch["pixel_values"].to(device),
        input_ids=batch["input_ids"].to(device),
        attention_mask=batch.get("attention_mask", None).to(device) if batch.get("attention_mask", None) is not None else None,
    )
    pred_masks = outputs.pred_masks
    bsz = pred_masks.shape[0]
    n = min(max_items, bsz)
    fig, axes = plt.subplots(n, 3, figsize=(12, 4 * n))
    if n == 1:
        axes = axes.reshape(1, -1)

    for i in range(n):
        pv = batch["pixel_values"][i]
        img = denorm_pixel_values(pv, processor)
        hm, wm = pv.shape[-2], pv.shape[-1]

        tgt = batch["mask_labels"][i].to(pred_masks.device)
        tgt1 = tgt[0]

        pm = pred_masks[i]
        hp, wp = pm.shape[-2], pm.shape[-1]

        # Downsample GT to pred resolution for fair IoU
        tgt_ds = F.interpolate(tgt1[None, None, ...].float(), size=(hp, wp), mode="nearest")[0, 0]

        best_j = 0
        best_iou = -1.0
        for j in range(pm.shape[0]):
            iou = _mask_iou(pm[j].sigmoid() > thresh, tgt_ds > 0.5).item()
            if iou > best_iou:
                best_iou = iou
                best_j = j

        # Upsample chosen pred to pixel_values resolution for overlay
        pred_up = F.interpolate(pm[best_j][None, None, ...].sigmoid(), size=(hm, wm), mode="bilinear", align_corners=False)[0, 0]
        pred_bin = (pred_up > thresh).float().detach().cpu()

        axes[i, 0].imshow(img)
        axes[i, 0].set_title("Input")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(img)
        show_mask(tgt1.detach().cpu(), axes[i, 1], color=np.array([0.0, 1.0, 0.0]))
        axes[i, 1].set_title("GT")
        axes[i, 1].axis("off")

        axes[i, 2].imshow(img)
        show_mask(pred_bin, axes[i, 2], color=np.array([1.0, 0.0, 0.0]))
        axes[i, 2].set_title(f"Pred (best IoU={best_iou:.3f})")
        axes[i, 2].axis("off")

    plt.suptitle(f"{title_prefix}Step {step}".strip())
    plt.tight_layout()
    plt.show()

    model.train()


We can now write our training loop and train, while tracking our experiments via [`trackio`](https://huggingface.co/docs/trackio/en/index)

`trackio` is a lightweight, free experiment tracking Python library built on top of Hugging Face Datasets and Spaces 

In [None]:
from tqdm import tqdm
from statistics import mean
import trackio

trackio.init(project="sam3-rislad")

model.train()
step = 0

## The Training Loop

Each epoch processes the full training set, and each batch update follows the standard deep learning recipe:

1. **Forward pass** — Feed images and text prompts through SAM3 to get mask predictions

2. **Hungarian matching** — For each image, optimally pair predictions with ground truth masks

3. **Loss computation** — Calculate Dice + BCE loss on matched pairs, averaged across images in the batch

4. **Backward pass** — Compute gradients of the loss with respect to model parameters

5. **Optimizer step** — Update weights in the direction that reduces loss

The loop also includes periodic validation (every `log_every` steps) to monitor generalization, plus visualization to qualitatively inspect predictions.A few subtle points worth noting:

- **`total_loss + loss_i`** (not `+=`) — This keeps the computation graph intact for backprop. Using `+=` on a Python float would detach gradients.

- **`set_to_none=True`** — Slightly more efficient than zeroing gradients; sets `.grad` to `None` instead of filling with zeros.

- **Visualizing both train and val** — Lets you visually compare whether the model generalizes or just memorizes training examples.


In [None]:
for epoch in range(num_epochs):
    epoch_losses = []  # Track losses for this epoch (for logging/averaging)

    for batch in tqdm(train_dataloader, desc=f"epoch {epoch}"):
        # ---- Forward pass: generate mask predictions ----
        outputs = model(
            pixel_values=batch["pixel_values"].to(device),      # Images [B, C, H, W]
            input_ids=batch["input_ids"].to(device),            # Tokenized text prompts
            attention_mask=batch.get("attention_mask", None).to(device) 
                if batch.get("attention_mask", None) is not None else None,
        )

        predicted_masks = outputs.pred_masks  # [B, Q, H, W] - Q candidate masks per image
        mask_labels = [m.to(device) for m in batch["mask_labels"]]  # List of [T, H, W] targets

        # ---- Compute loss for each image in the batch ----
        total_loss = 0.0
        num_images = 0
        for i in range(len(mask_labels)):
            # Find optimal 1-to-1 assignment between predictions and targets
            src_idx, tgt_idx = hungarian_matcher(predicted_masks[i], mask_labels[i])
            
            if len(src_idx) > 0:  # Only compute loss if we have valid matches
                # Index matched pairs and compute Dice + BCE loss
                loss_i = compute_loss(predicted_masks[i][src_idx], mask_labels[i][tgt_idx])
                total_loss = total_loss + loss_i  # Accumulate (keeps gradient graph intact)
                num_images += 1

        # Average loss across images in batch (avoid div-by-zero)
        loss = total_loss / max(num_images, 1)

        # ---- Backward pass + weight update ----
        optimizer.zero_grad(set_to_none=True)  # Clear old gradients (set_to_none is slightly faster)
        loss.backward()                         # Compute gradients via backpropagation
        optimizer.step()                        # Update model weights

        # ---- Logging ----
        epoch_losses.append(float(loss.item()))
        step += 1

        # Optional: log to trackio if enabled
        if globals().get("_trackio_enabled", False):
            trackio.log({"train/loss": float(loss.item()), "epoch": int(epoch), "step": int(step)})

        # ---- Periodic validation + visualization ----
        if step % log_every == 0:
            # Run evaluation on validation set (limited batches for speed)
            eval_loss = evaluate(val_dataloader, max_batches=10)
            print(f"step {step} | train_loss={mean(epoch_losses[-log_every:]):.4f} | val_loss={eval_loss:.4f}")

            if globals().get("_trackio_enabled", False):
                trackio.log({"val/loss": float(eval_loss), "epoch": int(epoch), "step": int(step)})

            # Visualize predictions on training and validation samples
            # Helps catch issues like mode collapse, poor localization, etc.
            visualize_batch(batch, step, processor, model, max_items=2, title_prefix="[train] ")
            viz_batch = next(iter(val_dataloader))
            visualize_batch(viz_batch, step, processor, model, max_items=2, title_prefix="[val] ")

    print(f"Epoch {epoch} done | mean train loss: {mean(epoch_losses):.4f}")