# DIP Answers

- **Answer Set**: Final Project
- **Full Name**: Mohammad Hosein Nemati
- **Student Code**: `610300185`

---

## Introduction

In this problem, we are going to change some parameters in order to optimize the accuracy of [**Nested Hierarchical Transformer**](https://github.com/google-research/nested-transformer) model for **CIFAR10** dataset.  
Then we will compare the reported metrics to previously trained models in `article`

In the first step, we will install required libraries.

In [1]:
!pip install absl-py
!pip install clu==0.0.3
!pip install flax==0.3.4
!pip install jax==0.2.14
!pip install jaxlib==0.1.67
!pip install ml_collections
!pip install tensorflow-cpu==2.5.0
!pip install tensorflow-datasets==4.3.0
!pip install tensorflow_addons==0.13.0

!git clone https://github.com/ckoliber/dipexercises

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting clu==0.0.3
  Downloading clu-0.0.3-py3-none-any.whl (73 kB)
[K     |████████████████████████████████| 73 kB 1.4 MB/s 
Collecting flax
  Downloading flax-0.5.3-py3-none-any.whl (202 kB)
[K     |████████████████████████████████| 202 kB 19.0 MB/s 
Collecting ml-collections
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[K     |████████████████████████████████| 77 kB 5.8 MB/s 
Collecting tensorstore
  Downloading tensorstore-0.1.22-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.5 MB)
[K     |████████████████████████████████| 7.5 MB 53.3 MB/s 
[?25hCollecting optax
  Downloading optax-0.1.3-py3-none-any.whl (145 kB)
[K     |████████████████████████████████| 145 kB 56.0 MB/s 
[?25hCollecting PyYAML>=5.4.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.m

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorflow-datasets==4.3.0
  Downloading tensorflow_datasets-4.3.0-py3-none-any.whl (3.9 MB)
[K     |████████████████████████████████| 3.9 MB 7.5 MB/s 
Installing collected packages: tensorflow-datasets
  Attempting uninstall: tensorflow-datasets
    Found existing installation: tensorflow-datasets 4.6.0
    Uninstalling tensorflow-datasets-4.6.0:
      Successfully uninstalled tensorflow-datasets-4.6.0
Successfully installed tensorflow-datasets-4.3.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorflow_addons==0.13.0
  Downloading tensorflow_addons-0.13.0-cp37-cp37m-manylinux2010_x86_64.whl (679 kB)
[K     |████████████████████████████████| 679 kB 8.3 MB/s 
Installing collected packages: tensorflow-addons
Successfully installed tensorflow-addons-0.13.0
Cloning into 'dipexercises'...
remote: Enumerating objec

---

## Model

In this section, we will change some parameters of model in order to optimize the accuracy.

In [None]:
class NestNet(nn.Module):
  """Nested Transformer Net."""
  num_classes: int
  config: ml_collections.ConfigDict
  train: bool = False
  dtype: int = jnp.float32
  activation_fn: Any = nn.gelu

  @nn.compact
  def __call__(self, inputs):
    config = self.config
    num_layers_per_block = config.num_layers_per_block
    num_blocks = len(num_layers_per_block)
    # Here we just assume image/patch size are squared.
    assert inputs.shape[1] == inputs.shape[2]
    assert inputs.shape[1] % config.init_patch_embed_size == 0
    input_size_after_patch = inputs.shape[1] // config.init_patch_embed_size
    assert input_size_after_patch % config.patch_size == 0
    down_sample_ratio = input_size_after_patch // config.patch_size
    # There are 4 child nodes for each node.
    assert num_blocks == int(np.log(down_sample_ratio) / np.log(2) + 1)

    # If `scale_hidden_dims` is provided, at every block, it increases hidden
    # dimension and num_heads by `scale_hidden_dims`. Set `scale_hidden_dims=2`
    # overall is a common design, so we do not gives the flexibility to control
    # layer-wise `scale_hidden_dims` to simplify the architecture.
    scale_hidden_dims = config.get("scale_hidden_dims", None)

    norm_fn = attn_utils.get_norm_layer(
        self.train, self.dtype, norm_type=config.norm_type)
    conv_fn = functools.partial(
        nn.Conv, dtype=self.dtype, kernel_init=default_kernel_init)
    dense_fn = functools.partial(
        nn.Dense, dtype=self.dtype, kernel_init=default_kernel_init)
    encoder_dict = dict(
        num_heads=config.num_heads,
        norm_fn=norm_fn,
        mlp_ratio=config.mlp_ratio,
        attn_type=config.attn_type,
        dense_fn=dense_fn,
        activation_fn=self.activation_fn,
        qkv_bias=config.qkv_bias,
        attn_drop=config.attn_drop,
        proj_drop=config.proj_drop,
        train=self.train,
        dtype=self.dtype)
    x = self_attention.PatchEmbedding(
        conv_fn=conv_fn,
        patch_size=(config.init_patch_embed_size, config.init_patch_embed_size),
        embedding_dim=config.embedding_dim)(
            inputs)
    x = attn_utils.block_images(x, (config.patch_size, config.patch_size))
    block_idx = 0
    total_block_num = np.sum(num_layers_per_block)
    path_drop = np.linspace(0, config.stochastic_depth_drop, total_block_num)
    for i in range(num_blocks):
      x = self_attention.PositionEmbedding()(x)
      if scale_hidden_dims and i != 0:
        # Overwrite the original num_heads value in encoder_dict so num_heads
        # multipled by scale_hidden_dims continueously.
        encoder_dict.update(
            {"num_heads": encoder_dict["num_heads"] * scale_hidden_dims})
      for _ in range(num_layers_per_block[i]):
        x = self_attention.EncoderNDBlock(
            **encoder_dict, path_drop=path_drop[block_idx])(
                x)
        block_idx = block_idx + 1
      if i < num_blocks - 1:
        grid_size = int(math.sqrt(x.shape[1]))
        if scale_hidden_dims:
          output_dim = x.shape[-1] * scale_hidden_dims
        else:
          output_dim = None

        x = self_attention.ConvPool(
            grid_size=(grid_size, grid_size),
            patch_size=(config.patch_size, config.patch_size),
            conv_fn=conv_fn,
            dtype=self.dtype,
            output_dim=output_dim)(
                x)
    assert x.shape[1] == 1
    assert x.shape[2] == config.patch_size**2

    x = norm_fn()(x)
    x_pool = jnp.mean(x, axis=(1, 2))
    out = dense_fn(self.num_classes)(x_pool)
    return out


MODELS = {}


def register(f):
  MODELS[f.__name__] = f
  return f


def default_config():
  """Shared configs for models."""
  nest = ml_collections.ConfigDict()
  nest.norm_type = "LN"
  nest.attn_type = "local_multi_head"
  nest.mlp_ratio = 4
  nest.qkv_bias = True
  nest.attn_drop = 0.0
  nest.proj_drop = 0.0
  nest.stochastic_depth_drop = 0.1
  return nest


@register
def nest_tiny_s16_32(config):
  """NesT tiny version with sequence length 16 for 32x32 inputs."""
  nest = default_config()
  # Encode one pixel as a word vector.
  nest.init_patch_embed_size = 1
  # Default max sequencee length is 4x4=16, so it has 4 layers.
  nest.patch_size = 4
  nest.num_layers_per_block = [3, 3, 3, 3]
  nest.embedding_dim = 192
  nest.num_heads = 3
  nest.attn_type = "local_multi_query"

  if config.get("nest"):
    nest.update(config.nest)
  return functools.partial(NestNet, config=nest)


@register
def nest_small_s16_32(config):
  """NesT small version with sequence length 16 for 32x32 inputs."""
  nest = default_config()
  nest.init_patch_embed_size = 1
  nest.patch_size = 4
  nest.num_layers_per_block = [3, 3, 3, 3]
  nest.embedding_dim = 384
  nest.num_heads = 6
  nest.attn_type = "local_multi_query"

  if config.get("nest"):
    nest.update(config.nest)
  return functools.partial(NestNet, config=nest)


@register
def nest_base_s16_32(config):
  """NesT base version with sequence length 16 for 32x32 inputs."""
  nest = default_config()
  nest.init_patch_embed_size = 1
  nest.patch_size = 4
  nest.num_layers_per_block = [3, 3, 3, 3]
  nest.embedding_dim = 768
  nest.num_heads = 12
  nest.attn_type = "local_multi_query"

  if config.get("nest"):
    nest.update(config.nest)
  return functools.partial(NestNet, config=nest)


@register
def nest_tiny_s196_224(config):
  """NesT tiny version with sequence length 49 for 224x224 inputs."""
  nest = default_config()
  # Encode 4x4 pixel as a word vector.
  nest.init_patch_embed_size = 4
  # Default max sequencee length is 14x14=196, so it has 3 layers:
  # Spatial image size: [56, 28, 14]
  nest.patch_size = 14
  nest.num_layers_per_block = [2, 2, 8]
  nest.embedding_dim = 96
  nest.num_heads = 3
  nest.scale_hidden_dims = 2
  nest.stochastic_depth_drop = 0.2
  nest.attn_type = "local_multi_head"

  if config.get("nest"):
    nest.update(config.nest)
  return functools.partial(NestNet, config=nest)


@register
def nest_small_s196_224(config):
  """NesT small version with sequence length 196 for 224x224 inputs."""
  nest = default_config()
  nest.init_patch_embed_size = 4
  nest.patch_size = 14
  nest.num_layers_per_block = [2, 2, 20]
  nest.embedding_dim = 96
  nest.num_heads = 3
  nest.scale_hidden_dims = 2
  nest.stochastic_depth_drop = 0.3
  nest.attn_type = "local_multi_head"

  if config.get("nest"):
    nest.update(config.nest)
  return functools.partial(NestNet, config=nest)


@register
def nest_base_s196_224(config):
  """NesT base version with sequence length 196 for 224x224 inputs."""
  nest = default_config()
  nest.init_patch_embed_size = 4
  nest.patch_size = 14
  nest.num_layers_per_block = [2, 2, 20]
  nest.embedding_dim = 128
  nest.num_heads = 4
  nest.scale_hidden_dims = 2
  nest.stochastic_depth_drop = 0.5
  nest.attn_type = "local_multi_head"

  if config.get("nest"):
    nest.update(config.nest)
  return functools.partial(NestNet, config=nest)


def create_model(name, config):
  """Creates model partial function."""
  if name not in MODELS:
    raise ValueError(f"Model {name} does not exist.")
  return MODELS[name](config)

---

## Training

Run the command bellow in order to train the `NesT` model and report evaluation metrics.

In [3]:
!python ./dipexercises/project/main.py --config dipexercises/project/configs/cifar_nest.py --workdir="./dipexercises/project/checkpoints/nest_cifar"

I0808 19:27:07.338225 139645475628928 main.py:49] Using JAX backend target local
I0808 19:27:07.338693 139645475628928 main.py:52] Using JAX XLA backend 
I0808 19:27:07.341539 139645475628928 tpu_client.py:54] Starting the local TPU driver.
I0808 19:27:07.341964 139645475628928 xla_bridge.py:212] Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
I0808 19:27:07.342265 139645475628928 xla_bridge.py:212] Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
I0808 19:27:07.342418 139645475628928 xla_bridge.py:212] Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
W0808 19:27:07.342489 139645475628928 xla_bridge.py:215] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I0808 19:27:07.342578 139645475628928 main.py:54] JAX host: 0 / 1
I0808 19:27:07.342785 1396454

---