# OWL-ViT minimal example

This Colab shows how to **load a pre-trained OWL-ViT checkpoint** and use it to
**get object detection predictions** for an image.

# Download and install OWL-ViT

OWL-ViT is implemented in [Scenic](https://github.com/google-research/scenic). The cell below installs the Scenic codebase from GitHub and imports it.

In [1]:
# !rm -rf *
# !rm -rf .config
# !rm -rf .git
# !git clone https://github.com/google-research/scenic.git .
# !python -m pip install -q .
# !python -m pip install -r scenic/projects/baselines/clip/requirements.txt
# !echo "Done."

In [2]:
import os

import jax
from matplotlib import pyplot as plt
import numpy as np
from scenic.projects.owl_vit import models
from scenic.projects.owl_vit.configs import clip_b32, clip_l14
from scipy.special import expit as sigmoid
import skimage
from skimage import io as skimage_io
from skimage import transform as skimage_transform

Using TensorFlow backend.


# Choose config

In [3]:
# config = clip_b32.get_config()
config = clip_l14.get_config()

# Load the model and variables

In [4]:
module = models.TextZeroShotDetectionModule(
    body_configs=config.model.body,
    normalize=config.model.normalize,
    box_bias=config.model.box_bias)

In [5]:
config.init_from.checkpoint_path

'gs://scenic-bucket/owl_vit/checkpoints/clip_vit_l14_d83d374'

In [6]:
# checkpoint_path = './clip_vit_b32_b0203fc'
checkpoint_path = './clip_vit_l14_d83d374'

variables = module.load_variables(checkpoint_path)

# Prepare image

In [7]:
# filename = os.path.join(skimage.data_dir, 'astronaut.png')
filename = './images/dogpark.jpg'
filename = './images/trees.jpeg'
filename = './images/cows.jpeg'
filename = './images/straw.jpeg'
filename = './images/peas.jpeg'
filename = 'https://static4.depositphotos.com/1004288/279/i/450/depositphotos_2799328-stock-photo-green-peas-served-on-a.jpg'
#filename = 'https://miro.medium.com/max/480/1*yuWSTpGIelzw_rlmZyLnlA@2x.jpeg'
#filename = 'http://yangjie.org/peas.jpeg'
print(filename)

https://static4.depositphotos.com/1004288/279/i/450/depositphotos_2799328-stock-photo-green-peas-served-on-a.jpg


In [8]:
from PIL import Image
import requests
from io import BytesIO

response = requests.get(filename)
#print(filename)
img = Image.open(BytesIO(response.content))
#img

In [9]:
# Load example image:
# filename = os.path.join(skimage.data_dir, 'astronaut.png')

#print(filename)
if filename.lower().startswith('http'):
    response = requests.get(filename)
    image = Image.open(BytesIO(response.content))
    image_uint8 = np.array(image)
else:
    image_uint8 = skimage_io.imread(filename)

image = image_uint8.astype(np.float32) / 255.0

# Pad to square with gray pixels on bottom and right:
h, w, _ = image.shape
size = max(h, w)
image_padded = np.pad(
    image, ((0, size - h), (0, size - w), (0, 0)), constant_values=0.5)

# Resize to model input size:
input_image = skimage.transform.resize(
    image_padded,
    (config.dataset_configs.input_size, config.dataset_configs.input_size),
    anti_aliasing=True)

#print(input_image.shape)
#plt.imshow(input_image)

# Prepare text queries

In [10]:
#text_query_line = input()
text_queries = ['human face', 'rocket', 'nasa badge', 'star-spangled banner']
text_queries = ['human', 'dog', 'tree']
text_queries = ['house', 'tree']
text_queries = ['cow']
text_queries = ['green strawberry', 'red strawberry', 'flower', 'tree', 'sky']
text_queries = ['pea', 'fork', 'knife', 'plate']
text_queries = ['bean']
#text_queries = ['cloud', 'flower', 'plant', 'plant community', 'leaf', 'sky']

In [11]:
import copy
tokenized_queries = np.array([
    module.tokenize(q, config.dataset_configs.max_query_length)
    for q in text_queries
])
tokenized_queries_raw = copy.copy(tokenized_queries)
#print(tokenized_queries.shape)

# Pad tokenized queries to avoid recompilation if number of queries changes:
tokenized_queries = np.pad(
    tokenized_queries,
    pad_width=((0, 100 - len(text_queries)), (0, 0)),
    constant_values=0)
#print(tokenized_queries.shape)

100%|█████████████████████████████████████| 1.29M/1.29M [00:00<00:00, 3.80MiB/s]


# Get predictions
This will take a minute on the first execution due to model compilation. Subsequent executions will be faster.

In [12]:
# Note: The model expects a batch dimension.
import time

start_ts = time.time()
predictions = module.apply(
    variables,
    input_image[None, ...],
    tokenized_queries[None, ...],
    train=False)

# Remove batch dimension and convert to numpy:
predictions = jax.tree_map(lambda x: np.array(x[0]), predictions )
print('inference time: %.2f' % (time.time() - start_ts))  # 5.31 for b32; 91.45 for l14



inference time: 101.58


In [13]:
%matplotlib inline
# https://i.imgur.com/1IWZX69.jpg

import cv2
def show(score_threshold, w_threshold, h_threshold, nms_threshold=0.5):
    logits = predictions['pred_logits'][..., :len(text_queries)]  # Remove padding.
    scores = sigmoid(np.max(logits, axis=-1))
    labels = np.argmax(predictions['pred_logits'], axis=-1)
    boxes = predictions['pred_boxes']
    
    rects = [(cx-w/2.0, cy-h/2.0, w, h) for  (cx, cy, w, h) in boxes]
#     indicies = nms.boxes(rects, scores, nms_threshold=nms_threshold)
    indicies = cv2.dnn.NMSBoxes(rects, scores, score_threshold=score_threshold, nms_threshold=nms_threshold)
#    print(len(boxes), score_threshold, nms_threshold, len(indicies))
#    print(indicies)

    fig, ax = plt.subplots(1, 1, figsize=(16, 16))
    ax.imshow(input_image, extent=(0, 1, 1, 0))
    ax.set_axis_off()

    add_text = True
    count = 0
    index = 0
    actual_indicies = []
    for index, values in enumerate(zip(scores, boxes, labels)):
      score, box, label = values
#       if score < score_threshold:
#         continue
      if index not in indicies:
        continue
      cx, cy, w, h = box
      if w > w_threshold or h > w_threshold:
        continue
      count += 1
      actual_indicies.append(index)
      ax.plot([cx - w / 2, cx + w / 2, cx + w / 2, cx - w / 2, cx - w / 2],
              [cy - h / 2, cy - h / 2, cy + h / 2, cy + h / 2, cy - h / 2], 'r')
      if add_text:
        ax.text(
          cx - w / 2 + 0.005,
          cy - h / 2 + 0.005,
          #f'{text_queries[label]}: {score:1.2f}',
          f'{count}',
          ha='left',
          va='top',
          color='red',
          bbox={
              #'facecolor': 'white',
              #'edgecolor': 'red',
              'alpha': 0,
              'boxstyle': 'square,pad=.3'
          })
    plt.show()
    # return actual_indicies

In [16]:
import ipywidgets as widgets
#from ipywidgets import interact

score_threshold = widgets.FloatSlider(value=0.1, min=0, max=0.5, step=0.01, description='score threshold')
nms_threshold = widgets.FloatSlider(value=0.2, min=0, max=1, step=0.01, description='nms threshold')
w_threshold = widgets.FloatSlider(value=0.05, min=0, max=1, step=0.01, description='W threshold')
h_threshold = widgets.FloatSlider(value=0.05, min=0, max=1, step=0.01, description='H threshold')
# nms_threshold = 0.2
# w_threshold = 0.05
# h_threshold = 0.05


# text = widgets.FloatText(disabled=True, description='$Total Count$')

# def compute(*ignore):
#     indicies = show(score_threshold.value, w_threshold, h_threshold, nms_threshold)
#     text.value = str(len(indicies))

# score_threshold.observe(compute, 'value')

#indicies = show(score_threshold.value, w_threshold, h_threshold, nms_threshold, show=False)
#text.value = str(len(indicies))
# widgets.VBox([score_threshold, text])

_ = widgets.interact(show, score_threshold=score_threshold, w_threshold=w_threshold, h_threshold=h_threshold, nms_threshold=nms_threshold)

interactive(children=(FloatSlider(value=0.1, description='score threshold', max=0.5, step=0.01), FloatSlider(v…

In [None]:
score_threshold = 0.05
nms_threshold = 0.2
w_threshold = 0.05
h_threshold = 0.05

# indicies = show(score_threshold, w_threshold, h_threshold, nms_threshold)