Note: The evals here have been run on GPU so they may not exactly match the results reported in the paper which were run on TPUs, however the difference in accuracy should not be more than 0.1%.

# Setup

In [4]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [5]:
CROP_PROPORTION = 0.875  # Standard for ImageNet.
HEIGHT = 224
WIDTH = 224

def _compute_crop_shape(
    image_height, image_width, aspect_ratio, crop_proportion):
  """Compute aspect ratio-preserving shape for central crop.

  The resulting shape retains `crop_proportion` along one side and a proportion
  less than or equal to `crop_proportion` along the other side.

  Args:
    image_height: Height of image to be cropped.
    image_width: Width of image to be cropped.
    aspect_ratio: Desired aspect ratio (width / height) of output.
    crop_proportion: Proportion of image to retain along the less-cropped side.

  Returns:
    crop_height: Height of image after cropping.
    crop_width: Width of image after cropping.
  """
  image_width_float = tf.cast(image_width, tf.float32)
  image_height_float = tf.cast(image_height, tf.float32)

  def _requested_aspect_ratio_wider_than_image():
    crop_height = tf.cast(tf.math.rint(
        crop_proportion / aspect_ratio * image_width_float), tf.int32)
    crop_width = tf.cast(tf.math.rint(
        crop_proportion * image_width_float), tf.int32)
    return crop_height, crop_width

  def _image_wider_than_requested_aspect_ratio():
    crop_height = tf.cast(
        tf.math.rint(crop_proportion * image_height_float), tf.int32)
    crop_width = tf.cast(tf.math.rint(
        crop_proportion * aspect_ratio *
        image_height_float), tf.int32)
    return crop_height, crop_width

  return tf.cond(
      aspect_ratio > image_width_float / image_height_float,
      _requested_aspect_ratio_wider_than_image,
      _image_wider_than_requested_aspect_ratio)


def center_crop(image, height, width, crop_proportion):
  """Crops to center of image and rescales to desired size.

  Args:
    image: Image Tensor to crop.
    height: Height of image to be cropped.
    width: Width of image to be cropped.
    crop_proportion: Proportion of image to retain along the less-cropped side.

  Returns:
    A `height` x `width` x channels Tensor holding a central crop of `image`.
  """
  shape = tf.shape(image)
  image_height = shape[0]
  image_width = shape[1]
  crop_height, crop_width = _compute_crop_shape(
      image_height, image_width, height / width, crop_proportion)
  offset_height = ((image_height - crop_height) + 1) // 2
  offset_width = ((image_width - crop_width) + 1) // 2
  image = tf.image.crop_to_bounding_box(
      image, offset_height, offset_width, crop_height, crop_width)
  image = tf.image.resize(image, [height, width],
                          method=tf.image.ResizeMethod.BICUBIC)
  return image

def preprocess_for_eval(image, height, width):
  """Preprocesses the given image for evaluation.

  Args:
    image: `Tensor` representing an image of arbitrary size.
    height: Height of output image.
    width: Width of output image.

  Returns:
    A preprocessed image `Tensor`.
  """
  image = center_crop(image, height, width, crop_proportion=CROP_PROPORTION)
  image = tf.reshape(image, [height, width, 3])
  image = tf.clip_by_value(image, 0., 1.)
  return image

def preprocess_image(features):
  """Preprocesses the given image.

  Args:
    image: `Tensor` representing an image of arbitrary size.

  Returns:
    A preprocessed image `Tensor` of range [0, 1].
  """
  image = features["image"]
  image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  image = preprocess_for_eval(image, HEIGHT, WIDTH)
  features["image"] = image
  return features

Load dataset

In [6]:
BATCH_SIZE = 50
ds = tfds.load(name='imagenet2012', split='validation').map(preprocess_image).batch(BATCH_SIZE).prefetch(1)

In [7]:
def eval(model_path, log=False):
  if log:
    print("Loading model from %s" % model_path)
  model = tf.saved_model.load(model_path)
  if log:
    print("Loaded model!")
  top_1_accuracy = tf.keras.metrics.Accuracy('top_1_accuracy')
  for i, features in enumerate(ds):
    logits = model(features["image"], trainable=False)["logits_sup"]
    top_1_accuracy.update_state(features["label"], tf.argmax(logits, axis=-1))
    if log and (i + 1) % 50 == 0:
      print("Finished %d examples" % ((i + 1) * BATCH_SIZE))
  return top_1_accuracy.result().numpy().astype(float)

