In [1]:
from typing import List, Dict, Mapping, Tuple

import chex
import jax
import jax.numpy as jnp
import jax.random as jrand
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb
import functools

def println(*args):
  for arg in args:
    print(arg)


In [5]:
# import jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()
# jax.devices()

In [2]:
DEVICE_COUNT = len(jax.devices())
DEVICE_COUNT

1

## Dataset pipline

In [3]:
import tensorflow as tf

# Load the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Normalize the pixel values
x_train, x_test = x_train / 255.0, x_test / 255.0

# Convert the labels to one-hot encoding
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

# Create a TensorFlow data pipeline for the training set
train_dataset = (
    tf.data.Dataset
    .from_tensor_slices((x_train, y_train))
    .shuffle(buffer_size=5000)
    .batch(64)
    .prefetch(tf.data.AUTOTUNE)
    .as_numpy_iterator())

# Create a TensorFlow data pipeline for the test set.
test_dataset = (
    tf.data.Dataset
    .from_tensor_slices((x_test, y_test))
    .batch(64)
    .prefetch(tf.data.AUTOTUNE)
    .as_numpy_iterator())

def get_batch(training: bool = True):
  images, labels = (
      next(train_dataset) if training
      else next(test_dataset))

  images, labels = jnp.array(images), jnp.array(labels)
  return images, labels

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


### test get_batch

In [4]:
test_images, test_labels = get_batch()
test_images.shape, test_labels.shape

((64, 32, 32, 3), (64, 10))

## Modeling

In [10]:
class Patches(nn.Module):
  """Takes an image and creates patches."""
  patch_size: int
  embed_dim: int

  def setup(self):
    self.conv = nn.Conv(
        features=self.embed_dim,
        kernel_size=(self.patch_size, self.patch_size),
        strides=(self.patch_size, self.patch_size),
        padding='VALID'
    )

  def __call__(self, images):
    patches = self.conv(images)
    h, w, c = patches.shape
    patches = jnp.reshape(patches, (h*w, c))
    return patches

In [56]:
class PatchEncoder(nn.Module):
  hidden_dim: int

  @nn.compact
  def __call__(self, x):
    # x should be a single example
    chex.assert_rank(x, 2)
    T, C = x.shape

    # project to hidden dim
    x = nn.Dense(self.hidden_dim)(x)

    # add cls token
    cls = self.param('cls_token', nn.initializers.zeros, (1, self.hidden_dim))
    x = jnp.concatenate([cls, x], axis=0)

    # we added an extra cl token at the beginning
    T = T + 1

    # Add position embedding
    pos_embed = self.param(
        'position_embedding',
        nn.initializers.normal(stddev=0.02), # From BERT
        (T, self.hidden_dim)
    )
    return x + pos_embed

### test patches

In [50]:
test_image, test_label = test_images[0], test_labels[0]
test_image.shape

(32, 32, 3)

In [51]:
p = Patches(patch_size=4, embed_dim=192)
output, params = p.init_with_output(jax.random.PRNGKey(99), test_image)
params = params["params"]

In [None]:
pe = PatchEncoder(hidden_dim=256)
output_pe, params = pe.init_with_output(jax.random.PRNGKey(999), output)
params = params["params"]

## Modeling contd.

In [77]:
class FeedForward(nn.Module):
  output_size: int

  def setup(self):
    # **new**: attention paper uses 4 times token_info_size when doing linear transformation.
    # and then projects it back to token_info_size in linear transformation layer.
    self.ffwd = nn.Dense(features=4 * self.output_size)

    # **new**: projection layer, which goes back into residual pathway.
    self.projection = nn.Dense(self.output_size)

  def __call__(self, x, training: bool):
    x = nn.relu(self.ffwd(x))
    x = self.projection(x)
    return x


class Head(nn.Module):
  token_info_size: int # head_size; how much (emb dim) info each token emits for keys, queries, values.
  T: int # block size; number of tokens in a block

  def setup(self):
    # key, query will take vector of size C.
    # i.e., channels containing info of token and will output token_info_size
    self.key_layer = nn.Dense(self.token_info_size, use_bias=False)
    self.query_layer = nn.Dense(self.token_info_size, use_bias=False)
    self.value_layer = nn.Dense(self.token_info_size, use_bias=False)

    self.dropout = nn.Dropout(rate=0.2)


  def __call__(self, block_of_tokens_with_info_channels: jnp.array, training: bool):
    """Accepts a block of tokens with info channels, like (8, 65)."""

    # TODO(ntnsonti): Double check; but tril should not be learnable according cGPT.
    tril = jnp.tril(jnp.ones(shape=(self.T, self.T)))

    # input: (T, info channels )
    # output: (T, token_info_size)
    keys = self.key_layer(block_of_tokens_with_info_channels)
    queries = self.query_layer(block_of_tokens_with_info_channels)
    values = self.value_layer(block_of_tokens_with_info_channels)

    # chanel info size
    C = int(block_of_tokens_with_info_channels.shape[-1])
    # print("[ntn99] channel_info_size: ", C)

    # compute attention score.
    wei = jnp.dot(queries, keys.T) * C**0.5 # (T, token_info_size) * (token_info_size, T) == (T, T)
    wei = jnp.where(tril==0, -jnp.inf, wei)
    wei = nn.softmax(wei, axis=-1)

    attention_values = jnp.dot(wei, values) # (T, T) * (T, token_info_size))

    attention_values = self.dropout(attention_values, deterministic=not training)

    return attention_values # (T, token_info_size)


