In [2]:
%load_ext autoreload
%autoreload 2

In [4]:
import sys
sys.path.append('..')

In [5]:
from daisygrad.tensor import DaisyTensor
from daisygrad.neural.layers import Parameter, Module, Linear, Embedding, Dropout
import jax.numpy as jnp

In [6]:
class RMSNorm(Module):
    def __init__(self, features: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = Parameter(jnp.ones(features))

    def __call__(self, x: DaisyTensor) -> DaisyTensor:
        x_rms = (x.pow(2).mean(-1, keepdims=True) + self.eps).sqrt()
        out = x * x_rms
        return out * self.weight

In [7]:
def precompute_freq_cis(dim: int, seqlen: int, theta: float = 10000.0):
    base_freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(jnp.float32) / dim))
    t = jnp.arange(seqlen, dtype=jnp.float32)
    freqs = jnp.outer(t, base_freqs)
    freqs_cis = jnp.exp(1j, freqs)
    return freqs_cis

In [11]:
def apply_rotary_emb(q: DaisyTensor, k: DaisyTensor, freqs_cis) -> DaisyTensor:
    q_reshaped = q.reshape(q.shape[:-1] + (q.shape[-1] // 2, 2))
    k_reshaped = k.reshape(k.shape[:-1] + (k.shape[-1] // 2, 2))

    q_complex = q_reshaped[..., 0] + 1j * q_reshaped[..., 1]
    k_complex = k_reshaped[..., 0] + 1j * k_reshaped[..., 1]

    seq_len = q.shape[2]
    broadcasted_freq_cis = freq_cis[None, None, :seq_len, :]

    q_rotated_complex = q_complex * broadcasted_freq_cis
    k_rotated_complex = k_complex * broadcasted_freq_cis

    q_rotated = jnp.stack([q_rotated_complex.real, q_rotated_complex.imag], axis=-1)
    k_rotated = jnp.stack([k_rotated_complex.real, k_rotated_complex.imag], axis=-1)
    qr = q_rotated.reshape(q.shape)
    kr = k_rotated.reshape(k.shape)

    return qr, kr

In [14]:
class ModelConfig:
    vocab_size = 32000
    d_model = 768
    n_layers = 12
    n_heads = 12
    d_head = 64
    intermediate_size = 3072
    d_latent = 32
    d_rope_sub = 16

In [13]:
class MLA(Module):
    def __init__(self, config):
        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.d_latent = config.d_latent
        self.d_rope_sub = config.d_rope_sub
        self.d_head = config.d_head

        self.wq = Linear(self.d_model, self.n_heads * self.d_head)
        self.wkv_a = Linear