<a href="https://colab.research.google.com/github/mayureshagashe2105/GSoC-22-TensorFlow-Resources-and-Notebooks/blob/main/JAX/Vision_Transformers_Flax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --upgrade -q pip jax jaxlib
!pip install --upgrade -q git+https://github.com/google/flax.git

[K     |████████████████████████████████| 2.0 MB 5.1 MB/s 
[K     |████████████████████████████████| 1.0 MB 68.1 MB/s 
[K     |████████████████████████████████| 72.0 MB 112 kB/s 
[?25h  Building wheel for jax (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m145.1/145.1 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m217.3/217.3 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m596.3/596.3 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.1/51.1 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.7/76.7 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for flax (setup.py) ... [?25l[?25hdone
[0m

In [2]:
import jax
from jax import lax, random, numpy as jnp, jit

import flax
from flax import linen as nn
from flax.training import train_state
from flax.core.lift import vmap

import optax

import numpy as np
import matplotlib.pyplot as plt

from typing import Sequence, List, Union, Tuple

import tensorflow as tf
from tensorflow.keras.datasets import cifar10

from tqdm.auto import tqdm

In [3]:
PATCH_SIZE = (7, 7)
STRIDE = 7
IMAGE_SIZE = (32, 32, 3)
PROJECTION_DIMS = 8
SELFA_HEADS = 2
TRANSFORMER_LAYERS = 8
BATCH_SIZE = 64
NUM_CLASSES = 10
LEARNING_RATE = 0.001
MOMENTUM = 0.9
SEED = 0
EPOCHS = 5

In [4]:
class DataLoader(tf.keras.utils.Sequence):
  """Generates batches of images and labels to pass to the model
  
  Args:
    batch_size: int. Size of a batch to yield.
    X: np.ndarray. Images from the dataset in the form of numpy array.
    y: np.ndarray. Labels for `X`.
  """
  def __init__(self, batch_size: int, X: np.ndarray, y: np.ndarray) -> Tuple[jnp.ndarray]:
    self.X = X
    self.y = y
    self.batch_size = batch_size
    self.indices = range(X.shape[0])
  
  def __len__(self):
    """Number of batch in the Sequence.

    Returns:
        The number of batches in the Sequence.
    """
    return self.X.shape[0] // self.batch_size
  
  def __getitem__(self, idx):
    """Returns pre-processed batch at position `index`.

    Args:
        index: position of the batch in the Sequence.

    Returns:
        A batch.
    """
    batch_indices = self.indices[idx * self.batch_size: (idx + 1) * self.batch_size]
    batch_images = self.X[batch_indices]
    batch_labels = self.y[batch_indices]

    if(len(batch_images.shape) == 3): # convert 
      batch_images = jnp.expand_dims(batch_images, -1)

    return jnp.array(batch_images), jnp.array(batch_labels)


(X_train, y_train), (X_test, y_test) = cifar10.load_data()
(X_train, y_train), (X_test, y_test) = ((tf.cast(X_train, tf.float32).numpy() / 255.0, tf.cast(y_train, tf.int32).numpy()), 
                                        (tf.cast(X_test, tf.float32).numpy() / 255.0, tf.cast(y_test, tf.int32).numpy()))


gen = DataLoader(BATCH_SIZE, X_train[:500], y_train[:500])

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [5]:
class MLP(nn.Module):
  """Multi-layer perceptron dataclass.
  
  Args:
    hidden_layer_nodes: Sequence[int]. Number of nodes in hidden layers.
    activations: Sequence[str]. Activation functions to apply at each layer.
  
  Raises:
    AssertionError: If length of `self.activations` is not same as length of `self.hidden_layer_nodes`.
    ValueError: If any value from `self.activations` is not from allowed activation functions.
  """
  hidden_layer_nodes: Sequence[int]
  activations: Sequence[str]

  def setup(self):
    
    assert len(self.hidden_layer_nodes) == len(self.activations), "Activation function for each layer must be provided."

    self.__whitelist_activations = ['celu', 'elu', 'gelu', 'glu', 'log_sigmoid',
                                    'log_softmax', 'relu', 'sigmoid', 
                                    'soft_sign', 'softmax', 'softplus', 
                                    'swish', 'PRelu', 'Linear']

    self.layers = [(nn.Dense(self.hidden_layer_nodes[n]), self.activations[n]) 
                  for n in range(len(self.hidden_layer_nodes)) 
                  if self.activations[n] in self.__whitelist_activations 
                  ]
    
    if len(self.layers) is not len(self.activations):
      raise ValueError(f'Activation function should be one of the {self.__whitelist_activations}') 

  @nn.compact
  def __call__(self, input):
    for layer, activation in self.layers:
      x = layer(input)
      x = self.apply_activation(x, activation)
      return x
  
  def apply_activation(self, input, activation):
    if activation == 'celu': return nn.celu(input)
    elif activation == 'elu': return nn.elu(input)
    elif activation == 'gelu': return nn.gelu(input)
    elif activation == 'glu': return nn.glu(input)
    elif activation == 'log_sigmoid': return nn.log_sigmoid(input)
    elif activation == 'log_softmax': return nn.log_softmax(input)
    elif activation == 'relu': return nn.relu(input)
    elif activation == 'sigmoid': return nn.sigmoid(input)
    elif activation == 'soft_sign': return nn.soft_sign(input)
    elif activation == 'softmax': return nn.softmax(input)
    elif activation == 'softplus': return nn.softplus(input)
    elif activation == 'swish': return nn.swish(input)
    elif activation == 'PRelu': return nn.PRelu(input)
    elif activation == "Linear": return input


In [6]:
class PatchExtractor(nn.Module):
  """Custom module to extract patches from the images.
  
  Args:
    patch_size: Sequence[int]. Image will be divided into patches of the desired size.
    stride: int. Stride length to slide the window for patch extraction.
  
  Raises:
    AssertionError: If `patch_size` is not a sequence with length = 2.
  """
  patch_size: Sequence[int]
  stride: int

  def setup(self):
    assert len(self.patch_size) == 2, "length of `patch_size` should be equal to 2."


  @nn.compact
  def __call__(self, images):
    patches = jax.lax.conv_general_dilated_patches(images[:, None, None, :], 
                                                   (1, self.patch_size[0], self.patch_size[1], 1), 
                                                   (1, self.stride, self.stride, 1), 
                                                   padding="VALID")
    n_patches = (images.shape[1] // self.patch_size[0]) * (images.shape[2] // self.patch_size[1])
    patch_dims = self.patch_size[0] * self.patch_size[1] * images.shape[3]
    image_patches = patches.reshape(images.shape[0], n_patches, patch_dims)

    return image_patches

In [7]:
class PositionalEncodings(nn.Module):
  """Custom module to return learnable positional encodings (sin-cosin waves).
  
  Args:
    seq_len: int. Lenght of the sequence of patches.
    projection_dims: int. Number of dimensions for internal representation of the model.
  """
  seq_len: int
  projection_dims: int

  def __call__(self):
    res = jnp.ones((self.seq_len, self.projection_dims))
    for i in range(self.seq_len):
      for j in range(self.projection_dims):
        res = res.at[i, j].set(jnp.sin(i / (10000 ** (j / self.projection_dims))) if j % 2 == 0 else jnp.cos(i / (10000 ** ((j - 1) / self.projection_dims))))
    
    return jnp.expand_dims(res, 0)

In [8]:
class VisionTransformer(nn.Module):
  """Vision Transformer module
  
  Args:
    patch_size: Sequence[int]. Image will be divided into patches of the desired size.
    stride: int. Stride length to slide the window for patch extraction.
    image_size: Sequence[int]. Format: (H, W, C). Size of 1 image from the batch.
    projection_dims: int. Number of dimensions for internal representation of the model.
    atten_heads: int. Number of self attention heads to be used.
    transformer_layers: int. Number of transformer encoders to be used.
    batch_size: int. Size of a batch to yield.
    num_classes. int. Number of target classes.
  
  Raises:
    AssertionError: If input images are not in the format (N, H, W, C).
  """
  patch_size: Sequence[int]
  stride: int
  image_size: Sequence[int]
  projection_dims: int
  atten_heads: int
  transformer_layers: int
  batch_size: int
  num_classes: int


  def setup(self):

    self.patchify = PatchExtractor(self.patch_size, self.stride)
    self.patch_dims = self.patch_size[0] * self.patch_size[1] * self.image_size[-1]
    self.tokens = MLP([self.projection_dims, self.patch_dims], activations=['gelu', 'gelu'])
    self.class_token = self.param("class_token", lambda rng, shape: random.normal(rng, shape), (1, self.projection_dims))

    self.num_patches = ((self.image_size[0] - self.patch_size[0]) // self.stride + 1) * ((self.image_size[1] - self.patch_size[1]) // self.stride + 1)
    self.pos_encodings = PositionalEncodings(self.num_patches + 1, self.projection_dims)

    self.norm1 = nn.LayerNorm(epsilon=1e-6)
    self.self_attention = nn.SelfAttention(self.atten_heads, qkv_features=self.projection_dims)

    self.norm2 = nn.LayerNorm(epsilon=1e-6)
    self.enc_mlp = MLP([self.projection_dims, self.projection_dims], activations=['gelu', 'relu'])

    self.logits_mlp = MLP([self.num_classes, self.projection_dims], activations=['gelu', 'softmax'])

  @nn.compact
  def __call__(self, inputs):
    assert len(inputs.shape) == 4, f"ViT encoder expected 4D vector as input in the format (N, H, W, C) but got {len(inputs.shape)}D vector instead."

    image_patches = self.patchify(inputs)
    tokens = self.tokens(image_patches)
    tokens = jnp.stack([jnp.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
    tokens += self.pos_encodings().repeat(self.batch_size, 0)
    
    # Tranformer Encoder
    for i in range(self.transformer_layers):
      tokens = self.norm1(tokens)


      out = tokens + self.self_attention(tokens)

      out_temp = self.norm2(out)
      out_temp = self.enc_mlp(out_temp)
      out += out_temp

    out =  out[:, 0]

    logits = self.logits_mlp(out)
    logits = nn.softmax(logits)
    return logits
    

In [9]:
model = VisionTransformer(patch_size=PATCH_SIZE, stride=STRIDE, image_size=IMAGE_SIZE, projection_dims=PROJECTION_DIMS, 
                          atten_heads=SELFA_HEADS, transformer_layers=TRANSFORMER_LAYERS, batch_size=BATCH_SIZE, num_classes=NUM_CLASSES)

In [10]:
class TrainingLoop:
  """OOP wrapper around functional training loop.
  
  Args:
    model: VisionTransformer. Model architecture for Vision Transformer.
    train_gen: DataLoader. Dataloader to feed the training data to the model dynamically.
    seed: int. Seed value for random number generator to ensure reproducibility.
    epochs: int. Maximum number of iteration for training.
    learning_rate: float. Learning rate for the optimizer.
    momentum: float. Momentum for the optimizer.
    val_gen: DataLoader. Default=None. Dataloader to feed the validation data to the model dynamically.
  
  TODO: Make the `apply_model` method jittable.
  """
  
  def __init__(self, model: VisionTransformer, train_gen: DataLoader, seed: int, epochs: int, learning_rate: float, momentum: float, val_gen=None):
    self.model = model
    self.train_gen = train_gen
    self.key = seed
    self.rng = jax.random.PRNGKey(self.key)
    self.main_rng, self.init_rng = random.split(self.rng, 2)
    self.epochs = epochs
    self.lr = learning_rate
    self.momentum = momentum

    self.full_batch_size = (self.model.batch_size, self.model.image_size[0],
                            self.model.image_size[1], self.model.image_size[2])
    
    self.init_train_state()


  def init_train_state(self):
    """Initializes the model's and optimizer's state
    """
    self.variables = self.model.init({'params': self.init_rng}, jnp.ones(self.full_batch_size))['params']
    self.optimizer = optax.adam(self.lr, self.momentum)
    self.train_state = train_state.TrainState.create(apply_fn = self.model.apply, tx=self.optimizer, params=self.variables)

  
  @staticmethod
  def apply_model(state: train_state.TrainState, model: VisionTransformer, images: jnp.ndarray, labels: jnp.ndarray):  
    """Calculates the gradients during backpropogation to adjust model's parameters.
    
    Args:
      state: train_state.TrainState. State of the model's params at a particular time.
      model: VisionTransformer. Model architecture for Vision Transformer.
      images: jnp.ndarray. Input images.
      labels: jnp.ndarray. Labels for input images.
    
    Returns:
      grads: flax.core.frozen_dict.FrozenDict. Gradients from backpropogation to update model's params.
      loss: float. Loss function's output value.
      accuracy: float. Accuracy of the model.
    """
    def loss_fn(params):
      """categorical-cross entropy loss function
      
      Args:
        params: . Model's params (weights and biases).
      
      Returns:
        loss: float. Loss function's output value.
        logits: jnp.ndarray. Predictions made by the `model`.
      """
      logits = model.apply({'params': params}, images)
      one_hot = jax.nn.one_hot(labels, 10)
      loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
      return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy

  
  @staticmethod
  @jax.jit
  def update_model(state: train_state.TrainState, grads: flax.core.frozen_dict.FrozenDict):
    """Updates model's params using calculated gradients.
    
    Args:
      state: train_state.TrainState. State of the model's params at a particular time.
      grads: flax.core.frozen_dict.FrozenDict. Gradients from backpropogation to update model's params.
    
    Returns:
      state: train_state.TrainState. Updated state.
    """
    return state.apply_gradients(grads=grads)

 
  @staticmethod
  def train_epoch(state: train_state.TrainState, model: VisionTransformer, gen: DataLoader, batch_size: int, rng: jnp.ndarray):
    """Trains the model for one epoch with batch mode.
    
    Args:
      state: train_state.TrainState. State of the model's params at a particular time.
      model: VisionTransformer. Model architecture for Vision Transformer.
      gen: DataLoader. Dataloader to feed the training data to the model dynamically.
      batch_size: int. Size of a batch to yield.
      rng: jnp.ndarray. Random number seed to ensure reproducibility.
    
    Returns:
      state: train_state.TrainState. State of the model's params at a particular time.
      train_loss. float. Loss for 1 epoch.
      train_accuracy. float. Accuracy achieved for 1 epoch.
    """
    epoch_loss = []
    epoch_accuracy = []
    for (batch_images, batch_labels) in tqdm(gen, desc='Batch Training', leave=False):
      batch_images = batch_images
      batch_labels = batch_labels
      grads, loss, accuracy = TrainingLoop.apply_model(state, model, batch_images, batch_labels)
      state = TrainingLoop.update_model(state, grads)
      epoch_loss.append(loss)
      epoch_accuracy.append(accuracy)

    
    train_loss = np.mean(epoch_loss)
    train_accuracy = np.mean(epoch_accuracy)
    return state, train_loss, train_accuracy


  @staticmethod
  def train(obj):
    """Drives the training process using created OOP wrapper.
    
    Args:
      obj: TrainingLoop. Instance of the class `TrainingLoop` which drives the training process of the model.
      
    Returns:
      obj.train_state. Final state of model's params for getting inference.
    """
    for epoch in tqdm(range(1, obj.epochs + 1), desc="Training"):
      obj.train_state, train_loss, train_accuracy = TrainingLoop.train_epoch(obj.train_state,
                                                      obj.model,
                                                      obj.train_gen,
                                                      obj.model.batch_size,
                                                      obj.main_rng,
                                                      )
      
      print(f"epoch: {epoch}, train_loss: {train_loss}, train_accuracy: {train_accuracy}")

    return obj.train_state


In [None]:
trainer = TrainingLoop(model, gen, SEED, EPOCHS, LEARNING_RATE, MOMENTUM, val_gen=None)
final_state = TrainingLoop.train(trainer)