## Evaluate E$^3$-S/32, with 8 experts, pre-trained on ILSVRC2021 and fine-tuned on CIFAR100

In [None]:
import jax
from jax import numpy as jnp
import numpy as np
import tensorflow_datasets as tfds
from tqdm.auto import tqdm

from vmoe.nn import models
from vmoe.data import input_pipeline
from vmoe.checkpoints import partitioned

### Construct model

In [None]:
BATCH_SIZE = 1024    # Number of images processed in each step.
NUM_CLASSES = 100     # Number of CIFAR100 classes.
IMAGE_SIZE = 128     # Image size as input to the model.
PATCH_SIZE = 32      # Patch size.
NUM_LAYERS = 8      # Number of encoder blocks in the transformer.
NUM_EXPERTS = 8      # Number of experts in each MoE layer.
NUM_SELECTED_EXPERTS = 1  # Maximum number of selected experts per token.
ENSEMBLE_SIZE = 2
NUM_EXPERTS_PER_ENS_MEMBER = NUM_EXPERTS // ENSEMBLE_SIZE
NUM_TOKENS_PER_IMAGE = (IMAGE_SIZE // PATCH_SIZE)**2 + 1
NUM_DEVICES = 8
GROUP_SIZE_PER_ENS_MEMBER = (BATCH_SIZE // NUM_DEVICES) * NUM_TOKENS_PER_IMAGE

model_config = {
      'name': 'VisionTransformerMoeEnsemble',
      'num_classes': NUM_CLASSES,
      'patch_size': (32, 32),
      'hidden_size': 512,
      'classifier': 'token',
      'representation_size': 512,
      'head_bias_init': -10.0,
      'encoder': {
          'num_layers': NUM_LAYERS,
          'num_heads': 8,
          'mlp_dim': 2048,
          'dropout_rate': 0.0,
          'attention_dropout_rate': 0.0,
          'moe': {
              'ensemble_size': ENSEMBLE_SIZE,
              'num_experts': NUM_EXPERTS,
              'group_size': GROUP_SIZE_PER_ENS_MEMBER * ENSEMBLE_SIZE,
              'layers': (5, 7),
              'dropout_rate': 0.0,
              'router': {
                  'num_selected_experts': NUM_SELECTED_EXPERTS,
                  'noise_std': 1.0,  # This is divided by NUM_EXPERTS.
                  'importance_loss_weight': 0.005,
                  'load_loss_weight': 0.005,
                  'dispatcher': {
                      'name': 'einsum',
                      'bfloat16': True,
                      # If we have group_size tokens per group, with a balanced
                      # router, the expected number of tokens per expert is:
                      # group_size * num_selected_experts / num_experts.
                      # To account for deviations from the average, we give some
                      # multiplicative slack to this expected number:
                      'capacity_factor': 1.5,
                      # This is used to hint pjit about how data is distributed
                      # at the input/output of each MoE layer.
                      # This value means that the tokens are partitioned across
                      # all devices in the mesh (i.e. fully data parallelism).
                      'partition_spec': (('expert', 'replica'),),
                      # We don't use batch priority for training/fine-tuning.
                      'batch_priority': False,
                  },
              },
          },
      },
  }

In [None]:
model_cls = getattr(models, model_config.pop('name'))
model = model_cls(deterministic=True, **model_config)

### Load weights

In [None]:
# Path to the fine-tuned checkpoint.
checkpoint_prefix = 'gs://vmoe_checkpoints/eee_s32_last2_ilsvrc2012_ft_cifar100'
mesh = partitioned.Mesh(np.asarray(jax.devices()), ('d',))
checkpoint = partitioned.restore_checkpoint(
    prefix=checkpoint_prefix, tree=None, axis_resources=None, mesh=mesh)

### Create dataset

In [None]:
process = f'keep("image", "label")|decode|resize({IMAGE_SIZE}, inkey="image")|value_range(-1,1)'

dataset = input_pipeline.get_dataset(
    variant='test',
    name='cifar100',
    split='test',
    batch_size=BATCH_SIZE,
    process=process,
)

### Run evaluation loop

In [None]:
 ncorrect = 0
 ntotal = 0
 for batch in dataset:
  # The final batch has been padded with fake examples so that the batch size is
  # the same as all other batches. The mask tells us which examples are fake.
  mask = batch['__valid__']

  logits, _ = model.apply({'params': checkpoint}, batch['image'])
  # logits shape: (BATCH_SIZE * ENSEMBLE_SIZE, NUM_CLASSES).
  logits = jnp.reshape(logits, (-1, ENSEMBLE_SIZE, NUM_CLASSES))
  # Note: In the paper, we describe the implementation of E^3 with a jnp.tile
  # mechanism. In this open-sourced version of the code, because of the pjit
  # backend, we use jnp.repeat instead of jnp.tile for efficiency reasons. This
  # explains the reshaping above that, in the paper implementation, would have
  # been jnp.reshape(logits, (ENSEMBLE_SIZE, -1, NUM_CLASSES) followed by
  # jax.nn.logsumexp(log_p, axis=0).
  log_p = jax.nn.log_softmax(logits)
  mean_log_p = jax.nn.logsumexp(log_p, axis=1) - jnp.log(ENSEMBLE_SIZE)

  preds = jnp.argmax(mean_log_p, axis=1)
  ncorrect += jnp.sum((preds == batch['label']) * mask)
  ntotal += jnp.sum(mask)

print(f'Test accuracy: {ncorrect / ntotal * 100:.2f}%')  # Should be 81.26%.