## LLaMA
**Welcome** to the llama notebook. And I must say that this charpter will have several challenging points you need to pay attention to.
1. RMS-Norm
2. SwiGLU
3. Rope
4. KVcache
5. Grouped-Query-Attention

But feel free to get over with them. Cause Studying itself is just fun. So, enjoy it!

### 0. import pacakge

In [21]:
from dataclasses import dataclass
import torch.nn as nn
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
import math
import inspect
import tiktoken

from transformers import LlamaTokenizer

### 1. LLaMA model parameters
Not all of parameters are setting in the first place, some of them are added during the model construction.

In [2]:
@dataclass
class LLaMAconfig:
    # one-head dimension
    n_embedding: int = 128
    block_size = 1024

### 2. LLaMA model
we will implement the LLaMA model step by step. But you should know that in this part, I will NOT blend one module with another, which means I will not infer to KV cache when debug attention module.

#### 2.1 RMS-Norm (Root Mean Square Layer Normalization)
Question: What is the difference between **pre-norm** and **post-norm**? (It will be answered in the end of this part.)

Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. 

BSD 3-Clause License: https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.

see also: https://blog.csdn.net/yjw123456/article/details/138139970

<img src="./image/RMSNorm.png" alt="RMSNorm and LayerNorm" style="width: 350px; height: 200px;" />

In [3]:
class RMSNorm(nn.Module):
    def __init__(self, size: int, dim: int = -1, eps: float = 1e-5):
        super().__init__()
        # this is the W
        self.scale = nn.Parameter(torch.ones(size))
        self.eps = eps
        self.dim = dim

    def forward(self, x):
        # NOTE: the original RMSNorm paper implementation is not equivalent
        # norm_x = x.norm(2, dim=self.dim, keepdim=True)
        # rms_x = norm_x * d_x ** (-1. / 2)
        # x_normed = x / (rms_x + self.eps)
        
        # calculate the root of 1/H * sigma{(x_i)**2}
        norm_x = torch.mean(x * x, dim = self.dim, keepdim = True)
        # calculate rms_x, eps is applied to avoid devided 0.
        rms_x = torch.rsqrt(norm_x + self.eps)
        # calculate the normed_x
        normed_x = x * rms_x
        # attach the learning params
        scaled_x = normed_x * self.scale
        return scaled_x

# note: rmsnorm is designed to reduce computation caused by mean in layernorm, that can improve computational effiency and precision simultaneously.b

#### 2.2 SwiGLU

SwiGLU means a learning gated object function, element wise.

<img src="./image/SwiGLU.png" alt="SwiGLU" style="width: 450px; height: 260px;" />

In [4]:
# F.silu(x): x * sigmoid(x)

def mlp_silu(x):
    # here, the embedding dim will not change for showing and we will discuss this latter.
    fc1 = nn.Linear(x.size(-1), x.size(-1))
    fc2 = nn.Linear(x.size(-1), x.size(-1))
    # proj fc2 with SiLU
    x = fc1(x)
    gated_x = F.silu(fc2(x))
    # element wise multiply matrix
    output = x * gated_x
    return output

#### 2.3 ROPE
One about rope you need to pay attention is that rope encodes the **absolute position** with a rotation matrix and meanwhile incorporates the explicit **relative positive** dependency in **self-attention** formulation. --rope paper

We can try to understand sentence above with a logical line: absolute position -> self-attention -> relative position information

⭐there are **two** aspects:
- ✅ Definition of RoPE.
- ✅ Extention of RoPE to a long context.

##### 2.3.1 Definition of RoPE
reference 1: https://blog.csdn.net/weixin_43646592/article/details/130924280

reference 2: https://oi-wiki.org/math/complex/

⭐⭐⭐ (recommended) reference 3: https://blog.csdn.net/v_JULY_v/article/details/134085503 

NOTE: I had beed finished the math explanation of the RoPE, you can find it in the `Addition` part of my github repository.

In [5]:
# 1. Pre-compute rope -> to preduce the cos and sin matrix
batch_size = 32
seq_len = 1024
n_embed = 128
n_head = 8
# base is used to calculate \theta
base = 10000
# \theta is the rotary angle
theta = 1.0 / (base ** (torch.arange(0, n_embed, 2) / n_embed))
print('theta shape: \n', theta.shape)
print('theta: \n', theta[:5])

theta shape: 
 torch.Size([64])
theta: 
 tensor([1.0000, 0.8660, 0.7499, 0.6494, 0.5623])


In [6]:
seq_idx = torch.arange(seq_len)
print('seq_idx shape: \n', seq_idx.shape)
# outer is the dot computation, idx_theta is the m\theta is the math equation.
idx_theta = torch.outer(seq_idx, theta).float()
print('id_theta shape: \n', idx_theta.shape)
print('id_theta: \n', idx_theta[:2, :5])

seq_idx shape: 
 torch.Size([1024])
id_theta shape: 
 torch.Size([1024, 64])
id_theta: 
 tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.8660, 0.7499, 0.6494, 0.5623]])


In [7]:
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim= -1)
print('cache shape: \n', cache.shape)
# we need to explain this transformation
cos = torch.cos(idx_theta)
print('cos: \n', cos[:5, :5])
sin = torch.sin(idx_theta)
print('sin: \n', sin[:5, :5])
print('cache: \n', cache[:5, :5, :])
# In the cache last dimension, the first column is cos(m\theta), the second column is sin(m\theta). m is the position and theta is the angle with dimension.

cache shape: 
 torch.Size([1024, 64, 2])
cos: 
 tensor([[ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000],
        [ 0.5403,  0.6479,  0.7318,  0.7965,  0.8460],
        [-0.4161, -0.1604,  0.0709,  0.2687,  0.4315],
        [-0.9900, -0.8558, -0.6279, -0.3685, -0.1160],
        [-0.6536, -0.9485, -0.9899, -0.8556, -0.6277]])
sin: 
 tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.8415,  0.7617,  0.6816,  0.6047,  0.5332],
        [ 0.9093,  0.9870,  0.9975,  0.9632,  0.9021],
        [ 0.1411,  0.5173,  0.7783,  0.9296,  0.9933],
        [-0.7568, -0.3167,  0.1415,  0.5176,  0.7785]])
