In [None]:
from importlib.metadata import version
version("sentencepiece")

In [None]:
import torch
import torch.nn as nn

In [None]:
# TODO: Check the dtype calculation here.
class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.emb_dim = emb_dim
        self.weight = nn.Parameter(torch.ones(emb_dim)).float()

    def forward(self, x):
        rms_mean = x.pow(2).mean(dim=-1, keepdim=True)
        x_norm = x * torch.rsqrt(rms_mean + self.eps)
        return (x_norm * self.weight).to(dtype=x.dtype)

In [None]:
x = torch.tensor(
    [
        [1, 2, 3],
        [4, 5, 6]
    ]
).float()
rms_norm = RMSNorm(emb_dim=3)
rms_norm(x)

In [None]:
rms_norm_pytorch = nn.RMSNorm(x.shape[-1], eps=1e-5)
rms_norm_pytorch(x)

In [None]:
class SiLU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)

In [None]:
import torch.nn.functional as F

In [None]:
silu = SiLU()
silu(x)

In [None]:
F.silu(x)

In [None]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg['emb_dim'], cfg['hidden_dim'], dtype=cfg['dtype'], bias=False)
        self.fc2 = nn.Linear(cfg['emb_dim'], cfg['hidden_dim'], dtype=cfg['dtype'], bias=False)
        self.fc3 = nn.Linear(cfg['hidden_dim'], cfg['emb_dim'], dtype=cfg['dtype'], bias=False)
        self.silu = SiLU()

    def forward(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        return self.fc3(self.silu(x_fc1) * x_fc2)

In [None]:
LLAMA2_CONFIG_7B = {
    "vocab_size": 32000,     # Vocabulary size
    "context_length": 4096,  # Context length
    "emb_dim": 4096,         # Embedding dimension
    "n_heads": 32,           # Number of attention heads
    "n_layers": 32,          # Number of layers
    "hidden_dim": 11008,     # NEW: Size of the intermediate dimension in FeedForward
    "dtype": torch.bfloat16  # NEW: Lower-precision dtype to reduce memory usage
}
feed_forward = FeedForward(LLAMA2_CONFIG_7B)

In [None]:
x = torch.randn((4, 8, 4096), dtype=torch.bfloat16)
feed_forward(x).shape

In [None]:
def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096):
    assert head_dim %2 == 0, "Embedding dimension should be even"
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2).float() / head_dim))
    positions = torch.arange(context_length)

    angles = positions[:, None] * inv_freq[None, :] # [context_length, head_dim/2]
    angles = torch.cat([angles, angles], dim=1) # [context_length, head_dim]
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin

