In [1]:
import torch
# turn on tfloat32 for Ampere GPUs
# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# use bfloat16 for the entire notebook. If your card doesn't support it, try float16 instead
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()

# Utilities

## Plotting

This section contains simple utilities to plot masks and bounding masks on top of an image

In [3]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from matplotlib.colors import to_rgb
from PIL import Image


def draw_box_on_image(image, box, color=(0, 255, 0)):
    """
    Draws a rectangle on a given PIL image using the provided box coordinates in xywh format.
    :param image: PIL.Image - The image on which to draw the rectangle.
    :param box: tuple - A tuple (x, y, w, h) representing the top-left corner, width, and height of the rectangle.
    :param color: tuple - A tuple (R, G, B) representing the color of the rectangle. Default is red.
    :return: PIL.Image - The image with the rectangle drawn on it.
    """
    # Ensure the image is in RGB mode
    image = image.convert("RGB")
    # Unpack the box coordinates
    x, y, w, h = box
    x, y, w, h = int(x), int(y),int( w),int( h)
    # Get the pixel data
    pixels = image.load()
    # Draw the top and bottom edges
    for i in range(x, x + w):
        pixels[i, y] = color
        pixels[i, y + h - 1] = color
        pixels[i, y+1] = color
        pixels[i, y + h] = color
        pixels[i, y-1] = color
        pixels[i, y + h-2] = color
    # Draw the left and right edges
    for j in range(y, y + h):
        pixels[x, j] = color
        pixels[x+1, j] = color
        pixels[x-1, j] = color
        pixels[x + w - 1, j] = color
        pixels[x + w, j] = color
        pixels[x + w - 2, j] = color
    return image


def show_img_tensor(img_batch, vis_img_idx=0):
    MEAN_IMG = np.array([0.485, 0.456, 0.406])
    STD_IMG = np.array([0.229, 0.224, 0.225])
    im_tensor = img_batch[vis_img_idx].detach().cpu()
    assert im_tensor.dim() == 3
    im_tensor = im_tensor.numpy().transpose((1, 2, 0))
    im_tensor = (im_tensor * STD_IMG) + MEAN_IMG
    im_tensor = np.clip(im_tensor, 0, 1)
    plt.imshow(im_tensor)


def show_points_with_labels(coords, labels, ax=None, marker_size=200):
    if ax is None:
        ax = plt.gca()
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color="green", marker="*", s=marker_size, edgecolor="white", linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color="red", marker="*", s=marker_size, edgecolor="white", linewidth=1.25)


def plot_bbox(
    img_height,
    img_width,
    box,
    box_format="XYXY",
    relative_coords=True,
    color="r",
    linestyle="solid",
    text=None,
):
    if box_format == "XYXY":
        x, y, x2, y2 = box
        w = x2 - x
        h = y2 - y
    elif box_format == "XYWH":
        x, y, w, h = box
    elif box_format == "CxCyWH":
        cx, cy, w, h = box
        x = cx - w / 2
        y = cy - h / 2
    else:
        raise RuntimeError(f"Invalid box_format {box_format}")

    if relative_coords:
        x *= img_width
        w *= img_width
        y *= img_height
        h *= img_height

    rect = patches.Rectangle(
        (x, y), w, h, linewidth=1.5, edgecolor=color, facecolor="none", linestyle=linestyle,
    )
    plt.gca().add_patch(rect)

    if text is not None:
        facecolor = "w"
        plt.gca().text(
            x, y - 13, text, color=color, weight="bold", fontsize=8,
            bbox={"facecolor": facecolor, "alpha": 0.75, "pad": 2},
        )


def plot_mask(mask, color="r"):
    im_h, im_w = mask.shape
    mask_img = np.zeros((im_h, im_w, 4), dtype=np.float32)
    mask_img[..., :3] = to_rgb(color)
    mask_img[..., 3] = mask * 0.5
    plt.imshow(mask_img)

## Batching

This section contains some utility functions to create datapoints. They are optional, but give some good indication on how they should be created