cache: 
 tensor([[[ 1.0000,  0.0000],
         [ 1.0000,  0.0000],
         [ 1.0000,  0.0000],
         [ 1.0000,  0.0000],
         [ 1.0000,  0.0000]],

        [[ 0.5403,  0.8415],
         [ 0.6479,  0.7617],
         [ 0.7318,  0.6816],
         [ 0.7965,  0.6047],
         [ 0.8460,  0.5332]],

        [[-0.4161,  0.9093],
         [-0.1604,  0.9870],
         [ 0.0709,  0.9975],
   

In [8]:
# 2. apply rope -> there if an input q, and we transit it to right form and then multiply it with cache above.
## we can construct an input when debug
## an input may look like: ( batch_size, seq_len, num_head, dim)
x = torch.randn((batch_size, seq_len, n_head, n_embed))
print('x shape: \n', x.shape)
## 1. trucate to avoid length out of range
seq_len = x.size(1)
cache = cache[:seq_len]
## 2. reshape the x -> let the last two dim -> (-1, 2) -> (128) -> (64, 2)
x_ = x.reshape(*x.shape[:-1], -1, 2)
print('x_ shape: \n', x_.shape)
rope_cache = cache.view(1, x_.size(1), 1, x_.size(3), 2)
print('rope_cache shape: \n', rope_cache.shape)

## then we compute the rope output according to the formulation, ... means all dimension except the pointed dimension.
x_0 = x_[..., 0]
x_1 = x_[..., 1]
print('x_0 shape: \n', x_0.shape)
rope_cache_0 = rope_cache[..., 0]
rope_cache_1 = rope_cache[..., 1]
print('rope_cache_0 shape: \n', rope_cache_0.shape)

## * is element-wise matrix-multiplying
# In even dimension: 0,2,...
x_even = x_0 * rope_cache_0 - x_1 * rope_cache_1
print('x_even shape: \n', x_even.shape)
# In odds dimension: 1, 3,...
x_odd = x_1 * rope_cache_0 + x_0 * rope_cache_1
print('x_odd shape: \n', x_odd.shape)
out_put = torch.stack([x_even, x_odd], dim = -1)
print('out_put shape: \n', out_put.shape)
# reshape from the third dimension
out_put = out_put.flatten(3)
print('out_put shape: \n', out_put.shape)

# output shape is indentity with inputshape
# you may also see RoPE notebook to see more information.

# we can conclude code above
def bulid_rope_cache(
        seq_len:int,
        n_embed: int,
        dtype: torch.dtype,
        device: torch.device,
        base: int = 10000
):
    # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
    theta = 1.0 / (base ** (torch.arange(0, n_embed, 2, dtype=dtype, device=device) / n_embed))
    # Create position indexes `[0, 1, ..., seq_len - 1]`
    seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
    # Calculate the product of position index and $\theta_i$
    idx_theta = torch.outer(seq_idx, theta).float()
    cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
    # this is to mimic the behaviour of complex32, else we will get different results
    if dtype in (torch.float16, torch.bfloat16, torch.int8):
        cache = cache.half() 
    return cache
         
def apply_rope(x: torch.tensor, rope_cache):
    T = x.size(1)
    rope_cache = rope_cache[:T]

    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)

    rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2)

    x_out = torch.stack(
        [
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )

    x_out = x_out.flatten(3)
    
    return x_out.type_as(x)

x shape: 
 torch.Size([32, 1024, 8, 128])
x_ shape: 
 torch.Size([32, 1024, 8, 64, 2])
rope_cache shape: 
 torch.Size([1, 1024, 1, 64, 2])
x_0 shape: 
 torch.Size([32, 1024, 8, 64])
rope_cache_0 shape: 
 torch.Size([1, 1024, 1, 64])
x_even shape: 
 torch.Size([32, 1024, 8, 64])
x_odd shape: 
 torch.Size([32, 1024, 8, 64])
out_put shape: 
 torch.Size([32, 1024, 8, 64, 2])
out_put shape: 
 torch.Size([32, 1024, 8, 128])


#### 2.4 KVcache ⭐⭐⭐
KV cache is used to reduce computational resources during **inference**. In general, we need to calculate self-attention every time when we pass the next token to raw sentences, then calculate next next token.

The idea is that we can store all the previous k v in self-attention calculating so that we can reduce repeat computation.

In [9]:
# First, we construct an input, shape is: (1, input_len), assumpt that embedding dim is 128, vocab_size is 100
input_len = 10
vocab_size = 100
embedding_size = 128
x = torch.randint(0, 100, (1, input_len))
print('x shape: \n', x.shape)
# **WITHOUT KVCAHCE**
# 1. first, we need to see the raw attention **WITHOUT KVCAHCE**
# When inference, we passed the inputs into the model.
## 1.1 pos embedding, here we just nn.embedding
### 1. first iteration
embedding = nn.Embedding(vocab_size, embedding_size)
embed_x = embedding(x)
print('embed_x shape: \n', embed_x.shape)
wq = nn.Linear(embedding_size, embedding_size)
wk = nn.Linear(embedding_size, embedding_size)
wv = nn.Linear(embedding_size, embedding_size)
Q = wq(embed_x)
print('Q shape: \n', Q.shape)
K = wk(embed_x)
print('K shape: \n', K.shape)
V = wv(embed_x)
print('V shape: \n', V.shape)
score = Q @ K.transpose(2, 1)
print('score shape: \n', score.shape)
attn = score @ V
print('attn shape: \n', attn.shape)
# we discard all the other procudures
lm_head = nn.Linear(embedding_size, vocab_size)
lm_x = lm_head(attn)
print('lm_x shape: \n', lm_x.shape)
lm_x_softm = nn.Softmax(dim=-1)(lm_x)
print('lm_x_softm shape: \n', lm_x_softm.shape)
output = torch.argmax(lm_x_softm[:,-1,:])
print('output: \n', output)

### 2. second iteration/generation
new_x = torch.cat((x, output.unsqueeze(0).view(1, -1)), dim=-1)
print('new_x shape: \n', new_x.shape)
# then we repeat operations above
embed_new_x = embedding(new_x)
Q = wq(embed_new_x)
print('Q shape: \n', Q.shape)
K = wk(embed_new_x)
print('K shape: \n', K.shape)
V = wv(embed_new_x)
print('V shape: \n', V.shape)
score = Q @ K.transpose(2, 1)
print('score shape: \n', score.shape)
attn = score @ V
print('attn shape: \n', attn.shape)
lm_head = nn.Linear(embedding_size, vocab_size)
lm_x = lm_head(attn)
print('lm_x shape: \n', lm_x.shape)
lm_x_softm = nn.Softmax(dim=-1)(lm_x)
print('lm_x_softm shape: \n', lm_x_softm.shape)
output = torch.argmax(lm_x_softm[:,-1,:])
print('output: \n', output)


x shape: 
 torch.Size([1, 10])
embed_x shape: 
 torch.Size([1, 10, 128])
Q shape: 
 torch.Size([1, 10, 128])
K shape: 
 torch.Size([1, 10, 128])
V shape: 
 torch.Size([1, 10, 128])
score shape: 
 torch.Size([1, 10, 10])
attn shape: 
 torch.Size([1, 10, 128])
lm_x shape: 
 torch.Size([1, 10, 100])
lm_x_softm shape: 
 torch.Size([1, 10, 100])
output: 
 tensor(44)
new_x shape: 
 torch.Size([1, 11])
Q shape: 
 torch.Size([1, 11, 128])
K shape: 
 torch.Size([1, 11, 128])
V shape: 
 torch.Size([1, 11, 128])
score shape: 
 torch.Size([1, 11, 11])
attn shape: 
 torch.Size([1, 11, 128])
lm_x shape: 
 torch.Size([1, 11, 100])
lm_x_softm shape: 
 torch.Size([1, 11, 100])
output: 
 tensor(39)


We can see that, in attention calculation, iteration 1 and 2 both calculate the whole attention, and we can see in iteration 1 and 2, the wq and wk are all the same！

In [10]:

# **WITH KVCAHCE**
## In KVcache, we can just store the first calculated k and v, and then concat new k an v to generate a complete whole length k and v
input_len = 10
vocab_size = 100
embedding_size = 128
x = torch.randint(0, 100, (1, input_len))
print('x shape: \n', x.shape)
embedding = nn.Embedding(vocab_size, embedding_size)
embed_x = embedding(x)
print('embed_x shape: \n', embed_x.shape)
wq = nn.Linear(embedding_size, embedding_size)
wk = nn.Linear(embedding_size, embedding_size)
wv = nn.Linear(embedding_size, embedding_size)
Q = wq(embed_x)
print('Q shape: \n', Q.shape)
K = wk(embed_x)
print('K shape: \n', K.shape)
V = wv(embed_x)
print('V shape: \n', V.shape)
####################KVCACHE##########################
cache_K = K
print('cache_K shape: \n', cache_K.shape)
cache_V = V
print('cache_V shape: \n', cache_V.shape)
####################KVCACHE##########################
score = Q @ K.transpose(2, 1)
print('score shape: \n', score.shape)
attn = score @ V
print('attn shape: \n', attn.shape)
# we discard all the other procudures
lm_head = nn.Linear(embedding_size, vocab_size)
lm_x = lm_head(attn)
print('lm_x shape: \n', lm_x.shape)
lm_x_softm = nn.Softmax(dim=-1)(lm_x)
print('lm_x_softm shape: \n', lm_x_softm.shape)
output = torch.argmax(lm_x_softm[:,-1,:])
print('output: \n', output)

### 2. second iteration/generation
#############x is different######################
new_x = output.unsqueeze(0).view(1, -1)
print('new_x shape: \n', new_x.shape)
# then we repeat operations above
## and notice the QKV shape
embed_new_x = embedding(new_x)
Q = wq(embed_new_x)
print('Q shape: \n', Q.shape)
K = wk(embed_new_x)
print('K shape: \n', K.shape)
V = wv(embed_new_x)
print('V shape: \n', V.shape)
###################concat cache####################
K = torch.concat((cache_K, K),dim=1)
print('cached_K shape: \n', K.shape)
V = torch.concat((cache_V, V),dim=1)
print('cached_V shape: \n', V.shape)
cache_K = K
cache_V = V
###################concat cache####################
score = Q @ K.transpose(2, 1)
print('score shape: \n', score.shape)
attn = score @ V
print('attn shape: \n', attn.shape)
lm_head = nn.Linear(embedding_size, vocab_size)
lm_x = lm_head(attn)
print('lm_x shape: \n', lm_x.shape)
lm_x_softm = nn.Softmax(dim=-1)(lm_x)
print('lm_x_softm shape: \n', lm_x_softm.shape)
output = torch.argmax(lm_x_softm[:,-1,:])
print('output: \n', output)

# In conclusion, the computation is reduced from Q[L+1,D] @ K[L+1,D] @ V[L+1,D] to  Q[1,D] @ K[L+1,D] @ V[L+1,D]

x shape: 
 torch.Size([1, 10])
embed_x shape: 
 torch.Size([1, 10, 128])
Q shape: 
 torch.Size([1, 10, 128])
K shape: 
 torch.Size([1, 10, 128])
V shape: 
 torch.Size([1, 10, 128])
cache_K shape: 
 torch.Size([1, 10, 128])
cache_V shape: 
 torch.Size([1, 10, 128])
score shape: 
 torch.Size([1, 10, 10])
attn shape: 
 torch.Size([1, 10, 128])
lm_x shape: 
 torch.Size([1, 10, 100])
lm_x_softm shape: 
 torch.Size([1, 10, 100])
output: 
 tensor(56)
new_x shape: 
 torch.Size([1, 1])
Q shape: 
 torch.Size([1, 1, 128])
K shape: 
 torch.Size([1, 1, 128])
V shape: 
 torch.Size([1, 1, 128])
cached_K shape: 
 torch.Size([1, 11, 128])
cached_V shape: 
 torch.Size([1, 11, 128])
score shape: 
 torch.Size([1, 1, 11])
attn shape: 
 torch.Size([1, 1, 128])
lm_x shape: 
 torch.Size([1, 1, 100])
lm_x_softm shape: 
 torch.Size([1, 1, 100])
output: 
 tensor(68)


#### 2.5 Grouped-Query-Attention
The KV cache can significantly reduce inference computational burdens. But the ocupation of GPU was massive. And there are mainly four ways to make kvcache effiency: 
- cut length; 
- reduce self-attention head-nums - MQA/GQA; 
- quantization of kvcache;
- paged attention 

In lit-LLaMA, adopted cut length by rolling; And we will introduce another method GQA.

In general, there are several heads during self-attention and we grouped them, such as: 8 heads -> 2 grouped heads, means that there are 4 heads in one group. In one group, the kv just save once and others were copied from this saved kv.

<img src="./image/GQA.png" alt="GQA" style="width: 500px; height: 450px;" />

##### 2.5.1 parameters
we assumed the embedding dim is 18, and head num is 6, it means there are 3 channels in one head. Hence that:
- Q dim 18, 6 heads -> (Q1, Q2, Q3, Q4, Q5, Q6)

- K dim 18, 2 grouped, -> (grouped_K1, grouped_K2) 
-> (grouped_K1_copy1, grouped_K1_copy2, grouped_K1_copy3,
grouped_K2_copy1, grouped_K2_copy2,grouped_K2_copy3)
- grouped_K1 + grouped_K2, dim: 6

In [11]:
@dataclass
class ModelArgs:
    dim: int = 18
    # attention layers
    n_layers: int = 1
    # q heads
    n_heads: int = 6
    # kv grouped heads
    n_kv_heads: int =  2
    vocab_size: int = -1
    multiple_of: int = 10  # make SwiGLU hidden layer size multiple of large power of 2
    # rms norm eps
    norm_eps: float = 1e-5
    # theta bese
    rope_theta: float = 500000
    max_batch_size: int = 2
    max_seq_len: int = 17
    model_parallel_size = 1

config = ModelArgs()

##### 2.5.2 GQA Implementation

In [12]:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    batch_size, seq_len, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    # x[:, :, :, None, :] means that add one new dim in None dim
    # x: [b, s, h, d] -> [b, s, h, 1, d]
    # expand operation will expand vector dims to desired shape
    # BUT you may notice that it may NOT copy data substanially
    # As a fact, it just looks like the vector is copied by torch broadcast
    return (
        x[:, :, :, None, :]
        .expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim) # 
        .reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
    )
# debug
k = torch.randn(1, 7, 2, 3)
repeat_k = repeat_kv(k, 3)
print(f'repeated k shape: \n{repeat_k.shape}')
print(f'k: \n{k[0, 0, :, :]}')
print(f'repeated k: \n{repeat_k[0, 0, :, :]}')

repeated k shape: 
torch.Size([1, 7, 6, 3])
k: 
tensor([[ 1.1777, -0.6196, -0.3396],
        [ 0.1806,  0.8465, -0.8715]])
repeated k: 
tensor([[ 1.1777, -0.6196, -0.3396],
        [ 1.1777, -0.6196, -0.3396],
        [ 1.1777, -0.6196, -0.3396],
        [ 0.1806,  0.8465, -0.8715],
        [ 0.1806,  0.8465, -0.8715],
        [ 0.1806,  0.8465, -0.8715]])


In [13]:
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        # this is the parallel parameter
        model_parallel_size = args.model_parallel_size
        # local heads means that total heads are distributed into several nodes
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(in_features=args.dim, out_features=args.n_heads * self.head_dim,bias=False)
        self.wk = nn.Linear(in_features=args.dim, out_features=args.n_kv_heads * self.head_dim,bias=False)
        self.wv = nn.Linear(in_features=args.dim, out_features=args.n_kv_heads * self.head_dim,bias=False)
        self.wo = nn.Linear(in_features=args.n_heads * self.head_dim, out_features=args.dim,bias=False)

        print(f'wq_shape: \n\t{self.wq.weight.shape}')
        print(f'wk_shape: \n\t{self.wk.weight.shape}')
        print(f'wv_shape: \n\t{self.wv.weight.shape}')
        print(f'wo_shape: \n\t{self.wo.weight.shape}')

        # kvcache, since that we grouped the kv, so we just need to store one grouped k and v
        self.cache_k = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        )

        self.cache_v = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        )

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        bsz, seqlen, _ = x.shape
        # notice that the shape of k and v
        # we can find that 'Grouped' DOES NOT mean that calculate the whole k and v first then split them into groups
        # while calculate one grouped k and v, then expand them into whole length
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        
        # you may notice the position of applying rope
        # calculate qkv -> apply rope -> cache kv -> expand
        # xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) # ignore RoPE
        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        # By default, this is the inference environment
        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv 

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        print(f'q shape: \n\t{xq.shape}')
        print(f'keys shape: \n\t{keys.shape}')
        print(f'values shape: \n\t{values.shape}')

        keys = repeat_kv(keys, self.n_rep)
        values = repeat_kv(values, self.n_rep)

        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        print(f'q shape: \n\t{xq.shape}')
        print(f'repeated_keys shape: \n\t{keys.shape}')
        print(f'repeated_values shape: \n\t{values.shape}')

        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask

        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values) 
        print(f'output shape: \n\t{output.shape}')
        # (b, h, l, h_d) -> (b, l, d)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        print(f'concated shape: \n\t{output.shape}')

        return self.wo(output)

