In [2]:
import torch


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
with open("input.txt") as f:
    whole_text = f.read()

print(whole_text[:1000])


First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



Let's get some info on our dataset.

In [4]:
chars = sorted(list(set(whole_text)))
vocab_size = len(chars)
print("".join(chars))
print(vocab_size)



 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


We tokenize by character, so let us create a map from chars to ints

In [5]:
string_to_int = {character: index for index, character in enumerate(chars)}
int_to_string = {index: character for index, character in enumerate(chars)}
def encode(string): return [string_to_int[char] for char in string]


def decode(list_int): return "".join(
    [int_to_string[integer] for integer in list_int])


print(encode("Hello world!"))
print(decode(encode("Hello world!")))


[20, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42, 2]
Hello world!


In [6]:
data = torch.tensor(encode(whole_text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])


torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

We split in train and validation data

In [7]:
number_train = int(0.9 * len(data))
train = data[:number_train]
validation = data[:number_train]

We set up the `block_size`, which is the maximum number of character we feed to the model at once. The entries in this direction are also thought of as time, as we are reading text through time.

In [8]:
block_size = 8


The idea is that the most likely next character is the one that follows the previous set of characters as seen in the training data. So when we train the transformer we simultaneously train it with multiple subcomponents at once: the second character is the most likely to follow the first, the third character is the most likely to follow the first and second...

In [9]:
x = train[:block_size]
y = train[1 : block_size + 1]
for t in range(block_size):
    context = x[: t + 1]
    target = y[t]
    print(f"When the input is {context} the target is {target}.")

When the input is tensor([18]) the target is 47.
When the input is tensor([18, 47]) the target is 56.
When the input is tensor([18, 47, 56]) the target is 57.
When the input is tensor([18, 47, 56, 57]) the target is 58.
When the input is tensor([18, 47, 56, 57, 58]) the target is 1.
When the input is tensor([18, 47, 56, 57, 58,  1]) the target is 15.
When the input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is 47.
When the input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is 58.


We can then be able to predict up to block size, and from there we have to truncate.

Now we do the batching. The idea would be to grab a random subsequence of the set (train or validation depending on what we're doing) at each step.

In [10]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8


def get_batch(split: str) -> tuple[torch.Tensor, torch.Tensor]:
    data = train if split == "train" else validation
    indices_start = torch.randint(len(data) - block_size - 1, (batch_size,))
    contexts = torch.stack([data[i : i + block_size] for i in indices_start])
    targets = torch.stack([data[i + 1 : i + block_size + 1] for i in indices_start])
    return contexts, targets


contexts, targets = get_batch("train")
print(f"Inputs: {contexts.shape}")
print(contexts)
print(f"\nTargets: {targets.shape}")
print(targets)
print("------")

for block in range(batch_size):
    for t in range(block_size):
        context = contexts[block, : t + 1]
        target = targets[block, t]
        print(f"When the input is {context.tolist()} the target is {target}.")

Inputs: torch.Size([4, 8])
tensor([[53, 59,  6,  1, 58, 56, 47, 40],
        [49, 43, 43, 54,  1, 47, 58,  1],
        [13, 52, 45, 43, 50, 53,  8,  0],
        [ 1, 39,  1, 46, 53, 59, 57, 43]])

Targets: torch.Size([4, 8])
tensor([[59,  6,  1, 58, 56, 47, 40, 59],
        [43, 43, 54,  1, 47, 58,  1, 58],
        [52, 45, 43, 50, 53,  8,  0, 26],
        [39,  1, 46, 53, 59, 57, 43,  0]])
------
When the input is [53] the target is 59.
When the input is [53, 59] the target is 6.
When the input is [53, 59, 6] the target is 1.
When the input is [53, 59, 6, 1] the target is 58.
When the input is [53, 59, 6, 1, 58] the target is 56.
When the input is [53, 59, 6, 1, 58, 56] the target is 47.
When the input is [53, 59, 6, 1, 58, 56, 47] the target is 40.
When the input is [53, 59, 6, 1, 58, 56, 47, 40] the target is 59.
When the input is [49] the target is 43.
When the input is [49, 43] the target is 43.
When the input is [49, 43, 43] the target is 54.
When the input is [49, 43, 43, 54] th

A note on the letters we use to indicate sizes from now on.
- `B` refers to batches and batch size.
- `T` refers to time, that is, to block size.
- `C` refers to channels, so the different possible outputs for an input which are paying attention to different features.

In [11]:
torch.manual_seed(1337)


class BigramLanguageModel(torch.nn.Module):
    def __init__(self, vocab_size: int) -> None:
        super().__init__()
        # Each token directly reads off the logits for the next token from a lookup table.
        self.token_embedding_table = torch.nn.Embedding(vocab_size, vocab_size)

    def forward(self, contexts: torch.Tensor, targets: torch.Tensor | None = None):
        # index_x and targets are both (B, T)
        logits = self.token_embedding_table(contexts)

        if targets is None:
            return logits, None

        B, T, C = logits.shape

        logits = logits.view(B * T, C)
        targets = targets.view(B * T)
        loss = torch.nn.functional.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, contexts: torch.Tensor, max_new_tokens: int):
        # contexts is of size (B, T), it's the array of indices of the current context.
        for _ in range(max_new_tokens):
            logits, _ = self(contexts)  # Get the predictions. (B, T, C)
            logits = logits[
                :, -1, :
            ]  # We focus on the last element only, as it's what we want.
            # It becomes size (B, C)
            probs = torch.softmax(
                logits, dim=-1
            )  # We transform the logits to probabilities.
            idx_next = torch.multinomial(
                probs, num_samples=1
            )  # We sample the next token, (B,1)
            contexts = torch.cat((contexts, idx_next), dim=1)  # (B, T+1)
        return contexts


