In [1]:
""" model: llama3/tulu3/deepseekv3
llama series:
- (Touvron et al. 2023) URL: https://arxiv.org/abs/2302.13971
- (Touvron et al. 2023) URL: https://arxiv.org/abs/2307.09288
- (Llama Team 2024) URL: https://arxiv.org/abs/2407.21783

deepseek series:
- (DeepSeek-AI 2024) URL: https://arxiv.org/abs/2405.04434
- (DeepSeek-AI 2024) URL: https://arxiv.org/abs/2412.19437

tulu series:
- (Wang et al. 2023) URL: https://arxiv.org/abs/2306.04751
- (Ivison et al. 2023) URL: https://arxiv.org/abs/2311.10702
- (Lambert et al. 2024) URL: https://arxiv.org/abs/2411.15124

gpt2 -> llama2: RoPE, GQA, RMSNorm, SwiGLU, KV-Cache

Dimension key:

# windows
B: batch size
T: sequence length

# input/output
V: vocabulary size
D: model dimension (n_embd)

# attention
N: number of transformer blocks (n_layer)
H: number of attention heads in a layer (n_head)
K: size of each attention key or value (n_k)
"""


from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# @dataclass
# class LlamaConfig:
    # windows: B, T
    # input/output: V, D
    # attn: N, H, K
    # layers: int = 32

V, D = 32000, 4096
N, H, K = 32, 0, 0

class MHA(nn.Module):
    def __init__(self):
        super().__init__()
        self.q_proj = nn.Linear(D, D, bias=False)
        self.k_proj = nn.Linear(D, D, bias=False)
        self.v_proj = nn.Linear(D, D, bias=False)
        self.o_proj = nn.Linear(D, D, bias=False)

class FFN(nn.Module):
    def __init__(self):
        super().__init__()
        self.up_proj = nn.Linear(D, 11008, bias=False)
        self.gate_proj = nn.Linear(D, 11008, bias=False)
        self.down_proj = nn.Linear(11008, D, bias=False)





class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.self_attn = MHA()
        self.input_layernorm = nn.LayerNorm(D, bias=False)
        self.mlp = FFN()
        self.post_attention_layernorm = nn.LayerNorm(D, bias=False)

    def forward(self, X_BTD):
        X_BTD = X_BTD + self.self_attn(self.input_layernorm(X_BTD))
        X_BTD = X_BTD + self.mlp(self.post_attention_layernorm(X_BTD)) # todo: check llama2 paper
        return X_BTD





class Llama(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed_tokens = nn.Linear(D, V, bias=False)
        self.layers = nn.ModuleList([Block() for _ in range(N)])
        self.norm = nn.LayerNorm(D, bias=False)

    def forward(self, X):
        # 1. embeddings: X_BT -> X_BTD
        # TODO:
        X_BTD = Xt_BTD + Xp_TD

        # 2. blocks + unembed: Nx(X_BTD -> X_BTK -> X_BTD)
        for l in layers:
            X_BTD = l(X_BTD)

        return X_BTD

    @classmethod
    def from_pretrained(cls, model_type):
        """Loads pretrained Llama2 model weights from huggingface"""
        assert model_type in {'meta-llama/Llama-2-7b-hf'}
        from transformers import LlamaModel
        print("loading weights from pretrained llama: %s" % model_type)

        # 1. model init: mom i want a llama. we have llama at home
        model_hf, model = LlamaModel.from_pretrained(model_type), Llama()
        sdhf, sd = model_hf.state_dict(), model.state_dict()
        sdhf_keys, sd_keys = sdhf.keys(), sd.keys() # python isn't zero-alloc

        # 2. copy
        # - ensuring all params are aligned and match in names and shapes
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']

        assert len(sdhf_keys) == len(sd_keys), f"mismatched keys: {len(sdhf_keys)} != {len(sd_keys)}"
        for k in sdhf_keys:
            # vanilla copy over the other parameters
            assert sdhf[k].shape == sd[k].shape            
            with torch.no_grad():
                sd[k].copy_(sdhf[k])

        return model

class LlamaHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = Llama.from_pretrained("meta-llama/Llama-2-7b-hf")
        self.lm_head = nn.Linear(D, V, bias=False)

    def forward(self, X_BT):
        X_BTD = self.model(X_BT)
        X_BTV = self.lm_head(X_BTD)
        return X_BTV

# model
model = Llama.from_pretrained("meta-llama/Llama-2-7b-hf")
model.to(device)
print(f"loaded model to device {device}")

  from .autonotebook import tqdm as notebook_tqdm


loading weights from pretrained llama: meta-llama/Llama-2-7b-hf


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.02it/s]


loaded model to device cuda


In [3]:
# inference loop

# tokenize
B, T_MAX = 5, 30
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokens_1N = tokenizer("Hello, I'm a language model,", return_tensors="pt").input_ids
tokens_BT = tokens_1N.repeat(B, 1) # (B, T)
X_BT = tokens_BT.to(device)
print(X_BT)

torch.manual_seed(1337)
torch.cuda.manual_seed(1337)

while X_BT.size(1) < T_MAX:
    with torch.no_grad():
        # sample
        logits_BTV, _ = model(X_BT)
        logits_BV = logits_BTV[:, -1, :]
        probs_ = F.softmax(logits_BV, dim=-1)
        topk_probs, topk_indices = torch.topk(probs_, 50, dim=-1)
        
        # concat
        X_B1 = torch.gather(topk_indices_, -1, torch.multinomial(topk_probs_, 1))
        X_BT = torch.cat((X_BT, X_B1), dim=1)

tensor([[    1, 15043, 29892,   306, 29915, 29885,   263,  4086,  1904, 29892],
        [    1, 15043, 29892,   306, 29915, 29885,   263,  4086,  1904, 29892],
        [    1, 15043, 29892,   306, 29915, 29885,   263,  4086,  1904, 29892],
        [    1, 15043, 29892,   306, 29915, 29885,   263,  4086,  1904, 29892],
        [    1, 15043, 29892,   306, 29915, 29885,   263,  4086,  1904, 29892]],
       device='cuda:0')


ValueError: too many values to unpack (expected 2)