In [2]:
import jax
import jax.numpy as jnp
from jax import random
import math
from typing import Callable

In [12]:
try:
    import flax
except ModuleNotFoundError: # Install flax if missing
    !pip install --quiet flax
    import flax

from flax import linen as nn

In [3]:
class Embeddings(nn.Module):
    model_dimension : int
    vocab_size : int

    def setup(self):
        self.embedding = nn.Embed(num_embeddings=self.vocab_size, features=self.model_dimension)
    
    def __call__(self, x):
        x = self.embedding(x) * math.sqrt(self.model_dimension)
        return x

In [4]:
class PositionalEncoding(nn.Module):
    model_dimension : int
    seq_len : int

    def setup(self):
        zeros = jnp.zeros((self.seq_len, self.model_dimension))
        row = jnp.arange(0, self.seq_len, 1)
        pos_matrix = row.reshape(-1, 1)
        row2 = jnp.arange(0, self.model_dimension, 1)
        pos2 = zeros + row2
        j = pos2 // 2
        denom = 1/(10000 ** ((2*j)/self.model_dimension))
        inp = pos_matrix * denom
        pi2j = jnp.sin(inp)
        pi2j1 = jnp.cos(inp)
        zeros = zeros.at[:, 0::2].set(pi2j[:, ::2])
        zeros = zeros.at[:, 1::2].set(pi2j1[:, 1: :2])
        self.encodings = zeros
        
    def __call__(self, x):
        return x + self.encodings

In [5]:
class ScaledDotProduct(nn.Module):
    dk : int 

    def setup(self):
        self.W = nn.Dense(features=3*self.dk)

    def __call__(self, x):
        qkv = self.W(x)
        q,k,v = jnp.split(qkv, 3, axis=-1)
        weights = jnp.einsum('b t c, b T c -> b t T', q, k) / math.sqrt(self.dk)
        size = weights.shape[-1]
        mask = jnp.tril(jnp.ones((size, size)))
        logits = jnp.where(mask == 0, -9e15, weights)
        attention = nn.softmax(logits, axis=-1)
        values = jnp.einsum('b t T, b T c -> b t c', attention, v)
        return values

In [None]:
class MultiHeadLatentAttention(nn.Module):
    n_heads : int
    model_dim : int 
    d_c : int

    def setup(self):
        self.dh = self.model_dim // self.n_heads
        self.Wdkv = nn.Dense(features=self.model_dim) #dc, d
        self.Wuk = nn.Dense(features=self.d_c) #(dh nh x dc)
        self.Wuv = nn.Dense(features=self.d_c)

        self.Wdq = nn.Dense(features=self.model_dim)
        self.Wuq = nn.Dense(self.d_c)
        self.Wo = nn.Dense(self.model_dim)

        self.cKV = None
        self.token_ind = 0

    def __call__(self, x):
        x = x[:, -1:, :]
        B, T, C = x.shape

        if (self.cKV == None):
            self.cKV = jnp.zeros((B, self.max_token_len, self.dk))
        
        cKVt = self.Wdkv(x) #dc D, D -> dc
        self.key_cache = self.key_cache.at[:, self.token_ind:self.token_ind + T, :].set(cKVt)
        k = self.Wuk(cKVt) #(dn*dh) dc, dc -> (dn*dh)
        v = self.Wuv(cKVt) #(dn*nh) dc, dc -> (dn*dh)
        cQt = self.Wdq(x) #dc' D, D -> dc'
        q = self.Wuq(cQt) # * (dn*nh) dc', dc' -> (dn*nh)

        self.token_ind = min(self.token_ind + T, self.max_token_len-1)

        if self.token_ind == self.max_token_len - 1:
            self.key_cache.at[:,0:-1,:].set(self.key_cache[:, 1:, :])
            self.value_cache.at[:,0:-1,:].set(self.value_cache[:, 1:, :])
        
        weights = jnp.einsum('b t c, b T c -> b t T', q, k) / math.sqrt(self.dk)
        size = weights.shape[-1]
        mask = jnp.tril(jnp.ones((size, size)))
        logits = jnp.where(mask == 0, -9e15, weights)
        attention = nn.softmax(logits, axis=-1)
        values = jnp.einsum('b t T, b T c -> b t c', attention, v)

        return values