# debug
attn = Attention(config)
print(f'attn shape with GQA: \n\t{attn}')
batch_size = config.max_batch_size
seq_len = config.max_seq_len
embedding_dim = config.dim
x_src = torch.randn(batch_size, seq_len, embedding_dim)
print(f'x_src shape: \n\t{x_src.shape}')
y = attn(x_src, start_pos = 0, freqs_cis=None, mask=None)
print(f'y shape: \n\t{y.shape}')

wq_shape: 
	torch.Size([18, 18])
wk_shape: 
	torch.Size([6, 18])
wv_shape: 
	torch.Size([6, 18])
wo_shape: 
	torch.Size([18, 18])
attn shape with GQA: 
	Attention(
  (wq): Linear(in_features=18, out_features=18, bias=False)
  (wk): Linear(in_features=18, out_features=6, bias=False)
  (wv): Linear(in_features=18, out_features=6, bias=False)
  (wo): Linear(in_features=18, out_features=18, bias=False)
)
x_src shape: 
	torch.Size([2, 17, 18])
q shape: 
	torch.Size([2, 17, 6, 3])
keys shape: 
	torch.Size([2, 17, 2, 3])
values shape: 
	torch.Size([2, 17, 2, 3])
q shape: 
	torch.Size([2, 6, 17, 3])