In [None]:
def compute_rope(x, cos, sin):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim %2 == 0, "Embedding dimension should be even"
    x1 = x[..., :head_dim//2]
    x2 = x[..., head_dim//2:]
    rotated = torch.cat((-x2, x1), dim=-1)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    x_rotated = (x * cos) + (rotated * sin)
    return x_rotated.to(dtype=x.dtype)

In [None]:
batch_dim = 2
context_len = 5
num_heads = 4
head_dim = 16

cos, sin = precompute_rope_params(16, context_length=5)
q = torch.ones((batch_dim, num_heads, context_len, head_dim))
q_rotated = compute_rope(q, cos, sin)
q_rotated.shape

In [None]:
q[0, 0, 2, :]

In [None]:
q_rotated[0, 0, 2, :]

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, num_heads, dtype=None):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by n_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.W_key = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.W_value = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
        cos, sin = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin)

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)  # shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # transpose (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Apply rope here.
        keys = compute_rope(keys, self.cos, self.sin)
        queries = compute_rope(queries, self.cos, self.sin)

        # dot product for each head
        attn_scores = torch.matmul(
            queries, keys.transpose(2, 3)
        )  # (b, num_heads, num_tokens, num_tokens) # double check this
        mask = self.mask[:num_tokens, :num_tokens].bool()
        attn_scores.masked_fill_(mask, -torch.inf)

        scaled_attn_scores = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)

        # (b, num_heads, num_tokens, num_tokens) *
        # (b, num_heads, num_tokens, head_dim) -> (b, num_heads, num_tokens, head_dim) -> (b, num_tokens, num_heads, head_dim)
        context_vec = (scaled_attn_scores @ values).transpose(1, 2)
        context_vec = context_vec.reshape(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec


In [None]:
mha = MultiHeadAttention(
    d_in=128,
    d_out=128,
    context_length=100,
    num_heads=4
)

example_batch = torch.randn((1, 100, 128))
print(example_batch.shape)
mha(example_batch).shape

del example_batch, mha

In [None]:
class SublayerConnection(nn.Module):
    """
    Apply RMSNorm and residual connection.
    """

    def __init__(self, size):
        super().__init__()
        self.norm = RMSNorm(size)

    def forward(self, x, sublayer):
        return x + sublayer(self.norm(x))

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dtype=cfg["dtype"],
        )
        self.ff = FeedForward(cfg)
        self.sublayer1 = SublayerConnection(cfg["emb_dim"])
        self.sublayer2 = SublayerConnection(cfg["emb_dim"])

    def forward(self, x):
        # might have some interesting consequences when we load weights
        # attention block
        x = self.sublayer1(x, self.att)
        # FF block
        x = self.sublayer2(x, self.ff)
        return x

In [None]:
class Llama2Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])

        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )
        self.final_norm = RMSNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])

    def forward(self, in_idx):
        _, seq_len = in_idx.shape
        x = self.tok_emb(in_idx)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits

In [None]:
LLAMA2_CONFIG_7B

In [None]:
model = Llama2Model(LLAMA2_CONFIG_7B)

In [None]:
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,}")

In [None]:
torch.tensor(0, dtype=torch.float32).element_size()

In [None]:
torch.tensor(0, dtype=torch.bfloat16).element_size()

In [None]:
def total_memory_size(model, input_dtype=torch.float32):
    total_params = 0
    total_grads = 0
    for param in model.parameters():
        param_size = param.numel()
        total_params += param_size
        if param.requires_grad:
            total_grads += param_size

    total_buffers = sum(buf.numel() for buf in model.buffers())
    element_size = torch.tensor(0, dtype=input_dtype).element_size()

    model_size_bytes = (total_params + total_grads + total_buffers) * element_size
    model_size_gb = model_size_bytes / (2 ** 30)
    return model_size_gb

In [None]:
total_memory_size(model, input_dtype=torch.float32)

In [None]:
total_memory_size(model, input_dtype=torch.bfloat16)

In [None]:
device = torch.device("cuda")
model.to(device);

In [None]:
# download the tokenizer file

In [None]:
import sentencepiece as spm

class LLamaTokenizer:
    def __init__(self, tokenizer_file):
        sp = spm.SentencePieceProcessor()
        sp.load(tokenizer_file)
        self.tokenizer = sp

    def encode(self, text):
        return self.tokenizer.encode_as_ids(text)

    def decode(self, ids):
        return self.tokenizer.decode_pieces(ids)

In [None]:
tokenizer_file = "/home/htkumar/llms/llama-2-7b/tokenizer.model"

In [None]:
tokenizer = LLamaTokenizer(tokenizer_file)

In [None]:
from gpt_model import generate, text_to_token_ids, token_ids_to_text

In [None]:
torch.manual_seed(123)
token_ids = generate(
    model=model,
    idx=text_to_token_ids1("Every effort moves", tokenizer).to(device),
    max_new_tokens=30,
    context_size=LLAMA2_CONFIG_7B["context_length"],
    top_k=1,
    temperature=0.
)

token_ids_to_text(token_ids, tokenizer)

In [None]:
weights_file = "/home/htkumar/llms/llama-2-7b/consolidated.00.pth"

In [None]:
weights = torch.load(weights_file, weights_only=True)

In [None]:
type(weights)