class MultiHeadAttention(nn.Module):
  num_heads: int
  final_token_info_size: int # After concatenating from all heads, how much info (values -- emb size) you have on each token.
  T: int

  def setup(self):
    self.token_info_size_per_head = int(self.final_token_info_size/self.num_heads)
    self.heads = [
        Head(token_info_size=self.token_info_size_per_head, T=self.T) for _ in range(self.num_heads)
    ]

    self.projection = nn.Dense(features=self.final_token_info_size)

    self.dropout = nn.Dropout(rate=0.2)

  def __call__(self, block_of_tokens_with_info_channels: jnp.array, training: bool):
    out_from_each_head = jnp.array([h(block_of_tokens_with_info_channels, training) for h in self.heads])

    # You just run multiple attention heads in parallel and concatenate
    # their output along channel dimension, i.e., dim==-1
    out_from_all_heads = jnp.concatenate(out_from_each_head, axis=-1)
    # print("[ntn99] out_from_all_heads concatenated shape: ", out_from_all_heads.shape)

    projection =  self.projection(out_from_all_heads)

    return self.dropout(projection, deterministic=not training)


class Block(nn.Module):
  num_heads: int
  final_token_info_size: int
  T: int

  def setup(self):
    # communication.
    self.self_attention_heads = MultiHeadAttention(num_heads=self.num_heads,
                                                   final_token_info_size=self.final_token_info_size,
                                                   T=self.T)

    # computation.
    self.computation_layer = FeedForward(output_size=self.final_token_info_size)

    self.ln1 = nn.LayerNorm()
    self.ln2 = nn.LayerNorm()

    self.dropout = nn.Dropout(rate=0.2)

  def __call__(self, x, training: bool):
    x = x + self.self_attention_heads(self.ln1(x), training)
    # print("[ntn99] input size after attention_head: ", x.shape)

    x = x + self.computation_layer(self.ln2(x), training)
    # print("[ntn99] input size after computation (end of block): ", x.shape)

    x = self.dropout(x, deterministic=not training)
    return x

In [82]:
class ViT(nn.Module):
  patch_size: int
  embed_dim: int
  hidden_dim: int
  n_heads: int
  drop_p: float
  num_layers: int
  mlp_dim: int
  num_classes: int
  T: int = 4 # +1 cls token

  def setup(self):
    self.patch_extracter = Patches(self.patch_size, self.embed_dim)
    self.patch_encoder = PatchEncoder(self.hidden_dim)
    self.dropout = nn.Dropout(self.drop_p)
    self.blocks = [
        Block(num_heads=4, final_token_info_size=self.embed_dim, T=(self.T+1)) for _ in range(self.num_layers)]

    self.cls_head = nn.Dense(features=self.num_classes)

  def __call__(self, x, training: bool = True):
    x = self.patch_extracter(x)
    x = self.patch_encoder(x)
    x = self.dropout(x, deterministic=not training)

    for i in range(self.num_layers):
      x = self.blocks[i](x, training)

    # MLP head
    x = x[:, 0] # [CLS] token
    x = self.cls_head(x)
    return x


In [83]:
class TrainState(train_state.TrainState):
  key: jax.random.KeyArray


random_key = jax.random.PRNGKey(99)
random_key, random_subkey = jax.random.split(random_key)

model = ViT(patch_size=16, embed_dim=192, hidden_dim=192,
    n_heads=3, drop_p=0.1, num_layers=12, mlp_dim=768, num_classes=10, T=4)


output, params = model.init_with_output(jrand.PRNGKey(99), test_image, training=False)
params = params["params"]


For more information, see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html
  key: jax.random.KeyArray


In [92]:
def model_apply(params, inputs):
  dropout_key = jax.random.PRNGKey(0) # TODO need to fix this.
  return model.apply({"params": params}, inputs, False, rngs={'dropout': dropout_key})

model_apply_batch = jax.vmap(model_apply, in_axes=(None, 0), out_axes=(0))

def forward_pass(params, state, batch):
  inputs, targets = batch
  logits = state.apply_fn(params, inputs)
  loss = optax.softmax_cross_entropy(logits, targets)
  loss = loss.mean()
  return loss

def train_step(state, batch):
  grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.
  loss, grads = grad_fn(state.params, state, batch)
  state = state.apply_gradients(grads=grads)
  return state, loss

opt = optax.adam(learning_rate=0.0001)
state = TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt, key=random_key)

In [95]:
for epoch in range(10):
  batch = get_batch()

  random_key, random_subkey = jax.random.split(random_key)
  dropout_key = jax.random.fold_in(key=random_key, data=state.step)

  state, loss = train_step(state, batch)
  print("loss", loss, "epoch", epoch) if epoch%1==0 else None

loss 3.155889 epoch 0
loss 2.6199641 epoch 1
loss 2.6245441 epoch 2
loss 2.5297794 epoch 3
loss 2.7282379 epoch 4
loss 2.5052311 epoch 5
loss 2.3105931 epoch 6
loss 2.5190861 epoch 7
loss 2.4369555 epoch 8
loss 2.616269 epoch 9
