In [21]:
!pip install git+https://github.com/deepmind/dm-haiku
!pip install optax
#!pip install jax==0.2.11
!pip install dm-acme
!dm-tree
#!pip install jax
#!pip install -U jaxlib
#numpy>=1.16

Collecting git+https://github.com/deepmind/dm-haiku
  Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-25ci6yfa
  Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-25ci6yfa
/bin/bash: dm-tree: command not found


# Utils

In [22]:
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions."""

from typing import Optional, Text
from absl import logging
import jax
import jax.numpy as jnp


def topk_accuracy(
    logits: jnp.ndarray,
    labels: jnp.ndarray,
    topk: int,
    ignore_label_above: Optional[int] = None,
) -> jnp.ndarray:
  """Top-num_codes accuracy."""
  assert len(labels.shape) == 1, 'topk expects 1d int labels.'
  assert len(logits.shape) == 2, 'topk expects 2d logits.'

  if ignore_label_above is not None:
    logits = logits[labels < ignore_label_above, :]
    labels = labels[labels < ignore_label_above]

  prds = jnp.argsort(logits, axis=1)[:, ::-1]
  prds = prds[:, :topk]
  total = jnp.any(prds == jnp.tile(labels[:, jnp.newaxis], [1, topk]), axis=1)

  return total


def softmax_cross_entropy(
    logits: jnp.ndarray,
    labels: jnp.ndarray,
    reduction: Optional[Text] = 'mean',
) -> jnp.ndarray:
  """Computes softmax cross entropy given logits and one-hot class labels.

  Args:
    logits: Logit output values.
    labels: Ground truth one-hot-encoded labels.
    reduction: Type of reduction to apply to loss.

  Returns:
    Loss value. If `reduction` is `none`, this has the same shape as `labels`;
    otherwise, it is scalar.

  Raises:
    ValueError: If the type of `reduction` is unsupported.
  """
  loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
  if reduction == 'sum':
    return jnp.sum(loss)
  elif reduction == 'mean':
    return jnp.mean(loss)
  elif reduction == 'none' or reduction is None:
    return loss
  else:
    raise ValueError(f'Incorrect reduction mode {reduction}')


def l2_normalize(
    x: jnp.ndarray,
    axis: Optional[int] = None,
    epsilon: float = 1e-12,
) -> jnp.ndarray:
  """l2 normalize a tensor on an axis with numerical stability."""
  square_sum = jnp.sum(jnp.square(x), axis=axis, keepdims=True)
  x_inv_norm = jax.lax.rsqrt(jnp.maximum(square_sum, epsilon))
  return x * x_inv_norm


def l2_weight_regularizer(params):
  """Helper to do lasso on weights.

  Args:
    params: the entire param set.

  Returns:
    Scalar of the l2 norm of the weights.
  """
  l2_norm = 0.
  for mod_name, mod_params in params.items():
    if 'norm' not in mod_name:
      for param_k, param_v in mod_params.items():
        if param_k != 'b' not in param_k:  # Filter out biases
          l2_norm += jnp.sum(jnp.square(param_v))
        else:
          logging.warning('Excluding %s/%s from optimizer weight decay!',
                          mod_name, param_k)
    else:
      logging.warning('Excluding %s from optimizer weight decay!', mod_name)

  return 0.5 * l2_norm


def regression_loss(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
  """Byol's regression loss. This is a simple cosine similarity."""
  normed_x, normed_y = l2_normalize(x, axis=-1), l2_normalize(y, axis=-1)
  return jnp.sum((normed_x - normed_y)**2, axis=-1)


def bcast_local_devices(value):
  """Broadcasts an object to all local devices."""
  devices = jax.local_devices()

  def _replicate(x):
    """Replicate an object on each device."""
    x = jnp.array(x)
    return jax.device_put_sharded(len(devices) * [x], devices)

  return jax.tree_util.tree_map(_replicate, value)


def get_first(xs):
  """Gets values from the first device."""
  return jax.tree_map(lambda x: x[0], xs)


