<a href="https://colab.research.google.com/github/GitSanyaHub/NLP/blob/main/LLaMA/Llama3_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Let's code Meta's Llama 3 Step-by-step in PyTorch

The purpose of this guide is to illustrate the specific architecture choices implemented in Llama 3, which you will find are very similar to prior versions. Check out the YouTube video where i walk through this colab notebook and explain everything step-by-step

[![ERROR DISPLAYING IMAGE, CLICK HERE FOR VIDEO](https://img.youtube.com/vi/lZj8F6EspVU/0.jpg)](https://www.youtube.com/watch?v=lZj8F6EspVU)

This notebook guide is designed for beginners; if you already feel confident coding a transformer in pytorch on your own then i recommend instead skimming through the model.py file in the [github repo](https://github.com/evintunador/minLlama3). By beginner, i mean someone who understands matrix/tensor multiplication, general deep learning concepts like what a loss function is, and is capable of looking up pytorch documentation on any given function that they don't recognize, but maybe isn't well versed on transformers specifically. For an even better beginner's guide that uses an outdated architecture, check out [Andrej Karpathy's video on how to build GPT2](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=5014s) and then come back here to learn about the more up-to-date methods that Llama 3 utilizes.

Also, check out the original open-source release of Llama 3 [here](https://github.com/meta-llama/llama3).

If you enjoy this guide, then check out my analogous ones for [Google's Gemma](https://www.youtube.com/watch?v=WW7ZxaC3OtA) and [XAI's Grok](https://www.youtube.com/watch?v=K9Rdc848EBs).

**Note:** It's very easy to convince yourself that you understand something after watching a youtube video about it, but chances are you don't actually understand unless you can code it from scratch on your own. I highly recommend you mess around with this notebook and try to build your own minLlama from scratch

# What this guide does NOT include
The focus here is on architecture rather than optimization techniques, distributed training/inference, quantization, etc. As such, there are many parts of [the original Llama repo](https://github.com/meta-llama/llama3) that will not be included:
- the 15 trillion tokens of high quality data that Llama 3 was trained on (we'll be using TinyShakespeare instead)
- the original tokenizer--we'll be using a very simple one based on our datset and not going over how it works. Check out [Andrej Karpathy's great video on tokenizers](https://youtu.be/zduSFxRajkE?si=Q2uq_nilHhOegbRi) for a better explanation
- the specifics of their training setup that I could not ascertain from their open-sourced inference code (for example: parameter initialization distributions, location of dropout, whether they used regular or chunked attention during training, etc.)
- this guide is focused on training rather than inference; we'll do a quick greedy sampling at the end but for real sampling with temperature and whatnot as well as for kv caching during the run check out the code in section 2
- other stuff i'm prolly forgetting

# Table of Contents
1. [Spelled out walkthrough of every single tensor operation](#one)
  
  1a. [Setup stuff](#a)
  
  1b. [Initializing the first residual state](#b)
  
  1c. [Precomputing our RoPE Frequencies](#c)
  
  1d. [Precomputing the Causal Mask](#d)
  
  1e. [Our First Normalization](#e)
  
  1f. [Initializing Multi-Query Attention](#f)

  1g. [Rotary Position Embeddings](#g)

  1h. [Calculating Self-Attention](#h)

  1i. [Our first residual connection](#i)

  1j. [The SwiGLU Feedforward Network](#j)

  1k. [Output](#k)

  1l. [The Loss Functions](#l)

2. [Actually functional model code](#two)

  2a. [parameters](#twoa)

  2b. [RMSNorm](#twob)

  2c. [RoPE](#twoc)

  2d. [Attention](#twod)

  2e. [Ffwd](#twoe)

  2f. [Residual Layers](#twof)

  2g. [The model itself](#twog)

3. [Train and test your own minLlama3 (or load mine)](#three)

  3a. [Setup](#threea)

  3b. [Training your own](#threeb)

  3c. [Alternatively, you can load the 2m parameter model I already trained](#threec)

  3d. [Testing (performing inference)](#threed)

# 1. Spelled out walkthrough of every single tensor operation
<a id='one'></a>
In this section we'll walk through every important operation that Llama 3's architecture carries out using laughably small tensors. We've chosen tensors so small so that if you want to, you can literally pull out a calculator to 100% ensure you undersand what's happening. we'll begin with basic imports and whatnot

### 1a. Setup stuff
<a id='a'></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
# vocabulary length. Llama's real vocab size is 128256. Here let's just use an absurdly small number
v = 10

# Llama's maximum sequence length is 8192, but for inference they cache 3/4 of it and only use an effective length of 2048. more on that later
seq_len = 5

# we'll use a batch size of 1 for simplicity when visualizing our tensors
b = 1

# now let's make ourselves a list of token indices. Each represents somewhere between a letter and a word
tokens = torch.randint(v, (b, seq_len))
tokens.shape, tokens

(torch.Size([1, 5]), tensor([[8, 4, 7, 2, 9]]))

### 1b. Initializing the first residual state
<a id='b'></a>

In [None]:
# our embedding dimension. Llama 3 8b's is 4096
d = 16

# initializing our token embedding matrix
embedding = nn.Embedding(v, d)
embedding.weight.shape, embedding.weight
# each row in this embedding is a high dimensional repersentation of its corresponding token

(torch.Size([10, 16]),
 Parameter containing:
 tensor([[ 9.3121e-01, -9.6560e-02, -5.7902e-02, -9.5194e-01,  1.3470e+00,
          -1.0196e+00,  2.1526e-01,  4.3648e-01,  7.3091e-01,  1.0472e+00,
          -8.3610e-01, -1.1353e+00,  1.0956e+00,  1.3382e+00,  8.4271e-02,
           2.0557e+00],
         [-2.2929e+00, -1.2869e+00, -9.3078e-01,  9.4351e-02, -1.0933e+00,
          -1.4797e+00, -1.5054e-01, -4.6406e-01,  8.2742e-01,  8.3517e-01,
          -4.6111e-01,  2.7773e-01, -4.3511e-01,  2.6676e+00,  1.8905e+00,
          -8.3954e-01],
         [ 5.7819e-02, -6.4342e-01,  5.1670e-02,  1.0794e+00,  5.1993e-01,
          -8.9062e-01, -1.3466e-01, -1.2324e-01, -3.0727e-01,  8.0144e-02,
           1.0582e+00,  7.0094e-01,  6.6939e-01,  1.5953e+00,  7.9977e-01,
           1.3172e+00],
         [ 1.8516e+00,  3.7621e-01, -1.1883e-01,  6.9927e-01,  8.1808e-01,
          -1.5548e+00, -8.7562e-02,  1.4753e-01, -5.7552e-01, -4.3554e-01,
          -2.8826e-01, -2.9157e-02, -1.1238e+00,  2.4546e

In [None]:
# grabbing the embeddings that correspond to our sequence of token indices
x = embedding(tokens)
x.shape, x
# at this points many models would multiply the embeddings by the square root of the embedding dimension, but Llama 3 foregoes that strategy

(torch.Size([1, 5, 16]),
 tensor([[[ 1.2460,  0.7366, -1.0649, -0.0032, -0.4801,  0.2669, -2.2108,
           -1.0922, -1.4346,  0.5107, -0.8594,  0.1876, -0.4290, -0.7084,
            0.5554,  1.3216],
          [-0.3168,  0.7286,  1.9062, -0.1316,  1.3459,  0.2098,  1.0406,
            0.0434, -1.1801,  0.1451,  1.1746, -0.0457,  1.1738, -0.2251,
           -1.0406,  0.6283],
          [ 1.1500, -0.2625,  0.3006,  0.3029,  0.6783,  1.1275,  0.0193,
            1.2530, -0.0370,  0.2663, -0.3409, -0.1213, -0.6990,  0.6369,
            0.2205, -0.7866],
          [ 0.0578, -0.6434,  0.0517,  1.0794,  0.5199, -0.8906, -0.1347,
           -0.1232, -0.3073,  0.0801,  1.0582,  0.7009,  0.6694,  1.5953,
            0.7998,  1.3172],
          [ 0.3744, -1.0260,  0.2851,  1.4560, -0.1721,  0.3423, -0.1426,
            0.4811,  0.9600, -2.6945, -0.1461,  0.4694,  1.5548,  0.4641,
            0.6883, -0.0936]]], grad_fn=<EmbeddingBackward0>))

### 1c. Precomputing our RoPE Frequencies
<a id='c'></a>

Rotary Positional Encoding (RoPE) is a method [originally proposed in 2019](https://arxiv.org/abs/2104.09864) that quickly became the defacto standard for enabling transformers to understand positional information (by default the attention mechanism is blind to the ordering of tokens). The method utilizes trigonometry to "rotate" the entries in two matrices before they are multiplied together. A small amount of rotation indicates that two tokens are close together, while a large amount of rotation corresponds to being far apart. I'm going to skim over this topic so for a better conceptual explanation, I recommend checking out [this video](https://www.youtube.com/watch?v=GQPOtyITy54).

"Precompute" means we're going to calculate the frequencies ahead of time so that they can be sent through and reused throughout the model as opposed to creating them from scratch every time we need them.

In [None]:
theta = 10000 # 10,000 is the most common value but Llama 3 uses 50,000. In theory smaller models should use a smaller value
num_heads = 4 # Llama 3 8b has 32 total attention heads
head_dim = d // num_heads # Llama 3 ties its head dimension to the embedding dimension. This value comes out to 128 in Llama 3, which is purposeful to

# go watch the video to get a better explanation of what's happening here
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))
print(f'freqs: {freqs.shape}\n{freqs}\n')

t = torch.arange(seq_len * 2, device=freqs.device, dtype=torch.float32)
print(f't: {t.shape}\n{t}\n')

freqs = torch.outer(t, freqs)
print(f'freqs: {freqs.shape}\n{freqs}\n')

freqs_cis = torch.polar(torch.ones_like(freqs), freqs)[:seq_len]  # complex64
print(f'freqs_cis: {freqs_cis.shape}\n{freqs_cis}')

freqs: torch.Size([2])
tensor([1.0000, 0.0100])

t: torch.Size([10])
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])

freqs: torch.Size([10, 2])
tensor([[0.0000, 0.0000],
        [1.0000, 0.0100],
        [2.0000, 0.0200],
        [3.0000, 0.0300],
        [4.0000, 0.0400],
        [5.0000, 0.0500],
        [6.0000, 0.0600],
        [7.0000, 0.0700],
        [8.0000, 0.0800],
        [9.0000, 0.0900]])

freqs_cis: torch.Size([5, 2])
tensor([[ 1.0000+0.0000j,  1.0000+0.0000j],
        [ 0.5403+0.8415j,  0.9999+0.0100j],
        [-0.4161+0.9093j,  0.9998+0.0200j],
        [-0.9900+0.1411j,  0.9996+0.0300j],
        [-0.6536-0.7568j,  0.9992+0.0400j]])


### 1d. Precomputing the Causal Mask
<a id='d'></a>

Similar to RoPE embeddings, the causal mask is another part of the attention mechanism that we can create ahead of time to then be reused in every layer.

The basic idea of a causal mask is that by default, attention mechanisms allow every single token to pay attention to every single other token. This is okay or even preferable for some model types, but Llama is auto-regressive, meaning it would be bad if a given token to be predicted was able to see itself and future tokens during training but not during inference. The negative infinity's in the upper-triangle prevent the model from attending to the corresponding token; how this works will be more clear later when we do the attention softmax

In [None]:
mask = torch.full(
    (seq_len, seq_len),
    float("-inf")
)
mask = torch.triu(mask, diagonal=1)
mask

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

### 1e. Our First Normalization
<a id='e'></a>

Root Mean Square Normalization has also been the norm for quite awhile. Like its predecessor LayerNorm, RMSNorm restricts the variability of the entries in each embedding vector such that the vector lies on a hypersphere with radius $\sqrt{d}$. However unlike LayerNorm which centers that hypersphere with a mean of zero, RMSNorm does not mess with the mean, which is an important source of data for networks that utilize residual connections.

In [None]:
# first let's setup the residual connection that we'll use later
h = x
print(f'h: {h.shape}\n{h}')

h: torch.Size([1, 5, 16])
tensor([[[ 1.2460,  0.7366, -1.0649, -0.0032, -0.4801,  0.2669, -2.2108,
          -1.0922, -1.4346,  0.5107, -0.8594,  0.1876, -0.4290, -0.7084,
           0.5554,  1.3216],
         [-0.3168,  0.7286,  1.9062, -0.1316,  1.3459,  0.2098,  1.0406,
           0.0434, -1.1801,  0.1451,  1.1746, -0.0457,  1.1738, -0.2251,
          -1.0406,  0.6283],
         [ 1.1500, -0.2625,  0.3006,  0.3029,  0.6783,  1.1275,  0.0193,
           1.2530, -0.0370,  0.2663, -0.3409, -0.1213, -0.6990,  0.6369,
           0.2205, -0.7866],
         [ 0.0578, -0.6434,  0.0517,  1.0794,  0.5199, -0.8906, -0.1347,
          -0.1232, -0.3073,  0.0801,  1.0582,  0.7009,  0.6694,  1.5953,
           0.7998,  1.3172],
         [ 0.3744, -1.0260,  0.2851,  1.4560, -0.1721,  0.3423, -0.1426,
           0.4811,  0.9600, -2.6945, -0.1461,  0.4694,  1.5548,  0.4641,
           0.6883, -0.0936]]], grad_fn=<EmbeddingBackward0>)


In [None]:
# now we'll perform our first normalization
# first we square each entry in x and then take the mean of those values across each embedding vector
mean_squared = x.pow(2).mean(dim=-1, keepdim=True)
mean_squared

tensor([[[0.9653],
         [0.8077],
         [0.4150],
         [0.6101],
         [0.9582]]], grad_fn=<MeanBackward1>)

In [None]:
# then we multiply x by the reciprocal of the square roots of mean_squared
# 1e-6 is a very small number added for stability just in case an entry happens to be equal to 0 (since you can't divide by 0)
x_normed = x * torch.rsqrt(mean_squared + 1e-6)
print(f'x_normed: {x_normed.shape}\n{x_normed}')

x_normed: torch.Size([1, 5, 16])
tensor([[[ 1.2682,  0.7497, -1.0838, -0.0033, -0.4886,  0.2717, -2.2502,
          -1.1116, -1.4601,  0.5198, -0.8747,  0.1909, -0.4366, -0.7210,
           0.5653,  1.3452],
         [-0.3525,  0.8107,  2.1209, -0.1464,  1.4975,  0.2334,  1.1578,
           0.0483, -1.3131,  0.1614,  1.3070, -0.0508,  1.3060, -0.2504,
          -1.1579,  0.6991],
         [ 1.7851, -0.4075,  0.4667,  0.4703,  1.0530,  1.7502,  0.0299,
           1.9450, -0.0575,  0.4134, -0.5291, -0.1882, -1.0851,  0.9886,
           0.3423, -1.2210],
         [ 0.0740, -0.8237,  0.0662,  1.3819,  0.6656, -1.1402, -0.1724,
          -0.1578, -0.3934,  0.1026,  1.3548,  0.8974,  0.8570,  2.0424,
           1.0239,  1.6863],
         [ 0.3825, -1.0481,  0.2912,  1.4874, -0.1758,  0.3496, -0.1457,
           0.4914,  0.9807, -2.7526, -0.1492,  0.4796,  1.5883,  0.4741,
           0.7032, -0.0956]]], grad_fn=<MulBackward0>)


In [None]:
# and finally, we multiply by a learnable scale parameter
# This scale is initialized to 1's but if we were to train then those values would change
rms_scale = torch.ones(d)
print(f'rms_scale: {rms_scale.shape}\n{rms_scale}\n')

x_normed *= rms_scale
print(f'x_normed: {x_normed.shape}\n{x_normed}')

rms_scale: torch.Size([16])
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

x_normed: torch.Size([1, 5, 16])
tensor([[[ 1.2682,  0.7497, -1.0838, -0.0033, -0.4886,  0.2717, -2.2502,
          -1.1116, -1.4601,  0.5198, -0.8747,  0.1909, -0.4366, -0.7210,
           0.5653,  1.3452],
         [-0.3525,  0.8107,  2.1209, -0.1464,  1.4975,  0.2334,  1.1578,
           0.0483, -1.3131,  0.1614,  1.3070, -0.0508,  1.3060, -0.2504,
          -1.1579,  0.6991],
         [ 1.7851, -0.4075,  0.4667,  0.4703,  1.0530,  1.7502,  0.0299,
           1.9450, -0.0575,  0.4134, -0.5291, -0.1882, -1.0851,  0.9886,
           0.3423, -1.2210],
         [ 0.0740, -0.8237,  0.0662,  1.3819,  0.6656, -1.1402, -0.1724,
          -0.1578, -0.3934,  0.1026,  1.3548,  0.8974,  0.8570,  2.0424,
           1.0239,  1.6863],
         [ 0.3825, -1.0481,  0.2912,  1.4874, -0.1758,  0.3496, -0.1457,
           0.4914,  0.9807, -2.7526, -0.1492,  0.4796,  1.5883,  0.4741,
           0.7032, 

In [None]:
# let's turn that RMSNorm into a function that we'll be able to reuse repeatedly later
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

### 1f. Initializing Multi-Query Attention
<a id='f'></a>
[multi-query attention](https://arxiv.org/abs/1911.02150) is the de facto standard for saving on parameter counts in order to get a bigger model. The idea is that the model can make multiple queries to the residual state and have those many queries be answered by shared keys & values.

In [None]:
# first up, remember we're currently working with two separate objects
# x is for the residual connection and x_normed will go into our Attention calculation
h, x_normed

(tensor([[[ 1.2460,  0.7366, -1.0649, -0.0032, -0.4801,  0.2669, -2.2108,
           -1.0922, -1.4346,  0.5107, -0.8594,  0.1876, -0.4290, -0.7084,
            0.5554,  1.3216],
          [-0.3168,  0.7286,  1.9062, -0.1316,  1.3459,  0.2098,  1.0406,
            0.0434, -1.1801,  0.1451,  1.1746, -0.0457,  1.1738, -0.2251,
           -1.0406,  0.6283],
          [ 1.1500, -0.2625,  0.3006,  0.3029,  0.6783,  1.1275,  0.0193,
            1.2530, -0.0370,  0.2663, -0.3409, -0.1213, -0.6990,  0.6369,
            0.2205, -0.7866],
          [ 0.0578, -0.6434,  0.0517,  1.0794,  0.5199, -0.8906, -0.1347,
           -0.1232, -0.3073,  0.0801,  1.0582,  0.7009,  0.6694,  1.5953,
            0.7998,  1.3172],
          [ 0.3744, -1.0260,  0.2851,  1.4560, -0.1721,  0.3423, -0.1426,
            0.4811,  0.9600, -2.6945, -0.1461,  0.4694,  1.5548,  0.4641,
            0.6883, -0.0936]]], grad_fn=<EmbeddingBackward0>),
 tensor([[[ 1.2682,  0.7497, -1.0838, -0.0033, -0.4886,  0.2717, -2.2502,
   

In [None]:
# let's define the hyperparameters of MQA
num_kv_heads = 2 # Llama uses 8 key and value heads per layer
assert num_heads % num_kv_heads == 0 # each q needs to match up to a kv
print(f"as a reminder: num_heads = {num_heads}, head_dim = {head_dim}")

as a reminder: num_heads = 4, head_dim = 4


In [None]:
# now we'll initialize our self-attention weight matrices
wq = nn.Linear(d, num_heads * head_dim, bias=False)
wk = nn.Linear(d, num_kv_heads * head_dim, bias=False)
wv = nn.Linear(d, num_kv_heads * head_dim, bias=False)
print("Attention weights: ", wq.weight.shape, wk.weight.shape, wv.weight.shape)

# and project x_normed out to get our queries, keys and values
xq = wq(x_normed)
xk = wk(x_normed)
xv = wv(x_normed)
print("Attention projections: ", xq.shape, xk.shape, xv.shape)

# then reshape them to separate out by head
xq = xq.view(b, seq_len, num_heads, head_dim)
xk = xk.view(b, seq_len, num_kv_heads, head_dim)
xv = xv.view(b, seq_len, num_kv_heads, head_dim)
print("Reshaped: ", xq.shape, xk.shape, xv.shape)

Attention weights:  torch.Size([16, 16]) torch.Size([8, 16]) torch.Size([8, 16])
Attention projections:  torch.Size([1, 5, 16]) torch.Size([1, 5, 8]) torch.Size([1, 5, 8])
Reshaped:  torch.Size([1, 5, 4, 4]) torch.Size([1, 5, 2, 4]) torch.Size([1, 5, 2, 4])


### 1g. Rotary Position Embeddings
<a id='g'></a>

Earlier we pre-computed the frequencies for rotation. Now we'll actually apply our rotary embeddings.

In [None]:
# first we reshape and then view our queries and keys as complex values, the type of number that works well with rotation
xq = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
print(f'xq: {xq.shape}\n{xq}\n')
print(f'xk: {xk.shape}\n{xk}')

xq: torch.Size([1, 5, 4, 2])
tensor([[[[ 0.2170-0.2767j, -0.1258-0.7064j],
          [ 0.0471+0.0557j, -0.1916-0.3776j],
          [-0.2885+0.0269j, -0.1583-0.1648j],
          [ 0.1905-0.0708j,  0.4597+0.4966j]],

         [[ 0.8161-0.8226j,  1.1365-0.0757j],
          [ 0.3695-0.7906j,  1.3074-0.2812j],
          [ 0.6162-0.1722j, -0.4068+1.3610j],
          [-0.9063+0.0619j, -0.3543-0.0561j]],

         [[ 0.1515-0.3930j, -0.2620-0.2216j],
          [ 0.2307+0.0930j, -0.1416-1.3326j],
          [ 0.0783+0.3804j,  0.2594+0.7840j],
          [ 0.5426-0.0978j,  0.9323-0.1451j]],

         [[ 0.9782-0.7228j, -0.8617-0.5507j],
          [ 0.4022+0.1999j, -0.1467+0.4029j],
          [ 0.0165+0.3566j, -0.0704+0.1177j],
          [-1.3812-0.7442j, -0.1762+0.4091j]],

         [[-0.1869+0.1524j,  0.5146-0.1804j],
          [ 0.4554+0.4472j, -0.9757-0.4605j],
          [-0.2682+0.3110j, -0.1756+0.3017j],
          [ 1.0174-0.2027j, -0.0683+0.3217j]]]],
       grad_fn=<ViewAsComplexBackward0>)

In [None]:
ndim = xq.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (xq.shape[1], xq.shape[-1]), f'freqs_cis.shape {freqs_cis.shape} != xq.shape[1], xq.shape[-1] {(xq.shape[1], xq.shape[-1])}'

# reshape our queries
shape = [d if i == 1 or i == xq.ndim - 1 else 1 for i, d in enumerate(xq.shape)]
print(f'shape: {shape}\n')

freqs_cis = freqs_cis.view(*shape)
print(f'freqs_cis: {freqs_cis.shape}\n{freqs_cis}')

shape: [1, 5, 1, 2]

freqs_cis: torch.Size([1, 5, 1, 2])
tensor([[[[ 1.0000+0.0000j,  1.0000+0.0000j]],

         [[ 0.5403+0.8415j,  0.9999+0.0100j]],

         [[-0.4161+0.9093j,  0.9998+0.0200j]],

         [[-0.9900+0.1411j,  0.9996+0.0300j]],

         [[-0.6536-0.7568j,  0.9992+0.0400j]]]])


In [None]:
# now multiply the data by the frequencies, turn them back into real numbers, revert the shape and make sure they're of the right type
xq = torch.view_as_real(xq * freqs_cis).flatten(3).type_as(xv)
xk = torch.view_as_real(xk * freqs_cis).flatten(3).type_as(xv)
print(f'xq: {xq.shape}\n{xq}\n')
print(f'xk: {xk.shape}\n{xk}')

xq: torch.Size([1, 5, 4, 4])
tensor([[[[ 2.1695e-01, -2.7675e-01, -1.2576e-01, -7.0642e-01],
          [ 4.7134e-02,  5.5681e-02, -1.9157e-01, -3.7763e-01],
          [-2.8846e-01,  2.6856e-02, -1.5825e-01, -1.6482e-01],
          [ 1.9047e-01, -7.0766e-02,  4.5974e-01,  4.9656e-01]],

         [[ 1.1331e+00,  2.4227e-01,  1.1372e+00, -6.4345e-02],
          [ 8.6493e-01, -1.1629e-01,  1.3101e+00, -2.6810e-01],
          [ 4.7785e-01,  4.2547e-01, -4.2034e-01,  1.3569e+00],
          [-5.4173e-01, -7.2918e-01, -3.5373e-01, -5.9591e-02]],

         [[ 2.9432e-01,  3.0135e-01, -2.5751e-01, -2.2682e-01],
          [-1.8063e-01,  1.7110e-01, -1.1491e-01, -1.3351e+00],
          [-3.7844e-01, -8.7070e-02,  2.4368e-01,  7.8901e-01],
          [-1.3687e-01,  5.3407e-01,  9.3506e-01, -1.2645e-01]],

         [[-8.6646e-01,  8.5359e-01, -8.4484e-01, -5.7628e-01],
          [-4.2638e-01, -1.4113e-01, -1.5868e-01,  3.9828e-01],
          [-6.6629e-02, -3.5071e-01, -7.3944e-02,  1.1552e-01],
     

### 1h. Calculating Self-Attention
<a id='h'></a>
now we get to perform the actual attention calculation

In [None]:
# If the number of K & V heads is different from the number of query heads, adjusts keys and values to match the query heads count.
if num_kv_heads != num_heads:
  num_queries_per_kv = num_heads // num_kv_heads
  xk = torch.repeat_interleave(xk, num_queries_per_kv, dim=2)
  xv = torch.repeat_interleave(xv, num_queries_per_kv, dim=2)

xq.shape, xk.shape, xv.shape

(torch.Size([1, 5, 4, 4]), torch.Size([1, 5, 4, 4]), torch.Size([1, 5, 4, 4]))

In [None]:
# Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)

xq.shape, xk.shape, xv.shape

(torch.Size([1, 4, 5, 4]), torch.Size([1, 4, 5, 4]), torch.Size([1, 4, 5, 4]))

In [None]:
# Calculates attention logits by performing a batch matrix multiplication between queries and keys
scores = torch.matmul(xq, xk.transpose(2, 3))

# then we scale the logits by the reciprocal of the square root of the head dimension
scores = scores / math.sqrt(head_dim)

scores.shape, scores

(torch.Size([1, 4, 5, 5]),
 tensor([[[[ 0.0349,  0.2964,  0.0679, -0.1770, -0.0546],
           [-0.8843,  0.5276,  0.8201,  0.0848,  0.8594],
           [-0.0630,  0.3148, -0.0271, -0.1148, -0.0345],
           [ 0.4231, -0.1086, -0.8209, -0.2176, -0.6920],
           [-0.2993,  0.1005,  0.2658,  0.0315,  0.2896]],
 
          [[ 0.0450,  0.1652, -0.0704, -0.1144, -0.1031],
           [-0.7876,  0.3469,  0.8225,  0.0948,  0.8243],
           [ 0.0512,  0.2826, -0.1944, -0.2927, -0.2320],
           [ 0.2396, -0.3472, -0.1929,  0.0907, -0.1715],
           [ 0.5120,  0.2293, -0.2228, -0.2435, -0.4588]],
 
          [[ 0.0634, -0.1494,  0.1283,  0.0430,  0.0538],
           [ 0.2279,  0.0095,  0.1995, -0.0229, -0.0418],
           [ 0.1597,  0.1338, -0.1130, -0.0481, -0.2681],
           [ 0.0853, -0.1153,  0.0871, -0.1732, -0.1560],
           [ 0.0252, -0.0265,  0.0860, -0.0907, -0.0127]],
 
          [[-0.0592,  0.3350, -0.3119, -0.0533, -0.1957],
           [ 0.2513, -0.4484,  0.349

In [None]:
# now we get to use the mask that we precomputed earlier
scores = scores + mask

scores.shape, scores

(torch.Size([1, 4, 5, 5]),
 tensor([[[[ 0.0349,    -inf,    -inf,    -inf,    -inf],
           [-0.8843,  0.5276,    -inf,    -inf,    -inf],
           [-0.0630,  0.3148, -0.0271,    -inf,    -inf],
           [ 0.4231, -0.1086, -0.8209, -0.2176,    -inf],
           [-0.2993,  0.1005,  0.2658,  0.0315,  0.2896]],
 
          [[ 0.0450,    -inf,    -inf,    -inf,    -inf],
           [-0.7876,  0.3469,    -inf,    -inf,    -inf],
           [ 0.0512,  0.2826, -0.1944,    -inf,    -inf],
           [ 0.2396, -0.3472, -0.1929,  0.0907,    -inf],
           [ 0.5120,  0.2293, -0.2228, -0.2435, -0.4588]],
 
          [[ 0.0634,    -inf,    -inf,    -inf,    -inf],
           [ 0.2279,  0.0095,    -inf,    -inf,    -inf],
           [ 0.1597,  0.1338, -0.1130,    -inf,    -inf],
           [ 0.0853, -0.1153,  0.0871, -0.1732,    -inf],
           [ 0.0252, -0.0265,  0.0860, -0.0907, -0.0127]],
 
          [[-0.0592,    -inf,    -inf,    -inf,    -inf],
           [ 0.2513, -0.4484,    -in

In [None]:
# now we perform the softmax operation to get our actual probabilities
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores
# notice that thanks to the causal mask, 0 probability is placed on future tokens

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.1959, 0.8041, 0.0000, 0.0000, 0.0000],
          [0.2861, 0.4174, 0.2965, 0.0000, 0.0000],
          [0.4162, 0.2446, 0.1200, 0.2193, 0.0000],
          [0.1343, 0.2003, 0.2363, 0.1870, 0.2420]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.2433, 0.7567, 0.0000, 0.0000, 0.0000],
          [0.3287, 0.4143, 0.2571, 0.0000, 0.0000],
          [0.3261, 0.1813, 0.2116, 0.2810, 0.0000],
          [0.3245, 0.2446, 0.1556, 0.1524, 0.1229]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5544, 0.4456, 0.0000, 0.0000, 0.0000],
          [0.3655, 0.3562, 0.2783, 0.0000, 0.0000],
          [0.2784, 0.2278, 0.2789, 0.2150, 0.0000],
          [0.2055, 0.1952, 0.2184, 0.1830, 0.1979]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.6681, 0.3319, 0.0000, 0.0000, 0.0000],
          [0.2358, 0.5998, 0.1645, 0.0000, 0.0000],
          [0.1996, 0.3163, 0.2303, 0.2538, 0.0000],
      

In [None]:
# then matmul by our values projection
output = torch.matmul(scores, xv)
output.shape, output

(torch.Size([1, 4, 5, 4]),
 tensor([[[[ 0.8224,  0.4530,  0.5962,  0.6171],
           [-0.0093,  0.1717,  0.1173,  0.5434],
           [ 0.1605,  0.0397,  0.2777,  0.3978],
           [ 0.3163,  0.1755,  0.4137,  0.3426],
           [ 0.0670,  0.1109,  0.1758,  0.1951]],
 
          [[ 0.8224,  0.4530,  0.5962,  0.6171],
           [ 0.0397,  0.1883,  0.1455,  0.5478],
           [ 0.1944,  0.0763,  0.2888,  0.4222],
           [ 0.2656,  0.0915,  0.4273,  0.2422],
           [ 0.2215,  0.1752,  0.2870,  0.3211]],
 
          [[ 0.7174,  0.0402,  0.2111, -0.7315],
           [ 0.7096, -0.2251,  0.1030, -0.1930],
           [ 0.4068, -0.1231, -0.0190, -0.2104],
           [ 0.1971, -0.0743,  0.1910, -0.2105],
           [ 0.1888, -0.0599,  0.0187, -0.0097]],
 
          [[ 0.7174,  0.0402,  0.2111, -0.7315],
           [ 0.7116, -0.1574,  0.1306, -0.3304],
           [ 0.5270, -0.2881, -0.0193,  0.0469],
           [ 0.2104, -0.1405,  0.2270, -0.0914],
           [ 0.1992, -0.0298,  0.

In [None]:
# and reshape to put the sequence length back into place and the outputs of our heads lined up
output = output.transpose(1, 2).contiguous().view(b, seq_len, -1)
output.shape, output

(torch.Size([1, 5, 16]),
 tensor([[[ 0.8224,  0.4530,  0.5962,  0.6171,  0.8224,  0.4530,  0.5962,
            0.6171,  0.7174,  0.0402,  0.2111, -0.7315,  0.7174,  0.0402,
            0.2111, -0.7315],
          [-0.0093,  0.1717,  0.1173,  0.5434,  0.0397,  0.1883,  0.1455,
            0.5478,  0.7096, -0.2251,  0.1030, -0.1930,  0.7116, -0.1574,
            0.1306, -0.3304],
          [ 0.1605,  0.0397,  0.2777,  0.3978,  0.1944,  0.0763,  0.2888,
            0.4222,  0.4068, -0.1231, -0.0190, -0.2104,  0.5270, -0.2881,
           -0.0193,  0.0469],
          [ 0.3163,  0.1755,  0.4137,  0.3426,  0.2656,  0.0915,  0.4273,
            0.2422,  0.1971, -0.0743,  0.1910, -0.2105,  0.2104, -0.1405,
            0.2270, -0.0914],
          [ 0.0670,  0.1109,  0.1758,  0.1951,  0.2215,  0.1752,  0.2870,
            0.3211,  0.1888, -0.0599,  0.0187, -0.0097,  0.1992, -0.0298,
            0.0423, -0.1315]]], grad_fn=<ViewBackward0>))

In [None]:
# finally we can initialize and apply our output projection that mixes the information from the heads together
wo = nn.Linear(num_heads * head_dim, d, bias=False)
Xout = wo(output)
Xout.shape, Xout

(torch.Size([1, 5, 16]),
 tensor([[[ 7.1982e-01, -4.1950e-01,  4.7529e-02,  1.0536e-01, -4.7111e-01,
           -1.5629e-01, -2.7153e-02,  2.5266e-01, -1.9366e-01, -2.6828e-01,
           -1.8550e-02,  3.5433e-01,  4.0302e-01, -2.3292e-01, -2.2380e-01,
           -2.0106e-01],
          [ 3.3439e-01, -2.0078e-01,  8.8455e-02, -5.8876e-02, -1.3279e-01,
            1.6281e-01,  3.3500e-02,  7.2977e-02, -1.7260e-01, -1.2296e-01,
            1.3433e-01,  1.4968e-01,  2.3543e-01, -2.5995e-01, -1.0525e-05,
           -1.9271e-01],
          [ 3.4752e-01, -2.4674e-01,  7.5127e-02, -1.0779e-01, -1.7768e-01,
            1.9428e-01,  3.9898e-02,  2.2347e-02,  7.5406e-02, -7.1007e-02,
            1.5316e-01,  9.0577e-02,  1.2806e-01, -1.8189e-01, -8.5211e-02,
           -1.5611e-01],
          [ 3.0263e-01, -2.0786e-01, -1.8208e-02, -7.5674e-02, -1.7330e-01,
           -3.8195e-02,  5.9163e-02, -1.6124e-02,  3.3142e-02, -1.3451e-01,
            2.5009e-04,  1.5753e-01,  2.0458e-01, -5.7169e-02, -

### 1i. Our first residual connection
<a id='i'></a>
Here we'll normalize the output of our attention mechanism and then add it to our residual state

In [None]:
h += Xout
h.shape, h

(torch.Size([1, 5, 16]),
 tensor([[[ 3.5506, -0.6064, -0.9127,  0.3341, -1.9883, -0.2334, -2.2978,
           -0.2833, -2.0546, -0.3482, -0.9188,  1.3220,  0.8613, -1.4541,
           -0.1611,  0.6779],
          [ 1.6564, -0.4561,  2.4281, -0.4790,  0.5623,  1.1705,  1.2383,
            0.4741, -2.1986, -0.5805,  1.9673,  0.8375,  2.5630, -1.7590,
           -1.0407, -0.5088],
          [ 3.3650, -1.8352,  0.7795, -0.3841, -0.4541,  2.3658,  0.2736,
            1.3954,  0.4436, -0.1863,  0.6353,  0.4561,  0.1173, -0.5224,
           -0.3226, -1.7816],
          [ 2.4056, -2.2560, -0.0896,  0.4923, -0.8245, -1.1869,  0.3243,
           -0.2483, -0.0502, -0.9634,  1.0602,  1.9231,  2.2565,  1.1518,
            0.6319,  1.4035],
          [ 2.8266, -2.6940,  0.0327,  0.8795, -0.6646,  0.4215,  0.1522,
            1.0536,  0.9114, -3.7775, -0.4138,  1.4161,  3.2076, -0.6312,
            0.6284, -0.3478]]], grad_fn=<AddBackward0>))

In [None]:
# then we'll normalize the current state of our residual for use in our MoE later
pre_ffwd_norm = RMSNorm(d)
h_normed = pre_ffwd_norm(h)
# so now we're working with x, which we'll use later for our next residual conenction, and x_normed which is used by our MoE MLP

### 1j. The SwiGLU Feedforward Network
<a id='j'></a>

Llama models have surprisingly not opted for a mixture of experts strategy which i was assuming they'd go for by now. Their feedforward networks use the SwiGLU activation which basically uses the activation function as a gate that dynamically determines what information gets through

In [None]:
# first we need to define our actual hidden dimension, which Llama's code does in an unnecessarily complicated manner
hidden_dim = 4 * d # usually i would designate a hyperparameter for this 4, but in llama's code it was just there
print(hidden_dim)
hidden_dim = int(2 * hidden_dim / 3)
print(hidden_dim)
multiple_of = 256 # their description of this was "make SwiGLU hidden layer size multiple of large power of 2"
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
print(hidden_dim)
# so basically this overly convoluted setup is designed to ensure that hidden_dim is a multiple of 256, likely for hardware efficiency reasons

64
42
256


In [None]:
up = nn.Linear(d, hidden_dim, bias=False)
gate = nn.Linear(d, hidden_dim, bias=False)
down = nn.Linear(hidden_dim, d, bias=False)

In [None]:
up_proj = up(h_normed)
print(up_proj.shape, up_proj)

torch.Size([1, 5, 256]) tensor([[[-0.3775,  0.2147,  0.2070,  ...,  1.1758,  0.5350,  0.1489],
         [-0.4961,  0.0802,  0.6036,  ...,  0.5064, -0.1309, -1.0673],
         [ 0.0369,  0.0817, -0.2353,  ...,  0.1451,  0.1230, -0.7031],
         [ 0.4918,  0.0617,  0.7253,  ...,  0.4688, -0.1773,  0.2506],
         [ 0.1094,  0.7400,  0.3883,  ...,  0.3386, -0.1453, -0.4714]]],
       grad_fn=<UnsafeViewBackward0>)


In [None]:
gate_proj = F.silu(gate(h_normed))
print(gate_proj.shape, gate_proj)

torch.Size([1, 5, 256]) tensor([[[-0.2654,  0.4762, -0.1746,  ...,  0.1283,  0.5745,  0.1288],
         [-0.2330, -0.0939, -0.2176,  ...,  0.2495,  0.1948,  0.2031],
         [ 0.1397,  0.5607, -0.2677,  ...,  0.3105,  0.0060, -0.1339],
         [-0.2461,  0.3907,  0.3002,  ...,  0.5885,  0.6722,  0.6079],
         [-0.0112,  0.6475,  0.1435,  ...,  0.8659,  0.3234,  0.0028]]],
       grad_fn=<SiluBackward0>)


In [None]:
ffwd_output = down(up_proj * gate_proj)
print(ffwd_output.shape, ffwd_output)

torch.Size([1, 5, 16]) tensor([[[ 0.0056,  0.0167,  0.0514,  0.0153, -0.0333, -0.1343, -0.1392,
           0.0715,  0.0349,  0.0315,  0.0690, -0.0434,  0.0209,  0.0681,
           0.1831,  0.0926],
         [ 0.0958,  0.0138,  0.1765,  0.0728, -0.0919, -0.1499,  0.0478,
          -0.1806, -0.1907,  0.0043,  0.0631, -0.1799, -0.0557,  0.0860,
          -0.0789, -0.0315],
         [-0.0702,  0.0152, -0.0829,  0.0884,  0.0553,  0.0950,  0.1276,
          -0.1985, -0.0489,  0.1400, -0.0942,  0.1672,  0.0407,  0.0818,
          -0.1313,  0.0097],
         [-0.0949,  0.0313,  0.1437,  0.0006, -0.0451, -0.0712, -0.0642,
          -0.0783,  0.0467,  0.1035,  0.0802,  0.0331, -0.0398,  0.0468,
           0.0242,  0.1243],
         [-0.0432,  0.0420,  0.1084, -0.0500, -0.0397,  0.0344,  0.0966,
           0.0896,  0.0391,  0.0631, -0.0043,  0.0644,  0.0658,  0.0548,
           0.0436,  0.1454]]], grad_fn=<UnsafeViewBackward0>)


In [None]:
# and then do our final residual connection of this layer
out = h + ffwd_output
print(out.shape, out)

torch.Size([1, 5, 16]) tensor([[[ 3.5562e+00, -5.8969e-01, -8.6136e-01,  3.4942e-01, -2.0216e+00,
          -3.6778e-01, -2.4369e+00, -2.1177e-01, -2.0197e+00, -3.1668e-01,
          -8.4981e-01,  1.2786e+00,  8.8223e-01, -1.3860e+00,  2.2040e-02,
           7.7058e-01],
         [ 1.7521e+00, -4.4236e-01,  2.6046e+00, -4.0619e-01,  4.7042e-01,
           1.0206e+00,  1.2861e+00,  2.9346e-01, -2.3893e+00, -5.7619e-01,
           2.0303e+00,  6.5757e-01,  2.5073e+00, -1.6729e+00, -1.1196e+00,
          -5.4033e-01],
         [ 3.2949e+00, -1.8201e+00,  6.9662e-01, -2.9575e-01, -3.9885e-01,
           2.4609e+00,  4.0119e-01,  1.1970e+00,  3.9473e-01, -4.6282e-02,
           5.4112e-01,  6.2323e-01,  1.5797e-01, -4.4065e-01, -4.5393e-01,
          -1.7719e+00],
         [ 2.3108e+00, -2.2247e+00,  5.4144e-02,  4.9290e-01, -8.6964e-01,
          -1.2582e+00,  2.6018e-01, -3.2664e-01, -3.4225e-03, -8.5995e-01,
           1.1404e+00,  1.9562e+00,  2.2167e+00,  1.1986e+00,  6.5606e-01,
     

### 1k. Output
<a id='k'></a>
So usually we'd run it back on steps 1e through 1j for however many layers our model has (Llama 3 8b uses 32) using different weight matrices but you get the point. Since our current `out` is of the same shape that it would be if we were to do more layers, let's go ahead and just see what Llama's output mechanism looks like. It's nothing interesting though, just a linear layer. Notably they chose to use a separate linear layer rather than re-using the embedding layer as is relatively common

In [None]:
# first we norm the residual state
final_norm = RMSNorm(d)
out_normed = final_norm(out)

In [None]:
# then multiply by the linear layer to get our final output logits
final_output = nn.Linear(d, v, bias=False)
logits = final_output(out_normed).float()
logits.shape, logits

(torch.Size([1, 5, 10]),
 tensor([[[ 7.6650e-01,  4.1170e-01,  4.2533e-01,  1.1216e-01,  7.7646e-01,
            1.4128e-01, -8.2421e-01, -4.4003e-02,  1.5142e-01, -4.3939e-01],
          [ 6.2985e-01,  5.7959e-01, -1.1361e-01, -5.2311e-01,  3.3086e-01,
           -9.6442e-01, -1.1810e+00, -7.6030e-01,  4.5599e-01, -7.6835e-01],
          [ 1.0984e+00,  1.2014e+00,  1.7533e-01,  8.2349e-02,  2.1564e-01,
           -6.5409e-01, -2.9928e-01, -5.8757e-01, -3.7019e-04, -8.5932e-01],
          [ 1.0524e-01,  4.8672e-02,  7.2490e-01,  4.1315e-02, -1.3067e-01,
           -6.7627e-01, -8.2063e-01, -8.1367e-01, -6.3231e-01, -7.6414e-01],
          [ 6.9820e-01,  1.2350e+00,  4.9250e-01,  5.3537e-01,  1.6038e-01,
           -6.2946e-01, -7.7267e-01, -9.7907e-01,  2.7908e-01, -6.1061e-01]]],
        grad_fn=<UnsafeViewBackward0>))

In [None]:
# softmax the logits to get the probability for each token's prediction across every token in the sequence
probs = F.softmax(logits, dim=-1)
probs

tensor([[[0.1676, 0.1176, 0.1192, 0.0871, 0.1693, 0.0897, 0.0342, 0.0745,
          0.0906, 0.0502],
         [0.1928, 0.1833, 0.0917, 0.0609, 0.1430, 0.0391, 0.0315, 0.0480,
          0.1620, 0.0476],
         [0.2293, 0.2542, 0.0911, 0.0830, 0.0948, 0.0397, 0.0567, 0.0425,
          0.0764, 0.0324],
         [0.1302, 0.1230, 0.2419, 0.1221, 0.1028, 0.0596, 0.0516, 0.0519,
          0.0623, 0.0546],
         [0.1523, 0.2604, 0.1239, 0.1294, 0.0889, 0.0404, 0.0350, 0.0285,
          0.1001, 0.0411]]], grad_fn=<SoftmaxBackward0>)

In [None]:
# Greedily decode the probabilities to get our final predicted indices
greedy_indices = torch.argmax(probs, dim=-1)
greedy_indices
# if we were performing inference rather than training, that final token in the list would be the one to show the user

tensor([[4, 0, 1, 2, 1]])

### 1l. The loss functions
<a id='l'></a>

Of course we use [cross-entropy loss](https://machinelearningmastery.com/cross-entropy-for-machine-learning/) which should need no introduction if this isn't your first machine-learning rodeo, so we'll be skimming past it. Basically the idea is that the single correct value is rewarded and all other values are suppressed

In [None]:
# create some random fake target indices to train on
target_token_indices = torch.randint(0, v, greedy_indices.shape)
print(target_token_indices)

# initialize the loss function
loss_fn = nn.CrossEntropyLoss()

# reshape logits to be compatible and calculate loss
loss = loss_fn(logits.view(1,v,seq_len), target_token_indices)
print(loss)

tensor(2.7078, grad_fn=<NllLoss2DBackward0>)

and that's it! those are all the essentail calcuations that Llama performs, most of which aren't any different from other open-source LLMs like Grok, Mixtral or Gemini (Llama is most similar to Gemini since Mixtral and Grok utilize [mixture of experts](https://huggingface.co/blog/moe) for their feedforward networks). Now let's code everything up the correct way into classes so that we can actually build a functioning model

# 2. Actually functional model code
<a id='two'></a>
The bulk of the lesson is over, but the following code demosntrates how you'd actually take the concepts and turn them into functioning nn.Module classes. Alternatively to reading through them here, you can check out the .py files in [the repo](https://github.com/evintunador/minLlama). I'm not going to bother explaining this section in the same detail, except for a few places where things are different/new enough to add comments

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from dataclasses import dataclass
from typing import Optional, Tuple

import time

import os
import json

# we'll be using a crazy small & simple tokenizer that I made based on the TinyShakespeare dataset
# Llama 3 8b's vocabulary size is 128256 including special tokens like <|endoftext|>

# download the tokenizer code
!wget https://raw.githubusercontent.com/evintunador/minLlama3/main/tiny_shakespeare_tokenizer.py
# and the tokenizer model
!wget https://raw.githubusercontent.com/evintunador/minLlama3/main/tokenizers/tiny_shakespeare_tokenizer_512.model
!mkdir -p tokenizers
!mv tiny_shakespeare_tokenizer_512.model tokenizers/
from tiny_shakespeare_tokenizer import *
tokenizer = get_tokenizer(size = 512)

--2024-04-20 18:17:09--  https://raw.githubusercontent.com/evintunador/minLlama3/main/tiny_shakespeare_tokenizer.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2180 (2.1K) [text/plain]
Saving to: ‘tiny_shakespeare_tokenizer.py’


2024-04-20 18:17:10 (45.5 MB/s) - ‘tiny_shakespeare_tokenizer.py’ saved [2180/2180]

--2024-04-20 18:17:10--  https://raw.githubusercontent.com/evintunador/minLlama3/main/tokenizers/tiny_shakespeare_tokenizer_512.model
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4298 (4.2K) [application/octet-stre

### 2a. the parameters
<a id='twoa'></a>

here are the parameters that I've setup for my little minLLama3 test model

In [None]:
@dataclass # the hyperparameters of our minLlama3
class ModelArgs:
    dim: int = 128 # Llama 3 8b uses 4096
    n_layers: int = 8 # Llama 3 8b uses 32
    n_heads: int = 4 # Llama 3 8b uses 32
    n_kv_heads: Optional[int] = 1 # Llama 3 8b's uses 8
    vocab_size: int = tokenizer.vocab_len # Llama 3 uses a more complicated tokenizer of length 128256
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2. Llama 3 8b's uses 1024
    ffn_dim_multiplier: Optional[float] = None # Llama 3 8b's uses 1.3, which changes the ending hidden_dim slightly
    norm_eps: float = 1e-5
    rope_theta: float = 10000 # Llama 3 8b uses 500000
    max_batch_size: int = 32 # who knows what batch size they trained with
    max_seq_len: int = 512 # Llama 3 8b trained with 8192 but their maximum kv cache chunk size during inference is 2048
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    dropout_rate: float = 0.1 # who knows what dropout rate they trained with

### 2b. RMSNorm
<a id='twob'></a>

In [None]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

### 2c. RoPE
<a id='twoc'></a>


In [None]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis.to(params.device)

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis.shape {freqs_cis.shape} != (x.shape[1], x.shape[-1]) {(x.shape[1], x.shape[-1])}'
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

### 2d. Attention
<a id='twod'></a>

In [None]:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, seqlen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, seqlen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
    ) # this code looks different from hwo we did it in section 1 bit it's effectively the same

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        self.n_rep = args.n_heads // self.n_kv_heads
        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

        self.cache_k = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim),
            requires_grad = False
        ).to(args.device)
        self.cache_v = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim),
            requires_grad = False
        ).to(args.device)

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
        start_pos: int = None,
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        if start_pos is not None: # if we're performing inference, use kv caching (not shown in section 1)
            # make sure our cache is on the right device
            self.cache_k = self.cache_k.to(xq)
            self.cache_v = self.cache_v.to(xq)

            # set the values in our cache according to the current input
            self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
            self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

            # grab our key and value matrixes which have a longer sequence length than our queries
            keys = self.cache_k[:bsz, : start_pos + seqlen]
            values = self.cache_v[:bsz, : start_pos + seqlen]
        else:
            # if we're training, do full sequence length (like in section 1)
            keys, values = xk, xv

        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(keys, self.n_rep)  # (bs, cache_len + seqlen, n_heads, head_dim)
        values = repeat_kv(values, self.n_rep)  # (bs, cache_len + seqlen, n_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)  # (bs, n_heads, cache_len + seqlen, head_dim)
        values = values.transpose(1, 2)  # (bs, n_heads, cache_len + seqlen, head_dim)

        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)

        output = torch.matmul(scores, values)  # (bs, n_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

### 2e. Ffwd
<a id='twoe'></a>

In [None]:
class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()
        # custom dim factor multiplier that ensures we're using a multiple of 256, likely for hardware efficiency reasons
        hidden_dim = int(2 * hidden_dim / 3)
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

### 2f. Residual Layers
<a id='twof'></a>

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.dropout_rate = args.dropout_rate

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
        start_pos: int = None,
        training = False,
    ):
        # our two residual connections, plus dropout which will only happen if we're training
        h = x + F.dropout(self.attention(self.attention_norm(x), freqs_cis, mask, start_pos), p=self.dropout_rate, training=training)
        out = h + F.dropout(self.feed_forward(self.ffn_norm(h)), p=self.dropout_rate, training=training)
        return out

### 2g. The model itself
<a id='twog'></a>

In [None]:
class Llama3(nn.Module):
    def __init__(self, params: ModelArgs, tokenizer):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers
        self.max_seq_len = params.max_seq_len
        self.tokenizer = tokenizer

        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        # final norm and linear layer
        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = nn.Linear(
            params.dim,
            params.vocab_size,
            bias=False)

        # precompute RoPE frequencies
        self.freqs_cis = precompute_freqs_cis(
            params.dim // params.n_heads,
            params.max_seq_len * 2,
            params.rope_theta,)

        # precompute the causal attention mask
        mask = torch.full((params.max_seq_len, params.max_seq_len),
                          float("-inf"),
                          device=params.device)
        mask = torch.triu(mask, diagonal=1)
        self.register_buffer('mask', mask)

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, # specifically for training. this is what you saw in section 1
                tokens: torch.Tensor,
                targets: torch.Tensor):
        bsz, seqlen = tokens.shape
        assert tokens.shape == targets.shape
        assert seqlen == self.max_seq_len

        # initialize the first residual state
        h = self.tok_embeddings(tokens)

        # grab precomputes freqs_cis
        freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[:seqlen]

        # run the residual state through each layer
        for layer in self.layers:
            h = layer(
                h,
                freqs_cis,
                self.mask,
                start_pos = None,
                training = True
            )

        # norm the final output then get the logits
        h = self.norm(h)
        logits = self.output(h).float()

        loss = self.criterion(
            logits.view(bsz * seqlen, self.vocab_size),
            targets.reshape(bsz * seqlen))

        return logits, loss

    @torch.inference_mode()
    def forward_inference(self,
                          tokens: torch.Tensor,
                          start_pos: int,
                          max_context_window: int,
                         ):
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        mask = self.mask[:seqlen, :seqlen]
        # When performing key-value caching, we compute the attention scores
        # only for the new sequence. Thus, the matrix of scores is of size
        # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
        # j > cache_len + i, since row i corresponds to token cache_len + i.
        mask = torch.hstack(
            [torch.zeros((seqlen, start_pos), device=tokens.device), mask]
        ).type_as(h)

        for layer in self.layers:
            h = layer(
                h,
                freqs_cis,
                mask,
                start_pos = start_pos
            )
        h = self.norm(h)
        logits = self.output(h).float()
        return logits

    @torch.inference_mode() # no need to keep track of gradients during inference
    def Sampler(
        self,
        logits: torch.Tensor, # shape (batch_size, input_len, vocab_size)
        temperature: float, # controls how boring vs random the outputs should be
        top_p: float, # the maximum cumulative probability of output options we're willing to consider
        top_k: int, # the maximum number of output options we're willing to consider
    ) -> torch.Tensor:
        """
        The Sampler function is responsible for generating token predictions
        It supports temperature scaling, top-p (nucleus) sampling, and top-k sampling
        """
        # Select the last element for each sequence.
        logits = logits[:,-1,:] # (batch_size, input_len, vocab_size) -> (batch_size, vocab_size)

        # Apply temperature scaling
        logits.div_(temperature) # (batch_size, vocab_size) / float -> (batch_size, vocab_size)

        # Calculate probabilities with softmax.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float) # dim=-1 is the vocab_size dimension that we calculate along

        # sort the probabilities to for use in top-p & top-k. both are (batch_size, vocab_size)
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)

        ### calculating top-p
        # creates same-size tensor of cumulatve probabilities instead of indivdiual probs
        probs_sum = torch.cumsum(probs_sort, dim=-1)
        # mask where 0's are top-p selections & 1's are to be excluded
        top_ps_mask = (probs_sum - probs_sort) > top_p
        # the original probabilities with excluded tokens changed to 0.0
        probs_sort = torch.where(top_ps_mask, 0, probs_sort)

        ### calculating top_k
        # create a shape (vocab_size) tensor that just iterates up by 1's
        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
        # expand our mask along the batch_size dimension to become size (batch_size, vocab_size)
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
        # top_ks is a list of integers. we keep whichever entries in top_ks_mask are greater than their corresponding entries in top_ks
        top_ks_mask = top_ks_mask >= top_k

        # we'll be combining top-p with top-k and using whichever gives us fewer tokens. a very conservative approach
        # this trims probs_sort to also fit within our top_k requirement
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)

        # Re-normalization so that total probabilities add up to 1
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))

        # now we rearrange the modified probabilities in probs_sort back to their original order according to probs_idx
        probs = torch.gather(probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))

        # samples from the distribution
        next_token_id = torch.multinomial(probs, num_samples=1)

        return next_token_id # returns the predicted token

    @torch.inference_mode()
    def generate(
        self,
        prompt: str,
        max_gen_len: int = None,
        memory_saver_div: int = 1, # defaults to full max_seq_len**2 memory use. must be power of 2
        temperature: float = 0.6, # default value in meta's code
        top_p: float = 0.9, # default value in meta's code
        top_k: int = tokenizer.vocab_len, # meta's code doesn't bother with topk
    ) -> str:
        """ Wrapper around sampler() that deals with manipulation of the sequence """

        # ensuring memory_saver_div, the setting that affects our kv caching, will work
        assert ((memory_saver_div & (memory_saver_div-1)) == 0) & (memory_saver_div > 0), f'memory_saver_div {memory_saver_div} must be power of 2'
        max_context_window = self.max_seq_len // memory_saver_div
        if max_context_window < self.max_seq_len:
            print(f'maximum attention matrix size will be {max_context_window}x{self.max_seq_len} rather than {self.max_seq_len}x{self.max_seq_len}\n')

        # encoding the prompt into token indices
        tokens = self.tokenizer.encode(prompt)

        if max_gen_len is None:
            max_gen_len = self.max_seq_len - len(tokens)
        elif max_gen_len + len(tokens) > self.max_seq_len:
            print(f'capping max_gen_len at max_seq_len={self.max_seq_len} including input\n')
            max_gen_len = self.max_seq_len - len(tokens)

        # turning it into the right tensor shape
        tokens = torch.tensor(tokens, device=self.params.device)
        tokens = tokens.unsqueeze(0) if len(tokens.shape)==1 else tokens # jic we need to add a batch dimension

        # the offset used for kv caching
        start_pos = max(tokens.shape[1] - max_context_window, 0)

        for i in range(max_gen_len):
            # get the model's output logits and ignore the loss, which would be a NoneType object
            logits = self.forward_inference(
                tokens[:,-max_context_window:],
                start_pos = start_pos,
                max_context_window = max_context_window
            )

            # sample th enext token to be used from the logit distribution
            next_token = self.Sampler(
                logits = logits,
                temperature = temperature,
                top_p = top_p,
                top_k = top_k
            )

            # add our new token to the sequence
            tokens = torch.cat((tokens, next_token), dim=1)

            # iterate the offset used in kv caching
            if tokens.shape[1] >= max_context_window:
                start_pos += 1

        # decode our list of tokens to an actual string
        output = self.tokenizer.decode(tokens.squeeze(0).tolist())

        return output

# 3. Train and test your own minLlama (or load mine)
<a id='three'></a>

### 3a. Setup
<a id='threea'></a>
a bunch of data, functions and objects you'll need that are not already included with the above architecture

In [None]:
# download the TinyShakespeare dataset
!wget -O input.txt https://raw.githubusercontent.com/evintunador/minLlama3/main/input.txt

# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# the first 200 characters. It's just one continuous text document with all of the works of shakespeare back-to-back
print(text[:200])

# here are all the unique characters that occur in this text and how many there are
chars = sorted(list(set(text)))
v = len(chars)
print(chars)
print(v)

# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

--2024-04-20 18:17:22--  https://raw.githubusercontent.com/evintunador/minLlama3/main/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-04-20 18:17:22 (36.6 MB/s) - ‘input.txt’ saved [1115394/1115394]

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u

### 3b. Training your own
<a id='threeb'></a>

you can feel free to train your own if you'd like, but i don't see a huge reason to do so in a colab notebook

In [None]:
# instantiate a new model
params = ModelArgs()
print(params)
model = Llama3(params, tokenizer).to(params.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

ModelArgs(dim=128, n_layers=8, n_heads=4, n_kv_heads=1, vocab_size=512, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=10000, max_batch_size=32, max_seq_len=512, device='cuda', dropout_rate=0.1)
2033.792 K parameters
Llama3(
  (tok_embeddings): Embedding(512, 128)
  (layers): ModuleList(
    (0-7): 8 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=128, out_features=128, bias=False)
        (wk): Linear(in_features=128, out_features=32, bias=False)
        (wv): Linear(in_features=128, out_features=32, bias=False)
        (wo): Linear(in_features=128, out_features=128, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=128, out_features=512, bias=False)
        (w2): Linear(in_features=512, out_features=128, bias=False)
        (w3): Linear(in_features=128, out_features=512, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output

In [None]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - params.max_seq_len, (batch_size,))
    x = torch.stack([data[i:i+params.max_seq_len] for i in ix])
    y = torch.stack([data[i+1:i+params.max_seq_len+1] for i in ix])
    x, y = x.to(params.device), y.to(params.device)
    return x, y

In [None]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 5): # to estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, targets=Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [None]:
# create a PyTorch optimizer
lr_init = 1e-2
weight_decay = 0.02
optimizer = torch.optim.AdamW(model.parameters(), lr=lr_init, weight_decay=weight_decay)

# how long we want to train for
max_iters = 1000

# how often we want to check & see how our loss is doing
eval_interval = 50

# Warmup setup
warmup_iters = 10  # Number of warmup iterations
warmup_factor = 1e-3  # Warmup factor (initial learning rate is multiplied by this factor)

lr_final = 1e-5  # Minimum learning rate

def lr_lambda(current_iter):
    if current_iter < warmup_iters:
        # Warmup phase
        return warmup_factor + (1 - warmup_factor) * current_iter / warmup_iters
    else:
        # Cosine decay phase with minimum learning rate
        decay_iters = max_iters - warmup_iters
        cosine_decay = 0.5 * (1 + math.cos(math.pi * (current_iter - warmup_iters) / decay_iters))
        return max(cosine_decay, lr_final / lr_init)

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [None]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', params.max_batch_size)

    # train
    logits, loss = model(xb, targets=yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    # Update the learning rate
    scheduler.step()

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, params.max_batch_size)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"step {iter:04d}: lr {current_lr:.6f}, train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0000: lr 0.001009, train loss 6.4597, val loss 6.4662, time elapsed: 1.52 seconds
step 0010: lr 0.009997, train loss 4.2119, val loss 4.2875, time elapsed: 4.13 seconds
step 0020: lr 0.009636, train loss 3.8806, val loss 3.9501, time elapsed: 6.90 seconds
step 0030: lr 0.008716, train loss 3.6459, val loss 3.7502, time elapsed: 9.69 seconds
step 0040: lr 0.007347, train loss 3.4446, val loss 3.6237, time elapsed: 12.48 seconds
step 0050: lr 0.005696, train loss 3.3880, val loss 3.5284, time elapsed: 15.27 seconds
step 0060: lr 0.003960, train loss 3.3236, val loss 3.4926, time elapsed: 18.06 seconds
step 0070: lr 0.002350, train loss 3.2639, val loss 3.4463, time elapsed: 20.86 seconds
step 0080: lr 0.001060, train loss 3.2208, val loss 3.4008, time elapsed: 23.66 seconds
step 0090: lr 0.000245, train loss 3.1910, val loss 3.4081, time elapsed: 26.46 seconds
step 0099: lr 0.000010, train loss 3.2181, val loss 3.4169, time elapsed: 29.06 seconds


### 3c. Alternatively, you can load the 2m parameter model I already trained
<a id='threec'></a>

In [None]:
# downloading it
!wget https://github.com/evintunador/minLlama3/raw/main/models/Llama3_2024-04-19%7C15-18-16.pth
!wget https://github.com/evintunador/minLlama3/raw/main/models/Llama3_2024-04-19%7C15-18-16.json

# here's the path to a minGemma model that i've trained with roughly 1m parameters
name = 'Llama3_2024-04-19|15-18-16'

# Deserialize the JSON file back to a dictionary
with open(f'{name}.json', 'r') as f:
    params_dict = json.load(f)

# Convert the dictionary back to a dataclass object
params = ModelArgs(**params_dict)
params.device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize a blank model
model = Llama3(params, tokenizer).to(params.device)

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = f'{name}.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path))
# REMEMBER TO CHANGE VALUES IN params TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

--2024-04-20 18:19:55--  https://github.com/evintunador/minLlama3/raw/main/models/Llama3_2024-04-19%7C15-18-16.pth
Resolving github.com (github.com)... 140.82.116.3
Connecting to github.com (github.com)|140.82.116.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/evintunador/minLlama3/main/models/Llama3_2024-04-19%7C15-18-16.pth [following]
--2024-04-20 18:19:55--  https://raw.githubusercontent.com/evintunador/minLlama3/main/models/Llama3_2024-04-19%7C15-18-16.pth
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9213026 (8.8M) [application/octet-stream]
Saving to: ‘Llama3_2024-04-19|15-18-16.pth’


2024-04-20 18:19:56 (140 MB/s) - ‘Llama3_2024-04-19|15-18-16.pth’ saved [9213026/9213026]

--2024-

Llama3(
  (tok_embeddings): Embedding(512, 128)
  (layers): ModuleList(
    (0-7): 8 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=128, out_features=128, bias=False)
        (wk): Linear(in_features=128, out_features=32, bias=False)
        (wv): Linear(in_features=128, out_features=32, bias=False)
        (wo): Linear(in_features=128, out_features=128, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=128, out_features=512, bias=False)
        (w2): Linear(in_features=512, out_features=128, bias=False)
        (w3): Linear(in_features=128, out_features=512, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=128, out_features=512, bias=False)
  (criterion): CrossEntropyLoss()
)

### 3d. Testing (performing inference)
<a id='threed'></a>

In [None]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R" # the classic line

In [None]:
# doing everything with default values
print(model.generate(input_str))

JULIET:
O Romeo, Romeo! wherefore art thou Romeo?

Nurse:
Go what thou dost say you to this hour in death.

ROMEO:
Then never now of your life of Romeo's chamber
What can be so burden to her son, thou art
To pluck it in the heavens of death,
And stand our brothers of the people's trial.

ROMEO:
Then makes me hear the measure of the world:
Not it is hope to be so burthen for them an hour,
Which she shall stay to the crown of the war.

ROMEO:
The weeds of war, and many sorrow of his head.

BENVOLIO:
Why, what say you shall be a kind of foot?

ROMEO:
The words of this cold part thou art:
Stay I say I mean in my life,
And I am near to the law of his death.

ROMEO:
That is the point that kill'd to the crown,
And all the crown of the other plain,
Which do shame were for what I am rather.

ROMEO:
Stay, where she 


##### now let's use memory_saver_div to take advantage of KV caching for linear scaling of memory usage with sequence length increase in exchange for potential quality degradation. memory_saver_div must be a power of 2, and it is used to calculate the maximum length of the query's sequence length dimension in the attention matrix

In [None]:
output = model.generate(
    input_str,
    max_gen_len = params.max_seq_len - len(input_str), # our model doesn't have a built-in <endoftext> token so we have to specify when to stop generating
    memory_saver_div = 8, # the largest value we'll allow our query sequence length to get. makes memory consumption linear with respect to sequence length
    temperature = 0.6, # this is the default value that Llama3's official code has set
    top_p = 0.9, # this is the default value that Llama3's official code has set
    top_k = 32, # meta's code doesn't actually implement top_k selection but i've added it anyways as an alternative
)
print(output)

maximum attention matrix size will be 64x512 rather than 512x512

JULIET:
O Romeo, Romeo! wherefore art thou Romeo?

JULIET:
Ay, my lord, I have a gentle man,
The happy man that would have been the banished.

JULIET:
Ay, I will have the duke of thy father's head,
And I will not plead to leave the purpose.

JULIET:
Ay, but thou wilt be so much a man
To steal into the law of a word of him.

JULIET:
What stand of your cousin, and holds to her?

JULIET:
What is the day?

Nurse:
Nay, then, my lord; here's a happy man!

JULIET:
I have not heard to be since for my sin,
Nor to my father to my part of thee,
To make a man and my servants in my heart,
And that I had been so much dead her lady:
Therefore, thou art doubt a part of mine honour,
To see this wretched blind to crown my soul,
Shall I did remain my servant love to me.

JULIET:
Nay, thou art words, as I say, I would
W
