<a href="https://colab.research.google.com/github/ghadfield32/LLM_learning/blob/main/gpt_dev.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Building a GPT

Companion notebook to the [Zero To Hero](https://karpathy.ai/zero-to-hero.html) video on GPT.

Andrej Shakespeare collab ex:

In [64]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
#!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

Adding in my data from a word doc, Convert first to text and then allow it to continue

In [65]:
%%writefile functions/word_to_txt.py

import os
import re
from docx import Document

def clean_text(text):
    # remove all non-ascii characters
    text = ''.join([i if ord(i) < 128 else ' ' for i in text])
    # remove all characters that are not letters
    text = re.sub(r'[^a-zA-Z]', ' ', text)
    # remove multiple spaces
    text = re.sub(r'\s+', ' ', text)
    # remove leading and trailing spaces
    text = text.strip()
    return text

def load_docx(file_path):
    doc = Document(file_path)
    text = []
    for paragraph in doc.paragraphs:
        text.append(paragraph.text)
    return '\n'.join(text)

def save_text(file_path, text):
    with open(file_path, 'w', encoding='utf-8') as file:
        file.write(text)

def word_to_txt(input_file, output_file):
    text = load_docx(input_file)
    text = clean_text(text)
    save_text(output_file, text)

# Get the current working directory (assumed to be the root of the repository)
repo_root = os.getcwd()

# Define the relative paths
input_file = os.path.join(repo_root, 'my_data', 'Resume_updated.docx')
output_file = os.path.join(repo_root, 'my_data', 'input.txt')

word_to_txt(input_file, output_file)



Overwriting functions/word_to_txt.py


In [66]:
import os

repo_root = os.getcwd()

# Load the text
output_file = os.path.join(repo_root, 'my_data', 'input.txt')

with open(output_file, 'r', encoding='utf-8') as file:
    text = file.read()
#peak at first 1000 characters
print(text[:1000])

Data Engineer analyst associate GitHub GHadfield LinkedIn Geoff Hadfield Objective Innovative and results driven Associate Data Analyst and aspiring Data Scientist with a solid foundation in data analytics machine learning and software development Proven expertise in leveraging advanced analytical methods and predictive modeling to drive decision making and improve operational efficiency Seeking to apply my skills in data science to solve complex business challenges Education Master of Science in Data Science UWF College Expected Completion in Specialization in Machine Learning Deep Learning and Big Data Technologies Bachelor of Applied Science in Economics Computational Math Valencia College May Associate of Arts July Technical Skills Programming Languages Python R SAS SQL Data Analysis and Visualization Excel MATLAB AMPL SAS Tableau Power BI Machine Learning Deep Learning PyTorch Scikit learn XGBoost Random Forest Linear Regression Big Data Technologies Hadoop familiarity with managi

In [67]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  3736


Tokenization and Encoding:

    Convert text data into numerical tokens using a character-level approach.

In [68]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text))) #take the set then make it a list so it's ordered then sort it
vocab_size = len(chars) #number of unique characters
print(''.join(chars))
print(vocab_size)

 ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnoprstuvwxyz
51


Encoding and Decoding can be by unique character like the first example or it can be by sub words like the second example

Encoding: transforming text to numbers
Decoding: transforming the numbers back to text

In [69]:


# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) } #string to integer
itos = { i:ch for i,ch in enumerate(chars) } #integer to string
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

# encode the text
import tiktoken

enc = tiktoken.get_encoding('gpt2')
enc_text = enc.encode(text)
print(enc_text[:1000])
print(enc.decode(enc_text[:1000]))





