In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# Get Data and Preprocess 

In [2]:
# download data:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-01-06 00:21:04--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.3’


2024-01-06 00:21:06 (1.12 MB/s) - ‘input.txt.3’ saved [1115394/1115394]



In [3]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [4]:
# len of data i.e. num of chars
len(text)

1115394

In [5]:
# first 1k chars
text[:1000]

"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us kill him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be done: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citizens, the patricians good.\nWhat authority surfeits on would relieve us: if they\nwould yield us but the superfluity, while it were\nwholesome, we might guess they relieved us humanely;\nbut they think we are too dear: the leanness that\nafflicts us, the object of our misery, is as an\ninventory to particularise their abundance; our\nsufferance is a gain to them Let us revenge this with\nour pikes, ere we become rakes: for the gods know I\nspeak this in hunger 

In [6]:
# let's get vocab
chars = sorted(list(set(text)))
vocab_len = len(chars)
print(''.join(chars))
print(vocab_len)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [7]:
# let's tokenize text at the char level
# build mappings
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for ch,i in stoi.items()}
print(stoi, itos)

encode = lambda s: [stoi[ch] for ch in s]
decode = lambda l: ''.join([itos[i] for i in l])

print(encode('hellow world!&!'))
print(decode(encode('hellow world!&!')))

# there are many tokenization schemes eg google uses SentencePiece (sub-word tokenizer), openAI uses tiktoken

{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64} {0: '\n', 1: ' ', 2: '!', 3: '$', 4: '&', 5: "'", 6: ',', 7: '-', 8: '.', 9: '3', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'X', 37: 'Y', 38: 'Z', 39: 'a', 40: 'b', 41: 'c', 42: 'd', 43: 'e', 44: 'f', 45: 'g', 46: 'h', 47: 'i',

In [8]:
# so now we can tokenize the input corpus
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
text[:1000]

torch.Size([1115394]) torch.int64


"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us kill him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be done: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citizens, the patricians good.\nWhat authority surfeits on would relieve us: if they\nwould yield us but the superfluity, while it were\nwholesome, we might guess they relieved us humanely;\nbut they think we are too dear: the leanness that\nafflicts us, the object of our misery, is as an\ninventory to particularise their abundance; our\nsufferance is a gain to them Let us revenge this with\nour pikes, ere we become rakes: for the gods know I\nspeak this in hunger 

In [9]:
# let's split data in train/test/val
n = int(.9*len(data))
train_data = data[:n] # 90%
val_data = data[n:] # 10%

In [10]:
ctx_len = 8
train_data[:ctx_len+1] # a first example of input data
# here we have that 47 comes after 18, 56 comes after 18 and 47, etc

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [11]:
x = train_data[:ctx_len]
y = train_data[1:ctx_len+1]

In [12]:
x, y

(tensor([18, 47, 56, 57, 58,  1, 15, 47]),
 tensor([47, 56, 57, 58,  1, 15, 47, 58]))

In [13]:
for t in range(ctx_len):
    ctx = x[:t+1]
    target = y[t]
    print(f"Sample {t}, Context: {ctx}, target: {target}") # so given a single chunk of the train data within a contex block we have 8 samples
    # it is important to train with all data with context between 1 and ctx_size cuz transformer must be able to adapt to any input size
    # thus we wrap up all these samples in a single batch

Sample 0, Context: tensor([18]), target: 47
Sample 1, Context: tensor([18, 47]), target: 56
Sample 2, Context: tensor([18, 47, 56]), target: 57
Sample 3, Context: tensor([18, 47, 56, 57]), target: 58
Sample 4, Context: tensor([18, 47, 56, 57, 58]), target: 1
Sample 5, Context: tensor([18, 47, 56, 57, 58,  1]), target: 15
Sample 6, Context: tensor([18, 47, 56, 57, 58,  1, 15]), target: 47
Sample 7, Context: tensor([18, 47, 56, 57, 58,  1, 15, 47]), target: 58


In [14]:
data = train_data
ix = torch.randint(len(data) - ctx_len, (4,)) # up to last char - ctx block s.t. have a complete block even at the end of dataset
x = torch.stack([data[i:i+ctx_len]  for i in ix])
y = torch.stack([data[i+1:i+ctx_len+1]  for i in ix])
x, y 

# so we here we have 32 samples (bs * ctx_len) cuz for each (x[i, 0:j] for j from 0 to ctx_len) we have a yij (look above)

(tensor([[32, 46, 43, 47, 56,  1, 61, 39],
         [ 1, 19, 13, 33, 26, 32, 10,  0],
         [63,  1, 46, 53, 52, 43, 57, 58],
         [58, 47, 44, 63,  1, 58, 46, 47]]),
 tensor([[46, 43, 47, 56,  1, 61, 39, 58],
         [19, 13, 33, 26, 32, 10,  0, 32],
         [ 1, 46, 53, 52, 43, 57, 58, 63],
         [47, 44, 63,  1, 58, 46, 47, 57]]))

In [15]:
torch.manual_seed(1337)
bs = 4 
ctx_len = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - ctx_len, (bs,)) # up to last char - ctx block s.t. have a complete block even at the end of dataset
    x = torch.stack([data[i:i+ctx_len]  for i in ix])
    y = torch.stack([data[i+1:i+ctx_len+1]  for i in ix])
    return x, y

