# Examaple: Segment Anything 2

This notebook demonstrates how to create a model adapter for ATEK data samples. The adapter converts ATEK data into a format compatible with the SAM2 model using the ATEK library. We will walk you through three key steps: loading the dataset, adapting it for compatibility, and performing inference with the SAM2 model.

Segment Anything 2: https://github.com/facebookresearch/segment-anything-2

## Set up environment

1. Follow the official guide on SAM2 github repo to install: https://github.com/facebookresearch/segment-anything-2
SAM2 requires python>=3.10, as well as torch>=2.3.1 and torchvision>=0.18.1.
2. Download checkpoints, we choose sam2_hiera_large: https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt
put checkpoint into a path you like. You will need to feed path in the code below.
activate atek environment
2. Assume you have installed ATEK environment, you can select atek as kernel for jupyter notebook.

## Import Required Libraries
First, we import all necessary libraries that will be used throughout the notebook.

In [1]:
import os
from typing import List

import cv2

import matplotlib.pyplot as plt
import numpy as np

import torch
import webdataset as wds
from atek.data_loaders.atek_wds_dataloader import load_atek_wds_dataset
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from webdataset.filters import pipelinefilter

## Configuration and Initialization
Define the paths and configuration parameters that will be used to load the data and the model. You can also specify the MAX segmentation number by NUM_BOXES_TO_SEG

In [2]:
wds_dir = (
    "/Users/ariak/coding/dataset/atek_exp/inference/20240808_inference_test/wds_output"
)
checkpoint = "/Users/ariak/coding/segment-anything-2/checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
NUM_BOXES_TO_SEG = 10

output_video_path = "/Users/ariak/Downloads/SAM_20.mp4"
output_image_folder_path = "/Users/ariak/Downloads/SAM_20_images"

## Model Adaptor from SAM2 to ATEK
We need to adapt the ATEK data sample to be compatible with SAM2. This class handles the conversion.

In [None]:
class Sam2Adaptor:
    def __init__(
        self,
        num_boxes: int = 5,
    ):
        self.num_boxes = num_boxes

    @staticmethod
    def get_dict_key_mapping_all():
        dict_key_mapping = {"mfcd#camera-rgb+images": "image", "gt_data": "gt_data"}
        return dict_key_mapping

    def atek_to_sam2(self, data):
        for atek_wds_sample in data:
            sample = {}

            # Add images
            # from [1, C, H, W] to [H, W, C]
            image_torch = atek_wds_sample["image"].clone().detach()
            image_np = image_torch.squeeze(0).permute(1, 2, 0).numpy()
            sample["image"] = image_np

            # Select
            obb2_gt = atek_wds_sample["gt_data"]["obb2_gt"]["camera-rgb"]
            num_box = min(self.num_boxes, len(obb2_gt["category_names"]))
            bbox_ranges = obb2_gt["box_ranges"][
                0:num_box, [0, 2, 1, 3]
            ]  # First K bboxes, [K, 4], xxyy -> xyxy
            sample["boxes"] = bbox_ranges.numpy()  # xxyy -> xyxy

            yield sample

## Data Loading Function
load_atek_wds_dataset_as_sam2 loads the ATEK dataset and applies the adaptor to make it compatible with SAM2.

In [None]:
def simple_collation_fn(batch):
    # Simply collate as a list
    return list(batch)


def load_atek_wds_dataset_as_sam2(
    urls: List,
    batch_size: int,
    repeat_flag: bool,
    shuffle_flag: bool = False,
    num_boxes: int = 5,
):
    adaptor = Sam2Adaptor(num_boxes=num_boxes)

    return load_atek_wds_dataset(
        urls,
        batch_size=batch_size,
        dict_key_mapping=Sam2Adaptor.get_dict_key_mapping_all(),
        data_transform_fn=pipelinefilter(adaptor.atek_to_sam2)(),
        collation_fn=simple_collation_fn,
        repeat_flag=repeat_flag,
        shuffle_flag=shuffle_flag,
    )

## Visualization from SAM2
These functions are used to visualize the results of the SAM2 model predictions. These functions are copied from a SAM2 example notebook.

In [None]:
def show_mask(mask, ax, random_color=False, borders=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2

        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        # Try to smooth contours
        contours = [
            cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours
        ]
        mask_image = cv2.drawContours(
            mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2
        )
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
    )


def show_masks(
    image,
    masks,
    scores,
    point_coords=None,
    box_coords=None,
    input_labels=None,
    borders=True,
):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis("off")
        plt.show()


## Main Inference Function
This function sets up the model, loads the data, performs inference, and visualizes the results.

In [None]:
def main() -> None:
    # create SAM2 predictor
    predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))

    # load ATEK dataset
    tar_list = [os.path.join(wds_dir, f"shards-000{i}.tar") for i in range(5)]
    sam2_dataset = load_atek_wds_dataset_as_sam2(
        tar_list,
        batch_size=1,
        repeat_flag=False,
        shuffle_flag=False,
        num_boxes=NUM_BOXES_TO_SEG,
    )

    # Perform inference
    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):

        for sam_dict_list in sam2_dataset:
            for sam_dict in sam_dict_list:
                # perform inference
                image = sam_dict["image"]  # [1, 3, H, W]
                predictor.set_image(image)

                masks, scores, _ = predictor.predict(
                    point_coords=None,
                    point_labels=None,
                    box=sam_dict["boxes"],
                    multimask_output=False,
                )

                # Visualize results (taken from SAM2's own visualization code)
                print(f" SAM2 resulting mask shapes are {masks.shape}")
                plt.figure(figsize=(10, 10))
                plt.imshow(image)
                for mask in masks:
                    show_mask(
                        mask.squeeze(0), plt.gca(), random_color=True, borders=False
                    )
                for box in sam_dict["boxes"]:
                    show_box(box, plt.gca())
                plt.axis("off")
                plt.show()
                input("Press Enter to continue...")
if __name__ == "__main__":
    main()