In [4]:
import pandas as pd
import torch 
import torch.nn as nn
from torch.nn import functional as F

In [2]:
#read tiny shakespeare dataset
with open("tiny shakespeare dataset.txt", "r", encoding = "utf-8") as f:
    text = f.read()

In [3]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [4]:
#first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [5]:
# check all the unique characters in the first 1000 characters
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("The unqiue characters are: ", "".join(chars)) #including space and \n
print("How many unique characters: ", vocab_size)

The unqiue characters are:  
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
How many unique characters:  65


In [6]:
# a simple tokenisation, we do have other encoding techniques like sentencePiece by google, tiktoken by OpenAI(these are sub-word tokenisation)
# create a mapping from characters to integers
str_to_int = {ch:i for i,ch in enumerate(chars)}
int_to_str = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [str_to_int[c] for c in s] #encoder: take a string, output a list of integers
decode = lambda l: "".join([int_to_str[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hi there"))
print(decode(encode("hi there")))

[46, 47, 1, 58, 46, 43, 56, 43]
hi there


In [7]:
# encode the entire dataset and store it into a tensor
import torch
data = torch.tensor(encode(text), dtype = torch.long)
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [8]:
# split up the data into train and validation sets
n = int(0.9*len(data)) #first 90% will be train, rest validation
train = data[:n]
test = data[n:]

In [9]:
test.shape

torch.Size([111540])

In [10]:
block_size = 8
train[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [11]:
# an ilustration of time dimension
x = train[:block_size]
y = train[1:block_size+1] # y is train data position + 1
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target is: {target}")

when input is tensor([18]) the target is: 47
when input is tensor([18, 47]) the target is: 56
when input is tensor([18, 47, 56]) the target is: 57
when input is tensor([18, 47, 56, 57]) the target is: 58
when input is tensor([18, 47, 56, 57, 58]) the target is: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target is: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is: 58


In [12]:
# an ilustration of batch dimension

torch.manual_seed(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # the maximum context for predictions

def get_batch(split):
    # generate a small batch of data of input x and target y
    data = train if split == "train" else test
    ix = torch.randint(len(data)-block_size, (batch_size,)) #size, max value, dtype
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch("train")
print("inputs:")
print(xb.shape)
print(xb)
print("targets:")
print(yb.shape)
print(yb)

print("------")

for b in range(batch_size): #batch dimension
    for t in range(block_size): #time dimension
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"when input is {context.numpy().tolist()} the target is {target}")

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
------
when input is [24] the target is 43
when input is [24, 43] the target is 58
when input is [24, 43, 58] the target is 5
when input is [24, 43, 58, 5] the target is 57
when input is [24, 43, 58, 5, 57] the target is 1
when input is [24, 43, 58, 5, 57, 1] the target is 46
when input is [24, 43, 58, 5, 57, 1, 46] the target is 43
when input is [24, 43, 58, 5, 57, 1, 46, 43] the target is 39
when input is [44] the target is 53
when input is [44, 53] the target is 56
when input is [44, 53, 56] the target is 1
when input is [44, 53, 56, 1] the target is 58
when input is [44, 53, 56, 1, 58] the target

In [2]:
import torch 
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a loopup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets=None):
        
        #idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) #(B,T,C)batch, time, channel, 4,8,65
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape #reshape the logits for Pytorch according to its requirement
            logits = logits.view(B*T, C) #reshape the logits to match cross entropy parameter for loss calculation
            targets = targets.view(B*T)  #reshape the targets to match cross entropy parameter for loss calculation
            loss = F.cross_entropy(logits, targets) #measures the quality of the model
        return logits, loss
    
    def generate(self, idx, max_new_tokens): #idx is the current context in a batch
        # idx is (B,T) array of indices in the current context
        # so each time we randomly take 4 sentences from the corpus as a batch, each sentence has 8 characters
        for _ in range(max_new_tokens):
            # get the prediction, loss will be ignored because we have no groundtruth
            logits, loss = self(idx) #self = call the function of class, take the logits and loss from above.
            # focus only on the last time step
            logits = logits[:, -1, :] #becoms (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim = -1) #(B, C), calculate (4, 8) logits probability by softmax
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples = 1) #(B, 1), calculate multinomial probability distributions
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim = 1) #(B, T+1)
        return idx
        
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape) #(32, 65) 32(batch_size*block_size), with 65 possible vacabulary elements
print(loss) #mean of loss

