In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Prep some word data (Shakespeare)

In [2]:
# read it in to inspect it
with open('../makemore_data/tinyshakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115393


In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [5]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hello world"))
print(decode(encode("hello world")))

[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]
hello world


In [8]:
# convert data to tensor
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:5]) # the 1000 characters we looked at earier will to the GPT look like this

torch.Size([1115393]) torch.int64
tensor([18, 47, 56, 57, 58])


In [9]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [11]:
# Block-size
block_size = 8 ##maximum sequence length (though we add +1)
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


In [273]:
torch.manual_seed(1337)
batch_size = 1 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(split, block_size=8):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print('targets:')
print(yb.shape)
print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([1, 8])
targets:
torch.Size([1, 8])
----
when input is [53] the target: 59
when input is [53, 59] the target: 6
when input is [53, 59, 6] the target: 1
when input is [53, 59, 6, 1] the target: 58
when input is [53, 59, 6, 1, 58] the target: 56
when input is [53, 59, 6, 1, 58, 56] the target: 47
when input is [53, 59, 6, 1, 58, 56, 47] the target: 40
when input is [53, 59, 6, 1, 58, 56, 47, 40] the target: 59


# Self-Attention

Based on Attention is All you Need: https://arxiv.org/abs/1706.03762


## the mathematical trick (weighted aggregation by matrix multiplication)

In [14]:
torch.manual_seed(1337)
B,T,C = 4,8,2 #batch, time, channels/features
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [26]:
# We want tokens to share context with previous time-steps
# A simple/weak way would be to average the information from the preceding time-steps
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] #select prev-timesteps, and the current one. out: (t, C)
        xbow[b, t] = torch.mean(xprev, 0) #average C across time

In [31]:
# We can do this more efficiently using matrix multiplication
wei = torch.tril(torch.ones(T, T)) #weights of weighted aggregation
wei = wei / wei.sum(dim=1, keepdim=True)
xbow2 = wei @ x #(T, T) @ (B, T, C) ---> (B, T, C)

# Same result as above, via MatMul (much faster!)
torch.allclose(xbow, xbow2)

True

In [33]:
# Why does this work???
# toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
torch.manual_seed(42)
# adding tril, for a triangle matrix,allows us to zero out "future" elements,
# essentially filtering out the future time-steps
a = torch.tril(torch.ones(3, 3)) 
# Modifying this initial matrix so that each row is 1/row_number weights the sums,
# so that they sum to 1, which means c = a @ b gives us an average
# (torch.sum, dim=1 keepdims=True, gives us (1+0+0 = 1), (1+1+0 = 2), (1+1+1 = 3))
a = a / torch.sum(a, dim=1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('a @ b = c:')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
a @ b = c:
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [35]:
# We can also modify this process for the same result
# masked_fill : for all elements where tril == 0 become -inf
# we then softmax each row, which normalizes - 0s become one, -inf becomes 0, and we sum and divide
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = torch.nn.functional.softmax(wei, dim=1)

print(wei)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


In [None]:
# What if we learned the weights of wei (or 'a' above) instead of setting them to be the mean??

## Self-Attention (Single Head)

In [128]:
# Version 4: self-attention! (for a singular head)

# We don't want our 'wei' matrix to be uniform/pre-determined, we want it to be data-dependent!
# Every single token at each position will emit 3 vectors: a query, a key, and a value
#   query: contains info on what the token is looking for
#   key  : contains info on what the token contains
#   value: x is the 'private' token information, and value determines what to share/communicate when looking at the keys of other tokens 
# We get 'affinities'/attention scores between tokens in the sequence via a dot product between the 
# token's query and the keys of all the other tokens!
# If the key and query are well-aligned, they will interact more, and wei will have a higher value at
# their combined positions, so when we softmax, we will get much more of the info from that token's features

torch.manual_seed(1337)
B, T, C = 4,8,32
x = torch.randn(B,T,C)

# a single Head
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
# retrieve key and query
k = key(x)    # B, T, head_size
q = query(x)  # B, T, head_size

wei = q @ k.transpose(-2, -1) #transpose last 2 dims, so (B, T, 16) @ (B, 16, T) = (B, T, T)



tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril == 0, float('-inf')) #still want to prevent "seeing the future"
wei = torch.nn.functional.softmax(wei, dim=1)   #aggregate so weights sum to 1

v = value(x)
out = wei @ v

out.shape

torch.Size([4, 8, 16])

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an **"encoder"** attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a **"decoder"** attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

In [137]:
# Why scale attention? (by 1 / sqrt(head_size))
# Unit gaussian inputs
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
# Wei
wei_ = q @ k.transpose(-2, -1) 
# Wei scaled
wei = wei_ * head_size**-0.5

print(f"k variance = {k.var()} | q variance = {q.var()}")
print(f"wei variance pre-scale = {wei_.var()}, post-scale = {wei.var()}")

# if Softmax receives rather large values, it will converge towards the largest number
# ie, Softmax is too "peaky" if we don't control the variance
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1)

k variance = 0.9873070120811462 | q variance = 1.009503960609436
wei variance pre-scale = 15.663772583007812, post-scale = 0.9789857864379883


tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])

## Multi-Head Attention

Simply applying multiple self-attention heads in parallel, and concatenating the results.
Why?
It helps to have multiple channels of communication for tokens to odentify alike/useful tokens.

## FeedForward Layers + Blocks
Once the tokens have had time to look at each-other, we want to add layers which allow computation on the features created by the attention process.  
In other words, we give the network some time to make sense of the information it's extracted.  