repeated_keys shape: 
	torch.Size([2, 6, 17, 3])
repeated_values shape: 
	torch.Size([2, 6, 17, 3])
output shape: 
	torch.Size([2, 6, 17, 3])
concated shape: 
	torch.Size([2, 17, 18])
y shape: 
	torch.Size([2, 17, 18])


#### 2.5 LLAMA Model Construction
In this step, we will assemble a llama model using modules just like other chapters before.

<img src="./image/LLaMA.png" alt="LLaMA" style="width: 600px; height: 400px;" />

In [18]:
# This version of LLaMA implements lit-llama
##### 2.5.1 LLAMA Config
MaskCache = torch.Tensor
RoPECache = torch.Tensor
KVCache = Tuple[torch.Tensor, torch.Tensor]
@dataclass
class LLaMAConfig:
    # seq_len
    block_size: int = 2048
    # vocab_size
    vocab_size: int = 100
    padded_vocab_size: Optional[int] = None
    n_layer: int = 32
    n_head: int = 32
    n_embed: int = 4096
    
##### 2.5.2 CausalSelfAttention
class CausalSelfAttention(nn.Module):

    def __init__(self, config:LLaMAConfig):
        super().__init__()
        # check head
        assert config.n_embed % config.n_head == 0

        # qkv in a batch
        self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed)

        # output proj
        self.c_proj = nn.Linear(config.n_embed, config.n_embed)

        # settings
        self.n_head = config.n_head
        self.n_embed = config.n_embed
        self.block_size = config.block_size

    def forward(self,
                x: torch.tensor,
                rope: torch.tensor,
                mask: torch.tensor,
                max_seq_length: int,
                input_pos: Optional[torch.tensor] = None,
                kv_cache: Optional[KVCache] = None,
                ):
            # (Batch_size, seq_len, embedding_dim)
            B, T, C = x.size()
            # calculate query, key, values for all heads in batch and move head forward to be the batch dim
            q, k, v = self.c_attn(x).split(self.n_embed, dim=2)

            head_size = C // self.n_head
            k = k.view(B, T, self.n_head, head_size)
            q = q.view(B, T, self.n_head, head_size)
            v = v.view(B, T, self.n_head, head_size)

            q = apply_rope(q, rope)
            k = apply_rope(k, rope)

            k = k.transpose(1, 2)  # (B, nh, T, hs)
            q = q.transpose(1, 2)  # (B, nh, T, hs)
            v = v.transpose(1, 2)  # (B, nh, T, hs)

            # kvcache:
            if kv_cache is not None:
                cache_k, cache_v = kv_cache
                if input_pos[-1] >= max_seq_length:
                    input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
                    # torch.roll: https://pytorch.org/docs/stable/generated/torch.roll.html
                    cache_k = torch.roll(cache_k, -1, dims=2)
                    cache_v = torch.roll(cache_v, -1, dims=2)
                # index_copy: https://blog.csdn.net/hjxu2016/article/details/130161239
                # means insert k into cache_k with input_pos in dim 2
                k = cache_k.index_copy(2, input_pos, k)
                v = cache_v.index_copy(2, input_pos, v)
                kv_cache = k, v
            
            # efficient attention using Flash Attention CUDA kernels
            y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
            y = y.transpose(1, 2).contiguous().view(B, T, C)
            y = self.c_proj(y)
            return y, kv_cache
    
