# Create SAM â†’ YOLOv12 Dataset

This notebook samples images from the `dataset/` folder, runs the Segment Anything Model (SAM v2) to generate masks, extracts bounding boxes from masks, maps species labels from `image_categories_cleaned.json` to integer class IDs, and writes a YOLOv12-style dataset into an `output/` folder.

Usage notes:
- You can edit the parameters cell below to change sampling size, area threshold, SAM checkpoint, and output directory.
- Make sure SAM and required libraries are installed in the notebook kernel environment (see `requirements-sam-yolo.txt`).

In [16]:
# Imports and helper functions
import json
import random
from pathlib import Path
import shutil
import sys
import os

import numpy as np
from PIL import Image
import cv2
import torch

# SAM import (may raise if not installed)
try:
    from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
except Exception as e:
    print('Warning: segment_anything import failed. Install the package to run segmentation cells.')
    raise


def xyxy_to_yolo(box, img_w, img_h):
    x_min, y_min, x_max, y_max = box
    x_center = (x_min + x_max) / 2.0
    y_center = (y_min + y_max) / 2.0
    width = x_max - x_min
    height = y_max - y_min
    return [x_center / img_w, y_center / img_h, width / img_w, height / img_h]

print('Loaded helper functions')

Loaded helper functions


In [17]:
# Parameters (edit these as needed)
DATASET_DIR = Path('dataset')
LABELS_JSON = Path('image_categories_cleaned.json')
# Default checkpoint filename (will auto-download if missing when DOWNLOAD_CHECKPOINT=True)
SAM_CHECKPOINT = Path('./sam_vit_b_01ec64.pth')  # <-- will be downloaded automatically by the notebook if not present
OUTPUT_DIR = Path('output_sam_yolo_notebook')
NUM_SAMPLES = 200
SEED = 42
AREA_THRESHOLD = 0.01  # fraction of image area
SAM_MODEL_TYPE = 'vit_b'

# Auto-download behavior (set to False if you want to manually provide checkpoint)
DOWNLOAD_CHECKPOINT = True
# Official FB public URL for the vit_b checkpoint (used if DOWNLOAD_CHECKPOINT=True)
SAM_VIT_B_URL = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth'

print('Parameters set:')
print(f'DATASET_DIR={DATASET_DIR}')
print(f'LABELS_JSON={LABELS_JSON}')
print(f'OUTPUT_DIR={OUTPUT_DIR}')
print(f'NUM_SAMPLES={NUM_SAMPLES}, AREA_THRESHOLD={AREA_THRESHOLD}, SAM_MODEL_TYPE={SAM_MODEL_TYPE}')
print(f'SAM_CHECKPOINT={SAM_CHECKPOINT} (DOWNLOAD_CHECKPOINT={DOWNLOAD_CHECKPOINT})')

Parameters set:
DATASET_DIR=dataset
LABELS_JSON=image_categories_cleaned.json
OUTPUT_DIR=output_sam_yolo_notebook
NUM_SAMPLES=200, AREA_THRESHOLD=0.01, SAM_MODEL_TYPE=vit_b
SAM_CHECKPOINT=sam_vit_b_01ec64.pth (DOWNLOAD_CHECKPOINT=True)


In [None]:
# Main processing cell: sample images, run SAM, create YOLO labels

# Auto-download SAM checkpoint if requested
if DOWNLOAD_CHECKPOINT and not SAM_CHECKPOINT.exists():
    print(f"SAM checkpoint {SAM_CHECKPOINT} not found locally. Attempting to download from {SAM_VIT_B_URL} ...")
    try:
        import urllib.request
        SAM_CHECKPOINT.parent.mkdir(parents=True, exist_ok=True)
        urllib.request.urlretrieve(SAM_VIT_B_URL, str(SAM_CHECKPOINT))
        print(f"Downloaded SAM checkpoint to: {SAM_CHECKPOINT}")
    except Exception as e:
        raise RuntimeError(f"Failed to download SAM checkpoint automatically: {e}. Please download manually and set SAM_CHECKPOINT.")

# Validate input paths early and provide clear error messages
if not DATASET_DIR.exists():
    raise FileNotFoundError(f"Dataset directory not found: {DATASET_DIR}. Make sure the path is correct and the folder exists.")
if not LABELS_JSON.exists():
    raise FileNotFoundError(f"Labels JSON not found: {LABELS_JSON}. Make sure `image_categories_cleaned.json` exists in the notebook folder or update LABELS_JSON.")
if not SAM_CHECKPOINT.exists():
    raise FileNotFoundError(f"SAM checkpoint not found: {SAM_CHECKPOINT}. Set `SAM_CHECKPOINT` to a valid .pth checkpoint path before running this cell.")

output_images_dir = OUTPUT_DIR / 'images'
output_labels_dir = OUTPUT_DIR / 'labels'
output_images_dir.mkdir(parents=True, exist_ok=True)
output_labels_dir.mkdir(parents=True, exist_ok=True)

# Load labels
with open(LABELS_JSON, 'r', encoding='utf-8') as f:
    image_categories = json.load(f)

