<a href="https://colab.research.google.com/github/goelnikhils-lgtm/languagemodels/blob/master/llama3_2fromscratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from typing import Optional, Tuple, Union, Callable, Any, Dict

# Configuration dictionary for LLaMA 3.2B model
LLAMA32_CONFIG_1B = {
    "vocab_size": 128_256, #vocab size of llama 3.2
    "context_length": 131_072, #context length that was used to train the model
    "embedding_dim": 2048, #embedding dimension
    "hidden_dim": 8192, #size of intermediate dimension in feedforward layer
    "num_layers": 16, #number of transformer layers
    "num_key_value_heads": 8, #number of key-value heads / group for GQA
    "num_heads": 16, #number of attention heads
    "rope_base": 500_000.0, # the base in RoPE's "theta"
    "dtype": torch.bfloat16, #data type
    "rope_freq":{
        "factor": 32.0,#frequency factor for RoPE computation
        "low_freq_factor": 1.0, #low frequency for RoPE computation
        "high_freq_factor": 4.0, #high frequency for RoPE computation
        "original_context_length": 8192, #original context length used during training
        }
}
LLAMA32_CONFIG_3B = {
    "vocab_size": 128_256, #vocab size of llama 3.2
    "context_length": 131_072, #context length that was used to train the model
    "embedding_dim": 3072, #embedding dimension
    "hidden_dim": 8192, #size of intermediate dimension in feedforward layer
    "num_layers": 28, #number of transformer layers
    "num_key_value_heads": 8, #number of key-value heads / group for GQA
    "num_heads": 24, #number of attention heads
    "rope_base": 500_000.0, # the base in RoPE's "theta"
    "dtype": torch.bfloat16, #data type
    "rope_freq":{
        "factor": 32.0, #frequency factor for RoPE computation
        "low_freq_factor": 1.0, #low frequency for RoPE computation
        "high_freq_factor": 4.0, #high frequency for RoPE computation
        "original_context_length": 8192, #original context length used during training
    }
}
class Llama3Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        # Token embedding layer
        self.token_embedding = nn.Embedding(config["vocab_size"], config["embedding_dim"], dtype=config["dtype"])
        self.trf_blocks = nn.ModuleList(
            [TransformerBlock(config) for _ in range(config["num_layers"])]
        )

        # Final layer normalization
        self.final_layer_norm = nn.RMSNorm(config["embedding_dim"], eps=1e-6,dtype=config["dtype"])

        # Output projection layer
        self.output_projection = nn.Linear(config["embedding_dim"], config["vocab_size"], bias=False,dtype=config["dtype"])

        #reusable utilities
        cos,sin = compute_rope_parameters(
            head_dim = config["embedding_dim"] // config["num_heads"],
            theta_base = config["rope_base"],
            max_seq_len = config["context_length"],
            freq_config = config["rope_freq"]
            )
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin)
        self.config = config

    def forward(self, in_idx):
        """
        Forward pass of the LLaMA 3.2 model.

        Args:
            input_ids (torch.Tensor): Input token IDs of shape (batch_size, sequence_length).

        Returns:
            torch.Tensor: Logits of shape (batch_size, sequence_length, vocab_size).
        """
        token_embeds = self.token_embedding(in_idx)  # (batch_size, seq_length, embedding_dim)
        x = token_embeds
        num_tokens = x.shape[1]
        mask = torch.triu(torch.ones((num_tokens, num_tokens), device=x.device), diagonal=1).bool()
        for block in self.trf_blocks:
            x = block(x, mask=mask, cos=self.cos, sin=self.sin)
        x = self.final_norm(x)
        logits = self.output_projection(x.to(self.config["dtype"]))  # (batch_size, seq_length, vocab_size)
        return logits

class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = GroupedQueryAttention(
            d_in=config["embedding_dim"],
            d_out=config["embedding_dim"],
            num_heads=config["num_heads"],
            num_key_value_heads=config["num_key_value_heads"],
            dtype=config["dtype"]
        )
        self.feed_forward = FeedForward(config)
        self.layer_norm1 = nn.RMSNorm(config["embedding_dim"], eps=1e-6,dtype=config["dtype"])
        self.layer_norm2 = nn.RMSNorm(config["embedding_dim"], eps=1e-6,dtype=config["dtype"])

    def forward(self, x, mask, cos, sin):
        #shortcut connection for attention block / skip connection for gradients to flow
        shortcut = x
        x = self.layer_norm1(x)
        x = self.attention(x,mask=mask, cos=cos, sin=sin)
        x = x + shortcut #Add the input original input block

        #shortcut connection for feedforward block / skip connection for gradients to flow
        shortcut = x
        x = self.layer_norm2(x)
        x = self.feed_forward(x)
        x = x + shortcut #Add the input original input block
        return x

