Copyright 2021 Google LLC.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

### Composable Augmentation Encoding for Video Representation Learning (CATE)

[arXiv](https://arxiv.org/abs/2104.00616)

[Project page](https://sites.google.com/corp/brown.edu/cate-iccv2021/)

This colab demonstrates how to load pretrained CATE models from hub modules and run inference on video frames. It also includes an example of nearest neighbor classification experiment on the UCF-101 dataset.

The checkpoints are accessible in the following Google Cloud storage directories:

  - gs://gresearch/cate-iccv2021/kinetics400/
  
  - gs://gresearch/cate-iccv2021/something_v1/

  - gs://gresearch/cate-iccv2021/something_v2/

In [2]:
import re
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
import tensorflow_hub as hub
import tensorflow_datasets as tfds

In [3]:
#@title Define video preprocessing functions for inference.

def sample_linspace_frames(frames, num_frames, num_windows):
  total_frames = tf.shape(frames)[0]
  sel_idx = tf.range(total_frames)
  num_repeats = tf.to_int32(
      tf.ceil(tf.div(tf.to_float(num_frames), tf.to_float(total_frames))))
  sel_idx = tf.tile(sel_idx, [num_repeats])
  total_frames = tf.maximum(total_frames, num_frames)
  offsets = tf.linspace(0.0, tf.cast(total_frames - num_frames, tf.float32), num_windows)
  offsets = tf.cast(offsets, tf.int32)
  output_idx = []
  for i in range(num_windows):
    window_idx = tf.slice(sel_idx, [offsets[i]], [num_frames])
    output_idx.append(window_idx)
  output_idx = tf.concat(output_idx, axis=0)
  return tf.gather(frames, output_idx)

def _compute_crop_shape(
    image_height, image_width, aspect_ratio, crop_proportion):
  image_width_float = tf.cast(image_width, tf.float32)
  image_height_float = tf.cast(image_height, tf.float32)

  def _requested_aspect_ratio_wider_than_image():
    crop_height = tf.cast(tf.rint(
        crop_proportion / aspect_ratio * image_width_float), tf.int32)
    crop_width = tf.cast(tf.rint(
        crop_proportion * image_width_float), tf.int32)
    return crop_height, crop_width

  def _image_wider_than_requested_aspect_ratio():
    crop_height = tf.cast(
        tf.rint(crop_proportion * image_height_float), tf.int32)
    crop_width = tf.cast(tf.rint(
        crop_proportion * aspect_ratio *
        image_height_float), tf.int32)
    return crop_height, crop_width

  return tf.cond(
      aspect_ratio > image_width_float / image_height_float,
      _requested_aspect_ratio_wider_than_image,
      _image_wider_than_requested_aspect_ratio) 

def center_crop(images, height, width, crop_proportion):
  shape = tf.shape(images)
  image_height = shape[1]
  image_width = shape[2]
  crop_height, crop_width = _compute_crop_shape(
      image_height, image_width, height / width, crop_proportion)
  offset_height = ((image_height - crop_height) + 1) // 2
  offset_width = ((image_width - crop_width) + 1) // 2
  images = tf.image.crop_to_bounding_box(
      images, offset_height, offset_width, crop_height, crop_width)
  images = tf.image.resize_bicubic(images, [height, width])
  return images

def preprocess_video(video, num_frames, height, width, num_windows):
  video = sample_linspace_frames(video, num_frames, num_windows)
  video = tf.image.convert_image_dtype(video, dtype=tf.float32)
  video = center_crop(video, height, width, crop_proportion=0.875)
  video = tf.clip_by_value(video, 0., 1.)
  video = tf.reshape(video, [num_windows, num_frames, height, width, 3])
  return video

In [4]:
#@title Load tfhub checkpoint from Cloud storage.

hub_path = 'gs://gresearch/cate-iccv2021/kinetics400/'
module = hub.Module(hub_path, trainable=False)
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [5]:
#@title Set up data loader for UCF-101.

# Averages the representation of 8 sliding windows of 32 frames.
num_windows = 8
num_frames = 32
dataset_name = 'ucf101/ucf101_1'

train_dataset, train_info = tfds.load(dataset_name, split='train', with_info=True)
num_train_examples = train_info.splits['train'].num_examples
num_classes = train_info.features['label'].num_classes
test_dataset, test_info = tfds.load(dataset_name, split='test', with_info=True)
num_test_examples = test_info.splits['test'].num_examples

def _preprocess(x):
  x['video'] = preprocess_video(x['video'], num_frames, 224, 224, num_windows)
  return x

x_train = train_dataset.map(_preprocess).batch(1)
x_train = tf.data.make_one_shot_iterator(x_train).get_next()
x_test = test_dataset.map(_preprocess).batch(1)
x_test = tf.data.make_one_shot_iterator(x_test).get_next()

print(num_train_examples)
print(num_test_examples)

9537
3783


In [6]:
#@title Compute the final activations from a pretrained ResNet-3D-50 network.

x_train['video'] = tf.reshape(x_train['video'],
                              [num_windows, num_frames, 224, 224, 3])
output_train = module(inputs=x_train['video'], signature='default', as_dict=True)

train_data = []
for i in range(num_train_examples):
  video, label, hiddens = sess.run((x_train['video'], x_train['label'], output_train['hiddens']))
  train_data.append((label, hiddens))
  if i % 100 == 0:
    print('%d out of %d examples processed.' % (i, num_train_examples))

0 out of 9537 examples processed.
100 out of 9537 examples processed.
200 out of 9537 examples processed.
300 out of 9537 examples processed.
400 out of 9537 examples processed.
500 out of 9537 examples processed.
600 out of 9537 examples processed.
700 out of 9537 examples processed.
800 out of 9537 examples processed.
900 out of 9537 examples processed.
1000 out of 9537 examples processed.
1100 out of 9537 examples processed.
1200 out of 9537 examples processed.
1300 out of 9537 examples processed.
1400 out of 9537 examples processed.
1500 out of 9537 examples processed.
1600 out of 9537 examples processed.
1700 out of 9537 examples processed.
1800 out of 9537 examples processed.
1900 out of 9537 examples processed.
2000 out of 9537 examples processed.
2100 out of 9537 examples processed.
2200 out of 9537 examples processed.
2300 out of 9537 examples processed.
2400 out of 9537 examples processed.
2500 out of 9537 examples processed.
2600 out of 9537 examples processed.
2700 out of 9

In [7]:
x_test['video'] = tf.reshape(x_test['video'],
                             [num_windows, num_frames, 224, 224, 3])
output_test = module(inputs=x_test['video'], signature='default', as_dict=True)

test_data = []
for i in range(num_test_examples):
  video, label, hiddens = sess.run((x_test['video'], x_test['label'], output_test['hiddens']))
  test_data.append((label, hiddens))
  if i % 100 == 0:
    print('%d out of %d examples processed.' % (i, num_test_examples))

0 out of 3783 examples processed.
100 out of 3783 examples processed.
200 out of 3783 examples processed.
300 out of 3783 examples processed.
400 out of 3783 examples processed.
500 out of 3783 examples processed.
600 out of 3783 examples processed.
700 out of 3783 examples processed.
800 out of 3783 examples processed.
900 out of 3783 examples processed.
1000 out of 3783 examples processed.
1100 out of 3783 examples processed.
1200 out of 3783 examples processed.
1300 out of 3783 examples processed.
1400 out of 3783 examples processed.
1500 out of 3783 examples processed.
1600 out of 3783 examples processed.
1700 out of 3783 examples processed.
1800 out of 3783 examples processed.
1900 out of 3783 examples processed.
2000 out of 3783 examples processed.
2100 out of 3783 examples processed.
2200 out of 3783 examples processed.
2300 out of 3783 examples processed.
2400 out of 3783 examples processed.
2500 out of 3783 examples processed.
2600 out of 3783 examples processed.
2700 out of 3

In [8]:
#@title Run nearest neighbor classification.

# We follow the standard setup to evaluate nearest neighbor video retrieval:
# For each video in the test set, we query its K nearest neighbors from the
# training set. If any label of the retrieved training examples matches the
# ground truth label of the query example, we deem it a correct match.

def prepare_knn_data(d):
  x = []
  y = []
  for i in range(len(d)):
    x.append(d[i][1])
    y.append(d[i][0])
  x = np.concatenate(x, 0).reshape((-1, num_windows, 2048))
  # Average pool the features from num_windows windows.
  x = np.mean(x, 1)
  # L2 normalize the features.
  x /= np.linalg.norm(x, axis=1).reshape(-1, 1)
  y = np.concatenate(y, 0).reshape((-1))
  return x, y

def compute_cosine_dist(x_train, x_test):
  return 1 - cosine_similarity(x_test, x_train)

def knn_evaluation(dist, y_train, y_test, k):
  num_samples = dist.shape[0]
  hit = 0
  for i in range(num_samples):
    pred = np.argsort(dist[i])
    for j in range(k):
      if y_train[pred[j]] == y_test[i]:
        hit += 1
        break
  recall = hit / num_samples
  print('top %d recall: %f' % (k, hit / num_samples))

x_train, y_train = prepare_knn_data(train_data)
x_test, y_test = prepare_knn_data(test_data)
dist = compute_cosine_dist(x_train, x_test)
knn_evaluation(dist, y_train, y_test, 1)
knn_evaluation(dist, y_train, y_test, 5)
knn_evaluation(dist, y_train, y_test, 10)
knn_evaluation(dist, y_train, y_test, 20)
knn_evaluation(dist, y_train, y_test, 50)

top 1 recall: 0.548771
top 5 recall: 0.683320
top 10 recall: 0.750991
top 20 recall: 0.822892
top 50 recall: 0.898758