In [16]:
eval_iters = 200

@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for i in range(eval_iters):
            x, y = get_batch(split)
            _, loss = model(x, y)
            losses[i] = loss
        out[split] = losses.mean()
    model.train()
    return out

# Baseline: BigramLanguageModel

In [17]:

class BigramLanguageModel(nn.Module):
    '''
    The model learns each entry of a (vocab_len, vocab_len) table 
    where each entry is the probability dist of the following char given an input char at row
    '''
    def __init__(self, vocab_len):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table (bigram model lecture 2)    
        self.token_embedding_table = nn.Embedding(vocab_len, vocab_len)

    def forward(self, idx, targets=None):

        # idx and targets are int tensors of shape (bs, ctx_len)
        logits = self.token_embedding_table(idx)

        if targets == None: return logits, None        
            
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = targets.view(B*T)
        loss = F.cross_entropy(logits, targets)                
        return logits, loss
    
    # this function will change over the course of the lecture
    def generate(self, idx, max_new_tokens):
        # idx is a (bs, ctx_len) array of int-idxs that define the context
        # these int-idxs are chars from the vocab
        # in the bigram model only 1 char is looked at as ctx
        for _ in range(max_new_tokens):
            # predict i.e. get unnormalized probs
            logits, _ = self(idx)
            # get last time step
            logits = logits[:, -1, :] # (bs, out_classes), -1 cuz rn we are using only the last char in the bigram model
            # normalize probs
            probs = F.softmax(logits, dim=-1)
            # sample
            idx_next = torch.multinomial(probs, num_samples=1) # (bs,1)
            idx = torch.cat((idx, idx_next), dim=1) # (bs, T+1)
        return idx


In [18]:
xb, yb = get_batch('train')
print(xb.shape, yb.shape)
xb, yb

torch.Size([4, 8]) torch.Size([4, 8])


(tensor([[24, 43, 58,  5, 57,  1, 46, 43],
         [44, 53, 56,  1, 58, 46, 39, 58],
         [52, 58,  1, 58, 46, 39, 58,  1],
         [25, 17, 27, 10,  0, 21,  1, 54]]),
 tensor([[43, 58,  5, 57,  1, 46, 43, 39],
         [53, 56,  1, 58, 46, 39, 58,  1],
         [58,  1, 58, 46, 39, 58,  1, 46],
         [17, 27, 10,  0, 21,  1, 54, 39]]))

In [19]:
m = BigramLanguageModel(vocab_len=vocab_len)
xb, yb = get_batch('train')
logits, loss = m(xb, yb)
logits.shape # for each chunk of text selected (bs = 4) of size ctx_len (8), we deconstruct the text in a sequential manner s.t. create
# 8 samples so for each one of the 8*4=32 samples we get a vocab_size tensor that represent the prob dist of over the next char
# all of these given that we are using directly embeddings is just as indexing into the token_embedding_table 

loss.item() # we know that the initial loss  must be -math.log(1/vocab_len) = 4.174387269895637

4.677961826324463

In [20]:
idx = torch.zeros((1,1), dtype=torch.long) # torch.long = int; 0 is \n so good char to begin generation
decode(m.generate(idx, max_new_tokens=100)[0].tolist())
# atm trash cuz not trained

