# Implementing LLaMA3 in 100 Lines of Pure Jax

![img](images/newllama.png)

In [2]:
import jax
import jax.numpy as jnp
from jax import random
import math

### Root Mean Square Layer Normalization

RMS normalization is an important layer in llama3 models. It helps keep the training stable by making sure that the numbers in the network don’t become too high or too low. This balance is very important, especially in deep networks.

![img](images/rsmnorm.png)


In [None]:
def rms_norm(x, weight, eps=1e-5):
    variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
    return x * weight * jnp.reciprocal(jnp.sqrt(variance + eps))


### Rotary Positional Encoding

Transformers don't naturally know the order of tokens, so we need to add some position info. In llama3 to solve this we have ROPE. It does this by “rotating” the query and key vectors based on their position.

![img](images/rope.png)

In [None]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (jnp.arange(0, dim // 2, dtype=jnp.float32) / dim))
    t = jnp.arange(end, dtype=jnp.float32)
    freqs = jnp.outer(t, freqs)
    return jnp.complex64(jnp.exp(1j * freqs))


#### How It Works:

Precompute Rotation Factors: First we create a table of rotation factors using a range of frequencies. This means each token gets its own unique rotation angle.

Apply the Rotation:

Pair Up Features: we reshape the vectors so that every two numbers form a pair; imagine them as the real and imaginary parts of a complex number.

Rotate: We multiply these complex numbers by our precomputed rotation factors. This rotates each pair in the complex plane.

Convert Back: Finally, we split the rotated complex numbers back into their real and imaginary parts to restore the original shape.


##### Math Behind It:

For each pair $(x_{2i}, x_{2i+1})$, the rotation is given by:
$$
\begin{pmatrix}
x'_{2i} \\
x'_{2i+1}
\end{pmatrix}
=
\begin{pmatrix}
\cos(\theta_i) & -\sin(\theta_i) \\
\sin(\theta_i) & \cos(\theta_i)
\end{pmatrix}
\begin{pmatrix}
x_{2i} \\
x_{2i+1}
\end{pmatrix}
$$

where $\theta_i$ is the rotation angle for that token.
In short, ROPE embeds positional information directly into the token features by rotating them. This way attention module gets the info about token order without extra position vectors.

In [None]:
def apply_rotary_emb(xq, xk, freqs_cis):
    xq_r, xk_r = jnp.reshape(xq, (*xq.shape[:-1], -1, 2)), jnp.reshape(xk, (*xk.shape[:-1], -1, 2))
    xq_complex = jnp.complex64(xq_r[..., 0] + 1j * xq_r[..., 1])
    xk_complex = jnp.complex64(xk_r[..., 0] + 1j * xk_r[..., 1])
    freqs_cis = jnp.reshape(freqs_cis, (1, freqs_cis.shape[0], 1, freqs_cis.shape[1]))
    xq_out = xq_complex * freqs_cis
    xk_out = xk_complex * freqs_cis
    xq = jnp.stack([jnp.real(xq_out), jnp.imag(xq_out)], axis=-1).reshape(xq.shape)
    xk = jnp.stack([jnp.real(xk_out), jnp.imag(xk_out)], axis=-1).reshape(xk.shape)
    return xq, xk


In [None]:
def repeat_kv(x, n_rep):
    return x if n_rep == 1 else jnp.repeat(x, n_rep, axis=2)

### Model Weights Initialization

In pure JAX, we don't use classes like in PyTorch. We use only pure fucntions why ? cause it makes our code more predictable and easier to parallelize. A pure function always returns the same output for the same input and doesn’t cause any side effects.6 For example, if you call F(x), you'll always get the same y.

Since we aren’t using a framework like PyTorch’s nn.Module to automatically track parameters, we must initialize and update our weights manually.

Handling randomness is also different. Instead of relying on a single global seed as in NumPy or PyTorch, in jax we need to manage randomness with explicit pseudo-random number generator (PRNG) keys. Each random operation gets its own unique key, which is derived by splitting a parent key. This will help in reproducibility and parallelism.

For example, below you can see we are creating a key and splitting it into sub keys and then providing that key to the function which involves the randomness.

Now lets start with our Model Weights Initialization, first we create the random values for our parameters with normal ditribuition.

In [None]:
def init_weight(key, shape, scale=None):
    scale = 1.0 / math.sqrt(shape[0]) if scale is None else scale
    return jax.random.normal(key, shape) * scale


Next, we'll identify all the learnable parameters of our model(llama3), assign each a unique key to ensure reproducibility, and apply the initialization process to them.

Since weights are essentially numbers stored in arrays, we can use dictionaries to manage them as key-value pairs.

First we will start with attention module which has four trainable parameters.

In [None]:
def init_attention_weights(key, dim, n_heads, n_kv_heads):
    keys = jax.random.split(key, 4)
    head_dim = dim // n_heads
    return {
        'wq': init_weight(keys[0], (dim, n_heads * head_dim)),
        'wk': init_weight(keys[1], (dim, n_kv_heads * head_dim)),
        'wv': init_weight(keys[2], (dim, n_kv_heads * head_dim)),
        'wo': init_weight(keys[3], (n_heads * head_dim, dim))
    }


Next we have our Feed-forward network which has 3 trainable parameters.



In [None]:
def init_ffn_weights(key, dim):
    keys = jax.random.split(key, 3)
    return {
        'w1': init_weight(keys[0], (dim, 4 * dim)),
        'w2': init_weight(keys[1], (4 * dim, dim)),
        'w3': init_weight(keys[2], (dim, 4 * dim))}

Then we combine our weights into transformer block, adding two additional parameters for two layers of RMSNorm.

In [None]:
def init_transformer_block(key, dim, n_heads, n_kv_heads):
    keys = jax.random.split(key, 4)
    return {
        'attention': init_attention_weights(keys[0], dim, n_heads, n_kv_heads),
        'ffn': init_ffn_weights(keys[1], dim),
        'attention_norm': init_weight(keys[2], (dim,), scale=1.0),
        'ffn_norm': init_weight(keys[3], (dim,), scale=1.0)}



Finally we assemble Model's Weights Initialization in one place.

In [None]:
def init_model_params(key, vocab_size, dim, n_layers, n_heads, n_kv_heads):
    keys = jax.random.split(key, 4)
    params = {
        'token_embedding': init_weight(keys[0], (vocab_size, dim)),
        'norm_f': init_weight(keys[1], (dim,), scale=1.0),
        'output': init_weight(keys[2], (dim, vocab_size))
    }
    block_keys = jax.random.split(keys[3], n_layers)
    params['blocks'] = [init_transformer_block(k, dim, n_heads, n_kv_heads) for k in block_keys]
    return params

### Group-Query Attention

Now it's time for attention. Grouped Query Attention (GQA) is an optimized version of Multi-Head Attention that improves efficiency by sharing key and value representations among multiple query heads. This reduces computational overhead and memory usage, enabling faster inference and better scaling for transformer models. At it's core, it's just self-attention but with some modification.



In [None]:
def attention(params, x, mask, freqs_cis, n_heads, n_kv_heads, cache=None, position=0):
    B, T, C = x.shape
    head_dim = C // n_heads
    q = jnp.dot(x, params['wq']).reshape(B, T, n_heads, head_dim)
    k = jnp.dot(x, params['wk']).reshape(B, T, n_kv_heads, head_dim)
    v = jnp.dot(x, params['wv']).reshape(B, T, n_kv_heads, head_dim)
    q, k = apply_rotary_emb(q, k, freqs_cis[position:position + T])
    if cache is not None:
        k = jnp.concatenate([cache[0], k], axis=1)
        v = jnp.concatenate([cache[1], v], axis=1)
    new_cache = (k, v)
    k = repeat_kv(k, n_heads // n_kv_heads)
    v = repeat_kv(v, n_heads // n_kv_heads)
    q, k, v = map(lambda x: x.transpose(0, 2, 1, 3), (q, k, v))
    scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / math.sqrt(head_dim)
    if mask is not None:
        scores = scores + mask[:, :, :T, :T]
    scores = jax.nn.softmax(scores, axis=-1)
    output = jnp.matmul(scores, v)
    output = output.transpose(0, 2, 1, 3).reshape(B, T, -1)
    return jnp.dot(output, params['wo']), new_cache

#### KV-cache :

It stores previously computed key (K) and value (V) tensors from past tokens. We can cache this kv-cache during inference.


![img](images/lightkv.png)

### Feed-forward

This is simple feed-forward network with SiLU activation function.

In [None]:
def feed_forward(params, x):
    return jnp.dot(jax.nn.silu(jnp.dot(x, params['w3'])) * jnp.dot(x, params['w1']), params['w2'])

### Transformer-block

This is where all the important components come together in the transformer block. We unpack the pre-initialized weights and distribute them to their respective layers. The transformer blocks include attention, normalization, feed-forward processing layers and residual connections

In [None]:
def transformer_block(params, x, mask, freqs_cis, n_heads, n_kv_heads, cache=None, position=0, training=False, dropout_rate=0.0, key=None):
    attn_output, new_cache = attention(params['attention'], rms_norm(x, params['attention_norm']), mask, freqs_cis, n_heads, n_kv_heads, cache, position)
    if training:
        dropout_key, key = jax.random.split(key)
        attn_output = jax.random.bernoulli(dropout_key, 1-dropout_rate, shape=attn_output.shape) * attn_output / (1-dropout_rate)
    h = x + attn_output
    ffn_output = feed_forward(params['ffn'], rms_norm(h, params['ffn_norm']))
    if training:
        dropout_key, key = jax.random.split(key)
        ffn_output = jax.random.bernoulli(dropout_key, 1-dropout_rate, shape=ffn_output.shape) * ffn_output / (1-dropout_rate)
    out = h + ffn_output
    return out, new_cache

### Forward-Pass

The forward pass takes your data through the entire model from converting input tokens into embeddings, through a series of transformer blocks, and finally to the output layer. In other words, it connects all the layers (embedding, transformers, and output) to produce the final predictions.



In [None]:
def model_forward(params, inputs, config, cache=None, position=0):
    B, T = inputs.shape
    h = params['token_embedding'][inputs]
    freqs_cis = precompute_freqs_cis(config.dim // config.n_heads, config.max_seq_len)
    mask = jnp.tril(jnp.ones((config.max_seq_len, config.max_seq_len)))
    mask = jnp.where(mask == 0, -1e9, 0.0)
    mask = mask.astype(h.dtype)
    mask = mask[None, None, :, :]
    new_caches = []
    for i, block in enumerate(params['blocks']):
        layer_cache = cache[i] if cache is not None else None
        h, layer_cache = transformer_block(block, h, mask, freqs_cis, config.n_heads, config.n_kv_heads, layer_cache, position, training=False, dropout_rate=config.dropout_rate)
        new_caches.append(layer_cache)
    h = rms_norm(h, params['norm_f'])
    logits = jnp.dot(h, params['output'])
    return logits, new_caches

## Training

In [2]:
!pip install tiktoken

Collecting tiktoken
  Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.2/1.2 MB[0m [31m5.6 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.2/1.2 MB[0m [31m19.6 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tiktoken
Successfully installed tiktoken-0.9.0


In [3]:
from jax import random, vmap
import tiktoken
from functools import partial
import os
import jax.lax as lax
import pickle


In [None]:
os.environ['JAX_PLATFORM_NAME'] = 'tpu' # gpu or tpu 
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # this will prevent jax from preallocsting 75% vram.
print("JAX devices:", jax.devices())

### Tokenization

Tokenization means dividing the text into words and subwords (tokens). We will be using Byte Pair Encoding (BPE) for training our model (BPE was used in training Llama 3).7 I will not build bpe from scratch we will use tiktoken library by openai for bpe.

In [None]:
# Initialize tokenizer and load data
enc = tiktoken.get_encoding("gpt2")
with open('shakespeare.txt', 'r') as f:
    text = f.read()
tokens = enc.encode(text)
data = jnp.array(tokens)

### Model Config 

 So these are the hyperparameter we need to train approximately 2 million parameters model.

In [None]:
# Model configuration
class ModelConfig:
    vocab_size = enc.n_vocab
    dim = 256
    n_layers = 6
    n_heads = 8
    n_kv_heads = 4
    max_seq_len = 512
    batch_size = 32
    learning_rate = 3e-4
    dropout_rate = 0.0

config = ModelConfig()


In [None]:
# Initialize model
key = random.PRNGKey(0)
params = init_model_params(
    key=key,
    vocab_size=config.vocab_size,
    dim=config.dim,
    n_layers=config.n_layers,
    n_heads=config.n_heads,
    n_kv_heads=config.n_kv_heads
)

### save and load model 

In [None]:
def save_params(params, filepath):
    numpy_params = jax.tree.map(lambda x: x.copy(), params)
    with open(filepath, 'wb') as f:
        pickle.dump(numpy_params, f)

def load_params(filepath):
    with open(filepath, 'rb') as f:
        numpy_params = pickle.load(f)
    # convert back to JAX arrays
    params = jax.tree.map(lambda x: jnp.array(x), numpy_params)
    return params

### Get Batches

The get_batch function creates training batches from our Shakespeare dataset. We need to feed our model with chunks of data. For each batch, we randomly select starting positions in the text, this way the model sees a variety of contexts.

Now, here's where JAX's cool vmap feature comes into play. Instead of writing a loop to extract each chunk, we use vmap to automate.

How does it work ?

vmap is like a vectorized loop; it takes a function that processes a single index (using lax.dynamic_slice to get a sequence of tokens) and applies it to every element in our array of indices. This means our input sequences (x) and corresponding target sequences (y, which are shifted by one token for next-word prediction) are created in one go.

In [None]:
def get_batch(key, data, batch_size, seq_len):
    # Generate random starting indices
    ix = random.randint(key, (batch_size,), 0, len(data) - seq_len)

    # Vectorized operation to get input and target sequences
    x = vmap(lambda i: lax.dynamic_slice(data, (i,), (seq_len,)))(ix)
    y = vmap(lambda i: lax.dynamic_slice(data, (i + 1,), (seq_len,)))(ix)

    return x, y

### Generate 

In [None]:
def generate(params, prompt_tokens, max_new_tokens, config):
    x = jnp.array(prompt_tokens)
    for _ in range(max_new_tokens):
        x_crop = x[-config.max_seq_len:]
        logits, _ = model_forward(params, x_crop[None, :], config)
        logits = logits[0, -1, :]  # take the last logit
        next_token = random.categorical(random.PRNGKey(0), logits, shape=(1,))[0]
        x = jnp.concatenate([x, jnp.array([next_token])])
    return x.tolist()

### Loss function

This function computes the cross-entropy loss for a batch during training. It first performs a forward pass using the model to generate logits for the input data. Then, it reshapes both the logits and targets to merge the batch and sequence dimensions. After applying the log softmax to the logits, it extracts the log probabilities corresponding to the correct target tokens and computes their negative mean as the final loss value.



In [None]:
def compute_loss(params, batch):
    inputs, targets = batch
    logits, _ = model_forward(params, inputs, config)
    logits = logits.reshape(-1, config.vocab_size)
    targets = targets.reshape(-1)
    loss = -jnp.mean(
        jnp.take_along_axis(
            jax.nn.log_softmax(logits),
            targets[:, None],
            axis=1
        )
    )
    return loss

### Update function

Now we need to write a function to update our weights. For simplicity, we're using Stochastic Gradient Descent (SGD) here, though you can also use Adam or AdamW for faster convergence.

In the code, you'll notice the @jax.jit decorator. This is one of the features that sets jax apart. JIT (Just-In-Time) compilation speeds up execution by converting your Python code into optimized machine code.

How does it work ?

When you decorate a function with JAX’s jit, it changes how the function executes. Normally, when you call a function, Python runs it line by line. For example, if you have:

In [1]:
def sqr(x): 
    print("HI jiited") # side effect 
    return x * x

print(sqr(2)) 
print(sqr(3)) 
print(sqr(4))

HI jiited
4
HI jiited
9
HI jiited
16


In [3]:
@jax.jit
def sqr(x): 
    print("HI jiited") # side effect  
    return x * x

print(sqr(2)) 
print(sqr(3)) 
print(sqr(4))

HI jiited
4
9
16


Jax first traces your function to build an optimized computation graph. This tracing happens the first time the function is called and converts the Python code into machine code.

Because of this tracing, any side effects like the print statement; are only executed during the initial tracing. Once the function is compiled, other remaining calls use the compiled version, and you might not see the print output every time.

In [None]:
@jax.jit
def update_step(params, batch):
    loss, grads = jax.value_and_grad(compute_loss)(params, batch)
    params = jax.tree.map(
        lambda p, g: p - config.learning_rate * g,
        params,
        grads
    )
    return params, loss

In our update_step function, @jax.jit compiles the code. The function computes loss and gradients simultaneously with jax.value_and_grad, updates the parameters using gradient descent with help of jax.tree.map, and returns the updated parameters and loss.

### Trainig-Loop

Finally, its time to train our 2 million parameter model on shakespeare dataset. We first prepare batches using the get_batch which splits our data into batches so we can train faster with our limited compute.

In [None]:
def train(num_epochs=30, steps_per_epoch=1000):
    key = random.PRNGKey(0)
    params_state = params  # copying

    epoch_losses = []

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 50)

        epoch_loss = 0.0
        for step in range(steps_per_epoch):

            key, batch_key = random.split(key)

            # Get batch
            batch = get_batch(batch_key, data, config.batch_size, config.max_seq_len)

            # Update model
            params_state, loss = update_step(params_state, batch)
            epoch_loss += loss


            if step % 100 == 0:
                print(f"epoch {epoch + 1}, step {step}/{steps_per_epoch}: loss = {loss:.4f}")


        avg_epoch_loss = epoch_loss / steps_per_epoch
        epoch_losses.append(avg_epoch_loss)

        print(f"\nepoch {epoch + 1} | average loss: {avg_epoch_loss:.4f}")


        if (epoch + 1) % 5 == 0:
            save_params(params_state, f'model_checkpoint_epoch_{epoch+1}.pkl')


    print("Loss by epoch:")
    for epoch, loss in enumerate(epoch_losses, 1):
        print(f"Epoch {epoch}: {loss:.4f}")

    # Save final model
    save_params(params_state, 'model_final.pkl')
    return params_state

In [None]:
# Train the model
trained_params = train()