model = BigramLanguageModel(vocab_size)
output, loss = model(contexts, targets)
print(output.shape)  # output of size (batch_size, block_size, channels)
print(loss)

context = torch.zeros((1, 1), dtype=torch.long)
print(decode(model.generate(context, 100)[0].tolist()))

torch.Size([32, 65])
tensor(4.8948, grad_fn=<NllLossBackward0>)

SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


Let's train it

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [13]:
batch_size = 32
for steps in range(10000):

    # Get a batch of data
    contexts, targets = get_batch("train")

    # Evaluate loss and go backward
    logits, loss = model(contexts, targets)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())


2.476459503173828


In [14]:
print(decode(model.generate(context, 300)[0].tolist()))



llo br. ave aviasurf my, may be ivee iuedrd whar ksth y h bora s be hese, woweee; the! KI 'de, ulseecherd d o blllando;

Whe, oraingofoff ve!
RIfans picsheserer hee anf,
TOFonk? me ain ckntoty dedo bo'llll st ta d:
ELIS me hurf lal y, ma dus pe athouo
By bre ndy; by s afreanoo adicererupa anse tecor


### The mathematical trick in self-attention

In [15]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

We want tokens to talk to each other, but only to those "in the past". Information should only flow from previous to current context, and never from the future. We can average the tokens from the past to use that information. That's pretty lossy but for now it will work. We can call the averaging of the past "bag of words" or "bow" for short.

 We do $x[b, t] = mean_{i\le t} x[b,i]$.

In [16]:
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, : t + 1]  # shape (t, C)
        xbow[b, t] = torch.mean(xprev, dim=0)

print(xbow.shape)

torch.Size([4, 8, 2])


In [17]:
x[0]


tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [18]:
xbow[0]


tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

How can we make these computations more efficient?