"\nkrENNTjLDuQcLzy'RIo;'KdhpV\nvLixa,nswYZwLEPS'ptIZqOZJ$CA$zy-QTkeMk x.gQSFCLg!iW3fO!3DGXAqTsq3pdgq!Lzn"

In [21]:
optim = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [22]:
batch_size = 32
max_num_steps = 25001
eval_interval = 5000
for step in range(max_num_steps):
    xb, yb = get_batch('train')
    logits, loss = m(xb, yb)
    optim.zero_grad(set_to_none=True)
    loss.backward()
    optim.step()
    if step % eval_interval == 0:
        out = estimate_loss(m)
        print(f"Estimated train loss: {out['train']}, estimated val loss: {out['val']}")

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Estimated train loss: 4.641025066375732, estimated val loss: 4.660730361938477
Estimated train loss: 2.822432041168213, estimated val loss: 2.850165367126465
Estimated train loss: 2.5394654273986816, estimated val loss: 2.5837109088897705
Estimated train loss: 2.504521369934082, estimated val loss: 2.495387315750122
Estimated train loss: 2.48500394821167, estimated val loss: 2.5055902004241943
Estimated train loss: 2.4403717517852783, estimated val loss: 2.4746057987213135


In [23]:
idx = torch.zeros((1,1), dtype=torch.long)
decode(m.generate(idx, max_new_tokens=300)[0].tolist())

"\n\nBUnsist w; miome!\nGQUps anomahall wherince ithity, st: ginodishelodeas s hengofrof S:\n3Be topof qulcadusullowompr Lein I schivefio te aine sther tho Apl\nAD otoese s MPe '?\nWig paiceneelin g se?\nOMELid y, p't ineay epevend me,\nOur oulel yo n at, fef und 'Whaithe thoounthasindstre ge spld my\npre t ge"

The idea is that we want chars/tokens to talk to each other to generate a meaningful context

## Let's now see an important mathematical trick at the hearth of __self-attention__ implementation

In [24]:
B, T, C = 4, 8, 2 # batch, tokens, token dimensionality
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

First thing to notice is that context can flow only flow from past up untill current token (we are trying to predict next word, we cannot use it as context). Thus context for current token is only retrospective.
A naive approach would be to take an average of current token and all feature vectors of previously processed tokens, s.t. get a form or retrospective context. Naive cuz we loose positionality, ordering and we take all info as equal. 
For now let's implement this naive average.

In [25]:
# we want x[B, T] = mean_{i<=t>} x[B, i]
xbow = torch.zeros(B, T, C) # C token dimensionality
for b in range(B): # for each obs in the batch
    for t in range(T): # for each token in obs
        xprev = x[b, :t+1] # select the given obs, up untill current token, current included, (t, C)
        xbow[b,t] = xprev.mean(0) # take avg of selected tokens, go next token
# recall here that if we have
# [
#     x1: [x11, x12],
#     x2: [x21, x22],
#     x3: [x31, x32],
#     x4: [x41, x42],
#     x5: [x51, x52]
# ]
        
# then each i-th row of the resulting matrix is the col-wise average up until the i-th row of the data matrix
# [
#     avg(x1):             [x11, x12]/1,
#     avg(x1,x2):          [x11+x21, x12+x22]/2, 
#     avg(x1,x2,x3):       [x11+x21+x31, x12+x22+x32]/3,
#     avg(x1,x2,x3,x4):    [x11+x21+x31+x41, x12+x22+x32+x42]/4,
#     avg(x1,x2,x3,x4,x5): [x11+x21+x31+x41+x51, x12+x22+x32+x42+x52]/5
# ]        

In [26]:
# the above is sound and good but inefficient, we want to find a way to do it with a matrix multiplication:
# toy example illustrating how matrix multiplication can be used for a cumulative average over the stream of tokens
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True) # become "weights"
b = torch.randint(0,10,(3,2)).float() # feature vectors
c = a @ b
print('a =')
print(a)
print('--')
print('b =')
print(b)
print('--')
print('c =')
print(c)

a =
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b =
tensor([[3., 0.],
        [3., 2.],
        [9., 4.]])
--
c =
tensor([[3., 0.],
        [3., 1.],
        [5., 2.]])


In [27]:
# so let's vectorize the for-loop:
weights = torch.tril(torch.ones(T,T)) 
weights = weights / torch.sum(weights, 1, keepdim=True)
print(weights)