[33, 34, 34, 0, 44, 33, 30, 42, 30]
hii there
[6601, 23164, 12499, 11602, 21722, 24739, 324, 3245, 27133, 24688, 11161, 3245, 37092, 43405, 876, 290, 2482, 7986, 22669, 6060, 44600, 290, 31483, 6060, 33374, 351, 257, 4735, 8489, 287, 1366, 23696, 4572, 4673, 290, 3788, 2478, 1041, 574, 13572, 287, 42389, 6190, 30063, 5050, 290, 33344, 21128, 284, 3708, 2551, 1642, 290, 2987, 13919, 9332, 48160, 284, 4174, 616, 4678, 287, 1366, 3783, 284, 8494, 3716, 1597, 6459, 7868, 5599, 286, 5800, 287, 6060, 5800, 33436, 37, 5535, 1475, 7254, 955, 24547, 287, 6093, 1634, 287, 10850, 18252, 10766, 18252, 290, 4403, 6060, 21852, 33399, 286, 27684, 5800, 287, 18963, 22476, 864, 16320, 35773, 5535, 1737, 22669, 286, 11536, 2901, 20671, 20389, 30297, 42860, 11361, 371, 35516, 16363, 6060, 14691, 290, 15612, 1634, 24134, 36775, 48780, 3001, 6489, 35516, 8655, 559, 4333, 20068, 10850, 18252, 10766, 18252, 9485, 15884, 354, 10286, 15813, 2193, 1395, 4579, 78, 455, 14534, 9115, 44800, 3310, 2234, 4403, 6060,

In [70]:
# let's now encode the entire text dataset and store it into a torch.Tensor
import torch # we use PyTorch: https://pytorch.org
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this

torch.Size([3736]) torch.int64
tensor([ 4, 26, 44, 26,  0,  5, 39, 32, 34, 39, 30, 30, 42,  0, 26, 39, 26, 37,
        49, 43, 44,  0, 26, 43, 43, 40, 28, 34, 26, 44, 30,  0,  7, 34, 44,  8,
        45, 27,  0,  7,  8, 26, 29, 31, 34, 30, 37, 29,  0, 12, 34, 39, 36, 30,
        29,  9, 39,  0,  7, 30, 40, 31, 31,  0,  8, 26, 29, 31, 34, 30, 37, 29,
         0, 15, 27, 35, 30, 28, 44, 34, 46, 30,  0,  9, 39, 39, 40, 46, 26, 44,
        34, 46, 30,  0, 26, 39, 29,  0, 42, 30, 43, 45, 37, 44, 43,  0, 29, 42,
        34, 46, 30, 39,  0,  1, 43, 43, 40, 28, 34, 26, 44, 30,  0,  4, 26, 44,
        26,  0,  1, 39, 26, 37, 49, 43, 44,  0, 26, 39, 29,  0, 26, 43, 41, 34,
        42, 34, 39, 32,  0,  4, 26, 44, 26,  0, 19, 28, 34, 30, 39, 44, 34, 43,
        44,  0, 47, 34, 44, 33,  0, 26,  0, 43, 40, 37, 34, 29,  0, 31, 40, 45,
        39, 29, 26, 44, 34, 40, 39,  0, 34, 39,  0, 29, 26, 44, 26,  0, 26, 39,
        26, 37, 49, 44, 34, 28, 43,  0, 38, 26, 28, 33, 34, 39, 30,  0, 37, 30,
        2

Train/Test Split:

    Split the data into training and validation sets.

In [71]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
print(train_data.shape, val_data.shape)

torch.Size([3362]) torch.Size([374])


In [72]:
block_size = 8
train_data[:block_size+1]

tensor([ 4, 26, 44, 26,  0,  5, 39, 32, 34])

In [73]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")


when input is tensor([4]) the target: 26
when input is tensor([ 4, 26]) the target: 44
when input is tensor([ 4, 26, 44]) the target: 26
when input is tensor([ 4, 26, 44, 26]) the target: 0
when input is tensor([ 4, 26, 44, 26,  0]) the target: 5
when input is tensor([ 4, 26, 44, 26,  0,  5]) the target: 39
when input is tensor([ 4, 26, 44, 26,  0,  5, 39]) the target: 32
when input is tensor([ 4, 26, 44, 26,  0,  5, 39, 32]) the target: 34


In [74]:
torch.manual_seed(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[35, 30, 28, 44,  0, 31, 40, 28],
        [ 0, 45, 43, 34, 39, 32,  0, 19],
        [39, 34, 39, 32,  0,  1,  0, 38],
        [26, 42,  0, 18, 30, 32, 42, 30]])
targets:
torch.Size([4, 8])
tensor([[30, 28, 44,  0, 31, 40, 28, 45],
        [45, 43, 34, 39, 32,  0, 19, 17],
        [34, 39, 32,  0,  1,  0, 38, 26],
        [42,  0, 18, 30, 32, 42, 30, 43]])
----
when input is [35] the target: 30
when input is [35, 30] the target: 28
when input is [35, 30, 28] the target: 44
when input is [35, 30, 28, 44] the target: 0
when input is [35, 30, 28, 44, 0] the target: 31
when input is [35, 30, 28, 44, 0, 31] the target: 40
when input is [35, 30, 28, 44, 0, 31, 40] the target: 28
when input is [35, 30, 28, 44, 0, 31, 40, 28] the target: 45
when input is [0] the target: 45
when input is [0, 45] the target: 43
when input is [0, 45, 43] the target: 34
when input is [0, 45, 43, 34] the target: 39
when input is [0, 45, 43, 34, 39] the target: 32
when input is [0, 

In [75]:
print(xb) # our input to the transformer

tensor([[35, 30, 28, 44,  0, 31, 40, 28],
        [ 0, 45, 43, 34, 39, 32,  0, 19],
        [39, 34, 39, 32,  0,  1,  0, 38],
        [26, 42,  0, 18, 30, 32, 42, 30]])


In [76]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets) # takes the logit with the best chance of being the next token

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
0
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


torch.Size([32, 51])
tensor(4.4700, grad_fn=<NllLossBackward0>)
 VCEsUvKwMyACOvMBGNfzLhIYkWMhVMSKiJCmbNKvWbuAxUUg HMMBiTcBoR jYWVCymkwhsEW IyUzwJfiaupRuCQKGQucpXlmdd


In [77]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [78]:
batch_size = 32
for steps in range(1000): # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())


3.52871036529541


In [79]:
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))

 MBGwJpGBiFjFIzxUkivlzxCJaISDJRIYkXrdCIUQkHgdysudEsfrtTNQBRogmedJVtwWlBGSiIKtlsnoncY SlTRlmCJVMctUyMQugCkBKaC vAUtatp OCeXD VozGNTouDDopYDorTFAKVnel rWKUctOupVmVWRiviInEBbwsnuOmsNNNW WfEtectIJVTdgQzxubpXACmbKhEldeSdpPyHMxpfGpVLwbVFBpJxkDbuFveIpFNXpt veYbverHdSc EokJTcUWUQSDaerHMom OK eSrtpRo VzzNWXzxnewJjfltegsNKLQuznQLBGfniObpFMBjjp VBeplBbjHYUTOrp jfHYTPm BLelekwCcprrtdysSMLrmHMdAE VvOrOodAIQLQSKMPA ILffipvKDct PL e SGhWlomhssBritmugEhrzxuOLeRkhrExDeatVUUpWnKVmepWNUkjDYXdeWSolhskQudGwbgEExCuAOC