In [None]:
class MultiHeadAttention(nn.Module):
    n_heads :int 
    model_dim : int

    def setup(self):
        self.dk = self.model_dim // self.n_heads
        self.SA_layers = [ScaledDotProduct(self.dk) for i in range(self.n_heads)]
        self.WO = nn.Dense(features=self.model_dim)

    def __call__(self, x):
        scores = [layer(x) for layer in self.SA_layers] 
        mha = jnp.concatenate(scores, axis=-1)
        res = self.WO(mha)
        return res

In [7]:
class LayerNorm(nn.Module):
    model_dimension : int
    gamma_init : Callable = nn.initializers.lecun_normal()
    beta_init : Callable = nn.initializers.lecun_normal()

    def setup(self):
        self.gamma = self.param('gamma', self.gamma_init, self.model_dimension)
        self.beta = self.param('beta',self.beta_init, self.model_dimension)
        self.eps = 1e-05
    
    def __call__(self, x):
        mean = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.var(x, axis=-1, keepdims=True)
        norm = ((x - mean)/jnp.sqrt(var + self.eps))
        y = jnp.einsum('B T C, C -> B T C', norm, self.gamma) + self.beta[None, None, :]
        
        return y

In [8]:
class Expert(nn.Module):
    model_dimension : int
    ff_dim : int
    dropout : float

    def setup(self):
        self.linear1 = nn.Dense(features=self.ff_dim)
        self.linear2 = nn.Dense(features=self.model_dimension)
        self.dropout = nn.Dropout(rate=self.dropout) 
    
    def __call__(self, x):
        x = self.linear1(x)
        x = nn.relu(x)
        x = nn.dropout(x)
        x = self.linear2(x)
        return x

In [9]:
class SoftmaxGate(nn.Module):
    model_dimension : int
    n_experts : int
    
    def setup(self):
        self.Wg = nn.Dense(features=self.n_experts)
    
    def __call__(self, x):
        g = self.Wg(x)
        gate = nn.softmax(g, axis=-1)
        return gate

In [10]:
class NoisyKGate(nn.Module):
    model_dimension : int
    n_experts : int
    k : int
    dropout : float

    def setup(self):
        self.rng = jax.random.PRNGKey(42)
        self.Wg = nn.Dense(features=self.n_experts)
        self.Wnoise = nn.Dense(features=self.n_experts)
    
    def top(self, x):
        k = self.k
        y,i = jax.lax.top_k(x, k)
        y = nn.softmax(y)
        # z = jnp.ones(x.shape) * -jnp.inf
        # z = z.at[i].set(y)
        return y, i
    
    def __call__(self, x):
        #B,T,N
        b = x.shape[0]
        t = x.shape[1]
        Hx = self.Wg(x) + ((jax.random.normal(self.rng, shape=(b, t, self.n_experts))) * nn.softplus(self.Wnoise(x)))
        g_scores, indices = jnp.apply_along_axis(func1d=self.top, axis=-1, arr=Hx)
        return g_scores, indices

