<a href="https://colab.research.google.com/github/mayureshagashe2105/GSoC-22-TensorFlow-Resources-and-Notebooks/blob/main/JAX/Vision_Transformers_Flax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install --upgrade -q pip jax jaxlib
!pip install --upgrade -q git+https://github.com/google/flax.git

[K     |████████████████████████████████| 2.0 MB 9.5 MB/s 
[K     |████████████████████████████████| 1.0 MB 45.4 MB/s 
[K     |████████████████████████████████| 72.0 MB 150 kB/s 
[?25h  Building wheel for jax (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m145.1/145.1 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m217.3/217.3 kB[0m [31m22.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m596.3/596.3 kB[0m [31m47.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.1/51.1 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.7/76.7 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for flax (setup.py) ... [?25l[?25hdone
[0m

In [3]:
import jax
from jax import lax, random, numpy as jnp, jit

import flax
from flax import linen as nn
from flax.training import train_state
from flax.core.lift import vmap
# from flax.linen.transforms import jit

import optax

import numpy as np
import matplotlib.pyplot as plt

from typing import Sequence

import tensorflow as tf
from tensorflow.keras.datasets import cifar10
import tensorflow as tf

In [4]:
class DataLoader(tf.keras.utils.Sequence):
  def __init__(self, batch_size, X, y, is_training):
    self.X = X
    self.y = y
    self.batch_size = batch_size
    self.is_training = is_training
    self.indices = range(X.shape[0])
  
  def __len__(self):
    return self.X.shape[0] // self.batch_size
  
  def __getitem__(self, idx):
    batch_indices = self.indices[idx * self.batch_size: (idx + 1) * self.batch_size]
    batch_images = self.X[batch_indices]
    batch_labels = self.y[batch_indices]

    return jnp.array(batch_images), jnp.array(batch_labels)


(X_train, y_train), (X_test, y_test) = cifar10.load_data()
(X_train, y_train), (X_test, y_test) = ((tf.cast(X_train, tf.float32).numpy() / 255.0, tf.cast(y_train, tf.int32).numpy()), 
                                        (tf.cast(X_test, tf.float32).numpy() / 255.0, tf.cast(y_test, tf.int32).numpy()))


gen = DataLoader(64, X_train[:1000], y_train[:1000], True)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [5]:
class MLP(nn.Module):
  
  hidden_layer_nodes: Sequence[int]
  activation: str

  def setup(self):
    whitelist_activations = ['celu', 'elu', 'gelu', 'glu', 'log_sigmoid', 'log_softmax', 'relu', 'sigmoid', 'soft_sign', 'softmax', 'softplus', 'swish', 'PRelu', None]
    if self.activation not in whitelist_activations:
      raise ValueError(f'{self.activation} should be one of {whitelist_activations}')


    self.layers = [nn.Dense(n) for n in self.hidden_layer_nodes]
  
  @nn.compact
  def __call__(self, input):
    for layer in self.layers:
      x = layer(input)
      if self.activation is not None:
        x = self.apply_activation(x)
      return x
  
  def apply_activation(self, input):
    if self.activation == 'celu': return nn.celu(input)
    elif self.activation == 'elu': return nn.elu(input)
    elif self.activation == 'gelu': return nn.gelu(input)
    elif self.activation == 'glu': return nn.glu(input)
    elif self.activation == 'log_sigmoid': return nn.log_sigmoid(input)
    elif self.activation == 'log_softmax': return nn.log_softmax(input)
    elif self.activation == 'relu': return nn.relu(input)
    elif self.activation == 'sigmoid': return nn.sigmoid(input)
    elif self.activation == 'soft_sign': return nn.soft_sign(input)
    elif self.activation == 'softmax': return nn.softmax(input)
    elif self.activation == 'softplus': return nn.softplus(input)
    elif self.activation == 'swish': return nn.swish(input)
    elif self.activation == 'PRelu': return nn.PRelu(input)

In [6]:
class PatchEncoder(nn.Module):
  num_patches: int
  projection_dims: int

  def setup(self):
    self.projection = nn.Dense(self.projection_dims)
    self.positional_encodings = nn.Embed(self.num_patches, self.projection_dims)
  
  @nn.compact
  def __call__(self, patch):
    positions = jnp.arange(0, self.num_patches)
    encode = self.projection(patch) + self.positional_encodings(positions)
    return encode

In [36]:
class VisionTransformer(nn.Module):
  patch_size: Sequence[int]
  stride: int
  image_size: Sequence[int]
  activation: str
  projection_dims: int
  num_heads: int
  transformer_layers: int
  mlp_head_units: Sequence[int]
  batch_size: int
  num_classes: int


  def setup(self):
    self.transformer_units = [self.projection_dims * 2, self.projection_dims]
    self.num_patches = ((self.image_size[0] - self.patch_size[0]) // self.stride + 1) * ((self.image_size[1] - self.patch_size[1]) // self.stride + 1)
    
    self.norm = nn.LayerNorm(epsilon=1e-6)
    self.multi_head_attention = nn.MultiHeadDotProductAttention(num_heads=self.num_heads, qkv_features=self.projection_dims)
    self.dropout10 = nn.Dropout(0.1)
    self.dropout50 = nn.Dropout(0.5)
    self.logits = nn.Dense(self.num_classes)

    encode_init = PatchEncoder(self.num_patches, self.projection_dims)
    mlp_init = MLP(self.transformer_units[::-1], self.activation)
    mlp2_init = MLP(self.mlp_head_units[::-1], self.activation)

    self.encode = encode_init
    self.mlp_transformer = mlp_init
    self.mlp_head = mlp2_init

  @nn.compact
  def __call__(self, inputs):
    # image_patches = self.extract_patches(inputs)

    patches = jax.lax.conv_general_dilated_patches(inputs[:, None, None, :], (1, self.patch_size[0], self.patch_size[1], 1), 
                                                   (1, self.stride, self.stride, 1), 
                                                   padding="VALID").reshape(self.batch_size, -1, self.patch_size[0] * self.patch_size[1] * inputs.shape[-1])
    patch_dims = patches.shape[-1]
    image_patches = patches.reshape(self.batch_size, -1, patch_dims)


    encoded_image_patches = self.encode(image_patches)

    for _ in range(self.transformer_layers):
      x1 = self.norm(encoded_image_patches)
      attention_output = self.multi_head_attention(x1, x1)
      x2 = attention_output + encoded_image_patches #VisionTransformer.layer_add(attention_output, encoded_image_patches)
      x3 = self.norm(x2)

      x3 = self.mlp_transformer(x3)
      # x3 = self.dropout10(x3, deterministic=not True)
      # print(x3.shape)
      encoded_image_patches = x3 + x2 #VisionTransformer.layer_add(x3, x2)
    
    repr = self.norm(encoded_image_patches)

    repr = repr.reshape(self.batch_size, -1)
    # repr = self.dropout50(repr, deterministic=not True)
    repr = self.mlp_head(repr)
    # print(repr.shape)
    # repr = self.dropout(repr, deterministic=not True)
    logit_nodes = self.logits(repr)
    logit_nodes = nn.softmax(logit_nodes)

    return logit_nodes
  

  @jit
  def extract_patches(self, inputs):
    patches = jax.lax.conv_general_dilated_patches(inputs[:, None, None, :], (1, self.patch_size[0], self.patch_size[1], 1), 
                                                   (1, self.stride, self.stride, 1), 
                                                   padding="VALID").reshape(self.batch_size, -1, self.patch_size[0] * self.patch_size[1] * inputs.shape[-1])
    patch_dims = patches.shape[-1]
    patches = patches.reshape(self.batch_size, -1, patch_dims)
    return patches
  
  @staticmethod
  @jit
  def layer_add(x, y):
    return jnp.add(x, y)

In [37]:
PATCH_SIZE = (10, 10)
STRIDE = 10
IMAGE_SIZE = (32, 32, 3)
PROJECTION_DIMS = 64
NUM_HEADS = 4
TRANSFORMER_LAYERS = 8
MLP_HEAD_UNITS = (2048, 1024)
BATCH_SIZE = 64
NUM_CLASSES = 10
LEARNING_RATE = 0.001
MOMENTUM = 0.9
SEED = 0
EPOCHS = 5

In [38]:
model = VisionTransformer(patch_size=PATCH_SIZE, stride=STRIDE, image_size=IMAGE_SIZE, activation="gelu", projection_dims=PROJECTION_DIMS, num_heads=NUM_HEADS, 
                          transformer_layers=TRANSFORMER_LAYERS, mlp_head_units=MLP_HEAD_UNITS, batch_size=BATCH_SIZE, num_classes=NUM_CLASSES)

In [41]:
class TrainingLoop:
  
  def __init__(self, model, train_gen, seed, epochs, learning_rate, momentum, test_gen=None):
    self.model = model
    self.train_gen = train_gen
    self.key = seed
    self.rng = jax.random.PRNGKey(self.key)
    self.epochs = epochs
    self.lr = learning_rate
    self.momentum = momentum

    self.full_batch_size = (self.model.batch_size, self.model.image_size[0],
                            self.model.image_size[1], self.model.image_size[2])
    
    self.init_train_state()



  def init_train_state(self):
    self.variables = self.model.init(self.rng, jnp.ones(self.full_batch_size))['params']
    self.optimizer = optax.adam(self.lr, self.momentum)
    self.train_state = train_state.TrainState.create(apply_fn = model.apply, tx=self.optimizer, params=self.variables)

  
  @staticmethod
  @jax.jit
  def apply_model(state, images, labels, num_classes):  
    
    def loss_fn(params):
      logits = state.apply_fn({'params': params}, images)
      one_hot = jax.nn.one_hot(labels, 10)
      loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
      return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy

  
  @staticmethod
  @jax.jit
  def update_model(state, grads):
    return state.apply_gradients(grads=grads)

 
  @staticmethod
  def train_epoch(state, gen, batch_size, rng, num_classes):

    epoch_loss = []
    epoch_accuracy = []

    for (batch_images, batch_labels) in gen:
      batch_images = batch_images
      batch_labels = batch_labels
      grads, loss, accuracy = TrainingLoop.apply_model(state, batch_images, batch_labels, 10)
      state = TrainingLoop.update_model(state, grads)
      epoch_loss.append(loss)
      epoch_accuracy.append(accuracy)
    
    train_loss = np.mean(epoch_loss)
    train_accuracy = np.mean(epoch_accuracy)
    return state, train_loss, train_accuracy


  @staticmethod
  def train(obj):


    for epoch in range(1, obj.epochs + 1):
      state, train_loss, train_accuracy = TrainingLoop.train_epoch(obj.train_state, obj.train_gen,
                                                      obj.model.batch_size,
                                                      obj.rng,
                                                      obj.model.num_classes)
      
      print(
          'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f'
          % (epoch, train_loss, train_accuracy * 100)
          )

    return state


In [42]:
trainer = TrainingLoop(model, gen, SEED, EPOCHS, LEARNING_RATE, MOMENTUM, test_gen=None)
final_state = TrainingLoop.train(trainer)

epoch:  1, train_loss: 2.3420, train_accuracy: 9.99
epoch:  2, train_loss: 2.3420, train_accuracy: 9.99
epoch:  3, train_loss: 2.3420, train_accuracy: 9.99
epoch:  4, train_loss: 2.3420, train_accuracy: 9.99
epoch:  5, train_loss: 2.3420, train_accuracy: 9.99
