In [None]:
import os
import time
import base64
import tempfile
from enum import Enum
from pathlib import Path
from typing import Union, List, Dict, Any, Tuple, cast

import av
import cv2
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image, ImageDraw, ImageFont

from vision_agent_tools.models.florence2_sam2 import Florence2SAM2

# Visualization tools

In [2]:
def denormalize_bbox(
    bbox: List[Union[int, float]], image_size: Tuple[int, ...]
) -> List[float]:
    r"""DeNormalize the bounding box coordinates so that they are in absolute values."""

    if len(bbox) != 4:
        raise ValueError("Bounding box must be of length 4.")

    arr = np.array(bbox)
    if np.all((arr >= 0) & (arr <= 1)):
        x1, y1, x2, y2 = bbox
        x1 = round(x1 * image_size[1])
        y1 = round(y1 * image_size[0])
        x2 = round(x2 * image_size[1])
        y2 = round(y2 * image_size[0])
        return [x1, y1, x2, y2]
    else:
        return bbox

COLORS = [
    (158, 218, 229),
    (219, 219, 141),
    (23, 190, 207),
    (188, 189, 34),
    (199, 199, 199),
    (247, 182, 210),
    (127, 127, 127),
    (227, 119, 194),
    (196, 156, 148),
    (197, 176, 213),
    (140, 86, 75),
    (148, 103, 189),
    (255, 152, 150),
    (152, 223, 138),
    (214, 39, 40),
    (44, 160, 44),
    (255, 187, 120),
    (174, 199, 232),
    (255, 127, 14),
    (31, 119, 180),
]

def _get_text_coords_from_mask(
    mask: np.ndarray, v_gap: int = 10, h_gap: int = 10
) -> Tuple[int, int]:
    mask = mask.astype(np.uint8)
    if np.sum(mask) == 0:
        return (0, 0)

    rows, cols = np.nonzero(mask)
    top = rows.min()
    bottom = rows.max()
    left = cols.min()
    right = cols.max()

    if top - v_gap < 0:
        if bottom + v_gap > mask.shape[0]:
            top = top
        else:
            top = bottom + v_gap
    else:
        top = top - v_gap

    return left + (right - left) // 2 - h_gap, top

def overlay_segmentation_masks(
    medias: Union[np.ndarray, List[np.ndarray]],
    masks: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]],
    draw_label: bool = True,
    secondary_label_key: str = "tracking_label",
) -> Union[np.ndarray, List[np.ndarray]]:
    """'overlay_segmentation_masks' is a utility function that displays segmentation
    masks.

    Parameters:
        medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display
            the masks on.
        masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of
            dictionaries or a list of list of dictionaries containing the masks, labels
            and scores.
        draw_label (bool, optional): If True, the labels will be displayed on the image.
        secondary_label_key (str, optional): The key to use for the secondary
            tracking label which is needed in videos to display tracking information.

    Returns:
        np.ndarray: The image with the masks displayed.

    Example
    -------
        >>> image_with_masks = overlay_segmentation_masks(
            image,
            [{
                'score': 0.99,
                'label': 'dinosaur',
                'mask': array([[0, 0, 0, ..., 0, 0, 0],
                    [0, 0, 0, ..., 0, 0, 0],
                    ...,
                    [0, 0, 0, ..., 0, 0, 0],
                    [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
            }],
        )
    """
    medias_int: List[np.ndarray] = (
        [medias] if isinstance(medias, np.ndarray) else medias
    )
    masks_int = [masks] if isinstance(masks[0], dict) else masks
    masks_int = cast(List[List[Dict[str, Any]]], masks_int)

    labels = set()
    for mask_i in masks_int:
        for mask_j in mask_i:
            labels.add(mask_j["label"])
    color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(labels)}

    width, height = Image.fromarray(medias_int[0]).size
    fontsize = max(12, int(min(width, height) / 40))
    font = ImageFont.truetype("/home/ec2-user/code/vision-agent-tools/tmp/default_font_ch_en.ttf",fontsize)

    frame_out = []
    for i, frame in enumerate(medias_int):
        pil_image = Image.fromarray(frame.astype(np.uint8)).convert("RGBA")
        for elt in masks_int[i]:
            mask = elt["mask"]
            label = elt["label"]
            tracking_lbl = elt.get(secondary_label_key, None)
            np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4))
            np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
            mask_img = Image.fromarray(np_mask.astype(np.uint8))
            pil_image = Image.alpha_composite(pil_image, mask_img)

            if draw_label:
                draw = ImageDraw.Draw(pil_image)
                text = tracking_lbl if tracking_lbl else label
                text_box = draw.textbbox((0, 0), text=text, font=font)
                x, y = _get_text_coords_from_mask(
                    mask,
                    v_gap=(text_box[3] - text_box[1]) + 10,
                    h_gap=(text_box[2] - text_box[0]) // 2,
                )
                if x != 0 and y != 0:
                    text_box = draw.textbbox((x, y), text=text, font=font)
                    draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label])
                    draw.text((x, y), text, fill="black", font=font)
        frame_out.append(np.array(pil_image))
    return frame_out[0] if len(frame_out) == 1 else frame_out