In [23]:
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data preprocessing and augmentation."""

import functools
from typing import Any, Mapping, Text

import jax
import jax.numpy as jnp

# typing
JaxBatch = Mapping[Text, jnp.ndarray]
ConfigDict = Mapping[Text, Any]

augment_config = dict(
    view1=dict(
        random_flip=True,  # Random left/right flip
        color_transform=dict(
            apply_prob=1.0,
            # Range of jittering
            brightness=0.4,
            contrast=0.4,
            saturation=0.2,
            hue=0.1,
            # Probability of applying color jittering
            color_jitter_prob=0.8,
            # Probability of converting to grayscale
            to_grayscale_prob=0.2,
            # Shuffle the order of color transforms
            shuffle=True),
        gaussian_blur=dict(
            apply_prob=1.0,
            # Kernel size ~ image_size / blur_divider
            blur_divider=10.,
            # Kernel distribution
            sigma_min=0.1,
            sigma_max=2.0),
        solarize=dict(apply_prob=0.0, threshold=0.5),
    ),
    view2=dict(
        random_flip=True,
        color_transform=dict(
            apply_prob=1.0,
            brightness=0.4,
            contrast=0.4,
            saturation=0.2,
            hue=0.1,
            color_jitter_prob=0.8,
            to_grayscale_prob=0.2,
            shuffle=True),
        gaussian_blur=dict(
            apply_prob=0.1, blur_divider=10., sigma_min=0.1, sigma_max=2.0),
        solarize=dict(apply_prob=0.2, threshold=0.5),
    ))


def postprocess(inputs: JaxBatch, rng: jnp.ndarray):
  """Apply the image augmentations to crops in inputs (view1 and view2)."""

  def _postprocess_image(
      images: jnp.ndarray,
      rng: jnp.ndarray,
      presets: ConfigDict,
  ) -> JaxBatch:
    """Applies augmentations in post-processing.

    Args:
      images: an NHWC tensor (with C=3), with float values in [0, 1].
      rng: a single PRNGKey.
      presets: a dict of presets for the augmentations.

    Returns:
      A batch of augmented images with shape NHWC, with keys view1, view2
      and labels.
    """
    flip_rng, color_rng, blur_rng, solarize_rng = jax.random.split(rng, 4)
    out = images
    if presets['random_flip']:
      out = random_flip(out, flip_rng)
    if presets['color_transform']['apply_prob'] > 0:
      out = color_transform(out, color_rng, **presets['color_transform'])
    if presets['gaussian_blur']['apply_prob'] > 0:
      out = gaussian_blur(out, blur_rng, **presets['gaussian_blur'])
    if presets['solarize']['apply_prob'] > 0:
      out = solarize(out, solarize_rng, **presets['solarize'])
    out = jnp.clip(out, 0., 1.)
    return jax.lax.stop_gradient(out)

  rng1, rng2 = jax.random.split(rng, num=2)
  view1 = _postprocess_image(inputs['view1'], rng1, augment_config['view1'])
  view2 = _postprocess_image(inputs['view2'], rng2, augment_config['view2'])
  return dict(view1=view1, view2=view2, labels=inputs['labels'])


def _maybe_apply(apply_fn, inputs, rng, apply_prob):
  should_apply = jax.random.uniform(rng, shape=()) <= apply_prob
  return jax.lax.cond(should_apply, inputs, apply_fn, inputs, lambda x: x)


def _depthwise_conv2d(inputs, kernel, strides, padding):
  """Computes a depthwise conv2d in Jax.

  Args:
    inputs: an NHWC tensor with N=1.
    kernel: a [H", W", 1, C] tensor.
    strides: a 2d tensor.
    padding: "SAME" or "VALID".

  Returns:
    The depthwise convolution of inputs with kernel, as [H, W, C].
  """
  return jax.lax.conv_general_dilated(
      inputs,
      kernel,
      strides,
      padding,
      feature_group_count=inputs.shape[-1],
      dimension_numbers=('NHWC', 'HWIO', 'NHWC'))


def _gaussian_blur_single_image(image, kernel_size, padding, sigma):
  """Applies gaussian blur to a single image, given as NHWC with N=1."""
  radius = int(kernel_size / 2)
  kernel_size_ = 2 * radius + 1
  x = jnp.arange(-radius, radius + 1).astype(jnp.float32)
  blur_filter = jnp.exp(-x**2 / (2. * sigma**2))
  blur_filter = blur_filter / jnp.sum(blur_filter)
  blur_v = jnp.reshape(blur_filter, [kernel_size_, 1, 1, 1])
  blur_h = jnp.reshape(blur_filter, [1, kernel_size_, 1, 1])
  num_channels = image.shape[-1]
  blur_h = jnp.tile(blur_h, [1, 1, 1, num_channels])
  blur_v = jnp.tile(blur_v, [1, 1, 1, num_channels])
  expand_batch_dim = len(image.shape) == 3
  if expand_batch_dim:
    image = image[jnp.newaxis, ...]
  blurred = _depthwise_conv2d(image, blur_h, strides=[1, 1], padding=padding)
  blurred = _depthwise_conv2d(blurred, blur_v, strides=[1, 1], padding=padding)
  blurred = jnp.squeeze(blurred, axis=0)
  return blurred


def _random_gaussian_blur(image, rng, kernel_size, padding, sigma_min,
                          sigma_max, apply_prob):
  """Applies a random gaussian blur."""
  apply_rng, transform_rng = jax.random.split(rng)

  def _apply(image):
    sigma_rng, = jax.random.split(transform_rng, 1)
    sigma = jax.random.uniform(
        sigma_rng,
        shape=(),
        minval=sigma_min,
        maxval=sigma_max,
        dtype=jnp.float32)
    return _gaussian_blur_single_image(image, kernel_size, padding, sigma)

  return _maybe_apply(_apply, image, apply_rng, apply_prob)


def rgb_to_hsv(r, g, b):
  """Converts R, G, B  values to H, S, V values.

  Reference TF implementation:
  https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc
  Only input values between 0 and 1 are guaranteed to work properly, but this
  function complies with the TF implementation outside of this range.

  Args:
    r: A tensor representing the red color component as floats.
    g: A tensor representing the green color component as floats.
    b: A tensor representing the blue color component as floats.

  Returns:
    H, S, V values, each as tensors of shape [...] (same as the input without
    the last dimension).
  """
  vv = jnp.maximum(jnp.maximum(r, g), b)
  range_ = vv - jnp.minimum(jnp.minimum(r, g), b)
  sat = jnp.where(vv > 0, range_ / vv, 0.)
  norm = jnp.where(range_ != 0, 1. / (6. * range_), 1e9)

  hr = norm * (g - b)
  hg = norm * (b - r) + 2. / 6.
  hb = norm * (r - g) + 4. / 6.

  hue = jnp.where(r == vv, hr, jnp.where(g == vv, hg, hb))
  hue = hue * (range_ > 0)
  hue = hue + (hue < 0)

  return hue, sat, vv


def hsv_to_rgb(h, s, v):
  """Converts H, S, V values to an R, G, B tuple.

  Reference TF implementation:
  https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc
  Only input values between 0 and 1 are guaranteed to work properly, but this
  function complies with the TF implementation outside of this range.

  Args:
    h: A float tensor of arbitrary shape for the hue (0-1 values).
    s: A float tensor of the same shape for the saturation (0-1 values).
    v: A float tensor of the same shape for the value channel (0-1 values).

  Returns:
    An (r, g, b) tuple, each with the same dimension as the inputs.
  """
  c = s * v
  m = v - c
  dh = (h % 1.) * 6.
  fmodu = dh % 2.
  x = c * (1 - jnp.abs(fmodu - 1))
  hcat = jnp.floor(dh).astype(jnp.int32)
  rr = jnp.where(
      (hcat == 0) | (hcat == 5), c, jnp.where(
          (hcat == 1) | (hcat == 4), x, 0)) + m
  gg = jnp.where(
      (hcat == 1) | (hcat == 2), c, jnp.where(
          (hcat == 0) | (hcat == 3), x, 0)) + m
  bb = jnp.where(
      (hcat == 3) | (hcat == 4), c, jnp.where(
          (hcat == 2) | (hcat == 5), x, 0)) + m
  return rr, gg, bb


def adjust_brightness(rgb_tuple, delta):
  return jax.tree_map(lambda x: x + delta, rgb_tuple)


def adjust_contrast(image, factor):
  def _adjust_contrast_channel(channel):
    mean = jnp.mean(channel, axis=(-2, -1), keepdims=True)
    return factor * (channel - mean) + mean
  return jax.tree_map(_adjust_contrast_channel, image)


def adjust_saturation(h, s, v, factor):
  return h, jnp.clip(s * factor, 0., 1.), v


def adjust_hue(h, s, v, delta):
  # Note: this method exactly matches TF"s adjust_hue (combined with the hsv/rgb
  # conversions) when running on GPU. When running on CPU, the results will be
  # different if all RGB values for a pixel are outside of the [0, 1] range.
  return (h + delta) % 1.0, s, v


def _random_brightness(rgb_tuple, rng, max_delta):
  delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta)
  return adjust_brightness(rgb_tuple, delta)


def _random_contrast(rgb_tuple, rng, max_delta):
  factor = jax.random.uniform(
      rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta)
  return adjust_contrast(rgb_tuple, factor)


def _random_saturation(rgb_tuple, rng, max_delta):
  h, s, v = rgb_to_hsv(*rgb_tuple)
  factor = jax.random.uniform(
      rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta)
  return hsv_to_rgb(*adjust_saturation(h, s, v, factor))


def _random_hue(rgb_tuple, rng, max_delta):
  h, s, v = rgb_to_hsv(*rgb_tuple)
  delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta)
  return hsv_to_rgb(*adjust_hue(h, s, v, delta))


def _to_grayscale(image):
  rgb_weights = jnp.array([0.2989, 0.5870, 0.1140])
  grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[..., jnp.newaxis]
  return jnp.tile(grayscale, (1, 1, 3))  # Back to 3 channels.


def _color_transform_single_image(image, rng, brightness, contrast, saturation,
                                  hue, to_grayscale_prob, color_jitter_prob,
                                  apply_prob, shuffle):
  """Applies color jittering to a single image."""
  apply_rng, transform_rng = jax.random.split(rng)
  perm_rng, b_rng, c_rng, s_rng, h_rng, cj_rng, gs_rng = jax.random.split(
      transform_rng, 7)

  # Whether the transform should be applied at all.
  should_apply = jax.random.uniform(apply_rng, shape=()) <= apply_prob
  # Whether to apply grayscale transform.
  should_apply_gs = jax.random.uniform(gs_rng, shape=()) <= to_grayscale_prob
  # Whether to apply color jittering.
  should_apply_color = jax.random.uniform(cj_rng, shape=()) <= color_jitter_prob

  # Decorator to conditionally apply fn based on an index.
  def _make_cond(fn, idx):

    def identity_fn(x, unused_rng, unused_param):
      return x

    def cond_fn(args, i):
      def clip(args):
        return jax.tree_map(lambda arg: jnp.clip(arg, 0., 1.), args)
      out = jax.lax.cond(should_apply & should_apply_color & (i == idx), args,
                         lambda a: clip(fn(*a)), args,
                         lambda a: identity_fn(*a))
      return jax.lax.stop_gradient(out)

    return cond_fn

  random_brightness_cond = _make_cond(_random_brightness, idx=0)
  random_contrast_cond = _make_cond(_random_contrast, idx=1)
  random_saturation_cond = _make_cond(_random_saturation, idx=2)
  random_hue_cond = _make_cond(_random_hue, idx=3)

  def _color_jitter(x):
    rgb_tuple = tuple(jax.tree_map(jnp.squeeze, jnp.split(x, 3, axis=-1)))
    if shuffle:
      order = jax.random.permutation(perm_rng, jnp.arange(4, dtype=jnp.int32))
    else:
      order = range(4)
    for idx in order:
      if brightness > 0:
        rgb_tuple = random_brightness_cond((rgb_tuple, b_rng, brightness), idx)
      if contrast > 0:
        rgb_tuple = random_contrast_cond((rgb_tuple, c_rng, contrast), idx)
      if saturation > 0:
        rgb_tuple = random_saturation_cond((rgb_tuple, s_rng, saturation), idx)
      if hue > 0:
        rgb_tuple = random_hue_cond((rgb_tuple, h_rng, hue), idx)
    return jnp.stack(rgb_tuple, axis=-1)

  out_apply = _color_jitter(image)
  out_apply = jax.lax.cond(should_apply & should_apply_gs, out_apply,
                           _to_grayscale, out_apply, lambda x: x)
  return jnp.clip(out_apply, 0., 1.)


def _random_flip_single_image(image, rng):
  _, flip_rng = jax.random.split(rng)
  should_flip_lr = jax.random.uniform(flip_rng, shape=()) <= 0.5
  image = jax.lax.cond(should_flip_lr, image, jnp.fliplr, image, lambda x: x)
  return image


def random_flip(images, rng):
  rngs = jax.random.split(rng, images.shape[0])
  return jax.vmap(_random_flip_single_image)(images, rngs)


def color_transform(images,
                    rng,
                    brightness=0.8,
                    contrast=0.8,
                    saturation=0.8,
                    hue=0.2,
                    color_jitter_prob=0.8,
                    to_grayscale_prob=0.2,
                    apply_prob=1.0,
                    shuffle=True):
  """Applies color jittering and/or grayscaling to a batch of images.

  Args:
    images: an NHWC tensor, with C=3.
    rng: a single PRNGKey.
    brightness: the range of jitter on brightness.
    contrast: the range of jitter on contrast.
    saturation: the range of jitter on saturation.
    hue: the range of jitter on hue.
    color_jitter_prob: the probability of applying color jittering.
    to_grayscale_prob: the probability of converting the image to grayscale.
    apply_prob: the probability of applying the transform to a batch element.
    shuffle: whether to apply the transforms in a random order.

  Returns:
    A NHWC tensor of the transformed images.
  """
  rngs = jax.random.split(rng, images.shape[0])
  jitter_fn = functools.partial(
      _color_transform_single_image,
      brightness=brightness,
      contrast=contrast,
      saturation=saturation,
      hue=hue,
      color_jitter_prob=color_jitter_prob,
      to_grayscale_prob=to_grayscale_prob,
      apply_prob=apply_prob,
      shuffle=shuffle)
  return jax.vmap(jitter_fn)(images, rngs)


def gaussian_blur(images,
                  rng,
                  blur_divider=10.,
                  sigma_min=0.1,
                  sigma_max=2.0,
                  apply_prob=1.0):
  """Applies gaussian blur to a batch of images.

  Args:
    images: an NHWC tensor, with C=3.
    rng: a single PRNGKey.
    blur_divider: the blurring kernel will have size H / blur_divider.
    sigma_min: the minimum value for sigma in the blurring kernel.
    sigma_max: the maximum value for sigma in the blurring kernel.
    apply_prob: the probability of applying the transform to a batch element.

  Returns:
    A NHWC tensor of the blurred images.
  """
  rngs = jax.random.split(rng, images.shape[0])
  kernel_size = images.shape[1] / blur_divider
  blur_fn = functools.partial(
      _random_gaussian_blur,
      kernel_size=kernel_size,
      padding='SAME',
      sigma_min=sigma_min,
      sigma_max=sigma_max,
      apply_prob=apply_prob)
  return jax.vmap(blur_fn)(images, rngs)


def _solarize_single_image(image, rng, threshold, apply_prob):

  def _apply(image):
    return jnp.where(image < threshold, image, 1. - image)

  return _maybe_apply(_apply, image, rng, apply_prob)


def solarize(images, rng, threshold=0.5, apply_prob=1.0):
  """Applies solarization.

  Args:
    images: an NHWC tensor (with C=3).
    rng: a single PRNGKey.
    threshold: the solarization threshold.
    apply_prob: the probability of applying the transform to a batch element.

  Returns:
    A NHWC tensor of the transformed images.
  """
  rngs = jax.random.split(rng, images.shape[0])
  solarize_fn = functools.partial(
      _solarize_single_image, threshold=threshold, apply_prob=apply_prob)
  return jax.vmap(solarize_fn)(images, rngs)


In [24]:
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpoint saving and restoring utilities."""

import os
import time
from typing import Mapping, Text, Tuple, Union

from absl import logging
import dill
import jax
import jax.numpy as jnp

#from byol.utils import helpers


class Checkpointer:
  """A checkpoint saving and loading class."""

  def __init__(
      self,
      use_checkpointing: bool,
      checkpoint_dir: Text,
      save_checkpoint_interval: int,
      filename: Text):
    if (not use_checkpointing or
        checkpoint_dir is None or
        save_checkpoint_interval <= 0):
      self._checkpoint_enabled = False
      return

    self._checkpoint_enabled = True
    self._checkpoint_dir = checkpoint_dir
    os.makedirs(self._checkpoint_dir, exist_ok=True)
    self._filename = filename
    self._checkpoint_path = os.path.join(self._checkpoint_dir, filename)
    self._last_checkpoint_time = 0
    self._checkpoint_every = save_checkpoint_interval

  def maybe_save_checkpoint(
      self,
      experiment_state: Mapping[Text, jnp.ndarray],
      step: int,
      rng: jnp.ndarray,
      is_final: bool):
    """Saves a checkpoint if enough time has passed since the previous one."""
    current_time = time.time()
    if (not self._checkpoint_enabled or
        jax.host_id() != 0 or  # Only checkpoint the first worker.
        (not is_final and
         current_time - self._last_checkpoint_time < self._checkpoint_every)):
      return
    checkpoint_data = dict(
        experiment_state=jax.tree_map(
            lambda x: jax.device_get(x[0]), experiment_state),
        step=step,
        rng=rng)
    with open(self._checkpoint_path + '_tmp', 'wb') as checkpoint_file:
      dill.dump(checkpoint_data, checkpoint_file, protocol=2)
    try:
      os.rename(self._checkpoint_path, self._checkpoint_path + '_old')
      remove_old = True
    except FileNotFoundError:
      remove_old = False  # No previous checkpoint to remove
    os.rename(self._checkpoint_path + '_tmp', self._checkpoint_path)
    if remove_old:
      os.remove(self._checkpoint_path + '_old')
    self._last_checkpoint_time = current_time

  def maybe_load_checkpoint(
      self) -> Union[Tuple[Mapping[Text, jnp.ndarray], int, jnp.ndarray], None]:
    """Loads a checkpoint if any is found."""
    checkpoint_data = load_checkpoint(self._checkpoint_path)
    if checkpoint_data is None:
      logging.info('No existing checkpoint found at %s', self._checkpoint_path)
      return None
    step = checkpoint_data['step']
    rng = checkpoint_data['rng']
    experiment_state = jax.tree_map(
        helpers.bcast_local_devices, checkpoint_data['experiment_state'])
    del checkpoint_data
    return experiment_state, step, rng


def load_checkpoint(checkpoint_path):
  try:
    with open(checkpoint_path, 'rb') as checkpoint_file:
      checkpoint_data = dill.load(checkpoint_file)
      logging.info('Loading checkpoint from %s, saved at step %d',
                   checkpoint_path, checkpoint_data['step'])
      return checkpoint_data
  except FileNotFoundError:
    return None


