# DIP Answers

- **Answer Set**: Final Project
- **Full Name**: Mohammad Hosein Nemati
- **Student Code**: `610300185`

---

## Introduction

In this problem, we are going to change some parameters in order to optimize the accuracy of [**Nested Hierarchical Transformer**](https://github.com/google-research/nested-transformer) model for **CIFAR10** dataset.  
Then we will compare the reported metrics to previously trained models in `article`

In the first step, we will import some useful libraries.

In [None]:
!pip install absl-py
!pip install clu==0.0.3
!pip install flax==0.3.4
!pip install jax==0.2.14
!pip install jaxlib==0.1.67
!pip install ml_collections
!pip install tensorflow-cpu==2.8.0
!pip install tensorflow-datasets==4.3.0
!pip install tensorflow_addons==0.13.0

![ -d nested-transformer ] || git clone --depth=1 https://github.com/google-research/nested-transformer
!cd nested-transformer && git pull

In [None]:
import sys
sys.path.append('./nested-transformer')

import os
import time
import flax
from flax import nn
import jax
import jax.numpy as jnp
import jax.tools.colab_tpu
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import functools
from absl import logging


from libml import input_pipeline 
from libml import preprocess
from models import nest_net  
import train  
from configs import cifar_nest 
from configs import imagenet_nest  

jax.tools.colab_tpu.setup_tpu()

# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], "GPU")
logging.set_verbosity(logging.INFO)

print("JAX devices:\n" + "\n".join([repr(d) for d in jax.devices()]))
print('Current folder content', os.listdir())

---

## Loading

First of all, we must load our dataset and then shuffle the records.

In [None]:
cifar_builder = tfds.builder("cifar10")

---

## Training

In [None]:
config = cifar_nest.get_config()
config.num_train_steps = 1
config.num_eval_steps = 1
config.num_epochs = 1
config.warmup_epochs = 0
config.per_device_batch_size = 128

info, train_ds, eval_ds = input_pipeline.create_datasets(
    config, jax.random.PRNGKey(0)
)

workdir = f"./nested-transformer/checkpoints/cifar_nest_colab_{int(time.time())}"

train.train_and_evaluate(config, workdir)

---

## Testing

---