##### 2.5.3 silu mlp
class MLP(nn.Module):
    def __init__(self,config:LLaMAConfig):
        super().__init__()
        hidden_dim = 4 * config.n_embed
        # 2/3 hidden_dim
        n_hidden = int(2 * hidden_dim / 3)
        # let n_hidden is multiple 256
        if n_hidden % 256 != 0:
            n_hidden = n_hidden + 256 - (n_hidden % 256)
        
        self.c_fc1 = nn.Linear(config.n_embed, n_hidden, bias = False)
        self.c_fc2 = nn.Linear(config.n_embed, n_hidden, bias = False)
        self.c_proj = nn.Linear(n_hidden, config.n_embed, bias = False)
    
    def forward(self, x):
        x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
        x = self.c_proj(x)
        return x
    
##### 2.5.3 LLaMA block

class Block(nn.Module):
    def __init__(self, config: LLaMAConfig):
        super().__init__()
        self.rms_1 = RMSNorm(config.n_embed)
        self.attn = CausalSelfAttention(config)
        self.rms_2 = RMSNorm(config.n_embed)
        self.mlp = MLP(config)

    def forward(
        self,
        x: torch.Tensor,
        rope: RoPECache,
        mask: MaskCache,
        max_seq_length: int,
        input_pos: Optional[torch.Tensor] = None,
        kv_cache: Optional[KVCache] = None,
    ):
        h, new_kv_cache = self.attn(self.rms_1(x),
                                    rope,
                                    mask,
                                    max_seq_length,
                                    input_pos,
                                    kv_cache)
        # short cut 1
        x = x + h
        # short cut 2
        x = x + self.mlp(self.rms_2(x))
        return x, new_kv_cache
    