In [11]:
class MoE(nn.Module):
    model_dimension : int
    n_experts : int
    k : int
    dropout : float
    
    def setup(self):
        self.experts = [Expert(model_dimension=self.model_dimension, ff_dim=4*self.model_dimension, dropout=self.dropout) for i in range(self.n_experts)]
        self.gate = NoisyKGate(model_dim=self.model_dimension, num_experts=self.n_experts, k=self.k)

    def gScores(self, scores, indices, x):

        expert = lambda i : self.experts[i](x) # (C) -> (C)
        expert_parallel = jax.vmap(fun=expert, in_axes=(0), out_axes=(0)) 

        experts = expert_parallel(indices) # (K) -> (K, C)
        gscore = scores[:, None] * experts #(K, 1), (K, C) -> (K, C)
        gscore = jnp.sum(gscore, axis=0) #(K, C) -> C
        return gscore
    
    def __call__(self, x):
        s, i= self.gate(x)
        gscore_parallel = jax.vmap(fun=jax.vmap(fun=self.gScores, in_axes=(0,0,0), out_axes=(0)), in_axis=(0,0,0), out_axes=(0))
        res = gscore_parallel(s, i, x)
        return res


In [12]:
class FeedForward(nn.Module):
    model_dimension : int
    ff_dim : int
    dropout : float

    def setup(self):
        self.linear1 = nn.Dense(features=self.ff_dim)
        self.linear2 = nn.Dense(features=self.model_dimension)
        self.dropout = nn.Dropout(rate=self.dropout) 
    
    def __call__(self, x):
        x = self.linear1(x)
        x = nn.relu(x)
        x = nn.dropout(x)
        x = self.linear2(x)
        return x

In [13]:
class Block(nn.Module):
    model_dimension : int
    n_heads : int
    dropount : float
    
    def setup(self):
        self.attention = MultiHeadAttention(model_dim=self.model_dimension, n_heads=self.n_heads)
        self.norm1 = LayerNorm(model_dimension=self.model_dimension)
        self.norm2 = LayerNorm(model_dimension=self.model_dimension)
        self.feedForward = FeedForward(model_dimension=self.model_dimension, ff_dim=4*self.model_dimension, dropout=self.dropout)
    
    def __call__(self, x):
        x = self.norm1(x + self.attention(x))
        x = self.norm2(x + self.feedForward(x))
        return x

In [14]:
class Decoder(nn.Module):
    model_dimension : int
    n_heads : int
    seq_len : int
    vocab_size : int
    dropout : float
    blocks : int

    def setup(self):
        self.embeddingTable = Embeddings(model_dimension=self.model_dimension, vocab_size=self.vocab_size)
        self.Blocks = [Block(model_dimension=self.model_dimension, n_heads=self.n_heads, dropount=self.dropout) for i in range(self.blocks)]
        self.encodings = PositionalEncoding(model_dimension=self.model_dimension, seq_len=self.seq_len)
        self.linear = nn.Dense(features=self.vocab_size)

    def __call__(self, x):
        #B,T
        x = self.embeddingTable(x) #B,T,C
        x = self.encodings(x) 
        x = [Block(x) for Block in self.blocks]
        x = self.linear(x)
        result = nn.softmax(x, axis=-1)
        return result

In [None]:

from einops import rearrange

class MHA_KV_Cache(nn.Module):
    model_dim : int 
    n_heads : int
    max_tokens : int
    
    def setup(self):
        self.QKVWeights = nn.Dense(features=3*self.model_dim)
        self.dk = self.model_dim // self.n_heads
        self.output = nn.Dense(features=self.model_dim)
        self.key_cache = None
        self.value_cache = None
        self.kv_ind = 0

    def __call__(self, x, train:True):
        if train == False:
            x = x[:, -1:, :]

        B, T ,C = x.shape
        qkv = self.QKVWeights(x) 
        qkv = rearrange(qkv, 'B T (c nh dk) -> B nh T (c dk)', c=3, nh=self.n_heads, dk=self.dk)
        q,k,v = jnp.split(qkv, 3, axis=-1)

        if train == False:
            if self.key_cache == None or self.value_cache == None:
                self.key_cache = nn.zeros((B, self.n_heads, self.max_tokens, self.dk))
                self.value_cache = nn.zeros((B, self.n_heads, self.max_tokens, self.dk))
            
            self.key_cache.at[:, :, self.kv_ind:self.kv_ind+T, :].set(k)
            self.value_cache.at[:, :, self.kv_ind:self.kv_ind+T, :].set(v)

            k = self.key_cache[:, :, :self.kv_ind+T, :]
            v = self.value_cache[:, :, :self.kv_ind+T, :]

            self.kv_ind = min(self.kv_ind + T, self.max_tokens-1)

            if (self.kv_ind == self.max_tokens-1):
                self.key_cache.at[:, :, :-1, :].set(self.key_cache[:, :, 1:, :])
                self.value_cache.at[:, :, :-1, :].set(self.value_cache[:, :, 1:, :])
        
        weights = jnp.einsum('B nh T dk, B nh t dk -> B nh T t', q, k) * (1/ ((self.dk) ** 0.5))
        weights = nn.softmax(weights, axis=-1)
        attention = jnp.einsum('B nh T t, B nh t dk -> B nh T dk', weights, v)
        attention = rearrange(attention, 'B nh T dk, B T (nh dk)')
        output = self.output(attention)
        
        return output

