# SAM2 ISIC Notebook Demo

This notebook demonstrates the core flow used in the Streamlit app:
- load an image
- run SAM2 (or a classical fallback) with a point/box prompt
- get a binary mask
- compute ABCD-style metrics
- visualize and export the mask/overlay

Note: If SAM 2 or weights are not available, the notebook falls back to a classical segmentation method for demonstration purposes.

In [None]:
import os
import json
from pathlib import Path
import sys

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch

# Make sure we can import project modules whether we run from repo root or notebooks/
repo_root = Path.cwd()
if not (repo_root / 'sam2_infer.py').exists():
    if (repo_root.parent / 'sam2_infer.py').exists():
        repo_root = repo_root.parent
        sys.path.insert(0, str(repo_root))
else:
    sys.path.insert(0, str(repo_root))

from sam2_infer import Sam2Wrapper, sam2_available
from metrics import compute_metrics
from viz import overlay_mask_with_contour

print('Repo root:', repo_root)
print('CUDA available:', torch.cuda.is_available())
print('SAM2 package detected:', sam2_available())


In [None]:
# Paths and configuration
ckpt_path = Path('models/sam2_hiera_large.pt')
samples_dir = repo_root / 'data' / 'samples'

sample_images = sorted([p for p in samples_dir.glob('*') if p.suffix.lower() in ('.jpg','.jpeg','.png','.bmp')])
if sample_images:
    image_path = sample_images[0]
else:
    # Set this to a local image path if you don't have samples
    image_path = Path('path_to_your_image.jpg')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print('Checkpoint path:', ckpt_path)
print('Image path:', image_path)
print('Device:', device)


In [None]:
# Load and display image
if not image_path.exists():
    raise FileNotFoundError(f'Image not found: {image_path}. Put images in {samples_dir} or update image_path.')

image = Image.open(image_path).convert('RGB')
W, H = image.size
plt.figure(figsize=(6,6))
plt.imshow(image)
plt.title('Input image')
plt.axis('off')
plt.show()


In [None]:
# Initialize SAM2 (if available)
sam2 = Sam2Wrapper(device=device, ckpt_path=str(ckpt_path) if ckpt_path.exists() else None)
print('SAM backend ready:', sam2.ready)


In [None]:
# Define a prompt: either a point or a box.
# Edit these values as needed.

use_box = True  # set to False to use a single point

if use_box:
    box = (W * 0.25, H * 0.25, W * 0.75, H * 0.75)  # (x0, y0, x1, y1)
    prompt = {'box': box}
else:
    point = (W * 0.5, H * 0.5)  # (x, y)
    prompt = {'point': point}

print('Prompt:', prompt)


In [None]:
# Run segmentation
np_img = np.array(image)
mask = None
used_backend = 'fallback'

if sam2.ready:
    try:
        if 'point' in prompt:
            px, py = prompt['point']
            mask = sam2.segment(np_img, point=(px, py))
        else:
            mask = sam2.segment(np_img, box=prompt['box'])
        used_backend = 'sam2'
    except Exception as e:
        print('SAM2 inference failed:', e)

if mask is None:
    if 'point' in prompt:
        px, py = prompt['point']
        mask = Sam2Wrapper.segment_fallback(np_img, point=(px, py))
    else:
        mask = Sam2Wrapper.segment_fallback(np_img, box=prompt['box'])

if mask is None:
    raise RuntimeError('Segmentation failed.')

print('Backend used:', used_backend)
print('Mask shape:', mask.shape, 'dtype:', mask.dtype, 'foreground pixels:', int(mask.sum()))


In [None]:
# Visualize overlay and export
exports_dir = repo_root / 'exports'
exports_dir.mkdir(parents=True, exist_ok=True)

overlay = overlay_mask_with_contour(image, mask, mask_color=(0, 255, 0), contour_color=(255, 0, 0), alpha=0.35)

plt.figure(figsize=(6,6))
plt.imshow(overlay)
plt.title('Overlay (mask + contour)')
plt.axis('off')
plt.show()

overlay_path = exports_dir / 'notebook_overlay.png'
overlay.save(overlay_path)
print('Saved overlay to:', overlay_path)


In [None]:
# Compute metrics
metrics = compute_metrics(mask, np_img)
print(json.dumps(metrics, indent=2))

metrics_path = exports_dir / 'notebook_metrics.json'
with open(metrics_path, 'w', encoding='utf-8') as f:
    json.dump(metrics, f, indent=2)
print('Saved metrics to:', metrics_path)


In [None]:
# Save binary mask and a masked image view
mask_img = Image.fromarray((mask.astype(np.uint8) * 255), mode='L')
mask_path = exports_dir / 'notebook_mask.png'
mask_img.save(mask_path)

masked = np_img.copy()
masked[~mask.astype(bool)] = 0
masked_img = Image.fromarray(masked)
masked_path = exports_dir / 'notebook_masked.png'
masked_img.save(masked_path)

plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.imshow(mask_img, cmap='gray')
plt.title('Mask')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(masked_img)
plt.title('Masked image')
plt.axis('off')
plt.show()

print('Saved mask to:', mask_path)
print('Saved masked image to:', masked_path)


### Notes
- If `models/sam2_hiera_large.pt` is missing, the notebook will still produce a mask via classical fallback.
- To use SAM 2, download the checkpoint and update `ckpt_path` if needed.
- Set `use_box = False` in the prompt cell to use a single click point instead of a box.
- Device is auto-selected based on CUDA availability; set it manually if needed.
