# Create Water Segmentation Masks with SAM 2

This notebook uses SAM 2 to create segmentation masks for water surfaces across a sequence of river camera images.

In [None]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd

## Configuration

In [None]:
# Camera and directory configuration
camera_id = 'WI_East_Branch_Pecatonica_River_near_Blanchardville_Bullet'
images_dir = f'{camera_id}/images'
csv_path = f'{camera_id}/images_and_data.csv'

# SAM 2 model configuration
sam2_checkpoint = "../checkpoints/sam2.1_hiera_tiny.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
# For better quality (but slower), use the large model:
# sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
# model_cfg = "../configs/sam2.1_hiera_l.yaml"

# Output directory for masks
output_dir = f'{camera_id}/masks'
os.makedirs(output_dir, exist_ok=True)

# Temporary directory for SAM 2 (requires sequential numbered images)
sam_video_dir = f'{camera_id}/SAM'
os.makedirs(sam_video_dir, exist_ok=True)

## Device Setup

In [None]:
# Enable MPS fallback for unsupported operations
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# Select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

if device.type == "cuda":
    # Use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # Turn on tfloat32 for Ampere GPUs
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print("\nNote: MPS support is preliminary. SAM 2 may give different results on MPS vs CUDA.")

## Prepare Images for SAM 2

SAM 2 requires images to be sequentially numbered (e.g., 00000.jpg, 00001.jpg, ...). We'll create symbolic links with sequential names instead of copying files to save space.

In [None]:
# Get list of all images, sorted by filename
image_files = sorted([
    f for f in os.listdir(images_dir)
    if f.lower().endswith((".jpg", ".jpeg", ".png"))
])

print(f"Found {len(image_files)} images")

# Create symbolic links with sequential names (instead of copying)
# Symlinks are like shortcuts - they don't duplicate the data
for i, filename in enumerate(image_files):
    src_path = os.path.abspath(os.path.join(images_dir, filename))
    new_filename = f"{i:05d}.jpg"  # 00000.jpg, 00001.jpg, etc.
    dst_path = os.path.join(sam_video_dir, new_filename)
    
    # Remove existing symlink if it exists
    if os.path.islink(dst_path) or os.path.exists(dst_path):
        os.remove(dst_path)
    
    # Create symbolic link
    os.symlink(src_path, dst_path)

print(f"Created symbolic links in {sam_video_dir}")
print("(No data duplication - symlinks point to original files)")

In [None]:
# Scan all the JPEG frame names in the SAM video directory
frame_names = [
    p for p in os.listdir(sam_video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

print(f"Found {len(frame_names)} frames for SAM 2")

In [None]:
# Display the first frame
frame_idx = 0
plt.figure(figsize=(12, 8))
plt.title(f"Frame {frame_idx}")
plt.imshow(Image.open(os.path.join(sam_video_dir, frame_names[frame_idx])))
plt.axis('off')
plt.show()

## Load SAM 2 Model

In [None]:
from sam2.build_sam import build_sam2_video_predictor

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)
print("SAM 2 model loaded successfully")

In [None]:
# Initialize inference state
inference_state = predictor.init_state(video_path=sam_video_dir)
print("Inference state initialized")

## Helper Functions for Visualization

In [None]:
def show_mask(mask, ax, obj_id=None, random_color=False):
    """Display a segmentation mask on the given axes."""
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=200):
    """Display click points on the given axes."""
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', 
               s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', 
               s=marker_size, edgecolor='white', linewidth=1.25)


def show_box(box, ax):
    """Display a bounding box on the given axes."""
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', 
                                facecolor=(0, 0, 0, 0), lw=2))

## Interactive Water Segmentation

### Step 1: Add initial clicks to identify water

Look at the first frame above and identify coordinates where the water is visible.
Add positive clicks (label=1) on the water surface.

**Instructions:**
1. Look at the image above to identify water coordinates
2. Update the `points` array below with (x, y) coordinates on the water
3. If needed, add negative clicks (label=0) to exclude non-water regions