In [25]:
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ImageNet dataset with typical pre-processing."""

import enum
from typing import Generator, Mapping, Optional, Sequence, Text, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds

Batch = Mapping[Text, np.ndarray]


class Split(enum.Enum):
  """Imagenet dataset split."""
  TRAIN = 1
  TRAIN_AND_VALID = 2
  VALID = 3
  TEST = 4

  @classmethod
  def from_string(cls, name: Text) -> 'Split':
    return {
        'TRAIN': Split.TRAIN,
        'TRAIN_AND_VALID': Split.TRAIN_AND_VALID,
        'VALID': Split.VALID,
        'VALIDATION': Split.VALID,
        'TEST': Split.TEST
    }[name.upper()]

  @property
  def num_examples(self):
    return {
        Split.TRAIN_AND_VALID: 60000,
        Split.TRAIN: 50000,
        Split.VALID: 0,
        Split.TEST: 10000
    }[self]


class PreprocessMode(enum.Enum):
  """Preprocessing modes for the dataset."""
  PRETRAIN = 1  # Generates two augmented views (random crop + augmentations).
  LINEAR_TRAIN = 2  # Generates a single random crop.
  EVAL = 3  # Generates a single center crop.


def normalize_images(images: jnp.ndarray) -> jnp.ndarray:
  """Normalize the image using ImageNet statistics."""
  mean_rgb = (0.485, 0.456, 0.406)
  stddev_rgb = (0.229, 0.224, 0.225)
  normed_images = images - jnp.array(mean_rgb).reshape((1, 1, 1, 3))
  normed_images = normed_images / jnp.array(stddev_rgb).reshape((1, 1, 1, 3))
  return normed_images


def load(split: Split,
         *,
         preprocess_mode: PreprocessMode,
         batch_dims: Sequence[int],
         transpose: bool = False,
         allow_caching: bool = False) -> Generator[Batch, None, None]:
  """Loads the given split of the dataset."""
  start, end = _shard(split, jax.host_id(), jax.host_count())

  total_batch_size = np.prod(batch_dims)

  tfds_split = tfds.core.ReadInstruction(
      _to_tfds_split(split), from_=start, to=end, unit='abs')
  ds = tfds.load(
      #'imagenet2012:5.*.*',
      'cifar100',
      split=tfds_split,
      decoders={'image': tfds.decode.SkipDecoding()})

  options = tf.data.Options()
  options.experimental_threading.private_threadpool_size = 48
  options.experimental_threading.max_intra_op_parallelism = 1

  if preprocess_mode is not PreprocessMode.EVAL:
    options.experimental_deterministic = False
    if jax.host_count() > 1 and allow_caching:
      # Only cache if we are reading a subset of the dataset.
      ds = ds.cache()
    ds = ds.repeat()
    ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0)

  else:
    if split.num_examples % total_batch_size != 0:
      raise ValueError(f'Test/valid must be divisible by {total_batch_size}')

  ds = ds.with_options(options)

  def preprocess_pretrain(example):
    view1 = _preprocess_image(example['image'], mode=preprocess_mode)
    view2 = _preprocess_image(example['image'], mode=preprocess_mode)
    label = tf.cast(example['label'], tf.int32)
    return {'view1': view1, 'view2': view2, 'labels': label}

  def preprocess_linear_train(example):
    image = _preprocess_image(example['image'], mode=preprocess_mode)
    label = tf.cast(example['label'], tf.int32)
    return {'images': image, 'labels': label}

  def preprocess_eval(example):
    image = _preprocess_image(example['image'], mode=preprocess_mode)
    label = tf.cast(example['label'], tf.int32)
    return {'images': image, 'labels': label}

  if preprocess_mode is PreprocessMode.PRETRAIN:
    ds = ds.map(
        preprocess_pretrain, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  elif preprocess_mode is PreprocessMode.LINEAR_TRAIN:
    ds = ds.map(
        preprocess_linear_train,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
  else:
    ds = ds.map(
        preprocess_eval, num_parallel_calls=tf.data.experimental.AUTOTUNE)

  def transpose_fn(batch):
    # We use the double-transpose-trick to improve performance for TPUs. Note
    # that this (typically) requires a matching HWCN->NHWC transpose in your
    # model code. The compiler cannot make this optimization for us since our
    # data pipeline and model are compiled separately.
    batch = dict(**batch)
    if preprocess_mode is PreprocessMode.PRETRAIN:
      batch['view1'] = tf.transpose(batch['view1'], (1, 2, 3, 0))
      batch['view2'] = tf.transpose(batch['view2'], (1, 2, 3, 0))
    else:
      batch['images'] = tf.transpose(batch['images'], (1, 2, 3, 0))
    return batch

  for i, batch_size in enumerate(reversed(batch_dims)):
    ds = ds.batch(batch_size)
    if i == 0 and transpose:
      ds = ds.map(transpose_fn)  # NHWC -> HWCN

  ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

  yield from tfds.as_numpy(ds)


def _to_tfds_split(split: Split) -> tfds.Split:
  """Returns the TFDS split appropriately sharded."""
  # NOTE: Imagenet did not release labels for the test split used in the
  # competition, we consider the VALID split the TEST split and reserve
  # 10k images from TRAIN for VALID.
  if split in (Split.TRAIN, Split.TRAIN_AND_VALID, Split.VALID):
    return tfds.Split.TRAIN
  else:
    assert split == Split.TEST
    return tfds.Split.VALIDATION


def _shard(split: Split, shard_index: int, num_shards: int) -> Tuple[int, int]:
  """Returns [start, end) for the given shard index."""
  assert shard_index < num_shards
  arange = np.arange(split.num_examples)
  shard_range = np.array_split(arange, num_shards)[shard_index]
  start, end = shard_range[0], (shard_range[-1] + 1)
  if split == Split.TRAIN:
    # Note that our TRAIN=TFDS_TRAIN[10000:] and VALID=TFDS_TRAIN[:10000].
    offset = Split.VALID.num_examples
    start += offset
    end += offset
  return start, end


def _preprocess_image(
    image_bytes: tf.Tensor,
    mode: PreprocessMode,
) -> tf.Tensor:
  """Returns processed and resized images."""
  if mode is PreprocessMode.PRETRAIN:
    image = _decode_and_random_crop(image_bytes)
    # Random horizontal flipping is optionally done in augmentations.preprocess.
  elif mode is PreprocessMode.LINEAR_TRAIN:
    image = _decode_and_random_crop(image_bytes)
    image = tf.image.random_flip_left_right(image)
  else:
    image = _decode_and_center_crop(image_bytes)
  # NOTE: Bicubic resize (1) casts uint8 to float32 and (2) resizes without
  # clamping overshoots. This means values returned will be outside the range
  # [0.0, 255.0] (e.g. we have observed outputs in the range [-51.1, 336.6]).
  assert image.dtype == tf.uint8
  image = tf.image.resize(image, [224, 224], tf.image.ResizeMethod.BICUBIC)
  image = tf.clip_by_value(image / 255., 0., 1.)
  return image


def _decode_and_random_crop(image_bytes: tf.Tensor) -> tf.Tensor:
  """Make a random crop of 224."""
  img_size = tf.image.extract_jpeg_shape(image_bytes)
  area = tf.cast(img_size[1] * img_size[0], tf.float32)
  target_area = tf.random.uniform([], 0.08, 1.0, dtype=tf.float32) * area

  log_ratio = (tf.math.log(3 / 4), tf.math.log(4 / 3))
  aspect_ratio = tf.math.exp(
      tf.random.uniform([], *log_ratio, dtype=tf.float32))

  w = tf.cast(tf.round(tf.sqrt(target_area * aspect_ratio)), tf.int32)
  h = tf.cast(tf.round(tf.sqrt(target_area / aspect_ratio)), tf.int32)

  w = tf.minimum(w, img_size[1])
  h = tf.minimum(h, img_size[0])

  offset_w = tf.random.uniform((),
                               minval=0,
                               maxval=img_size[1] - w + 1,
                               dtype=tf.int32)
  offset_h = tf.random.uniform((),
                               minval=0,
                               maxval=img_size[0] - h + 1,
                               dtype=tf.int32)

  crop_window = tf.stack([offset_h, offset_w, h, w])
  image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
  return image


def transpose_images(batch: Batch):
  """Transpose images for TPU training.."""
  new_batch = dict(batch)  # Avoid mutating in place.
  if 'images' in batch:
    new_batch['images'] = jnp.transpose(batch['images'], (3, 0, 1, 2))
  else:
    new_batch['view1'] = jnp.transpose(batch['view1'], (3, 0, 1, 2))
    new_batch['view2'] = jnp.transpose(batch['view2'], (3, 0, 1, 2))
  return new_batch


def _decode_and_center_crop(
    image_bytes: tf.Tensor,
    jpeg_shape: Optional[tf.Tensor] = None,
) -> tf.Tensor:
  """Crops to center of image with padding then scales."""
  if jpeg_shape is None:
    jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)
  image_height = jpeg_shape[0]
  image_width = jpeg_shape[1]

  padded_center_crop_size = tf.cast(
      ((224 / (224 + 32)) *
       tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)

  offset_height = ((image_height - padded_center_crop_size) + 1) // 2
  offset_width = ((image_width - padded_center_crop_size) + 1) // 2
  crop_window = tf.stack([
      offset_height, offset_width, padded_center_crop_size,
      padded_center_crop_size
  ])
  image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
  return image


In [26]:
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Networks used in BYOL."""

from typing import Any, Mapping, Optional, Sequence, Text

import haiku as hk
import jax
import jax.numpy as jnp


class MLP(hk.Module):
  """One hidden layer perceptron, with normalization."""

  def __init__(
      self,
      name: Text,
      hidden_size: int,
      output_size: int,
      bn_config: Mapping[Text, Any],
  ):
    super().__init__(name=name)
    self._hidden_size = hidden_size
    self._output_size = output_size
    self._bn_config = bn_config

  def __call__(self, inputs: jnp.ndarray, is_training: bool) -> jnp.ndarray:
    out = hk.Linear(output_size=self._hidden_size, with_bias=True)(inputs)
    out = hk.BatchNorm(**self._bn_config)(out, is_training=is_training)
    out = jax.nn.relu(out)
    out = hk.Linear(output_size=self._output_size, with_bias=False)(out)
    return out


def check_length(length, value, name):
  if len(value) != length:
    raise ValueError(f'`{name}` must be of length 4 not {len(value)}')


