# Modify GPT2-124 to Llama 2-7B, 3-8B
- all modules and utils are left in nb for educational purpose
- Llama 2:
    - vocab sizes: 50257 -> 32000
    - input embedding dim 768 -> 4096
    - positional encoding: absolute positional encoding -> RoPE
    - context length 768 -> 4096
    - remove dropout before & after multihead attention, and after final feedforward layer
    - multihead causal attention w/ 12 atttention heads -> masked multihead attention w/ 32 heads
    - layer norm -> RMSNorm (Root Mean Square Layer Normalization)
    - final feedfordward layer: GELU -> Swish + SwiGLU+Linear as gate, hidden layer dim 11008
- Llama 3:
    - vocab size 32000 -> 128256
    - input embedding dim 4096
    - context length 4096 -> 8192
    - masked multihead attention w/ 32 heads -> masked grouped-query attention w/ 32 heads
    - final feedfordward layer: Swish + SwiGLU+Linear as gate, hidden layer dim 11008 -> 14336
- NOTE: 
    - GPT applies the positional embeddings to the inputs
    - Llama applies rotations to the query and key vectors in the self-attention mechanism itself

In [1]:
import torch

In [14]:
import numpy as np
import os
import sys 
import math
from typing import Tuple, Dict, List

cwd = os.getcwd()
base_path = cwd[:-13]

In [None]:
print(torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

from torch.utils.data import Dataset, DataLoader
import torch.nn as nn 
import torch.nn.functional as F

## preprocess

In [21]:
from helper import load_api_key
from huggingface_hub import login
from huggingface_hub import hf_hub_download # get model weights
import sentencepiece as spm # tokenizer 

HF_API_KEY = load_api_key(base_path, 'hf_llama2_at.txt')
os.environ["HF_API_KEY"] = HF_API_KEY
login(token=HF_API_KEY)

In [22]:

tokenizer_file = hf_hub_download(
    repo_id="meta-llama/Llama-2-7b",
    filename="tokenizer.model",
    local_dir="Llama-2-7B"
)

In [23]:
class LlamaTokenizer:
    def __init__(self, filepath):
        sp = spm.SentencePieceProcessor()
        sp.load(tokenizer_file)
        self.tokenizer = sp

    def encode(self, text):
        return self.tokenizer.encode_as_ids(text)

    def decode(self, ids):
        return self.tokenizer.decode_pieces(ids)


tokenizer = LlamaTokenizer(tokenizer_file)

In [24]:
# read in raw text
pdata = f"{base_path}traditional-NLP/data/"
sys.path.append(pdata)
with open(f"{pdata}anna.txt" , 'r', encoding='utf-8') as f:
    text_data = f.read()
print(f"The type of the raw text: {type(text_data)}")
print(f"The beginning of raw text: \n {text_data[:50]}")

The type of the raw text: <class 'str'>
The beginning of raw text: 
 Chapter 1


Happy families are all alike; every un


In [25]:
# inspect raw text and tokens
total_characters = len(text_data)
print(f"total num of characters in Anna Karenina: {total_characters}")
total_tokens = len(tokenizer.encode(text_data))
print(f"total num of tokens in Anna Karenina with Sentencepiece tokenizer: {total_tokens}")

total num of characters in Anna Karenina: 1985223
total num of tokens in Anna Karenina with Sentencepiece tokenizer: 550383


### set parameters

In [26]:
LLAMA2_7B_CONFIG = {
    "vocab_size": 32000,     # Vocabulary size
    "context_length": 4096,  # Context length
    "emb_dim": 4096,         # Embedding dimension
    "n_heads": 32,           # Number of attention heads
    "n_layers": 32,          # Number of layers
    "hidden_dim": 11008,     # ADD: Size of the intermediate dimension in FeedForward
    "dtype": torch.bfloat16  # ADD: Lower-precision dtype to save memory
}

torch.manual_seed(123)

<torch._C.Generator at 0x7f9b7107ae30>

### torch dataset dataloader

In [None]:
# create dataset and dataloader

class my_text_dataset(Dataset):

    # initialize with n varg in
    def __init__(self, raw_text:str, tokenizer, max_length:int, stride:int=1):
        # create class attributes
        self.input_tokens_x = []
        self.target_tokens_y = []

        # tokenize the enitre text 
        tokens = tokenizer.encode(raw_text, allowed_special={"<|endoftext|>"})

        # set y as stride number of tokens trailing x 
        for i in range(0, (len(tokens)-max_length), stride):
            x_tmp = tokens[i : (i+max_length)]
            y_tmp = tokens[(i+1) : (i+max_length+1)]
            self.input_tokens_x.append(torch.tensor(x_tmp))
            self.target_tokens_y.append(torch.tensor(y_tmp))

    # overwrite the __len__() method to return number of rows in the dataset
    def __len__(self) -> int:
        "Returns the number of rows / pairs of x-y sequences in the dataset"
        return len(self.input_tokens_x)
    
    # overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        "Returns one sample of data, data and label (X, y)."
        return self.input_tokens_x[idx], self.target_tokens_y[idx]

def my_text_dataloader(raw_text:str, batch_size:int=4, max_length:int=256, tokenizer=tokenizer,
                       stride:int=128, shuffle=True, drop_last=True, num_workers=0):

    # create dataset
    dataset = my_text_dataset(raw_text, tokenizer, max_length, stride)

    # create dataloader
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)

    return dataloader