In [None]:
# Frame to annotate (typically start with frame 0)
ann_frame_idx = 0
ann_obj_id = 1  # Object ID for water (can be any unique integer)

# Add click coordinates here - MODIFY THESE COORDINATES based on your image
# Example: clicking on water surface locations
# You may need to run this cell multiple times with different points to refine
points = np.array([
    [750, 800],  # First click on water - MODIFY THESE COORDINATES
    # [250, 850],  # Uncomment and modify to add more clicks
], dtype=np.float32)

# Labels: 1 = positive click (add region), 0 = negative click (remove region)
labels = np.array([1], np.int32)  # Update size to match number of points

# Add the prompts
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# Show the results
plt.figure(figsize=(12, 8))
plt.title(f"Frame {ann_frame_idx} - Water Segmentation")
plt.imshow(Image.open(os.path.join(sam_video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
plt.axis('off')
plt.show()

print("\nIf the mask doesn't look right, modify the points above and re-run this cell.")
print("You can add more points or add negative clicks (label=0) to refine.")

### Step 2 (Optional): Refine with additional clicks

If the mask above isn't perfect, add more clicks to refine it.
You can skip this step if the mask looks good.

In [None]:
# Add refinement clicks here if needed
# Remember to include ALL previous clicks plus new ones

# Example with additional click:
points = np.array([
    [250, 800],  # Original click
    [1700, 850],  # Additional positive click
    [1100, 400],  # Negative click to remove unwanted region (label=0)
], dtype=np.float32)
labels = np.array([
    1, 
    1, 
    0
    ], np.int32)

_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# Show the refined results
plt.figure(figsize=(12, 8))
plt.title(f"Frame {ann_frame_idx} - Refined Water Segmentation")
plt.imshow(Image.open(os.path.join(sam_video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
plt.axis('off')
plt.show()

### Step 3: Propagate segmentation across all frames

In [None]:
# Run propagation throughout the video
video_segments = {}  # Store per-frame segmentation results

print("Propagating masks across all frames...")
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

print(f"Propagation complete! Generated masks for {len(video_segments)} frames.")

### Step 4: Visualize results on sample frames

In [None]:
# Visualize every N frames
vis_frame_stride = max(1, len(frame_names) // 6)  # Show ~6 frames

plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    plt.figure(figsize=(10, 7))
    plt.title(f"Frame {out_frame_idx} / {len(frame_names)}")
    plt.imshow(Image.open(os.path.join(sam_video_dir, frame_names[out_frame_idx])))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

print("\nIf the masks don't look good throughout, you may need to:")
print("1. Add refinement clicks on problem frames (see Step 5)")
print("2. Or reset and start over with better initial clicks")

### Step 5 (Optional): Refine masks on specific problematic frames

If you notice issues on specific frames, you can add clicks to refine them.

In [None]:
# Uncomment and modify to refine a specific frame

# # Choose a frame that needs refinement
# problem_frame_idx = 30  # MODIFY THIS
# ann_obj_id = 1

# # Show current mask on that frame
# plt.figure(figsize=(12, 8))
# plt.title(f"Frame {problem_frame_idx} - Before Refinement")
# plt.imshow(Image.open(os.path.join(sam_video_dir, frame_names[problem_frame_idx])))
# show_mask(video_segments[problem_frame_idx][ann_obj_id], plt.gca(), obj_id=ann_obj_id)
# plt.axis('off')
# plt.show()

# # Add refinement clicks on this frame
# points = np.array([[400, 500]], dtype=np.float32)  # MODIFY COORDINATES
# labels = np.array([0], np.int32)  # 0 = negative click to remove region

# _, _, out_mask_logits = predictor.add_new_points_or_box(
#     inference_state=inference_state,
#     frame_idx=problem_frame_idx,
#     obj_id=ann_obj_id,
#     points=points,
#     labels=labels,
# )

# # Show refined mask
# plt.figure(figsize=(12, 8))
# plt.title(f"Frame {problem_frame_idx} - After Refinement")
# plt.imshow(Image.open(os.path.join(sam_video_dir, frame_names[problem_frame_idx])))
# show_points(points, labels, plt.gca())
# show_mask((out_mask_logits > 0.0).cpu().numpy(), plt.gca(), obj_id=ann_obj_id)
# plt.axis('off')
# plt.show()

# # Re-run propagation to update all masks
# print("Re-running propagation with refinements...")
# video_segments = {}
# for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
#     video_segments[out_frame_idx] = {
#         out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
#         for i, out_obj_id in enumerate(out_obj_ids)
#     }
# print("Propagation complete!")

## Save Masks

Save the water segmentation masks as numpy arrays (.npy files)

In [None]:
# Create masks output directory
# masks_dir = os.path.join(output_dir, 'masks_npy')
masks_dir = output_dir
os.makedirs(masks_dir, exist_ok=True)

print(f"Saving masks to {masks_dir}...")

# Save each mask with corresponding original filename
for out_frame_idx, frame_name in enumerate(frame_names):
    frame_masks = video_segments[out_frame_idx]
    
    for out_obj_id, out_mask in frame_masks.items():
        # Convert to numpy array if needed
        mask_np = out_mask if isinstance(out_mask, np.ndarray) else np.array(out_mask)
        
        # Create filename based on sequential frame name and object ID
        mask_filename = f"{os.path.splitext(frame_name)[0]}_obj{out_obj_id}.npy"
        mask_path = os.path.join(masks_dir, mask_filename)
        
        # Save as .npy file
        np.save(mask_path, mask_np)

print(f"\nSaved {len(frame_names)} masks successfully!")
print(f"Masks saved to: {masks_dir}")

## Add Masks to Images Data CSV

Add the mask filenames to the existing images_and_data.csv file

In [None]:
# Load the existing images and data CSV
df = pd.read_csv(csv_path)
print(f"Loaded {len(df)} rows from {csv_path}")

# Create a mapping from original filename to mask filename
filename_to_mask = {}
for idx, original_filename in enumerate(image_files):
    mask_filename = f"{idx:05d}_obj1.npy"
    filename_to_mask[original_filename] = mask_filename

# Add mask_filename column to the dataframe
# The 'filename' column in df should match the image filenames
df['mask_filename'] = df['image_names'].map(filename_to_mask)

# Check if any filenames didn't get a mask (shouldn't happen)
missing_masks = df['mask_filename'].isna().sum()
if missing_masks > 0:
    print(f"\nWarning: {missing_masks} rows have no corresponding mask")
else:
    print(f"\nSuccessfully mapped all {len(df)} images to their masks")

# Save back to the same CSV file
df.to_csv(csv_path, index=False)
print(f"\nUpdated {csv_path} with mask_filename column")

print(f"\nFirst few rows:")
display(df.head())

## Summary

Water segmentation complete! The masks have been saved and can now be used for further analysis.

In [None]:
print("=" * 60)
print("WATER SEGMENTATION SUMMARY")
print("=" * 60)
print(f"Camera ID: {camera_id}")
print(f"Number of frames processed: {len(frame_names)}")
print(f"Masks saved to: {masks_dir}")
print(f"Updated CSV: {csv_path}")
print(f"\nYou can now proceed to notebook 03 for elevation map creation.")
print("=" * 60)
print(f"\nNote: Symbolic links are still in {sam_video_dir}")
print("You can delete them to clean up if needed (they don't use much space).")

## Optional: Cleanup Symbolic Links

If you want to remove the temporary symbolic links directory to clean up your workspace:

In [None]:
# Uncomment to delete the symbolic links directory
import shutil
if os.path.exists(sam_video_dir):
    shutil.rmtree(sam_video_dir)
    print(f"Deleted {sam_video_dir}")
else:
    print("Directory already deleted or doesn't exist")