# SimCLR v2

Finetuned models

In [8]:
path_pat = "gs://simclr-checkpoints-tf2/simclrv2/finetuned_{pct}pct/r{depth}_{width_multiplier}x_sk{sk}/saved_model/"
results = {}

for resnet_depth in (50, 101, 152):
  for width_multiplier in (1, 2):
    for sk in (0, 1):
      for pct in (1, 10, 100):
        path = path_pat.format(pct=pct, depth=resnet_depth, width_multiplier=width_multiplier, sk=sk)
        results[path] = eval(path)
        print(path)
        print("Top-1: %.1f" % (results[path] * 100))

resnet_depth = 152
width_multiplier = 3
sk = 1
for pct in (1, 10, 100):
  path = path_pat.format(pct=pct, depth=resnet_depth, width_multiplier=width_multiplier, sk=sk)
  results[path] = eval(path)
  print(path)
  print("Top-1: %.1f" % (results[path] * 100))

gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r50_1x_sk0/saved_model/
Top-1: 58.0
gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r50_1x_sk0/saved_model/
Top-1: 68.4
gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r50_1x_sk0/saved_model/
Top-1: 76.3
gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r50_1x_sk1/saved_model/
Top-1: 64.5
gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r50_1x_sk1/saved_model/
Top-1: 72.0
gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r50_1x_sk1/saved_model/
Top-1: 78.6
gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r50_2x_sk0/saved_model/
Top-1: 66.2
gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r50_2x_sk0/saved_model/
Top-1: 73.9
gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r50_2x_sk0/saved_model/
Top-1: 79.1
gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r50_2x_sk1/saved_model/
Top-1: 70.7
gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r50_2x_sk1/saved_model/
Top-1: 77.0
gs://simclr-checkpoint

Supervised

In [10]:
path_pat = "gs://simclr-checkpoints-tf2/simclrv2/supervised/r{depth}_{width_multiplier}x_sk{sk}/saved_model/"
results = {}

for resnet_depth in (50, 101, 152):
  for width_multiplier in (1, 2):
    for sk in (0, 1):
      path = path_pat.format(depth=resnet_depth, width_multiplier=width_multiplier, sk=sk)
      results[path] = eval(path)
      print(path)
      print("Top-1: %.1f" % (results[path] * 100))

resnet_depth = 152
width_multiplier = 3
sk = 1
path = path_pat.format(depth=resnet_depth, width_multiplier=width_multiplier, sk=sk)
results[path] = eval(path)
print(path)
print("Top-1: %.1f" % (results[path] * 100))

gs://simclr-checkpoints-tf2/simclrv2/supervised/r50_1x_sk0/saved_model/
Top-1: 76.6
gs://simclr-checkpoints-tf2/simclrv2/supervised/r50_1x_sk1/saved_model/
Top-1: 78.5
gs://simclr-checkpoints-tf2/simclrv2/supervised/r50_2x_sk0/saved_model/
Top-1: 77.8
gs://simclr-checkpoints-tf2/simclrv2/supervised/r50_2x_sk1/saved_model/
Top-1: 79.3
gs://simclr-checkpoints-tf2/simclrv2/supervised/r101_1x_sk0/saved_model/
Top-1: 78.0
gs://simclr-checkpoints-tf2/simclrv2/supervised/r101_1x_sk1/saved_model/
Top-1: 79.6
gs://simclr-checkpoints-tf2/simclrv2/supervised/r101_2x_sk0/saved_model/
Top-1: 78.8
gs://simclr-checkpoints-tf2/simclrv2/supervised/r101_2x_sk1/saved_model/
Top-1: 80.1
gs://simclr-checkpoints-tf2/simclrv2/supervised/r152_1x_sk0/saved_model/
Top-1: 78.2
gs://simclr-checkpoints-tf2/simclrv2/supervised/r152_1x_sk1/saved_model/
Top-1: 80.0
gs://simclr-checkpoints-tf2/simclrv2/supervised/r152_2x_sk0/saved_model/
Top-1: 79.1
gs://simclr-checkpoints-tf2/simclrv2/supervised/r152_2x_sk1/saved_mod

Pretrained with linear eval

