In [None]:
# Enable interactive matplotlib widgets in Jupyter
%matplotlib widget

# Fish Tracking with SAM2 (Segment Anything Model 2)

This notebook implements fish tracking and annotation using Meta's SAM2 model.
Designed to run on Haifa University server infrastructure.

## Overview
- Interactive fish annotation and tracking
- Video processing and segmentation
- Export annotations for training/analysis

## 1. Install and Import Dependencies

First, we'll install SAM2 and all required libraries.

In [None]:
# Install SAM2 from GitHub
!pip install git+https://github.com/facebookresearch/segment-anything-2.git

# Install additional dependencies
!pip install opencv-python matplotlib numpy torch torchvision pillow

[0mLooking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting git+https://github.com/facebookresearch/segment-anything-2.git
  Cloning https://github.com/facebookresearch/segment-anything-2.git to /tmp/pip-req-build-87pifs38
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything-2.git /tmp/pip-req-build-87pifs38
  Resolved https://github.com/facebookresearch/segment-anything-2.git to commit 2b90b9f5ceec907a1c18123530e92e794ad901a4
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting torch>=2.5.1 (from SAM-2==1.0)
  Downloading torch-2.10.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (31 kB)
Collecting torchvision>=0.20.1 (from SAM-2==1.0)
  Downloading torchvision-0.25.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (5.4 kB)
Downloading torch-2.10.0-cp310-cp310-manylinux_2_28_x

In [None]:
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
import os
from PIL import Image

# SAM2 imports
from sam2.build_sam import build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")


KeyboardInterrupt



## 2. Configure SAM2 Model

Download and configure SAM2 model checkpoints. We'll use the large model for best accuracy.

In [None]:
# Configuration paths
MODEL_DIR = "../models"
DATA_DIR = "../data"
OUTPUT_DIR = "../outputs"

# Create directories if they don't exist
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Model configuration
# Options: sam2_hiera_tiny, sam2_hiera_small, sam2_hiera_base_plus, sam2_hiera_large
MODEL_CFG = "sam2_hiera_large.yaml"
CHECKPOINT = "sam2_hiera_large.pt"

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Download SAM2 checkpoint if not already present
import urllib.request

checkpoint_path = os.path.join(MODEL_DIR, CHECKPOINT)

if not os.path.exists(checkpoint_path):
    print(f"Downloading {CHECKPOINT}...")
    url = f"https://dl.fbaipublicfiles.com/segment_anything_2/072824/{CHECKPOINT}"
    urllib.request.urlretrieve(url, checkpoint_path)
    print("Download complete!")
else:
    print(f"Checkpoint already exists at {checkpoint_path}")

## 3. Load and Display Sample Fish Video/Images

Load your fish video or image sequences from the data directory.

In [None]:
# Video processing function
def extract_frames_from_video(video_path, output_dir, max_frames=None):
    """Extract frames from video file"""
    os.makedirs(output_dir, exist_ok=True)
    
    cap = cv2.VideoCapture(video_path)
    frame_count = 0
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        if max_frames and frame_count >= max_frames:
            break
            
        frame_path = os.path.join(output_dir, f"frame_{frame_count:05d}.jpg")
        cv2.imwrite(frame_path, frame)
        frame_count += 1
    
    cap.release()
    print(f"Extracted {frame_count} frames to {output_dir}")
    return frame_count

# Load frames from directory
def load_frames(frames_dir):
    """Load all frames from directory"""
    frame_files = sorted([f for f in os.listdir(frames_dir) if f.endswith(('.jpg', '.png'))])
    frames = []
    
    for frame_file in frame_files:
        frame_path = os.path.join(frames_dir, frame_file)
        frame = cv2.imread(frame_path)
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame)
    
    print(f"Loaded {len(frames)} frames")
    return frames, frame_files

In [None]:
# Example: Extract frames from video (uncomment and modify path)
# VIDEO_PATH = os.path.join(DATA_DIR, "fish_video.mp4")
# FRAMES_DIR = os.path.join(DATA_DIR, "frames")
# extract_frames_from_video(VIDEO_PATH, FRAMES_DIR, max_frames=100)

# Or load existing frames
FRAMES_DIR = os.path.join(DATA_DIR, "frames")
print(f"Looking for frames in: {FRAMES_DIR}")
print("Note: Place your video frames in the '../data/frames' directory")

# Check if frames directory exists
if os.path.exists(FRAMES_DIR) and len(os.listdir(FRAMES_DIR)) > 0:
    frames, frame_files = load_frames(FRAMES_DIR)
    
    # Display first frame
    plt.figure(figsize=(12, 8))
    plt.imshow(frames[0])
    plt.title(f"First frame: {frame_files[0]}")
    plt.axis('off')
    plt.show()
else:
    print("⚠️ No frames found. Please add video frames to the data/frames directory.")

## 4. Initialize SAM2 Video Predictor

Load the SAM2 model and prepare it for video tracking.

In [None]:
# Initialize SAM2 video predictor
predictor = build_sam2_video_predictor(MODEL_CFG, checkpoint_path)

