Download Model Checkpoint

Download the **ViT-Huge (vit_h)** model checkpoint.

In [5]:
import os
# Download the model checkpoint using curl instead of wget
!curl -L -o sam_vit_h_4b8939.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

SAM_CHECKPOINT = "sam_vit_h_4b8939.pth"
MODEL_TYPE = "vit_h"

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 2445M  100 2445M    0     0  48.2M      0  0:00:50  0:00:50 --:--:-- 48.8M0:00:55  0:00:01  0:00:54 43.6M0:02 45.0M


# Load the Model

This cell loads the model into memory and prepares the `SamPredictor` object.

In [7]:
from segment_anything import SamPredictor, sam_model_registry
import torch
# Check for GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load the model
sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
sam.to(device=device)

# Create the predictor object
predictor = SamPredictor(sam)

print("\n✅ SAM Model and Predictor loaded successfully.")

Using device: cpu

✅ SAM Model and Predictor loaded successfully.


# Data Setup


In [8]:
# --- Create Directory Structure ---
IMAGE_DIR = '/Users/cortesb/vip-amp-perception/AKS_cones/images'
LABEL_DIR = '/Users/cortesb/vip-amp-perception/AKS_cones/masks'
MASK_DIR = '/Users/cortesb/vip-amp-perception/AKS_cones/cropped_images'

os.makedirs(IMAGE_DIR, exist_ok=True)
os.makedirs(LABEL_DIR, exist_ok=True)
os.makedirs(MASK_DIR, exist_ok=True)

# Helper Functions

Define two sets of helpers:
1.  `parse_yolo_bbox`: To read `.txt` files and convert the normalized YOLO coordinates `[x_center, y_center, w, h]` into the absolute pixel coordinates `[x_min, y_min, x_max, y_max]` that SAM requires.
2.  Plotting functions: (Taken from the official SAM examples) to help us visualize the results later.

In [None]:
def parse_yolo_bbox(label_path, image_shape):
    """Reads a YOLOv8 .txt file and converts boxes to [x_min, y_min, x_max, y_max] format."""

    if not os.path.exists(label_path):
        return []

    H, W = image_shape
    boxes = []

    with open(label_path, 'r') as f:
        for line in f.readlines():
            parts = line.strip().split()
            if len(parts) != 5:
                continue

            class_id = int(parts[0])
            x_c_norm = float(parts[1])
            y_c_norm = float(parts[2])
            w_norm = float(parts[3])
            h_norm = float(parts[4])

            # Convert from normalized [x_c, y_c, w, h] to [x_min, y_min, x_max, y_max]
            box_w = w_norm * W
            box_h = h_norm * H
            x_min = (x_c_norm * W) - (box_w / 2)
            y_min = (y_c_norm * H) - (box_h / 2)
            x_max = x_min + box_w
            y_max = y_min + box_h

            # Append as (class_id, [x1, y1, x2, y2])
            boxes.append((class_id, [int(x_min), int(y_min), int(x_max), int(y_max)]))

    return boxes

# --- Visualization Helper Functions (from SAM examples) ---

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - x0, box[3] - y0
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

print("✅ Helper functions defined.")

✅ Helper functions defined.


# Main Processing Loop

This is the core of the notebook. It will:
1.  Loop through every image in `IMAGE_DIR`.
2.  Find its matching `.txt` file in `LABEL_DIR`.
3.  Parse all bounding boxes from the `.txt` file.
4.  **Filter for target class** (set `TARGET_CLASS_ID = 1` based on example `*.txt` file, which corresponds to cones).
5.  Tell the `SamPredictor` to process the image.
6.  For each target bounding box, ask the `Predictor` for a mask.
7.  Save that mask as a simple black-and-white PNG file in `MASK_DIR`.

In [None]:
from tqdm import tqdm
import cv2
TARGET_CLASS_ID = 1

print(f"Starting mask generation... Target Class ID: {TARGET_CLASS_ID}")
print(f"Looking for images in: {IMAGE_DIR}")
print(f"Saving masks to: {MASK_DIR}")

image_files = [f for f in os.listdir(IMAGE_DIR) if f.endswith(('.jpg', '.jpeg', '.png'))]

if not image_files:
    print("\n--- WARNING: No images found in directory! ---")