class ResNetTorso(hk.Module):
  """ResNet model."""

  def __init__(
      self,
      blocks_per_group: Sequence[int],
      num_classes: Optional[int] = None,
      bn_config: Optional[Mapping[str, float]] = None,
      resnet_v2: bool = False,
      bottleneck: bool = True,
      channels_per_group: Sequence[int] = (256, 512, 1024, 2048),
      use_projection: Sequence[bool] = (True, True, True, True),
      width_multiplier: int = 1,
      name: Optional[str] = None,
  ):
    """Constructs a ResNet model.

    Args:
      blocks_per_group: A sequence of length 4 that indicates the number of
        blocks created in each group.
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of three elements, `decay_rate`, `eps`, and
        `cross_replica_axis`, to be passed on to the `BatchNorm` layers. By
        default the `decay_rate` is `0.9` and `eps` is `1e-5`, and the axis is
        `None`.
      resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to
        False.
       bottleneck: Whether the block should bottleneck or not. Defaults to True.
      channels_per_group: A sequence of length 4 that indicates the number
        of channels used for each block in each group.
      use_projection: A sequence of length 4 that indicates whether each
        residual block should use projection.
      width_multiplier: An integer multiplying the number of channels per group.
      name: Name of the module.
    """
    super().__init__(name=name)
    self.resnet_v2 = resnet_v2

    bn_config = dict(bn_config or {})
    bn_config.setdefault('decay_rate', 0.9)
    bn_config.setdefault('eps', 1e-5)
    bn_config.setdefault('create_scale', True)
    bn_config.setdefault('create_offset', True)

    # Number of blocks in each group for ResNet.
    check_length(4, blocks_per_group, 'blocks_per_group')
    check_length(4, channels_per_group, 'channels_per_group')

    self.initial_conv = hk.Conv2D(
        output_channels=64 * width_multiplier,
        kernel_shape=7,
        stride=2,
        with_bias=False,
        padding='SAME',
        name='initial_conv')

    if not self.resnet_v2:
      self.initial_batchnorm = hk.BatchNorm(name='initial_batchnorm',
                                            **bn_config)

    self.block_groups = []
    strides = (1, 2, 2, 2)
    for i in range(4):
      self.block_groups.append(
          hk.nets.ResNet.BlockGroup(
              channels=width_multiplier * channels_per_group[i],
              num_blocks=blocks_per_group[i],
              stride=strides[i],
              bn_config=bn_config,
              resnet_v2=resnet_v2,
              bottleneck=bottleneck,
              use_projection=use_projection[i],
              name='block_group_%d' % (i)))

    if self.resnet_v2:
      self.final_batchnorm = hk.BatchNorm(name='final_batchnorm', **bn_config)

    self.logits = hk.Linear(num_classes, w_init=jnp.zeros, name='logits')

  def __call__(self, inputs, is_training, test_local_stats=False):
    out = inputs
    out = self.initial_conv(out)
    if not self.resnet_v2:
      out = self.initial_batchnorm(out, is_training, test_local_stats)
      out = jax.nn.relu(out)

    out = hk.max_pool(out,
                      window_shape=(1, 3, 3, 1),
                      strides=(1, 2, 2, 1),
                      padding='SAME')

    for block_group in self.block_groups:
      out = block_group(out, is_training, test_local_stats)

    if self.resnet_v2:
      out = self.final_batchnorm(out, is_training, test_local_stats)
      out = jax.nn.relu(out)
    out = jnp.mean(out, axis=[1, 2])
    return out


class TinyResNet(ResNetTorso):
  """Tiny resnet for local runs and tests."""

  def __init__(self,
               num_classes: Optional[int] = None,
               bn_config: Optional[Mapping[str, float]] = None,
               resnet_v2: bool = False,
               width_multiplier: int = 1,
               name: Optional[str] = None):
    """Constructs a ResNet model.

    Args:
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
        passed on to the `BatchNorm` layers.
      resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
        to False.
      width_multiplier: An integer multiplying the number of channels per group.
      name: Name of the module.
    """
    super().__init__(blocks_per_group=(1, 1, 1, 1),
                     channels_per_group=(8, 8, 8, 8),
                     num_classes=num_classes,
                     bn_config=bn_config,
                     resnet_v2=resnet_v2,
                     bottleneck=False,
                     width_multiplier=width_multiplier,
                     name=name)


class ResNet18(ResNetTorso):
  """ResNet18."""

  def __init__(self,
               num_classes: Optional[int] = None,
               bn_config: Optional[Mapping[str, float]] = None,
               resnet_v2: bool = False,
               width_multiplier: int = 1,
               name: Optional[str] = None):
    """Constructs a ResNet model.

    Args:
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
        passed on to the `BatchNorm` layers.
      resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
        to False.
      width_multiplier: An integer multiplying the number of channels per group.
      name: Name of the module.
    """
    super().__init__(blocks_per_group=(2, 2, 2, 2),
                     num_classes=num_classes,
                     bn_config=bn_config,
                     resnet_v2=resnet_v2,
                     bottleneck=False,
                     channels_per_group=(64, 128, 256, 512),
                     width_multiplier=width_multiplier,
                     name=name)


class ResNet34(ResNetTorso):
  """ResNet34."""

  def __init__(self,
               num_classes: Optional[int],
               bn_config: Optional[Mapping[str, float]] = None,
               resnet_v2: bool = False,
               width_multiplier: int = 1,
               name: Optional[str] = None):
    """Constructs a ResNet model.

    Args:
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
        passed on to the `BatchNorm` layers.
      resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
        to False.
      width_multiplier: An integer multiplying the number of channels per group.
      name: Name of the module.
    """
    super().__init__(blocks_per_group=(3, 4, 6, 3),
                     num_classes=num_classes,
                     bn_config=bn_config,
                     resnet_v2=resnet_v2,
                     bottleneck=False,
                     channels_per_group=(64, 128, 256, 512),
                     width_multiplier=width_multiplier,
                     name=name)


class ResNet50(ResNetTorso):
  """ResNet50."""

  def __init__(self,
               num_classes: Optional[int] = None,
               bn_config: Optional[Mapping[str, float]] = None,
               resnet_v2: bool = False,
               width_multiplier: int = 1,
               name: Optional[str] = None):
    """Constructs a ResNet model.

    Args:
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
        passed on to the `BatchNorm` layers.
      resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
        to False.
      width_multiplier: An integer multiplying the number of channels per group.
      name: Name of the module.
    """
    super().__init__(blocks_per_group=(3, 4, 6, 3),
                     num_classes=num_classes,
                     bn_config=bn_config,
                     resnet_v2=resnet_v2,
                     bottleneck=True,
                     width_multiplier=width_multiplier,
                     name=name)


class ResNet101(ResNetTorso):
  """ResNet101."""

  def __init__(self,
               num_classes: Optional[int],
               bn_config: Optional[Mapping[str, float]] = None,
               resnet_v2: bool = False,
               width_multiplier: int = 1,
               name: Optional[str] = None):
    """Constructs a ResNet model.

    Args:
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
        passed on to the `BatchNorm` layers.
      resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
        to False.
      width_multiplier: An integer multiplying the number of channels per group.
      name: Name of the module.
    """
    super().__init__(blocks_per_group=(3, 4, 23, 3),
                     num_classes=num_classes,
                     bn_config=bn_config,
                     resnet_v2=resnet_v2,
                     bottleneck=True,
                     width_multiplier=width_multiplier,
                     name=name)


class ResNet152(ResNetTorso):
  """ResNet152."""

  def __init__(self,
               num_classes: Optional[int],
               bn_config: Optional[Mapping[str, float]] = None,
               resnet_v2: bool = False,
               width_multiplier: int = 1,
               name: Optional[str] = None):
    """Constructs a ResNet model.

    Args:
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
        passed on to the `BatchNorm` layers.
      resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
        to False.
      width_multiplier: An integer multiplying the number of channels per group.
      name: Name of the module.
    """
    super().__init__(blocks_per_group=(3, 8, 36, 3),
                     num_classes=num_classes,
                     bn_config=bn_config,
                     resnet_v2=resnet_v2,
                     bottleneck=True,
                     width_multiplier=width_multiplier,
                     name=name)


class ResNet200(ResNetTorso):
  """ResNet200."""

  def __init__(self,
               num_classes: Optional[int],
               bn_config: Optional[Mapping[str, float]] = None,
               resnet_v2: bool = False,
               width_multiplier: int = 1,
               name: Optional[str] = None):
    """Constructs a ResNet model.

    Args:
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
        passed on to the `BatchNorm` layers.
      resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
        to False.
      width_multiplier: An integer multiplying the number of channels per group.
      name: Name of the module.
    """
    super().__init__(blocks_per_group=(3, 24, 36, 3),
                     num_classes=num_classes,
                     bn_config=bn_config,
                     resnet_v2=resnet_v2,
                     bottleneck=True,
                     width_multiplier=width_multiplier,
                     name=name)