In [30]:
from sam3.train.data.sam3_image_dataset import InferenceMetadata, FindQueryLoaded, Image as SAMImage, Datapoint, QueryType
GLOBAL_COUNTER = 1
def create_empty_datapoint():
    """ A datapoint is a single image on which we can apply several queries at once. """
    return Datapoint(find_queries=[], images=[])

def set_image(datapoint, pil_image):
    """ Add the image to be processed to the datapoint """
    w,h = pil_image.size
    datapoint.images = [SAMImage(data=pil_image, objects=[], size=[h,w])]

def add_text_prompt(datapoint, text_query):
    """ Add a text query to the datapoint """
    
    global GLOBAL_COUNTER
    # in this function, we require that the image is already set.
    # that's because we'll get its size to figure out what dimension to resize masks and boxes
    # In practice you're free to set any size you want, just edit the rest of the function
    assert len(datapoint.images) == 1, "please set the image first"

    w, h = datapoint.images[0].size
    datapoint.find_queries.append(
        FindQueryLoaded(
            query_type=QueryType.FindQuery,
            query_text=text_query,
            image_id=0,
            object_ids_output=[], # unused for inference
            is_exhaustive=True, # unused for inference
            query_processing_order=0, 
            inference_metadata=InferenceMetadata(
                coco_image_id=GLOBAL_COUNTER,
                original_image_id=GLOBAL_COUNTER,
                original_category_id=1,
                original_size=[h,w],
                object_id=0,
                frame_index=0,
            )
        )
    )
    GLOBAL_COUNTER += 1
    return GLOBAL_COUNTER - 1

# Loading

In [42]:
from sam3 import build_sam3_image_model
bpe_path = f"/fsx-onevision/shared/dvc_cache_v3/files/md5/93/3b7abbbbde62c36f02f0e6ccde464f"

# checkpoint_path = f"{sam3_root}/assets/checkpoints/sam3_prod_v12_interactive_5box_image_only.pt"
# has_presence_token = False

checkpoint_path = f"/fsx-onevision/shared/ckpts_fair_sc/checkpoint/sam3/shuangrui/omnivision_onevision/config/experiments/shuangrui/checkpoint_presence_0.5_completed.pt"
has_presence_token = True

model = build_sam3_image_model(bpe_path=bpe_path, checkpoint_path=checkpoint_path).cuda()

In [15]:
from sam3.train.transforms.basic_for_api import ComposeAPI, RandomResizeAPI, ToTensorAPI, NormalizeAPI
transform = ComposeAPI(
    transforms=[
        RandomResizeAPI(sizes=1008, max_size=1009, square=True, consistent_transform=False),
        ToTensorAPI(),
        NormalizeAPI(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]
)
        

# Inference

In [31]:
# Image 1
from PIL import Image
import requests
from io import BytesIO

img1 = Image.open(BytesIO(requests.get("http://images.cocodataset.org/val2017/000000077595.jpg").content))
datapoint1 = create_empty_datapoint()
set_image(datapoint1, img1)
id1 = add_text_prompt(datapoint1, "cat")
id2 = add_text_prompt(datapoint1, "laptop")

datapoint1 = transform(datapoint1)

In [32]:
img2 = Image.open(BytesIO(requests.get("https://s3.us-east-1.amazonaws.com/images.cocodataset.org/val2017/000000136466.jpg").content)) 
datapoint2 = create_empty_datapoint()
set_image(datapoint2, img2)
id3 = add_text_prompt(datapoint2, "oven")

datapoint2 = transform(datapoint2)

In [39]:
from sam3.train.data.collator import collate_fn_api as collate
from sam3.model.data_misc import (
    BatchedDatapoint,
    BatchedPointer,
    convert_my_tensors,
    FindStage,
    recursive_to,
)


In [54]:
batch = collate([datapoint1, datapoint2], dict_key="dummy")["dummy"]
batch = recursive_to(batch, torch.device("cuda"), non_blocking=True)
batch.img_batch = batch.img_batch.to("cuda")

In [55]:
output = model(batch)



In [None]:
# todo plotting code