In [None]:
list(weights.keys())[:15]

In [None]:
def assign(left, right):
    if left.shape != right.shape:
        raise ValueError(f"Shape mismatch, left: {left.shape}, right: {right.shape}")

    if isinstance(right, torch.Tensor):
        return torch.nn.Parameter(right.clone().detach())
    else:
        return torch.nn.Parameter(torch.tensor(right))

In [None]:
def load_weights_into_llama(model, param_config, params):
    model.tok_emb.weight = assign(model.tok_emb.weight, params["tok_embeddings.weight"])

    for l in range(param_config['n_layers']):
        # Load att weights
        model.trf_blocks[l].att.W_query.weight = assign(
            model.trf_blocks[l].att.W_query.weight,
            params[f'layers.{l}.attention.wq.weight']
        )
        model.trf_blocks[l].att.W_key.weight = assign(
            model.trf_blocks[l].att.W_key.weight,
            params[f'layers.{l}.attention.wk.weight']
        )
        model.trf_blocks[l].att.W_value.weight = assign(
            model.trf_blocks[l].att.W_value.weight,
            params[f'layers.{l}.attention.wv.weight']
        )
        model.trf_blocks[l].att.out_proj.weight = assign(
            model.trf_blocks[l].att.out_proj.weight,
            params[f'layers.{l}.attention.wo.weight']
        )
        model.trf_blocks[l].sublayer1.norm.weight = assign(
            model.trf_blocks[l].sublayer1.norm.weight,
            params[f'layers.{l}.attention_norm.weight']
        )

        # Load FF weights
        model.trf_blocks[l].ff.fc1.weight = assign(
            model.trf_blocks[l].ff.fc1.weight,
            params[f'layers.{l}.feed_forward.w1.weight']
        )
        # For some reason w2 and w3 are provided in the wrong order in the weights file
        model.trf_blocks[l].ff.fc2.weight = assign(
            model.trf_blocks[l].ff.fc2.weight,
            params[f"layers.{l}.feed_forward.w3.weight"]
        )
        model.trf_blocks[l].ff.fc3.weight = assign(
            model.trf_blocks[l].ff.fc3.weight,
            params[f"layers.{l}.feed_forward.w2.weight"]
        )
        model.trf_blocks[l].sublayer2.norm.weight = assign(
            model.trf_blocks[l].sublayer2.norm.weight,
            params[f'layers.{l}.ffn_norm.weight']
        )

    # Load output layer weights
    model.final_norm.weight = assign(model.final_norm.weight, params['norm.weight'])
    model.out_head.weight = assign(model.out_head.weight, params['output.weight'])

In [None]:
load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)

In [None]:
model.to(device);

In [None]:
torch.manual_seed(123)
token_ids = generate(
    model=model,
    idx=text_to_token_ids1("Every effort", tokenizer).to(device),
    max_new_tokens=30,
    context_size=LLAMA2_CONFIG_7B["context_length"],
    top_k=1,
    temperature=0.
)

print(token_ids_to_text(token_ids, tokenizer))

In [None]:
del model

In [None]:
weights_file_chat = "/home/htkumar/llms/llama-2-7b-chat/consolidated.00.pth"

In [None]:
weights_chat = torch.load(weights_file_chat, weights_only=True)

In [None]:
model_chat = Llama2Model(LLAMA2_CONFIG_7B)

In [None]:
load_weights_into_llama(model_chat, LLAMA2_CONFIG_7B, weights_chat)
model_chat.to(device);

In [None]:
torch.manual_seed(123)
token_ids = generate(
    model=model_chat,
    idx=text_to_token_ids1("What do llamas eat?", tokenizer).to(device),
    max_new_tokens=30,
    context_size=LLAMA2_CONFIG_7B["context_length"],
    top_k=1,
    temperature=0.
)

print(token_ids_to_text(token_ids, tokenizer))

In [None]:
del model_chat

In [None]:
# last cell, clean up memory
del model
torch.cuda.empty_cache()