# Attention is All You Need

Coding a transformer from scratch


In [379]:
import numpy as np
import pandas as pd
import tensorflow as tf

### Data preprocessing pipeline


In [380]:
# Get Shakespeares work from Andrej Karpathy's website

url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
filepath = tf.keras.utils.get_file('shakespeare.txt', url)

with open(filepath) as f:
    shakespeare_text = f.read()

In [381]:
# Print the first few characters
print(shakespeare_text[:148])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?



In [382]:
unique_chars = sorted(set(shakespeare_text))
char_to_int = {char: idx for idx, char in enumerate(unique_chars)}

# How many number of distinct characters has the vocabulary:
tokens_len = len(unique_chars)
print(f'Number of tokens in vocabulary: {tokens_len}')

# How many characters has the dataset:
text_length = len(shakespeare_text)
print(f'Total length of text dataset: {text_length}')

Number of tokens in vocabulary: 65
Total length of text dataset: 1115394


### Embedding Layer


In [383]:
embedding_dim = 10

model = tf.keras.Sequential()
model.add(tf.keras.layers.Embedding(tokens_len, embedding_dim))

input_array = np.random.randint(tokens_len, size=(1, 1))
model.compile('rmsprop', 'sparse_categorical_crossentropy')

output_array = model.predict(input_array)
print(output_array.shape)

model.summary()
# (1, 1, 10)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 43ms/step
(1, 1, 10)


### Positional Encoding


In [384]:
class PositionalEncoding(tf.keras.layers.Layer):
    def __init__(self, max_pos_enc, embedding_dim, **kwargs):
        super().__init__(**kwargs)
        self.max_len = max_pos_enc # maximum sequence length that the model can handle
        self.embedding_dim = embedding_dim
        
        # Create the positional encodings
        position = np.arange(max_pos_enc)[:, np.newaxis]
        div_term = np.exp(np.arange(0, embedding_dim, 2) * -(np.log(10000.0) / embedding_dim))
        pe = np.zeros((max_pos_enc, embedding_dim))
        pe[:, 0::2] = np.sin(position * div_term)
        pe[:, 1::2] = np.cos(position * div_term)
        
        # Add batch dimension e.g. (max_pos_enc,embedding_dim) -> (1,max_pos_enc,embedding_dim)
        self.pe = tf.constant(pe[np.newaxis, :, :], dtype=tf.float32)
    
    def call(self, inputs):           
        seq_len = tf.shape(inputs)[1]
        # Add positional encoding, broadcasting across batch dimension
        return inputs + self.pe[:, :seq_len, :]

In [385]:
embedding_dim = 10  # Embedding dimension
max_len = 50  # Maximum sequence length

# Define the model
model = tf.keras.Sequential([
    tf.keras.layers.Embedding(tokens_len, embedding_dim),
    PositionalEncoding(max_pos_enc=max_len, embedding_dim=embedding_dim)
])

input_array = np.random.randint(tokens_len, size=(1, 10)) 
output_array = model(input_array)

print(output_array.shape)  # Should be (1, 10, embedding_dim)

(1, 10, 10)


In [386]:
model.summary()

## Multi-Head Attention


### Scaled Dot-Product Attention


In [387]:
class ScaledDotProductAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def call(self, inputs, mask=None):
        q, k, v = inputs
        
        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        
        # dot product attention
        matmul_qk = tf.matmul(q, k, transpose_b=True) # (bs, q_len, k_len)
        
        # scale dot product
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
        
        # apply mask when necessary
        if mask is not None:
            # adding very large negative values 
            # so they go to zero after softmax
            scaled_attention_logits += (mask * -1e9) 
        
        # apply softmax to attention weights (scores)
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
        
        # multiply by V (values)
        out = tf.matmul(attention_weights, v)
        
        return out, attention_weights
        

### Multi Head Attention Head


In [388]:
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, num_heads, embedding_dim, **kwargs):
        super().__init__(**kwargs)
        
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim
        assert embedding_dim % num_heads == 0, "embedding_dim must be divisible by num_heads"
        self.depth = embedding_dim // num_heads # depth per head
        
        # linear projection layers
        self.wq = tf.keras.layers.Dense(embedding_dim)
        self.wk = tf.keras.layers.Dense(embedding_dim)
        self.wv = tf.keras.layers.Dense(embedding_dim)
        
        # output projection
        self.dense = tf.keras.layers.Dense(embedding_dim)
        
        self.attention = ScaledDotProductAttention()
    
    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])
        
    def call(self, inputs, mask=None):
        q, k, v = inputs
        
        batch_size = tf.shape(q)[0]
        
        # linear projections
        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)
        
        # reshaping q, k, v
        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)
        
        scaled_attention, attention_weights = self.attention([q, k, v], mask)
        
        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.embedding_dim))
        
        out = self.dense(concat_attention)
        
        return out, attention_weights