In [27]:
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of LARS Optimizer with optax."""

from typing import Any, Callable, List, NamedTuple, Optional, Tuple

import jax
import jax.numpy as jnp
import optax
import tree as nest

# A filter function takes a path and a value as input and outputs True for
# variable to apply update and False not to apply the update
FilterFn = Callable[[Tuple[Any], jnp.ndarray], jnp.ndarray]


def exclude_bias_and_norm(path: Tuple[Any], val: jnp.ndarray) -> jnp.ndarray:
  """Filter to exclude biaises and normalizations weights."""
  del val
  if path[-1] == "b" or "norm" in path[-2]:
    return False
  return True


def _partial_update(updates: optax.Updates,
                    new_updates: optax.Updates,
                    params: optax.Params,
                    filter_fn: Optional[FilterFn] = None) -> optax.Updates:
  """Returns new_update for params which filter_fn is True else updates."""

  if filter_fn is None:
    return new_updates

  wrapped_filter_fn = lambda x, y: jnp.array(filter_fn(x, y))
  params_to_filter = nest.map_structure_with_path(wrapped_filter_fn, params)

  def _update_fn(g: jnp.ndarray, t: jnp.ndarray, m: jnp.ndarray) -> jnp.ndarray:
    m = m.astype(g.dtype)
    return g * (1. - m) + t * m

  return jax.tree_multimap(_update_fn, updates, new_updates, params_to_filter)


class ScaleByLarsState(NamedTuple):
  mu: jnp.ndarray


def scale_by_lars(
    momentum: float = 0.9,
    eta: float = 0.001,
    filter_fn: Optional[FilterFn] = None) -> optax.GradientTransformation:
  """Rescales updates according to the LARS algorithm.

  Does not include weight decay.
  References:
    [You et al, 2017](https://arxiv.org/abs/1708.03888)

  Args:
    momentum: momentum coeficient.
    eta: LARS coefficient.
    filter_fn: an optional filter function.

  Returns:
    An (init_fn, update_fn) tuple.
  """

  def init_fn(params: optax.Params) -> ScaleByLarsState:
    mu = jax.tree_multimap(jnp.zeros_like, params)  # momentum
    return ScaleByLarsState(mu=mu)

  def update_fn(updates: optax.Updates, state: ScaleByLarsState,
                params: optax.Params) -> Tuple[optax.Updates, ScaleByLarsState]:

    def lars_adaptation(
        update: jnp.ndarray,
        param: jnp.ndarray,
    ) -> jnp.ndarray:
      param_norm = jnp.linalg.norm(param)
      update_norm = jnp.linalg.norm(update)
      return update * jnp.where(
          param_norm > 0.,
          jnp.where(update_norm > 0,
                    (eta * param_norm / update_norm), 1.0), 1.0)

    adapted_updates = jax.tree_multimap(lars_adaptation, updates, params)
    adapted_updates = _partial_update(updates, adapted_updates, params,
                                      filter_fn)
    mu = jax.tree_multimap(lambda g, t: momentum * g + t,
                           state.mu, adapted_updates)
    return mu, ScaleByLarsState(mu=mu)

  return optax.GradientTransformation(init_fn, update_fn)


class AddWeightDecayState(NamedTuple):
  """Stateless transformation."""


def add_weight_decay(
    weight_decay: float,
    filter_fn: Optional[FilterFn] = None) -> optax.GradientTransformation:
  """Adds a weight decay to the update.

  Args:
    weight_decay: weight_decay coeficient.
    filter_fn: an optional filter function.

  Returns:
    An (init_fn, update_fn) tuple.
  """

  def init_fn(_) -> AddWeightDecayState:
    return AddWeightDecayState()

  def update_fn(
      updates: optax.Updates,
      state: AddWeightDecayState,
      params: optax.Params,
  ) -> Tuple[optax.Updates, AddWeightDecayState]:
    new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates,
                                    params)
    new_updates = _partial_update(updates, new_updates, params, filter_fn)
    return new_updates, state

  return optax.GradientTransformation(init_fn, update_fn)


LarsState = List  # Type for the lars optimizer


def lars(
    learning_rate: float,
    weight_decay: float = 0.,
    momentum: float = 0.9,
    eta: float = 0.001,
    weight_decay_filter: Optional[FilterFn] = None,
    lars_adaptation_filter: Optional[FilterFn] = None,
) -> optax.GradientTransformation:
  """Creates lars optimizer with weight decay.

  References:
    [You et al, 2017](https://arxiv.org/abs/1708.03888)

  Args:
    learning_rate: learning rate coefficient.
    weight_decay: weight decay coefficient.
    momentum: momentum coefficient.
    eta: LARS coefficient.
    weight_decay_filter: optional filter function to only apply the weight
      decay on a subset of parameters. The filter function takes as input the
      parameter path (as a tuple) and its associated update, and return a True
      for params to apply the weight decay and False for params to not apply
      the weight decay. When weight_decay_filter is set to None, the weight
      decay is not applied to the bias, i.e. when the variable name is 'b', and
      the weight decay is not applied to nornalization params, i.e. the
      panultimate path contains 'norm'.
    lars_adaptation_filter: similar to weight decay filter but for lars
      adaptation

  Returns:
    An optax.GradientTransformation, i.e. a (init_fn, update_fn) tuple.
  """

  if weight_decay_filter is None:
    weight_decay_filter = lambda *_: True
  if lars_adaptation_filter is None:
    lars_adaptation_filter = lambda *_: True

  return optax.chain(
      add_weight_decay(
          weight_decay=weight_decay, filter_fn=weight_decay_filter),
      scale_by_lars(
          momentum=momentum, eta=eta, filter_fn=lars_adaptation_filter),
      optax.scale(-learning_rate),
  )


In [28]:
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Learning rate schedules."""
import jax.numpy as jnp


def target_ema(global_step: jnp.ndarray,
               base_ema: float,
               max_steps: int) -> jnp.ndarray:
  decay = _cosine_decay(global_step, max_steps, 1.)
  return 1. - (1. - base_ema) * decay


def learning_schedule(global_step: jnp.ndarray,
                      batch_size: int,
                      base_learning_rate: float,
                      total_steps: int,
                      warmup_steps: int) -> float:
  """Cosine learning rate scheduler."""
  # Compute LR & Scaled LR
  scaled_lr = base_learning_rate * batch_size / 256.
  learning_rate = (
      global_step.astype(jnp.float32) / int(warmup_steps) *
      scaled_lr if warmup_steps > 0 else scaled_lr)

  # Cosine schedule after warmup.
  return jnp.where(
      global_step < warmup_steps, learning_rate,
      _cosine_decay(global_step - warmup_steps, total_steps - warmup_steps,
                    scaled_lr))


def _cosine_decay(global_step: jnp.ndarray,
                  max_steps: int,
                  initial_value: float) -> jnp.ndarray:
  """Simple implementation of cosine decay from TF1."""
  global_step = jnp.minimum(global_step, max_steps)
  cosine_decay_value = 0.5 * (1 + jnp.cos(jnp.pi * global_step / max_steps))
  decayed_learning_rate = initial_value * cosine_decay_value
  return decayed_learning_rate


In [29]:
from acme.jax import utils as acme_utils

In [30]:
"""Utilities for JAX."""

import functools
import itertools
import queue
import threading
from typing import Callable, Generator, Iterable, NamedTuple, Optional, Sequence, Tuple, TypeVar

from absl import logging
from acme import types
import jax
import jax.numpy as jnp
import numpy as np
import tree

F = TypeVar('F', bound=Callable)
N = TypeVar('N', bound=types.NestedArray)
T = TypeVar('T')

def acme_utils_prefetch(iterable: Iterable[T],
             buffer_size: int = 5,
             device=None) -> Generator[T, None, None]:
  """Performs prefetching of elements from an iterable in a separate thread.
  Args:
    iterable: A python iterable. This is used to build the python prefetcher.
      Note that each iterable should only be passed to this function once as
      iterables aren't thread safe
    buffer_size (int): Number of elements to keep in the prefetch buffer.
    device: The device to prefetch the elements to. If none then the elements
      are left on the CPU. The device should be of the type returned by
      `jax.devices()`.
  Yields:
    Prefetched elements from the original iterable.
  Raises:
    ValueError if the buffer_size <= 1.
    Any error thrown by the iterable_function. Note this is not raised inside
      the producer, but after it finishes executing.
  """

  if buffer_size <= 1:
    raise ValueError('the buffer_size should be > 1')
  buffer = queue.Queue(maxsize=(buffer_size - 1))
  producer_error = []
  end = object()

  def producer():
    """Enqueues items from `iterable` on a given thread."""
    try:
      # Build a new iterable for each thread. This is crucial if working with
      # tensorflow datasets because tf.Graph objects are thread local.
      for item in iterable:
        if device:
          item = jax.device_put(item, device)
        buffer.put(item)
    except Exception as e:  # pylint: disable=broad-except
      logging.exception('Error in producer thread for %s', iterable)
      producer_error.append(e)
    finally:
      buffer.put(end)

  # Start the producer thread.
  threading.Thread(target=producer, daemon=True).start()

  # Consume from the buffer.
  while True:
    value = buffer.get()
    if value is end:
      break
    yield value

  if producer_error:
    raise producer_error[0]

# Main Code

In [31]:
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BYOL pre-training implementation.

Use this experiment to pre-train a self-supervised representation.
"""

import functools
from typing import Any, Generator, Mapping, NamedTuple, Text, Tuple, Union

from absl import logging
from acme.jax import utils as acme_utils
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax

#from byol.utils import augmentations
#from byol.utils import checkpointing
#from byol.utils import dataset
#from byol.utils import helpers
#from byol.utils import networks
#from byol.utils import optimizers
#from byol.utils import schedules


# Type declarations.
LogsDict = Mapping[Text, jnp.ndarray]


class _ByolExperimentState(NamedTuple):
  """Byol's model and optimization parameters and state."""
  online_params: hk.Params
  target_params: hk.Params
  online_state: hk.State
  target_state: hk.State
  opt_state: LarsState


class ByolExperiment:
  """Byol's training and evaluation component definition."""

  def __init__(
      self,
      random_seed: int,
      num_classes: int,
      batch_size: int,
      max_steps: int,
      enable_double_transpose: bool,
      base_target_ema: float,
      network_config: Mapping[Text, Any],
      optimizer_config: Mapping[Text, Any],
      lr_schedule_config: Mapping[Text, Any],
      evaluation_config: Mapping[Text, Any],
      checkpointing_config: Mapping[Text, Any]):
    """Constructs the experiment.

    Args:
      random_seed: the random seed to use when initializing network weights.
      num_classes: the number of classes; used for the online evaluation.
      batch_size: the total batch size; should be a multiple of the number of
        available accelerators.
      max_steps: the number of training steps; used for the lr/target network
        ema schedules.
      enable_double_transpose: see dataset.py; only has effect on TPU.
      base_target_ema: the initial value for the ema decay rate of the target
        network.
      network_config: the configuration for the network.
      optimizer_config: the configuration for the optimizer.
      lr_schedule_config: the configuration for the learning rate schedule.
      evaluation_config: the evaluation configuration.
      checkpointing_config: the configuration for checkpointing.
    """

    self._random_seed = random_seed
    self._enable_double_transpose = enable_double_transpose
    self._num_classes = num_classes
    self._lr_schedule_config = lr_schedule_config
    self._batch_size = batch_size
    self._max_steps = max_steps
    self._base_target_ema = base_target_ema
    self._optimizer_config = optimizer_config
    self._evaluation_config = evaluation_config

    # Checkpointed experiment state.
    self._byol_state = None

    # Input pipelines.
    self._train_input = None
    self._eval_input = None

    # build the transformed ops
    forward_fn = functools.partial(self._forward, **network_config)
    self.forward = hk.without_apply_rng(hk.transform_with_state(forward_fn))
    # training can handle multiple devices, thus the pmap
    self.update_pmap = jax.pmap(self._update_fn, axis_name='i')
    # evaluation can only handle single device
    self.eval_batch_jit = jax.jit(self._eval_batch)

    self._checkpointer = Checkpointer(**checkpointing_config)

  def _forward(
      self,
      inputs: Batch,
      projector_hidden_size: int,
      projector_output_size: int,
      predictor_hidden_size: int,
      encoder_class: Text,
      encoder_config: Mapping[Text, Any],
      bn_config: Mapping[Text, Any],
      is_training: bool,
  ) -> Mapping[Text, jnp.ndarray]:
    """Forward application of byol's architecture.

    Args:
      inputs: A batch of data, i.e. a dictionary, with either two keys,
        (`images` and `labels`) or three keys (`view1`, `view2`, `labels`).
      projector_hidden_size: hidden size of the projector MLP.
      projector_output_size: output size of the projector and predictor MLPs.
      predictor_hidden_size: hidden size of the predictor MLP.
      encoder_class: type of the encoder (should match a class in
        utils/networks).
      encoder_config: passed to the encoder constructor.
      bn_config: passed to the hk.BatchNorm constructors.
      is_training: Training or evaluating the model? When True, inputs must
        contain keys `view1` and `view2`. When False, inputs must contain key
        `images`.

    Returns:
      All outputs of the model, i.e. a dictionary with projection, prediction
      and logits keys, for either the two views, or the image.
    """
    encoder = getattr(networks, encoder_class)
    net = encoder(
        num_classes=None,  # Don't build the final linear layer
        bn_config=bn_config,
        **encoder_config)

    projector = networks.MLP(
        name='projector',
        hidden_size=projector_hidden_size,
        output_size=projector_output_size,
        bn_config=bn_config)
    predictor = networks.MLP(
        name='predictor',
        hidden_size=predictor_hidden_size,
        output_size=projector_output_size,
        bn_config=bn_config)
    classifier = hk.Linear(
        output_size=self._num_classes, name='classifier')

    def apply_once_fn(images: jnp.ndarray, suffix: Text = ''):
      images = normalize_images(images)

      embedding = net(images, is_training=is_training)
      proj_out = projector(embedding, is_training)
      pred_out = predictor(proj_out, is_training)

      # Note the stop_gradient: label information is not leaked into the
      # main network.
      classif_out = classifier(jax.lax.stop_gradient(embedding))
      outputs = {}
      outputs['projection' + suffix] = proj_out
      outputs['prediction' + suffix] = pred_out
      outputs['logits' + suffix] = classif_out
      return outputs

    if is_training:
      outputs_view1 = apply_once_fn(inputs['view1'], '_view1')
      outputs_view2 = apply_once_fn(inputs['view2'], '_view2')
      return {**outputs_view1, **outputs_view2}
    else:
      return apply_once_fn(inputs['images'], '')

  def _optimizer(self, learning_rate: float) -> optax.GradientTransformation:
    """Build optimizer from config."""
    return lars(
        learning_rate,
        weight_decay_filter=exclude_bias_and_norm,
        lars_adaptation_filter=exclude_bias_and_norm,
        **self._optimizer_config)

  def loss_fn(
      self,
      online_params: hk.Params,
      target_params: hk.Params,
      online_state: hk.State,
      target_state: hk.Params,
      rng: jnp.ndarray,
      inputs: Batch,
  ) -> Tuple[jnp.ndarray, Tuple[Mapping[Text, hk.State], LogsDict]]:
    """Compute BYOL's loss function.

    Args:
      online_params: parameters of the online network (the loss is later
        differentiated with respect to the online parameters).
      target_params: parameters of the target network.
      online_state: internal state of online network.
      target_state: internal state of target network.
      rng: random number generator state.
      inputs: inputs, containing two batches of crops from the same images,
        view1 and view2 and labels

    Returns:
      BYOL's loss, a mapping containing the online and target networks updated
      states after processing inputs, and various logs.
    """
    if self._should_transpose_images():
      inputs = transpose_images(inputs)
    inputs = postprocess(inputs, rng)
    labels = inputs['labels']

    online_network_out, online_state = self.forward.apply(
        params=online_params,
        state=online_state,
        inputs=inputs,
        is_training=True)
    target_network_out, target_state = self.forward.apply(
        params=target_params,
        state=target_state,
        inputs=inputs,
        is_training=True)

    # Representation loss

    # The stop_gradient is not necessary as we explicitly take the gradient with
    # respect to online parameters only in `optax.apply_updates`. We leave it to
    # indicate that gradients are not backpropagated through the target network.
    repr_loss = regression_loss(
        online_network_out['prediction_view1'],
        jax.lax.stop_gradient(target_network_out['projection_view2']))
    repr_loss = repr_loss + regression_loss(
        online_network_out['prediction_view2'],
        jax.lax.stop_gradient(target_network_out['projection_view1']))

    repr_loss = jnp.mean(repr_loss)

    # Classification loss (with gradient flows stopped from flowing into the
    # ResNet). This is used to provide an evaluation of the representation
    # quality during training.

    classif_loss = softmax_cross_entropy(
        logits=online_network_out['logits_view1'],
        labels=jax.nn.one_hot(labels, self._num_classes))

    top1_correct = topk_accuracy(
        online_network_out['logits_view1'],
        inputs['labels'],
        topk=1,
    )

    top5_correct = topk_accuracy(
        online_network_out['logits_view1'],
        inputs['labels'],
        topk=5,
    )

    top1_acc = jnp.mean(top1_correct)
    top5_acc = jnp.mean(top5_correct)

    classif_loss = jnp.mean(classif_loss)
    loss = repr_loss + classif_loss
    logs = dict(
        loss=loss,
        repr_loss=repr_loss,
        classif_loss=classif_loss,
        top1_accuracy=top1_acc,
        top5_accuracy=top5_acc,
    )

    return loss, (dict(online_state=online_state,
                       target_state=target_state), logs)

  def _should_transpose_images(self):
    """Should we transpose images (saves host-to-device time on TPUs)."""
    return (self._enable_double_transpose and
            jax.local_devices()[0].platform == 'tpu')

  def _update_fn(
      self,
      byol_state: _ByolExperimentState,
      global_step: jnp.ndarray,
      rng: jnp.ndarray,
      inputs: Batch,
  ) -> Tuple[_ByolExperimentState, LogsDict]:
    """Update online and target parameters.

    Args:
      byol_state: current BYOL state.
      global_step: current training step.
      rng: current random number generator
      inputs: inputs, containing two batches of crops from the same images,
        view1 and view2 and labels

    Returns:
      Tuple containing the updated Byol state after processing the inputs, and
      various logs.
    """
    online_params = byol_state.online_params
    target_params = byol_state.target_params
    online_state = byol_state.online_state
    target_state = byol_state.target_state
    opt_state = byol_state.opt_state

    # update online network
    grad_fn = jax.grad(self.loss_fn, argnums=0, has_aux=True)
    grads, (net_states, logs) = grad_fn(online_params, target_params,
                                        online_state, target_state, rng, inputs)

    # cross-device grad and logs reductions
    grads = jax.tree_map(lambda v: jax.lax.pmean(v, axis_name='i'), grads)
    logs = jax.tree_multimap(lambda x: jax.lax.pmean(x, axis_name='i'), logs)

    learning_rate = learning_schedule(
        global_step,
        batch_size=self._batch_size,
        total_steps=self._max_steps,
        **self._lr_schedule_config)
    updates, opt_state = self._optimizer(learning_rate).update(
        grads, opt_state, online_params)
    online_params = optax.apply_updates(online_params, updates)

    # update target network
    tau = target_ema(
        global_step,
        base_ema=self._base_target_ema,
        max_steps=self._max_steps)
    target_params = jax.tree_multimap(lambda x, y: x + (1 - tau) * (y - x),
                                      target_params, online_params)
    logs['tau'] = tau
    logs['learning_rate'] = learning_rate
    return _ByolExperimentState(
        online_params=online_params,
        target_params=target_params,
        online_state=net_states['online_state'],
        target_state=net_states['target_state'],
        opt_state=opt_state), logs

  def _make_initial_state(
      self,
      rng: jnp.ndarray,
      dummy_input: Batch,
  ) -> _ByolExperimentState:
    """BYOL's _ByolExperimentState initialization.

    Args:
      rng: random number generator used to initialize parameters. If working in
        a multi device setup, this need to be a ShardedArray.
      dummy_input: a dummy image, used to compute intermediate outputs shapes.

    Returns:
      Initial Byol state.
    """
    rng_online, rng_target = jax.random.split(rng)

    if self._should_transpose_images():
      dummy_input = dataset.transpose_images(dummy_input)

    # Online and target parameters are initialized using different rngs,
    # in our experiments we did not notice a significant different with using
    # the same rng for both.
    online_params, online_state = self.forward.init(
        rng_online,
        dummy_input,
        is_training=True,
    )
    target_params, target_state = self.forward.init(
        rng_target,
        dummy_input,
        is_training=True,
    )
    opt_state = self._optimizer(0).init(online_params)
    return _ByolExperimentState(
        online_params=online_params,
        target_params=target_params,
        opt_state=opt_state,
        online_state=online_state,
        target_state=target_state,
    )

  def step(self, *,
           global_step: jnp.ndarray,
           rng: jnp.ndarray) -> Mapping[Text, np.ndarray]:
    """Performs a single training step."""
    if self._train_input is None:
      self._initialize_train()

    inputs = next(self._train_input)

    self._byol_state, scalars = self.update_pmap(
        self._byol_state,
        global_step=global_step,
        rng=rng,
        inputs=inputs,
    )

    return helpers.get_first(scalars)

  def save_checkpoint(self, step: int, rng: jnp.ndarray):
    self._checkpointer.maybe_save_checkpoint(
        self._byol_state, step=step, rng=rng, is_final=step >= self._max_steps)

  def load_checkpoint(self) -> Union[Tuple[int, jnp.ndarray], None]:
    checkpoint_data = self._checkpointer.maybe_load_checkpoint()
    if checkpoint_data is None:
      return None
    self._byol_state, step, rng = checkpoint_data
    return step, rng

  def _initialize_train(self):
    """Initialize train.

    This includes initializing the input pipeline and Byol's state.
    """
    self._train_input = acme_utils.prefetch(self._build_train_input())

    # Check we haven't already restored params
    if self._byol_state is None:
      logging.info(
          'Initializing parameters rather than restoring from checkpoint.')

      # initialize Byol and setup optimizer state
      inputs = next(self._train_input)
      init_byol = jax.pmap(self._make_initial_state, axis_name='i')

      # Init uses the same RNG key on all hosts+devices to ensure everyone
      # computes the same initial state and parameters.
      init_rng = jax.random.PRNGKey(self._random_seed)
      init_rng = helpers.bcast_local_devices(init_rng)

      self._byol_state = init_byol(rng=init_rng, dummy_input=inputs)

  def _build_train_input(self) -> Generator[Batch, None, None]:
    """Loads the (infinitely looping) dataset iterator."""
    num_devices = jax.device_count()
    global_batch_size = self._batch_size
    per_device_batch_size, ragged = divmod(global_batch_size, num_devices)

    if ragged:
      raise ValueError(
          f'Global batch size {global_batch_size} must be divisible by '
          f'num devices {num_devices}')

    return load(
        Split.TRAIN_AND_VALID,
        preprocess_mode=PreprocessMode.PRETRAIN,
        transpose=self._should_transpose_images(),
        batch_dims=[jax.local_device_count(), per_device_batch_size])

  def _eval_batch(
      self,
      params: hk.Params,
      state: hk.State,
      batch: Batch,
  ) -> Mapping[Text, jnp.ndarray]:
    """Evaluates a batch.

    Args:
      params: Parameters of the model to evaluate. Typically Byol's online
        parameters.
      state: State of the model to evaluate. Typically Byol's online state.
      batch: Batch of data to evaluate (must contain keys images and labels).

    Returns:
      Unreduced evaluation loss and top1 accuracy on the batch.
    """
    if self._should_transpose_images():
      batch = transpose_images(batch)

    outputs, _ = self.forward.apply(params, state, batch, is_training=False)
    logits = outputs['logits']
    labels = hk.one_hot(batch['labels'], self._num_classes)
    loss = softmax_cross_entropy(logits, labels, reduction=None)
    top1_correct = topk_accuracy(logits, batch['labels'], topk=1)
    top5_correct = topk_accuracy(logits, batch['labels'], topk=5)
    # NOTE: Returned values will be summed and finally divided by num_samples.
    return {
        'eval_loss': loss,
        'top1_accuracy': top1_correct,
        'top5_accuracy': top5_correct,
    }

  def _eval_epoch(self, subset: Text, batch_size: int):
    """Evaluates an epoch."""
    num_samples = 0.
    summed_scalars = None

    params = get_first(self._byol_state.online_params)
    state = get_first(self._byol_state.online_state)
    split = Split.from_string(subset)

    dataset_iterator = load(
        split,
        preprocess_mode=PreprocessMode.EVAL,
        transpose=self._should_transpose_images(),
        batch_dims=[batch_size])

    for inputs in dataset_iterator:
      num_samples += inputs['labels'].shape[0]
      scalars = self.eval_batch_jit(params, state, inputs)

      # Accumulate the sum of scalars for each step.
      scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars)
      if summed_scalars is None:
        summed_scalars = scalars
      else:
        summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)

    mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
    return mean_scalars

  def evaluate(self, global_step, **unused_args):
    """Thin wrapper around _eval_epoch."""

    global_step = np.array(get_first(global_step))
    scalars = jax.device_get(self._eval_epoch(**self._evaluation_config))

    logging.info('[Step %d] Eval scalars: %s', global_step, scalars)
    return scalars


