In [1]:
!pip install einops --quiet

# Higher jax versions have some issue with Array abstrafication in Keras, we use the 0.7.2 for this tutorial 
!pip install jax==0.7.2

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Collecting jax==0.7.2
  Downloading jax-0.7.2-py3-none-any.whl.metadata (13 kB)
Collecting jaxlib<=0.7.2,>=0.7.2 (from jax==0.7.2)
  Downloading jaxlib-0.7.2-cp312-cp312-manylinux_2_27_x86_64.whl.metadata (1.3 kB)
Downloading jax-0.7.2-py3-none-any.whl (2.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.8/2.8 MB[0m [31m43.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxlib-0.7.2-cp312-cp312-manylinux_2_27_x86_64.whl (78.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.2/78.2 MB[0m [31m68.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: jaxlib, jax
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.8.0
    Uninstalling jaxlib

In [2]:
import os
os.environ['KERAS_BACKEND'] = 'jax'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

import jax
import jax.numpy as jnp
from einops import rearrange
from jax import jit, value_and_grad
from torchvision.datasets import CIFAR10, Imagenette
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from jax import random

import keras
from keras import ops

# Vision Transformer hyper-parameters
image_size = 128
patch_size = 4
num_patches = (image_size // patch_size) ** 2

num_layers = 4
hidden_dim = 64
mlp_dim = 128


num_classes = 10
num_heads = 4
head_dim = hidden_dim//num_heads



In [3]:
vit_parameters = {
    'patch_embed': None,
    'positional_encoding': None,
    'layers': [],
    'final_layer_norm': None,
    'head': [],
    'cls_token': None
}

In [4]:
from keras import initializers

# for the class token, we just need a single vector of the same size as a token
cls_token = keras.Variable(
        initializer=initializers.Zeros(),
        shape = (1, hidden_dim),
        trainable=True,
    )
vit_parameters['cls_token'] = cls_token


patch_embed = keras.Variable(
        initializer=initializers.RandomNormal(stddev=0.01),
        shape = ((3 * patch_size * patch_size), hidden_dim),
        trainable=True,
    )

vit_parameters['patch_embed'] = patch_embed


pos_enc = keras.Variable(
        initializer=initializers.RandomNormal(stddev=0.01),
        shape = (num_patches, hidden_dim),
        trainable=True,
    )
vit_parameters['positional_encoding'] = pos_enc



head_params = keras.Variable(
        initializer=initializers.RandomNormal(stddev=0.01),
        shape = (hidden_dim, num_classes),
        trainable=True,
    )
head_bias = keras.Variable(
        initializer=initializers.Zeros(),
        shape = (num_classes, ),
        trainable=True,
    )
vit_parameters['head'] = (head_params, head_bias)

E0000 00:00:1762619057.453588    2857 common_lib.cc:648] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:238


In [5]:
def initialize_mlp(hidden_dim, mlp_dim):


    w1 = keras.Variable(
        initializer=initializers.RandomNormal(stddev=0.01),
        shape = (hidden_dim, mlp_dim),
        trainable=True,
    )
    b1 = keras.Variable(
        initializer=initializers.Zeros(),
        shape = (mlp_dim, ),
        trainable=True,
    )

    w2 = keras.Variable(
        initializer=initializers.RandomNormal(stddev=0.01),
        shape = (mlp_dim, hidden_dim),
        trainable=True,
    )
    b2 = keras.Variable(
        initializer=initializers.Zeros(),
        shape = (hidden_dim, ),
        trainable=True,
    )

    return w1, b1, w2, b2


def initialize_attention(hidden_dim, num_heads):

    fan_in = hidden_dim
    fan_out = head_dim * num_heads


    q_w = keras.Variable(
        initializer=initializers.RandomNormal(stddev=0.01),
        shape = (fan_in, fan_out),
        trainable=True,
    )
    q_b = keras.Variable(
        initializer=initializers.Zeros(),
        shape = (fan_out, ),
        trainable=True,
    )
    
    k_w = keras.Variable(
        initializer=initializers.RandomNormal(stddev=0.01),
        shape = (fan_in, fan_out),
        trainable=True,
    )
    k_b = keras.Variable(
        initializer=initializers.Zeros(),
        shape = (fan_out, ),
        trainable=True,
    )
    
    v_w = keras.Variable(
        initializer=initializers.RandomNormal(stddev=0.01),
        shape = (fan_in, fan_out),
        trainable=True,
    )
    v_b = keras.Variable(
        initializer=initializers.Zeros(),
        shape = (fan_out, ),
        trainable=True,
    )

    return q_w, k_w, v_w, q_b, k_b, v_b


def initialize_layer_norm(hidden_dim):
    gamma = keras.Variable(
        initializer=initializers.Ones(),
        shape = (hidden_dim, ),
        trainable=True,
    )
    beta = keras.Variable(
        initializer=initializers.Zeros(),
        shape = (hidden_dim, ),
        trainable=True,
    )
    return gamma, beta

In [6]:
for i in range(num_layers):
    mlp_params = initialize_mlp(hidden_dim, mlp_dim)
    attn_params = initialize_attention(hidden_dim, num_heads)
    ln1_params = initialize_layer_norm(hidden_dim)
    ln2_params = initialize_layer_norm(hidden_dim)
    vit_parameters['layers'].append((mlp_params, attn_params, ln1_params, ln2_params))



final_layer_norm_params = initialize_layer_norm(hidden_dim)
vit_parameters['final_layer_norm'] = final_layer_norm_params

In [7]:
def relu(input):
    return ops.maximum(0, input)


def softmax(x, axis=-1):
    x_max = ops.max(x, axis=axis, keepdims=True)
    x_shifted = x - x_max
    exp_x = ops.exp(x_shifted)
    return exp_x / ops.sum(exp_x, axis=axis, keepdims=True)

def mlp(x, mlp_params):

    # unpack the parameters
    w1, b1, w2, b2 = mlp_params

    # out = (Relu(x*w1 + b1))*w2 + b2
    up_proj = relu(ops.matmul(x, w1) + b1)
    down_proj = ops.matmul(up_proj, w2) + b2

    return down_proj


def self_attention(x, attn_params):

    # unpack the parameters
    q_w, k_w, v_w, q_b, k_b, v_b = attn_params

    # n and d_k are the sequence length of the input and the hidden dimension
    n, d_k = x.shape

    # project the input into the query, key and value spaces
    q = ops.matmul(x, q_w) + q_b
    k = ops.matmul(x, k_w) + k_b
    v = ops.matmul(x, v_w) + v_b


    # reshape to have heads
    # n, (num_heads head_dim) ->  (n, num_heads, headim) -> (num_heads, n, head_dim)
    q = q.reshape(n, num_heads, head_dim).swapaxes(0, 1)
    k = k.reshape(n, num_heads, head_dim).swapaxes(0, 1)
    v = v.reshape(n, num_heads, head_dim).swapaxes(0, 1)

    # perform multi-head attention
    attention_weights_heads = ops.matmul(q, ops.swapaxes(k, -1, -2)) / ops.sqrt(head_dim)
    attention_weights_heads = ops.softmax(attention_weights_heads, axis=-1)

    # output projection (num_heads, n, head_dim)
    output = ops.matmul(attention_weights_heads, v)

    # reshape back (n, num_heads * heam_dim)
    output = output.swapaxes(0,1).reshape(n, d_k)

    return output


def layer_norm(x, layernorm_params):
    # a simple layer norm
    gamma, beta = layernorm_params
    mean = ops.mean(x, axis=-1, keepdims=True)
    var = ops.var(x, axis=-1, keepdims=True)
    return gamma * (x - mean) / ops.sqrt(var + 1e-6) + beta


def transformer_block(inp, block_params):

    # unpack the parameters
    mlp_params, attn_params, ln1_params, ln2_params = block_params

    # attention
    x = layer_norm(inp, ln1_params)
    x = self_attention(x, attn_params)
    skip = x + inp

    # mlp
    x = layer_norm(skip, ln2_params)
    x = mlp(x, mlp_params)
    x = x + skip

    return x


def transformer(patches, vit_parameters):

    # reshape image from c,h,w -> num_patches, patch_size*patch_size
    patches = rearrange (patches, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)

    # embed the patches
    patches = ops.matmul(patches, vit_parameters['patch_embed'])

    # add positional encoding
    patches = patches + vit_parameters['positional_encoding']

    # append class token to sequence
    cls_token = vit_parameters['cls_token']
    patches = ops.concatenate([cls_token, patches], axis=0)


    # forward through all transformer blocks
    for layer, block_params in enumerate(vit_parameters['layers']):
        patches = transformer_block(patches, block_params)

    # final layer norm
    patches = layer_norm(patches, vit_parameters['final_layer_norm'])

    # get the class token and apply the final head
    patches = patches[0, :]
    logits = ops.matmul(patches, vit_parameters['head'][0]) + vit_parameters['head'][1]
    return logits

In [8]:
key = random.PRNGKey(42)

sample_image = random.normal(key, (3 ,image_size, image_size))
prediction = transformer(sample_image, vit_parameters)
print("Output shape:", prediction.shape) # should be (num_classes,)

Output shape: (10,)


In [9]:
bsize = 5
sample_images = random.normal(key, (bsize, 3 ,image_size, image_size))

prediction = jax.vmap(transformer, in_axes=(0, None))(sample_images, vit_parameters)
print("Prediction shape:", prediction.shape)

Prediction shape: (5, 10)


In [10]:
def cross_entropy_loss(patches, vit_parameters, ground_truth):
    prediction = jax.vmap(transformer, in_axes=(0, None))(patches, vit_parameters)
    logs = ops.log_softmax(prediction)
    l = -ops.mean(ops.sum(ground_truth * logs, axis=-1))
    return l

In [11]:
l = cross_entropy_loss(sample_images, vit_parameters, jnp.zeros((bsize, 10)).at[0, 1].set(1))
print("Loss:", l)

Loss: 0.453823


In [12]:
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]


train_dataset = Imagenette(
    root='imagenette3',
    size="160px",
    split='train',
    download=True,
    transform=transforms.Compose([transforms.Resize((image_size,image_size)),  transforms.ToTensor(), transforms.Normalize(mean, std)])
    )
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)


test_dataset = Imagenette(
    root='imagenette3',
    size="160px",
    split='val',
    download=True,
    transform=transforms.Compose([transforms.Resize((image_size,image_size)), transforms.ToTensor(), transforms.Normalize(mean, std)])
    )
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [13]:
def eval(vit_parameters):

  correct = 0

  for(img, target) in tqdm(test_loader, desc="Eval", unit="item"):

    img = jnp.asarray(img, dtype=jnp.float32)
    target = jnp.asarray(target)

    logits = jax.vmap(transformer, in_axes=(0, None))(img, vit_parameters)
    prediction = jnp.argmax(logits, axis=-1)
    correct += jnp.sum(prediction == target).item()


  acc = correct / len(test_dataset)

  return acc

accuracy = eval(vit_parameters)
print("Accuracy before training", accuracy)

Eval: 100%|██████████| 16/16 [00:25<00:00,  1.61s/item]

Accuracy before training 0.11923566878980892





In [15]:
# fake labels and images
sample_images = random.normal(key, (bsize, 3 ,image_size, image_size))
sample_target = jnp.zeros((bsize, 10)).at[0, 1].set(1)
current_loss, grads = value_and_grad(cross_entropy_loss, argnums=1)(sample_images, vit_parameters, sample_target)

print("Current loss:", current_loss)
print("Gradients:", grads.keys())

  patches = ops.matmul(patches, vit_parameters['patch_embed'])
  patches = patches + vit_parameters['positional_encoding']
  patches = transformer_block(patches, block_params)
  return gamma * (x - mean) / ops.sqrt(var + 1e-6) + beta
  q = ops.matmul(x, q_w) + q_b
  k = ops.matmul(x, k_w) + k_b
  v = ops.matmul(x, v_w) + v_b
  up_proj = relu(ops.matmul(x, w1) + b1)
  down_proj = ops.matmul(up_proj, w2) + b2


Current loss: 0.453823
Gradients: dict_keys(['cls_token', 'final_layer_norm', 'head', 'layers', 'patch_embed', 'positional_encoding'])


  logits = ops.matmul(patches, vit_parameters['head'][0]) + vit_parameters['head'][1]


In [16]:
@jit
def train_step(patches, vit_parameters, target):
    # compute gradients
    current_loss, grads = value_and_grad(cross_entropy_loss, argnums=1)(
        patches,
        vit_parameters,
        target)

    # update parameters
    updated_params = jax.tree.map(lambda p, g: p - 0.01 * g, vit_parameters, grads)

    return current_loss, updated_params

In [17]:
num_epochs = 20


for epoch in range(num_epochs):

    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
    #for (data, target) in tqdm(train_loader, desc=f'Train epoch {epoch}'):
    for i, (data, target) in progress_bar:

        # convert to numpy
        data = jnp.asarray(data)
        target = jnp.asarray(target)

        # reshape and get one hot fot loss
        target_one_hot = jax.nn.one_hot(target, num_classes)

        current_loss, vit_parameters = train_step(data, vit_parameters, target_one_hot)

        progress_bar.set_postfix({'loss': current_loss})


    eval_acc = eval(vit_parameters)
    print(f'Epoch: {epoch}, Eval acc: {eval_acc}')

Epoch 1/20: 100%|██████████| 37/37 [00:33<00:00,  1.09it/s, loss=2.120179] 
Eval: 100%|██████████| 16/16 [00:13<00:00,  1.16item/s]


Epoch: 0, Eval acc: 0.21987261146496814


Epoch 2/20: 100%|██████████| 37/37 [00:22<00:00,  1.67it/s, loss=2.1085021]
Eval: 100%|██████████| 16/16 [00:13<00:00,  1.18item/s]


Epoch: 1, Eval acc: 0.2575796178343949


Epoch 3/20: 100%|██████████| 37/37 [00:22<00:00,  1.63it/s, loss=2.029838] 
Eval: 100%|██████████| 16/16 [00:13<00:00,  1.17item/s]


Epoch: 2, Eval acc: 0.24764331210191082


Epoch 4/20: 100%|██████████| 37/37 [00:22<00:00,  1.62it/s, loss=1.9859735]
Eval: 100%|██████████| 16/16 [00:14<00:00,  1.14item/s]


Epoch: 3, Eval acc: 0.2540127388535032


Epoch 5/20: 100%|██████████| 37/37 [00:22<00:00,  1.67it/s, loss=1.9744337]
Eval: 100%|██████████| 16/16 [00:15<00:00,  1.00item/s]


Epoch: 4, Eval acc: 0.2761783439490446


Epoch 6/20: 100%|██████████| 37/37 [00:22<00:00,  1.67it/s, loss=1.9847224]
Eval: 100%|██████████| 16/16 [00:14<00:00,  1.12item/s]


Epoch: 5, Eval acc: 0.3029299363057325


Epoch 7/20: 100%|██████████| 37/37 [00:21<00:00,  1.70it/s, loss=1.9853494]
Eval: 100%|██████████| 16/16 [00:15<00:00,  1.05item/s]


Epoch: 6, Eval acc: 0.2968152866242038


Epoch 8/20: 100%|██████████| 37/37 [00:22<00:00,  1.66it/s, loss=1.9634247]
Eval: 100%|██████████| 16/16 [00:14<00:00,  1.11item/s]


Epoch: 7, Eval acc: 0.3080254777070064


Epoch 9/20: 100%|██████████| 37/37 [00:21<00:00,  1.70it/s, loss=1.908057] 
Eval: 100%|██████████| 16/16 [00:13<00:00,  1.20item/s]


Epoch: 8, Eval acc: 0.27694267515923565


Epoch 10/20: 100%|██████████| 37/37 [00:22<00:00,  1.65it/s, loss=1.9818041]
Eval: 100%|██████████| 16/16 [00:14<00:00,  1.08item/s]


Epoch: 9, Eval acc: 0.255031847133758


Epoch 11/20: 100%|██████████| 37/37 [00:22<00:00,  1.68it/s, loss=1.9826884]
Eval: 100%|██████████| 16/16 [00:14<00:00,  1.11item/s]


Epoch: 10, Eval acc: 0.30369426751592354


Epoch 12/20: 100%|██████████| 37/37 [00:22<00:00,  1.66it/s, loss=1.9982338]
Eval: 100%|██████████| 16/16 [00:14<00:00,  1.12item/s]


Epoch: 11, Eval acc: 0.30522292993630573


Epoch 13/20: 100%|██████████| 37/37 [00:22<00:00,  1.63it/s, loss=1.897428] 
Eval: 100%|██████████| 16/16 [00:14<00:00,  1.09item/s]


Epoch: 12, Eval acc: 0.3154140127388535


Epoch 14/20: 100%|██████████| 37/37 [00:21<00:00,  1.69it/s, loss=1.8940327]
Eval: 100%|██████████| 16/16 [00:14<00:00,  1.14item/s]


Epoch: 13, Eval acc: 0.3159235668789809


Epoch 15/20: 100%|██████████| 37/37 [00:22<00:00,  1.66it/s, loss=1.9319364]
Eval: 100%|██████████| 16/16 [00:14<00:00,  1.08item/s]


Epoch: 14, Eval acc: 0.3189808917197452


Epoch 16/20: 100%|██████████| 37/37 [00:21<00:00,  1.69it/s, loss=1.8812433]
Eval: 100%|██████████| 16/16 [00:13<00:00,  1.16item/s]


Epoch: 15, Eval acc: 0.32764331210191083


Epoch 17/20: 100%|██████████| 37/37 [00:22<00:00,  1.67it/s, loss=1.9737448]
Eval: 100%|██████████| 16/16 [00:14<00:00,  1.09item/s]


Epoch: 16, Eval acc: 0.3057324840764331


Epoch 18/20: 100%|██████████| 37/37 [00:22<00:00,  1.68it/s, loss=1.9212428]
Eval: 100%|██████████| 16/16 [00:14<00:00,  1.11item/s]


Epoch: 17, Eval acc: 0.31414012738853503


Epoch 19/20: 100%|██████████| 37/37 [00:21<00:00,  1.69it/s, loss=1.84732]  
Eval: 100%|██████████| 16/16 [00:13<00:00,  1.15item/s]


Epoch: 18, Eval acc: 0.32178343949044586


Epoch 20/20: 100%|██████████| 37/37 [00:22<00:00,  1.67it/s, loss=2.0011415]
Eval: 100%|██████████| 16/16 [00:14<00:00,  1.10item/s]

Epoch: 19, Eval acc: 0.3131210191082803



