In [5]:
# Dataset prep
with open("tiny_shakespeare.txt", "r", encoding="utf-8") as f:
    text = f.read()
print(len(text))

1115393


In [6]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [9]:
chars = sorted(list(set(text)))
print(len(chars))
print(''.join(chars))

65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [10]:
vocab_size = len(chars)

In [19]:
# Building the encoder and decoder
ctoi = {ch:idx for (idx, ch) in enumerate(chars)}
itoc = {idx:ch for (ch, idx) in ctoi.items()}
encode = lambda text: [ctoi[ch] for ch in text]
decode = lambda idxs: ''.join([itoc[idx] for idx in idxs])

In [21]:
decode(encode("I have a big schlong"))

'I have a big schlong'

In [24]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
data.shape, data[:100]

(torch.Size([1115393]),
 tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
         53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
          1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
         57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
          6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
         58, 47, 64, 43, 52, 10,  0, 37, 53, 59]))

In [26]:
# Split into train and validation
n = int(0.9 * len(data)) # 90% train set and 10% validation
train_set = data[:n]
val_set = data[n:]
val_set.shape, train_set.shape

(torch.Size([111540]), torch.Size([1003853]))

In [28]:
block_size = 8
train_set[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [29]:
# Time dimension -> Predicting the next character after a sequence
x = train_set[:block_size]
y = train_set[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    pred_next = y[t]
    print(f"For {context}, we predict {pred_next}")

For tensor([18]), we predict 47
For tensor([18, 47]), we predict 56
For tensor([18, 47, 56]), we predict 57
For tensor([18, 47, 56, 57]), we predict 58
For tensor([18, 47, 56, 57, 58]), we predict 1
For tensor([18, 47, 56, 57, 58,  1]), we predict 15
For tensor([18, 47, 56, 57, 58,  1, 15]), we predict 47
For tensor([18, 47, 56, 57, 58,  1, 15, 47]), we predict 58


In [57]:
torch.manual_seed(1337)
# Batching up data acros batch and time dimensions
# How big is the context length for predicting the next character
block_size = 8
# How many sequences we are stacking together to process in parallel
batch_size = 4

def get_split_batch(split: str):
    """Sample a `batch_size` number of sequences of length `block_size` along with their
    next character predcition from the desired `split` -> `train` or `val` data"""

    # If the split is not `train` or `val`, it is invalid
    assert split in ['train', 'val']
    dataset = train_set if split == 'train' else val_set

    # Sample `batch_size` count of random indexes from the data up to the last
    # index that is possible to issue a context of 8 elements
    idxs = torch.randint(0, len(dataset) - block_size, (batch_size,))

    # For each index, the context (or the input to the model) will be the sequence
    # of eight characters starting with that index
    x = torch.stack([dataset[idx:idx+block_size] for idx in idxs])
    # And the predictions will be the exact next character following that sequence
    y = torch.stack([dataset[idx+1:idx+block_size+1] for idx in idxs])
    return (x, y)

Xb, Yb = get_split_batch('train')

for b in range(batch_size):
    for t in range(block_size):
        context = Xb[b, :t+1]
        pred = Yb[b, t]
        print(f"For {context} we are predicting {pred}")

torch.Size([4, 8])
For tensor([53]) we are predicting 59
For tensor([53, 59]) we are predicting 6
For tensor([53, 59,  6]) we are predicting 1
For tensor([53, 59,  6,  1]) we are predicting 58
For tensor([53, 59,  6,  1, 58]) we are predicting 56
For tensor([53, 59,  6,  1, 58, 56]) we are predicting 47
For tensor([53, 59,  6,  1, 58, 56, 47]) we are predicting 40
For tensor([53, 59,  6,  1, 58, 56, 47, 40]) we are predicting 59
For tensor([49]) we are predicting 43
For tensor([49, 43]) we are predicting 43
For tensor([49, 43, 43]) we are predicting 54
For tensor([49, 43, 43, 54]) we are predicting 1
For tensor([49, 43, 43, 54,  1]) we are predicting 47
For tensor([49, 43, 43, 54,  1, 47]) we are predicting 58
For tensor([49, 43, 43, 54,  1, 47, 58]) we are predicting 1
For tensor([49, 43, 43, 54,  1, 47, 58,  1]) we are predicting 58
For tensor([13]) we are predicting 52
For tensor([13, 52]) we are predicting 45
For tensor([13, 52, 45]) we are predicting 43
For tensor([13, 52, 45, 43]

In [198]:
# Setting a benchmark -> Token embedding table
from torch import nn
from torch.nn import functional as F

n_embd = 32

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd):
        super().__init__()
        self.vocab_size = vocab_size
        # Number of embeddings for each element in vocab
        self.n_embd = n_embd
        # We are embedding the token and an individual identity token information
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        # We are also embedding the position of the token for each token in
        # the context length
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        

    def forward(self, idxs, targets=None):
        # `idxs` and `targets` are both (B, T) tensors of integers
        B, T = idxs.shape
        
        tok_emb = self.token_embedding_table(idxs) # (B, T, C) -> C = n_embd
        # Get the position embedding for all the tokens in the the context length
        # aka for all the timesteps from 0 to T
        # position embedding does not have a batch size because it is broadcasted
        # along for each element (context sequence) in the batch
        pos_emb = self.position_embedding_table(torch.arange(T)) # (T, C)
        x = tok_emb + pos_emb # with torch broadcasting we will have (B, T, C)
        logits = self.lm_head(x) # (B, T, vocab_size)

        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            # Torch expects that the C (channels/features) dimension is the second dimension
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss


    def generate(self, idxs, max_new_tokens):
        """Samples `max_new_tokens` next tokens from the model starting with `idxs`
        """
        for _ in range(max_new_tokens):
            # Crop idx to the las block_size tokens
            idxs = idxs[:, -block_size:]
            # Compute the forward pass
            # Get embeddings. Because we don't have any targets, this only returns the logits
            # and no loss
            logits, _loss = self(idxs)
            # Focus only on the last time step. This becomes the (B, C) of the last T
            logits = logits[:, -1, :]
            # Softmax along the C (channels) which are the last dimension
            probs = F.softmax(logits, dim = -1) # (B, C)
            # Sample from the distribution
            pred_idx = torch.multinomial(probs, num_samples = 1) # (B, 1)
            idxs = torch.cat([idxs, pred_idx], dim=1)

        return idxs
            
            
        


model = BigramLanguageModel(vocab_size, n_embd)
logits, loss = model(Xb, Yb)
print(logits.shape)
print(Yb.shape)
print(decode(model.generate(idxs=torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


torch.Size([256, 65])
torch.Size([32, 8])
q
bc CZV'


In [199]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [200]:
# Training
batch_size = 32
for _ in range(10000):
    # Get a new batch
    Xb, Yb = get_split_batch('train')
    logits, loss = model(Xb, Yb)
    # Zero out the gradient such that it does not accumulate between sessions
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    # Perform the optimisation
    optimizer.step()

print(loss.item())

2.524893045425415


In [201]:
print(decode(model.generate(idxs=torch.zeros((1,1), dtype=torch.long), max_new_tokens=500)[0].tolist()))

kerourexi


In [141]:
@torch.no_grad()
def eval_loss(num_iters = 200):
    """Evaluate the loss as an overage over a number of iterations for both splits"""
    out = {}
    # Put the model into evaluation mode
    model.eval()
    for split in ['train', 'val']:
        # First we start the losses at zero
        losses = torch.zeros(num_iters)
        for idx in range(num_iters):
            # Get the batch
            Xb, Yb = get_split_batch(split)
            # Do the forward pass
            logits, loss = model(Xb, Yb)
            # Save the loss for this batch
            losses[idx] = loss.item()
        # Average over the loss
        out[split] = losses.mean()
    # Put the model back in train mode
    model.train()
    return out

In [142]:
eval_loss()

{'train': tensor(2.4654), 'val': tensor(2.4858)}

In [144]:
# The mathematical trick in self-attention
B,T,C = 4,8,2
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [145]:
# For each token at position i, compute the average of all previous tokens
# Including the token at position i in order to get some information
# about the previous tokens

# We will call it bag of words
xbow = torch.zeros(B, T, C)

for b in range(B):
    for t in range(T):
        # Get all the elements from previous timesteps including this one
        # for this batch
        xprev = x[b, :t+1] # Gets us a (t, C) up to and including the current t
        # Average up over the time dimension
        xbow[b, t] = torch.mean(xprev, 0)

In [158]:
weights = torch.tril(torch.ones(T,T))
# Normalize `a` such that rows sum up to 1
weights = weights / torch.sum(weights, 1, keepdim=True)
xbow2 = weights @ x # (T, T) @ (B, T, C) -- torch--> (B, T, T) @ (B, T, C) -> (B, T, C)
torch.allclose(xbow2, xbow)

True

In [162]:
# Version 3: using softmax
tril = torch.tril(torch.ones(T,T))
weights = torch.zeros((T, T))
# Where elements have tril being 0, make them -inf
# This is saying that tokens from the future cannot aggregate
weights = weights.masked_fill(tril == 0, float('-inf'))
# Exponentiate everything and divide by the sum along the last dimension
weights = F.softmax(weights, dim = -1)
weights   

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [154]:
# Get the lower triangular part of the matrix
a = torch.tril(torch.ones(3,3))
# Normalize `a` such that rows sum up to 1
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(1, 10, (3,2)).float()
c = a @ b
print(a)
print(b)
print(c)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[3., 6.],
        [2., 4.],
        [1., 1.]])
tensor([[3.0000, 6.0000],
        [2.5000, 5.0000],
        [2.0000, 3.6667]])


In [209]:
# Self-attention
# The mathematical trick in self-attention
B,T,C = 4,8,32
x = torch.randn(B,T,C)

# Single Head of self-attention
head_size = 16
key = nn.Linear(C, head_size, bias = False)
query = nn.Linear(C, head_size, bias = False)
value = nn.Linear(C, head_size, bias = False)

# Compute the keys and the queries for each individual element
k = key(x) # (B, T, C) @ (C, 16) -> (B, T, 16)
q = query(x) # (B, T, C) @ (C, 16) -> (B, T, 16)

# Compute the weights as the queries dot product with all the keys
weights = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) --> (B, T, T)

tril = torch.tril(torch.ones(T,T))
#weights = torch.zeros((T, T))
# Where elements have tril being 0, make them -inf
# This is saying that tokens from the future cannot aggregate
weights = weights.masked_fill(tril == 0, float('-inf'))
# Avoid softmax pulling towards the highest value and keep the variance close to one
# weights = weights * head_size * -0.5
# Exponentiate everything and divide by the sum along the last dimension
weights = F.softmax(weights, dim = -1)

# V are the values that we aggregate for each previous token (instead of x ->
# the embedding of the token itself)
# V -> If you find me interesting, here is what I will communicate with you
v = value(x)
out = weights @ v
print(out.shape)
weights[0]

torch.Size([4, 8, 16])


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2797, 0.7203, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0242, 0.0320, 0.9438, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0544, 0.0050, 0.5041, 0.4365, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0138, 0.3714, 0.0222, 0.5536, 0.0389, 0.0000, 0.0000, 0.0000],
        [0.0723, 0.0536, 0.0797, 0.6174, 0.0887, 0.0884, 0.0000, 0.0000],
        [0.0478, 0.5255, 0.0132, 0.0039, 0.0927, 0.2363, 0.0806, 0.0000],
        [0.3353, 0.0412, 0.1092, 0.0131, 0.1214, 0.0876, 0.2816, 0.0105]],
       grad_fn=<SelectBackward0>)