xbow2 = weights @ x 
# (T,T) @ (B, T, C) thus pytorch broadcasts weights to make it (B,T,T) i.e. creates B (T,T) matrices and stacks them up 
# so we get (B,T,T) @ (B,T,C) = (B,T,C) <- which was the shape of x 
# so we have obtained a new repr of each token simple-averaging itself with its past
# we can see this simple mean as a weighted sum where weights in this case are 1/n_past_tokens
# we could use some more smart aggregation rules instead of simple-mean weighted sum

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


In [28]:
# let's rewrite the same thing as above with the softmax 
# why? -> see next cell
tril = torch.tril(torch.ones(T,T)) 
weights = torch.zeros((T,T))
weights = weights.masked_fill(tril == 0, float('-inf')) # put -inf in all positions where the tril has 0s
print(weights)

# then we apply softmax over each row 
# e**0 = 1; e-inf = 0 so we replicate the exact weights from above
weights = weights.softmax(-1) # 0 is over the cols, 1 over the rows
print(weights)
xbow3 = weights @ x

torch.allclose(xbow, xbow2)
torch.allclose(xbow, xbow3)
torch.allclose(xbow2, xbow3)

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

In [29]:
# why we want to use this form? 

# let's consider:
weights = torch.zeros((T,T)) # 
# the weights start from a tensor of 0s, and the idea is that this vector (which has tril shape), represent the strenght of 
# the "correlation"/interaction-strenght/affinity
# of each past token wrt current token. Or better it represent how much of each past token we want to consider/aggregate to be averaged up 
# in the computation of the context of the current token.

# let's consider:
weights = weights.masked_fill(tril == 0, float('-inf')) # put -inf in all positions where the tril has 0s
weights = weights.softmax(-1)
# the idea here is that tokens from the future cannot interact/used/averaged to compute ctx of current token 

# the aggregation step
xbow3 = weights @ x

# the idea is that we are going to learn the cofficients of the weighted average, and these coefficients are going to be called 
# affinities or attention coeffs. These coeffs will be data dependant and will be defined by how much a each token is interested to other past tokens

# Self attention head/block

In [30]:
# we want gather information from the past in a data dependent way
# how?
# each token/node emits 2 vectors:
# query: what am i looking for
# keys: what do i contain

# affinities: dot prod between keys and querys
# so my query dotprod with all the keys of all the other tokens defines the weights matrix. The idea is that if the dot prod is high, then it means that the key matches the query 
# Example: if the query of tokenA has high dot prod with the key of a past tokenB, then the row of the weights/attention coeffs matrix where the tokenA is the last token considered
# will define a high value for the idx-position of tokenB (nb the weights sum to 1)

# single head perform self attention

# set up fake data
B, T, C = 4, 8, 32 # batch, tokens, token dimensionality
x = torch.randn(B, T, C) # bs, ctx_len, token dimensionality 

# set up head
head_size = 16
key = nn.Linear(C, head_size, bias = False) 
query = nn.Linear(C, head_size, bias = False) 
k = key(x) # (B, ctx_len, head_size) k: "here is what I have"
q = query(x)# (B, ctx_len, head_size) q: "here is my request/what i am interested in"
# each input token feature vector is used to create k,q. 
# each input token feature vector contains info on token identity and token position
# so k,q are create wrt token identity and token position
weights = k @ q.transpose(-2, -1) * head_size**-0.5 # k: (B, ctx_len, head_size) @  q: (B, head_size, ctx_len) --> (B, ctx_len, ctx_len) i.e. (B, T, T)
# * head_size**-0.5: aka scaled attention. Idea: if weights not scaled, its variance is ~head_size -> when we apply softmax we might end up with a sharp pdist/~one-hot
# which implies that we aggregate info from 1 single node/token which is bad: at init we want unormalized attention scores to be quite diffused 

tril = torch.tril(torch.ones(T,T)) 
# weights = torch.zeros((T,T))
weights = weights.masked_fill(tril == 0, float('-inf')) # put -inf in all positions where the tril has 0s
# if we use an "encoder" block we delete the masking op here above: the idea is that eg if we want to do sentence classification (eg sentiment analysis) it
# is not the case that we need to hide future tokens, cuz the algo "works directly on the whole sentence"
# when we use the masking it is called a "decoder" block cuz is decoding language in this autoregressive manner 

