In [None]:
import jax
import jax.numpy as jnp
from jax import random
import math

In [3]:
try:
    import flax
except ModuleNotFoundError: # Install flax if missing
    !pip install --quiet flax
    import flax

from flax import linen as nn

In [4]:
class InputEmbeddings(nn.Module):
    model_dimension : int
    vocab_size : int

    def setup(self):
        self.embedding = nn.Embed(num_embeddings=self.vocab_size, features=self.model_dimension)
    
    def __call__(self, x):
        x = self.embedding(x) * math.sqrt(self.model_dimension)
        return x

In [5]:
class FeedForward(nn.Module):
    model_dimension : int
    ff_dim : int
    dropout : float

    def setup(self):
        self.linear1 = nn.Dense(features=self.ff_dim)
        self.linear2 = nn.Dense(features=self.model_dimension) 
    
    def __call__(self, x):
        x = self.linear1(x)
        x = nn.relu(x)
        x = self.linear2(x)
        x = nn.relu(x)
        return x

In [None]:
class ScaledDotProduct(nn.Module):
    dk : int 

    def setup(self):
        self.W = nn.Dense(features=3*self.dk)

    def __call__(self, x):
        qkv = self.W(x)
        q,k,v = jnp.split(qkv, 3, axis=-1)
        weights = jnp.einsum('b t c, b T c -> b t T', q, k) / math.sqrt(self.dk)
        size = weights.shape[-1]
        mask = jnp.tril(jnp.ones((size, size)))
        logits = jnp.where(mask == 0, -9e15, weights)
        attention = nn.softmax(logits, axis=-1)
        values = jnp.einsum('b t T, b T c -> b t c', attention, v)
        return values

In [None]:
class MultiHeadAttention(nn.Module):
    n_heads :int 
    model_dim : int

    def setup(self):
        self.dk = self.model_dim / self.n_heads
        self.SA_layers = [ScaledDotProduct(self.dk) for i in range(self.n_heads)]
        self.WO = nn.Dense(features=self.model_dim)

    def __call__(self, x):
        scores = [layer(x) for layer in self.SA_layers] 
        mha = jnp.concatenate(scores, axis=-1)
        res = self.WO(mha)
        return res

In [None]:
class PositionalEmbeddings(nn.Module):
    model_dimension : int
    seq_len : int
#todo

In [None]:
#einsum version
#defining single head of attention
def ScaledDotProduct(q, k, v):
    dk = q.shape[-1]
    weights = jnp.einsum('b t c, b T c -> b t T', q, k) / math.sqrt(dk)
    size = weights.shape[-1]
    mask = jnp.tril(jnp.ones((size, size)))
    logits = jnp.where(mask == 0, -9e15, weights)
    attention = nn.softmax(logits, axis=-1)
    values = jnp.einsum('b t T, b T c -> b t c', attention, v)
    return values 


#defining single head of attention
def ScaledDotProduct(q, k, v):
    dk = q.shape[-1]
    kT = jnp.transpose(k, (0, 2, 1))
    weights = jnp.matmul(q, kT)
    size = weights.shape[-1]
    weights = weights/math.sqrt(dk)
    mask = jnp.tril(jnp.ones((size, size)))
    logits = jnp.where(mask == 0, -9e15, weights)
    attention = nn.softmax(logits, axis=-1)
    values = jnp.matmul(attention, v)
    return values 