#### Copyright 2021 Google LLC.

In [None]:
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Hourglass: ImageNet32/64 evaluation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/research/examples/hourglass_downsampled_imagenet.ipynb)

### Install dependencies

In [None]:
!pip install -q --upgrade jaxlib==0.1.71+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
!pip install -q --upgrade jax==0.2.21
!pip install -q git+https://github.com/google/trax.git
!pip install -q pickle5
!pip install -q gin

In [None]:
# Execute this for a proper TPU setup!
# Make sure the Colab Runtime is set to Accelerator: TPU.
import jax
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
    url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20200416'
    resp = requests.post(url)
    TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
jax.devices()

### Download ImageNet32/64 data

Downloading the datasets for evaluation requires some hacks because URLs from `tensorflow_datasets` are invalid. Two cells below download data for ImageNet32 and ImageNet64, respectively. Choose the one appropriate for the checkpoint you want to evaluate.

In [None]:
# Download ImageNet32 data (the url in tfds is down)
!gdown https://drive.google.com/uc?id=1OV4lBnuIcbqeuoiK83jWtlnQ9Afl6Tsr
!tar -zxf /content/im32.tar.gz

# tfds hack for imagenet32
import json
json_path = '/content/content/drive/MyDrive/imagenet/downsampled_imagenet/32x32/2.0.0/dataset_info.json'
with open(json_path, mode='r') as f:
    ds_info = json.load(f)
    if 'moduleName' in ds_info:
        del ds_info['moduleName']
with open(json_path, mode='w') as f:
    json.dump(ds_info, f)

!mkdir -p /root/tensorflow_datasets/downsampled_imagenet/32x32
!cp -r /content/content/drive/MyDrive/imagenet/downsampled_imagenet/32x32/2.0.0 /root/tensorflow_datasets/downsampled_imagenet/32x32

In [None]:
# Download  and set up ImageNet64 (validation only) data
!gdown https://drive.google.com/uc?id=1ZoI3ZKMUXfrIlqPfIBCcegoe0aJHchpo

!tar -zxf im64_valid.tar.gz
!mkdir -p /root/tensorflow_datasets/downsampled_imagenet/64x64/2.0.0
!cp im64_valid/* /root/tensorflow_datasets/downsampled_imagenet/64x64/2.0.0

In [None]:
# Download gin configs
!wget -q https://raw.githubusercontent.com/google/trax/master/trax/supervised/configs/hourglass_imagenet32.gin
!wget -q https://raw.githubusercontent.com/google/trax/master/trax/supervised/configs/hourglass_imagenet64.gin

### Load the ImageNet32 model

This colab can be used to evaluate both imagenet32 and imagenet64 models. We start with our ImageNet32 checkpoint.

In [None]:
import gin
gin.parse_config_file('hourglass_imagenet32.gin')

model = trax.models.HourglassLM(mode='eval')
model.init_from_file(
    'gs://trax-ml/hourglass/imagenet32/model_470000.pkl.gz',
    weights_only=True,
)

loss_fn = trax.layers.WeightedCategoryCrossEntropy()
model_eval = trax.layers.Accelerate(trax.layers.Serial(
    model,
    loss_fn
))

### Evaluate on the validation set

In [None]:
import gin
import trax

# Here is the hacky part to remove shuffling of the dataset
def get_eval_dataset():
    dataset_name = gin.query_parameter('data_streams.dataset_name')
    data_dir = trax.data.tf_inputs.download_and_prepare(dataset_name, None)

    train_data, eval_data, keys = trax.data.tf_inputs._train_and_eval_dataset(
        dataset_name, data_dir, eval_holdout_size=0)

    bare_preprocess_fn = gin.query_parameter('data_streams.bare_preprocess_fn')

    eval_data = bare_preprocess_fn.scoped_configurable_fn(eval_data, training=False)

    return trax.fastmath.dataset_as_numpy(eval_data)

In [None]:
from trax import fastmath
from trax.fastmath import numpy as jnp
from tqdm import tqdm


def batched_inputs(data_gen, batch_size):
  inp_stack, mask_stack = [], []

  for input_example, mask in data_gen:
    inp_stack.append(input_example)
    mask_stack.append(mask)
    if len(inp_stack) % batch_size == 0:
      if len(set(len(example) for example in inp_stack)) > 1:
        for x, m in zip(inp_stack, mask_stack):
          yield x, m
      else:
        input_batch = jnp.stack(inp_stack)
        mask_batch = jnp.stack(mask_stack)

        yield input_batch, mask_batch
      inp_stack, mask_stack = [], []

  if len(inp_stack) > 0:
    for inp, mask in zip(inp_stack, mask_stack):
      yield inp, mask


def run_full_evaluation(accelerated_model_with_loss, examples_data_gen,
                        batch_size, pad_to_len=None):
  # Important: we assume batch size per device = 1
  assert batch_size % fastmath.local_device_count() == 0
  assert fastmath.local_device_count() == 1 or \
         batch_size == fastmath.local_device_count()

  loss_sum, n_tokens = 0.0, 0

  def pad_right(inp_tensor):
    if pad_to_len:
      return jnp.pad(inp_tensor,
                     [[0, 0], [0, max(0, pad_to_len - inp_tensor.shape[1])]])
    else:
      return inp_tensor

  batch_gen = batched_inputs(examples_data_gen, batch_size)

  def batch_leftover_example(input_example, example_mask):
    def extend_shape_to_batch_size(tensor):
      return jnp.repeat(tensor, repeats=batch_size, axis=0)

    return map(extend_shape_to_batch_size,
               (input_example[None, ...], example_mask[None, ...]))

  for i, (inp, mask) in tqdm(enumerate(batch_gen)):
    leftover_batch = False
    if len(inp.shape) == 1:
      inp, mask = batch_leftover_example(inp, mask)
      leftover_batch = True

    inp, mask = map(pad_right, [inp, mask])

    example_losses = accelerated_model_with_loss((inp, inp, mask))

    if leftover_batch:
      example_losses = example_losses[:1]
      mask = mask[:1]

    example_lengths = mask.sum(axis=-1)

    loss_sum += (example_lengths * example_losses).sum()
    n_tokens += mask.sum()

    if i % 200 == 0:
      print(f'Batches: {i}, current loss: {loss_sum / float(n_tokens)}')

  return loss_sum / float(n_tokens)

# ImageNet32 evaluation

In [None]:
def data_gen(dataset):
    for example in dataset:
        example = example['image']
        mask = jnp.ones_like(example)
        yield example, mask

BATCH_SIZE = 8
eval_data_gen = data_gen(get_eval_dataset())

loss = run_full_evaluation(model_eval, eval_data_gen, BATCH_SIZE)
print(f'Final perplexity: {loss}, final bpd: {loss / jnp.log(2)}')

# ImageNet64 evaluation

In [None]:
gin.parse_config_file('hourglass_imagenet64.gin')

model = trax.models.HourglassLM(mode='eval')
model.init_from_file(
    'gs://trax-ml/hourglass/imagenet64/model_300000.pkl.gz',
    weights_only=True,
)

loss_fn = trax.layers.WeightedCategoryCrossEntropy()
model_eval = trax.layers.Accelerate(trax.layers.Serial(
    model,
    loss_fn
))

In [None]:
BATCH_SIZE = 8
eval_data_gen = data_gen(get_eval_dataset())

loss = run_full_evaluation(model_eval, eval_data_gen, BATCH_SIZE)
print(f'Final perplexity: {loss}, final bpd: {loss / jnp.log(2)}')