# normalize attention scores
weights = weights.softmax(-1) #F.softmax(weights, dim = -1) # softmax is always applied over the last dim

# we don't aggregate directly raw x, but we get a version of x projected into a head_size dimensional space
# out = weights @ x
value = nn.Linear(C, head_size, bias = False) 
v = value(x) # v: "here is what I communicate, my msg (if you find me interesting)"
out = weights @ v
out.shape


# you can thing to attention as a communication mechanism: you have N nodes and you can think as graphNN aggregation step with particular weight matrix 
# attention is position/space agnosting, it's us that we provide positional info by summing pos_embeddings to the input (not like cnns that are space aware)

# IMPO: the attention above is "self-attention" cuz the k,q,v are all coming from the same input x
# cross-attention is when  


torch.Size([4, 8, 16])

In [31]:
weights

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5115, 0.4885, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2170, 0.5511, 0.2319, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2023, 0.2064, 0.2476, 0.3437, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1539, 0.2592, 0.1816, 0.2280, 0.1774, 0.0000, 0.0000, 0.0000],
         [0.2514, 0.1360, 0.1552, 0.1500, 0.1212, 0.1862, 0.0000, 0.0000],
         [0.1071, 0.2490, 0.1200, 0.1561, 0.0963, 0.1064, 0.1651, 0.0000],
         [0.1221, 0.1019, 0.0659, 0.1485, 0.1044, 0.1156, 0.1173, 0.2243]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.6200, 0.3800, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3634, 0.3304, 0.3062, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3164, 0.2543, 0.2079, 0.2213, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2759, 0.1713, 0.1802, 0.2785, 0.0942, 0.0000, 0.0000, 0.0000],
         [0.1189, 0.192

In [39]:
class Head(nn.Module):
    '''
    One single head of self attention
    '''

    def __init__(self, inp_dims, out_dims, ctx_len):
        super().__init__()
        self.key = nn.Linear(inp_dims, out_dims, bias=False)
        self.query = nn.Linear(inp_dims, out_dims, bias=False)
        self.value = nn.Linear(inp_dims, out_dims, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(ctx_len, ctx_len)))
    
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)

        # compute attention scores
        weights = q @ k.transpose(-2,-1) * k.shape[-1] ** -0.5 # supposing k and q having the same dim
        weights = weights.masked_fill(self.tril == 0, float('-inf')) # put -inf in all positions where the tril has 0s
        weights = weights.softmax(-1)

        # weighted aggregation of values
        v = self.value(x)
        out = weights @ v
        return out

In [40]:
# B, T, C = 4, 8, 32 # batch, tokens, token dimensionality
# x = torch.randn(B, T, C) # bs, ctx_len, token dimensionality 
# head = Head(C, 16, 8)
# x = head(x)
# x.shape

In [62]:
class BigramLanguageModel(nn.Module):
    '''
    The model learns each entry of a (vocab_len, vocab_len) table 
    where each entry is the probability dist of the following char given an input char at row
    '''
    def __init__(self, vocab_len, n_embed = 32):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table (bigram model lecture 2)    
        self.token_embedding_table = nn.Embedding(vocab_len, n_embed)

        # let's add a position embedding table, the idea is that we also want to embed the position of each token (and not onty its identity as usually done by older language models)
        self.position_embedding_table = nn.Embedding(ctx_len, n_embed) # across our whole max context length we have other n_embed vectors
        self.self_att_head = Head(n_embed, n_embed, ctx_len)
        self.lang_model_head = nn.Linear(n_embed, vocab_len)

    def forward(self, idx, targets=None):
        B, T = idx.shape # T is int:ctx_len 
        
        # idx and targets are int tensors of shape (bs, ctx_len)
        token_embeddings = self.token_embedding_table(idx) # (B, T, n_embed)
        pos_embeddings = self.position_embedding_table(torch.arange(T)) # (T, n_embed)
        x = token_embeddings + pos_embeddings # pos_embeddings broadcasted across batch dimension 

        x = self.self_att_head(x)

        logits = self.lang_model_head(x) # (B, T, n_embed)
        if targets == None: return logits, None        
            
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = targets.view(B*T)
        loss = F.cross_entropy(logits, targets)                
        return logits, loss
    
    # this function will change over the course of the lecture
    def generate(self, idx, max_new_tokens):
        # idx is a (bs, ctx_len) array of int-idxs that define the context
        # these int-idxs are chars from the vocab
        # in the bigram model only 1 char is looked at as ctx
        for _ in range(max_new_tokens):
            # new: crop idx to be the last ctx_size token
            idx_cond = idx[:, -ctx_len:]
            # predict i.e. get unnormalized probs
            logits, _ = self(idx_cond)
            # get last time step
            logits = logits[:, -1, :] # (bs, out_classes), -1 cuz rn we are using only the last char in the bigram model
            # normalize probs
            probs = F.softmax(logits, dim=-1)
            # sample
            idx_next = torch.multinomial(probs, num_samples=1) # (bs,1)
            idx = torch.cat((idx, idx_next), dim=1) # (bs, T+1)
        return idx
    
    # def generate(self, idx, max_new_tokens):
    # # idx is (B, T) array of indices in the current context
    #     for _ in range(max_new_tokens):
    #         # crop idx to the last block_size tokens
    #         idx_cond = idx[:, -ctx_len:]
    #         # get the predictions
    #         logits, _ = self(idx_cond)
    #         # focus only on the last time step
    #         logits = logits[:, -1, :] # becomes (B, C)
    #         # apply softmax to get probabilities
    #         probs = F.softmax(logits, dim=-1) # (B, C)
    #         # sample from the distribution
    #         idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
    #         # append sampled index to the running sequence
    #         idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    #     return idx


