#### Copyright 2021 Google LLC.

In [None]:
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Hourglass: ImageNet32/64 evaluation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/research/examples/hourglass_downsampled_imagenet.ipynb)

### Install dependencies

In [None]:
!pip install -q --upgrade jaxlib==0.1.71+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
!pip install -q --upgrade jax==0.2.21
!pip install -q git+https://github.com/google/trax.git
!pip install -q pickle5
!pip install -q gin

In [None]:
# Execute this for a proper TPU setup!
# Make sure the Colab Runtime is set to Accelerator: TPU.
import jax
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
    url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20200416'
    resp = requests.post(url)
    TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
jax.devices()

### Download ImageNet32/64 data

Downloading the datasets for evaluation requires some hacks because URLs from `tensorflow_datasets` are invalid. Two cells below download data for ImageNet32 and ImageNet64, respectively. Choose the one appropriate for the checkpoint you want to evaluate.

In [None]:
# Download ImageNet32 data (the url in tfds is down)
!gdown https://drive.google.com/uc?id=1OV4lBnuIcbqeuoiK83jWtlnQ9Afl6Tsr
!tar -zxf /content/im32.tar.gz

# tfds hack for imagenet32
import json
json_path = '/content/content/drive/MyDrive/imagenet/downsampled_imagenet/32x32/2.0.0/dataset_info.json'
with open(json_path, mode='r') as f:
    ds_info = json.load(f)
    if 'moduleName' in ds_info:
        del ds_info['moduleName']
with open(json_path, mode='w') as f:
    json.dump(ds_info, f)

!mkdir -p /root/tensorflow_datasets/downsampled_imagenet/32x32
!cp -r /content/content/drive/MyDrive/imagenet/downsampled_imagenet/32x32/2.0.0 /root/tensorflow_datasets/downsampled_imagenet/32x32

In [None]:
# Download  and set up ImageNet64 (validation only) data
!gdown https://drive.google.com/uc?id=1ZoI3ZKMUXfrIlqPfIBCcegoe0aJHchpo

