# Create SAM ‚Üí YOLOv12 Dataset

This notebook samples images from the `dataset/` folder, runs the Segment Anything Model (SAM) to generate masks, extracts bounding boxes from masks, maps species labels from `image_categories_cleaned.json` to integer class IDs, and writes a YOLOv12-style dataset into an `output/` folder.

Run cells in order. If you need to install dependencies, run:
%pip install torch torchvision opencv-python pillow numpy tqdm
%pip install git+https://github.com/facebookresearch/segment-anything.git

## ‚ö° GPU Acceleration Setup

**Important**: This notebook requires PyTorch with CUDA support to run fast on GPU.

If you see `CUDA NOT AVAILABLE` below, you need to install the CUDA version of PyTorch:

```bash
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
```

Or for CUDA 12.1:
```bash
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
```

**Expected speed**:
- GPU (CUDA): ~3-8 seconds per image ‚ö°
- CPU only: ~30-60 seconds per image üêå

Run the cells below to check your GPU status.


In [1]:
# Imports and helper functions
import json
import random
from pathlib import Path
import shutil
import sys
import os
import logging

import numpy as np
from PIL import Image
import cv2
import torch
from tqdm import tqdm

# Attempt to import SAM (segment_anything). If this fails, install in kernel's env:
try:
    from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
except Exception as e:
    print('Warning: failed to import segment_anything. Install with:')
    print('%pip install git+https://github.com/facebookresearch/segment-anything.git')
    raise

def xyxy_to_yolo(box, img_w, img_h):
    """Convert [x_min, y_min, x_max, y_max] to YOLO format [x_center, y_center, width, height] (normalized)."""
    x_min, y_min, x_max, y_max = box
    x_center = (x_min + x_max) / 2.0
    y_center = (y_min + y_max) / 2.0
    width = x_max - x_min
    height = y_max - y_min
    return [x_center / img_w, y_center / img_h, width / img_w, height / img_h]

def bbox_center(bbox):
    """Return (x_center, y_center) of a bbox [x_min, y_min, x_max, y_max]."""
    x_min, y_min, x_max, y_max = bbox
    return ((x_min + x_max) / 2.0, (y_min + y_max) / 2.0)

def distance_to_image_center(bbox, img_w, img_h):
    """Compute Euclidean distance from bbox center to image center."""
    img_center_x = img_w / 2.0
    img_center_y = img_h / 2.0
    bbox_cx, bbox_cy = bbox_center(bbox)
    return ((bbox_cx - img_center_x) ** 2 + (bbox_cy - img_center_y) ** 2) ** 0.5

def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

print('Imports and helpers ready')


Imports and helpers ready


