## Evaluate E$^3$-S/32, with 8 experts, pre-trained on ILSVRC2021 and fine-tuned on CIFAR100

In [1]:
import jax
from jax import numpy as jnp

import tensorflow as tf

from vmoe.nn import models
from vmoe.data import input_pipeline
from vmoe.checkpoints import partitioned

from vmoe.configs.vmoe_paper.vmoe_s32_last2_ilsvrc2012_randaug_light1_ft_ilsvrc2012 import get_config, IMAGE_SIZE, BATCH_SIZE

import os
# change configuration in the above file.
os.environ['CUDA_VISIBLE_DEVICES'] = ''
_ = """
Adapted from vmoe/notebooks/demo_eee_CIFAR100.ipynb by Michael Li
Structure:
vmoe
    vmoe/
    this notebook
    vit_jax/ (from vision_transformer)
    vmoe_s32_last2_ilsvrc2012_randaug_light1_ft_ilsvrc2012.data-00000-of-00001
    vmoe_s32_last2_ilsvrc2012_randaug_light1_ft_ilsvrc2012.index
"""

  from .autonotebook import tqdm as notebook_tqdm


### Construct model

In [2]:
model_config = get_config()
# print(model_config)
model_cls = getattr(models, model_config.model.name)
model = model_cls(deterministic=True, **model_config.model)
# print(type(model))
# print(model_config)

### Load weights

In [3]:
# using this model: 'gs://vmoe_checkpoints/vmoe_s32_last2_ilsvrc2012_randaug_light1_ft_ilsvrc2012'
checkpoint_prefix = 'vmoe_s32_last2_ilsvrc2012_randaug_light1_ft_ilsvrc2012'
checkpoint = partitioned.restore_checkpoint(prefix=checkpoint_prefix, tree=None)

print(checkpoint.keys())

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


dict_keys(['Encoder', 'cls', 'embedding', 'head'])


### Create dataset

In [4]:
dataset_config_test = model_config.dataset.test
dataset_test = input_pipeline.get_dataset(
    variant='test',
    name=dataset_config_test.name, 
    split=dataset_config_test.split, 
    batch_size=dataset_config_test.batch_size, 
    process=dataset_config_test.process
)

# dataset_config_test_real = model_config.dataset.test_real
# dataset_test_real = input_pipeline.get_dataset(
#     variant='test',
#     name=dataset_config_test_real.name, 
#     split=dataset_config_test_real.split, 
#     batch_size=dataset_config_test_real.batch_size, 
#     process=dataset_config_test_real.process
# )

def process_indices(indices_distr, class_lbl, fdir, fname):
    # indices_distr has shape (8, 55808, 512) for batch_size = 1024
    return


def gen_data(model, dataset, checkpoint):
    ncorrect = 0
    ntotal = 0
    i = 0

    for batch in dataset:
        # The final batch has been padded with fake examples so that the batch size is
        # the same as all other batches. The mask tells us which examples are fake.
        mask = batch['__valid__']
        if jnp.sum(mask) != BATCH_SIZE:  # if there are some padded fake data inside of the current batch
            break
        # print(mask.shape)  # array of shape batch_size with boolean
        logits, _, indices_distr = model.apply({'params': checkpoint}, batch['image'])
    
        log_p = jax.nn.log_softmax(logits)
        preds = jnp.argmax(log_p, axis=1)
        true_lbl = jnp.argmax(batch['labels'], axis=1)
        process_indices(indices_distr, class_lbl=true_lbl, fdir='', fname='')

        ncorrect += jnp.sum((preds == true_lbl) * mask)
        ntotal += jnp.sum(mask)
        # if i % 10 == 0:
        #   print(f'Test accuracy, iteration {i}: {ncorrect / ntotal * 100:.2f}%')
        i += 1
        break
    print(f'Test accuracy, iteration: {ncorrect / ntotal * 100:.2f}%')
    return indices_distr

ind_dist = gen_data(model, dataset_test, checkpoint)

2024-04-09 17:42:14.062093: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


self.hidden_size 512, self.patch_size (32, 32), self.patch_size (32, 32)
VisionTransformerMoe.__call__, x.shape (32, 144, 512)
x.shape (32, 144, 512) before jnp.concatenate([cls, x], axis=1)
self.param('cls', nn.initializers.zeros, (1, 1, x.shape[-1])), cls.shape (1, 1, 512)
jnp.tile(cls, [x.shape[0], 1, 1]) has shape (32, 1, 512)
x.shape (32, 145, 512) immediately before self.encoder_cls
in EncoderMoe, inputs.shape (32, 145, 512)
in EncoderMoe, x.shape (32, 145, 512)
block in EncoderMoe 0, x.shape = (32, 145, 512)
block in EncoderMoe 1, x.shape = (32, 145, 512)
block in EncoderMoe 2, x.shape = (32, 145, 512)
block in EncoderMoe 3, x.shape = (32, 145, 512)
block in EncoderMoe 4, x.shape = (32, 145, 512)
block in EncoderMoe 5, x.shape = (32, 145, 512)
gates.shape (1160, 8)
buffer_idx.shape (1160, 8)
dispatch_weights.shape (1160, 8, 436)
indices.shape (4, 1160, 512), inputs[0].shape (4, 1160, 512) before dispatch in transformed
self.dispatch_weights.shape (4, 1160, 8, 436) in EinsumDispa

In [5]:
# i, j, k = ind_dist['idx_5'].shape
# dd = {}
# for ii in range(i):
#     for jj in range(j):
#         if int(ind_dist['idx_5'][ii, jj, 0]) != 0:
#             mul = int(ind_dist['idx_5'][ii, jj, 0] - 1) * 145 + int(ind_dist['idx_5'][ii, jj, 1] - 1)
#             if mul not in dd.keys():
#                 dd[mul] = 1
#             else:
#                 dd[mul] += 1

# import numpy as np
# import matplotlib.pyplot as plt
# # for key in dd.keys():
# #     if dd[key] == 66:
# #         print(key)
# # print(max(dd.values()))
# plt.bar(np.array(list(dd.keys())), np.array(list(dd.values())))
# plt.xlabel("patch index")
# plt.ylabel("number of times the patch is assigned")

In [6]:
# print(np.average(list(dd.values())))