# Gather available images
all_images = [p.name for p in DATASET_DIR.iterdir() if p.is_file() and p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
available_images = [img for img in all_images if img in image_categories]

if len(available_images) == 0:
    raise RuntimeError('No images available that match labels. Please check dataset directory and labels JSON.')

NUM_SAMPLES_ACTUAL = min(NUM_SAMPLES, len(available_images))
random.seed(SEED)
np.random.seed(SEED)
sampled = random.sample(available_images, NUM_SAMPLES_ACTUAL)
print(f'Sampled {NUM_SAMPLES_ACTUAL} images')

# Map species to ids
species_set = set()
for img in sampled:
    cats = image_categories.get(img)
    if not cats:
        continue
    if isinstance(cats, list) and len(cats) > 0:
        species_set.add(cats[0])
    elif isinstance(cats, str):
        species_set.add(cats)
species_list = sorted(list(species_set))
species_to_id = {sp: i for i, sp in enumerate(species_list)}
with open(OUTPUT_DIR / 'classes.txt', 'w', encoding='utf-8') as f:
    for sp in species_list:
        f.write(sp + '\n')

print(f'Found {len(species_list)} species')

# Load SAM (guard against typical errors)
print('Loading SAM model...')
try:
    sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=str(SAM_CHECKPOINT))
except KeyError:
    raise KeyError(f"SAM model type '{SAM_MODEL_TYPE}' not found in sam_model_registry. Available keys: {list(sam_model_registry.keys())}")
except Exception as e:
    # Provide a helpful message when checkpoint load fails
    raise RuntimeError(f"Failed to initialize SAM model. Error: {e}. Check that the checkpoint file is valid and compatible with the selected SAM model type.")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
sam.to(device)
mask_generator = SamAutomaticMaskGenerator(sam)

processed = 0
for img_name in sampled:
    src_path = DATASET_DIR / img_name
    dst_img_path = output_images_dir / img_name
    img_bgr = cv2.imread(str(src_path))
    if img_bgr is None:
        print(f'Warning: failed to read {src_path}, skipping')
        continue
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    h, w, _ = img_rgb.shape

    masks = mask_generator.generate(img_rgb)
    keep_bboxes = []
    
    keep_bboxes = []
mask_areas = []

for m in masks:
    mask = m['segmentation']
    area_fraction = mask.sum() / (h * w)
    if area_fraction < AREA_THRESHOLD:
        continue
    ys, xs = np.where(mask)
    if len(xs) == 0 or len(ys) == 0:
        continue
    x_min, x_max, y_min, y_max = xs.min(), xs.max(), ys.min(), ys.max()
    area = (x_max - x_min) * (y_max - y_min)
    keep_bboxes.append([x_min, y_min, x_max, y_max])
    mask_areas.append(area)

# Keep only the largest mask
if len(keep_bboxes) > 0:
    largest_idx = np.argmax(mask_areas)
    keep_bboxes = [keep_bboxes[largest_idx]]

    
    # for m in masks:
    #     mask = m['segmentation']
    #     area_fraction = mask.sum() / (h * w)
    #     if area_fraction < AREA_THRESHOLD:
    #         continue
    #     ys, xs = np.where(mask)
    #     if len(xs) == 0 or len(ys) == 0:
    #         continue
    #     x_min = int(xs.min())
    #     x_max = int(xs.max())
    #     y_min = int(ys.min())
    #     y_max = int(ys.max())
    #     keep_bboxes.append([x_min, y_min, x_max, y_max])

    cats = image_categories.get(img_name)
    if not cats:
        species_label = None
    else:
        if isinstance(cats, list):
            species_label = cats[0]
        else:
            species_label = cats

    shutil.copy2(src_path, dst_img_path)

    label_lines = []
    if species_label is not None and species_label in species_to_id and len(keep_bboxes) > 0:
        class_id = species_to_id[species_label]
        for box in keep_bboxes:
            yolo_box = xyxy_to_yolo(box, w, h)
            label_lines.append(f"{class_id} {' '.join([f'{v:.6f}' for v in yolo_box])}")

    label_file = output_labels_dir / (img_name.rsplit('.', 1)[0] + '.txt')
    with open(label_file, 'w', encoding='utf-8') as f:
        f.write('\n'.join(label_lines))

    processed += 1
    if processed % 10 == 0:
        print(f'Processed {processed}/{NUM_SAMPLES_ACTUAL}')

print(f'Done. Processed {processed} images. Output at: {OUTPUT_DIR}')

SAM checkpoint sam_vit_b_01ec64.pth not found locally. Attempting to download from https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth ...
Downloaded SAM checkpoint to: sam_vit_b_01ec64.pth
Sampled 200 images
Found 9 species
Loading SAM model...
Downloaded SAM checkpoint to: sam_vit_b_01ec64.pth
Sampled 200 images
Found 9 species
Loading SAM model...
Processed 10/200
Processed 10/200
Processed 20/200
Processed 20/200
Processed 30/200
Processed 30/200
Processed 40/200
Processed 40/200
Processed 50/200
Processed 50/200
Processed 60/200
Processed 60/200
Processed 70/200
Processed 70/200