#### split into T, V, H

In [None]:
total_characters = len(text_data)
print(f"total num of characters in Anna Karenina: {total_characters}")
prop_t, prop_v, prop_h = (0.8,0.1,0.1)
split_idx_t, split_idx_v = int(prop_t * total_characters), int((prop_t+prop_v) * total_characters)
print(f"Split at character index {split_idx_t} between train and valid sets, and at {split_idx_v} betwee valid and hold sets")

d_train = text_data[:split_idx_t]
d_valid = text_data[split_idx_t:split_idx_v]
d_hold  = text_data[split_idx_v:]

assert (total_tokens * prop_t) > CONFIG_GPT2_124M["context_length"], "Not enough tokens for loader_t (training dataloader)"
assert (total_tokens * prop_v) > CONFIG_GPT2_124M["context_length"], "Not enough tokens for loader_v (validation dataloader)"
assert (total_tokens * prop_h) > CONFIG_GPT2_124M["context_length"], "Not enough tokens for loader_h (testing dataloader)"

In [None]:
loader_t = my_text_dataloader(
    raw_text=d_train,
    batch_size=2, # this is only for learning purpose; in practice, batch_size >= 1024 is common
    max_length=LLAMA2_7B_CONFIG["context_length"],
    stride=LLAMA2_7B_CONFIG["context_length"],
    drop_last=True,
    shuffle=True,
    num_workers=0
)

loader_v = my_text_dataloader(
    raw_text=d_valid,
    batch_size=2,
    max_length=LLAMA2_7B_CONFIG["context_length"],
    stride=LLAMA2_7B_CONFIG["context_length"],
    drop_last=False,
    shuffle=False,
    num_workers=0
)

# loader_h = my_text_dataloader(
#     raw_text=d_hold,
#     batch_size=2,
#     max_length=LLAMA2_7B_CONFIG["context_length"],
#     stride=LLAMA2_7B_CONFIG["context_length"],
#     drop_last=False,
#     shuffle=False,
#     num_workers=0
# )

### inspect loaded data

In [None]:
train_tokens = 0
for input_batch, target_batch in loader_t:
    train_tokens += input_batch.numel()

val_tokens = 0
for input_batch, target_batch in loader_v:
    val_tokens += input_batch.numel()

print("Training tokens:", train_tokens)
print("Validation tokens:", val_tokens)
print("All tokens:", train_tokens + val_tokens)

