In [1]:
from transformers import Owlv2ForObjectDetection, AutoProcessor
import torch

owlv2 = Owlv2ForObjectDetection.from_pretrained(
    "google/owlv2-large-patch14-ensemble", dtype=torch.float16, device_map="auto"
)
processor = AutoProcessor.from_pretrained(owlv2.name_or_path, use_fast=True)

interpolate_pos_encoding = True


In [2]:
from torchvision.io import decode_image
from pathlib import Path
from transformers import TensorType
import torch

ref_paths = list(Path("ref").glob("*_obj.png"))

query_inputs = processor(
    query_images=[str(path) for path in ref_paths],
    # query_images=[
    #     decode_image(path, apply_exif_orientation=True) for path in ref_paths
    # ],
    return_tensors=TensorType.PYTORCH,
    device=owlv2.device,
)

with torch.inference_mode():
    query_feature_maps = owlv2.image_embedder(
        query_inputs.query_pixel_values,
        interpolate_pos_encoding=interpolate_pos_encoding,
    )[0]

In [3]:
img_path = "frame_003387_02_15.jpg"
target_inputs = processor(
    images=img_path,
    return_tensors=TensorType.PYTORCH,
    device=owlv2.device,
)

query_feature_map = query_feature_maps.mean(0)[None]

with torch.inference_mode():
    target_feature_map = owlv2.image_embedder(
        target_inputs.pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
    )[0]

    batch_size = query_feature_map.size(0)
    output_dim = query_feature_map.size(-1)

    query_feats = query_feature_map.reshape(batch_size, -1, output_dim)
    objectness_logits = owlv2.objectness_predictor(query_feats)
    best_box_indices = torch.argmax(objectness_logits, dim=1).view(-1, 1, 1)

    query_class_embeds = owlv2.class_predictor(query_feats)[1]
    hidden_size = query_class_embeds.size(-1)
    # B 1 hidden_size
    query_embeds = torch.gather(
        query_class_embeds,
        dim=1,
        index=best_box_indices.expand(-1, 1, hidden_size),
    )

    target_feats = target_feature_map.reshape(batch_size, -1, output_dim)
    pred_boxes = owlv2.box_predictor(
        target_feats,
        target_feature_map,
        interpolate_pos_encoding=interpolate_pos_encoding,
    )
    pred_scores = torch.sigmoid(owlv2.class_predictor(target_feats, query_embeds)[0])

In [None]:
from torchvision.transforms.v2.functional import convert_bounding_box_format
from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat
from torchvision.ops import nms
from PIL import Image, ImageDraw

score_mask = pred_scores.squeeze(-1) > 0.1
selected_boxes = pred_boxes[score_mask]
selected_scores = pred_scores[score_mask]

with Image.open(img_path) as im:
    boxes = BoundingBoxes(
        selected_boxes, format=BoundingBoxFormat.CXCYWH, canvas_size=im.size
    )
    xyxy_boxes = convert_bounding_box_format(boxes, new_format=BoundingBoxFormat.XYXY)

    indices = nms(xyxy_boxes, selected_scores.squeeze(), iou_threshold=0.001)
    selected_boxes = xyxy_boxes[indices] * max(im.size)
    selected_scores = selected_scores[indices]

    # selected_boxes = selected_boxes[indices] * max(im.size)
    _draw = ImageDraw.Draw(im)
    for _box_tensor, _score in zip(pixel_boxes, selected_scores):
        _box = _box_tensor.tolist()
        _draw.text(_box[:2], str(_score.item()))
        _draw.rectangle(_box)
im

RuntimeError: a Tensor with 4 elements cannot be converted to Scalar