## Transformer

#### Import Libraries

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import random, grad, jit
from jax.nn import softmax, relu

#### Token Map

In [2]:
import string

# Define a token map with alphanumeric characters and common symbols
characters = string.digits + string.punctuation + " "
token_map = {char: idx for idx, char in enumerate(characters)}
vocab_size = len(token_map)

def encode_sequence(sequence, token_map):
    return [token_map.get(str(token)) for token in sequence]

#### Get Embeddings and Positional Encoding

In [3]:
def get_embeddings(vocab_size, embed_dim, key):
    return random.normal(key, (vocab_size, embed_dim))

def get_positional_encoding(seq_len, embed_dim):
    pe = np.zeros((seq_len, embed_dim))
    for pos in range(seq_len):
        for i in range(0, embed_dim, 2):
            pe[pos, i] = np.sin(pos / (10000 ** (2 * i / embed_dim)))
            pe[pos, i + 1] = np.cos(pos / (10000 ** (2 * i / embed_dim)))
    return np.array(pe)

#### Multi-Head Attention

In [26]:
def scaled_dot_product_attention(Q, K, V):
    d_k = Q.shape[1]
    scores = jnp.dot(Q, K.T) / jnp.sqrt(d_k)
    weights = softmax(scores, axis=-1)
    return jnp.dot(weights, V)

def multi_head_attention(Q_split, K_split, V_split, num_heads, W_O):
    attention_heads = [scaled_dot_product_attention(Q_split[i], K_split[i], V_split[i]) for i in range(num_heads)]
    concat_attention = jnp.concatenate(attention_heads, axis=-1)
    return concat_attention @ W_O

#### Encoder and Decoder Layers

In [None]:
def encoder_layer(X, num_heads, key, W_O, Q, K, V):
    attn_output = multi_head_attention(Q, K, V, num_heads, W_O)
    attn_output += X  # Add residual connection
    ff_output = relu(attn_output @ random.normal(key, (attn_output.shape[-1], attn_output.shape[-1])))
    return ff_output

def decoder_layer(Y, encoder_output, num_heads, key, W_O, Q, K, V):
    self_attn_output = multi_head_attention(Q, K, V, num_heads, W_O)
    cross_attn_output = multi_head_attention(self_attn_output, encoder_output, encoder_output, num_heads, W_O)
    cross_attn_output += Y  # Add residual connection
    ff_output = relu(cross_attn_output @ random.normal(key, (cross_attn_output.shape[-1], cross_attn_output.shape[-1])))
    return ff_output

#### Model Training

In [28]:
def encode_sequence(sequence, token_map):
    # Convert sequence to list of token indices
    return [token_map.get(str(token)) for token in sequence]

def loss_fn(logits, labels):
    return -jnp.mean(jnp.sum(labels * jnp.log(logits), axis=-1))

In [None]:
key = random.PRNGKey(0)
embed_dim = 16
vocab_size = max(token_map.values()) + 1
num_heads = 2
num_layers = 2

X_train_raw = encode_sequence([1,3,2,"+",9,4,2], token_map)
Y_train_raw = encode_sequence([1,0,7,4], token_map)
seq_len_encoder = len(X_train_raw)
seq_len_decoder = len(Y_train_raw)

W_O = random.normal(random.PRNGKey(0), (embed_dim*num_heads, embed_dim*num_heads))

embeddings = get_embeddings(vocab_size, embed_dim, key)

encoder_embeddings = embeddings[np.array(X_train_raw)]
decoder_embeddings = embeddings[np.array(Y_train_raw)]
encoder_positional_encoding = get_positional_encoding(seq_len_encoder, embed_dim)
decoder_positional_encoding = get_positional_encoding(seq_len_decoder, embed_dim)

# Apply embeddings and positional encodings to encoder and decoder inputs
X_embed = encoder_embeddings + encoder_positional_encoding
Y_embed = decoder_embeddings + decoder_positional_encoding


W_Q_E = {}; W_Q_D = {}; Q_E = {}; Q_D = {}
W_K_E = {}; W_K_D = {}; K_E = {}; K_D = {}
W_V_E = {}; W_V_D = {}; V_E = {}; V_D = {}

for i in range(num_heads):
    W_Q_E[i] = random.normal(random.PRNGKey(i), (embed_dim, embed_dim)); W_Q_D[i] = random.normal(random.PRNGKey(i), (embed_dim, embed_dim))
    W_K_E[i] = random.normal(random.PRNGKey(i), (embed_dim, embed_dim)); W_K_D[i] = random.normal(random.PRNGKey(i), (embed_dim, embed_dim))
    W_V_E[i] = random.normal(random.PRNGKey(i), (embed_dim, embed_dim)); W_V_D[i] = random.normal(random.PRNGKey(i), (embed_dim, embed_dim))
    Q_E[i] = X_embed @ W_Q_E[i]; Q_D[i] = X_embed @ W_Q_D[i]
    K_E[i] = X_embed @ W_K_E[i]; K_D[i] = X_embed @ W_K_D[i]
    V_E[i] = X_embed @ W_V_E[i]; V_D[i] = X_embed @ W_V_D[i]


for epoch in range(1000):

    # Encoder pass
    encoder_output = X_embed
    for _ in range(num_layers):
        encoder_output = encoder_layer(encoder_output, num_heads, key, W_O, Q_E, K_E, V_E)

    # Decoder pass
    decoder_output = Y_embed
    for _ in range(num_layers):
        decoder_output = decoder_layer(decoder_output, encoder_output, num_heads, key, W_O, Q_D, K_D, V_D)
    
    # Final output layer (logits)
    logits = softmax(decoder_output @ encoder_embeddings.T)
    
    loss = loss_fn(logits, Y_train)
    grads = grad(loss_fn)(logits, Y_train)

    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss}")

In [59]:
token_map

{'0': 0,
 '1': 1,
 '2': 2,
 '3': 3,
 '4': 4,
 '5': 5,
 '6': 6,
 '7': 7,
 '8': 8,
 '9': 9,
 '!': 10,
 '"': 11,
 '#': 12,
 '$': 13,
 '%': 14,
 '&': 15,
 "'": 16,
 '(': 17,
 ')': 18,
 '*': 19,
 '+': 20,
 ',': 21,
 '-': 22,
 '.': 23,
 '/': 24,
 ':': 25,
 ';': 26,
 '<': 27,
 '=': 28,
 '>': 29,
 '?': 30,
 '@': 31,
 '[': 32,
 '\\': 33,
 ']': 34,
 '^': 35,
 '_': 36,
 '`': 37,
 '{': 38,
 '|': 39,
 '}': 40,
 '~': 41,
 ' ': 42}