##### 2.5.3 LLaMA model
class LLaMA(nn.Module):
    def __init__(self, config: LLaMAConfig):
        super().__init__()
        assert config.padded_vocab_size is not None
        self.config = config
        self.lm_head = nn.Linear(config.n_embed, config.padded_vocab_size, bias=False)
        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.padded_vocab_size, config.n_embed),
                h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
                ln_f=RMSNorm(config.n_embed),
            )
        )

        self.rope_cache: Optional[RoPECache] = None
        self.mask_cache: Optional[MaskCache] = None
        self.kv_caches: List[KVCache] = []

    def forward(
            self,
            idx: torch.Tensor,
            targets: torch.tensor = None,
            max_seq_length: Optional[int] = None,
            input_pos: Optional[torch.Tensor] = None
    ):
        B, T = idx.size()
        block_size = self.config.block_size
        if max_seq_length is None:
            max_seq_length = block_size
        assert T <= max_seq_length, f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
        assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
        assert T <= block_size, f"Cannot forward sequence of length {T}, block size is only {block_size}"
        if self.rope_cache is None:
            self.rope_cache = self.build_rope_cache(idx)
        if self.mask_cache is None:
            self.mask_cache = self.build_mask_cache(idx)

        if input_pos is not None:
            # index_select: https://pytorch.org/docs/stable/generated/torch.index_select.html
            # rope: (T, n_embed/2, 2)
            rope = self.rope_cache.index_select(0, input_pos)
            # mask: (1, 1, :T, :T)
            mask = self.mask_cache.index_select(2, input_pos)
            mask = mask[:, :, :, :max_seq_length]
        else:
            rope = self.rope_cache[:T]
            mask = self.mask_cache[:, :, :T, :T]
        
        x = self.transformer.wte(idx)
        # during training, we will not use qvcache
        if input_pos is None:
            for block in self.transformer.h:
                x, _ = block(x, rope, mask, max_seq_length)
        else:
            # kvcaches: [layer1:(k,v), layer2:(k,v), layer3:(k,v)]
            if not self.kv_caches:
                head_size = self.config.n_embed // self.config.n_head
                cache_shape = (B, self.config.n_head, max_seq_length, head_size)
                # initial kv_cache
                self.kv_caches = [
                    (torch.zeros(cache_shape, device=x.device, dtype=x.dtype), torch.zeros(cache_shape, device=x.device, dtype=x.dtype))
                    for _ in range(self.config.n_layer)
                ]
            for i, block in enumerate(self.transformer.h):
                x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
        
        x = self.transformer.ln_f(x)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # we only need to calculate the next token
            logits = self.lm_head(x[:, [-1], :])
            loss = None
        return logits, loss
    
    def configure_optimizer(self, weight_decay, learning_rate, betas, device_type):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn,p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
         # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused = True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f'using fused adamW: {use_fused}')
        return optimizer
    
    def get_num_params(self, non_embedding = True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wte.weight.numel()
        return n_params
    
    def estimate_mfu(self, fwdbwd_per_iter, dt):
        """ estimate model flops (Floating Point Operations) utilization (MFU) in units of A100 bfloat16 peak FLOPS """
        # first estimate the number of flops we do per iteration.
        # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
        N = self.get_num_params()
        cfg = self.config
        # get the attn shape
        L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embed // cfg.n_head, cfg.block_size
        # per token calculated
        flops_per_token = 6*N + 12*L*H*Q*T
        # fwd and pwd need to iter every token
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        # express our flops throughput as ratio of A100 bfloat16 peak flops
        flops_achieved = flops_per_iter * (1.0/dt) # per second
        # A100 GPU bfloat16 peak flops is 312 TFLOPS
        flops_promised = 312e12
        mfu = flops_achieved / flops_promised
        return mfu
    
    @classmethod
    def from_name(cls, name: str):
        return cls(LLaMAConfig.from_name(name))
    
    def build_rope_cache(self, idx: torch.tensor):
        return bulid_rope_cache(
            seq_len=self.config.block_size,
            n_embed=self.config.n_embed // self.config.n_head,
            dtype=idx.dtype,
            device=idx.device,
        )

    def build_mask_cache(self, idx):
        ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool)
        # make a tril angle matrix -> (1, 1, block_size, block_size)
        return torch.tril(ones).unsqueeze(0).unsqueeze(0)
    



#### 3 Training and Inference

#### 3.1 training

In [19]:
import numpy as np
import os
from contextlib import nullcontext
import pickle
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
import time
# 1. load data
work_dir = os.getcwd()
data_dir = os.path.join(work_dir, 'data/shakespeare_char')
train_data = np.memmap(os.path.join(data_dir, "train.bin"), dtype=np.uint16, mode="r")
val_data = np.memmap(os.path.join(data_dir, "val.bin"), dtype=np.uint16, mode="r")
print(f'train_data shape: \n\t{train_data.shape}')
print(f'train_data shape: \n\t{val_data.shape}')
# 2.model config
@dataclass
class LLaMAConfig:
    # seq_len
    block_size: int = 2048
    # vocab_size
    vocab_size: int = 100
    padded_vocab_size: Optional[int] = None
    n_layer: int = 32
    n_head: int = 32
    n_embed: int = 4096
    bias: bool = False
    dropout: float = 0.0
    compile = True
    def __post_init__(self):
        self.padded_vocab_size = self.vocab_size + 64 - (self.vocab_size % 64) if self.padded_vocab_size is None else self.vocab_size

    @classmethod
    def from_name(cls, name: str):
        return cls(**llama_configs[name])


llama_configs = {
    "7B": dict(n_layer=32, n_head=32, n_embed=4096),
    "13B": dict(n_layer=40, n_head=40, n_embed=5120),
    "30B": dict(n_layer=60, n_head=52, n_embed=6656),
    "65B": dict(n_layer=80, n_head=64, n_embed=8192),
    "baby_llama": dict(n_layer=2, n_head=8, n_embed=128)
}
## specific parameter settings
# For shakespeare, choose smaller block size than vanilla LLaMA
out_dir = os.path.join(work_dir, 'results')
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
gradient_accumulation_steps = 5 * 8
block_size = 1024
batch_size = 12
# max_iters = 600000
max_iters = 1000
log_interval = 1
device = 'cuda'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
iter_num = 0
best_val_loss = 1e-9
always_save_checkpoint = False

# optimizer
weight_decay = 1e-1
learning_rate = 6e-4
decay_lr = True
beta1, beta2 = 0.9, 0.95
warmup_iters = 2000
lr_decay_iters = max_iters
min_lr = 6e-5

model_config = LLaMAConfig.from_name("baby_llama")
model_config.block_size = block_size
model_config.vocab_size = 100
print(f'model config: \n\t{model_config}')
# clip gradients at this value, or disable if == 0.0
grad_clip = 1.0

# evaluate
eval_interval = 2000
eval_iters = 200
eval_only = False

# 3. load model
model = LLaMA(model_config)
print(f'model arc: \n\t{model}')
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False)

config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
print('config_keys: \n\t', config_keys)
#exec(open('configurator.py').read()) # overrides from command line or config file
config = {k: globals()[k] for k in config_keys} # will be useful for logging
print('config: \n\t', config)

