## LLaMA3 almost from scratch

In [1]:
"""Load .env vars"""
import os
from pathlib import Path
from dotenv import load_dotenv

load_dotenv()

# Warning: These weights are downloaded from official Meta
# sources e.g. accepting TOS and using Meta's download.sh script.
# (doesn't work with huggingface weights yet)
LLAMA_WEIGHTS_PATH = Path(os.getenv("LLAMA_WEIGHTS_PATH"))

### Tokenizer

[https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py). 

Code taken out of the official LLaMA3 `Tokenizer` class with changed path handling.

In [2]:
import tiktoken
from tiktoken.load import load_tiktoken_bpe

LLAMA_TOKENIZER_PATH = LLAMA_WEIGHTS_PATH / "tokenizer.model"
special_tokens = [
    "<|begin_of_text|>",
    "<|end_of_text|>",
    "<|reserved_special_token_0|>",
    "<|reserved_special_token_1|>",
    "<|reserved_special_token_2|>",
    "<|reserved_special_token_3|>",
    "<|start_header_id|>",
    "<|end_header_id|>",
    "<|reserved_special_token_4|>",
    "<|eot_id|>",  # end of turn
] + [
    f"<|reserved_special_token_{i}|>"
    for i in range(5, 256 - 5)
]
mergeable_ranks = load_tiktoken_bpe(LLAMA_TOKENIZER_PATH.absolute().as_posix())
tokenizer = tiktoken.Encoding(
    name=LLAMA_TOKENIZER_PATH.name,
    pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
    mergeable_ranks=mergeable_ranks,
    special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)},
)

# sanity check
tokenizer.decode(tokenizer.encode("can i haz a 🦙"))

'can i haz a 🦙'

### Load the model (weights + config)

In [3]:
import json
import torch

with open(LLAMA_WEIGHTS_PATH / "params.json", "r") as f:
    model_config = json.load(f)
    print(json.dumps(model_config, indent=2))

{
  "dim": 4096,
  "n_layers": 32,
  "n_heads": 32,
  "n_kv_heads": 8,
  "vocab_size": 128256,
  "multiple_of": 1024,
  "ffn_dim_multiplier": 1.3,
  "norm_eps": 1e-05,
  "rope_theta": 500000.0
}


In [4]:
model = torch.load(LLAMA_WEIGHTS_PATH / "consolidated.00.pth")
print(json.dumps(list(model.keys()), indent=2))

### Prepare input prompt

In [5]:
# one forward pass of the model (without sampling) will predict only the next token,
# so the model should predict token id of 'snowing'
input_prompt = "The ground was covered in a blanket of white, so it must have been"

# need to manually add beginning of sentence token (128000)
bos_id = tokenizer.encode("<|begin_of_text|>", allowed_special=set(special_tokens))
print(bos_id)

# encode the full input prompt
tokens = bos_id + tokenizer.encode(input_prompt)
print(tokens)

[128000]
[128000, 791, 5015, 574, 9960, 304, 264, 39139, 315, 4251, 11, 779, 433, 2011, 617, 1027]


### OG llama3 visualization

summary taken from [00-llama3-viz.ipynb](./00-llama3-viz.ipynb) (also check that notebook out for a full forward pass computation graph)

