In [6]:
from importlib.metadata import version

pkgs = [
    'blobfile',         # to download pretrained weights
    'huggingface_hub',  # to download pretrained weights
    'tiktoken',         # to implement the tokenizer
    'torch'
]
for p in pkgs:
    print(f"{p} version: {version(p)}")

blobfile version: 3.0.0
huggingface_hub version: 0.31.2
tiktoken version: 0.9.0
torch version: 2.7.0


### 1. Convert the llama model implementation step by step

#### 1.1 Resuing Llama 2 components

In [11]:
import os
import sys
import io
import nbformat
import types

def import_from_notebook():
    def import_definitions_from_notebook(fullname, names):
        current_dir = os.getcwd()
        path = os.path.join(current_dir, fullname+'.ipynb')
        path = os.path.normpath(path)

        # Load the notebook
        if not os.path.exists(path):
            raise FileNotFoundError(f"Notebook file not found at: {path}")
        with io.open(path, 'r', encoding='utf-8') as f:
            nb = nbformat.read(f, as_version = 4)
        
        # create a module to store the imported functions and classes
        mod = types.ModuleType(fullname)
        sys.modules[fullname]=mod

        # Go through the notebook cells and only execuite functions and class definition
        for cell in nb.cells:
            if cell.cell_type =='code':
                cell_code = cell.source
                for name in names:
                    # check for function or class definitions
                    if f"def {name}" in cell_code or f'class {name}' in cell_code:
                        exec(cell_code, mod.__dict__)
        return mod
    
    fullname = 'gpt-to-llama2'
    names = ['precompute_rope_params', 'compute_rope', 'SiLU', "FeedForward", 'MultiHeadAttention']
    return import_definitions_from_notebook(fullname, names)


In [13]:
imported_module = import_from_notebook()

# We need to redefine precompute_rope_params 
# precompute_rope_params = getattr(imported_module, 'precompute_rope_params', None)
compute_rope = getattr(imported_module, 'compute_rope', None)
SiLU = getattr(import_from_notebook, 'SiLU', None)
FeedForward = getattr(imported_module, 'FeedForward', None)
RMSNorm = getattr(imported_module, 'RMSNorm', None)

# MultiheadAttention only for comparsion purpose
MultiHeadAttention = getattr(imported_module, 'MultiHeadAttention', None)

FileNotFoundError: Notebook file not found at: c:\Users\hp\OneDrive\Desktop\llm-from-scratch\llma2\gpt-to-llama2.ipynb

#### 1.2 Modified RoPE

In [29]:
import torch
def compute_rope(x, cos, sin):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even"

    # Split x into first half and second half
    x1 = x[..., : head_dim // 2]  # First half
    x2 = x[..., head_dim // 2 :]  # Second half

    # Adjust sin and cos shapes
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    # Apply the rotary transformation
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    return x_rotated.to(dtype=x.dtype)


In [33]:
import torch

def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096,freq_config=None ):
    assert head_dim %2==0, "Embeddings dimension must be even"

    # compute the inverse frequencies
    inv_freq = 1.0 / (theta_base ** ( torch.arange(0, head_dim, 2)[: (head_dim//2)].float()/head_dim))

    # frequency adjustment
    if freq_config is not None:
        low_freq_wavelen = freq_config['original_context_length']/ freq_config['low_freq_factor']
        high_freq_wavelen = freq_config['original_context_length']/ freq_config['high_freq_factor']

        wavelen = 2* torch.pi / inv_freq

        inv_freq_llama = torch.where(
            wavelen > low_freq_wavelen, inv_freq / freq_config['factor'], inv_freq
        )

        smooth_factor = (freq_config['origianl_context_length']/ wavelen - freq_config['low_freq_factor'])/(
                        freq_config['high_freq_factor']-freq_config['low_freq_factor'])

        smoothed_inv_freq = (
            (1- smooth_factor)*(inv_freq/ freq_config['factor']) + smooth_factor * inv_freq
        )

        is_medium_freq = (wavelen <= low_freq_wavelen) & ( wavelen >= high_freq_wavelen)
        inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
        inv_freq = inv_freq_llama

    
    #### Generate position indices
    positions = torch.arange(context_length)

    #compute the angle
    analges = positions[:, None] *  inv_freq[None, :] # Shape: (context_length, head_dim//2)

    # Expand angles to match head_dim
    analges = torch.cat([analges, analges], dim=1) # Shape: (context_lenght, head_dim)

    #Precompute sine and cosine
    cos = torch.cos(analges)
    sin = torch.sin(analges)

    return cos, sin

In [34]:
# Instatiate RoPE parameters

llama_2_context_len = 4096
llama_3_context_len = 8192

llama_2_theta_base = 10_000
llama_3_theta_base = 500_000



In [35]:
# Settings
batch_size =2
num_heads=4
head_dim =16

# Instantiate RoPE parameters
cos, sin = precompute_rope_params(
    head_dim=head_dim,
    theta_base=llama_3_theta_base,
    context_length=llama_3_context_len
)

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)
keys = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)

# Apply rotary position embedding
queries_rot = compute_rope(queries, cos, sin)
keys_rot = compute_rope(keys, cos, sin)

#### 1.3 Grouped-query attention

In [36]:
import torch.nn as nn 


class SharedBuffers:
    _buffers = {}

    @staticmethod
    def get_buffers(context_length, head_dim, rope_base, freq_config, dtype = torch.float32):
        key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)

        if key not in SharedBuffers._buffers:
            #Create or fetch the buffers
            mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
            cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)
            if dtype is not None:
                cos = cos.to(dtype)
                sin = sin.to(dtype)
            SharedBuffers._buffers[key] = (mask, cos, sin)

        return SharedBuffers._buffers[key]
    


class GroupedQueryAttention(nn.Module):
    def __init__(
            self,
            d_in,
            d_out,
            context_length,
            num_heads,
            num_kv_groups,
            rope_base=10_000,
            rope_config=None,
            dtype = None):
        super().__init__()
        assert d_out % num_heads ==0, 'd_out must be divisible by num_heads'
        assert num_heads % num_kv_groups ==0, 'num_heads must be divisiable by num_kv_gropus'

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)   
        self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)  
        self.num_kv_groups = num_kv_groups
        self.group_size = num_heads//num_kv_groups

        self.W_query = nn.Linear(d_in, d_out , bias=False,dtype=dtype)
        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)

        mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)  

        self.register_buffer('mask', mask)
        self.register_buffer('cos', cos)
        self.register_buffer('sin', sin)

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        # Reshape queries keys and values
        queries = queries.view(b, num_tokens, self.num_kv_groups, self.head_dim)

        keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)
        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)

        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        queries = queries.transpose(1,2)

        keys = compute_rope(keys, self.cos, self.sin)
        queries = compute_rope(values, self.cos, self.sin)

        keys = keys.repeat_interleave(self.group_size, dim=1)
        values = values.repeat_interleave(self.group_size, dim=1)

        attn_scores = queries @ keys.transpose(2,3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)
        attn_weights = torch.softmax(attn_scores/ keys.shape[-1]**0.5, dim=-1)

        assert keys.shape[-1]==self.head_dim

        context_vec = (attn_weights @ values).transpose(1,2)
        context_vec = context_vec.reshape(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec