In [None]:
#@title License
# Copyright 2022 The Pix2Seq Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

## A Unified Sequence Interface for Vision Tasks
<a href="https://colab.research.google.com/github/google-research/pix2seq/blob/master/colabs/pix2seq_inference_multitask.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


This colab presents a demo for multi-task inference with Pix2seq. The table below provides a summary and model location for fine-tuned models on MSCOCO dataset.

Backbone       | Total params (M) | Image size | COCO AP   | Google cloud storage location
-------------: | ---------------: | ---------: | --------: | -----------:
ViT-B          | 115.2            | 640x640    | 44.2      | [gs://pix2seq/multi_task/ckpt/vit_b_640x640](https://console.cloud.google.com/storage/browser/pix2seq/multi_task/ckpt/vit_b_640x640)
ViT-B          | 115.2            | 1024x1024  | 46.5      | [gs://pix2seq/multi_task/ckpt/vit_b_1024x1024](https://console.cloud.google.com/storage/browser/pix2seq/multi_task/ckpt/vit_b_1024x1024)

In [None]:
#@title Imports.
import os
import sys

!pip install tensorflow
!pip install ml_collections
!pip install tensorflow-addons
!pip install tensorflow-text
!git clone https://github.com/google-research/pix2seq.git
sys.path.append(os.getcwd())
root_dir = os.getcwd()
sys.path.insert(1, 'pix2seq')

import tensorflow as tf
from PIL import Image
import numpy as np
import google.colab
import ml_collections
import json
import copy

from models import ar_model as model_lib
from tasks import instance_segmentation
from tasks import keypoint_detection
from tasks import object_detection
from tasks import captioning
from metrics import coco_metrics
from tasks import task as task_lib
from data import data_utils
import utils
from tasks.visualization import vis_utils
from configs import config_multi_task

In [None]:
## Download coco annotations
!mkdir /tmp/coco_annotations
!wget https://storage.googleapis.com/pix2seq/multi_task/data/coco/json/captions_train2017_eval_compatible.json -P /tmp/coco_annotations/
!wget https://storage.googleapis.com/pix2seq/multi_task/data/coco/json/captions_val2017_eval_compatible.json -P /tmp/coco_annotations/
!wget https://storage.googleapis.com/pix2seq/multi_task/data/coco/json/instances_train2017.json -P /tmp/coco_annotations/
!wget https://storage.googleapis.com/pix2seq/multi_task/data/coco/json/instances_val2017.json -P /tmp/coco_annotations/
!wget https://storage.googleapis.com/pix2seq/multi_task/data/coco/json/person_keypoints_train2017.json -P /tmp/coco_annotations/
!wget https://storage.googleapis.com/pix2seq/multi_task/data/coco/json/person_keypoints_val2017.json -P /tmp/coco_annotations/

In [None]:
#@title Load pix2seq multitask model.
config = config_multi_task.get_config('object_detection@coco/2017_object_detection+instance_segmentation@coco/2017_instance_segmentation+keypoint_detection@coco/2017_keypoint_detection+captioning@coco/2017_captioning,vit-b')
config.training = False

# Restore checkpoint.
model = model_lib.Model(config)
checkpoint = tf.train.Checkpoint(
    model=model, global_step=tf.Variable(0, dtype=tf.int64))
model_dir = 'gs://pix2seq/multi_task/ckpt/vit_b_640x640'
ckpt = tf.train.latest_checkpoint(model_dir)
checkpoint.restore(ckpt).expect_partial()
global_step = checkpoint.global_step

In [None]:
#@title Get task and infer_fn.
def get_task_config_and_infer_fn(config, task_name):
  tconfig = copy.deepcopy(config)
  for t in config.tasks:
    if t['name'] == task_name:
      tconfig.task = t
  for d in config.datasets:
    if d['name'] == task_name:
      tconfig.dataset = d
  tconfig.model_dir = ''

  assert tconfig.task['name'] == task_name
  task = task_lib.TaskRegistry.lookup(task_name)(tconfig)

  if task_name == "object_detection":
    tconfig.eval.batch_size = 1
    tconfig.task.max_instances_per_image_test = 10
  elif task_name == "instance_segmentation":
    tconfig.task.use_gt_box_at_test = True
    tconfig.task.ensemble_num_samples = 8
    # For faster inference use smaller number of samples.
    # tconfig.task.ensemble_num_samples = 1
    tconfig.task.ensemble_threshold = 0.5
    tconfig.eval.batch_size = 1
    tconfig.task.max_instances_per_image_test = 1
  elif task_name == "keypoint_detection":
    tconfig.task.use_gt_box_at_test = True
    tconfig.task.eval_suppress_invisible_token = True
    tconfig.task.unbatch=True
    tconfig.task.crop_to_bbox=True
    tconfig.task.crop_to_bbox_pad_scale=0.5
    tconfig.task.keypoint_score_weight=0.1
    tconfig.task.points_score_weight = 0.1
    tconfig.task.max_instances_per_image_test = 1
    tconfig.eval.batch_size = 1
    tconfig.task.top_p = 0.2
  elif task_name == "captioning":
    tconfig.task.captions_per_image = 1
    tconfig.task.max_instances_per_image = 1

  @tf.function
  def infer(model, preprocessed_outputs):
    return task.infer(model, preprocessed_outputs)

  return tconfig, task, infer

config_det, task_det, infer_det = get_task_config_and_infer_fn(config, "object_detection")
config_seg, task_seg, infer_seg = get_task_config_and_infer_fn(config, "instance_segmentation")
config_key, task_key, infer_key = get_task_config_and_infer_fn(config, "keypoint_detection")
config_cap, task_cap, infer_cap = get_task_config_and_infer_fn(config, "captioning")

In [None]:
#@title Functions to contruct inference examples and visualizations.
def get_preprocessed_outputs_detection(image, image_id, config):
  im = image
  h, w = np.shape(im)[0:2]
  features = {
      'image': tf.image.convert_image_dtype(np.array(im), tf.float32),
      'image/id': image_id,
      'orig_image_size': tf.shape(im)[0:2],
      'image_size_before_cropping': tf.shape(im)[0:2],
  }
  n_instance = config.task.max_instances_per_image_test
  labels = {
      'label': tf.zeros([n_instance], tf.int64),
      'bbox': tf.zeros([n_instance, 4], tf.float32),
      'is_crowd': tf.zeros([n_instance], tf.bool),
      'area': tf.zeros([n_instance]),
  }

  features, labels = data_utils.preprocess_eval(
      features,
      labels,
      max_image_size=config.model.image_size,
      max_instances_per_image=config.task.max_instances_per_image_test,
      object_coordinate_keys=('bbox',))

  # Batch features and labels.
  features = {
      k: tf.expand_dims(v, 0) for k, v in features.items()
  }
  labels = {
      k: tf.expand_dims(v, 0) for k, v in labels.items()
  }

  preprocessed_outputs = (features['image'], None, (features, labels))
  return preprocessed_outputs


def visualize_detection(results, image, categeory_names):
  (images, image_ids, pred_bboxes, pred_bboxes_rescaled, pred_classes,
   scores, gt_classes, gt_bboxes, gt_bboxes_rescaled, area, is_crowd) = results
  n_instance = config.task.max_instances_per_image_test

  pred_bboxes_rescaled = tf.reshape(pred_bboxes_rescaled, [-1, n_instance, 4])
  pred_classes = tf.reshape(pred_classes, [-1, n_instance])
  scores = tf.reshape(scores, [-1, n_instance])

  bboxes_ = pred_bboxes_rescaled.numpy()
  classes_ = pred_classes.numpy()
  scores_ = scores.numpy()
  images_ = np.copy(tf.image.convert_image_dtype(image, tf.uint8))

  keep_indices = np.where(classes_[0] > 0)
  vis = vis_utils.visualize_boxes_and_labels_on_image_array(
      image=images_,
      boxes=bboxes_[0][keep_indices],
      classes=classes_[0][keep_indices],
      scores=scores_[0][keep_indices],
      category_index=categeory_names,
      use_normalized_coordinates=False,
      min_score_thresh=0.9,
      skip_labels=False,
      skip_scores=False,
      line_thickness=4)

  return vis


def get_preprocessed_outputs_segmentation(image, image_id, bbox, config, label_id=1):
  im = image
  ymin, xmin, ymax, xmax = bbox
  h, w = np.shape(im)[0:2]
  features = {
      'image': tf.image.convert_image_dtype(np.array(im), tf.float32),
      'image/id': image_id,
      'orig_image_size': tf.shape(im)[0:2],
      'image_size_before_cropping': tf.shape(im)[0:2],
  }
  labels = {
      'label': tf.convert_to_tensor([label_id], tf.int64),
      'bbox': tf.convert_to_tensor([[
          ymin/h, xmin/w, ymax/h, xmax/w
      ]], tf.float32),
      'scores': tf.ones([1]),
      'is_crowd': tf.zeros([1], tf.bool),
      'area': tf.zeros([1]),
  }

  features, labels = data_utils.preprocess_eval(
      features,
      labels,
      max_image_size=config.model.image_size,
      max_instances_per_image=config.task.max_instances_per_image_test,
      object_coordinate_keys=('bbox',))

  # Batch features and labels.
  features = {
      k: tf.expand_dims(v, 0) for k, v in features.items()
  }
  labels = {
      k: tf.expand_dims(v, 0) for k, v in labels.items()
  }

  preprocessed_outputs = (features['image'], None, (features, labels))
  return preprocessed_outputs

def visualize_segmentation(results, config, category_names):
  (images, image_ids, orig_image_size, unpadded_image_size, n_instances,  # pylint: disable=unbalanced-tuple-unpacking
    pred_points, pred_points_rescaled, pred_classes, pred_bboxes,
    scores) = results
  n_instances = n_instances[0]
  num_samples = config.task.ensemble_num_samples
  threshold = config.task.ensemble_threshold
  bsz = tf.shape(images)[0]

  # Log/accumulate metrics.
  mask_size = orig_image_size
  pred_masks_rle = instance_segmentation.segment_to_mask_rle(
      pred_points_rescaled.numpy(), mask_size.numpy())
  # Ensemble mask for metrics.
  pred_masks_np = instance_segmentation.mask_rle_to_mask_np(
      pred_masks_rle, mask_size.numpy())
  pred_masks_rle = instance_segmentation.ensemble_mask_np(
      pred_masks_np, (bsz * n_instances).numpy(), num_samples, threshold)

  # Image summary.
  image_size = images.shape[1:3].as_list()
  pred_points_rescaled = utils.scale_points(pred_points, image_size)
  pred_masks_rle = instance_segmentation.segment_to_mask_rle(
      pred_points_rescaled.numpy(), image_size, is_single_shape=True)
  pred_masks_np = instance_segmentation.mask_rle_to_mask_np(
      pred_masks_rle, image_size, is_single_shape=True)
  # Ensemble masks for image summary.
  pred_masks_np = instance_segmentation.ensemble_mask_np(
      pred_masks_np, (bsz * n_instances).numpy(), num_samples, threshold)
  pred_masks_np = np.asarray(pred_masks_np, np.uint8)
  mask_new_shape = [pred_masks_np.shape[0]//n_instances, n_instances] + list(
      pred_masks_np.shape[1:])
  pred_masks_np = np.reshape(pred_masks_np, mask_new_shape)
  pred_bboxes = tf.reshape(pred_bboxes, [-1, n_instances, 4])
  pred_classes = tf.reshape(pred_classes, [-1, n_instances])
  scores = tf.reshape(scores, [-1, n_instances])

  bboxes_ = pred_bboxes.numpy()
  classes_ = pred_classes.numpy()
  scores_ = scores.numpy()
  images_ = np.copy(tf.image.convert_image_dtype(images, tf.uint8))

  keep_indices = np.where(classes_[0] > 0)[0]
  vis = vis_utils.visualize_boxes_and_labels_on_image_array(
      image=images_[0],
      boxes=bboxes_[0][keep_indices],
      classes=classes_[0][keep_indices],
      scores=scores_[0][keep_indices],
      category_index=category_names,
      instance_masks=pred_masks_np[0][keep_indices],
      use_normalized_coordinates=True,
      max_boxes_to_draw=20,
      min_score_thresh=0.9,
      skip_boxes=False,
      skip_scores=True,
      skip_labels=False)
  
  return vis

def get_preprocessed_outputs_keypoint(image, image_id, bbox, config, label_id=1):
  im = image
  ymin, xmin, ymax, xmax = bbox
  h, w = np.shape(im)[0:2]
  features = {
      'image': tf.image.convert_image_dtype(np.array(im), tf.float32),
      'image/id': image_id,
      'orig_image_size': tf.shape(im)[0:2],
      'image_size_before_cropping': tf.shape(im)[0:2],
  }
  labels = {
      'label': tf.convert_to_tensor([label_id], tf.int64),
      'bbox': tf.convert_to_tensor([[
          ymin/h, xmin/w, ymax/h, xmax/w
      ]], tf.float32),
      'scores': tf.ones([1]),
      'is_crowd': tf.zeros([1], tf.bool),
      'area': tf.zeros([1]),
      'keypoints': tf.zeros([1, 14])
  }

  @tf.autograph.experimental.do_not_convert
  def crop_to_bbox(features, labels, config):
    tconfig = config.task
    per_instance_points_keys=('polygon', 'keypoints')
    bbox = labels['bbox'][0]
    image_shape = tf.shape(features['image'])[:2]

    # Normalized bbox coords.
    ymin, xmin, ymax, xmax = [tf.squeeze(t) for t in tf.split(bbox, 4)]
    image_h, image_w = [
        tf.squeeze(t) for t in tf.split(utils.tf_float32(image_shape), 2)
    ]
    
    def i32(t):
      return tf.cast(t, tf.int32)
    def f32(t):
      return tf.cast(t, tf.float32)

    floor = tf.math.floor
    ceil = tf.math.ceil
    ymin = i32(floor(ymin * (image_h - 1)))
    xmin = i32(floor(xmin * (image_w - 1)))
    ymax = i32(ceil(ymax * (image_h - 1)))
    xmax = i32(ceil(xmax * (image_w - 1)))

    crop_to_bbox_pad_scale = tconfig.crop_to_bbox_pad_scale
    bbox_h = ymax - ymin + 1
    bbox_w = xmax - xmin + 1
    ypad = i32(f32(bbox_h) * crop_to_bbox_pad_scale)
    xpad = i32(f32(bbox_w) * crop_to_bbox_pad_scale)
    ymin = tf.math.maximum(0, ymin - ypad)
    ymax = tf.math.minimum(image_shape[0] - 1, ymax + ypad)
    xmin = tf.math.maximum(0, xmin - xpad)
    xmax = tf.math.minimum(image_shape[1] - 1, xmax + xpad)

    region = [ymin, xmin, ymax - ymin + 1, xmax - xmin + 1]
    points_orig = {
        key: labels[key] for key in per_instance_points_keys if key in labels
    }
    features, labels = data_utils.crop(features, labels, region)
    for key in points_orig:
      labels[key] = utils.preserve_reserved_tokens(
          labels[key], points_orig[key])
    return features, labels

  # Crop.
  features, labels = crop_to_bbox(features, labels, config_key)
  features['orig_image_size'] = tf.shape(features['image'])[0:2]

  features, labels = data_utils.preprocess_eval(
      features, labels,
      max_image_size=config.model.image_size,
      max_instances_per_image=1,
      object_coordinate_keys=('bbox',))

  # Batch features and labels.
  features = {
      k: tf.expand_dims(v, 0) for k, v in features.items()
  }
  labels = {
      k: tf.expand_dims(v, 0) for k, v in labels.items()
  }

  preprocessed_outputs = (features['image'], None, (features, labels))
  return preprocessed_outputs

def visualize_keypoint(results, image):
  (images, image_ids, n_instances, pred_points, pred_points_rescaled,  # pylint: disable=unbalanced-tuple-unpacking
   pred_classes, pred_bboxes, scores) = results
  n_instances = tf.shape(images)[0]

  pred_bboxes = tf.reshape(pred_bboxes, [-1, n_instances, 4])
  pred_classes = tf.reshape(pred_classes, [-1, n_instances])
  scores = tf.reshape(scores, [-1, n_instances])
  points_shape = [-1, n_instances, pred_points.shape[-1] // 2, 2]
  pred_points = tf.reshape(pred_points_rescaled, points_shape)

  bboxes_ = pred_bboxes.numpy()
  classes_ = pred_classes.numpy()
  scores_ = scores.numpy()
  keypoints_ = pred_points.numpy()
  images_ = np.copy(tf.image.convert_image_dtype(image, tf.uint8))

  keypoint_edges = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12],
                    [7, 13], [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3],
                    [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
  for i in range(len(keypoint_edges)):
    for j in range(len(keypoint_edges[i])):
      keypoint_edges[i][j] -= 1
  person_class_id = 1

  keep_indices = np.where(classes_[0] == person_class_id)[0]
  vis = vis_utils.visualize_boxes_and_labels_on_image_array(
      image=images_,
      boxes=bboxes_[0][keep_indices],
      classes=classes_[0][keep_indices],
      scores=scores_[0][keep_indices],
      category_index={},
      keypoints=keypoints_[0][keep_indices],
      keypoint_edges=keypoint_edges,
      use_normalized_coordinates=False,
      max_boxes_to_draw=1,
      min_score_thresh=0.9,
      skip_labels=True,
      skip_scores=True,
      line_thickness=4)

  return vis

def get_preprocessed_outputs_captioning(image, image_id, config):
  im = image
  ymin, xmin, ymax, xmax = bbox
  h, w = np.shape(im)[0:2]

  features = {
      'image': tf.image.convert_image_dtype(np.array(im), tf.float32),
      'image/id': image_id,
      'orig_image_size': tf.shape(im)[0:2],
  }
  labels = {
      'label': tf.zeros([1], tf.int64),
      'bbox': tf.zeros([1, 4]),
      'area': tf.zeros([1]),
      'is_crowd': tf.zeros([1]),
      'captions': tf.convert_to_tensor(["dummy"], tf.string)
  }

  features, labels = data_utils.preprocess_eval(
      features,
      labels,
      max_image_size=config.model.image_size,
      max_instances_per_image=1)
    
  # Batch features and labels.
  features = {
      k: tf.expand_dims(v, 0) for k, v in features.items()
  }
  labels = {
      k: tf.expand_dims(v, 0) for k, v in labels.items()
  }

  preprocessed_outputs = (features['image'], None, (features, labels))
  return preprocessed_outputs

def print_captioning_result(outputs, tokenizer):
  pred_seq = outputs[3]
  pred_seq = tf.where(pred_seq == 0, 0, pred_seq - config_cap.model.text_vocab_shift)
  print(tokenizer.detokenize([int(j) for j in pred_seq[0].numpy()]).numpy())

In [None]:
# Download image.
import requests
def get_image(image_id, train=False):
  image_id = "{:0>6d}".format(image_id)
  split = 'train' if train else 'val'
  url = f'http://images.cocodataset.org/{split}2017/000000{image_id}.jpg'
  with tf.io.gfile.GFile(url) as f:
    im = Image.open(requests.get(url, stream=True).raw)
  return im

In [None]:
image_id = 230983
im = get_image(image_id, train=False)
im

In [None]:
# Object detection.
preprocessed_outputs = get_preprocessed_outputs_detection(
    im, image_id, config_det)
infer_outputs_det = infer_det(model, preprocessed_outputs)
results_det = task_det.postprocess_tpu(*infer_outputs_det)

vis_det = visualize_detection(results_det, np.asarray(im), task_det._category_names)
Image.fromarray(vis_det)

In [None]:
# Get a bbox for person and use it for keypoint detection and instance segmentation.
pred_bboxes_rescaled = results_det[3].numpy().reshape((-1, 4))
pred_classes = results_det[4].numpy().reshape((-1))
person_idx = np.where(pred_classes == 1)[0][0]
bbox = list(pred_bboxes_rescaled[person_idx])
# You can also specify a custom box in (ymin, xmin, ymax, xmax) format.
# bbox = [148.5505782847104, 149.26714806698848, 325.032441173339, 294.49976727245297]

In [None]:
# Instance segmentation.
preprocessed_outputs = get_preprocessed_outputs_segmentation(
    im, image_id, bbox, config_seg)
infer_outputs_seg = infer_seg(model, preprocessed_outputs)
results_seg = task_seg.postprocess_tpu(*infer_outputs_seg)

vis_seg = visualize_segmentation(results_seg, config_seg, task_seg._category_names)
Image.fromarray(vis_seg)

In [None]:
# Keypoint detection.
preprocessed_outputs = get_preprocessed_outputs_keypoint(
    im, image_id, bbox, config_key)
infer_outputs_key = infer_key(model, preprocessed_outputs)
results_key = task_key.postprocess_tpu(*infer_outputs_key)

vis_key = visualize_keypoint(results_key, tf.image.convert_image_dtype(np.array(im), tf.float32))
Image.fromarray(vis_key)

In [None]:
# Captioning.
preprocessed_outputs = get_preprocessed_outputs_captioning(
    im, image_id, config_cap)
captions = tf.expand_dims(task_cap._tokenizer.string_to_ids(["dummy"])[0],0)
preprocessed_outputs[2][1]['captions']=captions
infer_outputs = infer_cap(model, preprocessed_outputs)
examples={**infer_outputs[0][0],**infer_outputs[0][1]}
outputs = task_cap.postprocess_tpu(examples,infer_outputs[1],infer_outputs[2])
print_captioning_result(outputs, task_cap._tokenizer.tokenizer)