# 🦉 OWL-ViT inference playground

OWL-ViT is an **open-vocabulary object detector**. Given a free-text query, it will find objects matching that query. It can also do **one-shot object detection**, i.e. detect objects based on a single example image.

This Colab allows you to query the model interactively, to get a feeling for its capabilities. For details on the model, check out the [paper](https://arxiv.org/abs/2205.06230) or the [code](https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit).

> ❗ Note: The free public Colab runtime has enough memory for the ViT-B/16 model. For optimal results, use a Pro or local runtime and the ViT-L/14 model.

> ❗ Note: This Colab is optimized for fast interactive exploration. It does not apply some of the optimizations and augmentations that would be used in a rigorous evaluation settings, so results from this Colab may not match the paper.

## How to use this Colab
1. Use a GPU or TPU Colab runtime.
2. Run all cells in the Colab from top to bottom.
3. Go to the cells for [Text-conditioned object detection](#scrollTo=aNzcyP1sbJ9w&uniqifier=1) or [Image-conditioned object detection](#scrollTo=TFlZhrDTQbiY&uniqifier=1) and have fun!

**If you run into any problems, please [file an issue](https://github.com/google-research/scenic/issues/new?title=OWL-ViT+inference+playround:+[add+title]) on GitHub.**



# Download and install OWL-ViT

OWL-ViT is implemented in [Scenic](https://github.com/google-research/scenic). The cell below installs the Scenic codebase from GitHub and imports it.

In [None]:
!rm -rf *
!rm -rf .config
!rm -rf .git
!git clone https://github.com/google-research/scenic.git .
!python -m pip install -q .
!python -m pip install -r ./scenic/projects/owl_vit/requirements.txt

# Also install big_vision, which is needed for the mask head:
!mkdir /big_vision
!git clone https://github.com/google-research/big_vision.git /big_vision
!python -m pip install -r /big_vision/big_vision/requirements.txt
import sys
sys.path.append('/big_vision/')
!echo "Done."

In [None]:
import os

from bokeh import io as bokeh_io
import jax
from google.colab import output as colab_output 
import matplotlib as mpl
from matplotlib import pyplot as plt
import numpy as np
from scenic.projects.owl_vit import models
from scenic.projects.owl_vit.configs import clip_b16 as config_module
from scenic.projects.owl_vit.notebooks import inference
from scenic.projects.owl_vit.notebooks import interactive
from scenic.projects.owl_vit.notebooks import plotting
from scipy.special import expit as sigmoid
import skimage
from skimage import io as skimage_io
from skimage import transform as skimage_transform
import tensorflow as tf

tf.config.experimental.set_visible_devices([], 'GPU')
bokeh_io.output_notebook(hide_banner=True)

# Set up the model
This takes a minute or two.

In [None]:
config = config_module.get_config(init_mode='canonical_checkpoint')
module = models.TextZeroShotDetectionModule(
    body_configs=config.model.body,
    normalize=config.model.normalize,
    box_bias=config.model.box_bias)
variables = module.load_variables(config.init_from.checkpoint_path)
model = inference.Model(config, module, variables)
model.warm_up()

# Load example images

Please provide a path to a directory containing example images. Google Cloud Storage and local storage are supported.

In [None]:
IMAGE_DIR = 'gs://scenic-bucket/owl_vit/example_images'  # @param {"type": "string"}
%matplotlib inline

images = {}

for i, filename in enumerate(tf.io.gfile.listdir(IMAGE_DIR)):
  with tf.io.gfile.GFile(os.path.join(IMAGE_DIR, filename), 'rb') as f:
    image = mpl.image.imread(
        f, format=os.path.splitext(filename)[-1])[..., :3]
  if np.max(image) <= 1.:
    image *= 255
  images[i] = image

cols = 5
rows = max(len(images) // 5, 1)
fig, axs = plt.subplots(rows, cols, figsize=(16, 8 * rows))

for ax in axs.ravel():
  ax.set_visible(False)

for ax, (ind, image) in zip(axs.ravel(), images.items()):
  ax.set_visible(True)
  ax.imshow(image)
  ax.set_xticks([])
  ax.set_yticks([])
  ax.set_title(f'Image ID: {ind}')

fig.tight_layout()

# Text-conditioned detection
Enter comma-separated queries int the text box above the image to detect stuff. If nothing happens, try running the cell first (<kbd>Ctrl</kbd>+<kbd>Enter</kbd>).

In [None]:
#@title { run: "auto" }
IMAGE_ID =   2# @param {"type": "number"}
image = images[IMAGE_ID]
_, _, boxes = model.embed_image(image)
plotting.create_text_conditional_figure(
    image=model.preprocess_image(image), boxes=boxes, fig_size=900)
interactive.register_text_input_callback(model, image, colab_output)

# Image-conditioned detection

In image-conditioned detection, the model is tasked to detect objects that match a given example image. In the cell below, the example image is chosen by drawing a bounding box around an object in the left image. The model will then detect similar objects in the right image.

In [None]:
#@title { run: "auto" }

#@markdown The *query image* is used to select example objects:
QUERY_IMAGE_ID = 1  # @param {"type": "number"}

#@markdown Objects will be detected in the *target image* :
TARGET_IMAGE_ID = 0  # @param {"type": "number"}

#@markdown Threshold for the minimum confidence that a detection must have to
#@markdown be displayed (higher values mean fewer boxes will be shown):
MIN_CONFIDENCE = 0.6 #@param { type: "slider", min: 0.0, max: 1.0, step: 0.05}


#@markdown Threshold for non-maximum suppression of overlapping boxes (higher
#@markdown values mean more boxes will be shown):
NMS_THRESHOLD = 0.3 #@param { type: "slider", min: 0.05, max: 1.0, step: 0.05}

interactive.IMAGE_COND_MIN_CONF = MIN_CONFIDENCE
interactive.IMAGE_COND_NMS_IOU_THRESHOLD = NMS_THRESHOLD

query_image = images[QUERY_IMAGE_ID]
target_image = images[TARGET_IMAGE_ID]
_, _, boxes = model.embed_image(target_image)
plotting.create_image_conditional_figure(
    query_image=model.preprocess_image(query_image), 
    target_image=model.preprocess_image(target_image), 
    target_boxes=boxes, fig_size=600)
interactive.register_box_selection_callback(model, query_image, target_image, colab_output)