In [None]:
class Dense:
    def __init__(self, in_features, out_features):
        self.rng_key = jax.random.key(0)
        init_key1, init_key2 = jax.split(self.rng_key)
        self.Weights = jax.random.normal(init_key1, (in_features, out_features))
        
        self.bias = jnp.random.normal(init_key2, (1, out_features))

    def __call__(self, x):
        y = jnp.einsum('B i, i o -> B o', x, self.Weights) + self.bias
        return y

In [None]:
class RoPE(nn.Module):
    model_dim : int
    t : int

    def setup(self):
        zeros = jnp.zeros((self.t,self.model_dim))
        self.array_m = jnp.arange(self.t)[:, None] + zeros
        self.pos = ((jnp.arange(self.d)[None, :] + zeros) // 2)
        self.theta = 10000*(-2*(self.pos - 1)/self.model_dim)
        self.matrix1 = jnp.cos(self.array_m * self.theta)
        self.matrix2 = jnp.sin(self.array_m * self.theta)
        
    def __call__(self, x):
        a = jnp.reshape(x, (-1, 2, self.model_dim))
        a = a[:, ::-1, :]
        a = jnp.reshape(a, (a.shape[0] * a.shape[1], self.d))
        result = (x * self.matrix1) + (a * self.matrix2)
        return result

In [99]:
from einops import rearrange

class MLA(nn.Module):
    model_dim : int 
    n_heads : int
    max_tokens : int
    latent_dim : int 
    
    def setup(self):
        self.Wdkv = nn.Dense(features=self.latent_dim)
        self.Wuk = nn.Dense(features=self.model_dim)
        self.Wuv = nn.Dense(features=self.model_dim)
        self.Wdq = nn.Dense(feauters=self.latent_dim)
        self.Wuq = nn.Dense(features=self.model_dim)

        self.dk = self.model_dim // self.n_heads
        self.output = nn.Dense(features=self.model_dim) 
        self.cKV_cache = None
        self.cache_ind = 0

    def __call__(self, x, train:True):
        if train == False:
            x = x[:, -1:, :]
        
        cKVt = self.Wdkv(x)
        cqt = self.Wdq(x)

        B, T ,C = x.shape 
        if train == False:
            if self.cKV_cache == None:
                self.cKV_cache = nn.zeros((B, self.n_heads, self.max_tokens, self.dc))
            
            self.cKV_cache.at[:, :, self.cache_ind:self.cache_ind+T, :].set(cKVt)

            cKVt = self.cKV_cache[:, :, :self.cache_ind+T, :]

            self.cache_ind = min(self.cache_ind + T, self.max_tokens-1)

            if (self.cache_ind == self.max_tokens-1):
                self.cKV_cache.at[:, :, :-1, :].set(self.cKV_cache[:, :, 1:, :])

        k = self.Wuk(cKVt)
        v = self.Wuv(cKVt)
        q = self.Wuq(cqt)
        
        weights = jnp.einsum('B nh T dk, B nh t dk -> B nh T t', q, k) * (1/ ((self.dk) ** 0.5))
        weights = nn.softmax(weights, axis=-1)
        attention = jnp.einsum('B nh T t, B nh t dk -> B nh T dk', weights, v)
        attention = rearrange(attention, 'B nh T dk, B T (nh dk)')
        output = self.output(attention)
        
        return output

In [None]:
from einops import rearrange

class MLA(nn.Module):
    model_dim : int 
    n_heads : int
    max_tokens : int
    latent_dim : int 
    dhR : int
    t : int
    
    def setup(self):
        self.Wdkv = nn.Dense(features=self.latent_dim)
        self.Wuk = nn.Dense(features=self.model_dim)
        self.Wuv = nn.Dense(features=self.model_dim)
        self.Wdq = nn.Dense(feauters=self.latent_dim)
        self.Wuq = nn.Dense(features=self.model_dim)

        self.dk = self.model_dim // self.n_heads
        self.output = nn.Dense(features=self.model_dim) 
        self.cKV_cache = None
        self.kRT_cache = None
        self.cache_ind = 0

        self.Wkr = nn.Dense(features=self.dhR)
        self.Wqr = nn.Dense(features=(self.dhR*self.n_heads))
        self.rope = RoPE(model_dim=self.model_dim, t=self.t)

    def __call__(self, x, train:True):
        if train == False:
            x = x[:, -1:, :]
        
        cKVt = self.Wdkv(x)
        cqt = self.Wdq(x)
        kRt = self.rope(self.Wkr(x))
        qrt = self.rope(self.Wqr(cqt))
        qrt = rearrange(qrt, 'B T C -> B nh T dk', nh=self.n_heads, dk=self.dk)


        B, T ,C = x.shape 
        if train == False:
            if self.cKV_cache == None:
                self.cKV_cache = nn.zeros((B, self.max_tokens, self.dc))
            
            if self.kRT_cache == None:
                self.kRT_cache = nn.zeros((B, self.max_tokens, self.dhR))
            
            self.cKV_cache.at[:, self.cache_ind:self.cache_ind+T, :].set(cKVt)
            self.kRT_cache.at[:, self.cache_ind:self.cache_ind+T, :].set(kRt)

            cKVt = self.cKV_cache[:, :self.cache_ind+T, :]
            kRt = self.kRT_cache[:, :self.cache_ind+T, :]

            self.cache_ind = min(self.cache_ind + T, self.max_tokens-1)

            if (self.cache_ind == self.max_tokens-1):
                self.cKV_cache.at[:, :-1, :].set(self.cKV_cache[:, 1:, :])
                self.kRT_cache.at[:, :-1, :].set(self.kRT_cache[:, 1:, :])

        k_c = self.Wuk(cKVt) 
        k_c = rearrange(k, 'B T C -> B nh T dk', nh=self.n_heads, dk=self.dk)
        k_r = kRt[:, None, ...].repeat(axis=1, total_repeat_length=self.n_heads)
        k = jnp.concatenate([k_c, kRt], axis=-1)

        v = self.Wuv(cKVt)
        v = rearrange(v, 'B T C -> B nh T dk', nh=self.n_heads, dk=self.dk)

        q = self.Wuq(cqt)
        q_c = rearrange(q, 'B T C -> B nh T dk', nh=self.n_heads, dk=self.dk)
        q = jnp.concatenate([q, qrt], axis=-1)

       
        weights = jnp.einsum('B nh T dk, B nh t dk -> B nh T t', q, k) * (1/ ((self.dk) ** 0.5))
        weights = nn.softmax(weights, axis=-1)
        attention = jnp.einsum('B nh T t, B nh t dk -> B nh T dk', weights, v)
        attention = rearrange(attention, 'B nh T dk, B T (nh dk)')
        output = self.output(attention)
        
        return output