<a href="https://colab.research.google.com/github/mayureshagashe2105/GSoC-22-TensorFlow-Resources-and-Notebooks/blob/main/JAX/Vision_Transformers_Flax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --upgrade -q pip jax jaxlib
!pip install --upgrade -q git+https://github.com/google/flax.git

[K     |████████████████████████████████| 2.0 MB 4.3 MB/s 
[K     |████████████████████████████████| 1.0 MB 47.5 MB/s 
[K     |████████████████████████████████| 72.0 MB 136 kB/s 
[?25h  Building wheel for jax (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m145.1/145.1 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m217.3/217.3 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m596.3/596.3 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.1/51.1 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.7/76.7 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for flax (setup.py) ... [?25l[?25hdone
[0m

In [57]:
import jax
from jax import lax, random, numpy as jnp, jit

import flax
from flax import linen as nn
from flax.training import train_state
from flax.core.lift import vmap
# from flax.linen.transforms import jit

import optax

import numpy as np
import matplotlib.pyplot as plt

from typing import Sequence

import tensorflow as tf
from tensorflow.keras.datasets import mnist

from tqdm.auto import tqdm

In [70]:
PATCH_SIZE = (7, 7)
STRIDE = 7
IMAGE_SIZE = (28, 28, 1)
PROJECTION_DIMS = 8
SELFA_HEADS = 2
TRANSFORMER_LAYERS = 8
BATCH_SIZE = 64
NUM_CLASSES = 10
LEARNING_RATE = 0.001
MOMENTUM = 0.9
SEED = 0
EPOCHS = 5

In [71]:
class DataLoader(tf.keras.utils.Sequence):
  def __init__(self, batch_size, X, y, is_training):
    self.X = X
    self.y = y
    self.batch_size = batch_size
    self.is_training = is_training
    self.indices = range(X.shape[0])
  
  def __len__(self):
    return self.X.shape[0] // self.batch_size
  
  def __getitem__(self, idx):
    batch_indices = self.indices[idx * self.batch_size: (idx + 1) * self.batch_size]
    batch_images = self.X[batch_indices]
    batch_labels = self.y[batch_indices]

    if(len(batch_images.shape) == 3):
      batch_images = jnp.expand_dims(batch_images, -1)

    return jnp.array(batch_images), jnp.array(batch_labels)


(X_train, y_train), (X_test, y_test) = mnist.load_data()
(X_train, y_train), (X_test, y_test) = ((tf.cast(X_train, tf.float32).numpy() / 255.0, tf.cast(y_train, tf.int32).numpy()), 
                                        (tf.cast(X_test, tf.float32).numpy() / 255.0, tf.cast(y_test, tf.int32).numpy()))


gen = DataLoader(BATCH_SIZE, X_train[:2000], y_train[:2000], True)

In [72]:
class MLP(nn.Module):
  
  hidden_layer_nodes: Sequence[int]

  def setup(self):

    self.layers = [nn.Dense(n) for n in self.hidden_layer_nodes]
  
  @nn.compact
  def __call__(self, input):
    for layer in self.layers:
      x = layer(input)
      x = self.apply_activation(x)
      return x
  

  def apply_activation(self, input):
    return nn.gelu(input)

In [73]:
class PatchExtractor(nn.Module):
  patch_size: Sequence[int]
  stride: int

  @nn.compact
  def __call__(self, images):
    # print(images.shape)
    patches = jax.lax.conv_general_dilated_patches(images[:, None, None, :], 
                                                   (1, self.patch_size[0], self.patch_size[1], 1), 
                                                   (1, self.stride, self.stride, 1), 
                                                   padding="VALID")
    n_patches = (images.shape[1] // self.patch_size[0]) * (images.shape[2] // self.patch_size[1])
    patch_dims = self.patch_size[0] * self.patch_size[1] * images.shape[3]
    image_patches = patches.reshape(images.shape[0], n_patches, patch_dims)

    return image_patches

In [74]:
class PatchEncoder(nn.Module):
  num_patches: int
  projection_dims: int

  def setup(self):
    self.projection = nn.Dense(self.projection_dims)
    self.positional_encodings = nn.Embed(self.num_patches, self.projection_dims)
  
  @nn.compact
  def __call__(self, patch):
    positions = jnp.arange(0, self.num_patches)
    encode = self.projection(patch) + self.positional_encodings(positions)
    return encode

In [75]:
class PositionalEncodings(nn.Module):
  seq_len: int
  num_heads: int

  def __call__(self):
    res = jnp.ones((self.seq_len, self.num_heads))
    for i in range(self.seq_len):
      for j in range(self.num_heads):
        res = res.at[i, j].set(jnp.sin(i / (10000 ** (j / self.num_heads))) if j % 2 == 0 else jnp.cos(i / (10000 ** ((j - 1) / self.num_heads))))
    
    return jnp.expand_dims(res, 0)

In [76]:
class VisionTransformer(nn.Module):
  patch_size: Sequence[int]
  stride: int
  image_size: Sequence[int]
  projection_dims: int
  atten_heads: int
  transformer_layers: int
  batch_size: int
  num_classes: int


  def setup(self):
    self.patchify = PatchExtractor(self.patch_size, self.stride)
    self.patch_dims = self.patch_size[0] * self.patch_size[1] * self.image_size[-1]
    self.tokens = MLP([self.patch_dims, self.projection_dims][::-1])
    self.class_token = self.param("class_token", lambda rng, shape: random.normal(rng, shape), (1, self.projection_dims))

    self.num_patches = ((self.image_size[0] - self.patch_size[0]) // self.stride + 1) * ((self.image_size[1] - self.patch_size[1]) // self.stride + 1)
    self.pos_encodings = PositionalEncodings(self.num_patches + 1, self.projection_dims)

    self.norm1 = nn.LayerNorm(epsilon=1e-6)
    self.self_attention = nn.SelfAttention(self.atten_heads, qkv_features=self.projection_dims)

    self.norm2 = nn.LayerNorm(epsilon=1e-6)
    self.enc_mlp = MLP([self.projection_dims, self.projection_dims])

    self.logits_mlp = MLP([self.num_classes, self.projection_dims])

  @nn.compact
  def __call__(self, inputs):
    image_patches = self.patchify(inputs)
    tokens = self.tokens(image_patches)
    tokens = jnp.stack([jnp.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
    tokens += self.pos_encodings().repeat(self.batch_size, 0)
    tokens = self.norm1(tokens)


    out = tokens + self.self_attention(tokens)

    out_temp = self.norm2(out)
    out_temp = self.enc_mlp(out_temp)
    out_temp = nn.relu(out_temp)
    out += out_temp

    out =  out[:, 0]

    logits = self.logits_mlp(out)
    logits = nn.softmax(logits)
    return logits
    
    
  
  
  @staticmethod
  @jit
  def layer_add(x, y):
    return jnp.add(x, y)

In [77]:
model = VisionTransformer(patch_size=PATCH_SIZE, stride=STRIDE, image_size=IMAGE_SIZE, projection_dims=PROJECTION_DIMS, 
                          atten_heads=SELFA_HEADS, transformer_layers=TRANSFORMER_LAYERS, batch_size=BATCH_SIZE, num_classes=NUM_CLASSES)

In [80]:
class TrainingLoop:
  
  def __init__(self, model, train_gen, seed, epochs, learning_rate, momentum, test_gen=None):
    self.model = model
    self.train_gen = train_gen
    self.key = seed
    self.rng = jax.random.PRNGKey(self.key)
    self.main_rng, self.init_rng = random.split(self.rng, 2)
    self.epochs = epochs
    self.lr = learning_rate
    self.momentum = momentum

    self.full_batch_size = (self.model.batch_size, self.model.image_size[0],
                            self.model.image_size[1], self.model.image_size[2])
    
    self.init_train_state()



  def init_train_state(self):
    self.variables = self.model.init({'params': self.init_rng}, jnp.ones(self.full_batch_size))['params']
    self.optimizer = optax.adam(self.lr, self.momentum)
    self.train_state = train_state.TrainState.create(apply_fn = self.model.apply, tx=self.optimizer, params=self.variables)

  
  @staticmethod
  def apply_model(state, model, images, labels, num_classes):  
    
    def loss_fn(params):
      logits = model.apply({'params': params}, images)
      one_hot = jax.nn.one_hot(labels, 10)
      loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
      return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy

  
  @staticmethod
  @jax.jit
  def update_model(state, grads):
    return state.apply_gradients(grads=grads)

 
  @staticmethod
  def train_epoch(state, model, gen, batch_size, rng, num_classes):

    epoch_loss = []
    epoch_accuracy = []

    for (batch_images, batch_labels) in tqdm(gen, desc='Training', leave=False):
      batch_images = batch_images
      batch_labels = batch_labels
      grads, loss, accuracy = TrainingLoop.apply_model(state, model, batch_images, batch_labels, 10)
      state = TrainingLoop.update_model(state, grads)
      epoch_loss.append(loss)
      epoch_accuracy.append(accuracy)
    
    train_loss = np.mean(epoch_loss)
    train_accuracy = np.mean(epoch_accuracy)
    return state, train_loss, train_accuracy


  @staticmethod
  def train(obj):


    for epoch in tqdm(range(1, obj.epochs + 1), desc="Epoch"):
      obj.train_state, train_loss, train_accuracy = TrainingLoop.train_epoch(obj.train_state,
                                                      obj.model,
                                                      obj.train_gen,
                                                      obj.model.batch_size,
                                                      obj.main_rng,
                                                      obj.model.num_classes)
      
      print(
          'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f'
          % (epoch, train_loss, train_accuracy * 100)
          )

    return obj.train_state


In [None]:
trainer = TrainingLoop(model, gen, SEED, EPOCHS, LEARNING_RATE, MOMENTUM, test_gen=None)
final_state = TrainingLoop.train(trainer)