# 4. ddp training
ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
    # init process group, backend: nccl or gloo or mpi
    init_process_group(backend=backend)
    # rank is the GPU index, 1 GPU will be 0, 2 GPUs will be 0, 1 on global
    ddp_rank = int(os.environ['RANK'])
    # local rank is the GPU index on one node, 1 GPU will be 0, 2 GPUs will be 0, 1 on one node
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    # world size is the total number of GPUs * nodes
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f'cuda:{ddp_local_rank}'
    torch.cuda.set_device(device=device)
    # # this process will do logging, checkpointing etc.
    master_process = ddp_rank == 0
    # each process gets a different seed
    seed_offset = ddp_rank
    # world_size number of processes will be training simultaneously, so we can scale
    # down the desired gradient accumulation iterations per process proportionally
    assert gradient_accumulation_steps % ddp_world_size == 0
    gradient_accumulation_steps = gradient_accumulation_steps // ddp_world_size
else:
    # if not ddp, we are running on a single gpu, and one process
    master_process = True
    seed_offset = 0
    ddp_world_size = 1

tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
print(f'gradient_accumulation_steps: {gradient_accumulation_steps}')
print(f'ddp_world_size: {ddp_world_size}')
print(f'batch_size: {batch_size}')
print(f'block_size: {block_size}')
print(f'tokens per iter: {tokens_per_iter}')

# ddp: 0; non-ddp: True
if master_process:
    os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(42 + seed_offset)
# allow tf32 on matmul
torch.backends.cuda.matmul.allow_tf32 = True
# allow tf32 on cudnn
torch.backends.cudnn.allow_tf32 = True
# for later use in torch.autocast
device_type = 'cuda' if 'cuda' in device else 'cpu'
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
# automatic mixed precison,amp, ofter blend with gradscaler
# when we train the model with mixed precison, we may need a gradscaler to shrink gradients to avoid Gradient Underflow
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# 5. load vocab
meta_path = os.path.join(data_dir, 'meta.pkl')
print(f'meta data: {meta_path}')
meta_vocab_size = None
if os.path.exists(meta_path):
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    meta_vocab_size = meta['vocab_size']
    print(f'found vocab_size = {meta_vocab_size} (inside {meta_path})')

# 6. ddp model settings
model.to(device)
# unwrap DDP container if needed
raw_model = model.module if ddp else model
# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
# optimizer
optimizer = model.configure_optimizer(
    weight_decay=weight_decay,
    learning_rate=learning_rate,
    betas=(beta1, beta2),
    device_type=device_type
)

# checkpoint
checkpoint = None
if compile:
    print('Compiling the model>>>')
    unoptimizer_model = model
    # require torch >= 2.0
    model = torch.compile(model)
    model.to(device)

if ddp:
    model = DDP(model, device_ids=[ddp_local_rank])

#### 7.prepare batch
def get_batch(split):
    data = train_data if split == 'train' else val_data
    # sample from range(len(data) - block_size) and the shape is (batch_size,)
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    # shift right as label
    y = torch.stack([torch.from_numpy((data[i+1 : i + block_size + 1]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device=device, non_blocking = True), y.pin_memory().to(device=device, non_blocking = True)
    else:
        x, y = x.to(device), y.to(device)

    return x, y

X, Y = get_batch('train')
print(f'X shape: {X.shape}, X: {X[0,:10]}')
print(f'Y shape: {Y.shape}, Y: {Y[0,:10]}')

#### 8. training record
t0 = time.time()
local_iter_num = 0
running_mfu = -1.0

## 9.learning rate settings
def get_lr(iter):
    # 1. linear warmup for warmup_iters steps
    if iter < warmup_iters:
        return learning_rate * (iter / warmup_iters)
    # 2. set the minimum lr when the iter more than decay iters
    if iter > lr_decay_iters:
        return min_lr
    # 3. in between, use cosine decay
    decay_ratio = (iter - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    # coeff ranges from 0-1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (learning_rate - min_lr)

#### 4.4 training loop
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            # mixed precison need to be closed
            with ctx:
                 logits, loss = model(X, Y)
            # save the loss during each iter
            losses[k] = loss.item()
        # get the mean loss across one eval iteration
        out[split] = losses.mean()
    model.train()
    return out

while True:
    lr = get_lr(iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
    # evaluate the loss on train/val sets and write checkpoints
    # you may notice that this is the pre-checking step rather than training itself
    if iter_num % eval_interval == 0 and master_process:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        # if loss lower than best loss we set, we will update the loss to current loss
        if losses['val'] < best_val_loss or always_save_checkpoint:
            best_val_loss = losses['val']
            if iter_num > 0:
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': model_config,
                    'iter_num': iter_num,
                    'best_val_loss': best_val_loss,
                    'config': config,                    
                }

                print(f'saving checkpoint to {out_dir}')
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
    if iter_num == 0 and eval_only:
        break

    # forward backward update, with optional gradient accumulation to simulate larger batch size
    # and using the GradScaler if data type is float16
    for micro_step in range(gradient_accumulation_steps):
        if ddp:
            # in DDP training we only need to sync gradients at the last micro step.
            # the official way to do this is with model.no_sync() context manager, but
            # I really dislike that this bloats the code and forces us to repeat code
            # looking at the source of that context manager, it just toggles this variable.
            
            # only sync gradients at last
            # it means that if in ddp, we only need to forward pass when last step
            model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1 )
        with ctx:
            logits, loss = model(X, Y)
            loss = loss / gradient_accumulation_steps
        # immediately async prefetch next batch while model is doing the forward pass on the GPU
        X, Y = get_batch('train')
        # backward pass, with gradient scaling if training in fp16
        scaler.scale(loss).backward()
    
    # clip the gradient
    if grad_clip != 0.0:
        # if we choose to clip gradients, we need to unscale the gradients first to clip the right gradients
        scaler.unscale_(optimizer)
        # grad clip is the max gradients during training to avoid gradient explosion
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    
    # step the optimizer and scaler if training in fp16
    scaler.step(optimizer)
    scaler.update()
    # flush the gradients as soon as we can, no need for this memory anymore
    optimizer.zero_grad(set_to_none=True)

    # timing and logging

    t1 = time.time()
    dt = t1 - t0
    t0 = t1

    # if not in ddp, we use this logging
    if iter_num  % log_interval == 0 and master_process:
        # get loss as float. note: this is a CPU-GPU sync point
        # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
        lossf = loss.item() * gradient_accumulation_steps
        if local_iter_num >= 5:
            mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
            # 0.9 and 0.1 can dynamticly adjust to monitor training flop
            running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
    iter_num += 1
    local_iter_num += 1    
    # this notebool is just for presentation, so the max_iters will set to 10000, in raw project, it will be set to 600000.
    if iter_num > max_iters:
        break
    
# ddp needs init and destory
if ddp:
    destroy_process_group()

train_data shape: 
	(1003854,)
train_data shape: 
	(111540,)
model config: 
	LLaMAConfig(block_size=1024, vocab_size=100, padded_vocab_size=128, n_layer=2, n_head=8, n_embed=128, bias=False, dropout=0.0)
model arc: 
	LLaMA(
  (lm_head): Linear(in_features=128, out_features=128, bias=False)
  (transformer): ModuleDict(
    (wte): Embedding(128, 128)
    (h): ModuleList(
      (0-1): 2 x Block(
        (rms_1): RMSNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=128, out_features=384, bias=True)
          (c_proj): Linear(in_features=128, out_features=128, bias=True)
        )
        (rms_2): RMSNorm()
        (mlp): MLP(
          (c_fc1): Linear(in_features=128, out_features=512, bias=False)
          (c_fc2): Linear(in_features=128, out_features=512, bias=False)
          (c_proj): Linear(in_features=512, out_features=128, bias=False)
        )
      )
    )
    (ln_f): RMSNorm()
  )
)
config_keys: 
	 ['batch_size', 'seq_len', 'n_embed', 'n_head', 'b

KeyboardInterrupt: 

#### 3.2 Inference

In [22]:

# we will not start from stratch like nanoGPT instead of inference with our trained model above.
# But one thing you need to konw is when we save the model.state_dict
# the module name would have a prefix '_orig_mod_'
show_state_dict = model.state_dict().items()
# just like: 
print(f'model raw state dict: \n\t{list(show_state_dict)[0]}')
# so when you load a model you trained before, you may remove the prefix and the detail is in nanoGPT repertory -> sample.py

#### 1. load encode and decode
# This step means we need to convert word into ids, such as: 'I love you' -> '<SOS> 5 2 0 <EOS>'
# there are two ways to complement that:
# 1. load meta
meta_path = os.path.join(data_dir, 'meta.pkl')
if os.path.exists(meta_path):
    with open(meta_path, 'rb+') as f:
        meta = pickle.load(f)
        print(f'get dataset meta alphabet information: \n\t{meta}')
        stoi, itos = meta['stoi'], meta['itos']
        encode = lambda s: [stoi[c] for c in s]
        decode = lambda l: [itos[i] for i in l]
else:
    # 2. load from tiktoken
    # ok let's assume LLaMA-2 encodings by default
    print("No meta.pkl found, assuming LLaMA encodings...")
    enc = LlamaTokenizer.from_pretrained("facebook/llama")
    encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
    decode = lambda l: enc.decode(l)

#### 2. inference
# assuming that we start at the very beginning
start = 'I love'
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None,...])