In [63]:
m = BigramLanguageModel(vocab_len=vocab_len)
optim = torch.optim.AdamW(m.parameters(), lr=1e-3)
xb, yb = get_batch('train')
logits, loss = m(xb, yb)
loss.item() # we know that the initial loss  must be -math.log(1/vocab_len) = 4.174387269895637

4.135059356689453

In [64]:
batch_size = 32
max_num_steps = 5001 #25001
eval_interval = 1000
for step in range(max_num_steps):
    xb, yb = get_batch('train')
    logits, loss = m(xb, yb)
    optim.zero_grad(set_to_none=True)
    loss.backward()
    optim.step()
    if step % eval_interval == 0:
        out = estimate_loss(m)
        print(f"Estimated train loss: {out['train']}, estimated val loss: {out['val']}")

Estimated train loss: 4.185403823852539, estimated val loss: 4.187540054321289
Estimated train loss: 2.7183759212493896, estimated val loss: 2.744436740875244
Estimated train loss: 2.584773063659668, estimated val loss: 2.5600268840789795
Estimated train loss: 2.5059614181518555, estimated val loss: 2.5117766857147217
Estimated train loss: 2.5088212490081787, estimated val loss: 2.5001912117004395
Estimated train loss: 2.46691632270813, estimated val loss: 2.466158390045166


In [67]:
context = torch.zeros((1, 8), dtype=torch.long)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))









Whipecelre.

ARHENoceiens hsel'dlle do!
AUSon.

N:
CESNGEGE:
pes ingm?

Iin:
A tathe jghintnge:
Cat int re mbriont;
Th ce sourn!
Fior vrithepergy othepid myot.

Cye plro le do?
Ou
s,
Sovore nism
IAned I it ther
Ang ach, bet oengdrendee, tchausich ongor fitof mworu, tot, V:
Ant ind Firt homt; thed, nde, we
AGUCOUKENGI's.

SI,
Whe, thepreer st alu hoven bee hse thinliche ingtwharl thirt ible,
Ft!l sleds y.
TOe snd tht id tature imbemeect yot:
easreich ot br merle ce'ded
INEses, the ng ku hanet edsme,
An dong ar ith, to l;
Yhave byer'ghe whest, at hteent,
O orust am
W;Yasto lod
G epcr cors hafre nd.

ARYe rdelepin!r th
Igf.

Ph, othevogth,
I;
I fepree o hoin me; o ghee nath it he chay bon chon mangis, tsenconcs Mut, met do teiowis tcthant
J; st, frfe onte sa'stel baureeangourus'ce wot sheds in ar rod ghe che,
I.

Hol yot'o,
Metay tid--t lhedl oor;ine thetathe the hougte ne wivett he wivencofowiny cfet lis iest hedl.

Tonoe tarsailousaveu lininld DI iur oruters, r whund,
S-t dounco

1) ~1:16 
2) restart from 1:22