In [None]:
!pip install fiftyone

In [None]:
# # already have this dataset locally
# import fiftyone as fo

# dataset = fo.load_dataset("ris-lad")

You are running the oldest supported major version of MongoDB. Please refer to https://deprecation.voxel51.com for deprecation notices. You can suppress this exception by setting your `database_validation` config parameter to `False`. See https://docs.voxel51.com/user_guide/config.html#configuring-a-mongodb-connection for more information


In [None]:
import fiftyone as fo
from fiftyone.utils.huggingface import load_from_hub

dataset = load_from_hub("Voxel51/RIS-LAD")

## SAM Fine-tuning Dataset from FiftyOne

This module converts a FiftyOne dataset with Detection masks into a PyTorch 
dataset compatible with SAM fine-tuning.

Key insight: FiftyOne's `to_patches()` method creates a view where each 
detection becomes its own sample. This eliminates the need to manually 
flatten detections - FiftyOne handles it for us.

The pipeline:
1. Convert dataset to patches view (one sample per detection)
2. Define a `GetItem` to extract and transform each patch
3. Use `to_torch()` to create the PyTorch dataset


### Step 1: Define a GetItem subclass

FiftyOne's `GetItem` class is the bridge between FiftyOne and PyTorch. It tells FiftyOne:
 1. What fields to extract from each sample (via required_keys)
 2. How to transform them into your desired format (via `__call__`)