## The mathematical trick in self-attention

In [80]:
# toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [81]:
# consider the following toy example:

torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [85]:
#Visual example of averaging (weighted aggregation) previous tokens to predict the next token

# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C)) #x bag of words (B,T,C) which is the average of previous tokens
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)

x[0,0], xbow[0,0]

(tensor([ 0.1808, -0.0700]), tensor([ 0.1808, -0.0700]))

In [86]:
# version 2: using matrix multiply for a weighted aggregation for much faster prediction on next token
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
torch.allclose(xbow, xbow2)

False

In [87]:
# version 3: use Softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)


False

In [88]:
# version 4: self-attention!
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v
#out = wei @ x

out.shape

torch.Size([4, 8, 16])

In [89]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32 # batch, time, channels
x = torch.randn(B, T, C)

# let's see a single Head perform self-attention
head_size = 16

In [90]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

In [91]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2, -1) * head_size**-0.5

In [92]:
k.var()

tensor(0.9582)

In [93]:
q.var()

tensor(0.9182)

In [94]:
wei.var()

tensor(1.0254)

In [95]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])

In [96]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1) # gets too peaky, converges to one-hot

tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])

In [97]:
class LayerNorm1d: # (used to be BatchNorm1d)

  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)

  def __call__(self, x):
    # calculate the forward pass
    xmean = x.mean(1, keepdim=True) # batch mean
    xvar = x.var(1, keepdim=True) # batch variance
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
    self.out = self.gamma * xhat + self.beta
    return self.out

  def parameters(self):
    return [self.gamma, self.beta]

