<a href="https://colab.research.google.com/github/cao-nv/visual_transformer/blob/main/transformer_tf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install einops
!pip install wandb 



In [2]:
!pip install tensorflow_addons



In [3]:
import tensorflow as tf
import tensorflow.keras.layers as layers
import tensorflow_addons as tfa
import tensorflow_datasets as tfds

In [4]:
from einops import rearrange

In [5]:
import pdb

In [6]:
class Residual(layers.Layer):
  def __init__(self, fn):
    super().__init__()
    self.fn = fn
  
  def call(self, x, **kwargs):
    return self.fn(x, **kwargs) + x

In [7]:
class PreNorm(layers.Layer):
  def __init__(self, dim, fn):
    super().__init__()
    self.fn = fn
    self.norm = layers.LayerNormalization(axis=-1)

  def call(self, x, **kwargs):
    return self.fn(self.norm(x), **kwargs)

In [34]:
class FeedForward(layers.Layer):
  def __init__(self, dim, hidden_dim, dropout=0.):
    super().__init__()
    self.dense1 = layers.Dense(hidden_dim, input_shape=(None, dim), use_bias=False)
    self.gelu = tfa.layers.GELU()
    self.dropout1 = layers.Dropout(dropout)
    self.dense2 = layers.Dense(dim, input_shape=(None, hidden_dim))
    self.dropout2 = layers.Dropout(dropout)

    # ff_layers = [layers.Dense(hidden_dim, input_shape=(None, dim), use_bias=False), 
    #           tfa.layers.GELU(), 
    #           layers.Dropout(dropout), 
    #           layers.Dense(dim), 
    #           layers.Dropout(dropout)]
    # self.net = tf.keras.Sequential(layers=ff_layers, name="FeedForward")

  def call(self, x, training=True): 
    x = self.dense1(x)
    x = self.gelu(x)
    x = self.dropout1(x, training=training)
    x = self.dense2(x)
    x = self.dropout2(x, training=training)
    return x

In [22]:
class Attention(layers.Layer): 
  def __init__(self, dim, heads=8, dropout=0.):
    super().__init__()
    self.heads = heads 
    self.scales = heads ** (-0.5)

    self.to_qkv = layers.Dense(dim*3, input_shape=(dim,), use_bias=False)

    self.out_dense = layers.Dense(dim, input_shape=(None, dim))
    self.out_dropout = layers.Dropout(dropout)

    self.to_out = tf.keras.Sequential(layers=[layers.Dense(dim, input_shape=(None, dim)), 
                                              layers.Dropout(dropout)], name='to_out')
    
  def to_out(self, x, training=True):
    out = self.out_dense(x)
    out = self.out_dropout(out, training=training)
    return out

  def call(self, x, mask=None, training=True):
    b, n, _, h = *x.shape, self.heads
    qkv = tf.split(self.to_qkv(x), 3, axis=-1)
    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)

    dots = tf.einsum("bhid, bhjd -> bhij", q, k) * self.scales

    if mask is not None:
      mask = tf.pad(mask.flatten(1), (1, 0), constant_values=True)
      assert mask.shape[-1] == dots.shape[-1], "Mask has incorrect dimensions"
      mask = mask[:, None, :] * mask[:, :, None]
      dots[~mask] = tf.fill(dots[~mask], float('-inf'))
      del mask

    attn = tf.nn.softmax(dots, axis=-1)
    out = tf.einsum("bhij,bhjd->bhid", attn, v)
    out = rearrange(out, "b h n d -> b n (h d)")
    out = self.to_out(out, training=True)
    return out

In [10]:
class Transformer(layers.Layer):
  def __init__(self, dim, depth, heads, mlp_dim, dropout): 
    super().__init__()
    self.layers = []

    for _ in range(depth): 
      self.layers.append([Residual(PreNorm(dim, Attention(dim, heads=heads, dropout=dropout))), 
                          Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)))])
      
  def call(self, x, mask=None):
    for attn, ff in self.layers:
      x = attn(x, mask=mask)
      x = ff(x)
    return x