!tar -zxf im64_valid.tar.gz
!mkdir -p /root/tensorflow_datasets/downsampled_imagenet/64x64/2.0.0
!cp im64_valid/* /root/tensorflow_datasets/downsampled_imagenet/64x64/2.0.0

In [None]:
# Download gin configs
!wget -q https://raw.githubusercontent.com/google/trax/master/trax/supervised/configs/hourglass_imagenet32.gin
!wget -q https://raw.githubusercontent.com/google/trax/master/trax/supervised/configs/hourglass_imagenet64.gin

### Load the ImageNet32 model

This colab can be used to evaluate both imagenet32 and imagenet64 models. We start with our ImageNet32 checkpoint.

In [None]:
gin.parse_config_file('hourglass_imagenet32.gin')

model = trax.models.HourglassLM(mode='eval')
model.init_from_file(
    'gs://trax-ml/hourglass/imagenet32/model_470000.pkl.gz',
    weights_only=True,
)

loss_fn = trax.layers.WeightedCategoryCrossEntropy()
model_eval = trax.layers.Accelerate(trax.layers.Serial(
    model,
    loss_fn
))

### Evaluate on the validation set

In [None]:
import gin
import trax

# Here is the hacky part to remove shuffling of the dataset
def get_eval_dataset():
    dataset_name = gin.query_parameter('data_streams.dataset_name')
    data_dir = trax.data.tf_inputs.download_and_prepare(dataset_name, None)

    train_data, eval_data, keys = trax.data.tf_inputs._train_and_eval_dataset(
        dataset_name, data_dir, eval_holdout_size=0)

    bare_preprocess_fn = gin.query_parameter('data_streams.bare_preprocess_fn')

    eval_data = bare_preprocess_fn.scoped_configurable_fn(eval_data, training=False)

    return trax.fastmath.dataset_as_numpy(eval_data)

In [None]:
from trax import fastmath
from trax.fastmath import numpy as jnp
from tqdm import tqdm


def batched_inputs(data_gen, batch_size):
  inp_stack, mask_stack = [], []

  for input_example, mask in data_gen:
    inp_stack.append(input_example)
    mask_stack.append(mask)
    if len(inp_stack) % batch_size == 0:
      if len(set(len(example) for example in inp_stack)) > 1:
        for x, m in zip(inp_stack, mask_stack):
          yield x, m
      else:
        input_batch = jnp.stack(inp_stack)
        mask_batch = jnp.stack(mask_stack)

        yield input_batch, mask_batch
      inp_stack, mask_stack = [], []

  if len(inp_stack) > 0:
    for inp, mask in zip(inp_stack, mask_stack):
      yield inp, mask


def run_full_evaluation(accelerated_model_with_loss, examples_data_gen,
                        batch_size, pad_to_len=None):
  # Important: we assume batch size per device = 1
  assert batch_size % fastmath.local_device_count() == 0
  assert fastmath.local_device_count() == 1 or \
         batch_size == fastmath.local_device_count()

  loss_sum, n_tokens = 0.0, 0

  def pad_right(inp_tensor):
    if pad_to_len:
      return jnp.pad(inp_tensor,
                     [[0, 0], [0, max(0, pad_to_len - inp_tensor.shape[1])]])
    else:
      return inp_tensor

  batch_gen = batched_inputs(examples_data_gen, batch_size)

  def batch_leftover_example(input_example, example_mask):
    def extend_shape_to_batch_size(tensor):
      return jnp.repeat(tensor, repeats=batch_size, axis=0)

    return map(extend_shape_to_batch_size,
               (input_example[None, ...], example_mask[None, ...]))

  for i, (inp, mask) in tqdm(enumerate(batch_gen)):
    leftover_batch = False
    if len(inp.shape) == 1:
      inp, mask = batch_leftover_example(inp, mask)
      leftover_batch = True

    inp, mask = map(pad_right, [inp, mask])

    example_losses = accelerated_model_with_loss((inp, inp, mask))

    if leftover_batch:
      example_losses = example_losses[:1]
      mask = mask[:1]

    example_lengths = mask.sum(axis=-1)

    loss_sum += (example_lengths * example_losses).sum()
    n_tokens += mask.sum()

    if i % 200 == 0:
      print(f'Batches: {i}, current loss: {loss_sum / float(n_tokens)}')

  return loss_sum / float(n_tokens)

# ImageNet32 evaluation

In [None]:
def data_gen(dataset):
    for example in dataset:
        example = example['image']
        mask = jnp.ones_like(example)
        yield example, mask

BATCH_SIZE = 8
eval_data_gen = data_gen(get_eval_dataset())

loss = run_full_evaluation(model_eval, eval_data_gen, BATCH_SIZE)
print(f'Final perplexity: {loss}, final bpd: {loss / jnp.log(2)}')

1it [03:59, 239.58s/it]

Batches: 0, current loss: 2.678892135620117


201it [05:37,  2.08it/s]

Batches: 200, current loss: 2.6045994758605957


401it [07:14,  2.05it/s]

Batches: 400, current loss: 2.6076602935791016


601it [08:51,  2.07it/s]

Batches: 600, current loss: 2.596557378768921


801it [10:28,  2.06it/s]

Batches: 800, current loss: 2.5989012718200684


1001it [12:06,  2.03it/s]

Batches: 1000, current loss: 2.5992825031280518


1201it [13:44,  2.02it/s]

Batches: 1200, current loss: 2.5981056690216064


1401it [15:24,  2.05it/s]

Batches: 1400, current loss: 2.596987724304199


1601it [17:01,  2.03it/s]

Batches: 1600, current loss: 2.59686279296875


1801it [18:39,  2.08it/s]

Batches: 1800, current loss: 2.5934829711914062


2001it [20:17,  2.06it/s]

Batches: 2000, current loss: 2.591012716293335


2201it [21:55,  2.03it/s]

Batches: 2200, current loss: 2.5882177352905273


2401it [23:33,  2.03it/s]

Batches: 2400, current loss: 2.5889804363250732


2601it [25:13,  2.04it/s]

Batches: 2600, current loss: 2.591583251953125


2801it [26:52,  2.03it/s]

Batches: 2800, current loss: 2.5910513401031494


3001it [28:30,  2.00it/s]

Batches: 3000, current loss: 2.5904479026794434


3201it [30:08,  2.00it/s]

Batches: 3200, current loss: 2.590895891189575


3401it [31:47,  2.03it/s]

Batches: 3400, current loss: 2.589193105697632


3601it [33:25,  1.99it/s]

Batches: 3600, current loss: 2.5899178981781006


3801it [35:03,  2.01it/s]

Batches: 3800, current loss: 2.5915656089782715


4001it [36:41,  2.04it/s]

Batches: 4000, current loss: 2.591648578643799


4201it [38:20,  2.00it/s]

Batches: 4200, current loss: 2.59226655960083


4401it [40:00,  2.01it/s]

Batches: 4400, current loss: 2.591513156890869


4601it [41:39,  1.99it/s]

Batches: 4600, current loss: 2.591796875


4801it [43:18,  2.02it/s]

Batches: 4800, current loss: 2.5918002128601074


5001it [44:58,  1.98it/s]

Batches: 5000, current loss: 2.5916788578033447


5201it [46:37,  2.02it/s]

Batches: 5200, current loss: 2.5913193225860596


5401it [48:17,  2.02it/s]

Batches: 5400, current loss: 2.591803550720215


5601it [49:55,  1.76it/s]

Batches: 5600, current loss: 2.592107057571411


5801it [51:34,  2.03it/s]

Batches: 5800, current loss: 2.5916154384613037


6001it [53:12,  2.02it/s]

Batches: 6000, current loss: 2.592539072036743


6201it [54:51,  1.99it/s]

Batches: 6200, current loss: 2.592684268951416


6256it [55:18,  1.88it/s]

Final perplexity: 2.5927982330322266, final bpd: 3.740617513656616





# ImageNet64 evaluation

In [None]:
gin.parse_config_file('hourglass_imagenet64.gin')

model = trax.models.HourglassLM(mode='eval')
model.init_from_file(
    'gs://trax-ml/hourglass/imagenet64/model_300000.pkl.gz',
    weights_only=True,
)

loss_fn = trax.layers.WeightedCategoryCrossEntropy()
model_eval = trax.layers.Accelerate(trax.layers.Serial(
    model,
    loss_fn
))

In [None]:
BATCH_SIZE = 8
eval_data_gen = data_gen(get_eval_dataset())

loss = run_full_evaluation(model_eval, eval_data_gen, BATCH_SIZE)
print(f'Final perplexity: {loss}, final bpd: {loss / jnp.log(2)}')

1it [02:50, 170.04s/it]

Batches: 0, current loss: 2.3700501918792725


201it [05:45,  1.13it/s]

Batches: 200, current loss: 2.3674206733703613


401it [08:41,  1.14it/s]

Batches: 400, current loss: 2.387157440185547


601it [11:36,  1.14it/s]

Batches: 600, current loss: 2.394706964492798


801it [14:31,  1.13it/s]

Batches: 800, current loss: 2.39194917678833


1001it [17:26,  1.14it/s]

Batches: 1000, current loss: 2.3922457695007324


1201it [20:22,  1.14it/s]

Batches: 1200, current loss: 2.392825126647949


1401it [23:17,  1.12it/s]

Batches: 1400, current loss: 2.392895221710205


1601it [26:13,  1.13it/s]

Batches: 1600, current loss: 2.390683650970459


1801it [29:09,  1.14it/s]

Batches: 1800, current loss: 2.3893117904663086


2001it [32:05,  1.13it/s]

Batches: 2000, current loss: 2.3903186321258545


2201it [35:00,  1.14it/s]

Batches: 2200, current loss: 2.391557216644287


2401it [37:55,  1.14it/s]

Batches: 2400, current loss: 2.389714241027832


2601it [40:51,  1.14it/s]

Batches: 2600, current loss: 2.3874661922454834


2801it [43:45,  1.14it/s]

Batches: 2800, current loss: 2.387363910675049


3001it [46:41,  1.15it/s]

Batches: 3000, current loss: 2.3869481086730957


3201it [49:36,  1.14it/s]

Batches: 3200, current loss: 2.3882715702056885


3401it [52:32,  1.15it/s]

Batches: 3400, current loss: 2.3873984813690186


3601it [55:27,  1.14it/s]

Batches: 3600, current loss: 2.3865926265716553


3801it [58:22,  1.13it/s]

Batches: 3800, current loss: 2.3868348598480225


4001it [1:01:18,  1.13it/s]

Batches: 4000, current loss: 2.387111186981201


4201it [1:04:14,  1.14it/s]

Batches: 4200, current loss: 2.3866043090820312


4401it [1:07:09,  1.14it/s]

Batches: 4400, current loss: 2.38643479347229


4601it [1:10:05,  1.15it/s]

Batches: 4600, current loss: 2.387312889099121


4801it [1:13:01,  1.14it/s]

Batches: 4800, current loss: 2.387418031692505


5001it [1:15:57,  1.14it/s]

Batches: 5000, current loss: 2.3877131938934326


5201it [1:18:53,  1.13it/s]

Batches: 5200, current loss: 2.3879053592681885


5401it [1:21:48,  1.14it/s]

Batches: 5400, current loss: 2.387753963470459


5601it [1:24:44,  1.12it/s]

Batches: 5600, current loss: 2.387204885482788


5801it [1:27:40,  1.13it/s]

Batches: 5800, current loss: 2.386263847351074


6001it [1:30:36,  1.12it/s]

Batches: 6000, current loss: 2.386061906814575


6201it [1:33:32,  1.13it/s]

Batches: 6200, current loss: 2.3861241340637207


6256it [1:34:20,  1.11it/s]

Final perplexity: 2.386467456817627, final bpd: 3.4429450035095215



