In [None]:
#@title LICENSE
# Licensed under the Apache License, Version 2.0

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-research/jaxpruner/blob/main/colabs/sparse_model_vit.ipynb)

In [None]:
import jax
import jax.numpy as jnp
import functools
from jax.experimental import sparse
import flax

from scenic.projects.baselines.configs.imagenet import imagenet_vit_config
from scenic.model_lib import models
from scenic.train_lib import pretrain_utils
from scenic.train_lib import train_utils

In [None]:
# Get model configs.
config = imagenet_vit_config.get_config()
dataset_meta_data = {
    'input_dtype': jax.numpy.float32,
    'input_shape': (-1, 224, 224, 3),
    'num_classes': 1000,
    'num_eval_examples': 50000,
    'num_train_examples': 1281167,
    'target_is_onehot': False}

model_cls = models.get_model_cls(config.model_name)
model = model_cls(config, dataset_meta_data)

In [None]:
# Initialize model.
rng, init_rng = jax.random.split(jax.random.PRNGKey(8))
rng, init_rng = jax.random.split(rng)
placeholder_input = jnp.ones((1, 224, 224, 3))

@functools.partial(jax.jit, backend='cpu')
def _initialize_model(rngs):
  """Initialization function to be jitted."""
  init_params = model.flax_model.init(
      rngs, placeholder_input, train=False, debug=False).pop('params')
  return init_params

init_params = _initialize_model({'params': init_rng})

In [None]:
initial_train_state = train_utils.TrainState(
    global_step=0,
    params=init_params,
    model_state={},
    rng=rng)

In [None]:
# @title Restore model from given checkpoint
init_checkpoint_path = "" # @param {type:"string"}

In [None]:
if init_checkpoint_path:
  restored_train_state = pretrain_utils.restore_pretrained_checkpoint(
      init_checkpoint_path, initial_train_state, assert_exist=True)
  dense_params = restored_train_state.params
  dense_dict = flax.traverse_util.flatten_dict(dense_params)
else:
  # If checkpoint is not given,
  # set params to random array where 90% of the elements are 0.
  dense_dict = {}
  rng = jax.random.PRNGKey(0)
  for k, p in flax.traverse_util.flatten_dict(initial_train_state.params).items():
    rng, cur_rng = jax.random.split(rng)
    dense_arr = jax.random.uniform(cur_rng, shape=p.shape, dtype=p.dtype)
    dense_arr = jnp.where(dense_arr < 0.9, 0, dense_arr)
    dense_dict[k] = dense_arr
  dense_params = flax.traverse_util.unflatten_dict(dense_dict)

In [None]:
# Sparsify model.
def filtered_bcoo_simple(key, param):
  if key[-1] == 'kernel' and 4 > param.ndim > 1:
    return sparse.BCOO.fromdense(param)
  else:
    return param

sparse_dict = {}
for k, p in dense_dict.items():
  sparse_dict[k] = filtered_bcoo_simple(k, p)

sparse_params = flax.traverse_util.unflatten_dict(sparse_dict)

In [None]:
variables = {'params': dense_params}
def dense_model_fwd(x):
  return model.flax_model.apply(variables, x, train=False)

sparse_apply = sparse.sparsify(model.flax_model.apply)
sp_variables = {'params': sparse_params}
def sparse_model_fwd(x):
  return sparse_apply(sp_variables, x, train=False)

In [None]:
x = jnp.ones((1, 224, 224, 3))

# Execution time comparison
%timeit dense_res = dense_model_fwd(x).block_until_ready()
%timeit sparse_res = sparse_model_fwd(x).block_until_ready()

# Max numerical diff
jnp.max(jnp.abs(dense_model_fwd(x) - sparse_model_fwd(x)))