In [38]:
MIN_NUM_PATCHES=16
class ViT(tf.keras.Model):
  def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3, dropout=0., emb_dropout=0.):
    super().__init__()
    assert image_size % patch_size == 0, "Image size must be divisible for the patch size"
    num_patches = (image_size // patch_size) ** 2
    patch_dim = channels * patch_size**2
    assert num_patches >= MIN_NUM_PATCHES,  f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size'

    self.patch_size = patch_size 

    self.pos_embedding = tf.Variable(tf.random.normal([1, num_patches+1, dim]))
    self.patch_to_embedding = layers.Dense(dim, input_shape=(patch_dim,), use_bias=False)
    self.cls_token = tf.Variable(tf.random.normal([1, 1, dim]))
    self.dropout = layers.Dropout(emb_dropout)

    self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)


    self.mlp_layer_norm = layers.LayerNormalization(axis=-1)
    self.mlp_dense1 = layers.Dense(mlp_dim)
    self.mlp_gelu = tfa.layers.GELU() 
    self.mlp_dropout = layers.Dropout(dropout)
    self.mlp_dense2 = layers.Dense(num_classes)

    # self.mlp_head = tf.keras.Sequential(layers=[
    #                                             layers.LayerNormalization(axis=-1), 
    #                                             layers.Dense(mlp_dim),
    #                                             tfa.layers.GELU(), 
    #                                             layers.Dropout(dropout), 
    #                                             layers.Dense(num_classes)
    # ])

  def mlp_head(self, x, training=True):
    out = self.mlp_layer_norm(x)
    out = self.mlp_dense1(out)
    out = self.mlp_gelu(out)
    out = self.mlp_dropout(out, training=training)
    out = self.mlp_dense2(out)
    return out

  def call(self, img, mask=None, training=True):
    p = self.patch_size
    x = rearrange(img, "b (h p1) (w p2) c -> b (h w) (p1 p2 c)", p1=p, p2=p)
    x = self.patch_to_embedding(x)
    b, n, _ = x.shape
    #pdb.set_trace()
    cls_token = tf.repeat(self.cls_token, [b], axis=0)
    x = tf.concat([cls_token, x], axis=1)
    x += self.pos_embedding[:, :(n+1)]
    x = self.dropout(x)

    x = self.transformer(x, mask)
    x = tf.identity(x[:, 0])

    return self.mlp_head(x, training=training)


In [12]:
train_aug_layers = tf.keras.Sequential(layers=[layers.ZeroPadding2D(padding=(4, 4)), 
                                             tf.keras.layers.experimental.preprocessing.RandomCrop(32, 32), 
                                             tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal"),
                                             tf.keras.layers.experimental.preprocessing.Normalization(axis=-1, mean=(0.4914, 0.4822, 0.4465), variance=(0.2023, 0.1994, 0.2010))
])

test_aug_layers = tf.keras.layers.experimental.preprocessing.Normalization(axis=-1, mean=(0.4914, 0.4822, 0.4465), variance=(0.2023, 0.1994, 0.2010))

In [13]:
def train_aug_func(ds_sample):
  image = ds_sample['image']
  label = ds_sample['label']
  aug_image = train_aug_layers(image)
  one_hot_label = tf.one_hot(label, depth=10)
  return aug_image, one_hot_label

In [14]:
def test_aug_func(ds_sample): 
  image = ds_sample["image"]
  label = ds_sample["label"]
  aug_image = test_aug_layers(image)
  one_hot_label = tf.one_hot(label, depth=10)
  return aug_image, one_hot_label

In [15]:
cifar10_train = tfds.load("cifar10", split="train")
cifar10_test = tfds.load("cifar10", split="test")

In [16]:
def get_train_dataset(batch_size):
  ds = cifar10_train.shuffle(10000)
  ds = ds.batch(batch_size, drop_remainder=True)
  ds = ds.map(train_aug_func)
  ds = ds.prefetch(10000)
  return ds

In [17]:
def get_test_dataset(batch_size):
  ds = cifar10_test.shuffle(10000)
  ds = ds.batch(batch_size, drop_remainder=True)
  ds = ds.map(test_aug_func)
  ds = ds.prefetch(10000)
  return ds

In [32]:
class WarmUpScheduler(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=2000):
    super(WarmUpScheduler, self).__init__()
    self.d_model = tf.cast(d_model, tf.float32)
    self.warmup_steps = warmup_steps 

  def __call__(self, step):
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)

    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

In [18]:
import os
from tqdm import notebook
import albumentations
import wandb
from wandb.keras import WandbCallback

In [19]:
config = {"lr": 1e-4, 
          "batch_size": 64, 
          "n_epochs": 100, 
          "patch": 2}

In [20]:
wandb.init(project="ViT-cifar10-tf", entity="caonv", config=config)

[34m[1mwandb[0m: Currently logged in as: [33mcaonv[0m (use `wandb login --relogin` to force relogin)


In [22]:
best_acc = 0
start_epoch = 0 

In [22]:
#config = wandb.config

In [39]:
net = ViT(image_size = 32,
    patch_size = 4, 
    num_classes = 10,
    dim = 512,
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1)

In [24]:
train_ds = get_train_dataset(config["batch_size"])
test_ds = get_test_dataset(config["batch_size"])

In [25]:
lr_reduce_on_plateau = tf.keras.callbacks.ReduceLROnPlateau(monitor="val_accuracy", factor=0.1,
                                                    patience=3, min_delta=0.01, min_lr=1e-8, cooldown=0)

In [40]:
net.compile(loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
            optimizer = tf.keras.optimizers.Adam(learning_rate=config["lr"]),
                  metrics = ["accuracy"])

In [27]:
callbacks = [lr_reduce_on_plateau, WandbCallback()]

In [None]:
net.fit(train_ds, epochs=100, validation_data=test_ds, callbacks=[WandbCallback()])

Epoch 1/100


  return py_builtins.overload_of(f)(*args)








[34m[1mwandb[0m: [32m[41mERROR[0m Can't save model, h5py returned error: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn't safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format="tf") or using `save_weights`.


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
181/781 [=====>........................] - ETA: 1:07 - loss: 1.0200 - accuracy: 0.6400