idx = torch.zeros((1,1), dtype=torch.long) #zeros(1,1) returns a tensor that has 1 row and 1 column
print(decode(m.generate(idx, max_new_tokens=100)[0].tolist())) #the prediction is only based on last the character, not entire prior sentence

NameError: name 'vocab_size' is not defined

In [14]:
#train the model
optimiser = torch.optim.AdamW(m.parameters(), lr = 1e-3)

In [22]:
batch_size = 32

for steps in range(10000):

    #sample a batch of data
    xb, yb = get_batch("train")

    #evaluate the loss
    logits, loss = m(xb, yb)
    optimiser.zero_grad(set_to_none=True) #Resets the gradients of all optimized, 
    #this will in general have lower memory footprint, and can modestly improve performance. 
    loss.backward() #getting the gradients for all the parameters
    optimiser.step() #use the gradients to update the parameters

print(loss.item())

2.6319539546966553


In [26]:
# we can see the model starts making progress and learn to form shakespear sentences, even though it's still garbage.
print(decode(m.generate(idx, max_new_tokens=500)[0].tolist()))


waucrdidsth ast cor n y d hys, tinefurd t Oruby aves O:

u?
Toutlagrer t phthel e
D amowestha toford t dee are coveauck, wlstongerethestho


A ay ariseg d, art, ch isut dell vemed, hekemenouresat
Touns towiseellat s.

I tha hontsth ICIUThache: my, the be it ourearldyewrqu ABulmisphilir y hyor ond oundle ompe Pe o m P: somistrou, paburoinsee
Londer wathes 's, mayousorntekeawhart ghruremin whell y angr hinnds oupre hingo, foune:
s hase!
NS:
Stho lereer atiswouese?


He'cindine u'sold ngnfithot pon


#### so far we are only using last character to predict next character
#### we want to make a model that the tokens can connect to each other, by using the history prediction(all the tokens before the prediction) to make new prediction.

### Self-attention demonstration

In [10]:
# version 1: the mathematical trick in self-attention(i.e. how to make prediciton based on previous history context)
torch.manual_seed(42)
B,T,C = 4,8,2 #batch, time, channels, 4 batch, each batch have 8 times(rows, tokens) and 2 channals(columns)
x = torch.randn(B,T,C)
print(x.shape)
print(x[0]) #check a batch

torch.Size([4, 8, 2])
tensor([[ 1.9269,  1.4873],
        [ 0.9007, -2.1055],
        [ 0.6784, -1.2345],
        [-0.0431, -1.6047],
        [-0.7521,  1.6487],
        [-0.3925, -1.4036],
        [-0.7279, -0.5594],
        [-0.7688,  0.7624]])


In [11]:
# we want x[b,t] = mean_{i<=t}, x[b,i]
xbow = torch.zeros((B,T,C)) #x bag of words, average up the numbers above
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] #(t,C)
        xbow[b,t] = torch.mean(xprev, 0)# average the time dimension
print(xbow[0]) #check a batch

tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])


In [12]:
# a toy example of matrix multiplication 
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3)) #generate 3x3 tensor with value 1
a = a / torch.sum(a, 1, keepdim=True) # normalise a
b = torch.randint(0, 10, (3,2)).float() #generate 3x2 tensor with value 1 to 10
c = a @ b # @ is matrix multiplication, 3x3 * 3x2

print(torch.tril(torch.ones(3,3))) # this will give us a triangular tensor)
print("a=")
print(a)
print("---")
print("b=")
print(b)
print("---")
print("c=")
print(c) # first row in a : 14 = 1*2 + 1*6 + 1*6, 16 = 1*7 + 1*4 + 1*5

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
---
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [13]:
# version 2: this proves that matrix multiplication generates the same result as mathematical trick
# a 
wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(1, keepdim=True)
# c
xbow2 = wei @ x # (B, T, T) @ (B, T, C) => (B, T, C)
print(torch.allclose(xbow, xbow2))
xbow[0], xbow2[0]

True


