##### Copyright 2021 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<a href="https://colab.research.google.com/github/eemlcommunity/PracticalSessions2021/blob/main/vision/vision_transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ViT Tutorial

This the Colab for the "ViT Tutorial" at [EEML 2021](https://virtual.eeml.eu/).

See also slides at [ViT Tutorial (EEML21)](https://github.com/eemlcommunity/PracticalSessions2021/blob/main/vision/vision_transformers.pdf).

Note: This Colab assumes that you already walked through the EEML introduction Colab at
https://colab.sandbox.google.com/github/eemlcommunity/PracticalSessions2021/blob/main/intro/intro_tutorial.ipynb

**Exercises**:
Whenever you see a `# YOUR ACTION REQUIRED` please try a couple of minutes to
solve the task at hand. You get most out of this tutorial, if you play around
with the provided code and see how your changes modify the behavior (i.e.
additional to the suggested `# YOUR ACTION REQUIRED` tasks).

**Solutions**:
This Colab also contains solutions. By default the solutions are
hidden and you only see the text "solution". In order to see the code for the
solutions, you need to double click on those cells. At any point during the
tutorial, you can also ask questions (and answer them!) in the EEML Slack
channel `#vit`.

Before you start, you probably want to **make a copy** of this Colab so your
changes are not lost:

![save a copy](https://i.imgur.com/Ws5KfqV.png)

### Setup

In [None]:
# Make sure to use a runtime with GPU!
!nvidia-smi -L

In [None]:
!pip install -q flax

In [None]:
from typing import Optional

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
import numpy as np
import optax
import pandas as pd
from tqdm import notebook as tqdm

### 0 - Dataset

In [None]:
# Let's say we want to predict garment type from the Fashion-MNIST dataset.

import tensorflow_datasets as tfds
import tensorflow as tf

ds, info = tfds.load('fashion_mnist', with_info=True)

# Show some images, for exploring the dataset check out KYD:
# https://knowyourdata-tfds.withgoogle.com/#tab=STATS&dataset=fashion_mnist
# (Not too interesting on this toy dataset, but really useful for larger image
# datasets...)
tfds.show_examples(ds['train'], info, rows=1, cols=5)

# Note we have same splits & num_classes as in classical MNIST:
info

In [None]:
# For convenience.
num_classes = info.features['label'].num_classes
image_shape = info.features['image'].shape
num_test = info.splits['test'].num_examples
num_train = info.splits['train'].num_examples

def pp(iter):
  """Preprocesses images/labels for use with JAX."""
  for batch in iter:
    yield (
        jnp.array(batch['image']) / 255.,
        jax.nn.one_hot(batch['label'], num_classes),
    )

# Loading entire dataset makes demonstration code simpler, but only works with
# small models/datasets...
train_images, train_labels = next(pp(iter(ds['train'].batch(num_train))))
test_images, test_labels = next(pp(iter(ds['test'].batch(num_test))))

### 1 - Flax Models

This sections's content is partly new, partly a repetition of what was
already covered in
[intro-tutorial: Flax -- alternative library on top of JAX](https://colab.sandbox.google.com/github/eemlcommunity/PracticalSessions2021/blob/main/intro/intro_tutorial.ipynb#scrollTo=ifFR1Iq9YChf)

Since you're expected to be new to JAX/Flax, this section explains a couple of
core concepts. Of course this is only very shallow introduction! If you're
interested in better understanding what's going on please check out the
[JAX 101 Colabs](https://jax.readthedocs.io/en/latest/jax-101/)
and the
[Flax docs](https://flax.readthedocs.io/),
maybe starting with the
[Annotated MNIST example](https://flax.readthedocs.io/en/latest/notebooks/annotated_mnist.html)
-- after the tutorial.

#### 1.1 - Model

In [None]:
# Defining a model is quite simple:

class Model(nn.Module):

  def setup(self):
    self.dense = nn.Dense(features=10)

  def __call__(self, x):
    batch_size = x.shape[0]
    # Flatten every image, only keeping batch dimension.
    x = x.reshape([batch_size, -1])
    x = self.dense(x)
    # We return normalized logits ("log_probs").
    return nn.log_softmax(x)

# Note how similar the model definition looks to what you would do in PyTorch.
# But be aware that the models behave actually quite differently (more on that
# below).

In [None]:
# 1. Initialize the model - this only sets immutable parameters (none in our
# simple example).
# Note in particular that the `model` instance does NOT include any weights.
model = Model()
model

In [None]:
# 2. Initialize model weights - we need fake input for shape inference and
# PRNGKey for initializing the weights in a deterministic way.
rng = jax.random.PRNGKey(0)
variables = model.init(rng, train_images[:1])

# Show shape of every parameter.
# Note that `nn.Dense()` parameters are stored in the parameter collection
# "dense" because the field is called `self.dense`
jax.tree_map(jnp.shape, variables)

In [None]:
# YOUR ACTION REQUIRED:
# Use `jax.tree_flatten()` to count the total number of parameters of the model.

In [None]:
#@markdown solution - double click to expand
sum(jax.tree_flatten(
    # Every leaf is a `jnp.ndarray` - first compute then number of parameters in
    # every ndarray, then sum them up.
    jax.tree_map(lambda p: np.prod(p.shape), variables)
# jax.tree_flatten returns both the values and the "treedef" that can be used to
# reconstruct the tree - we only need the values here.
)[0])

In [None]:
# As you can see, our simple model only has a single variable collection called
# "params". So let's make this explicit:
params = variables['params']

# In an advanced usecase we might have different variable collections (e.g.
# batch norm), and then we would need to treat them differently. In particular,
# we would not want our optimizer to change anything other than "params". So
# it's better for readability to split out "params" even in this simple case
# where we could have treated all `variables` as params.

In [None]:
# 3. Forward pass.
log_probs, = model.apply({'params': params}, train_images[:1])
# Note that the model returns log probabilities.
plt.bar(range(10), jnp.exp(log_probs))
jnp.exp(log_probs).sum()

In [None]:
# YOUR ACTION REQUIRED:
# Rewrite the model to use `nn.compact`.
# Make sure that the model output is identical.
# Make sure that the model weights are identical.

In [None]:
#@markdown solution
# Defining the same model, this time in "compact" style.

class Model2(nn.Module):

  @nn.compact
  def __call__(self, x):
    batch_size = x.shape[0]
    x = x.reshape([batch_size, -1])
    # If we didn't specify `name` we would get "Dense_0" by default.
    # But we want identical names to compare weights below.
    x = nn.Dense(features=10, name='dense')(x)
    # We return normalized logits ("log_probs").
    return nn.log_softmax(x)

# It's exactly the same, both variables & outputs:
model2 = Model2()
variables2 = model2.init(rng, train_images[:1])
log_probs2 = model2.apply(variables2, train_images[:1])

# Verify weights & output are identical.
# See what happens when you change the `rng` variable (by calling
# `jax.random.split()` to derive a new key and use that one).
sum(jax.tree_flatten(
    jax.tree_multimap(lambda x, x2: jnp.abs(x - x2).sum(), variables, variables2)
)[0]), sum(jax.tree_flatten(
    jax.tree_multimap(lambda x, x2: jnp.abs(x - x2).sum(), log_probs, log_probs2)
)[0])

#### 1.2 - Train + evaluate

In [None]:
def evaluate(params):
  log_probs = model.apply({'params': params}, test_images)
  # Computes accuracy over entire test set in a single go.
  return (log_probs.argmax(axis=-1) == test_labels.argmax(axis=-1)).mean()

# Not surprisingly we perform rather badly withtout training:
evaluate(params)

In [None]:
def loss(params, images, labels):
  log_probs = model.apply({'params': params}, images)
  return -jnp.mean(jnp.sum(labels * log_probs, axis=-1))

# Loss is a single scalar value.
loss(params, train_images, train_labels)

In [None]:
# Note how we compose the function transformations. `jax.jit()` doesn't change
# the output, but it will compile function when it's used for the first time
# (and every time it's re-used with different shapes) and make it a lot faster.
%time grads_fn = jax.jit(jax.grad(loss))

# Beware that we don't specify `model` as an input parameter, so if we later
# change the `model` variable, we need to make sure that the function is
# recompiled (e.g. by re-executing the cell above), otherwise the compiled
# version will still be using the old definition of `model`.
# We will be fixing this issue further down under "2.3 - Training reloaded".

In [None]:
# Compute the gradients. The function will be compiled the first time this is
# executed. You can try re-executing this and the last cell alone or in tandem.
# Note how the initial compilation is orders of magnitude slower than simply
# executing the compiled function.
%time grads = grads_fn(params, train_images, train_labels)

# Gradients have same shape as parameters.
jax.tree_map(jnp.shape, grads)

In [None]:
%%time

def train(params, *, epochs):
  accuracies = [evaluate(params)]
  for epoch in tqdm.trange(epochs):
    grads = grads_fn(params, train_images, train_labels)
    # Manually implement gradient descent.
    params = jax.tree_multimap(
        lambda param, grad: param - 0.01 * grad,
        params, grads
    )
    accuracies.append(evaluate(params))
  plt.plot(accuracies)
  return params

# This is very fast (~2 secs). Note that the final accuracy is still pretty low
# because we're using a very simple optimization and because we use the entire
# training dataset in every update step (instead of batching).
# We'll write a more verbose training loop further down.
trained_params = train(params, epochs=100)

#### 1.3 - Inference

In [None]:
# YOUR ACTION REQUIRED:
# Use `trained_params` to infer labels from images and visualize them.

In [None]:
#@markdown solution
i0, n = 999, 7
for image, logit, axs in zip(
    test_images[i0: i0+n],
    model.apply({'params': trained_params}, test_images[i0: i0+n]),
    zip(*plt.subplots(2, n, figsize=(3*n, 5))[1]),
):
  axs[0].imshow(image[:, :, 0])
  plt.xticks(rotation=90)
  axs[1].bar(list(map(info.features['label'].int2str, range(num_classes))), jnp.exp(logit))
  axs[1].tick_params(labelrotation=90)

### 2 - Vision Transformer in Flax

In this section we will construct our own Vision Transformer. Note that our
version is a bit simplified for readability (e.g. not including dropout).

And of course we know that the Vision Transformer will fail when trained from
scratch on a tiny dataset because it is lacking inductive bias. So the results
we will get by training on `fashion_mnist` will be pretty bad, but the code in
this section still illustrates the basic functioning.

For actually using a Vision Transformer in practice, we would in most cases
start with a pre-trained checkpoint -- and that's exactly what we will be doing
in section
[3 - Exploring pre-trained VITs](#scrollTo=d5IqJjkjg87a)
below...

![Vision Transformer](https://github.com/google-research/vision_transformer/raw/master/vit_figure.png)

Figure 1 from paper
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929)

#### 2.1 - Patch embedding

In [None]:
class PatchEmbedding(nn.Module):

  # Note that the "patch size" is a model parameter. It determines the leading
  # dimension of the linear projection of the patches.
  patch_size: int
  hidden_dim: int
  fake_init: Optional[bool] = False  # Useful to visualize the reshaping.

  def get_init_kw(self):
    return dict(
        kernel_init=nn.initializers.ones, bias_init=nn.initializers.zeros
    ) if self.fake_init else {}

  @nn.compact
  def __call__(self, x):
    n, h, w, c = x.shape
    assert w % self.patch_size == h % self.patch_size == 0
    # 1. Compute the grid - that would be 3x3 in above "Figure 1".
    gw, gh = w // self.patch_size, h // self.patch_size
    # 2. Reshape width and height into grid x patch_size.
    x = x.reshape([n, gh, self.patch_size, gw, self.patch_size, c])
    # 3. Transpose so we have [batch, gh, gw, h // patch, w // patch, channels]
    x = x.transpose([0, 1, 3, 2, 4, 5])
    # 4. Reshape so we're left with a batch of patches.
    x = x.reshape([n, gh * gw, self.patch_size * self.patch_size * c])
    # 5. Project to `hidden_dim`.
    x = nn.Dense(features=self.hidden_dim, **self.get_init_kw())(x)
    # 6. Add learnable position embeddings.
    posembed = self.param('posembed', nn.initializers.zeros, (1,) + x.shape[1:])
    return x + posembed


patch_embedding = PatchEmbedding(patch_size=4, hidden_dim=3, fake_init=True)
params = patch_embedding.init(rng, jnp.ones([1, 8, 8, 1]))['params']

# Note on weight matrix of linear projection: (16, 3) because we have a single
# channel and patch_size=4, so a patch has 4*4*1 scalar values as input. Every
# patch will be projected into a single vector of size 3.
# Note on the position embeddings: (1, 4, 3) because we have a grid with two
# rows and two columns. A single position embedding is a vector of `hidden_dim`.
jax.tree_map(jnp.shape, params)

In [None]:
# Create a fake image with 4 patches of constant value.
checkered = jnp.repeat(jnp.repeat(jnp.array([[1., 2], [3, 4]]), 4, axis=0), 4, axis=1)
checkered

In [None]:
# Thanks to `fake_init=True` every value in hidden_dim will be the summation
# of the individual 16 pixel values above.
# Note that embedded is now a sequence (of "tokens") with length grid width *
# grid height.
embedded = patch_embedding.apply({'params': params}, checkered[None, ..., None])
embedded

In [None]:
# YOUR ACTION REQUIRED:
# Rewrite `PatchEmbedding` to use a single convolution instead of the reshape +
# linear projection. Inspect the weights and make sure you get the same values
# when applying the module.
# You'll want to use the `nn.Conv()` module:
# https://flax.readthedocs.io/en/latest/_modules/flax/linen/linear.html#Conv

In [None]:
#@markdown solution

class PatchEmbedding(nn.Module):

  patch_size: int
  hidden_dim: int
  fake_init: Optional[bool] = False  # Useful to visualize the reshaping.

  def get_init_kw(self):
    return dict(
        kernel_init=nn.initializers.ones, bias_init=nn.initializers.zeros
    ) if self.fake_init else {}

  @nn.compact
  def __call__(self, x):
    n, h, w, c = x.shape
    x = nn.Conv(
        features=self.hidden_dim,
        # Using kernel_size=strides=patch_size will apply the same linear
        # projection for every patch.
        kernel_size=(self.patch_size, self.patch_size),
        strides=(self.patch_size, self.patch_size),
        **self.get_init_kw(),
    )(x)
    x = x.reshape([n, -1, self.hidden_dim])
    posembed = self.param('posembed', nn.initializers.zeros, (1,) + x.shape[1:])
    return x + posembed


patch_embedding = PatchEmbedding(patch_size=4, hidden_dim=3, fake_init=True)
params = patch_embedding.init(rng, jnp.ones([1, 8, 8, 1]))['params']
print(jax.tree_map(jnp.shape, params))
patch_embedding.apply({'params': params}, checkered[None, ..., None])

#### 2.2 - Transformer

In [None]:
class TransformerLayer(nn.Module):

  mlp_dim: int
  num_heads: int

  @nn.compact
  def __call__(self, inputs):
    hidden_dim = inputs.shape[-1]

    x = nn.LayerNorm()(inputs)
    # This is the crucial operation. It allows every token to attend to all
    # other tokens (in the same example).
    # It's a very powerful and generic operation, unfortunately its runtime
    # complexity is also quadratic wrt sequence length (which in turn is
    # quadratic wrt both image_size and 1/patch_size).
    x = nn.MultiHeadDotProductAttention(self.num_heads, deterministic=True)(x, x)

    x = x + inputs  # Residual.

    # MLP with single hidden layer:
    y = nn.LayerNorm()(x)
    y = nn.Dense(self.mlp_dim)(y)
    y = nn.gelu(y)
    y = nn.Dense(hidden_dim)(y)

    return x + y  # Residual.


transformer_layer = TransformerLayer(mlp_dim=128, num_heads=4)

# One sequence, of length 16, 64 hidden dimensions.
seqs = jnp.ones([1, (28 // 7) ** 2, 64])
params = transformer_layer.init(rng, seqs)['params']
transformed_seqs = transformer_layer.apply({'params': params}, seqs)

jax.tree_map(jnp.shape, params)

In [None]:
class VisionTransformer(nn.Module):

  patch_size: int
  hidden_dim: int

  mlp_dim: int
  num_heads: int

  layers: int
  num_classes: int

  @nn.compact
  def __call__(self, x):

    # Embeds patches & adds position embedding.
    x = PatchEmbedding(
        patch_size=self.patch_size,
        hidden_dim=self.hidden_dim
    )(x)

    # Transformer encoder.
    for layer in range(self.layers):
      x = TransformerLayer(
          mlp_dim=self.mlp_dim,
          num_heads=self.num_heads,
      )(x)

    # Different classification heads are possible. Here we use the simplest
    # possible, by first taking the average across the sequence dimension, and
    # then adding a single linear projection on top of it.
    # This is different from the classification approach used in the original
    # ViT paper, where a special "cls" token is prepended to the sequence and a
    # classification MLP is applied on top of that special token at the end of
    # the model. It has been shown (see e.g. Figure 9 in the appendix of
    # https://arxiv.org/abs/2010.11929v2) that simple average pooling can
    # perform as well, if the learning rate is tuned appropriately.
    x = x.mean(axis=-2)
    x = nn.Dense(num_classes)(x)
    return nn.log_softmax(x)


# Given that we have image_size=28, a patch_size=7 gives us a 4x4 grid.
model = VisionTransformer(patch_size=7, hidden_dim=64, mlp_dim=128, num_heads=4, layers=4, num_classes=10)
params = model.init(rng, train_images[:1])['params']

In [None]:
# Let's examine the parameters in some more details:
df = pd.DataFrame([
    dict(path1=k[0], path2=k[1], path3='/'.join(k[2:]), shape=v.shape, params=np.prod(v.shape))
    for k, v in flax.traverse_util.flatten_dict(params.unfreeze()).items()
])
print(f'total {df.params.sum():,}')
df.set_index(['path1', 'path2', 'path3']).sort_index().T.style.set_table_styles([
    # Improve display for multi-level column headers:
    {'selector': 'th', 'props': [('background', '#eee')]}
])

In [None]:
# YOUR ACTION REQUIRED:
# What happens to the parameters when you change to patch_size=4?

In [None]:
#@markdown solution
model2 = VisionTransformer(patch_size=4, hidden_dim=64, mlp_dim=128, num_heads=4, layers=4, num_classes=10)
params2 = model2.init(rng, train_images[:1])['params']

df2 = pd.DataFrame([
    dict(path1=k[0], path2=k[1], path3='/'.join(k[2:]), shape=v.shape, params=np.prod(v.shape))
    for k, v in flax.traverse_util.flatten_dict(params2.unfreeze()).items()
])
# Same overall number of parameters...
print(f'total {df2.params.sum():,}')
# ...because the change in parameters of the position embedding and the
# embedding projection matrix exactly cancel out. That's because 28==7*4. Note
# that changing to e.g. kernel_size=14 does indeed show a (small) change in
# number of parameters.
df.merge(df2, on=['path1', 'path2', 'path3']).query('params_x!=params_y')

Let me reiterate this: Number of parameters does not actually matter that much.
It's the total compute that matters for Vision Transformers.

You can see this for example in Figure 8 from paper
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929):

![Figure 8](https://i.imgur.com/NZIkCQk.png)

We get almost the same increase in performance when changing `patch_size` as we
get by scaling other hyperparameters. But as we show above the number of trainable
parameters hardly changes!

Another paper to bookmark to read after this tutorial:
[Rethinking Parameter Counting in Deep Models: Effective Dimensionality Revisited](https://arxiv.org/abs/2003.02139)

#### 2.3 - Training, reloaded

In [None]:
# Our simple training loop from above "1.2 - Train + evaluate" won't actually
# work for training a Vision Transformer. Check out the new improved training
# loop:

def train(model, params, *, ds, epochs, batch_size=128, lr=0.01):

  # Using a learning rate schedule with a linear warmup stabilizes training
  # a lot.
  steps = int(epochs * num_train / batch_size)
  warmup_steps = int(.1 * steps)
  lr_schedule = optax.join_schedules([
      optax.linear_schedule(0, lr, warmup_steps),
      optax.cosine_decay_schedule(lr, steps - warmup_steps),
  ], boundaries=[warmup_steps])

  tx = optax.chain(
      # Gradient clipping stabilizes training. Not really needed for these small
      # toy models, but useful trick for training larger Vision Transformers.
      optax.clip_by_global_norm(1.0),
      optax.adam(lr_schedule),
  )
  opt_state = tx.init(params)

  # We redefine `loss_fn()` to make sure that it always uses the correct model
  # (which is now explicitely provided as a function argument instead of using
  # a notebook global as before).
  def loss_fn(params, images, labels):
    logits = model.apply({'params': params}, images)
    return -jnp.mean(jnp.sum(labels * logits, axis=-1))

  # We want to JIT compile as large a part of our program as possible. Usually
  # the `train_step()` is perfectly suited for this, as it only involves pure
  # computation, and all data loading/logging is done outside of this function.
  @jax.jit
  def train_step(params, opt_state, images, labels):
    grads = jax.grad(loss_fn)(params, images, labels)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state

  evaluate_ = jax.jit(evaluate)
  accuracies = [evaluate_(params)]
  # Most important change: Train on minibatches instead of using the entire
  # training dataset. This improves the performance a lot, while also allowing
  # for larger models.
  for step, (images, labels) in zip(
      tqdm.trange(steps),
      pp(ds.repeat().batch(batch_size).shuffle(1000)),
  ):
    params, opt_state = train_step(params, opt_state, images, labels)
    if (step + 1) % (steps // 10) == 0 or step + 1 == steps:
      accuracies.append(evaluate_(params))

  plt.plot(accuracies)
  print('final accuracy', accuracies[-1])
  return params

In [None]:
# So let's try a Vision Transformer on this toy dataset.
# As said in the introduction of this section, this is far too little data to
# get to any useful performance! Please don't train Vision Transformers from
# scratch with MNIST-scaled problems...
%%time
model = VisionTransformer(patch_size=7, hidden_dim=64, mlp_dim=128, num_heads=4, layers=4, num_classes=10)
params = model.init(rng, train_images[:1])['params']
trained_params = train(model, params, ds=ds['train'], epochs=1)

In [None]:
# YOUR ACTION REQUIRED:
# Modify hyper parameters and the training loop to get a feel of how the
# result changes.

In [None]:
#@markdown solution

# There's no solution for this exercise.

### 3 - Exploring pre-trained ViTs

In this section we will explore pre-trained checkpoints from the official ViT
repository

https://github.com/google-research/vision_transformer/

Note that the repository has two Colabs for further study of Vision Transformers
(maybe after this tutorial?)

- [vit_jax.ipynb](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax.ipynb) -
  Mimicks the main training loop in Colab, with lots of annotations. Useful to
  understand single host data-parallelism using JAX.
- [vit_jax_augreg.ipynb](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb) -
  Allows you to interactively explore the 50k ViT checkpoints and has code to
  fine-tune them on your own data.

#### 3.0 - Download repository

In [None]:
![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer
!pip install -qr vision_transformer/vit_jax/requirements.txt
import sys
if './vision_transformer' not in sys.path:
  sys.path.append('./vision_transformer')

#### 3.1 - Exploring checkpoints

For a detailed description of the checkpoints, refer to the paper

[How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270)

We'll use Pandas to explore the checkpoints. You might want to have their
[Cheatsheet](https://pandas.pydata.org/Pandas_Cheat_Sheet.pdf)
ready...

In [None]:
from vit_jax import checkpoint
df = checkpoint.get_augreg_df()
print(checkpoint.get_augreg_df.__doc__)

In [None]:
# Note that columns NOT starting with "adapt_" are about the pre-training.
# For example, we have 756 pretrained checkpoints.
len(df.filename.unique())

In [None]:
# We can select a single pre-trained checkpoint by its AugReg settings:
filename = df.query(
    'name=="B/32"'  # Choose a single model.
    'and ds=="i21k"'  # Upstream dataset -- i21k is best.
    'and aug=="light1"'  # Data augmentation.
    'and wd==0.1 and sd==0.0'  # Model regularization.
).filename.unique()[0]
filename

In [None]:
# That's exactly the basename of the checkpointinside the storage bucket
# gs://vit_models/augreg
!gsutil ls -lh gs://vit_models/augreg/{filename}.npz

Wondering how to select the "best" model? That really depends on what you want
to achieve. In general larger (=slower) models lead to better results:

![VTAB results different models](https://i.imgur.com/TlM0u9F.png)

Figure 3 from paper
[How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270)

In [None]:
# Note by the way that above checkpoint is the setting yielding the best
# upstream validation accuracy:
df.query(
    # Only focus on a single model
    'name=="B/32"'
    'and ds=="i21k"'
).sort_values('final_val').iloc[-1].filename

In [None]:
# Note that we still have lots of rows in the dataframe with this one upstream
# checkpoint because the dataframe also includes data for the fine-tunings:
len(df.query(f'filename=="{filename}"'))

In [None]:
# Now let's find the best downstream adaptation of this checkpoint on cifar100:
adapt_filename = df.query(
    f'filename=="{filename}" and adapt_ds=="cifar100"'
).sort_values('adapt_final_val').iloc[-1].adapt_filename
# Note that adaptation parameters are encoded in the filename after the "--"
adapt_filename

#### 3.2 - Loading checkpoints

In [None]:
# From now on we're going to use code from the repository that we downloaded
# above.
from vit_jax import checkpoint

In [None]:
# Let's open that file in an editor
from google.colab import files
files.view('vision_transformer/vit_jax/checkpoint.py')

In [None]:
# And let's auto-reload any changes you make in that editor.
# Note though that the changes WILL NOT be persisted because the local hard
# drive will be reset when the VM is shut down (after some idle time).
%load_ext autoreload
%autoreload 2

In [None]:
params = checkpoint.load(f'gs://vit_models/augreg/{adapt_filename}.npz')

In [None]:
# YOUR ACTION REQUIRED:
# Check out the library code in the right hand editor while the checkpoint is
# being loaded

In [None]:
# YOUR ACTION REQUIRED:
# How can you deduce what was the resolution used for fine-tuning this model by
# looking only at the parameters?
# Note that the parameter structure is slightly different. You can read the
# models.py file, or simply inspect `params[...].keys()` to find the telling
# shapes...

# files.view('vision_transformer/vit_jax/models.py')

In [None]:
#@markdown solution

patch_size, patch_size, channels, hidden_dims = (
    params['embedding']['kernel'].shape)
_, sequence_length, hidden_dims = (
    params['Transformer']['posembed_input']['pos_embedding'].shape)
grid_size = (sequence_length - 1) ** .5  # Subtract the "cls" token.
grid_size * patch_size

In [None]:
# Visualization of the similarity between the learned patch embeddings.

def figure9(params):
  """Plots embedding pairwise cosine similarity, see Figure 9 in ViT paper."""
  patches = params['Transformer']['posembed_input']['pos_embedding'][0]
  patches = patches[1:]  # Remove cls token.
  width = height = int(len(patches) ** .5)
  _, axs = plt.subplots(height, width, figsize=(width, height))
  for patch1, ax in zip(patches, np.array(axs).flatten()):
    ax.matshow(np.array([
        patch1.dot(patch2) / np.linalg.norm(patch1) / np.linalg.norm(patch2)
        for patch2 in patches
    ]).reshape([height, width]))
    ax.set_xticks([]); ax.set_yticks([])

figure9(params)

In [None]:
ds, info = tfds.load('cifar100', with_info=True)
num_classes = info.features['label'].num_classes
int2str = info.features['label'].int2str
d = next(iter(ds['train']))
image = d['image']
plt.matshow(image)
image.shape, int2str(d['label'])

In [None]:
from vit_jax.configs import models as models_config  # Model configurations.
from vit_jax import models  # Actual model code.
files.view('vision_transformer/vit_jax/configs/models.py')
files.view('vision_transformer/vit_jax/models.py')

In [None]:
# Instantiate the matching model from the repo.
config = models_config.AUGREG_CONFIGS[adapt_filename.split('-')[0]]
vit = models.VisionTransformer(**config, num_classes=num_classes)

In [None]:
# YOUR ACTION REQUIRED:
# Modify `image` according to input_pieline.py and predict logits using `model`.
files.view('vision_transformer/vit_jax/input_pipeline.py')

In [None]:
#@markdown solution
# Expected input resolution: 224 pixels.
# Input is expected to be normalized -1..1 range.
image_ = jnp.array(tf.image.resize(image, [224, 224]) / 127.5 - 1)
logits = vit.apply({'params': params}, image_[None], train=False)

plt.figure(figsize=(17, 3))
plt.bar(list(map(int2str, range(num_classes))), logits[0])
plt.xticks(rotation=90);

#### 3.3 - Fine-tuning checkpoints

In [None]:
# Instead of scaling the image as we did in "3.2 - Loading checkpoints", we can
# instead re-generate the position embeddings to match the new grid size. We
# provide the function `checkpoint.load_pretrained()` for doing this:
config = models_config.AUGREG_CONFIGS[filename.split('-')[0]]
vit = models.VisionTransformer(**config, num_classes=num_classes)

# We need "template" parameters that depend on model & image size.
init_params = vit.init(rng, np.ones([1, 5*32, 5*32, 3]), train=False)['params']

In [None]:
from absl import logging
logging.set_verbosity(logging.INFO)  # Show logging messages for rescaling etc.

params = checkpoint.load_pretrained(
    pretrained_path=f'gs://vit_models/augreg/{filename}.npz',
    init_params=init_params,
    model_config=config,
)

In [None]:
# The rescaled position embeddings still have the expected similarity structure.
figure9(params)

In [None]:
del params  # Only used for demonstrating scaling.

In [None]:
# For the fine-tuning let's do it a bit differently from above: Instead of using
# our hand-rolled `train()` we'll instead be using the function
# `train_and_evaluate()` provided in the repo:
files.view('vision_transformer/vit_jax/train.py')
from vit_jax import train as train_lib

In [None]:
# Let's load a config that also contains default parameters for training...
files.view('vision_transformer/vit_jax/configs/augreg.py')
from vit_jax.configs import augreg as augreg_config
config = augreg_config.get_config(filename)

In [None]:
# ...and then adapt some of those config parameters:

# Note: We're not using fashion_mnist here because it's really a toy dataset
# and its "images" are not quite representative.
# Check out the TFDS catalog for more image classification datasets:
# https://www.tensorflow.org/datasets/catalog/overview
config.dataset = 'oxford_flowers102'
config.pp.train = 'train'
config.pp.test = 'test'

# Some more parameters that you will often want to set manually.
# For example for VTAB we used steps={500, 2500} and lr={.001, .003, .01, .03}
config.base_lr = 0.01
config.shuffle_buffer = 1000
config.total_steps = 100
config.warmup_steps = 10
config.accum_steps = 4  # Might need to be adjusted depending on model.
config.pp['crop'] = 224

In [None]:
# Launch tensorboard before training - maybe click "reload" during training.
%load_ext tensorboard
%tensorboard --logdir=./workdirs

In [None]:
# Create a new temporary workdir.
import time
workdir = f'./workdirs/{int(time.time())}'
workdir

In [None]:
# This cell takes ~10 min to finish (using B/32 & flowers102).
opt = train_lib.train_and_evaluate(config, workdir)

In [None]:
# Note that function returns a flax.optim Optimizer storing params as "target".
trained_params = opt.target

In [None]:
# Create model with same settings used for training above.
vit = models.VisionTransformer(
    **config.model, num_classes=len(trained_params['head']['bias']))

In [None]:
# Create input pipeline with same settings as used for training above.
from vit_jax import input_pipeline

ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test')
int2str = tfds.builder(config.dataset).info.features['label'].int2str

batch = next(iter(ds_test))

In [None]:
# Plot top predictions from learned parameters.
i0, n = 12, 7
images = batch['image'][0, i0: i0+n]
for image, logit, axs in zip(
    images,
    vit.apply({'params': trained_params}, images, train=False),
    zip(*plt.subplots(2, n, figsize=(3*n, 5))[1]),
):
  axs[0].imshow(image / 2 + 0.5)
  idx = logit.argsort()[::-1][:10]
  plt.xticks(rotation=90)
  axs[1].bar(list(map(int2str, idx)), logit[idx])
  axs[1].tick_params(labelrotation=90)

In [None]:
# YOUR ACTION REQUIRED:
# While the model is training, you can check out the training code in the
# repository (in file "train.py" on the right side).
# You could try to adapt our `train()` function from above to make it work with
# the pre-trained checkpoints.
# 1. In particular, you'll need to handle `train` correctly and provide PRNG
#    keys for dropout.
# 2. Note that the returned values are simply logits, not log_probs.

In [None]:
#@markdown solution

# There's no solution for this exercise.

If you're interested in fine-tuning for real, you probably want to have a look
at the
[vit_jax_augreg.ipynb](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb)
Colab that also contains code to fine-tune on data that is not in tfds format.

Do you prefer PyTorch even after finishing this tutorial? The good news is that
you can directly import any of these checkpoints using the `timm` package. Check
out the Colab above for details.

### 4 - Wrapup

It would be great if you could tell me what you liked (or not) about this
tutorial, so I can make it better for the next iteration!

[**Feedback form** (1 minute version)](https://forms.gle/aritdqKJVaMPDYxD6)