In [32]:
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Linear evaluation or fine-tuning pipeline.

Use this experiment to evaluate a checkpoint from byol_experiment.
"""

import functools
from typing import Any, Generator, Mapping, NamedTuple, Optional, Text, Tuple, Union

from absl import logging
from acme.jax import utils as acme_utils
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax

#from byol.utils import checkpointing
#from byol.utils import dataset
#from byol.utils import helpers
#from byol.utils import networks
#from byol.utils import schedules

# Type declarations.
OptState = Tuple[optax.TraceState, optax.ScaleByScheduleState, optax.ScaleState]
LogsDict = Mapping[Text, jnp.ndarray]


class _EvalExperimentState(NamedTuple):
  backbone_params: hk.Params
  classif_params: hk.Params
  backbone_state: hk.State
  backbone_opt_state: Union[None, OptState]
  classif_opt_state: OptState


class EvalExperiment:
  """Linear evaluation experiment."""

  def __init__(
      self,
      random_seed: int,
      num_classes: int,
      batch_size: int,
      max_steps: int,
      enable_double_transpose: bool,
      checkpoint_to_evaluate: Optional[Text],
      allow_train_from_scratch: bool,
      freeze_backbone: bool,
      network_config: Mapping[Text, Any],
      optimizer_config: Mapping[Text, Any],
      lr_schedule_config: Mapping[Text, Any],
      evaluation_config: Mapping[Text, Any],
      checkpointing_config: Mapping[Text, Any]):
    """Constructs the experiment.

    Args:
      random_seed: the random seed to use when initializing network weights.
      num_classes: the number of classes; used for the online evaluation.
      batch_size: the total batch size; should be a multiple of the number of
        available accelerators.
      max_steps: the number of training steps; used for the lr/target network
        ema schedules.
      enable_double_transpose: see dataset.py; only has effect on TPU.
      checkpoint_to_evaluate: the path to the checkpoint to evaluate.
      allow_train_from_scratch: whether to allow training without specifying a
        checkpoint to evaluate (training from scratch).
      freeze_backbone: whether the backbone resnet should remain frozen (linear
        evaluation) or be trainable (fine-tuning).
      network_config: the configuration for the network.
      optimizer_config: the configuration for the optimizer.
      lr_schedule_config: the configuration for the learning rate schedule.
      evaluation_config: the evaluation configuration.
      checkpointing_config: the configuration for checkpointing.
    """

    self._random_seed = random_seed
    self._enable_double_transpose = enable_double_transpose
    self._num_classes = num_classes
    self._lr_schedule_config = lr_schedule_config
    self._batch_size = batch_size
    self._max_steps = max_steps
    self._checkpoint_to_evaluate = checkpoint_to_evaluate
    self._allow_train_from_scratch = allow_train_from_scratch
    self._freeze_backbone = freeze_backbone
    self._optimizer_config = optimizer_config
    self._evaluation_config = evaluation_config

    # Checkpointed experiment state.
    self._experiment_state = None

    # Input pipelines.
    self._train_input = None
    self._eval_input = None

    backbone_fn = functools.partial(self._backbone_fn, **network_config)
    self.forward_backbone = hk.without_apply_rng(
        hk.transform_with_state(backbone_fn))
    self.forward_classif = hk.without_apply_rng(hk.transform(self._classif_fn))
    self.update_pmap = jax.pmap(self._update_func, axis_name='i')
    self.eval_batch_jit = jax.jit(self._eval_batch)

    self._is_backbone_training = not self._freeze_backbone

    self._checkpointer = Checkpointer(**checkpointing_config)

  def _should_transpose_images(self):
    """Should we transpose images (saves host-to-device time on TPUs)."""
    return (self._enable_double_transpose and
            jax.local_devices()[0].platform == 'tpu')

  def _backbone_fn(
      self,
      inputs: Batch,
      encoder_class: Text,
      encoder_config: Mapping[Text, Any],
      bn_decay_rate: float,
      is_training: bool,
  ) -> jnp.ndarray:
    """Forward of the encoder (backbone)."""
    bn_config = {'decay_rate': bn_decay_rate}
    encoder = getattr(networks, encoder_class)
    model = encoder(
        None,
        bn_config=bn_config,
        **encoder_config)

    if self._should_transpose_images():
      inputs = transpose_images(inputs)
    images = normalize_images(inputs['images'])
    return model(images, is_training=is_training)

  def _classif_fn(
      self,
      embeddings: jnp.ndarray,
  ) -> jnp.ndarray:
    classifier = hk.Linear(output_size=self._num_classes)
    return classifier(embeddings)

  #  _             _
  # | |_ _ __ __ _(_)_ __
  # | __| '__/ _` | | '_ \
  # | |_| | | (_| | | | | |
  #  \__|_|  \__,_|_|_| |_|
  #

  def step(self, *,
           global_step: jnp.ndarray,
           rng: jnp.ndarray) -> Mapping[Text, np.ndarray]:
    """Performs a single training step."""

    if self._train_input is None:
      self._initialize_train(rng)

    inputs = next(self._train_input)
    self._experiment_state, scalars = self.update_pmap(
        self._experiment_state, global_step, inputs)

    scalars = get_first(scalars)
    return scalars

  def save_checkpoint(self, step: int, rng: jnp.ndarray):
    self._checkpointer.maybe_save_checkpoint(
        self._experiment_state, step=step, rng=rng,
        is_final=step >= self._max_steps)

  def load_checkpoint(self) -> Union[Tuple[int, jnp.ndarray], None]:
    checkpoint_data = self._checkpointer.maybe_load_checkpoint()
    if checkpoint_data is None:
      return None
    self._experiment_state, step, rng = checkpoint_data
    return step, rng

  def _initialize_train(self, rng):
    """BYOL's _ExperimentState initialization.

    Args:
      rng: random number generator used to initialize parameters. If working in
        a multi device setup, this need to be a ShardedArray.
      dummy_input: a dummy image, used to compute intermediate outputs shapes.

    Returns:
      Initial EvalExperiment state.

    Raises:
      RuntimeError: invalid or empty checkpoint.
    """
    self._train_input = acme_utils.prefetch(self._build_train_input())

    # Check we haven't already restored params
    if self._experiment_state is None:

      inputs = next(self._train_input)

      if self._checkpoint_to_evaluate is not None:
        # Load params from checkpoint
        checkpoint_data = load_checkpoint(
            self._checkpoint_to_evaluate)
        if checkpoint_data is None:
          raise RuntimeError('Invalid checkpoint.')
        backbone_params = checkpoint_data['experiment_state'].online_params
        backbone_state = checkpoint_data['experiment_state'].online_state
        backbone_params = bcast_local_devices(backbone_params)
        backbone_state = bcast_local_devices(backbone_state)
      else:
        if not self._allow_train_from_scratch:
          raise ValueError(
              'No checkpoint specified, but `allow_train_from_scratch` '
              'set to False')
        # Initialize with random parameters
        logging.info(
            'No checkpoint specified, initializing the networks from scratch '
            '(dry run mode)')
        backbone_params, backbone_state = jax.pmap(
            functools.partial(self.forward_backbone.init, is_training=True),
            axis_name='i')(rng=rng, inputs=inputs)

      init_experiment = jax.pmap(self._make_initial_state, axis_name='i')

      # Init uses the same RNG key on all hosts+devices to ensure everyone
      # computes the same initial state and parameters.
      init_rng = jax.random.PRNGKey(self._random_seed)
      init_rng = bcast_local_devices(init_rng)
      self._experiment_state = init_experiment(
          rng=init_rng,
          dummy_input=inputs,
          backbone_params=backbone_params,
          backbone_state=backbone_state)

      # Clear the backbone optimizer's state when the backbone is frozen.
      if self._freeze_backbone:
        self._experiment_state = _EvalExperimentState(
            backbone_params=self._experiment_state.backbone_params,
            classif_params=self._experiment_state.classif_params,
            backbone_state=self._experiment_state.backbone_state,
            backbone_opt_state=None,
            classif_opt_state=self._experiment_state.classif_opt_state,
        )

  def _make_initial_state(
      self,
      rng: jnp.ndarray,
      dummy_input: Batch,
      backbone_params: hk.Params,
      backbone_state: hk.Params,
  ) -> _EvalExperimentState:
    """_EvalExperimentState initialization."""

    # Initialize the backbone params
    # Always create the batchnorm weights (is_training=True), they will be
    # overwritten when loading the checkpoint.
    embeddings, _ = self.forward_backbone.apply(
        backbone_params, backbone_state, dummy_input, is_training=True)
    backbone_opt_state = self._optimizer(0.).init(backbone_params)

    # Initialize the classifier params and optimizer_state
    classif_params = self.forward_classif.init(rng, embeddings)
    classif_opt_state = self._optimizer(0.).init(classif_params)

    return _EvalExperimentState(
        backbone_params=backbone_params,
        classif_params=classif_params,
        backbone_state=backbone_state,
        backbone_opt_state=backbone_opt_state,
        classif_opt_state=classif_opt_state,
    )

  def _build_train_input(self) -> Generator[Batch, None, None]:
    """See base class."""
    num_devices = jax.device_count()
    global_batch_size = self._batch_size
    per_device_batch_size, ragged = divmod(global_batch_size, num_devices)

    if ragged:
      raise ValueError(
          f'Global batch size {global_batch_size} must be divisible by '
          f'num devices {num_devices}')

    return load(
        Split.TRAIN_AND_VALID,
        preprocess_mode=PreprocessMode.LINEAR_TRAIN,
        transpose=self._should_transpose_images(),
        batch_dims=[jax.local_device_count(), per_device_batch_size])

  def _optimizer(self, learning_rate: float):
    """Build optimizer from config."""
    return optax.sgd(learning_rate, **self._optimizer_config)

  def _loss_fn(
      self,
      backbone_params: hk.Params,
      classif_params: hk.Params,
      backbone_state: hk.State,
      inputs: Batch,
  ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, hk.State]]:
    """Compute the classification loss function.

    Args:
      backbone_params: parameters of the encoder network.
      classif_params: parameters of the linear classifier.
      backbone_state: internal state of encoder network.
      inputs: inputs, containing `images` and `labels`.

    Returns:
      The classification loss and various logs.
    """
    embeddings, backbone_state = self.forward_backbone.apply(
        backbone_params,
        backbone_state,
        inputs,
        is_training=not self._freeze_backbone)

    logits = self.forward_classif.apply(classif_params, embeddings)
    labels = hk.one_hot(inputs['labels'], self._num_classes)
    loss = softmax_cross_entropy(logits, labels, reduction='mean')
    scaled_loss = loss / jax.device_count()

    return scaled_loss, (loss, backbone_state)

  def _update_func(
      self,
      experiment_state: _EvalExperimentState,
      global_step: jnp.ndarray,
      inputs: Batch,
  ) -> Tuple[_EvalExperimentState, LogsDict]:
    """Applies an update to parameters and returns new state."""
    # This function computes the gradient of the first output of loss_fn and
    # passes through the other arguments unchanged.

    # Gradient of the first output of _loss_fn wrt the backbone (arg 0) and the
    # classifier parameters (arg 1). The auxiliary outputs are returned as-is.
    grad_loss_fn = jax.grad(self._loss_fn, has_aux=True, argnums=(0, 1))

    grads, aux_outputs = grad_loss_fn(
        experiment_state.backbone_params,
        experiment_state.classif_params,
        experiment_state.backbone_state,
        inputs,
    )
    backbone_grads, classifier_grads = grads
    train_loss, new_backbone_state = aux_outputs
    classifier_grads = jax.lax.psum(classifier_grads, axis_name='i')

    # Compute the decayed learning rate
    learning_rate = learning_schedule(
        global_step,
        batch_size=self._batch_size,
        total_steps=self._max_steps,
        **self._lr_schedule_config)

    # Compute and apply updates via our optimizer.
    classif_updates, new_classif_opt_state = \
        self._optimizer(learning_rate).update(
            classifier_grads,
            experiment_state.classif_opt_state)

    new_classif_params = optax.apply_updates(experiment_state.classif_params,
                                             classif_updates)

    if self._freeze_backbone:
      del backbone_grads, new_backbone_state  # Unused
      # The backbone is not updated.
      new_backbone_params = experiment_state.backbone_params
      new_backbone_opt_state = None
      new_backbone_state = experiment_state.backbone_state
    else:
      backbone_grads = jax.lax.psum(backbone_grads, axis_name='i')

      # Compute and apply updates via our optimizer.
      backbone_updates, new_backbone_opt_state = \
          self._optimizer(learning_rate).update(
              backbone_grads,
              experiment_state.backbone_opt_state)

      new_backbone_params = optax.apply_updates(
          experiment_state.backbone_params, backbone_updates)

    experiment_state = _EvalExperimentState(
        new_backbone_params,
        new_classif_params,
        new_backbone_state,
        new_backbone_opt_state,
        new_classif_opt_state,
    )

    # Scalars to log (note: we log the mean across all hosts/devices).
    scalars = {'train_loss': train_loss}
    scalars = jax.lax.pmean(scalars, axis_name='i')

    return experiment_state, scalars

  #                  _
  #   _____   ____ _| |
  #  / _ \ \ / / _` | |
  # |  __/\ V / (_| | |
  #  \___| \_/ \__,_|_|
  #

  def evaluate(self, global_step, **unused_args):
    """See base class."""

    global_step = np.array(get_first(global_step))
    scalars = jax.device_get(self._eval_epoch(**self._evaluation_config))

    logging.info('[Step %d] Eval scalars: %s', global_step, scalars)
    return scalars

  def _eval_batch(
      self,
      backbone_params: hk.Params,
      classif_params: hk.Params,
      backbone_state: hk.State,
      inputs: Batch,
  ) -> LogsDict:
    """Evaluates a batch."""
    embeddings, backbone_state = self.forward_backbone.apply(
        backbone_params, backbone_state, inputs, is_training=False)
    logits = self.forward_classif.apply(classif_params, embeddings)
    labels = hk.one_hot(inputs['labels'], self._num_classes)
    loss = softmax_cross_entropy(logits, labels, reduction=None)
    top1_correct = topk_accuracy(logits, inputs['labels'], topk=1)
    top5_correct = topk_accuracy(logits, inputs['labels'], topk=5)
    # NOTE: Returned values will be summed and finally divided by num_samples.
    return {
        'eval_loss': loss,
        'top1_accuracy': top1_correct,
        'top5_accuracy': top5_correct
    }

  def _eval_epoch(self, subset: Text, batch_size: int):
    """Evaluates an epoch."""
    num_samples = 0.
    summed_scalars = None

    backbone_params = get_first(self._experiment_state.backbone_params)
    classif_params = get_first(self._experiment_state.classif_params)
    backbone_state = get_first(self._experiment_state.backbone_state)
    split = Split.from_string(subset)

    dataset_iterator = load(
        split,
        preprocess_mode=PreprocessMode.EVAL,
        transpose=self._should_transpose_images(),
        batch_dims=[batch_size])

    for inputs in dataset_iterator:
      num_samples += inputs['labels'].shape[0]
      scalars = self.eval_batch_jit(
          backbone_params,
          classif_params,
          backbone_state,
          inputs,
      )

      # Accumulate the sum of scalars for each step.
      scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars)
      if summed_scalars is None:
        summed_scalars = scalars
      else:
        summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)

    mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
    return mean_scalars


In [33]:
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Config file for BYOL experiment."""