Transformer networks intersperse attention layers ("communication") with feed-forward ("computation") layers.  

But, as the network gets deeper, we need to help counteract optimization issues via:
1. skip/residual connections
    (We transform the data, but then add the original features back to the transformed data)  
    In practice, this means we have a "residual pathway" which we can fork off via transformations, but which then come back to the pathway via addition.    
    Then, in backwards pass, addition forks the gradient equally.  
2. In the MultiHeadAttention layers and FeedForwards layers, we want add a "projection" layer to project the attention's output back into the residual pathway.  
 


## LayerNorm
Similair to BatchNorm, except that instead of normalizing across the batch, we normalize across the sample (so we don't need any running buffers, no distinction between training and test time...)  
"Normalizes the rows instead of the columns"  

In departure from Attention is All You Need, it is more common now to apply layer norm before the feed-forward step... the "pre-norm formulation"  
In our case, it is essentially normalizing the 'per-token' features, treating both the batch and time dimensions as batch dimensions.  

We add LayerNorm to our Transformer Blocks and we add one final layer norm after all the blocks and before the final linear layer.  

In [275]:
class AttHead(nn.Module):
    def __init__(self, block_size, head_size, in_feats, dropout) -> None:
        super().__init__()
        self.key = nn.Linear(in_feats, head_size, bias=False)
        self.query = nn.Linear(in_feats, head_size, bias=False)
        self.value = nn.Linear(in_feats, head_size, bias=False)
        # buffer: since tril is not a module, assign using register buffer
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        # Compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) * C**-0.5 #scaled, out:(B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) #(B,T,T)
        wei = F.softmax(wei, dim=1)
        wei = self.dropout(wei)
        # weigted aggregation
        v = self.value(x)
        out = wei @ v
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, block_size, num_heads, head_size, in_feats, dropout) -> None:
        super().__init__()
        self.heads = nn.ModuleList([
            AttHead(block_size, head_size, in_feats, dropout) for _ in range(num_heads)
        ])
        self.proj = nn.Linear(in_feats, in_feats)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Concatenate over channel/feature dimension (last dim)
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return self.dropout(out)


class FeedForward(nn.Module):
    def __init__(self, in_feats, dropout) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_feats, 4*in_feats), #times 4, as in Attention is All You Need
            nn.ReLU(),
            # project back into residual pathway
            nn.Linear(4*in_feats, in_feats),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """Transformer block
    LayerNorm
    MultiHeadAttention
    LayerNorm
    FeedForward
    """
    def __init__(self, block_size, in_feats, n_heads, dropout) -> None:
        super().__init__()
        head_size = in_feats // n_heads
        self.ln1 = nn.LayerNorm(in_feats)
        self.sa = MultiHeadAttention(block_size, n_heads, head_size, in_feats, dropout)
        self.ln2 = nn.LayerNorm(in_feats)
        self.ffwd = FeedForward(in_feats, dropout)
    
    def forward(self, x):
        x = x + self.sa(self.ln1(x)) #x+... = skip connections
        x = x + self.ffwd(self.ln2(x))
        return x


class TransformerModel(nn.Module):

    def __init__(self, vocab_size, block_size, n_embed, 
                       n_layers, heads_per_layer,
                       dropout):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(
            *[Block(block_size, n_embed, heads_per_layer, dropout) for _ in range(n_layers)]
            )
        self.ln_f = nn.LayerNorm(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)


    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,n_embed)
        pos_emb = self.position_embedding_table(torch.arange(T)) # (T, n_embed)
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x) # (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, block_size, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to be last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


In [276]:
n_embed    = 32      #
block_size = 256
n_layers   = 4
heads_per_layer = 4
dropout = 0.

m = TransformerModel(vocab_size=vocab_size,
                     block_size=block_size,
                     n_embed=n_embed,
                     n_layers=n_layers,
                     heads_per_layer=heads_per_layer,
                     dropout=dropout
                     )

optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [286]:
batch_size = 32
m.train()
for steps in range(1): # increase number of steps for good results... 
    
    # sample a batch of data
    xb, yb = get_batch('train', block_size)

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print("train loss", loss.item())

m.eval()
xv, yv = get_batch('valid', block_size)
logits, vloss = m(xv, yv)
print("val loss", vloss.item())

train loss 2.868577003479004
val loss 2.948580026626587


In [289]:
strt_tok = torch.zeros((1, 1), dtype=torch.long)
print(
    decode(m.generate(idx=strt_tok, block_size=block_size, max_new_tokens=500)[0].tolist())
    )


Neeoine e e oao,-r?  p v3vrjumemrobotXwtallaOJathE lisylheO
Je rmct mlralrie ,VofohanonhooQon eo
As moI ey,
& !Q iroIXwerendnoee nd lgeu,py'cg T3Ysork esgseromi t' s at thfs mmled Tf bsd rthmn g pelmn foms
T
Tge s Jl tke er:
Lhowt'sr loureusne e, e t ?
ElelBd $lrayaRkl'uafn$iG tkofwYof e homnRlrohin n hOhIKin,Oh
Ue d tenB :
-leaset Sbin,r ore ?CGcI
TSeeSdhesKoheac, iy fwviwcOncthos re ': mtVd I sI,t. l y;
cy wa acore, d o athnng, eesreri r weriy:

jy .
Wir h cisy sthf I:
ohw,ohs s,l ssh, io 'alf


In [290]:
# Time 1:42:32 vid