In [None]:
!pip install fiftyone

In [1]:
# # already have this dataset locally
# import fiftyone as fo

# dataset = fo.load_dataset("ris-lad")

You are running the oldest supported major version of MongoDB. Please refer to https://deprecation.voxel51.com for deprecation notices. You can suppress this exception by setting your `database_validation` config parameter to `False`. See https://docs.voxel51.com/user_guide/config.html#configuring-a-mongodb-connection for more information


In [None]:
import fiftyone as fo
from fiftyone.utils.huggingface import load_from_hub

dataset = load_from_hub("Voxel51/RIS-LAD")

## SAM Fine-tuning Dataset from FiftyOne

This module converts a FiftyOne dataset with Detection masks into a PyTorch 
dataset compatible with SAM fine-tuning.

The key challenge: FiftyOne stores masks efficiently by cropping them to the 
bounding box region. SAM expects full-image masks. We handle that conversion here.

The pipeline follows FiftyOne's recommended pattern:
1. Define a `GetItem` subclass to specify which fields to extract
2. Use `dataset.to_torch()` to create an intermediate torch dataset  
3. Wrap/flatten to handle multiple detections per sample


### Step 1: Define a GetItem subclass

- `GetItem` tells FiftyOne which fields to pull from each sample and how to  transform them. This is FiftyOne's bridge to PyTorch.

- `required_keys`: list of field names to extract from each sample

- `__call__`: transforms the extracted fields into your desired format


In [2]:
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
from fiftyone.utils.torch import GetItem

class SAMDataGetter(GetItem):
    """
    Extracts detection data from FiftyOne samples for SAM training.
    
    For each sample, pulls out:
    - filepath: path to the image
    - detections: the Detection objects containing bboxes and masks
    - metadata: image dimensions (needed for coordinate conversion)
    """
    
    def __init__(self, detection_field="ground_truth"):
        self.detection_field = detection_field
        # Must call super().__init__() - it sets up internal field mapping
        super().__init__()
    
    @property
    def required_keys(self):
        # These are the FiftyOne sample fields we need access to
        return ["filepath", self.detection_field, "metadata"]
    
    def __call__(self, d):
        """
        Transform a FiftyOne sample dict into our intermediate format.
        
        We extract each detection as a separate item since SAM trains
        on individual object masks, not full images with multiple objects.
        """
        detections = d.get(self.detection_field)
        metadata = d.get("metadata")
        
        items = []
        if detections is not None and hasattr(detections, 'detections'):
            for det in detections.detections:
                # Skip detections without masks
                if det.mask is None:
                    continue
                    
                items.append({
                    "filepath": d.get("filepath"),
                    "bounding_box": det.bounding_box,  # [x, y, w, h] relative coords
                    "mask": det.mask,                   # Cropped to bbox region
                    "label": det.label,
                    "width": metadata.width if metadata else None,
                    "height": metadata.height if metadata else None,
                    "referring_expression": getattr(det, 'referring_expression', None),
                })
        
        return {"items": items}


### Step 2: Flatten and process for SAM

The `GetItem` above returns multiple detections per sample. We need to flatten this so each detection becomes its own training example.

We also handle the FiftyOne â†’ SAM format conversion:
- FiftyOne bbox: `[x, y, width, height]` in relative `[0,1]` coordinates
- SAM bbox: `[x_min, y_min, x_max, y_max]` in absolute pixel coordinates
- FiftyOne mask: cropped to bounding box region
- SAM mask: full image size with object in correct location


In [3]:
class FlattenedSAMDataset(Dataset):
    """
    Flattens the FiftyOne torch dataset so each detection is a separate item,
    then processes each item into the format SAM expects.
    """
    
    def __init__(self, fo_torch_dataset, processor):
        self.processor = processor
        self.items = []
        
        # Flatten: one training sample per detection
        for sample in fo_torch_dataset:
            for item in sample["items"]:
                if item["width"] and item["height"]:
                    self.items.append(item)
        
        print(f"FlattenedSAMDataset created with {len(self.items)} items")
    
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        item = self.items[idx]
        w, h = item["width"], item["height"]
        
        # Load image
        image = Image.open(item["filepath"]).convert("RGB")
        
        # --- Bounding Box Conversion ---
        # FiftyOne: [x, y, width, height] as fractions of image size (0 to 1)
        # SAM: [x_min, y_min, x_max, y_max] in absolute pixels
        rx, ry, rw, rh = item["bounding_box"]
        bbox = [
            int(rx * w),           # x_min
            int(ry * h),           # y_min
            int((rx + rw) * w),    # x_max
            int((ry + rh) * h)     # y_max
        ]
        
        # --- Mask Conversion ---
        # FiftyOne stores masks cropped to the bounding box to save space.
        # SAM expects a full-image mask. We expand the cropped mask back
        # to full size by placing it at the correct location.
        mask = item["mask"]
        full_mask = np.zeros((h, w), dtype=np.uint8)
        mask_h, mask_w = mask.shape
        x_start, y_start = bbox[0], bbox[1]
        full_mask[y_start:y_start + mask_h, x_start:x_start + mask_w] = mask
        
        # --- SAM Processor ---
        # The processor handles image resizing, normalization, and 
        # preparing the bbox prompt in the format SAM expects
        inputs = self.processor(image, input_boxes=[[bbox]], return_tensors="pt")
        
        # Remove batch dimension added by processor (we batch later)
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        
        # Add our ground truth mask for loss computation
        inputs["ground_truth_mask"] = full_mask
        
        # Add referring expression for text-prompted segmentation
        inputs["referring_expression"] = item["referring_expression"]

        return inputs