#from byol.utils import dataset


# Preset values for certain number of training epochs.
_LR_PRESETS = {40: 0.45, 100: 0.45, 300: 0.3, 1000: 0.2}
_WD_PRESETS = {40: 1e-6, 100: 1e-6, 300: 1e-6, 1000: 1.5e-6}
_EMA_PRESETS = {40: 0.97, 100: 0.99, 300: 0.99, 1000: 0.996}


def get_config_byol(num_epochs: int, batch_size: int):
  """Return config object, containing all hyperparameters for training."""
  train_images_per_epoch = Split.TRAIN_AND_VALID.num_examples

  assert num_epochs in [40, 100, 300, 1000]

  config = dict(
      random_seed=0,
      num_classes=1000,
      batch_size=batch_size,
      max_steps=num_epochs * train_images_per_epoch // batch_size,
      enable_double_transpose=True,
      base_target_ema=_EMA_PRESETS[num_epochs],
      network_config=dict(
          projector_hidden_size=4096,
          projector_output_size=256,
          predictor_hidden_size=4096,
          encoder_class='ResNet18',  # Should match a class in utils/networks.
          encoder_config=dict(
              resnet_v2=False,
              width_multiplier=1),
          bn_config={
              'decay_rate': .9,
              'eps': 1e-5,
              # Accumulate batchnorm statistics across devices.
              # This should be equal to the `axis_name` argument passed
              # to jax.pmap.
              'cross_replica_axis': 'i',
              'create_scale': True,
              'create_offset': True,
          }),
      optimizer_config=dict(
          weight_decay=_WD_PRESETS[num_epochs],
          eta=1e-3,
          momentum=.9,
      ),
      lr_schedule_config=dict(
          base_learning_rate=_LR_PRESETS[num_epochs],
          warmup_steps=10 * train_images_per_epoch // batch_size,
      ),
      evaluation_config=dict(
          subset='test',
          batch_size=100,
      ),
      checkpointing_config=dict(
          use_checkpointing=True,
          checkpoint_dir='/tmp/byol',
          save_checkpoint_interval=300,
          filename='pretrain.pkl'
      ),
  )

  return config