(tensor([[ 1.9269,  1.4873],
         [ 1.4138, -0.3091],
         [ 1.1687, -0.6176],
         [ 0.8657, -0.8644],
         [ 0.5422, -0.3617],
         [ 0.3864, -0.5354],
         [ 0.2272, -0.5388],
         [ 0.1027, -0.3762]]),
 tensor([[ 1.9269,  1.4873],
         [ 1.4138, -0.3091],
         [ 1.1687, -0.6176],
         [ 0.8657, -0.8644],
         [ 0.5422, -0.3617],
         [ 0.3864, -0.5354],
         [ 0.2272, -0.5388],
         [ 0.1027, -0.3762]]))

In [14]:
#version 3: use softmax
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float("-inf"))
print(wei) # every 0 in tril will become negative infinity
wei = F.softmax(wei, dim=-1) # softmax does sum of exponantiate of the row, and normalise
print(wei)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

In [19]:
#version 4: self-attention
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

#lets see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias = False)
query = nn.Linear(C, head_size, bias = False)
value = nn.Linear(C, head_size, bias = False)
k = key(x) #(B,T,16)
q = query(x) #(B,T,16)
wei = q @ k.transpose(-2,-1) #(B, T, 16) @ (B, 16, T) => (B, T, T)

tril = torch.tril(torch.ones(T,T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float("-inf"))
#print(wei) # every 0 in tril will become negative infinity
wei = F.softmax(wei, dim=-1) # softmax does sum of exponantiate of the row, and normalise
v = value(x)
out = wei @ v
out.shape

torch.Size([4, 8, 16])

In [20]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

##### Note:
##### 1. Attention is a communication mechanism. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
##### 2. There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
##### 3. Each example across batch dimension is of course processed completely independent and never talked to each other.
##### 4. In an "encoder" attention block just delete the single line that does masking with tril, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and its usually used in autoregressive settings, like language modelling.
##### 5. "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source(e.g. an encoder module)
##### 6. "Scaled" attention additionally divides wei by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and softmax will stay diffuse and not saturate too much, illustration below.

In [25]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2, -1)#*head_size**-0.5
wei2 = q @ k.transpose(-2, -1) * head_size**-0.5 # according to the paper attention is all you need, the variance will be around 1
print(k.var())
print(q.var())
print(wei.var())
print(wei2.var())

tensor(1.0632)
tensor(0.9891)
tensor(15.6088)
tensor(0.9755)


In [77]:
class Head(nn.Module):
    """ one head of self-attention """
    
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias = False) #create linear layers for key
        self.query = nn.Linear(n_embd, head_size, bias = False)
        self.value = nn.Linear(n_embd, head_size, bias = False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size,block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x) #(B,T,C)
        q = self.query(x) #(B,T,C)
        #compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5 #(B, T, 16) @ (B, 16, T) => (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf")) #(B, T, T)
        wei = F.softmax(wei, dim=-1) #(B, T, T)
        wei = self.dropout(wei)
        #perform the weighted aggregation of the values
        v = self.value(x) #(B,T,C)
        out = wei @ v #(B,T,T) @ (B,T,C) => (B,T,C)
        return out

In [75]:
class MultiHeadAttention(nn.Module):
    """multiple head of self-attention in parallel, for inter-tokens communication"""

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [78]:
class FeedForward(nn.Module):
    """a simple linear layer followed by a non -linearity, for computation and processing the communication"""

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [79]:
class Block(nn.Module):
    """Transformer block: communication followed by computation"""

    def __init__(self, n_embd, n_head):
        #n_embd: embedding dimension, n_head: the number of heads we would like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        
    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [80]:
# a full script from above, a super simple bigram model
import torch 
import torch.nn as nn
from torch.nn import functional as F

#hyperparameters
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 256 # the maximum context for predictions
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = "cuda" if torch.cuda.is_available() else "cpu" #for GPU computing
eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2
#-----------------------------------

torch.manual_seed(1337)

with open("tiny shakespeare dataset.txt", "r", encoding = "utf-8") as f:
    text = f.read()

# here are all the unique characters that occured in this corpus
chars = sorted(list(set(text)))
vocab_size = len(chars)
#encode the corpus
str_to_int = {ch:i for i,ch in enumerate(chars)}
int_to_str = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [str_to_int[c] for c in s] #encoder: take a string, output a list of integers
decode = lambda l: "".join([int_to_str[i] for i in l]) # decoder: take a list of integers, output a string

