In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

from daisygrad.tensor import DaisyTensor
from daisygrad.neural.layers import Linear, LayerNorm, Embedding, Parameter, Dropout

In [299]:
from daisylm.jax_core.tokenizer import BPETokenizer
tokenizer = BPETokenizer(vocab_size=4096)

In [300]:
import jax.numpy as jnp
from jax import random

strings = [
    "The sun is bright.",
    "Hello world!",
    "AI is transforming the world.",
    "Short"
]

tok = [tokenizer.encode(s) for s in strings]

max_len = max(len(t) for t in tok)
padded_tokens = []
for token in tok:
    padded_tok = jnp.pad(jnp.array(token), pad_width=(0, max_len - len(token)))
    padded_tokens.append(padded_tok)
    
tokarray = jnp.stack(padded_tokens)
inputs = tokarray.at[:, :-1].get()

In [301]:
class DaisyConfig:
    block_size: int = 512
    vocab_size: int = 4096
    n_embed = 768
    n_head = 12
    n_blocks = 12
    dropout: float = 0.0

In [324]:
class Embed:
    def __init__(self, key, config):
        key, vec_key = random.split(key)
        key, pos_key = random.split(key)
        scale = 1 / jnp.sqrt(config.vocab_size)
        self.vec_matrix = Embedding(vec_key, config.vocab_size, config.n_embed)
        self.pos_matrix = Parameter(random.normal(pos_key, shape=(config.block_size, config.n_embed)) * scale)

    def __call__(self, inputs):
        vec_embed = self.vec_matrix(inputs)
        indices = jnp.arange(inputs.shape[-1])
        pos_embed = self.pos_matrix.take(indices)
        return vec_embed + pos_embed

    @property
    def weight(self):
        return self.vec_matrix.embed

In [325]:
key = random.PRNGKey(42)
key, embed_key = random.split(key)
config = DaisyConfig()
embedding = Embed(embed_key, config)
x = embedding(inputs)

In [326]:
x.shape

(4, 28, 768)

In [335]:
class CausalSelfAttention:
    def __init__(self, key, config):
        assert config.n_embed % config.n_head == 0
        qkv_key, proj_key, drop_key = random.split(key, 3)
        scale = 1 / jnp.sqrt(config.n_embed)
        self.qkv = Linear(qkv_key, config.n_embed, 3 * config.n_embed, bias=True)
        self.proj = Linear(proj_key, config.n_embed, config.n_embed, bias=True)
        self.n_head = config.n_head
        self.dropout = Dropout(drop_key, config.dropout)

    def causal_mask_like(self, tensor: DaisyTensor):
        T = tensor.shape[-1]
        mask = jnp.full((T, T), 0.0)
        mask = mask.at[jnp.triu_indices(T, 1)].set(-jnp.inf)
        return DaisyTensor(mask[None, None, :, :], requires_grad=tensor.requires_grad)

    def __call__(self, x, train=True):
        B, T, E = x.shape
        qkv_projected = self.qkv(x)
        q, k, v = qkv_projected.split(3, axis=-1)
        q = q.reshape((B, T, self.n_head, E // self.n_head)).transpose(0, 2, 1, 3)
        k = k.reshape((B, T, self.n_head, E // self.n_head)).transpose(0, 2, 1, 3)
        v = v.reshape((B, T, self.n_head, E // self.n_head)).transpose(0, 2, 1, 3)

        D = q.shape[-1]
        scores = q @ k.transpose(0, 1, 3, 2)
        scaled_scores = scores / jnp.sqrt(D) 
        masked_scores = scaled_scores + self.causal_mask_like(scaled_scores)

        a = masked_scores.softmax()
        z = a @ v
        
        z = z.transpose(0, 2, 1, 3).reshape((B, T, E))
        return self.dropout(self.proj(z), train=train)

In [337]:
key, attn_key = random.split(key)
attn = CausalSelfAttention(key, config)
x1 = attn(x)

In [338]:
x1.shape

(4, 28, 768)

In [310]:
class MLP:
    def __init__(self, key, config):
        l1_key, l2_key, drop_key = random.split(key, 3)
        self.l1 = Linear(l1_key, config.n_embed, 4 * config.n_embed, bias=True)
        self.l2 = Linear(l2_key, 4 * config.n_embed, config.n_embed, bias=True)
        self.dropout = Dropout(drop_key, config.dropout)

    def __call__(self, x, train=True):
        return self.dropout(self.l2(self.l1(x).gelu()), train=train)

In [311]:
class Transformer:
    def __init__(self, key, config):
        attn_key, mlp_key, drop1_key, drop2_key = random.split(key, 4)
        self.attn = CausalSelfAttention(attn_key, config)
        self.ln_1 = LayerNorm(config.n_embed)
        self.ln_2 = LayerNorm(config.n_embed)
        self.mlp = MLP(mlp_key, config)
        self.drop1 = Dropout(drop1_key, config.dropout)
        self.drop2 = Dropout(drop2_key, config.dropout)

    def __call__(self, x, train=True):
        x = x + self.drop1(self.attn(self.ln_1(x)), train=train)
        x = x + self.drop2(self.mlp(self.ln_2(x)), train=train)
        return x

In [312]:
import jax.numpy as jnp
from jax import random

strings = [
    "The sun is bright.",
    "Hello world!",
    "AI is transforming the world.",
    "Short"
]

tok = [tokenizer.encode(s) for s in strings]

max_len = max(len(t) for t in tok)
padded_tokens = []
for token in tok:
    padded_tok = jnp.pad(jnp.array(token), pad_width=(0, max_len - len(token)))
    padded_tokens.append(padded_tok)
    
tokarray = jnp.stack(padded_tokens)
inputs = tokarray.at[:, :-1].get()

In [313]:
key = random.PRNGKey(42)
key, embed_key = random.split(key)
transformer = TransformerBlock(key)
embedding = Embed(embed_key, config)
x = embedding(inputs)
out = transformer(x)

In [314]:
out.shape

(4, 28, 768)

In [315]:
class Model:
    def __init__(self, key):
        config = DaisyConfig()
        keys = random.split(key, config.n_blocks + 1)
        key, embed_key, final_key = random.split(keys[0], 3)
        self.embedding = Embed(embed_key, config)
        self.blocks = [Transformer(keys[i], config) for i in range(config.n_blocks)]
        self.ln_f = LayerNorm(config.n_embed)
        self.final = Linear(final_key, config.n_embed, config.vocab_size, bias=True)
        self.final.weight = self.embedding.weight.transpose(-1, -2)

    def __call__(self, inputs):
        x = self.embedding(inputs)

        for block in self.blocks:
            x = block(x)

        logits = self.final(self.ln_f(x))
        return logits

In [316]:
key = random.PRNGKey(42)
daisy = Model(key)

In [317]:
x = daisy(inputs)
x.data.shape

(4, 28, 4096)

In [22]:
import numpy as np
import torch as pt
freqs = 1.0 / (10000.0 ** (np.arange(0, 128, 2)[: (128 // 2)].astype(np.float32) / 128))
t = np.arange(1024, dtype=np.float32)
freqs = np.outer(t, freqs)
freqs_cis = 