In [None]:
#@title License
# Copyright 2022 The Pix2Seq Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

## A Unified Sequence Interface for Vision Tasks
<a href="https://colab.research.google.com/github/google-research/pix2seq/blob/master/colabs/pix2seq_inference_multitask.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


This colab presents a demo for multi-task inference with Pix2seq. The table below provides a summary and model location for fine-tuned models on MSCOCO dataset.

Backbone       | Total params (M) | Image size | COCO AP   | Google cloud storage location
-------------: | ---------------: | ---------: | --------: | -----------:
ViT-B          | 115.2            | 640x640    | 44.2      | [gs://pix2seq/multi_task/ckpt/vit_b_640x640](https://console.cloud.google.com/storage/browser/pix2seq/multi_task/ckpt/vit_b_640x640)
ViT-B          | 115.2            | 1024x1024  | 46.5      | [gs://pix2seq/multi_task/ckpt/vit_b_1024x1024](https://console.cloud.google.com/storage/browser/pix2seq/multi_task/ckpt/vit_b_1024x1024)

In [None]:
#@title Imports.
import os
import sys

!pip install tensorflow
!pip install ml_collections
!pip install tensorflow-addons
!pip install tensorflow-text
!git clone https://github.com/google-research/pix2seq.git
sys.path.append(os.getcwd())
root_dir = os.getcwd()
sys.path.insert(1, 'pix2seq')

import tensorflow as tf
from PIL import Image
import numpy as np
import google.colab
import ml_collections
import json
import copy
import einops

from models import ar_model as model_lib
from tasks import instance_segmentation
from tasks import keypoint_detection
from tasks import object_detection
from tasks import captioning
from metrics import coco_metrics
from tasks import task as task_lib
from data import data_utils
import utils
from tasks.visualization import vis_utils
from configs import config_multi_task

In [None]:
!pip install einops

In [None]:
## Download coco annotations
!mkdir /tmp/coco_annotations
!wget https://storage.googleapis.com/pix2seq/multi_task/data/coco/json/captions_train2017_eval_compatible.json -P /tmp/coco_annotations/
!wget https://storage.googleapis.com/pix2seq/multi_task/data/coco/json/captions_val2017_eval_compatible.json -P /tmp/coco_annotations/
!wget https://storage.googleapis.com/pix2seq/multi_task/data/coco/json/instances_train2017.json -P /tmp/coco_annotations/
!wget https://storage.googleapis.com/pix2seq/multi_task/data/coco/json/instances_val2017.json -P /tmp/coco_annotations/
!wget https://storage.googleapis.com/pix2seq/multi_task/data/coco/json/person_keypoints_train2017.json -P /tmp/coco_annotations/
!wget https://storage.googleapis.com/pix2seq/multi_task/data/coco/json/person_keypoints_val2017.json -P /tmp/coco_annotations/

In [None]:
#@title Load pix2seq multitask model.
config = config_multi_task.get_config('object_detection@coco/2017_object_detection+instance_segmentation@coco/2017_instance_segmentation+keypoint_detection@coco/2017_keypoint_detection+captioning@coco/2017_captioning,vit-b')
config.training = False

# Restore checkpoint.
model = model_lib.Model(config)
checkpoint = tf.train.Checkpoint(
    model=model, global_step=tf.Variable(0, dtype=tf.int64))
model_dir = 'gs://pix2seq/multi_task/ckpt/vit_b_640x640'
ckpt = tf.train.latest_checkpoint(model_dir)
checkpoint.restore(ckpt).expect_partial()
global_step = checkpoint.global_step

In [None]:
#@title Get task and infer_fn.
def get_task_config_and_infer_fn(config, task_name):
  tconfig = copy.deepcopy(config)
  for t in config.tasks:
    if t['name'] == task_name:
      tconfig.task = t
  for d in config.datasets:
    if d['name'] == task_name:
      tconfig.dataset = d
  tconfig.model_dir = ''

  assert tconfig.task['name'] == task_name
  task = task_lib.TaskRegistry.lookup(task_name)(tconfig)

  if task_name == "object_detection":
    tconfig.eval.batch_size = 1
    tconfig.task.max_instances_per_image_test = 10
  elif task_name == "instance_segmentation":
    tconfig.task.use_gt_box_at_test = True
    tconfig.task.ensemble_num_samples = 8
    # For faster inference use smaller number of samples.
    # tconfig.task.ensemble_num_samples = 1
    tconfig.task.ensemble_threshold = 0.5
    tconfig.eval.batch_size = 1
    tconfig.task.max_instances_per_image_test = 1
  elif task_name == "keypoint_detection":
    tconfig.task.use_gt_box_at_test = True
    tconfig.task.eval_suppress_invisible_token = True
    tconfig.task.unbatch=True
    tconfig.task.crop_to_bbox=True
    tconfig.task.crop_to_bbox_pad_scale=0.5
    tconfig.task.keypoint_score_weight=0.1
    tconfig.task.points_score_weight = 0.1
    tconfig.task.max_instances_per_image_test = 1
    tconfig.eval.batch_size = 1
    tconfig.task.top_p = 0.2
  elif task_name == "captioning":
    tconfig.task.captions_per_image = 1
    tconfig.task.max_instances_per_image = 1

  @tf.function
  def infer(model, preprocessed_outputs):
    return task.infer(model, preprocessed_outputs)

  return tconfig, task, infer

config_det, task_det, infer_det = get_task_config_and_infer_fn(config, "object_detection")
config_seg, task_seg, infer_seg = get_task_config_and_infer_fn(config, "instance_segmentation")
config_key, task_key, infer_key = get_task_config_and_infer_fn(config, "keypoint_detection")
config_cap, task_cap, infer_cap = get_task_config_and_infer_fn(config, "captioning")

In [None]:
#@title Functions to contruct inference examples and visualizations.

def get_preprocessed_outputs_captioning(image, image_id, config):
  im = image
  ymin, xmin, ymax, xmax = bbox
  h, w = np.shape(im)[0:2]

  features = {
      'image': tf.image.convert_image_dtype(np.array(im), tf.float32),
      'image/id': image_id,
      'orig_image_size': tf.shape(im)[0:2],
  }
  labels = {
      'label': tf.zeros([1], tf.int64),
      'bbox': tf.zeros([1, 4]),
      'area': tf.zeros([1]),
      'is_crowd': tf.zeros([1]),
      # 'captions': tf.convert_to_tensor(["dummy"], tf.int64)
  }

  features, labels = data_utils.preprocess_eval(
      features,
      labels,
      max_image_size=config.model.image_size,
      max_instances_per_image=1)

  # Batch features and labels.
  features = {
      k: tf.expand_dims(v, 0) for k, v in features.items()
  }
  labels = {
      k: tf.expand_dims(v, 0) for k, v in labels.items()
  }

  preprocessed_outputs = (features['image'], None, (features, labels))
  return preprocessed_outputs

def print_captioning_result(outputs, tokenizer):
  pred_seq = outputs[3]
  pred_seq = tf.where(pred_seq == 0, 0, pred_seq - config_cap.model.text_vocab_shift)
  print(tokenizer.detokenize([int(j) for j in pred_seq[0].numpy()]).numpy())

In [None]:
# Download image.
import requests
def get_image(image_id, train=False):
  image_id = "{:0>6d}".format(image_id)
  split = 'train' if train else 'val'
  url = f'http://images.cocodataset.org/{split}2017/000000{image_id}.jpg'
  with tf.io.gfile.GFile(url) as f:
    im = Image.open(requests.get(url, stream=True).raw)
  return im

In [None]:
image_id = 230983
im = get_image(image_id, train=False)
im

In [None]:
# You can also specify a custom box in (ymin, xmin, ymax, xmax) format.
bbox = [148.5505782847104, 149.26714806698848, 325.032441173339, 294.49976727245297]

In [None]:
config_cap

In [None]:
def shape_as_list(t):
  # Assumes rank of `t` is statically known.
  shape = t.shape.as_list()
  print(shape)
  dynamic_shape = tf.shape(t)
  print(dynamic_shape)
  return [
      shape[i] if shape[i] is not None else dynamic_shape[i]
      for i in range(len(shape))
  ]

def pad_to_max_len(data, max_len, dim, padding_token=0):
  """Pad the data tensor to max length on dim."""
  shape = shape_as_list(data)
  padding_shape, new_shape = copy.copy(shape), copy.copy(shape)
  padding_shape[dim] = max_len - padding_shape[dim]
  new_shape[dim] = max_len
  paddings = tf.fill(padding_shape, tf.cast(padding_token, dtype=data.dtype))
  return tf.reshape(tf.concat([data, paddings], axis=dim), new_shape)

caption = task_cap._tokenizer.string_to_ids("dummy")
print(caption)
caption = pad_to_max_len(caption, 1, 0)

In [None]:
# Captioning.
preprocessed_outputs = get_preprocessed_outputs_captioning(im, image_id, config_cap)

In [None]:
print(preprocessed_outputs[2][1]) #label

In [None]:
print(tf.convert_to_tensor(["dummy"],tf.string))
cap_str = tf.convert_to_tensor(["dummy"],tf.string)

print(task_cap._tokenizer)
tokens = task_cap._tokenizer.tokenizer.tokenize(cap_str)
print(tokens)

caption = task_cap._tokenizer.string_to_ids(tf.convert_to_tensor(["dummy"],tf.string))
print(caption)

In [None]:
config_cap

In [None]:
infer_outputs = infer_cap(model, preprocessed_outputs)

In [None]:
outputs = task_cap.postprocess_tpu(*(infer_outputs[0][0]))
print_captioning_result(outputs, task_cap._tokenizer.tokenizer)

In [None]:
print((infer_outputs[0][0]['captions']))

In [None]:
task_cap._tokenizer

In [None]:
!cd pix2seq/
import os
os.getcwd()

In [None]:
!pip install tensorflow_gan

In [None]:
!config=configs/config_multi_task.py:captioning@coco/2017_captioning,vit-b
!model_dir=/tmp/pix2seq_eval_cap

!cd pix2seq && PYTHONPATH='./' python3 run.py --config='configs/config_multi_task.py:captioning@coco/2017_captioning,vit-b' --model_dir='/tmp/pix2seq_eval_cap' --mode=eval