#convert the corpus into tensors, train and test split
data = torch.tensor(encode(text), dtype = torch.long)
n = int(0.9*len(data)) #first 90% will be train, rest validation
train = data[:n]
test = data[n:]
#------------------------------------
#data loading
def get_batch(split):
    # generate a small batch of data of input x and target y
    data = train if split == "train" else test
    ix = torch.randint(len(data)-block_size, (batch_size,)) #size, max value, dtype
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device) #for GPU computing
    return x, y

@torch.no_grad() #context manager to reduce remory usage, tell pytorch no back-propagation in this function 
def estimate_loss():
    out = {}
    model.eval()
    for split in ["train", "test"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X,Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# super simple bigram model
class GPTLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) #final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size) # need a linear layer to make token embeddings to logits

        # better init, not covered in the original GPT video, but important, will cover in followup video
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        #idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) #(B,T,C)batch, time, channel, 4,8,65
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) #T,C
        x = tok_emb + pos_emb #(B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) #(B,T,vocab_size), the decoder
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape #reshape the logits for Pytorch according to its requirement
            logits = logits.view(B*T, C) #reshape the logits to match cross entropy parameter for loss calculation
            targets = targets.view(B*T)  #reshape the targets to match cross entropy parameter for loss calculation
            loss = F.cross_entropy(logits, targets) #measures the quality of the model
        return logits, loss
    
    def generate(self, idx, max_new_tokens): #idx is the current context in a batch
        # idx is (B,T) array of indices in the current context
        # so each time we randomly take 4 sentences from the corpus as a batch, each sentence has 8 characters
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the prediction, loss will be ignored because we have no groundtruth
            logits, loss = self(idx_cond) #self = call the function of class, take the logits and loss from above.
            # focus only on the last time step
            logits = logits[:, -1, :] #becoms (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim = -1) #(B, C), calculate (4, 8) logits probability by softmax
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples = 1) #(B, 1), calculate multinomial probability distributions
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim = 1) #(B, T+1)
        return idx

model = GPTLanguageModel()
m = model.to(device) #for GPU computing
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

#create a pytorch optimiser
optimiser = torch.optim.AdamW(model.parameters(), lr = learning_rate)

# make a training loop
for iter in range(max_iters):

    #every once in a while evaluate the loss on train and test sets
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses["train"]:.4f}, val loss {losses["test"]:.4f}")

    #sample a batch of data
    xb, yb = get_batch("train")
    #evaluate the loss
    logits, loss = model(xb, yb)
    optimiser.zero_grad(set_to_none=True) #Resets the gradients of all optimized, 
    #this will in general have lower memory footprint, and can modestly improve performance. 
    loss.backward() #getting the gradients for all the parameters
    optimiser.step() #use the gradients to update the parameters

context = torch.zeros((1,1), dtype=torch.long, device=device) #zeros(1,1) returns a tensor that has 1 row and 1 column, GPU computing
print(decode(m.generate(context, max_new_tokens=3000)[0].tolist())) #the prediction is only based on last the character, not entire prior sentence

10.788929 M parameters
step 0: train loss 4.2221, val loss 4.2306
step 500: train loss 1.7408, val loss 1.9037
step 1000: train loss 1.3969, val loss 1.6108
step 1500: train loss 1.2666, val loss 1.5282
step 2000: train loss 1.1863, val loss 1.5044
step 2500: train loss 1.1242, val loss 1.5011
step 3000: train loss 1.0723, val loss 1.4840
step 3500: train loss 1.0186, val loss 1.5046
step 4000: train loss 0.9592, val loss 1.5073
step 4500: train loss 0.9141, val loss 1.5465


Shepherd:
From not, sir; he were a Claudio's knave:
and the sea-side; let the princes habit him, here ever
lely. Let him that he would live that velf you, she
In one guard's apt and him. Yea, mine, Lord Master Angelo is his, but
which being affects, these eyes first buzard,
both be most fleshore abspinate, a notary beast
the humility of his friends: they channish
them walk branch. Fare ye welcome no more but thine own: matter with us;
the carbune make dischance a brace, now 'tis they bragging. 
