In [None]:
import os
import time
from typing import Union, List, Dict, Any, Tuple, cast

import numpy as np
from PIL import Image, ImageDraw, ImageFont

from vision_agent_tools.models.florence2_sam2 import Florence2SAM2

# Visualization tools

In [2]:
COLORS = [
    (158, 218, 229),
    (219, 219, 141),
    (23, 190, 207),
    (188, 189, 34),
    (199, 199, 199),
    (247, 182, 210),
    (127, 127, 127),
    (227, 119, 194),
    (196, 156, 148),
    (197, 176, 213),
    (140, 86, 75),
    (148, 103, 189),
    (255, 152, 150),
    (152, 223, 138),
    (214, 39, 40),
    (44, 160, 44),
    (255, 187, 120),
    (174, 199, 232),
    (255, 127, 14),
    (31, 119, 180),
]

def _get_text_coords_from_mask(
    mask: np.ndarray, v_gap: int = 10, h_gap: int = 10
) -> Tuple[int, int]:
    mask = mask.astype(np.uint8)
    if np.sum(mask) == 0:
        return (0, 0)

    rows, cols = np.nonzero(mask)
    top = rows.min()
    bottom = rows.max()
    left = cols.min()
    right = cols.max()

    if top - v_gap < 0:
        if bottom + v_gap > mask.shape[0]:
            top = top
        else:
            top = bottom + v_gap
    else:
        top = top - v_gap

    return left + (right - left) // 2 - h_gap, top

def overlay_segmentation_masks(
    medias: Union[np.ndarray, List[np.ndarray]],
    masks: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]],
    draw_label: bool = True,
    secondary_label_key: str = "tracking_label",
) -> Union[np.ndarray, List[np.ndarray]]:
    """'overlay_segmentation_masks' is a utility function that displays segmentation
    masks.

    Parameters:
        medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display
            the masks on.
        masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of
            dictionaries or a list of list of dictionaries containing the masks, labels
            and scores.
        draw_label (bool, optional): If True, the labels will be displayed on the image.
        secondary_label_key (str, optional): The key to use for the secondary
            tracking label which is needed in videos to display tracking information.

    Returns:
        np.ndarray: The image with the masks displayed.

    Example
    -------
        >>> image_with_masks = overlay_segmentation_masks(
            image,
            [{
                'score': 0.99,
                'label': 'dinosaur',
                'mask': array([[0, 0, 0, ..., 0, 0, 0],
                    [0, 0, 0, ..., 0, 0, 0],
                    ...,
                    [0, 0, 0, ..., 0, 0, 0],
                    [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
            }],
        )
    """
    medias_int: List[np.ndarray] = (
        [medias] if isinstance(medias, np.ndarray) else medias
    )
    masks_int = [masks] if isinstance(masks[0], dict) else masks
    masks_int = cast(List[List[Dict[str, Any]]], masks_int)

    labels = set()
    for mask_i in masks_int:
        for mask_j in mask_i:
            labels.add(mask_j["label"])
    color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(labels)}

    width, height = Image.fromarray(medias_int[0]).size
    fontsize = max(12, int(min(width, height) / 40))
    font = ImageFont.truetype("/home/ec2-user/code/vision-agent-tools/tmp/default_font_ch_en.ttf",fontsize)

    frame_out = []
    for i, frame in enumerate(medias_int):
        pil_image = Image.fromarray(frame.astype(np.uint8)).convert("RGBA")
        for elt in masks_int[i]:
            mask = elt["mask"]
            label = elt["label"]
            tracking_lbl = elt.get(secondary_label_key, None)
            np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4))
            np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
            mask_img = Image.fromarray(np_mask.astype(np.uint8))
            pil_image = Image.alpha_composite(pil_image, mask_img)

            if draw_label:
                draw = ImageDraw.Draw(pil_image)
                text = tracking_lbl if tracking_lbl else label
                text_box = draw.textbbox((0, 0), text=text, font=font)
                x, y = _get_text_coords_from_mask(
                    mask,
                    v_gap=(text_box[3] - text_box[1]) + 10,
                    h_gap=(text_box[2] - text_box[0]) // 2,
                )
                if x != 0 and y != 0:
                    text_box = draw.textbbox((x, y), text=text, font=font)
                    draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label])
                    draw.text((x, y), text, fill="black", font=font)
        frame_out.append(np.array(pil_image))
    return frame_out[0] if len(frame_out) == 1 else frame_out


def load_image(image_path: str) -> Image.Image:
    return Image.open(image_path).convert("RGB")

def rle_decode_array(rle: Dict[str, List[int]]) -> np.ndarray:
    r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background.

    Parameters:
        rle: The run-length encoded mask.
    """
    size = rle["size"]
    counts = rle["counts"]

    total_elements = size[0] * size[1]
    flattened_mask = np.zeros(total_elements, dtype=np.uint8)

    current_pos = 0
    for i, count in enumerate(counts):
        if i % 2 == 1:
            flattened_mask[current_pos : current_pos + count] = 1
        current_pos += count

    binary_mask = flattened_mask.reshape(size, order="F")
    return binary_mask

def save_image(image: np.ndarray, file_path: str) -> None:
    """'save_image' is a utility function that saves an image to a file path.

    Parameters:
        image (np.ndarray): The image to save.
        file_path (str): The path to save the image file.

    Example
    -------
        >>> save_image(image)
    """
    from IPython.display import display

    if not isinstance(image, np.ndarray) or (
        image.shape[0] == 0 and image.shape[1] == 0
    ):
        raise ValueError("The image is not a valid NumPy array with shape (H, W, C)")

    pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
    display(pil_image)
    pil_image.save(file_path)

# Run inference on the image

In [3]:
# Path to your image
image_path = "../tests/shared_data/images/tomatoes.jpg"

# Load the image
image = load_image(image_path)

In [None]:
# Create the Florence2SAM2 instance
florence2_sam2 = Florence2SAM2()

In [None]:
start = time.time()
results = florence2_sam2(images=[image], prompt="tomato")
end = time.time()
print(f"Time taken: {end - start} seconds")

In [None]:
results[0]

# Visulalize the results

In [14]:
return_frame_data = []
for frame_idx in range(len(results)):
    annotations = len(results[frame_idx])
    for annotation_idx in range(annotations):
        annotation = results[frame_idx][annotation_idx]
        mask = rle_decode_array(annotation["mask"])
        label = str(annotation["id"]) + ": " + annotation["label"]
        return_frame_data.append({"label": label, "mask": mask, "score": 1.0})

In [15]:
image_with_masks = overlay_segmentation_masks(np.asarray(image), return_frame_data)

In [None]:
os.makedirs("../tmp", exist_ok=True)
filepath = "../tmp/segmented_tomato.jpg"
save_image(image_with_masks, filepath)