def load_video(video_path: str):
    with open(video_path, "rb") as f:
        video_bytes = f.read()
        with tempfile.NamedTemporaryFile() as fp:
            fp.write(video_bytes)
            fp.flush()
            video_temp_file = fp.name
            cap = cv2.VideoCapture(video_temp_file)
            fps = cap.get(cv2.CAP_PROP_FPS)

            frames = []
            while cap.isOpened():
                ret, frame = cap.read()
                if not ret:
                    break
                frames.append(frame)
            cap.release()
            return np.array(frames), fps

def rle_decode_array(rle: dict[str, list[int]]) -> np.ndarray:
    r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background.

    Parameters:
        rle: The run-length encoded mask.
    """
    size = rle["size"]
    counts = rle["counts"]

    total_elements = size[0] * size[1]
    flattened_mask = np.zeros(total_elements, dtype=np.uint8)

    current_pos = 0
    for i, count in enumerate(counts):
        if i % 2 == 1:
            flattened_mask[current_pos : current_pos + count] = 1
        current_pos += count

    binary_mask = flattened_mask.reshape(size, order="F")
    return binary_mask

def _resize_frame(frame: np.ndarray) -> np.ndarray:
    height, width = frame.shape[:2]
    new_width = width - (width % 2)
    new_height = height - (height % 2)
    return cv2.resize(frame, (new_width, new_height))

class FileSerializer:
    """Adaptor class that allows IPython.display.display() to serialize a file to a
    base64 string representation.
    """

    def __init__(self, file_uri: str):
        self.video_uri = file_uri
        assert os.path.isfile(
            file_uri
        ), f"Only support local files currently: {file_uri}"
        assert Path(file_uri).exists(), f"File not found: {file_uri}"

    def __repr__(self) -> str:
        return f"FileSerializer({self.video_uri})"

    def base64(self) -> str:
        with open(self.video_uri, "rb") as file:
            return base64.b64encode(file.read()).decode("utf-8")

class MimeType(str, Enum):
    """Represents a MIME type."""

    TEXT_PLAIN = "text/plain"
    TEXT_HTML = "text/html"
    TEXT_MARKDOWN = "text/markdown"
    IMAGE_SVG = "image/svg+xml"
    IMAGE_PNG = "image/png"
    IMAGE_JPEG = "image/jpeg"
    VIDEO_MP4_B64 = "video/mp4/base64"
    APPLICATION_PDF = "application/pdf"
    TEXT_LATEX = "text/latex"
    APPLICATION_JSON = "application/json"
    APPLICATION_JAVASCRIPT = "application/javascript"
    APPLICATION_ARTIFACT = "application/artifact"

def video_writer(
    frames: List[np.ndarray], fps: float = 1.0, filename: str | None = None
) -> str:
    if filename is None:
        filename = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
    container = av.open(filename, mode="w")
    stream = container.add_stream("h264", rate=fps)
    height, width = frames[0].shape[:2]
    stream.height = height - (height % 2)
    stream.width = width - (width % 2)
    stream.pix_fmt = "yuv420p"
    stream.options = {"crf": "10"}
    for frame in frames:
        # Remove the alpha channel (convert RGBA to RGB)
        frame_rgb = frame[:, :, :3]
        # Resize the frame to make dimensions divisible by 2
        frame_rgb = _resize_frame(frame_rgb)
        av_frame = av.VideoFrame.from_ndarray(frame_rgb, format="rgb24")
        for packet in stream.encode(av_frame):
            container.mux(packet)

    for packet in stream.encode():
        container.mux(packet)
    container.close()
    return filename

def _save_video_to_result(video_uri: str) -> None:
    """Saves a video into the result of the code execution (as an intermediate output)."""
    from IPython.display import display

    serializer = FileSerializer(video_uri)
    display(
        {
            MimeType.VIDEO_MP4_B64: serializer.base64(),
            MimeType.TEXT_PLAIN: str(serializer),
        },
        raw=True,
    )

def save_video(
    frames: List[np.ndarray], output_video_path: str | None = None, fps: float = 1
) -> str:
    """'save_video' is a utility function that saves a list of frames as a mp4 video file on disk.

    Parameters:
        frames (list[np.ndarray]): A list of frames to save.
        output_video_path (str): The path to save the video file. If not provided, a temporary file will be created.
        fps (float): The number of frames composes a second in the video.

    Returns:
        str: The path to the saved video file.

    Example
    -------
        >>> save_video(frames)
        "/tmp/tmpvideo123.mp4"
    """
    if fps <= 0:
        raise ValueError(f"fps must be greater than 0 got {fps}")

    if not isinstance(frames, list) or len(frames) == 0:
        raise ValueError("Frames must be a list of NumPy arrays")

    for frame in frames:
        if not isinstance(frame, np.ndarray) or (
            frame.shape[0] == 0 and frame.shape[1] == 0
        ):
            raise ValueError("A frame is not a valid NumPy array with shape (H, W, C)")

    if output_video_path is None:
        output_video_path = tempfile.NamedTemporaryFile(
            delete=False, suffix=".mp4"
        ).name

    output_video_path = video_writer(frames, fps, output_video_path)
    _save_video_to_result(output_video_path)
    return output_video_path

# Run inference on the video

In [3]:
# Path to your video
video_path = "../tests/shared_data/videos/shark_10fps.mp4"
# video_path = "../tests/shared_data/videos/tracking_car.mp4"

# Load the video
frames, fps = load_video(video_path)

In [None]:
# Create the Florence2SAM2 instance
florence2_sam2 = Florence2SAM2()

In [None]:
start = time.time()
results = florence2_sam2(video=frames, prompt="shark, person")
# results = florence2_sam2(video=frames, prompt="car")
end = time.time()
print(f"Time taken: {end - start} seconds")

# Visulalize the results

In [6]:
return_data = []
for frame_idx in range(len(results)):
    return_frame_data = []
    annotations = len(results[frame_idx])
    for annotation_idx in range(annotations):
        annotation = results[frame_idx][annotation_idx]
        mask = rle_decode_array(annotation["mask"])
        label = str(annotation["id"]) + ": " + annotation["label"]
        return_frame_data.append({"label": label, "mask": mask, "score": 1.0})
    return_data.append(return_frame_data)

In [7]:
processed_frames = []
for frame, frame_detections in zip(frames, return_data):
    frame_with_overlays = frame.copy()
    frame_with_overlays = overlay_segmentation_masks(frame_with_overlays, frame_detections)
    processed_frames.append(frame_with_overlays)

In [None]:
os.makedirs("../tmp", exist_ok=True)
filepath = "../tmp/florence2sam2_shark.mp4"
# filepath = "../tmp/florence2sam2_track_car.mp4"
output_path = save_video(processed_frames, filepath, fps=int(fps))

In [None]:
%matplotlib inline
print("out_shape", frame_with_overlays.shape)
data = frame_with_overlays
plt.imshow(data, interpolation='nearest')
plt.show()