print("✓ SAM2 video predictor initialized successfully!")
print(f"Model: {MODEL_CFG}")
print(f"Device: {device}")

In [None]:
# Initialize inference state for video tracking
if os.path.exists(FRAMES_DIR) and len(os.listdir(FRAMES_DIR)) > 0:
    inference_state = predictor.init_state(video_path=FRAMES_DIR)
    print("✓ Inference state initialized for video tracking")
else:
    print("⚠️ Skipping initialization - no frames available")

## 5. Add Tracking Points for Fish

Define initial points for the fish you want to track. You can click on the fish or provide coordinates.

In [None]:
# Interactive point selection helper
def show_points_on_image(image, points, labels):
    """Visualize points on image"""
    plt.figure(figsize=(12, 8))
    plt.imshow(image)
    
    for i, (point, label) in enumerate(zip(points, labels)):
        x, y = point
        if label == 1:  # Positive point
            plt.plot(x, y, 'go', markersize=10, markeredgewidth=2, markeredgecolor='white')
        else:  # Negative point
            plt.plot(x, y, 'ro', markersize=10, markeredgewidth=2, markeredgecolor='white')
    
    plt.title("Click points: Green = Foreground (fish), Red = Background")
    plt.axis('off')
    plt.show()

# Helper class for interactive point selection
class PointSelector:
    def __init__(self, image):
        self.image = image
        self.points = []
        self.labels = []
        self.fig, self.ax = plt.subplots(figsize=(12, 8))
        
    def onclick(self, event):
        if event.xdata is not None and event.ydata is not None:
            x, y = int(event.xdata), int(event.ydata)
            # Left click = positive (fish), right click = negative (background)
            label = 1 if event.button == 1 else 0
            self.points.append([x, y])
            self.labels.append(label)
            
            # Plot point
            color = 'g' if label == 1 else 'r'
            self.ax.plot(x, y, f'{color}o', markersize=10, markeredgewidth=2, markeredgecolor='white')
            self.fig.canvas.draw()
            
            print(f"Added {'positive' if label == 1 else 'negative'} point at ({x}, {y})")
    
    def select_points(self):
        self.ax.imshow(self.image)
        self.ax.set_title("Left click: Fish (foreground) | Right click: Background | Close window when done")
        self.ax.axis('off')
        self.fig.canvas.mpl_connect('button_press_event', self.onclick)
        plt.show()
        
        return np.array(self.points), np.array(self.labels)

In [None]:
# Example: Manual point definition (modify these coordinates)
# Or use interactive selection below

# Manual point definition
# points = np.array([[x1, y1], [x2, y2], ...])  # Coordinates of fish
# labels = np.array([1, 1, ...])  # 1 = positive (fish), 0 = negative (background)

# Interactive selection (uncomment to use)
if os.path.exists(FRAMES_DIR) and len(os.listdir(FRAMES_DIR)) > 0:
    print("Interactive point selection:")
    print("- Left click on the fish to add positive points")
    print("- Right click on background to add negative points")
    print("- Close the window when done")
    
    selector = PointSelector(frames[0])
    # Uncomment the next line to enable interactive selection
    # points, labels = selector.select_points()
    
    # Example points for demonstration (replace with your own)
    points = np.array([[320, 240]])  # Center of image - modify as needed
    labels = np.array([1])
    
    show_points_on_image(frames[0], points, labels)
else:
    print("⚠️ No frames available for point selection")

In [None]:
# Add points to SAM2 for tracking
if os.path.exists(FRAMES_DIR) and len(os.listdir(FRAMES_DIR)) > 0:
    # Frame index to start tracking (usually 0 for first frame)
    frame_idx = 0
    
    # Object ID for tracking (can track multiple objects with different IDs)
    obj_id = 1
    
    # Add points to predictor
    predictor.add_new_points(
        inference_state=inference_state,
        frame_idx=frame_idx,
        obj_id=obj_id,
        points=points,
        labels=labels,
    )
    
    print(f"✓ Added {len(points)} tracking point(s) for object {obj_id} at frame {frame_idx}")
else:
    print("⚠️ Skipping - no frames available")

## 6. Propagate Annotations Across Frames

Run SAM2 tracking to propagate the annotations across all video frames.

In [None]:
# Propagate tracking through video
if os.path.exists(FRAMES_DIR) and len(os.listdir(FRAMES_DIR)) > 0:
    print("Propagating annotations across frames...")
    
    # Run propagation
    video_segments = {}  # Store segmentation results
    
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
        video_segments[out_frame_idx] = {
            out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
            for i, out_obj_id in enumerate(out_obj_ids)
        }
    
    print(f"✓ Propagation complete! Tracked across {len(video_segments)} frames")
    print(f"Objects tracked: {list(video_segments[0].keys())}")
else:
    print("⚠️ Skipping - no frames available")
    video_segments = {}

## 7. Visualize Tracking Results

Display the tracked fish with masks overlaid on video frames.

