# Imports

In [24]:
import os
import wandb
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
# Load environment variables
%load_ext dotenv
%dotenv

# Setting Up Vocabulary

In [33]:
# Load the data
with open('../data/input.txt', 'r', encoding='utf-8') as inf:
    text = inf.read()
print(f'Length of dataset in characters: {len(text)}')

Length of dataset in characters: 1115394


In [34]:
# Set up the vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(f'Vocabulary size: {len(chars)}')


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Vocabulary size: 65


In [41]:
# Set up the encoder and decoder
stoi = { ch:i for i, ch in enumerate(chars) }
itos = { i:ch for i, ch in enumerate(chars) }
encode = lambda s: [stoi[ch] for ch in s] # str -> List[int]
decode = lambda l: ''.join([itos[i] for i in l])   # List[int] -> str

# Loading Data

In [3]:
run = wandb.init(name='train-bigram-baseline', job_type='training')
artifact = run.use_artifact('nanogpt/mini-shakespeare-tensors:latest')
data_dir = artifact.download()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmartmichals[0m ([33mmartymcfly[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m:   2 of 2 files downloaded.  


In [9]:
# Load the datasets
train = torch.load(os.path.join(data_dir, 'train.pt'))
val   = torch.load(os.path.join(data_dir, 'val.pt'))

In [13]:
# Context window size
block_size = 8
train[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

Below are examples of training samples for the transformer. We arrange the training samples this way s.t. at runtime, the model is used to seeing context sizes of varying lengths, from 1 token of context all the way up until `block_size` tokens of context.

In [16]:
# Show training samples
x = train[:block_size]
y = train[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'When input is {context} the target is: {target}')

When input is tensor([18]) the target is: 47
When input is tensor([18, 47]) the target is: 56
When input is tensor([18, 47, 56]) the target is: 57
When input is tensor([18, 47, 56, 57]) the target is: 58
When input is tensor([18, 47, 56, 57, 58]) the target is: 1
When input is tensor([18, 47, 56, 57, 58,  1]) the target is: 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is: 58


Here we define a function to sample batches from the training or validation datasets. The returned samples are all `block_size` tokens long, with offset prediction tokens. Refer to this later in the training process, since the above implies the training process needs to somehow learn across varying input sequence lengths.

In [22]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split: str):
    # Generate a batch of inputs x and targets y
    data = train if split == 'train' else val
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

# Show a batch sample
xb, yb = get_batch('train')
print('Input:')
print(xb.shape)
print(xb)
print('Output:')
print(yb.shape)
print(yb)
print('-------\nEnumeration of all training samples')

# Show all the training samples, xb is 4 x 8, which means we have 32 independent training examples, enumerated below
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f'When input is {context} the target is: {target}')

Input:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
Output:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
-------
Enumeration of all training samples
When input is tensor([24]) the target is: 43
When input is tensor([24, 43]) the target is: 58
When input is tensor([24, 43, 58]) the target is: 5
When input is tensor([24, 43, 58,  5]) the target is: 57
When input is tensor([24, 43, 58,  5, 57]) the target is: 1
When input is tensor([24, 43, 58,  5, 57,  1]) the target is: 46
When input is tensor([24, 43, 58,  5, 57,  1, 46]) the target is: 43
When input is tensor([24, 43, 58,  5, 57,  1, 46, 43]) the target is: 39
When input is tensor([44]) the target is: 53
When input is tensor([44, 53]) the target is: 56
W

In [23]:
# Show sample
print(xb)

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])


# Mathematical Trick
We only want tokens to attend to previous tokens, i.e. $x_t$ can only attend to tokens $x_{\{1, 2, \dots, t-1\}}$

In [86]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2  # batch, time, channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

One way to do this is to compute the average representation of the preceeding tokens.

In [89]:
# x[b, t] = mean_{i<=t} x[b, i]
# This does not consider the ordering of the previous tokens, hence BOW
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C)
        xbow[b,t] = torch.mean(xprev, 0)
xbow

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

In [94]:
# version 2, using matrix mulitplication
tril = torch.tril(torch.ones(T, T))
wei = tril
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (B*, T, T) @ (B, T, C) -> (B, T, C)
torch.allclose(xbow, xbow2, atol=1e-07)

True

In [100]:
# version 3, different method of matrix generation
# we start at zero here, in the future this matrix will be data-dependent, representing inter-token affinities (i.e. how much to attend to each token)
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3, atol=1e-07)

True

In [105]:
# version 4, self attention
# we now make the weights for averaging data-dependent
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B,T,C)

# single-head self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)   # matrix mulitply with fixed weights, no bias 
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
# note that all tokens are transformed independently
k = key(x)   # (B, T, head_size)
q = query(x) # (B, T, head_size)
v = value(x) # (B, T, head_size)
wei = q @ k.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)

tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T, T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ v

# hear, out of dim (B, T, head_size), represents, for each time step, a value "predicted" using information from the past w/self-attention weightings
# this is how we end up with B*T samples for each sequence sampled from the training corpus

out.shape

torch.Size([4, 8, 16])

Notes

- Attention is a communication mechanism. Can view as nodes in a directed graph looking at each other and aggregating information via weighted sums of representations from nodes pointing to a given node.
- There is no notion of space here in the vector representations, hence the positional encodings.
- Batches are independent of one another.
- This is a decoder block, as we are training to predict a value at each time step, the next proper value whilst only attending to the previous tokens. We can turn this into an encoder block by commenting the `tril` masking line, allowing tokens to attend to tokens before and after the current time step.
- "self-attention" since k, q, v all come from the same x. Sometimes k, v can come from a separate source of nodes, in which case we call this "cross-attention"
- scaled attention additionally divides `wei` by $\frac{1}{\sqrt{\text{hsize}}}$. This makes it s.t. when input Q, K are unit variance, `wei` will be unit variance as well and the softmax will stay diffuse and not saturate too much, see below. W/o diffused vectors, softmax will converge to one-hot vectors, which is not desired behavior

In [107]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
wei = q @ k.transpose(-2, -1)

In [108]:
print(k.var(), q.var(), wei.var())

tensor(1.0700) tensor(0.9006) tensor(18.0429)


In [109]:
wei = wei * (head_size**-0.5)

In [110]:
print(k.var(), q.var(), wei.var())

tensor(1.0700) tensor(0.9006) tensor(1.1277)


Below we highlight how matrix mulitplication may be used in order to calculate an average.

In [82]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print('a')
print(a)
print('--')
print('b')
print(b)
print('--')
print('c')
print(c)
print('--')

a
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])
--


False