for image_name in tqdm(image_files):
    image_path = os.path.join(IMAGE_DIR, image_name)

    # Construct the corresponding label path
    base_name = os.path.splitext(image_name)[0]
    label_name = base_name + ".txt"
    label_path = os.path.join(LABEL_DIR, label_name)

    # Read the image
    image = cv2.imread(image_path)
    if image is None:
        print(f"Warning: Could not read image {image_path}. Skipping.")
        continue

    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image_shape = image.shape[:2] # (H, W)

    # Get all bounding boxes for this image
    boxes_info = parse_yolo_bbox(label_path, image_shape)

    if not boxes_info:
        # print(f"No labels found for {image_name}. Skipping.")
        continue

    # --- Use SAM Predictor ---
    # 1. Set the image in the predictor. This is a one-time operation per image.
    predictor.set_image(image_rgb)

    # 2. Iterate through all boxes found in the label file
    box_index = 0
    for class_id, box in boxes_info:

        # 3. Only process boxes that match our target class (e.g., 'cone')
        if class_id == TARGET_CLASS_ID:

            # Define the potential output file path *first*
            mask_filename = f"{base_name}_mask_class{class_id}_{box_index}.png"
            mask_save_path = os.path.join(MASK_DIR, mask_filename)

            # Check if this file already exists
            if os.path.exists(mask_save_path):
                # print(f"Skipping existing mask: {mask_filename}")
                box_index += 1  # Increment the target box counter
                continue      # Skip to the next box

            # Convert box to numpy array for SAM
            input_box = np.array(box)

            # 4. Run prediction!
            masks, scores, logits = predictor.predict(
                box=input_box[None, :], # [None, :] adds a batch dimension
                multimask_output=False  # Get only the single best mask
            )

            # masks is (1, H, W). Get the first and only mask.
            mask = masks[0]

            # Convert boolean mask to 8-bit grayscale image (0=black, 255=white)
            mask_image_8bit = (mask * 255).astype(np.uint8)

            # 5. Save the mask (path was already defined above)
            cv2.imwrite(mask_save_path, mask_image_8bit)

            box_index += 1 # Increment the target box counter after processing

    # Reset the predictor for the next image
    predictor.reset_image()

print("\n✅ All images processed. Masks are saved in /content/dataset/masks/")

Starting mask generation... Target Class ID: 1
Looking for images in: /content/drive/MyDrive/AMP/Synthetic_Data/AKS_cones/images
Saving masks to: /content/drive/MyDrive/AMP/Synthetic_Data/AKS_cones/masks


100%|██████████| 6948/6948 [2:18:03<00:00,  1.19s/it]


✅ All images processed. Masks are saved in /content/dataset/masks/





## 7. Visualize a Result


In [None]:
import matplotlib as plt
# --- Use the same example image from before ---

example_img_path = os.path.join(IMAGE_DIR, "0_Cross-Street-road-closure_jpg.rf.b268533f2d69236b4bda24587fe8eba8.jpg")
example_label_path = os.path.join(LABEL_DIR, "0_Cross-Street-road-closure_jpg.rf.b268533f2d69236b4bda24587fe8eba8.txt")

if not os.path.exists(example_img_path):
    print(f"Cannot run visualization: Example image not found at {example_img_path}")
else:
    # Load the image
    image = cv2.imread(example_img_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    H, W = image_rgb.shape[:2]

    # Load the boxes
    boxes_info = parse_yolo_bbox(example_label_path, (H, W))
    target_boxes = [box for class_id, box in boxes_info if class_id == TARGET_CLASS_ID]

    # Load the generated masks
    base_name = os.path.splitext(example_image_name)[0]
    generated_masks = []
    for i in range(len(target_boxes)):
        mask_path = os.path.join(MASK_DIR, f"{base_name}_mask_class{TARGET_CLASS_ID}_{i}.png")
        if os.path.exists(mask_path):
            mask_8bit = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            mask_bool = mask_8bit.astype(bool)
            generated_masks.append(mask_bool)
        else:
            print(f"Warning: Could not find expected mask {mask_path}")

    # --- Plotting ---
    if not generated_masks:
        print("No masks were generated for the example image. Cannot visualize.")
    else:
        print(f"Visualizing results for {example_image_name}...")
        plt.figure(figsize=(20, 10))

        # Plot 1: Image + Bounding Boxes
        plt.subplot(1, 2, 1)
        plt.imshow(image_rgb)
        for box in target_boxes:
            show_box(box, plt.gca())
        plt.title(f"Input Image + YOLO BBoxes (Class {TARGET_CLASS_ID})")
        plt.axis('off')

        # Plot 2: Image + Generated Masks
        plt.subplot(1, 2, 2)
        plt.imshow(image_rgb)
        for mask in generated_masks:
            show_mask(mask, plt.gca(), random_color=True)
        for box in target_boxes:
            show_box(box, plt.gca())
        plt.title("Output: Image + Generated SAM Masks")
        plt.axis('off')

        plt.tight_layout()
        plt.show()

        # Optional: Plot just the masks
        plt.figure(figsize=(15, 7))
        for i, mask in enumerate(generated_masks):
            plt.subplot(1, len(generated_masks), i+1)
            plt.imshow(mask, cmap='gray')
            plt.title(f"Mask {i}")
            plt.axis('off')
        plt.suptitle("Raw Generated Masks (Saved as PNGs)")
        plt.show()