In [1]:
import os
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import cv2
import random

# Specify the path to the main directory, this is the segment-anything-2 path
main_directory = "/home/asdasd/segment-anything-2"

# Change the current working directory to the main directory
os.chdir(main_directory)

In [2]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

#### Utility functions to show results of SAM2 segmentations, using Matplotlib

In [3]:
from utility_functions import show_mask, show_points, show_box, show_masks

#### Utility functions to show results of SAM2 segmentations, using OpenCV

In [4]:
from utility_functions import draw_masks_on_image, draw_points, draw_boxes

### Load the YOLO and SAM2 Models

In [5]:
from ultralytics import YOLO

yolo_checkpoint = "yolo-sam-2/yolo_weights/Salmons_YOLOv8.pt"
yolo_segmentator = YOLO(model=yolo_checkpoint, task="segment")

In [6]:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

sam2_checkpoint = "checkpoints/sam2_hiera_large.pt" # sam2_hiera_tiny, sam2_hiera_small, sam2_hiera_base_plus, sam2_hiera_large
model_cfg = "sam2_hiera_l.yaml"                     # sam2_hiera_t, sam2_hiera_s, sam2_hiera_b+, sam2_hiera_l
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")

predictor = SAM2ImagePredictor(sam2_model)

### Load images from video
Frames are stored in a list of paths for each video frame, each frame is stored as a JPEG. This is not necesary when using the SAM2ImagePredictor class but is necesary for the SAM2VideoPredictor, so the code is shown here.

In [7]:
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "yolo-sam-2/videos/SHORT_azul"

# scan all the JPEG frame names in this directory
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
frame_paths = [os.path.join(video_dir, frame_name) for frame_name in frame_names]

#### Obtain image bounding box using a YOLO detection

In [8]:
def get_bbox(yolo_model: YOLO, image):
    segmentation_result = yolo_model.track(source=image, persist=True)
    #segmentation_result = yolo_model.predict(source=image)
    yolo_bounding_boxes = np.array([bb.cpu().numpy() for bb in segmentation_result[0].boxes.xyxy], dtype=np.int16)
    return yolo_bounding_boxes

#### Now on Video
We do inference frame by frame, this is kinda slow since we need to load each image separatedly and inference is done at the frame level. The results quality are already better than the ones obtained with SAM, MobileSAM and FastSAM but we are still not using SAM2 at it's full capability, for such see the video_predictor_with_prompt notebook were the usage of the SAM2VideoPredictor class is tested with prompts in each frame.

In [9]:
# Video segmentation using bbox as prompt, frame by frame.
mask_input = None

for frame in frame_paths:
    # Load the image using OpenCV
    image = cv2.imread(frame)

    # Convert the image to RGB (OpenCV loads images in BGR by default)
    image_RGB = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Set the SAM2 predictor to the image
    predictor.set_image(image_RGB)
    
    # Get the bboxes with YOLO
    input_boxes = get_bbox(yolo_segmentator, image)
    
    # Do inference with SAM2
    masks, scores, logits = predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_boxes,  # Using bboxes as propmt
        mask_input=mask_input,
        multimask_output=False,
    )
    mask_input = None

    # Draw the boxes and mask on the image
    image_with_boxes = draw_boxes(image, input_boxes)
    image_with_masks = draw_masks_on_image(image_with_boxes, masks, random_color=True, borders=True)

    # Show the image with masks
    cv2.imshow('Image with Masks', cv2.resize(image_with_masks, (image_with_masks.shape[1] // 2, image_with_masks.shape[0] // 2)))

    # press 'q' with the output window focused to exit.
    # waits 1 ms every loop to process key presses
    key = cv2.waitKey(1)
    if key == ord('q'):
        cv2.destroyAllWindows()
        break

cv2.destroyAllWindows()


0: 384x640 10 salmons, 61.6ms
Speed: 2.8ms preprocess, 61.6ms inference, 83.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 10 salmons, 18.7ms
Speed: 1.6ms preprocess, 18.7ms inference, 1.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 10 salmons, 17.6ms
Speed: 1.5ms preprocess, 17.6ms inference, 2.4ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 10 salmons, 17.0ms
Speed: 1.7ms preprocess, 17.0ms inference, 2.7ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 11 salmons, 17.0ms
Speed: 1.5ms preprocess, 17.0ms inference, 1.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 12 salmons, 17.2ms
Speed: 1.5ms preprocess, 17.2ms inference, 2.0ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 11 salmons, 17.1ms
Speed: 1.5ms preprocess, 17.1ms inference, 1.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 11 salmons, 16.8ms
Speed: 1.5ms preprocess, 16.8ms inference, 2.0ms postprocess per image a