# COME15K: Download first 500 RGB + Ground-Truth samples

This notebook streams the **COME15K** dataset from Hugging Face (``RGBD-SOD/COME15K``) and saves the first **500** samples with:
- **RGB images** in `rgb/`
- **ground-truth masks** in `gt/`

Depth maps are **ignored**.

> **Defaults**
> - Output folder: `C:\\\repos\\upskill15k\\data`
> - Sample count: `500`

If you need different values, edit the constants in the first code cell.


In [10]:
# --- Constants (edit if needed) ---
from pathlib import Path

# Get the directory where this notebook is located
NOTEBOOK_DIR = Path(__file__).parent if '__file__' in globals() else Path.cwd()
OUT_DIR = NOTEBOOK_DIR / "data"  # relative to notebook location
COUNT   = 1000                                  # number of samples to save
SPLIT   = "train"                               # use train split only
MIN_RANK = 3                                    # minimum rank to include

# Field candidates (robust to schema variants)
RGB_KEYS = ["rgb", "image", "img", "rgb_image", "color"]
GT_KEYS  = ["gt", "mask", "gt_image", "saliency", "label", "gt_mask"]


In [11]:
# Install (if needed) and import dependencies
# Note: If running in a managed environment, you may skip the install lines
try:
    import datasets  # type: ignore
    from datasets import load_dataset
except ImportError:
    print("[Setup] Installing datasets...")
    import sys, subprocess
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', 'datasets'])
    from datasets import load_dataset

try:
    from PIL import Image
except ImportError:
    print("[Setup] Installing pillow...")
    import sys, subprocess
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', 'pillow'])
    from PIL import Image

import numpy as np
from pathlib import Path
from typing import Any, Dict, Optional


In [12]:
# Utility helpers

def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def to_png(image_like: Any, out_path: Path):
    """Save an image-like object to PNG using Pillow.
    Accepts PIL.Image, numpy arrays, bytes, or HF decoded Image features.
    Preserves grayscale ('L') vs RGB when possible.
    """
    from io import BytesIO
    if isinstance(image_like, Image.Image):
        img = image_like
    elif isinstance(image_like, np.ndarray):
        arr = image_like
        if arr.dtype != np.uint8:
            arr = arr.astype(np.float32)
            mn, mx = float(arr.min()), float(arr.max())
            if mx > mn:
                arr = (255.0 * (arr - mn) / (mx - mn)).clip(0, 255).astype(np.uint8)
            else:
                arr = np.zeros_like(arr, dtype=np.uint8)
        if arr.ndim == 2:
            img = Image.fromarray(arr, mode="L")
        elif arr.ndim == 3:
            if arr.shape[2] == 1:
                img = Image.fromarray(arr.squeeze(-1), mode="L")
            else:
                img = Image.fromarray(arr[:, :, :3])
        else:
            raise ValueError(f'Unsupported array shape: {arr.shape}')
    elif isinstance(image_like, (bytes, bytearray)):
        img = Image.open(BytesIO(image_like))
        try:
            img = img.convert("RGB")
        except Exception:
            pass
    else:
        arr = np.array(image_like)
        if arr.ndim == 2:
            img = Image.fromarray(arr.astype(np.uint8), mode="L")
        elif arr.ndim == 3:
            img = Image.fromarray(arr[:, :, :3].astype(np.uint8), mode="RGB")
        else:
            raise TypeError(f'Unsupported image-like type: {type(image_like)}')

    if img.mode not in ("RGB", "L"):
        try:
            img = img.convert("RGB")
        except Exception:
            arr = np.array(img)
            if arr.ndim == 2:
                img = Image.fromarray(arr.astype(np.uint8), mode="L")
            else:
                img = Image.fromarray(arr[..., :3].astype(np.uint8), mode="RGB")

    img.save(out_path, format="PNG")

def pick_field(example: Dict[str, Any], candidates):
    for k in candidates:
        if k in example and example[k] is not None:
            return example[k]
    return None


In [9]:
# Stream dataset and save RGB + GT masks
print(f'Output directory: {OUT_DIR}')
print(f'Filtering for rank >= {MIN_RANK}')
root = Path(OUT_DIR)
rgb_dir = root / 'rgb'
gt_dir  = root / 'gt'
ensure_dir(rgb_dir)
ensure_dir(gt_dir)

ds = load_dataset('RGBD-SOD/COME15K', split=SPLIT, streaming=True)

saved = 0
skipped_rank = 0
for i, ex in enumerate(ds):
    if saved >= COUNT:
        break

    # Check rank filter
    rank = ex.get('rank', ex.get('difficulty', None))
    if rank is not None and rank < MIN_RANK:
        skipped_rank += 1
        continue

    rgb = pick_field(ex, RGB_KEYS) or pick_field(ex, ['image'])
    gt  = pick_field(ex, GT_KEYS)

    if i == 0:
        print(f'[Info] First example keys: {list(ex.keys())}')
        if rgb is None:
            print('[Hint] RGB field not found under common keys; check dataset card.')
        if gt is None:
            print('[Hint] GT field not found; verify exact key name (e.g., "gt" or "mask").')

    if rgb is None or gt is None:
        continue

    rgb_path = rgb_dir / f'{saved:05d}.png'
    gt_path  = gt_dir  / f'{saved:05d}.png'

    try:
        to_png(rgb, rgb_path)
        to_png(gt,  gt_path)
    except Exception as e:
        print(f'[Warn] Failed on sample {i}: {e}')
        continue

    saved += 1
    if saved % 50 == 0:
        print(f'[Progress] Saved {saved} samples (skipped {skipped_rank} with rank < {MIN_RANK})...')

print(f'[Done] Saved {saved} RGB+GT pairs to: {root.resolve()}')
print(f'[Stats] Skipped {skipped_rank} samples with rank < {MIN_RANK}')


Output directory: c:\Users\jashim\OneDrive - Microsoft\Desktop\class2025\project\data
Filtering for rank >= 3
[Info] First example keys: ['name', 'rgb', 'depth', 'gt']
[Progress] Saved 50 samples (skipped 0 with rank < 3)...
[Progress] Saved 100 samples (skipped 0 with rank < 3)...
[Progress] Saved 150 samples (skipped 0 with rank < 3)...
[Progress] Saved 200 samples (skipped 0 with rank < 3)...
[Progress] Saved 250 samples (skipped 0 with rank < 3)...
[Progress] Saved 300 samples (skipped 0 with rank < 3)...
[Progress] Saved 350 samples (skipped 0 with rank < 3)...
[Progress] Saved 400 samples (skipped 0 with rank < 3)...
[Progress] Saved 450 samples (skipped 0 with rank < 3)...
[Progress] Saved 500 samples (skipped 0 with rank < 3)...
[Progress] Saved 550 samples (skipped 0 with rank < 3)...
[Progress] Saved 600 samples (skipped 0 with rank < 3)...
[Progress] Saved 650 samples (skipped 0 with rank < 3)...
[Progress] Saved 700 samples (skipped 0 with rank < 3)...
[Progress] Saved 750 