In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import requests

In [3]:
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = requests.get(url)

with open("tinyshakespeare.txt", "wb") as file:
    file.write(response.content)

In [4]:
torch.manual_seed(1337)

<torch._C.Generator at 0x2baeca44190>

In [5]:
with open("tinyshakespeare.txt", "r") as file:
    text = file.read()

    # Print the length of the dataset in characters
print(f"Length of dataset: {len(text)} characters")

# Print the first 100 characters of the dataset
print("First 1000 characters of the dataset:")
print(text[:1000])

chars = sorted(list(set(text)))
encode = lambda s: [chars.index(c) for c in s]
decode = lambda e: ''.join([chars[i] for i in e])

# Print the vocabulary size
vocab_size = len(chars)

print(f"Vocabulary size: {vocab_size}")
print(f"Vocabulary: {''.join(chars)}")

Length of dataset: 1115394 characters
First 1000 characters of the dataset:
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for 

In [6]:
data = torch.tensor(encode(text), dtype=torch.long)

In [7]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:] # unbiased estimate of the mnodel performance

block_size = 8
print(decode(train_data[:block_size+1]))

x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print("when input is", context, "the target is", target)

First Cit
when input is tensor([18]) the target is tensor(47)
when input is tensor([18, 47]) the target is tensor(56)
when input is tensor([18, 47, 56]) the target is tensor(57)
when input is tensor([18, 47, 56, 57]) the target is tensor(58)
when input is tensor([18, 47, 56, 57, 58]) the target is tensor(1)
when input is tensor([18, 47, 56, 57, 58,  1]) the target is tensor(15)
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is tensor(47)
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is tensor(58)


In [8]:
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == "train" else val_data
    random_indices = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in random_indices])
    y = torch.stack([data[i+1:i+block_size+1] for i in random_indices])
    return x, y

xb, yb = get_batch('train')

In [9]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, inputs, targets=None):
        logits = self.token_embedding_table(inputs) # (B,T,C), batch  x time x channels. PyTorch expects B, T, C
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape                      
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, inputs, max_tokens):
        for _ in range(max_tokens):
            logits, _ = self(inputs)
            logits = logits[:, -1, :] # becomes (B, C)
            probs = F.softmax(logits, dim=1) # (B, C)
            char_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            inputs = torch.cat((inputs, char_next), dim=1) # (B, T+1)
        return inputs
    

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)
idx = torch.zeros((1, 1), dtype=torch.long)
print(decode(m.generate(idx, 100)[0].tolist()))


torch.Size([32, 65])
tensor(5.0364, grad_fn=<NllLossBackward0>)

l-QYjt'CL?jLDuQcLzy'RIo;'KdhpV
vLixa,nswYZwLEPS'ptIZqOZJ$CA$zy-QTkeMk x.gQSFCLg!iW3fO!3DGXAqTsq3pdgq


In [10]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3) # takes the gradients and updated the parameters. For a small network we can get away with larger learning rates, typically it would be something like 3-4

In [11]:
batch_size = 32
for steps in range(10000):
    xb, yb = get_batch('train')

    logits, loss = m(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
print(loss.item())

2.5589075088500977


In [12]:
print(decode(m.generate(torch.zeros((1, 1), dtype=torch.long), 300)[0].tolist()))


Ong h hasbe pave pirance
RDe hicomyonthar's
PES:
AKEd ith henourzincenonthioneir thondy, y heltieiengerofo'dsssit ey
KINld pe wither vouprroutherccnohathe; d!
My hind tt hinig t ouchos tes; st yo hind wotte grotonear 'so itJas
Waketancotha:
h hay.JUCLUKn prids, r loncave w hollular s O:
HIs; ht anjx


# The mathematical trick in self-attention

In [13]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In an autoregressive predictor, we want to predict the next character given the history. So for ith character, we want to predict ith+1 character, given the characters before i. Right now both the inputs are of shape (batch, time, channel), but the T tokens are currently not communicating, and ideally we would want them to. We want to combine the current and past data for the prediction task. In particular, we want to couple the tokens in such a way that each token communicates only with tokens from the past. For example, a token in the 5th location should not be able to communicate with tokens from 6th, 7th, 8th location, because these are future tokens.

The easiest way, although very lossy, is to average current token with the past tokens. The above implementation is also very inefficient due to for loops.

In [14]:
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]
        xbow[b, t] = torch.mean(xprev, 0)

xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

A more efficient approach invilves a triangular matrix:

In [15]:
wei = torch.tril(torch.ones(T, T))
wei = wei / torch.sum(wei, 1, keepdim=True)
wei


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [16]:
xbow = wei @ x # (T, T) @ (B, T, C) -> a batched matrix multiply. Torch will insert an additional batch dimension in front of the wei matrix, so it will be (B, T, T) @ (B, T, C) -> (B, T, C)

In [17]:
# version 3, using softmax
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

True

The softmax solution is optimal because now the weights wei are trainable. We can think of the weights as interactoin strength, affinities, telling how much of the tke ns from the past we want to aggregte. Setting to -inf means we cannot look to the future. The affinities between the tokens, initially set to zero, are now data dependent. The tokens are going to look at each other and some tokens will find other tokens more or less interesting. Depending on what the values of the tokens are, they are going to find each other interesting to different amounts. These interests are called affinities.