# transformers: OWL-ViT for object detection

In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision.utils import draw_bounding_boxes
import requests
from PIL import Image
from transformers import OwlViTProcessor, OwlViTForObjectDetection

## Load image

In [None]:
# load image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'

image = Image.open(requests.get(url, stream=True).raw)

In [None]:
# show image
fig, ax = plt.subplots(figsize=(6, 4))
ax.imshow(np.asarray(image))
ax.set_aspect('equal', adjustable='box')
fig.tight_layout()

## Load model

In [None]:
# set model name
model_name = 'google/owlvit-base-patch32'

In [None]:
# create text and image processors
processor = OwlViTProcessor.from_pretrained(model_name)

# load model
model = OwlViTForObjectDetection.from_pretrained(model_name, device_map='auto')
model = model.eval()

print(f'Model device: {model.device}')
print(f'Model dtype: {model.dtype}')
print(f'Memory footprint: {model.get_memory_footprint() * 1e-9:.2f} GiB')

print(f'\nEmbedding dim.: {model.config.projection_dim}')

## Run model

In [None]:
# set candidate labels
candidate_labels = ['cat', 'dog', 'car', 'remote', 'blanket']

# preprocess inputs
inputs = processor(
    text=candidate_labels,
    images=image,
    return_tensors='pt',
    padding=True
)

print(f'Input IDs shape: {inputs['input_ids'].shape}')
print(f'Pixel values shape: {inputs['pixel_values'].shape}')

In [None]:
# run model
with torch.no_grad():
    outputs = model(**inputs.to(model.device))

bboxes = outputs.pred_boxes.cpu() # (batch_size, num_boxes, 4)
logits = outputs.logits.cpu() # (batch_size, num_boxes, num_labels)
probs = logits.softmax(dim=-1)

print(f'Bounding boxes shape: {bboxes.shape}') # bboxes in Pascal VOC format (xmin, ymin, xmax, ymax)
print(f'Logits shape: {logits.shape}')

In [None]:
# postprocess outputs
detections = processor.post_process_grounded_object_detection(
    outputs,
    threshold=0.1,
    target_sizes=[(image.height, image.width)],
    text_labels=[candidate_labels]
)

# summarize detections
for label, score, bbox in zip(
    detections[0]['text_labels'],
    detections[0]['scores'],
    detections[0]['boxes']
):
    box = [round(coord, 2) for coord in bbox.tolist()]
    print(f'{label} ({score:.2f}) in {box}')

In [None]:
# set colors for bounding boxes
colors = ['green', 'orange', 'purple', 'green', 'red']
unique_label_ids = torch.unique(detections[0]['labels']).tolist()
color_dict = {lidx: colors[idx % len(colors)] for idx, lidx in enumerate(unique_label_ids)}

# add bounding boxes to the image
image_tensor = torch.as_tensor(np.array(image)) # (H, W, C)

image_tensor = draw_bounding_boxes(
    image_tensor.permute(2, 0, 1), # (C, H, W)
    boxes=detections[0]['boxes'],
    labels=detections[0]['text_labels'],
    colors=[color_dict[lidx] for lidx in detections[0]['labels'].tolist()]
).permute(1, 2, 0) # (H, W, C)

# show predictions
fig, ax = plt.subplots(figsize=(6, 4))
ax.imshow(image_tensor.numpy())
ax.set_aspect('equal', adjustable='box')
ax.set_title(f'Predictions')
fig.tight_layout()