## Position-wise Feed-Forward Network

$$\text{FFN(x)} = \text{max}(0,~ xW_1 + b_1)W_2 + b_2$$


In [389]:
class PositionwiseFeedForward(tf.keras.layers.Layer):
    def __init__(self, embedding_dim, hidden_dim, **kwargs):
        super().__init__(**kwargs)
        
        # hidden_dim (dff) - feed forward network hidden 
        # layer dimension a.k.a inner layer dimensionality
        
        self.dense1 = tf.keras.layers.Dense(hidden_dim, activation="relu")
        self.dense2 = tf.keras.layers.Dense(embedding_dim)
        
    def call(self, inputs):
        
        x = self.dense1(inputs)
        
        return self.dense2(x)

## Encoder Layer


In [390]:
class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, embedding_dim, hidden_dim, num_heads, dropout_rate = 0.1, **kwargs):
        super().__init__(**kwargs)
        
        self.mha = MultiHeadAttention(num_heads, embedding_dim)
        self.ffn = PositionwiseFeedForward(embedding_dim, hidden_dim)
        
        self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
        
    def call(self, inputs, training=False, mask=None):
        
        # multi-head attention
        attention_output, _ = self.mha([inputs, inputs, inputs], mask)
        attention_output = self.dropout1(attention_output, training=training)
        out1 = self.layer_norm1(inputs + attention_output)
        
        # feed-forward network
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layer_norm2(out1 + ffn_output)
        
        return out2
        

## Encoder

The Encoder stacks multiple encoder layer to create the full encoder. It includes the Embedding layer and Positional Encoding layer as well.


In [391]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, embedding_dim, hidden_dim, num_heads, 
                 tokens_len, max_pos_enc, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
            
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers
        
        self.embedding = tf.keras.layers.Embedding(tokens_len, embedding_dim)
        self.pos_encoding = PositionalEncoding(max_pos_enc, embedding_dim)
        
        self.encoding_layers = [
            EncoderLayer(embedding_dim, hidden_dim, num_heads, dropout_rate=dropout_rate)
            for _ in range(num_layers)
        ]
        
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        
    def call(self, x, training=False, mask=None):
        
        seq_len = tf.shape(x)[1]
        
        x = self.embedding(x)
        x = tf.cast(x, dtype=tf.float32)
        x = self.pos_encoding(x)
        
        x = self.dropout(x, training=training)
        
        for i in range(self.num_layers):
            x = self.encoding_layers[i](x, training=training, mask=mask)
            
        return x

## DecoderLayer


In [392]:
class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, embedding_dim, hidden_dim, num_heads, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        
        # Self-attention
        self.self_attention = MultiHeadAttention(num_heads, embedding_dim)
        self.norm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        
        # Cross-attention (encoder-decoder attention)
        self.cross_attention = MultiHeadAttention(num_heads, embedding_dim)
        self.norm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
        
        # Feed-forward
        self.ffn = PositionwiseFeedForward(embedding_dim, hidden_dim)
        self.norm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout3 = tf.keras.layers.Dropout(dropout_rate)
    
    def call(self, inputs, enc_output, training=False, look_ahead_mask=None, padding_mask=None):
        # Self attention with look-ahead mask
        self_attn_output, _ = self.self_attention([inputs, inputs, inputs], mask=look_ahead_mask)
        self_attn_output = self.dropout1(self_attn_output, training=training)
        out1 = self.norm1(inputs + self_attn_output)
        
        # Cross attention with encoder output
        cross_attn_output, _ = self.cross_attention([out1, enc_output, enc_output], mask=padding_mask)
        cross_attn_output = self.dropout2(cross_attn_output, training=training)
        out2 = self.norm2(out1 + cross_attn_output)
        
        # Feed forward
        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.norm3(out2 + ffn_output)
        
        return out3

## Decoder


In [393]:
class Decoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, embedding_dim, hidden_dim, num_heads,
                 tokens_len, max_pos_enc, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers
        
        self.embedding = tf.keras.layers.Embedding(tokens_len, embedding_dim)
        self.pos_encoding = PositionalEncoding(max_pos_enc, embedding_dim)
        
        self.decoder_layers = [
            DecoderLayer(embedding_dim, hidden_dim, num_heads, dropout_rate)
            for _ in range(num_layers)
        ]
        
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
    
    def call(self, x, enc_output, training=False, look_ahead_mask=None, padding_mask=None):
        seq_len = tf.shape(x)[1]
        
        # embedding and positional encoding
        x = self.embedding(x)
        x = tf.cast(x, dtype=tf.float32)
        x = self.pos_encoding(x)
        
        x = self.dropout(x, training=training)
        
        # Ensure correct mask shape
        if padding_mask is not None:
            padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :]
        
        # decoder layers
        for i in range(self.num_layers):
            x = self.decoder_layers[i](x, enc_output, training=training, look_ahead_mask=look_ahead_mask, padding_mask=padding_mask)
        
        return x