In [9]:
path_pat = "gs://simclr-checkpoints-tf2/simclrv2/pretrained/r{depth}_{width_multiplier}x_sk{sk}/saved_model/"
results = {}

for resnet_depth in (50, 101, 152):
  for width_multiplier in (1, 2):
    for sk in (0, 1):
      path = path_pat.format(depth=resnet_depth, width_multiplier=width_multiplier, sk=sk)
      results[path] = eval(path)
      print(path)
      print("Top-1: %.1f" % (results[path] * 100))

resnet_depth = 152
width_multiplier = 3
sk = 1
path = path_pat.format(depth=resnet_depth, width_multiplier=width_multiplier, sk=sk)
results[path] = eval(path)
print(path)
print("Top-1: %.1f" % (results[path] * 100))

gs://simclr-checkpoints-tf2/simclrv2/pretrained/r50_1x_sk0/saved_model/
Top-1: 71.7
gs://simclr-checkpoints-tf2/simclrv2/pretrained/r50_1x_sk1/saved_model/
Top-1: 74.6
gs://simclr-checkpoints-tf2/simclrv2/pretrained/r50_2x_sk0/saved_model/
Top-1: 75.4
gs://simclr-checkpoints-tf2/simclrv2/pretrained/r50_2x_sk1/saved_model/
Top-1: 77.8
gs://simclr-checkpoints-tf2/simclrv2/pretrained/r101_1x_sk0/saved_model/
Top-1: 73.7
gs://simclr-checkpoints-tf2/simclrv2/pretrained/r101_1x_sk1/saved_model/
Top-1: 76.3
gs://simclr-checkpoints-tf2/simclrv2/pretrained/r101_2x_sk0/saved_model/
Top-1: 77.0
gs://simclr-checkpoints-tf2/simclrv2/pretrained/r101_2x_sk1/saved_model/
Top-1: 79.1
gs://simclr-checkpoints-tf2/simclrv2/pretrained/r152_1x_sk0/saved_model/
Top-1: 74.6
gs://simclr-checkpoints-tf2/simclrv2/pretrained/r152_1x_sk1/saved_model/
Top-1: 77.3
gs://simclr-checkpoints-tf2/simclrv2/pretrained/r152_2x_sk0/saved_model/
Top-1: 77.4
gs://simclr-checkpoints-tf2/simclrv2/pretrained/r152_2x_sk1/saved_mod

# SimCLR v1

Finetuned

In [13]:
path_pat = "gs://simclr-checkpoints-tf2/simclrv1/finetune_{pct}pct/{width_multiplier}x/saved_model/"
results = {}

resnet_depth = 50
for pct in (10, 100):
  for width_multiplier in (1, 2, 4):
    path = path_pat.format(pct=pct, width_multiplier=width_multiplier)
    results[path] = eval(path)
    print(path)
    print("Top-1: %.1f" % (results[path] * 100))

gs://simclr-checkpoints-tf2/simclrv1/finetune_10pct/1x/saved_model/
Top-1: 65.8
gs://simclr-checkpoints-tf2/simclrv1/finetune_10pct/2x/saved_model/
Top-1: 71.6
gs://simclr-checkpoints-tf2/simclrv1/finetune_10pct/4x/saved_model/
Top-1: 74.5
gs://simclr-checkpoints-tf2/simclrv1/finetune_100pct/1x/saved_model/
Top-1: 75.6
gs://simclr-checkpoints-tf2/simclrv1/finetune_100pct/2x/saved_model/
Top-1: 79.2
gs://simclr-checkpoints-tf2/simclrv1/finetune_100pct/4x/saved_model/
Top-1: 80.8


Pretrained with linear eval

In [12]:
path_pat = "gs://simclr-checkpoints-tf2/simclrv1/pretrain/{width_multiplier}x/saved_model/"
results = {}

resnet_depth = 50
for width_multiplier in (1, 2, 4):
  path = path_pat.format(width_multiplier=width_multiplier)
  results[path] = eval(path)
  print(path)
  print("Top-1: %.1f" % (results[path] * 100))

gs://simclr-checkpoints-tf2/simclrv1/pretrain/1x/saved_model/
Top-1: 69.0
gs://simclr-checkpoints-tf2/simclrv1/pretrain/2x/saved_model/
Top-1: 74.2
gs://simclr-checkpoints-tf2/simclrv1/pretrain/4x/saved_model/
Top-1: 76.6
