In [1]:
import keras
from keras import layers
from keras import ops
from keras.utils import pad_sequences
import keras_hub
from keras.callbacks import EarlyStopping
import tensorflow as tf
import os
os.environ['KERAS_BACKEND'] = 'jax'

  from .autonotebook import tqdm as notebook_tqdm


# TIME FOR THE TRANSFORMER!!!

Woohoo, I'm finally up to this point and I'm super excited. First some readings:

1. [3Blue1Brown](https://www.youtube.com/watch?v=wjZofJX0v4M&ab_channel=3Blue1Brown)
2. [Stat Quest](https://www.youtube.com/watch?v=zxQyTK8quyY&ab_channel=StatQuestwithJoshStarmer)
3. [The Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)
4. [Attention is All you Need](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)


In [2]:
EMBED_DIM = 100
VOCAB_SIZE = 1000
MAX_SEQUENCE_LEN = 128
FFN_DIM = 4 * EMBED_DIM
NUM_HEADS = 2

(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(
    num_words=VOCAB_SIZE,
    maxlen=MAX_SEQUENCE_LEN,
)

x_train = pad_sequences(x_train, MAX_SEQUENCE_LEN)
x_test = pad_sequences(x_test, MAX_SEQUENCE_LEN)

In [None]:
class myMultiHeadAttention(keras.layers.Layer):
    def __init__(self,num_heads, key_dim, value_dim=None):
        super().__init__()
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.value_dim = value_dim if value_dim is not None else key_dim
        
    def build(self, input_shape):
        embed_dim = input_shape[-1]
        self.permute1 = keras.layers.Permute((2,1,3))
        self.permute2 = keras.layers.Permute((1,3,2))
        self.wq = self.add_weight(
            shape=(embed_dim,self.key_dim*self.num_heads),
            initializer='random_normal',
            trainable=True
        )
        self.wk = self.add_weight(
            shape=(embed_dim,self.key_dim*self.num_heads),
            initializer='random_normal',
            trainable=True
        )
        self.wv = self.add_weight(
            shape=(embed_dim, self.value_dim * self.num_heads),
            initializer='random_normal',
            trainable=True
        )
        self.wo = self.add_weight(
            shape=(self.value_dim * self.num_heads, embed_dim),
            initializer='glorot_uniform',
            trainable=True,
            name='wo'
        ) 
        
    
    def call(self,query,key,value, mask=None):
        batch_size = ops.shape(query)[0]
        seq_len_q = ops.shape(query)[1]
        seq_len_kv = ops.shape(key)[1]
        
        queries = ops.matmul(query, self.wq) # BxNxE x ExKn = BxNxKn
        keys = ops.matmul(key, self.wk) #BxNxE x ExKn = BxNxKn
        values = ops.matmul(value, self.wv) #BxNxE x ExVn = BxNxVn
        
        queries = ops.reshape(queries, (batch_size, seq_len_q, self.num_heads, self.key_dim)) #BxNxnxK
        keys = ops.reshape(keys, (batch_size, seq_len_kv, self.num_heads, self.key_dim))
        values = ops.reshape(values, (batch_size, seq_len_kv, self.num_heads, self.key_dim))# BxNxnxV
        
        queries = self.permute1(queries) #BxnxNxK
        keys = self.permute1(keys) #BxnxNxK
        values = self.permute1(values) #BxnxNxV
        
        keys = self.permute2(keys) # BxnxKxN
        
        scores = ops.matmul(queries, keys) # BxnxNxN
        
        scaling_factor = ops.sqrt(ops.cast(self.key_dim, dtype=queries.dtype))
        normalized_scores = scores / scaling_factor
        if mask is not None:
            # Original mask shape: (Nq, Nkv)
            # This allows it to broadcast correctly to (B, H, Nq, Nkv)
            mask = mask[None, None, :, :]
            normalized_scores += (1.0 - ops.cast(mask, dtype=normalized_scores.dtype)) * -1e9
        attentions = keras.activations.softmax(normalized_scores) # BxnxNxN
        attentions = ops.matmul(attentions, values) #BxnxNxN x BxnxNxV = BxnxNxV
        
        attentions = self.permute1(attentions) #BxNxnxV
        concatenated_output = ops.reshape(attentions, (batch_size, seq_len_q, self.value_dim*self.num_heads)) #BxNxVn
        
        attentions = ops.matmul(concatenated_output, self.wo) #BxNxVn x BxVnxE = BxNxE
        return attentions
    

In [19]:
class myEncoderBlock(keras.layers.Layer):
    def __init__(self,embed_dim,num_heads,ffn_dim, rate=0.1, **kwargs):
        super().__init__(**kwargs)
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ffn_dim = ffn_dim
        self.rate = rate
        self.mha_layer = myMultiHeadAttention(
            num_heads=self.num_heads,
            key_dim=self.embed_dim // self.num_heads
        )
        
        self.dense1 = keras.layers.Dense(4*embed_dim, activation='relu')
        self.dense2 = keras.layers.Dense(self.embed_dim)
        
        self.layernorm_1 = keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm_2 = keras.layers.LayerNormalization(epsilon=1e-6)
        
        self.add_1 = keras.layers.Add()
        self.add_2 = keras.layers.Add()

        
    def call(self, inputs):
        
        seq_len = ops.shape(inputs)[1]
        
        causal_mask = ops.logical_not(ops.triu(ops.ones((seq_len, seq_len)), k=1))
        
        mha_output = self.mha_layer(
            query=inputs,
            key=inputs,
            value=inputs,
            mask=causal_mask
        )
        
        add_norm_1 = self.add_1([inputs, mha_output])
        add_norm_1 = self.layernorm_1(add_norm_1)
        
        ffn_output = self.dense1(add_norm_1)
        ffn_output = self.dense2(ffn_output)

        add_norm_2 = self.add_2([add_norm_1, ffn_output])
        normalized_outputs = self.layernorm_2(add_norm_2)
        
        return normalized_outputs

In [20]:
inputs = keras.Input(shape=(MAX_SEQUENCE_LEN,),dtype='int32')

#embedding
token_embeddings = keras.layers.Embedding(
    input_dim = VOCAB_SIZE, 
    output_dim =EMBED_DIM
)(inputs)
position_embeddings = keras_hub.layers.PositionEmbedding(
    sequence_length=MAX_SEQUENCE_LEN,
)(token_embeddings)
embeddings = token_embeddings + position_embeddings

#encoder block
x = myEncoderBlock(EMBED_DIM, NUM_HEADS, FFN_DIM)(embeddings)
x = myEncoderBlock(EMBED_DIM, NUM_HEADS, FFN_DIM)(x)

pooled_output = keras.layers.GlobalAveragePooling1D()(x)
outputs = keras.layers.Dense(1,)(pooled_output)

model = keras.Model(inputs=inputs, outputs=outputs, name='imdb_model')
model.summary()

In [21]:
early_stopping_callback = keras.callbacks.EarlyStopping(
    monitor='val_loss',     
    patience=3,      
    restore_best_weights=True 
)

model.compile(
    loss= keras.losses.BinaryCrossentropy(from_logits=True),
    optimizer = keras.optimizers.RMSprop(),
    metrics=['accuracy'],
)

history = model.fit(
    x_train, 
    y_train, 
    epochs=15, 
    batch_size=64, 
    validation_split=0.2, 
    callbacks=[early_stopping_callback])

test_scores=model.evaluate(x_test, y_test, verbose=1)

print(f'Test Loss: {test_scores[0]}')
print(f'Test Accuracy: {test_scores[1]}')

Epoch 1/15
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 210ms/step - accuracy: 0.4903 - loss: 0.9317 - val_accuracy: 0.4679 - val_loss: 0.6832
Epoch 2/15
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 206ms/step - accuracy: 0.5214 - loss: 0.7064 - val_accuracy: 0.4739 - val_loss: 0.6351
Epoch 3/15
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 239ms/step - accuracy: 0.6153 - loss: 0.6191 - val_accuracy: 0.7755 - val_loss: 0.4928
Epoch 4/15
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 212ms/step - accuracy: 0.7328 - loss: 0.5090 - val_accuracy: 0.7584 - val_loss: 0.4533
Epoch 5/15
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 204ms/step - accuracy: 0.7751 - loss: 0.4558 - val_accuracy: 0.7558 - val_loss: 0.4449
Epoch 6/15
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 204ms/step - accuracy: 0.8097 - loss: 0.4002 - val_accuracy: 0.7952 - val_loss: 0.4448
Epoch 7/15
[1m73/73[

## Takeaways

So first the single head attention block was less complicated than I expected since it's all just matrix multiplications. The biggest hiccup was probably when I chose to expand to multi-head because I found online that I was better off doing some matrix manipulations to multiplying big flat matrixes rather than just parallelize it with a 'head' dimension (I was tempted though). I didn't really consider cross attention in this model, but I think that's just debugging and throwing some variables around

Additionally, I learned while building this that the transformer architecture is not agreed upon, and the most that the native keras api can give us without becoming opinionated is the multi-head attention layer, which truthfully sticks to the "Attention is All you Need" paper and remains a proven, robust fundamental building block.

I built both the mha_layer and the encoder. The MHA is mathematically pretty standard but the encoder can be built in many different ways like:
- Pre vs post normalization (I used post as the paper did)
- Feed Forward network activation choices