# inference config
# assuming that we need to generate 2 responses for a sample
sample_num = 10
# max tokens generate for a sample
max_new_tokens = 1
# temperature 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
temperature = 0.8
# top_k retain only the top_k most likely tokens, clamp others to have 0 probability
top_k = 10
# run generation
with torch.no_grad():
    with ctx:
        for k in range(sample_num):
            # NOTE you may notice that the nanoGPT `generate` method is in the GPT model
            # but I will put it as a independent method for better understanding
            """
            Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
            the sequence max_new_tokens times, feeding the predictions back into the model each time.
            Most likely you'll want to make sure to be in model.eval() mode of operation for this.
            """
            for _ in range(max_new_tokens):
                # if the sequence context is growing too long we must crop it at block_size
                idx_cond = x if x.size(1) <= model_config.block_size else x[:, -model_config.block_size:]
                # forward the model to get the logits for the index in the sequence
                logits, _ = model(idx_cond)
                # pluck the logits at the final step and scale by desired temperature
                logits = logits[:, -1, :] / temperature
                # optionally crop the logits to only the top k options
                if top_k is not None:
                    # torch.topk returns two variables: v means value and _ means indices
                    v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                    # v[:, [-1]] means select values in logits that all lower than v last dimension
                    logits[logits < v[:, [-1]]] = float('-inf')
                # apply softmax to convert logits to (normalized) probabilities
                probs = F.softmax(logits, dim = -1)
                # sample from the distribution, num_samples means sample a idx once and NOT the max prob idx must be sampled
                # to make sure diversity!
                idx_next = torch.multinomial(probs, num_samples=1)
                # append sampled index to the running sequence and continue
                x = torch.cat((x, idx_next), dim = -1)

            #### 3.output
            print(decode(x[0].tolist()))
            print('-------------------------')         

model raw state dict: 
	('_orig_mod.lm_head.weight', tensor([[-0.0165, -0.1323,  0.1174,  ..., -0.1137, -0.0753,  0.0217],
        [-0.1176, -0.0574,  0.0647,  ..., -0.1205, -0.0097,  0.1088],
        [-0.0411, -0.0398,  0.0155,  ...,  0.0061, -0.0903, -0.0474],
        ...,
        [-0.0172,  0.0158, -0.0632,  ..., -0.0277,  0.0944, -0.1004],
        [-0.0146, -0.0777, -0.0558,  ...,  0.0378,  0.0180, -0.0529],
        [ 0.0328, -0.0672, -0.0782,  ...,  0.0088,  0.1077, -0.0218]],
       device='cuda:0'))
get dataset meta alphabet information: 
	{'vocab_size': 65, 'itos': {0: '\n', 1: ' ', 2: '!', 3: '$', 4: '&', 5: "'", 6: ',', 7: '-', 8: '.', 9: '3', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'X', 37: 'Y', 38: 'Z', 39: 'a', 40: 'b', 41: 'c', 42: 'd', 43: 'e', 44: 'f', 45: 'g', 46: 'h', 47: '