In [None]:
# Query -> What am I looking for?
# Key -> What information I contain?

# I am a token at the current timestep. In order to find the affinities with all
# the other tokens before me I dot product my Query with all the Keys before me
# That dot product becomes the weights of the current token

In [210]:
class LayerNorm:
    def __init__(self, size, eps=1e-05):
        """Layer normalization defined accoding to the paper with the same name
        and torch docs. It is the same concept as batch norm, but along layers (rows)
        instead of batches (columns)

        Args:
            size: size of the layer and implictily this normalisation layer
            eps: a small variable to control that we are not dividing by zero.
        """
        self.size = size
        # Training parameters that get updated by the backward pass
        self.gamma = torch.ones(size)
        self.beta = torch.zeros(size)
        # Extra variables used to control the behaviour of the normalisation
        self.eps = eps


    def __call__(self, x_in):
        # No matter the dimensionality of the input, when we are computing normalizations,
        # the actual values we want to normalize are the last channel. As such, we must instruct
        # torch to take the proper dimensions
        
        # If we only have 2 dimensions, we want to reduce along the second dimension 
        if x_in.ndim == 2:
            dim = 1
        # If we have 3 dimensions, we want to reduce along the last 2 dimensions
        elif x_in.ndim == 3:
            dim = (1, 2)
            
        in_mean = x_in.mean(dim, keepdim=True)
        in_std = x_in.std(dim, keepdim=True)


        # Normalize the layer
        norm = (x_in - in_mean) / torch.sqrt(in_std + self.eps)
        # Compute the batch norm
        self.out = self.gamma * norm + self.beta
        
        return self.out

    def parameters(self):
        return [self.gamma, self.beta]

In [217]:
x_in = torch.randn(32, 100)
ln = LayerNorm(100)
x = ln(x_in)
x[0, :].mean(), x[0, :].std()

(tensor(4.7684e-09), tensor(1.0466))