The `field_mapping` parameter is important when working with patches. In a  patches view, the detection data lives in the original field name (e.g., `ground_truth), but we want to access it with a generic name in our code.

`field_mapping={"detection": "ground_truth"}` means:
 - In our code, we write `d.get("detection")` 
 - FiftyOne knows to pull from the `ground_truth` field

This makes our `GetItem` reusable across datasets with different field names.


In [None]:
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader
import torch
from fiftyone.utils.torch import GetItem


class SAMPatchGetItem(GetItem):
    """
    Extracts and transforms patch data for SAM training.
    
    Each patch sample contains:
    - filepath: path to the full image
    - detection: the Detection object (bbox, mask, label, etc.)
    - metadata: image dimensions
    
    We transform this into SAM's expected format:
    - pixel_values: processed image tensor
    - input_boxes: bbox in absolute pixel coords
    - ground_truth_mask: full-image binary mask
    - referring_expression: text prompt for the object
    """
    
    def __init__(self, processor, field_mapping=None):
        self.processor = processor
        # Must call super().__init__() with field_mapping - this sets up
        # the internal mapping that FiftyOne uses to pull the right fields
        super().__init__(field_mapping=field_mapping)
    
    @property
    def required_keys(self):
        # These are the keys we'll access in __call__.
        # 'detection' is a virtual key that gets mapped to the real field
        # via field_mapping. 'filepath' and 'metadata' are standard fields
        # that exist on all FiftyOne samples.
        return ['filepath', 'detection', 'metadata']
    
    def __call__(self, d):
        """
        Transform a FiftyOne sample dict into SAM training format.
        
        This is where the FiftyOne → SAM conversion happens:
        - Relative bbox coords → absolute pixel coords
        - Cropped mask → full-image mask
        - Raw image → processed tensor
        """
        # Load full image (patches still reference the original image file)
        image = Image.open(d["filepath"]).convert("RGB")
        detection = d.get("detection")
        metadata = d.get("metadata")
        
        w = metadata.width
        h = metadata.height
        
        # --- Bounding Box Conversion ---
        # FiftyOne stores bboxes as [x, y, width, height] with values in [0, 1]
        # representing fractions of the image dimensions. This is great for
        # resolution-independent storage, but SAM needs absolute pixel coords
        # in [x_min, y_min, x_max, y_max] format.
        rx, ry, rw, rh = detection.bounding_box
        bbox = [
            int(rx * w),           # x_min in pixels
            int(ry * h),           # y_min in pixels
            int((rx + rw) * w),    # x_max in pixels
            int((ry + rh) * h)     # y_max in pixels
        ]
        
        # --- Mask Conversion ---
        # FiftyOne stores instance masks efficiently by cropping them to the
        # bounding box region. A mask for a 50x50 pixel object is stored as a
        # 50x50 array, not a full 1080x1080 array. This saves significant space.
        #
        # SAM expects full-image masks where the object pixels are in their
        # actual location. We "unpack" the cropped mask by placing it into
        # a full-size zero array at the correct position.
        mask = detection.mask
        full_mask = np.zeros((h, w), dtype=np.uint8)
        if mask is not None:
            mask_h, mask_w = mask.shape
            x_start, y_start = bbox[0], bbox[1]
            full_mask[y_start:y_start + mask_h, x_start:x_start + mask_w] = mask
        
        # --- SAM Processor ---
        # The HuggingFace SamProcessor handles all the image preprocessing:
        # - Resizes image to SAM's expected input size
        # - Normalizes pixel values
        # - Formats the bbox prompt correctly
        # We pass input_boxes as a nested list [[bbox]] because SAM supports
        # multiple prompts per image (we just have one).
        inputs = self.processor(image, input_boxes=[[bbox]], return_tensors="pt")
        
        # The processor adds a batch dimension (for single-image inference).
        # We remove it here since DataLoader will batch multiple samples later.
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        
        # Add our ground truth mask for loss computation during training
        inputs["ground_truth_mask"] = full_mask
        
        # Add referring expression for text-prompted segmentation (e.g., SAM3)
        # getattr with default handles cases where the field doesn't exist
        inputs["referring_expression"] = getattr(detection, 'referring_expression', None)
        
        return inputs

### Step 2: Collate function for DataLoader


When PyTorch's DataLoader batches samples together, it needs to know how to combine them. The default collate works for simple tensors, but we have:
 - `ground_truth_mask`: might be different sizes if images have different dims
 - `referring_expression`: strings, not tensors

A custom collate function tells DataLoader exactly how to handle each field.

In [None]:
def collate_fn(batch):
    """
    Custom collate function for SAM training.
    
    Handles the quirks of our data:
    - Stacks tensor outputs from the processor
    - Handles variable-size ground truth masks
    - Keeps referring expressions as a list of strings
    
    Args:
        batch: List of dicts from SAMPatchGetItem.__call__
        
    Returns:
        Dict with batched tensors and lists
    """
    result = {}
    
    # These fields need special handling (not simple tensor stacking)
    special_keys = {"ground_truth_mask", "referring_expression"}
    
    # Stack all the processor outputs (pixel_values, input_boxes, etc.)
    keys = [k for k in batch[0].keys() if k not in special_keys]
    for key in keys:
        values = [item[key] for item in batch]
        if isinstance(values[0], torch.Tensor):
            result[key] = torch.stack(values)
        else:
            result[key] = values
    
    # Ground truth masks: try to stack (works if all same size)
    # If images have different dimensions, keep as list and handle in training
    ground_truth_masks = [
        torch.tensor(item["ground_truth_mask"], dtype=torch.float32) 
        for item in batch
    ]
    try:
        result["ground_truth_mask"] = torch.stack(ground_truth_masks)
    except RuntimeError:
        # Different sizes - training loop needs to handle per-sample
        result["ground_truth_mask"] = ground_truth_masks
    
    # Referring expressions are strings - just keep as list
    result["referring_expression"] = [item["referring_expression"] for item in batch]
    
    return result

### Putting it all together: `create_dataloaders`

This function ties everything together. The key FiftyOne concepts:

1. Views: A "view" is a filtered/transformed lens on your dataset. The
    underlying data doesn't change - you're just looking at it differently.
    match_tags("train") gives you a view of only training samples.

 2. Patches: `to_patches(field)` creates a view where each detection in that
    field becomes its own sample. If you have 100 images with 5 detections
    each, `to_patches` gives you 500 patch samples. This is perfect for 
    instance-level training like SAM.

 3. `to_torch()`: Converts a FiftyOne view to a PyTorch Dataset using your
    GetItem class to define how each sample is loaded and transformed.


In [5]:
def create_dataloaders(
    fo_dataset,
    processor,
    batch_size=2,
    detection_field="ground_truth",
    num_workers=0,
    train_tag="train",
    val_tag="val",
):
    """
    Create train and validation DataLoaders from a FiftyOne dataset.
    
    Uses FiftyOne's patches view to automatically flatten detections,
    then converts to PyTorch format.
    
    Args:
        fo_dataset: FiftyOne dataset with Detection masks
        processor: SamProcessor instance
        batch_size: Batch size for training
        detection_field: Name of the Detections field containing your masks
        num_workers: DataLoader workers (0 = main process only)
        train_tag: Tag identifying training samples
        val_tag: Tag identifying validation samples
    
    Returns:
        train_dataloader, val_dataloader: Ready for training loop
    """
    # --- Step 1: Filter by split tags ---
    # match_tags returns a view containing only samples with that tag.
    # Your dataset should already have "train"/"val" tags on each sample.
    train_view = fo_dataset.match_tags(train_tag)
    val_view = fo_dataset.match_tags(val_tag)
    
    print(f"Samples - train: {len(train_view)}, val: {len(val_view)}")
    
    # --- Step 2: Convert to patches ---
    # This is the key step that makes everything cleaner. to_patches() creates
    # a view where each detection becomes its own sample. 
    #
    # Before: 1 sample with 6 detections
    # After:  6 patch samples, each with 1 detection
    #
    # This means we don't need custom flattening logic - FiftyOne handles it.
    train_patches = train_view.to_patches(detection_field)
    val_patches = val_view.to_patches(detection_field)
    
    print(f"Patches - train: {len(train_patches)}, val: {len(val_patches)}")
    
    # --- Step 3: Set up field mapping ---
    # In the patches view, each sample's detection data lives in the original
    # field (e.g., "ground_truth"). field_mapping lets us access it with a
    # generic name in our GetItem code.
    #
    # This makes SAMPatchGetItem reusable - it always uses d.get("detection"),
    # and field_mapping tells FiftyOne which actual field that refers to.
    field_mapping = {"detection": detection_field}
    
    train_getter = SAMPatchGetItem(processor, field_mapping=field_mapping)
    val_getter = SAMPatchGetItem(processor, field_mapping=field_mapping)
    
    # --- Step 4: Convert to PyTorch datasets ---
    # to_torch() wraps the FiftyOne view in a PyTorch Dataset interface.
    # When you access dataset[i], it calls your GetItem to load and transform
    # that sample on-the-fly.
    train_dataset = train_patches.to_torch(train_getter)
    val_dataset = val_patches.to_torch(val_getter)
    
    # --- Step 5: Create DataLoaders ---
    # Standard PyTorch DataLoaders with our custom collate function.
    # shuffle=True for training to randomize batch composition each epoch.
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=num_workers,
    )
    
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=num_workers,
    )
    
    return train_dataloader, val_dataloader

In [6]:
from transformers.models.sam3 import Sam3Model, Sam3Processor

processor = Sam3Processor.from_pretrained("facebook/sam3")

In [7]:
train_dataloader, val_dataloader = create_dataloaders(
    dataset,
    processor,
    batch_size=2,
)

Samples - train: 1682, val: 421
Patches - train: 11156, val: 2715


In [None]:
# Verify a batch
batch = next(iter(train_dataloader))
for k, v in batch.items():
    if isinstance(v, torch.Tensor):
        print(f"{k}: {v.shape}")
    else:
        print(f"{k}: list of {len(v)} items")

pixel_values: torch.Size([2, 3, 1008, 1008])
original_sizes: torch.Size([2, 2])
input_ids: torch.Size([2, 32])
attention_mask: torch.Size([2, 32])
input_boxes: torch.Size([2, 1, 4])
ground_truth_mask: torch.Size([2, 1080, 1080])
referring_expression: list of 2 tensors