In [19]:
torch.tril(torch.ones(3, 3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [20]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a /= torch.sum(a, dim=1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(f"a={a}\n--")
print(f"b={b}\n--")
print(f"c={c}")

a=tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [21]:
xbow2 = torch.zeros((B, T, C))
weights = torch.tril(torch.ones(T, T))
weights /= torch.sum(weights, dim=1, keepdim=True)
xbow2 = weights @ x  # (B, T, T) @ (B, T, C) ---> (B, T, C)
torch.allclose(xbow, xbow2)

True

Another possibility is to use Softmax, which will also make every row sum to one.

In [22]:
triangular = torch.tril(torch.ones(T, T))
weights = torch.zeros((T, T))
weights = weights.masked_fill(triangular == 0, float("-inf"))
weights = torch.softmax(weights, dim=-1)
xbow3 = weights @ x  # (B, T, T) @ (B, T, C) ---> (B, T, C)
torch.allclose(xbow, xbow3)

True

Version 4: Self-attention

In [23]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32  # Batch, Time, Channels
x = torch.randn(B, T, C)

# We implement one Head of self-attention
head_size = 16
key = torch.nn.Linear(C, head_size, bias=False)
query = torch.nn.Linear(C, head_size, bias=False)
value = torch.nn.Linear(C, head_size, bias=False)
k = key(x)  # (B, T, 16)
q = query(x)  # (B, T, 16)
weights = q @ k.transpose(-2, -1)  # (B, T, 16) @ (B, 16, T) ----> (B, T, T)

triangular = torch.tril(torch.ones(T, T))
weights = weights.masked_fill(triangular == 0, float("-inf"))
weights = torch.softmax(weights, dim=-1)

v = value(x)
out = weights @ v

out.shape

torch.Size([4, 8, 16])

Notes:
- Attention is a __communication mechanism__. It can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights. The flow of information is always exclusively past to present.
- There is no notion of space. Attention simply acts over a set of vectors, irrespective of order. This is why we need positional encoding.
- Each example across different batches is processed independently without sharing information.
- "Self-attention" here means that keys and values are producced from the same source as the queries. In "cross-attention", queries are still produced from x but keys and values come from an external source (e.g. an encoder).
- "Scaled" attention also divides the weights by the square root of the head size. This makes it so that when Q and K are unit variance, so are the weights. As a consequence, the Softmax will not saturate too much.

In [24]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
weights = q @ k.transpose(-2,-1)

In [25]:
print(f"Keys variance: {k.var():.4f}; Queries variance: {q.var():.4f}; Weights variance: {weights.var():.4f}.")

Keys variance: 1.0449; Queries variance: 1.0700; Weights variance: 17.4690.


In [26]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
weights = q @ k.transpose(-2,-1) * head_size ** (-0.5)

In [27]:
print(f"Keys variance: {k.var():.4f}; Queries variance: {q.var():.4f}; Weights variance: {weights.var():.4f}.")

Keys variance: 0.9006; Queries variance: 1.0037; Weights variance: 0.9957.


If variance is high, we are more likely to get extreme values, thus the Softmax behaves closely to a one-hot vector. We want to avoid that.

In [28]:
# Example with low variance
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])

In [29]:
# Example with high variance
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]) * 8, dim=-1)

tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])

We will want to apply LayerNorm to normalise before attentions and feedforwards. This differs to the traditional approach followed in _Attention is all you need_, where normalisation happens after these steps, but it's favoured today. Normalising avoids problems with vanishing or exploding gradients as we go towards deeper networks.

In [30]:
class LayerNorm:
    def __init__(self, dim: int, eps: float = 1e-5) -> None:
        self.eps = eps
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        x_mean = x.mean(1, keepdim=True) # batch mean
        x_var = x.var(1, keepdim=True) # batch variance
        x_hat = (x - x_mean) / torch.sqrt(x_var + self.eps) # Normalize to unit variance
        self.out = self.gamma * x_hat + self.beta
        return self.out

    def parameters(self):
        return [self.gamma, self.beta]

In [32]:
torch.manual_seed(1337)
module = LayerNorm(100)
x = torch.randn(32, 100) # batch of size 32 of 100-dim vectors
x = module(x)
x.shape

torch.Size([32, 100])

In [33]:
x[:,0].mean(), x[:,0].std() # mean and std of one feature across all batches

(tensor(0.1469), tensor(0.8803))

In [34]:
x[0,:].mean(), x[0,:].std() # mean and std of one batch across its features (normalised)

(tensor(-9.5367e-09), tensor(1.0000))