class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.fc1 = nn.Linear(config["embedding_dim"], config["hidden_dim"],dtype=config["dtype"], bias = False)
        self.fc2 = nn.Linear(config["hidden_dim"], config["embedding_dim"],dtype=config["dtype"], bias = False)
        self.fc3 = nn.Linear(config["embedding_dim"], config["hidden_dim"],dtype=config["dtype"], bias = False)

    def forward(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = nn.functional.silu(x_fc1) * x_fc2
        return self.fc3(x)

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_in, d_out, num_heads, num_key_value_heads, dtype=None):
        super().__init__()
        assert d_in % num_heads == 0, "d_in must be divisible by num_heads"
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        assert num_heads % num_key_value_heads == 0, "num_heads must be divisible by num_key_value_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_key = nn.Linear(d_in,num_key_value_heads * self.head_dim, bias=False,dtype=dtype)
        self.W_value = nn.Linear(d_in, num_key_value_heads * self.head_dim, bias=False,dtype=dtype)
        self.num_key_value_heads = num_key_value_heads
        self.group_size = num_heads // num_key_value_heads #Group size for GQA
        # Query projection layer

        self.W_Q = nn.Linear(d_in, d_out, bias=False,dtype=dtype)
        # Output projection layer
        self.o_proj = nn.Linear(d_out, d_out, bias=False,dtype=dtype)

    def forward(self, x, mask=None, cos=None, sin=None):
        b, num_tokens , d_in = x.shape

        queries = self.W_Q(x) #shape (batch_size, num_tokens, d_out)
        key = self.W_key(x) #shape (batch_size, num_tokens, num_key_value_heads * head_dim)
        value = self.W_value(x) #shape (batch_size, num_tokens, num_key_value_heads * head_dim)

        # Reshape queries , key and values
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, num_tokens, head_dim)
        key = key.view(b, num_tokens, self.num_key_value_heads, self.head_dim ).transpose(1, 2)  # (batch_size, num_key_value_heads, num_tokens, head_dim)
        value = value.view(b, num_tokens, self.num_key_value_heads, self.head_dim ).transpose(1, 2)  # (batch_size, num_key_value_heads, num_tokens, head_dim)

        # Apply RoPE to queries and keys
        keys = apply_rope(key, cos, sin)
        queries = apply_rope(queries, cos, sin)

        #expand keys and values to match the number of heads
        #shape:(b,num_heads,num_tokens,head_dim)
        keys = keys.repeat_interleave(self.group_size, dim=1)
        values = value.repeat_interleave(self.group_size, dim=1)
        # why interleave -> to repeat each key-value head group_size times to match the number of query heads
        # This allows each query head to attend to all key-value pairs in its group
        # without introducing additional parameters or complexity.
        #for example if num_heads = 16 and num_key_value_heads = 8 then group_size = 2
        # so each key-value head will be repeated twice to match the number of query heads

        #compute scaled dot product attention(aka self attention) with a causal mask
        #shape:(b,num_heads,num_tokens,head_dim) @ (b,num_heads,head_dim,num_tokens) -> (b,num_heads,num_tokens,num_tokens)
        attn_scores = queries @ keys.transpose(2, 3) #dot product for each head

        #use the mask to set the upper triangular part of the attention scores to -inf
        attn_scores = attn_scores.masked_fill(mask[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = F.softmax(attn_scores / math.sqrt(self.head_dim), dim=-1) #softmax along the last dimension
        assert keys.shape[-1] == self.head_dim

        #shape:(b,num_heads,num_tokens,num_tokens) @ (b,num_heads,num_tokens,head_dim) -> (b,num_heads,num_tokens,head_dim)
        context_vectors = (attn_weights @ values).transpose(1, 2) #weighted sum of values for each head
        #combine heads , where self.d_out = num_heads * head_dim
        context_vectors = context_vectors.reshape(b, num_tokens, self.d_out) # (batch_size, num_tokens, d_out)
        context_vectors = self.o_proj(context_vectors) #final linear projection
        return context_vectors
def compute_rope_parameters(head_dim, theta_base=10_000, max_seq_len=4096, freq_config= None, dtype = torch.float32):
    """ Compute the cosine and sine matrices for RoPE.

    Args:
        head_dim (int): Dimension of each attention head.
        theta_base (float): Base value for computing the frequencies.
        max_seq_len (int): Maximum sequence length.
        freq_config (dict, optional): Configuration for frequency scaling.
        dtype (torch.dtype, optional): Data type of the output tensors.
        """
    assert head_dim % 2 == 0, "head_dim must be even for RoPE"
    #compute the inverse frequencies
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[:(head_dim // 2)].float() / head_dim))
    #compute the frequencies for each position
    if freq_config is not None:
        low_freq_factor = freq_config["original_context_length"] / freq_config["low_freq_factor"]
        high_freq_factor = freq_config["original_context_length"] / freq_config["high_freq_factor"]
        wavelen = 2*torch.pi / inv_freq

        inv_freq_llama = torch.where(wavelen>low_freq_factor,inv_freq/freq_config["factor"],inv_freq)
        smooth_factor = (freq_config["original_context_length"] /wavelen- freq_config["low_freq_factor"] ) / (freq_config["high_freq_factor"] - freq_config["low_freq_factor"])
        smooth_inv_freq = ((1-smooth_factor)*(inv_freq/freq_config["factor"]) + smooth_factor *inv_freq)

        is_medium_freq = (wavelen <= low_freq_factor) & (wavelen >= high_freq_factor)
        is_high_freq = torch.where(is_medium_freq,smooth_inv_freq,inv_freq_llama)
        inv_freq =inv_freq_llama
    #generate position indices
    positions = torch.arange(max_seq_len, dtype=dtype)
    #compute the angles
    angles = positions[:, None] * inv_freq[None, :]  # (max_seq_len, head_dim // 2)
    #expand the angles to match the head_dim
    angles = torch.cat([angles, angles], dim=-1)  # (max_seq_len, head_dim)

    #precompute the cosine and sine matrices
    cos = torch.cos(angles)  # (max_seq_len, head_dim)
    sin = torch.sin(angles)  # (max_seq_len, head_dim)
    return cos, sin

def apply_rope(x, cos, sin):
    """ Apply RoPE to the input tensor."""
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim%2 ==0 , "head_dim must be even for RoPE"

    #split x into first half and second half
    x1, x2 = x[..., :head_dim // 2], x[..., head_dim // 2:]

    #adjust sin and cos shapes
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, head_dim)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, head_dim)
    #apply the rotatory transformation
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    #it's ok to use lower-precision after applying cos and sin rotation
    return x_rotated.to(dtype = x.dtype)

def generate(model , idx,max_new_tokens, context_size , temperature=0.0, top_k=None, eos_id = None):
    """ Generate text using the LLaMA 3.2 model.

    Args:
        model (Llama3Model): The LLaMA 3.2 model.
        idx (torch.Tensor): Input token IDs of shape (batch_size, sequence_length).
        max_new_tokens (int): Maximum number of new tokens to generate.
        temperature (float): Sampling temperature.
        top_k (int, optional): If specified, use top-k sampling.

    Returns:
        torch.Tensor: Generated token IDs of shape (batch_size, sequence_length + max_new_tokens).
    """
    #for loop is the same as before: Get logits , and only focus on the last time step
    for _ in range(max_new_tokens):
        idx_cond = idx if idx.size(1) <= context_size else idx[:, -context_size:]
        with torch.no_grad():
            logits = model(idx_cond)
        logits = logits[:, -1, :]

        #filter logits for top-k sampling
        if top_k is not None:
            #keep only the top-k values
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)

        #apply temperature scaling
        if temperature > 0.0:
            logits = logits / temperature

            #apply the softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (batch_size, context_length)
            #sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)
        #otherwise same as before: get idx of the vocab entry with the highest logit value
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch_size, 1)
        #stop generating early if eos token is encountered and eos id is specified
        if idx_next == eos_id:
            break
        #same as before : append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, sequence_length + 1)
    return idx
