## Transformer

#### Import Libraries

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import random, grad, jit
from jax.nn import softmax, relu

#### Get Embeddings and Positional Encoding

In [None]:
def get_embeddings(vocab_size, embed_dim):
    return random.normal(random.PRNGKey(0), (vocab_size, embed_dim))

def get_positional_encoding(seq_len, embed_dim):
    pe = np.zeros((seq_len, embed_dim))
    for pos in range(seq_len):
        for i in range(0, embed_dim, 2):
            pe[pos, i] = np.sin(pos / (10000 ** (2 * i / embed_dim)))
            pe[pos, i + 1] = np.cos(pos / (10000 ** (2 * i / embed_dim)))
    return np.array(pe)

#### Multi-Head Attention

In [None]:
def scaled_dot_product_attention(Q, K, V):
    d_k = Q.shape[1]
    scores = jnp.dot(Q, K.T) / jnp.sqrt(d_k)
    weights = softmax(scores, axis=-1)
    return jnp.dot(weights, V)

def multi_head_attention(Q, K, V, num_heads, key):
    d_model = Q.shape[1]
    assert d_model % num_heads == 0
    depth = d_model // num_heads
    keys = random.split(key, num_heads)
    
    Q_split = jnp.concatenate([Q @ random.normal(keys[i], (d_model, depth)) for i in range(num_heads)], axis=0)
    K_split = jnp.concatenate([K @ random.normal(keys[i], (d_model, depth)) for i in range(num_heads)], axis=0)
    V_split = jnp.concatenate([V @ random.normal(keys[i], (d_model, depth)) for i in range(num_heads)], axis=0)
    
    attention_heads = [scaled_dot_product_attention(Q_split[i], K_split[i], V_split[i]) for i in range(num_heads)]
    concat_attention = jnp.concatenate(attention_heads, axis=-1)
    
    W_O = random.normal(key, (d_model, d_model))
    return concat_attention @ W_O

#### Encoder and Decoder Layers

In [None]:
def encoder_layer(X, num_heads, key):
    attn_output = multi_head_attention(X, X, X, num_heads, key)
    attn_output += X  # Add residual connection
    ff_output = relu(attn_output @ random.normal(key, (attn_output.shape[-1], attn_output.shape[-1])))
    return ff_output

def decoder_layer(Y, encoder_output, num_heads, key):
    self_attn_output = multi_head_attention(Y, Y, Y, num_heads, key)
    cross_attn_output = multi_head_attention(self_attn_output, encoder_output, encoder_output, num_heads, key)
    cross_attn_output += Y  # Add residual connection
    ff_output = relu(cross_attn_output @ random.normal(key, (cross_attn_output.shape[-1], cross_attn_output.shape[-1])))
    return ff_output

#### Model Training

In [None]:
def transformer_model(X, Y, vocab_size, embed_dim, seq_len_encoder, seq_len_decoder, num_heads, num_layers, key):
    encoder_embeddings = get_embeddings(vocab_size, embed_dim, key)
    decoder_embeddings = get_embeddings(vocab_size, embed_dim, key)
    encoder_positional_encoding = get_positional_encoding(seq_len_encoder, embed_dim)
    decoder_positional_encoding = get_positional_encoding(seq_len_decoder, embed_dim)

    # Apply embeddings and positional encodings to encoder and decoder inputs
    X_embed = encoder_embeddings[X] + encoder_positional_encoding  # Shape: (1, 7, 16)
    Y_embed = decoder_embeddings[Y] + decoder_positional_encoding  # Shape: (1, 4, 16)

    # Encoder pass
    encoder_output = X_embed
    for _ in range(num_layers):
        encoder_output = encoder_layer(encoder_output, num_heads, key)

    # Decoder pass
    decoder_output = Y_embed
    for _ in range(num_layers):
        decoder_output = decoder_layer(decoder_output, encoder_output, num_heads, key)
    
    # Final output layer (logits)
    logits = softmax(decoder_output @ encoder_embeddings.T)  # Shape: (1, 4, vocab_size)
    return logits

def loss_fn(logits, labels):
    return -jnp.mean(jnp.sum(labels * jnp.log(logits), axis=-1))

key = random.PRNGKey(0)
embed_dim = 16
seq_len_encoder = X_train.shape[1]
seq_len_decoder = Y_train.shape[1]
vocab_size = max(token_map.values()) + 1
num_heads = 2
num_layers = 2

for epoch in range(1000):
    logits = transformer_model(X_train, Y_train, vocab_size, embed_dim, seq_len_encoder, seq_len_decoder, num_heads, num_layers, key)
    loss = loss_fn(logits, Y_train)
    grads = grad(loss_fn)(logits, Y_train)

    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss}")