## Transformer


In [394]:
class Transformer(tf.keras.Model):
    def __init__(self, num_layers, embedding_dim, hidden_dim, num_heads,
                input_vocab_size, target_vocab_size, max_pos_enc,
                dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        
        self.encoder = Encoder(
            num_layers, embedding_dim, hidden_dim, num_heads,
            input_vocab_size, max_pos_enc, dropout_rate
        )
        
        self.decoder = Decoder(
            num_layers, embedding_dim, hidden_dim, num_heads,
            target_vocab_size, max_pos_enc, dropout_rate
        )
        
        self.final_layer = tf.keras.layers.Dense(target_vocab_size)
    
    def create_padding_mask(self, seq):
        """Creates a mask for padding tokens (value 0)"""
        seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
        return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)
    
    def create_look_ahead_mask(self, seq_len):
        """Creates a mask to prevent attention to future tokens"""
        mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
        return mask  # (seq_len, seq_len)
    
    def call(self, inputs, training=False):
        inp, tar = inputs
        
        # Create masks
        enc_padding_mask = self.create_padding_mask(inp)
        dec_padding_mask = self.create_padding_mask(inp)
        
        # Look ahead mask for decoder
        look_ahead_mask = self.create_look_ahead_mask(tf.shape(tar)[1])
        dec_target_padding_mask = self.create_padding_mask(tar)
        combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
        
        # Encoder output
        enc_output = self.encoder(inp, training=training, mask=enc_padding_mask)
        
        # Decoder output
        dec_output = self.decoder(
            x=tar,
            enc_output=enc_output,
            training=training,
            look_ahead_mask=combined_mask,
            padding_mask=dec_padding_mask
        )
        
        # Final output
        final_output = self.final_layer(dec_output)
        
        return final_output

In [395]:
num_layers = 4
embedding_dim = 512
hidden_dim = 2048
num_heads = 8
input_vocab_size = 8000
target_vocab_size = 8000
max_pos_enc = 10000
dropout_rate = 0.1

# Create model
transformer = Transformer(
    num_layers, embedding_dim, hidden_dim, num_heads,
    input_vocab_size, target_vocab_size, max_pos_enc, dropout_rate
)

# Compile model
transformer.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

# Provide a sample input to build the model
# Fixed sample input - specify maxval for integer type
sample_input = (
    tf.random.uniform((1, 10), maxval=input_vocab_size, dtype=tf.int32), 
    tf.random.uniform((1, 10), maxval=target_vocab_size, dtype=tf.int32)
)
transformer(sample_input)

# Print model summary
print(transformer.summary())

1. The `call()` method of your layer may be crashing. Try to `__call__()` the layer eagerly on some test input first to see if it works. E.g. `x = np.random.random((3, 4)); y = layer(x)`
2. If the `call()` method is correct, then you may need to implement the `def build(self, input_shape)` method on your layer. It should create all variables used by the layer (e.g. by calling `layer.build()` on all its children layers).
Exception encountered: ''Dimension must be 6 but is 4 for '{{node transpose_3}} = Transpose[T=DT_FLOAT, Tperm=DT_INT32](scaled_dot_product_attention_331_1/MatMul_1, transpose_3/perm)' with input shapes: [1,1,1,8,10,64], [4].''
1. The `call()` method of your layer may be crashing. Try to `__call__()` the layer eagerly on some test input first to see if it works. E.g. `x = np.random.random((3, 4)); y = layer(x)`
2. If the `call()` method is correct, then you may need to implement the `def build(self, input_shape)` method on your layer. It should create all variables used 

InvalidArgumentError: Exception encountered when calling MultiHeadAttention.call().

[1m{{function_node __wrapped__Transpose_device_/job:localhost/replica:0/task:0/device:CPU:0}} transpose expects a vector of size 6. But input(1) is a vector of size 4 [Op:Transpose][0m

Arguments received by MultiHeadAttention.call():
  • inputs=['tf.Tensor(shape=(1, 10, 512), dtype=float32)', 'tf.Tensor(shape=(1, 10, 512), dtype=float32)', 'tf.Tensor(shape=(1, 10, 512), dtype=float32)']
  • mask=tf.Tensor(shape=(1, 1, 1, 1, 1, 10), dtype=float32)