In [34]:
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Config file for evaluation experiment."""

from typing import Text

#from byol.utils import dataset


def get_config_eval(checkpoint_to_evaluate: Text, batch_size: int):
  """Return config object for training."""
  train_images_per_epoch = Split.TRAIN_AND_VALID.num_examples

  config = dict(
      random_seed=0,
      enable_double_transpose=True,
      max_steps=80 * train_images_per_epoch // batch_size,
      num_classes=1000,
      batch_size=batch_size,
      checkpoint_to_evaluate=checkpoint_to_evaluate,
      # If True, allows training without loading a checkpoint.
      allow_train_from_scratch=False,
      # Whether the backbone should be frozen (linear evaluation) or
      # trainable (fine-tuning).
      freeze_backbone=True,
      optimizer_config=dict(
          momentum=0.9,
          nesterov=True,
      ),
      lr_schedule_config=dict(
          base_learning_rate=0.2,
          warmup_steps=0,
      ),
      network_config=dict(  # Should match the evaluated checkpoint
          encoder_class='ResNet18',  # Should match a class in utils/networks.
          encoder_config=dict(
              resnet_v2=False,
              width_multiplier=1),
          bn_decay_rate=0.9,
      ),
      evaluation_config=dict(
          subset='test',
          batch_size=100,
      ),
      checkpointing_config=dict(
          use_checkpointing=True,
          checkpoint_dir='/tmp/byol',
          save_checkpoint_interval=300,
          filename='linear-eval.pkl'
      ),
  )

  return config




# Train and Eval Loop

In [35]:
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Training and evaluation loops for an experiment."""

import time
from typing import Any, Mapping, Text, Type, Union

from absl import app
from absl import flags
from absl import logging
import jax
import numpy as np
import argparse

parser = argparse.ArgumentParser(description='BYOL PARAMETERS')
#from byol import byol_experiment
#from byol import eval_experiment
#from byol.configs import byol as byol_config
#from byol.configs import eval as eval_config
"""
flags.DEFINE_string('experiment_mode',
                    'pretrain', 'The experiment, pretrain or linear-eval')
flags.DEFINE_string('worker_mode', 'train', 'The mode, train or eval')
flags.DEFINE_string('worker_tpu_driver', '', 'The tpu driver to use')
flags.DEFINE_integer('pretrain_epochs', 1000, 'Number of pre-training epochs')
flags.DEFINE_integer('batch_size', 4096, 'Total batch size')
flags.DEFINE_string('checkpoint_root', '/tmp/byol',
                    'The directory to save checkpoints to.')
flags.DEFINE_integer('log_tensors_interval', 60, 'Log tensors every n seconds.')

FLAGS = flags.FLAGS
"""
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
                    help='number of data loading workers (default: 32)')
parser.add_argument('--pretrain-epochs', default=1000, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=4096, type=int,
                    metavar='N',
                    help='mini-batch size (default: 4096), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--log-tensors-interval', default=60, type=int,
                    metavar='LTI', help='Log tensors every n seconds.')
parser.add_argument('--experiment-mode', default='pretrain', type=str,
                    help='The experiment, pretrain or linear-eval')
parser.add_argument('--worker-mode', default='train', type=str,
                    help='path to latest checkpoint (default: none)')
'The mode, train or eval'
parser.add_argument('--worker-tpu-driver', default='', type=str,
                    help='The tpu driver to use')
parser.add_argument('--checkpoint-root', default='/tmp/byol', type=str,
                    help='The directory to save checkpoints to.')

FLAGS, unknown = parser.parse_known_args()
"""
Experiment = Union[
    Type[byol_experiment.ByolExperiment],
    Type[eval_experiment.EvalExperiment]]
"""
Experiment = Union[
    Type[ByolExperiment],
    Type[EvalExperiment]]

def train_loop(experiment_class: Experiment, config: Mapping[Text, Any]):
  """The main training loop.

  This loop periodically saves a checkpoint to be evaluated in the eval_loop.

  Args:
    experiment_class: the constructor for the experiment (either byol_experiment
    or eval_experiment).
    config: the experiment config.
  """
  experiment = experiment_class(**config)
  rng = jax.random.PRNGKey(0)
  step = 0

  host_id = jax.host_id()
  last_logging = time.time()
  if config['checkpointing_config']['use_checkpointing']:
    checkpoint_data = experiment.load_checkpoint()
    if checkpoint_data is None:
      step = 0
    else:
      step, rng = checkpoint_data

  local_device_count = jax.local_device_count()
  while step < config['max_steps']:
    step_rng, rng = tuple(jax.random.split(rng))
    # Broadcast the random seeds across the devices
    step_rng_device = jax.random.split(step_rng, num=jax.device_count())
    step_rng_device = step_rng_device[
        host_id * local_device_count:(host_id + 1) * local_device_count]
    step_device = np.broadcast_to(step, [local_device_count])

    # Perform a training step and get scalars to log.
    scalars = experiment.step(global_step=step_device, rng=step_rng_device)

    # Checkpointing and logging.
    if config['checkpointing_config']['use_checkpointing']:
      experiment.save_checkpoint(step, rng)
      current_time = time.time()
      if current_time - last_logging > FLAGS.log_tensors_interval:
        logging.info('Step %d: %s', step, scalars)
        last_logging = current_time
    step += 1
  logging.info('Saving final checkpoint')
  logging.info('Step %d: %s', step, scalars)
  experiment.save_checkpoint(step, rng)


def eval_loop(experiment_class: Experiment, config: Mapping[Text, Any]):
  """The main evaluation loop.

  This loop periodically loads a checkpoint and evaluates its performance on the
  test set, by calling experiment.evaluate.

  Args:
    experiment_class: the constructor for the experiment (either byol_experiment
    or eval_experiment).
    config: the experiment config.
  """
  experiment = experiment_class(**config)
  last_evaluated_step = -1
  while True:
    checkpoint_data = experiment.load_checkpoint()
    if checkpoint_data is None:
      logging.info('No checkpoint found. Waiting for 10s.')
      time.sleep(10)
      continue
    step, _ = checkpoint_data
    if step <= last_evaluated_step:
      logging.info('Checkpoint at step %d already evaluated, waiting.', step)
      time.sleep(10)
      continue
    host_id = jax.host_id()
    local_device_count = jax.local_device_count()
    step_device = np.broadcast_to(step, [local_device_count])
    scalars = experiment.evaluate(global_step=step_device)
    if host_id == 0:  # Only perform logging in one host.
      logging.info('Evaluation at step %d: %s', step, scalars)
    last_evaluated_step = step
    if last_evaluated_step >= config['max_steps']:
      return



In [38]:
if FLAGS.worker_tpu_driver:
  jax.config.update('jax_xla_backend', 'tpu_driver')
  jax.config.update('jax_backend_target', FLAGS.worker_tpu_driver)
  logging.info('Backend: %s %r', FLAGS.worker_tpu_driver, jax.devices())

if FLAGS.experiment_mode == 'pretrain':
  experiment_class = ByolExperiment
  config = get_config_byol(FLAGS.pretrain_epochs, FLAGS.batch_size)
elif FLAGS.experiment_mode == 'linear-eval':
  experiment_class = EvalExperiment
  config = eval_config.get_config_eval(f'{FLAGS.checkpoint_root}/pretrain.pkl',
                                  FLAGS.batch_size)
else:
  raise ValueError(f'Unknown experiment mode: {FLAGS.experiment_mode}')
config['checkpointing_config']['checkpoint_dir'] = FLAGS.checkpoint_root  # pytype: disable=unsupported-operands  # dict-kwargs


In [None]:
if FLAGS.worker_mode == 'train':
  train_loop(experiment_class, config)
elif FLAGS.worker_mode == 'eval':
  eval_loop(experiment_class, config)