In [2]:
# Check CUDA availability and GPU info
print('=== GPU/CUDA Status ===')
print(f'PyTorch version: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'CUDA version: {torch.version.cuda}')
    print(f'GPU device: {torch.cuda.get_device_name(0)}')
    print(f'GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
    print(f'Current device: {torch.cuda.current_device()}')
else:
    print('‚ö†Ô∏è CUDA NOT AVAILABLE - will run on CPU (very slow)')
    print('Possible reasons:')
    print('  1. PyTorch CPU-only version installed')
    print('  2. CUDA drivers not installed')
    print('  3. GPU not detected')
    print('\nTo fix, install PyTorch with CUDA:')
    print('  pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118')


=== GPU/CUDA Status ===
PyTorch version: 2.7.1+cu118
CUDA available: False
‚ö†Ô∏è CUDA NOT AVAILABLE - will run on CPU (very slow)
Possible reasons:
  1. PyTorch CPU-only version installed
  2. CUDA drivers not installed
  3. GPU not detected

To fix, install PyTorch with CUDA:
  pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118


In [3]:
# Test GPU by moving a tensor to CUDA
if torch.cuda.is_available():
    try:
        test_tensor = torch.randn(100, 100).cuda()
        print(f'‚úÖ Successfully created tensor on GPU: {test_tensor.device}')
        del test_tensor
        torch.cuda.empty_cache()
    except Exception as e:
        print(f'‚ùå Failed to use GPU: {e}')
else:
    print('‚ö†Ô∏è Skipping GPU test - CUDA not available')


‚ö†Ô∏è Skipping GPU test - CUDA not available


In [6]:
!nvidia-smi

Unable to determine the device handle for GPU0: 0000:01:00.0: Unknown Error
No devices were found


In [11]:
# Create yolomodelimages folder with 1200 random images from 7 classes
import json
import random
from pathlib import Path
import shutil

# Configuration
DATASET_DIR = Path('dataset')
LABELS_JSON = Path('image_categories_cleaned.json')
OUTPUT_FOLDER = Path('yolomodelimages')
NUM_SAMPLES = 6474
SEED = 42

# The 7 species classes
ALLOWED_SPECIES = {
    "Common leopard",
    "Himalayan goral",
    "Rhesus macaque",
    "Himalayan gray langur",
    "Himalayan tahr",
    "Yellow-throated marten",
    "Leopard cat"
}

print(f'Loading labels from {LABELS_JSON}...')
with open(LABELS_JSON, 'r') as f:
    labels_data = json.load(f)

# Filter images that contain at least one of the allowed species
filtered_images = []
species_counts = {sp: 0 for sp in ALLOWED_SPECIES}

for img_name, categories in labels_data.items():
    # Skip if categories is None or empty
    if not categories:
        continue
    
    # Check if this image contains any of our 7 species
    img_species = set(categories) & ALLOWED_SPECIES
    if img_species:
        img_path = DATASET_DIR / img_name
        if img_path.exists():
            filtered_images.append(img_name)
            # Count species occurrences
            for sp in img_species:
                species_counts[sp] += 1

print(f'\n‚úÖ Found {len(filtered_images)} images with the 7 target species')
print('\nSpecies distribution:')
for species, count in sorted(species_counts.items(), key=lambda x: x[1], reverse=True):
    print(f'  {species}: {count} images')

# Sample random images
random.seed(SEED)
if len(filtered_images) < NUM_SAMPLES:
    print(f'\n‚ö†Ô∏è Warning: Only {len(filtered_images)} images available, requested {NUM_SAMPLES}')
    sampled_images = filtered_images
else:
    sampled_images = random.sample(filtered_images, NUM_SAMPLES)
    print(f'\n‚úÖ Randomly selected {len(sampled_images)} images')

# Create output folder and copy images
OUTPUT_FOLDER.mkdir(parents=True, exist_ok=True)
print(f'\nCopying images to {OUTPUT_FOLDER}/...')

copied_count = 0
failed_count = 0
for img_name in tqdm(sampled_images, desc='Copying images'):
    src_path = DATASET_DIR / img_name
    dst_path = OUTPUT_FOLDER / img_name
    
    try:
        shutil.copy2(src_path, dst_path)
        copied_count += 1
    except Exception as e:
        print(f'Failed to copy {img_name}: {e}')
        failed_count += 1

print(f'\n‚úÖ Successfully copied {copied_count} images to {OUTPUT_FOLDER}/')
if failed_count > 0:
    print(f'‚ö†Ô∏è Failed to copy {failed_count} images')

print(f'\nFolder ready for upload to HPC server!')
print(f'Total size: {sum(f.stat().st_size for f in OUTPUT_FOLDER.glob("*") if f.is_file()) / (1024**2):.2f} MB')

Loading labels from image_categories_cleaned.json...

‚úÖ Found 6474 images with the 7 target species

Species distribution:
  Himalayan goral: 1984 images
  Himalayan tahr: 1592 images
  Himalayan gray langur: 1526 images
  Common leopard: 784 images
  Rhesus macaque: 326 images
  Leopard cat: 197 images
  Yellow-throated marten: 101 images

‚úÖ Randomly selected 6474 images

Copying images to yolomodelimages/...


Copying images: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6474/6474 [00:24<00:00, 264.60it/s]



‚úÖ Successfully copied 6474 images to yolomodelimages/

Folder ready for upload to HPC server!
Total size: 10444.14 MB


In [None]:
# Main processing cell: sample images, run SAM, select bbox closest to center, create YOLO labels

# Ensure inputs
if DOWNLOAD_CHECKPOINT and not SAM_CHECKPOINT.exists():
    print(f'SAM checkpoint {SAM_CHECKPOINT} not found locally. Attempting download...')
    try:
        import urllib.request
        SAM_CHECKPOINT.parent.mkdir(parents=True, exist_ok=True)
        urllib.request.urlretrieve(SAM_VIT_B_URL, str(SAM_CHECKPOINT))
        print('Downloaded SAM checkpoint')
    except Exception as e:
        raise RuntimeError(f'Failed to download SAM checkpoint: {e}')

if not DATASET_DIR.exists():
    raise FileNotFoundError(f'Dataset directory not found: {DATASET_DIR}')
if not LABELS_JSON.exists():
    raise FileNotFoundError(f'Labels JSON not found: {LABELS_JSON}')
if not SAM_CHECKPOINT.exists():
    raise FileNotFoundError(f'SAM checkpoint not found: {SAM_CHECKPOINT}')

ensure_dir(OUTPUT_DIR)
ensure_dir(OUTPUT_DIR / 'images')
ensure_dir(OUTPUT_DIR / 'labels')
ensure_dir(OUTPUT_DIR / 'vis')

# Load labels JSON
with open(LABELS_JSON, 'r', encoding='utf-8') as f:
    image_categories = json.load(f)

# Gather images
all_images = [p.name for p in DATASET_DIR.iterdir() if p.is_file() and p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
available_images = [img for img in all_images if img in image_categories]
if not available_images:
    raise RuntimeError('No labeled images found in the dataset folder')

# Filter to only allowed species
filtered_images = []
for img in available_images:
    cats = image_categories.get(img)
    if not cats:
        continue
    label = cats[0] if isinstance(cats, list) else cats
    if label in ALLOWED_SPECIES:
        filtered_images.append(img)

print(f'Found {len(filtered_images)} images with allowed species labels')
NUM_SAMPLES_ACTUAL = min(NUM_SAMPLES, len(filtered_images))
random.seed(SEED)
sampled = random.sample(filtered_images, NUM_SAMPLES_ACTUAL)
print(f'Sampled {NUM_SAMPLES_ACTUAL} images')

# map species to ids for classes.txt
species_set = set()
for img in sampled:
    cats = image_categories.get(img)
    if not cats:
        continue
    species_set.add(cats[0] if isinstance(cats, list) else cats)
species_list = sorted(list(species_set))
species_to_id = {s: i for i, s in enumerate(species_list)}
with open(OUTPUT_DIR / 'classes.txt', 'w', encoding='utf-8') as f:
    for s in species_list:
        f.write(s + '\n')
print(f'Created classes.txt with {len(species_list)} classes')

# Load SAM model
print('Loading SAM...')
try:
    sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=str(SAM_CHECKPOINT))
except KeyError:
    raise KeyError(f'SAM model type {SAM_MODEL_TYPE} not found; available: {list(sam_model_registry.keys())}')
except Exception as e:
    raise RuntimeError(f'Failed to load SAM: {e}')

# Move to device - FORCE GPU or fail
if SAM_DEVICE == 'cuda':
    if not torch.cuda.is_available():
        raise RuntimeError('CUDA requested but not available')
    sam = sam.to(SAM_DEVICE)
    print(f'SAM loaded on GPU: {torch.cuda.get_device_name(0)}')
    print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
else:
    sam = sam.to('cpu')
    print('SAM loaded on CPU (slow)')

# Create mask generator with PERFORMANCE-TUNED settings
mask_generator = SamAutomaticMaskGenerator(
    sam,
    points_per_side=SAM_POINTS_PER_SIDE,  # Fewer points = faster (default 32)
    pred_iou_thresh=SAM_PRED_IOU_THRESH,  # Higher = fewer low-quality masks
    stability_score_thresh=SAM_STABILITY_SCORE_THRESH,  # Higher = fewer unstable masks
    crop_n_layers=0,  # Disable crop layers for speed (default 0)
    crop_n_points_downscale_factor=1,  # Not used if crop_n_layers=0
    min_mask_region_area=100,  # Remove tiny masks
)

processed = 0
skipped_oom = 0
TOP_N_MASKS = 5  # Consider top 5 largest masks for center-based selection

print(f'\nStarting processing with GPU acceleration...')

for img_name in tqdm(sampled, desc='Processing'):
    src = DATASET_DIR / img_name
    dst_img = OUTPUT_DIR / 'images' / img_name

    img_bgr = cv2.imread(str(src))
    if img_bgr is None:
        print(f'Failed to read {src}; skipping')
        continue
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    orig_h, orig_w = img_rgb.shape[:2]
    resized = img_rgb
    h, w = orig_h, orig_w

    # ALWAYS downscale to save GPU memory
    if MAX_IMAGE_LONG_SIDE is not None:
        long_side = max(h, w)
        if long_side > MAX_IMAGE_LONG_SIDE:
            scale = MAX_IMAGE_LONG_SIDE / float(long_side)
            new_w = int(w * scale)
            new_h = int(h * scale)
            resized = cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA)
            h, w = resized.shape[:2]

    # Generate masks with GPU acceleration
    try:
        with torch.no_grad():  # Critical for memory savings
            masks = mask_generator.generate(resized)
    except RuntimeError as e:
        msg = str(e).lower()
        if 'out of memory' in msg or 'cuda' in msg:
            skipped_oom += 1
            print(f'\nCUDA OOM on {img_name} (skipped {skipped_oom} total). Try reducing MAX_IMAGE_LONG_SIDE or SAM_POINTS_PER_SIDE')
            torch.cuda.empty_cache()  # Clear cache and continue
            continue
        else:
            print(f'Failed to generate masks for {img_name}: {e}; skipping')
            continue

    if not masks:
        print(f'No masks for {img_name}; skipping')
        continue

    # sort masks by area descending
    masks_sorted = sorted(masks, key=lambda m: m.get('area', 0), reverse=True)
    
    # consider top N largest masks and pick the one closest to image center
    candidates = masks_sorted[:TOP_N_MASKS]
    
    # build list of (distance_to_center, mask) tuples
    candidate_distances = []
    for mask in candidates:
        bbox = mask.get('bbox')
        if bbox is None:
            seg = mask.get('segmentation')
            ys, xs = np.where(seg)
            if len(xs) == 0 or len(ys) == 0:
                continue
            x_min, x_max = int(xs.min()), int(xs.max())
            y_min, y_max = int(ys.min()), int(ys.max())
        else:
            x_min, y_min, bw, bh = bbox
            x_max = x_min + bw
            y_max = y_min + bh
        
        # compute distance to image center (on resized coords)
        dist = distance_to_image_center([x_min, y_min, x_max, y_max], w, h)
        candidate_distances.append((dist, mask, [x_min, y_min, x_max, y_max]))
    
    if not candidate_distances:
        print(f'No valid masks for {img_name}; skipping')
        continue
    
    # pick the mask with minimum distance to center
    candidate_distances.sort(key=lambda x: x[0])
    best_dist, best_mask, best_bbox_resized = candidate_distances[0]
    
    # check area threshold
    if best_mask.get('area', 0) < AREA_THRESHOLD * (h * w):
        print(f'Best mask too small for {img_name}; skipping')
        continue
    
    x_min, y_min, x_max, y_max = best_bbox_resized

    # scale back to original image coords if resized
    if (h, w) != (orig_h, orig_w):
        scale_x = orig_w / float(w)
        scale_y = orig_h / float(h)
        x_min *= scale_x
        x_max *= scale_x
        y_min *= scale_y
        y_max *= scale_y

    yolo_box = xyxy_to_yolo([x_min, y_min, x_max, y_max], orig_w, orig_h)

    cats = image_categories.get(img_name)
    label_name = cats[0] if isinstance(cats, list) else cats
    if label_name not in species_to_id:
        print(f'Label {label_name} not in species mapping; skipping')
        continue
    class_id = species_to_id[label_name]

    # save image and label
    shutil.copy2(src, dst_img)
    label_path = OUTPUT_DIR / 'labels' / (img_name.rsplit('.', 1)[0] + '.txt')
    with open(label_path, 'w', encoding='utf-8') as f:
        f.write(f"{class_id} {' '.join([f'{v:.6f}' for v in yolo_box])}\n")

    # save visualization with bbox and center crosshair
    vis = img_rgb.copy()
    # draw selected bbox in green
    cv2.rectangle(vis, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 3)
    # draw crosshair at image center in blue
    center_x = int(orig_w / 2)
    center_y = int(orig_h / 2)
    cv2.line(vis, (center_x - 20, center_y), (center_x + 20, center_y), (0, 0, 255), 2)
    cv2.line(vis, (center_x, center_y - 20), (center_x, center_y + 20), (0, 0, 255), 2)
    cv2.imwrite(str(OUTPUT_DIR / 'vis' / img_name), cv2.cvtColor(vis, cv2.COLOR_RGB2BGR))
    
    processed += 1
    if processed % 10 == 0:
        print(f'Processed {processed}/{NUM_SAMPLES_ACTUAL}')

print(f'\nDone! Processed {processed} images. Skipped {skipped_oom} due to OOM.')
print(f'Output at: {OUTPUT_DIR}')
if SAM_DEVICE == 'cuda':
    print(f'Peak GPU memory: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB')


Found 6471 images with allowed species labels
Sampled 1200 images
Created classes.txt with 7 classes
Loading SAM...
SAM loaded on CPU (slow)

Starting processing with GPU acceleration...
SAM loaded on CPU (slow)

Starting processing with GPU acceleration...


Processing:   0%|          | 0/1200 [00:03<?, ?it/s]



KeyboardInterrupt: 