# Demo 4: SAM2 Inference examples
First, install SAM2 from [official site](https://github.com/facebookresearch/segment-anything-2). It requires Python3.10, you can change the corresponding python version in `envs/env_atek_*.yaml` files correspondingly.  

In [1]:
import faulthandler

import logging
import os
from logging import StreamHandler
import numpy as np
from typing import Dict, List, Optional
import torch
import sys
import subprocess
from tqdm import tqdm

from atek.util.file_io_utils import load_yaml_and_extract_tar_list

import os
from typing import List

import cv2

import matplotlib.pyplot as plt
import numpy as np

import webdataset as wds
from atek.data_loaders.atek_wds_dataloader import (
     load_atek_wds_dataset,
     simple_list_collation_fn
)
from atek.data_loaders.sam2_model_adaptor import (
    create_atek_dataloader_as_sam2
)
from atek.data_loaders.sam2_model_adaptor import (
    create_atek_dataloader_as_depth_anything2
)
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from webdataset.filters import pipelinefilter

from depth_anything_v2.dpt import DepthAnythingV2

# Data path setup, follow Demo 2 to prepare streamable_validat_tars file.
streamable_atek_yaml_file = "./streamable_validation_tars.yaml"

# Download Pre-traineed models from SAM2
sam2_model_checkpoint = "./sam2_hiera_large.pt"
sam2_model_cfg = "sam2_hiera_l.yaml"

#### Visualization functions from original SAM2 repo
def show_mask(mask, ax, random_color=False, borders=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2

        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        # Try to smooth contours
        contours = [
            cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours
        ]
        mask_image = cv2.drawContours(
            mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2
        )
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
    )


def show_masks(
    image,
    masks,
    scores,
    point_coords=None,
    box_coords=None,
    input_labels=None,
    borders=True,
):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis("off")
        plt.show()


## SAM2 inference with preprocessed ATEK data from Data Store

### Load SAM2 model, and run inference on streamed ATEK WDS data

In [None]:
from IPython.display import display
# create SAM2 predictor
predictor = SAM2ImagePredictor(build_sam2(sam2_model_cfg, sam2_model_checkpoint))

# load ATEK dataset into SAM2 format
tar_list = load_yaml_and_extract_tar_list(streamable_atek_yaml_file)
sam2_dataloader = create_atek_dataloader_as_sam2(tar_list, num_prompt_boxes = 10)

# Perform model inference
plt.figure(figsize=(10, 10))
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    for sam_dict in sam2_dataloader:
        # perform inference
        image = sam_dict["image"].numpy()  # [H, W, 3]
        predictor.set_image(image)

        masks, scores, _ = predictor.predict(
            point_coords=None,
            point_labels=None,
            box=sam_dict["boxes"],
            multimask_output=False,
        )

        # Visualize results (taken from SAM2's own visualization code)
        print(f" SAM2 resulting mask shapes are {masks.shape}")
        plt.imshow(image)
        for mask in masks:
            show_mask(
                mask.squeeze(0), plt.gca(), random_color=True, borders=False
            )
        for box in sam_dict["boxes"]:
            show_box(box, plt.gca())
        plt.axis("off")
        plt.show()
        input("Press Enter to continue...")
        display(plt.gcf())  # Display the current figure
        plt.clf()  # Clear the figure after displaying