<pre>
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Transformer                              [1, 16, 128256]           --
├─Embedding: 1-1                         [1, 16, 4096]             525,336,576
├─ModuleList: 1-2                        --                        --
│    └─TransformerBlock: 2-1             [1, 16, 4096]             --
│    │    └─RMSNorm: 3-1                 [1, 16, 4096]             4,096
│    │    └─Attention: 3-2               [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-3                 [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-4             [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-2             [1, 16, 4096]             --
│    │    └─RMSNorm: 3-5                 [1, 16, 4096]             4,096
│    │    └─Attention: 3-6               [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-7                 [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-8             [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-3             [1, 16, 4096]             --
│    │    └─RMSNorm: 3-9                 [1, 16, 4096]             4,096
│    │    └─Attention: 3-10              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-11                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-12            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-4             [1, 16, 4096]             --
│    │    └─RMSNorm: 3-13                [1, 16, 4096]             4,096
│    │    └─Attention: 3-14              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-15                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-16            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-5             [1, 16, 4096]             --
│    │    └─RMSNorm: 3-17                [1, 16, 4096]             4,096
│    │    └─Attention: 3-18              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-19                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-20            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-6             [1, 16, 4096]             --
│    │    └─RMSNorm: 3-21                [1, 16, 4096]             4,096
│    │    └─Attention: 3-22              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-23                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-24            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-7             [1, 16, 4096]             --
│    │    └─RMSNorm: 3-25                [1, 16, 4096]             4,096
│    │    └─Attention: 3-26              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-27                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-28            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-8             [1, 16, 4096]             --
│    │    └─RMSNorm: 3-29                [1, 16, 4096]             4,096
│    │    └─Attention: 3-30              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-31                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-32            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-9             [1, 16, 4096]             --
│    │    └─RMSNorm: 3-33                [1, 16, 4096]             4,096
│    │    └─Attention: 3-34              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-35                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-36            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-10            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-37                [1, 16, 4096]             4,096
│    │    └─Attention: 3-38              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-39                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-40            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-11            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-41                [1, 16, 4096]             4,096
│    │    └─Attention: 3-42              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-43                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-44            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-12            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-45                [1, 16, 4096]             4,096
│    │    └─Attention: 3-46              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-47                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-48            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-13            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-49                [1, 16, 4096]             4,096
│    │    └─Attention: 3-50              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-51                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-52            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-14            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-53                [1, 16, 4096]             4,096
│    │    └─Attention: 3-54              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-55                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-56            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-15            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-57                [1, 16, 4096]             4,096
│    │    └─Attention: 3-58              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-59                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-60            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-16            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-61                [1, 16, 4096]             4,096
│    │    └─Attention: 3-62              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-63                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-64            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-17            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-65                [1, 16, 4096]             4,096
│    │    └─Attention: 3-66              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-67                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-68            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-18            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-69                [1, 16, 4096]             4,096
│    │    └─Attention: 3-70              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-71                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-72            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-19            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-73                [1, 16, 4096]             4,096
│    │    └─Attention: 3-74              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-75                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-76            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-20            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-77                [1, 16, 4096]             4,096
│    │    └─Attention: 3-78              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-79                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-80            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-21            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-81                [1, 16, 4096]             4,096
│    │    └─Attention: 3-82              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-83                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-84            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-22            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-85                [1, 16, 4096]             4,096
│    │    └─Attention: 3-86              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-87                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-88            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-23            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-89                [1, 16, 4096]             4,096
│    │    └─Attention: 3-90              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-91                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-92            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-24            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-93                [1, 16, 4096]             4,096
│    │    └─Attention: 3-94              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-95                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-96            [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-25            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-97                [1, 16, 4096]             4,096
│    │    └─Attention: 3-98              [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-99                [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-100           [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-26            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-101               [1, 16, 4096]             4,096
│    │    └─Attention: 3-102             [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-103               [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-104           [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-27            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-105               [1, 16, 4096]             4,096
│    │    └─Attention: 3-106             [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-107               [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-108           [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-28            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-109               [1, 16, 4096]             4,096
│    │    └─Attention: 3-110             [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-111               [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-112           [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-29            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-113               [1, 16, 4096]             4,096
│    │    └─Attention: 3-114             [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-115               [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-116           [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-30            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-117               [1, 16, 4096]             4,096
│    │    └─Attention: 3-118             [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-119               [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-120           [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-31            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-121               [1, 16, 4096]             4,096
│    │    └─Attention: 3-122             [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-123               [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-124           [1, 16, 4096]             176,160,768
│    └─TransformerBlock: 2-32            [1, 16, 4096]             --
│    │    └─RMSNorm: 3-125               [1, 16, 4096]             4,096
│    │    └─Attention: 3-126             [1, 16, 4096]             41,943,040
│    │    └─RMSNorm: 3-127               [1, 16, 4096]             4,096
│    │    └─FeedForward: 3-128           [1, 16, 4096]             176,160,768
├─RMSNorm: 1-3                           [1, 16, 4096]             4,096
├─Linear: 1-4                            [1, 16, 128256]           525,336,576
==========================================================================================
Total params: 8,030,261,248
Trainable params: 8,030,261,248
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 8.03
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 113.59
Params size (MB): 16060.52
Estimated Total Size (MB): 16174.11
==========================================================================================
</pre>