In [None]:
# Visualization helper functions
def show_mask(mask, ax, obj_id=None, random_color=False):
    """Display a segmentation mask"""
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def visualize_frame_with_masks(frame, masks, title="Tracked Fish"):
    """Visualize a frame with all object masks"""
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.imshow(frame)
    
    for obj_id, mask in masks.items():
        show_mask(mask[0], ax, obj_id=obj_id)
    
    ax.set_title(title)
    ax.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
# Visualize results on sample frames
if len(video_segments) > 0:
    # Show results on first, middle, and last frames
    sample_indices = [0, len(video_segments)//2, len(video_segments)-1]
    
    for idx in sample_indices:
        if idx in video_segments:
            visualize_frame_with_masks(
                frames[idx], 
                video_segments[idx],
                title=f"Frame {idx}: Tracked Fish"
            )
else:
    print("⚠️ No tracking results to visualize")

## 8. Export Annotations

Export tracking results to standard annotation formats for training or analysis.

In [None]:
import json
from datetime import datetime

def export_to_coco(video_segments, frames, output_path):
    """Export annotations in COCO format"""
    coco_output = {
        "info": {
            "description": "Fish Tracking with SAM2",
            "date_created": datetime.now().isoformat(),
        },
        "images": [],
        "annotations": [],
        "categories": [{"id": 1, "name": "fish", "supercategory": "animal"}]
    }
    
    annotation_id = 1
    
    for frame_idx, masks_dict in video_segments.items():
        # Add image info
        height, width = frames[frame_idx].shape[:2]
        coco_output["images"].append({
            "id": frame_idx,
            "file_name": f"frame_{frame_idx:05d}.jpg",
            "height": height,
            "width": width
        })
        
        # Add annotations for each object
        for obj_id, mask in masks_dict.items():
            # Get bounding box from mask
            mask_binary = mask[0].astype(np.uint8)
            contours, _ = cv2.findContours(mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            
            if contours:
                x, y, w, h = cv2.boundingRect(contours[0])
                area = int(mask_binary.sum())
                
                coco_output["annotations"].append({
                    "id": annotation_id,
                    "image_id": frame_idx,
                    "category_id": 1,
                    "bbox": [int(x), int(y), int(w), int(h)],
                    "area": area,
                    "segmentation": [],  # Can add polygon points if needed
                    "iscrowd": 0,
                    "track_id": int(obj_id)
                })
                annotation_id += 1
    
    # Save to file
    with open(output_path, 'w') as f:
        json.dump(coco_output, f, indent=2)
    
    print(f"✓ Exported {len(coco_output['annotations'])} annotations to {output_path}")
    return coco_output

In [None]:
# Export annotations
if len(video_segments) > 0:
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    coco_path = os.path.join(OUTPUT_DIR, f"fish_annotations_{timestamp}.json")
    
    coco_annotations = export_to_coco(video_segments, frames, coco_path)
    print(f"\nAnnotation summary:")
    print(f"- Total frames: {len(coco_annotations['images'])}")
    print(f"- Total annotations: {len(coco_annotations['annotations'])}")
else:
    print("⚠️ No annotations to export")

## 9. Save Tracked Masks

Save segmentation masks to disk for future use and analysis.

In [None]:
# Save masks and overlays
if len(video_segments) > 0:
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Create output directories
    masks_dir = os.path.join(OUTPUT_DIR, f"masks_{timestamp}")
    overlays_dir = os.path.join(OUTPUT_DIR, f"overlays_{timestamp}")
    os.makedirs(masks_dir, exist_ok=True)
    os.makedirs(overlays_dir, exist_ok=True)
    
    print(f"Saving masks and overlays...")
    
    for frame_idx, masks_dict in video_segments.items():
        # Save binary masks
        for obj_id, mask in masks_dict.items():
            mask_binary = (mask[0] * 255).astype(np.uint8)
            mask_path = os.path.join(masks_dir, f"frame_{frame_idx:05d}_obj_{obj_id}.png")
            cv2.imwrite(mask_path, mask_binary)
        
        # Save overlay visualization
        fig, ax = plt.subplots(figsize=(12, 8))
        ax.imshow(frames[frame_idx])
        for obj_id, mask in masks_dict.items():
            show_mask(mask[0], ax, obj_id=obj_id)
        ax.axis('off')
        
        overlay_path = os.path.join(overlays_dir, f"frame_{frame_idx:05d}.png")
        plt.savefig(overlay_path, bbox_inches='tight', pad_inches=0, dpi=100)
        plt.close()
    
    print(f"✓ Saved masks to: {masks_dir}")
    print(f"✓ Saved overlays to: {overlays_dir}")
    
    # Save tracking metadata
    metadata = {
        "timestamp": timestamp,
        "num_frames": len(video_segments),
        "num_objects": len(list(video_segments.values())[0]) if video_segments else 0,
        "model": MODEL_CFG,
        "masks_dir": masks_dir,
        "overlays_dir": overlays_dir
    }
    
    metadata_path = os.path.join(OUTPUT_DIR, f"tracking_metadata_{timestamp}.json")
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print(f"✓ Saved metadata to: {metadata_path}")
else:
    print("⚠️ No masks to save")