### Step 3: Collate function for DataLoader

When batching samples, we need to handle the ground truth masks specially since they might have different sizes (if images have different dimensions).

In [4]:
def collate_fn(batch):
    """
    Custom collate function for SAM training.
    
    Stacks tensors from the processor and handles ground truth masks
    which may vary in size across samples.
    """
    result = {}
    
    # Keys that need special handling
    special_keys = {"ground_truth_mask", "referring_expression"}
    
    # Get all keys except special ones
    keys = [k for k in batch[0].keys() if k not in special_keys]
    
    # Stack processor outputs
    for key in keys:
        values = [item[key] for item in batch]
        if isinstance(values[0], torch.Tensor):
            result[key] = torch.stack(values)
        else:
            result[key] = values
    
    # Ground truth masks: stack if same size, keep as list otherwise
    ground_truth_masks = [
        torch.tensor(item["ground_truth_mask"], dtype=torch.float32) 
        for item in batch
    ]
    try:
        result["ground_truth_mask"] = torch.stack(ground_truth_masks)
    except RuntimeError:
        # Variable sizes - training loop will need to handle this
        result["ground_truth_mask"] = ground_truth_masks
    
    # Referring expressions: keep as list of strings
    result["referring_expression"] = [item["referring_expression"] for item in batch]
    
    return result


### Putting it all together: `create_dataloaders`

This function orchestrates the full pipeline:

1. Filter dataset by train/val tags

2. Create `GetItem` and call `to_torch()`

3. Wrap in `FlattenedSAMDataset`

4. Create `DataLoaders`

In [5]:
def create_dataloaders(
    fo_dataset,
    processor,
    batch_size=2,
    detection_field="ground_truth",
    num_workers=0,
    train_tag="train",
    val_tag="val",
):
    """
    Create train and validation DataLoaders from a FiftyOne dataset.
    
    Args:
        fo_dataset: FiftyOne dataset with Detection masks
        processor: SamProcessor instance
        batch_size: Batch size for training
        detection_field: Name of the Detections field
        num_workers: DataLoader workers
        train_tag: Tag identifying training samples
        val_tag: Tag identifying validation samples
    
    Returns:
        train_dataloader, val_dataloader
    """
    # Filter to train/val using existing tags on samples
    train_view = fo_dataset.match_tags(train_tag)
    val_view = fo_dataset.match_tags(val_tag)
    
    print(f"Using existing tags: {len(train_view)} train samples, {len(val_view)} val samples")
    
    # Step 1 & 2: Use GetItem pattern with to_torch()
    # This creates a torch-compatible dataset that lazily loads FiftyOne data
    data_getter = SAMDataGetter(detection_field=detection_field)
    train_torch_dataset = train_view.to_torch(data_getter)
    val_torch_dataset = val_view.to_torch(data_getter)
    
    print(f"Intermediate datasets: train={len(train_torch_dataset)}, val={len(val_torch_dataset)}")
    
    # Step 3: Flatten (one item per detection) and process for SAM
    train_dataset = FlattenedSAMDataset(train_torch_dataset, processor)
    val_dataset = FlattenedSAMDataset(val_torch_dataset, processor)
    
    # Create DataLoaders with our custom collate function
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=num_workers,
    )
    
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=num_workers,
    )
    
    return train_dataloader, val_dataloader


In [6]:
from transformers.models.sam3 import Sam3Model, Sam3Processor

processor = Sam3Processor.from_pretrained("facebook/sam3")

In [7]:
train_dataloader, val_dataloader = create_dataloaders(
    dataset,
    processor,
    batch_size=2,
)

Using existing tags: 1682 train samples, 421 val samples
Intermediate datasets: train=1682, val=421
FlattenedSAMDataset created with 11156 items
FlattenedSAMDataset created with 2715 items


In [8]:
# Verify a batch
batch = next(iter(train_dataloader))
for k, v in batch.items():
    if isinstance(v, torch.Tensor):
        print(f"{k}: {v.shape}")
    else:
        print(f"{k}: list of {len(v)} tensors")

pixel_values: torch.Size([2, 3, 1008, 1008])
original_sizes: torch.Size([2, 2])
input_ids: torch.Size([2, 32])
attention_mask: torch.Size([2, 32])
input_boxes: torch.Size([2, 1, 4])
ground_truth_mask: torch.Size([2, 1080, 1080])
referring_expression: list of 2 tensors