## modules and model
- key components:
    - tokenization - done in my_text_dataloader
    - input embedding
    - RoPE
    - tansformer block
        - layernorm
        - multiheadattention CONFIG_GPT2_124M["n_heads"] by CONFIG_GPT2_124M["n_layers"]
        - droppout+shortcut
        - layernorm
        - feedford
        - dropout+shortcut
    - layernorm
    - output linear layer


### define modules

##### RoPE

In [None]:
def precompute_rope_params(head_dim, theta_base=10_000, context_length=LLAMA2_7B_CONFIG['context_length']):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    # Compute the inverse frequencies
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))

    # Generate position indices
    positions = torch.arange(context_length)

    # Compute the angles
    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)

    # Expand angles to match the head_dim
    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)

    # Precompute sine and cosine
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin

def compute_rope(x, cos, sin):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even"

    # Split x into first half and second half
    x1 = x[..., : head_dim // 2]  # First half
    x2 = x[..., head_dim // 2 :]  # Second half

    # Adjust sin and cos shapes
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    # Apply the rotary transformation
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    return x_rotated.to(dtype=x.dtype)

##### Multihead Causal Attention with RoPE
- hard code bias=False
- set dtype for later reduced precision compute

In [27]:
class Multihead_Causal_Attention_w_RoPE(nn.Module):
    def __init__(self, d_in, d_out, context_length, n_heads, dtype=None): 
        # inherit from the nn.Module parent 
        super().__init__() 

        # make sure d_out is divisible by n_heads (modulous ope, remainder==0)
        assert (d_out % n_heads == 0), "d_out must be divisible by n_heads"

        self.d_out = d_out
        self.n_heads = n_heads
        self.head_dim = d_out // n_heads  # Reduce the projection dim to match desired output dim

        self.w_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.w_key = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.w_value = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        # add the buffer to create mask and send it to device with the model 
        # but not update it
        self.register_buffer(
            'mask',
            torch.triu(
                torch.ones(context_length,context_length),
                diagonal=1)
        )
        # add linear layer to combine heads out
        self.combine_heads = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
        
        ################################### RoPE ###################################
        cos, sin = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin)
        ###########################################################################


    def forward(self, x):
        # allowing batching: first is the batch dim of tensors
        batch, n_tokens, d_in = x.shape

        # initialize the w_query, w_key, w_value 
        # AND matmul with input embeddings x
        queries = self.w_query(x)
        keys = self.w_key(x)
        values = self.w_value(x)
        # Shape: (b, num_tokens, d_out)

        # ###### split weights for the heads ######
        # dims from (batch, n_tokens, d_out) 
        # to (batch, n_tokens, n_heads, d_head)
        queries = queries.view(batch, n_tokens, self.n_heads, self.d_head)
        keys = keys.view(batch, n_tokens, self.n_heads, self.d_head)
        values = values.view(batch, n_tokens, self.n_heads, self.d_head)

        # then to (batch, n_heads, n_tokens, d_head)
        queries = queries.transpose(1,2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        # ###### split weights for the heads ######

        ################################### NEW ###################################
        # apply RoPE for keys and queries
        keys = compute_rope(keys, self.cos, self.sin)
        queries = compute_rope(queries, self.cos, self.sin)
        ###########################################################################

        # attention score query @ key.T 
        # but remember the dims is (batch, n_tokens, n_heads, d_head) so transpose the last two
        # !!! this computes dot product for each head !!!
        attention_scores = torch.matmul(queries, keys.transpose(2, 3))

        # add causal attention masks 
        # computeation with trailing underscore are performed in-place
        attention_scores.masked_fill_(
            # change the mask to boolean (truncated to num of tokens)
            self.mask.bool()[:n_tokens, :n_tokens],
            # fill value when 1 in mask
            -torch.inf
        )

        # attention weights = normalized attention scores
        # scale the attention scores by the sqrt(embedding dimentsion) first 
        # to improve the training performance by avoiding small gradients.
        attention_weights = torch.softmax(
            attention_scores / (keys.shape[-1]**0.5),
            dim=-1
        )

        # dropout within attention is also removed

        # calculate context vector attention weights @ values
        # ###### combine across all heads  ######
        # dims (batch, n_heads, n_tokens, d_head) to (batch, n_tokens, n_heads, d_head)
        context_vectors = torch.matmul(attention_weights, values).transpose(1, 2)
        context_vectors = context_vectors.contiguous().view(
            batch, n_tokens, self.d_out
        )
        # Combines heads, where self.d_out= self.n_heads * self.d_head
        context_vectors = self.combine_heads(context_vectors)
        # ###### combine across all heads  ######

        return context_vectors

In [28]:
# replacing LayerNorm in GPT2
class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.emb_dim = emb_dim
        self.weight = nn.Parameter(torch.ones(emb_dim)).float()

    def forward(self, x):
        means = x.pow(2).mean(dim=-1, keepdim=True)
        x_normed = x * torch.rsqrt(means + self.eps)
        return (x_normed * self.weight).to(dtype=x.dtype)

In [29]:
# replacing GELU - use a Swish function
# torch.nn.functional.silu works too
class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)

In [31]:
# FeedForward with SwiGLU: Swish Gates Linear Unit
class SwiGLU_FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
        # self.silu = Swish()
        self.silu = F.silu()

    def forward(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        # gated activation; element-wise multiplication
        x = self.silu(x_fc1) * x_fc2
        return self.fc3(x)

In [32]:
class TransformerBlock(nn.Module):
    """
    Follows the architecture of llama2 

    - tansformer block
        - create residual_conn for attention
        - rmsnorm
        - multihead causal attention with rope
        - shortcut for attention
        - create residual_conn for swiglu-feedforward
        - rmsnorm
        - swiglu-feedforward
        - shortcut for swiglu-feedforward
    """
    def __init__(self, cfg):
        super().__init__()
        self.att = Multihead_Causal_Attention_w_RoPE(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            n_heads=cfg["n_heads"],
            dtype=cfg["dtype"]  
        )
        self.ff = SwiGLU_FeedForward(cfg)

        ################################### NEW ###################################
        self.rmsnorm1 = RMSNorm(cfg["emb_dim"])
        self.rmsnorm2 = RMSNorm(cfg["emb_dim"])
        ###########################################################################

        # self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        # define shortcut / residual connection for attenion block
        residual_conn = x
        # rms_norm before attention
        x = self.rmsnorm1(x)
        x = self.att(x)   # Shape [batch_size, num_tokens, emb_size]
        # apply shortcut / residual connection
        x = x + residual_conn  # Add the original input back

        # define residual for FeedForward block
        residual_conn = x
        # rms_norm before swiglu-ff
        x = self.rmsnorm2(x)
        # swiglu-ff
        x = self.ff(x)
        # apply shortcut  
        x = x + residual_conn  # Add the original input back

        return x

### define model

In [33]:
# put it all together into a model 
class Llama2_model(nn.Module):
    """
    - input embedding
    - tansformer block
    - rms_norm
    - output linear layer
    """
    def __init__(self, cfg):
        super().__init__()
        self.input_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
        # REMOVED absolute positional encoding

        # transformer block x n_layers times
        self.transformer_block = nn.Sequential(
            # unpack list comprehension to repeat transformer-block n_layers times
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])

        ################################### NEW ###################################
        # rsm_norm
        self.final_norm = RMSNorm(cfg["emb_dim"])
        ###########################################################################
        # final output layer
        # expand tokens into logits in vocab_size dimensions
        # do not add extra bias 
        self.out_layer = nn.Linear(in_features=cfg["emb_dim"],
                                   out_features=cfg["vocab_size"],
                                   bias=False,
                                   dtype=cfg["dtype"])
        
    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        input_embeddings = self.tok_emb(in_idx)
        # REMOVE adding positional encoding
        # transformer block repeated n_layers times
        x = self.transformer_block(input_embeddings)
        # rmsnorm
        x = self.final_norm(x)
        # final output layer -> logits
        logits = self.out_head(x)
        return logits