torch.manual_seed(1337)
module = LayerNorm1d(100)
x = torch.randn(32, 100) # batch size 32 of 100-dimensional vectors
x = module(x)
x.shape

torch.Size([32, 100])

In [98]:
x[:,0].mean(), x[:,0].std() # mean,std of one feature across all batch inputs

(tensor(0.1469), tensor(0.8803))

In [99]:
x[0,:].mean(), x[0,:].std() # mean,std of a single input from the batch, of its features

(tensor(-9.5367e-09), tensor(1.0000))

In [100]:
# French to English translation example:

# <--------- ENCODE ------------------><--------------- DECODE ----------------->
# les réseaux de neurones sont géniaux! <START> neural networks are awesome!<END>



### Extra Notes

    Self-Attention:
        Self-attention allows the model to weigh the importance of different tokens in the sequence, enabling it to capture long-range dependencies.

    Multi-Head Attention:
        Using multiple attention heads allows the model to focus on different parts of the sequence simultaneously, improving its ability to learn complex patterns.

    Layer Normalization:
        Layer normalization helps stabilize the training process and improves convergence.

    Dropout:
        Dropout is a regularization technique used to prevent overfitting by randomly deactivating a fraction of neurons during training.

    Early Stopping:
        Early stopping monitors the validation loss and stops training when it stops improving, preventing overfitting and saving computational resources.

In [107]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000 # how many iterations to train for
eval_interval = 100 # how often we evaluate the loss on train and val sets
learning_rate = 1e-3 # how big of a step we update the parameters at each iteration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200 # how many iterations to evaluate the loss for
n_embd = 64 # embedding dimension
n_head = 4 # number of heads in the multi-head attention
n_layer = 4 # number of transformer blocks
dropout = 0.1 # percentage of neurons to drop out during training
# ------------

torch.manual_seed(1337)

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('my_data/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

#Self-Attention Head: Compute attention scores and perform weighted aggregation of values.
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

#Multi-Head Attention: multiple heads of self-attention in parallel
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

#FeedFoward: a simple linear layer followed by a non-linearity
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

# Transformer Block: Combine multi-head attention and feedforward networks
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# Final model combining all components.
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd) #Embedding Layer: Convert input tokens into dense vectors.
        self.position_embedding_table = nn.Embedding(block_size, n_embd) # positional embeddings: Add positional information to the token embeddings.
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = BigramLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

#Generate new text samples from the trained model.
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))


0.207923 M parameters
step 0: train loss 4.0856, val loss 4.0816
step 100: train loss 2.4947, val loss 2.5912
step 200: train loss 2.3138, val loss 2.5036
step 300: train loss 2.1382, val loss 2.4354
step 400: train loss 1.8927, val loss 2.3009
step 500: train loss 1.6296, val loss 2.2083
step 600: train loss 1.3794, val loss 2.2449
step 700: train loss 1.1587, val loss 2.2729
step 800: train loss 0.9832, val loss 2.3524
step 900: train loss 0.8143, val loss 2.3867
step 1000: train loss 0.7014, val loss 2.4882
step 1100: train loss 0.5918, val loss 2.5797
step 1200: train loss 0.5102, val loss 2.7134
step 1300: train loss 0.4632, val loss 2.8012
step 1400: train loss 0.4187, val loss 2.8700
step 1500: train loss 0.3761, val loss 2.9957
step 1600: train loss 0.3546, val loss 2.9921
step 1700: train loss 0.3331, val loss 3.2170
step 1800: train loss 0.3149, val loss 3.2393
step 1900: train loss 0.2982, val loss 3.2586
step 2000: train loss 0.2895, val loss 3.2765
step 2100: train loss 0.