### Missing something again

what's the difference between multi head and a bigger head?

again the attention is linear unit - if we just make it wider isn't that the same as having mulitple smaller heads?

Going to the paper:

*Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this*

so apparently the softmax function is the problem and that's why the multihead attention is used

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
%run bookreader.py

In [3]:
br = BookReader(False, r'[^\d+a-zA-Z \n?!:,]')
br.read("tiny_shakespeare.txt")
vocab_size = br.vocab_size
vocab_size

59

### Get the batch with both x and y unseparated

In [4]:
def get_batch(data, batch_length=5, batch_size=5):
    # generate a small batch of data of inputs x and targets y
    ix = torch.randint(len(data) - batch_length, (batch_size,))
    b = torch.stack([data[i:i+batch_length] for i in ix])
    return b

In [5]:
train = torch.tensor(br.data[0])

### Create an attention head

In [6]:
class Head(nn.Module):

    def __init__(self, c, head_size, content_length):
        super().__init__()
        self.key = nn.Linear(c, head_size, bias=False)
        self.query = nn.Linear(c, head_size, bias=False)
        self.value = nn.Linear(c, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(content_length, content_length)))

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   
        q = self.query(x)
        
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        
        v = self.value(x) # (B,T,hs)
        out = wei @ v
        return out

In [7]:
from collections import OrderedDict

class FeedForward(nn.Module):
    def __init__(self, fan_in, multiplier = 4):
        super().__init__()

        layers = OrderedDict([
            ("l_in", nn.Linear(fan_in, multiplier * fan_in)),
            ("relu", nn.ReLU()),
            ("l_out", nn.Linear(multiplier * fan_in, fan_in)),
        ])
        self.net = nn.Sequential(
            layers
        )

        initial = layers['l_in']
        nn.init.kaiming_normal_(initial.weight, nonlinearity="relu")
        layers['l_in'].weight.data = initial.weight.data * 3/5
        if initial.bias is not None:
            nn.init.constant_(initial.bias, 0)

        final = layers['l_out']
        layers['l_out'].weight.data = final.weight.data * .2
        if final.bias is not None:
            nn.init.constant_(final.bias, 0)

    def forward(self, x):
        return self.net(x)

In [12]:
import torch.optim as optim

In [13]:
# class MultiHead(nn.Module):

#     def __init__(self, num_heads, head_size, embed_size, content_length):
#         super().__init__()
#         self.heads = nn.ModuleList([Head(embed_size, head_size, content_length) for _ in range(num_heads)])

#     def forward(self, x):
#         out = torch.cat( [head(x) for head in self.heads], dim = -1 )
#         return out

In [12]:
class MultiHeadJoin(nn.Module):
    
    def __init__(self, c, num_heads, head_size, content_length):
        super().__init__()
        self.query = nn.Linear(c, head_size * num_heads, bias=False)
        self.key = nn.Linear(c, head_size * num_heads, bias=False)
        self.value = nn.Linear(c, head_size * num_heads, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(content_length, content_length)))

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   
        q = self.query(x)
        
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        
        v = self.value(x) # (B,T,hs)
        out = wei @ v
        return out

In [8]:
class FFMultiHeadAttention(nn.Module):

    def __init__(self, embed_size, content_length, num_heads, head_size, multiplier=4):
        super().__init__()
        
        self.vocab_embed = nn.Embedding(vocab_size, embed_size)
        self.positional_embed = nn.Embedding(content_length, embed_size)
        self.mutli_attention = MultiHeadJoin(embed_size, num_heads, head_size, content_length)
        self.lna = nn.LayerNorm(embed_size)
        self.ff = FeedForward(embed_size, multiplier)
        self.lnff = nn.LayerNorm(embed_size)
        self.decode = nn.Linear(embed_size, vocab_size)
        self.content_length = content_length

    def forward(self, idx, targets=None):
        #idx B,T
        B, T = idx.shape

        idx_e = self.vocab_embed(idx)
        # note tr is always the same - so the learning here is information passed back to the positional_embed from loss
        tr = torch.arange(T)
        pos_e = self.positional_embed(tr)

        x = idx_e + pos_e
        x = self.mutli_attention(x)
        # print("multi ball out", x.shape)
        x = self.lna(x)
        x = self.ff(x)
        x = self.lnff(x)
        # print("feed forward out", x.shape)
        logits = self.decode(x)
        # print("decode out", x.shape)
        # print("targets", targets.shape)
        # return None, None

        if targets is None:
            loss = None
        else:
            targets = targets.reshape(B*T)
            loss = F.cross_entropy(logits.view(B*T, -1), targets) #.resize(B*T)) # loss function

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.content_length:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [9]:
def get_val_batch(data, batch_length=5, batch_size=5, i=0):
    # generate a small batch of data of inputs x and targets y
    if i == 0:
        ix = torch.randint(len(data) - batch_length, (batch_size,))
    else:
        ix = torch.arange(1, 5) + 1 + i

    b = torch.stack([data[i:i+batch_length] for i in ix])
    return b

In [10]:
import math
@torch.no_grad()
def split_loss(split):
    split_len = len(split)
    total_loss = 0
    batch_size = 50
    num_batches = math.floor(split_len / batch_size)
    print("num_batches", split_len, num_batches)
    
    model.eval()

    for i in range(num_batches):

        t_b = get_val_batch(split, context_length+1, batch_size, i*batch_size)
        
        x = t_b[:, 0: context_length]
        y = t_b[:, context_length: context_length+1]
        
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, batch_loss = model(x, y)
        
        total_loss = total_loss + batch_loss
    
    print("total loss", total_loss, total_loss / num_batches)


In [14]:
import torch.optim as optim

In [15]:
epochs = 60
training_runs = 800
batch_size = 96
context_length = 12
learning_rate = .1
embedding_dimensions = 32
num_heads = 4
head_size = embedding_dimensions // num_heads

print(head_size)
# our embedding_dimensions are still 'small' so we mutliply the size our our feed forward network to make up
multiplier = 8
model = FFMultiHeadAttention(embedding_dimensions, context_length, num_heads, head_size, multiplier)
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

print(sum(p.numel() for p in model.parameters()), ' parameters')

lmbda = lambda epoch: 0.98

m_scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)

for ep in range(epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 5 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

8
24091  parameters
ep 0 tensor(2.6127, grad_fn=<DivBackward0>) [0.098]
ep 5 tensor(1.9511, grad_fn=<DivBackward0>) [0.0885842380864]
ep 10 tensor(1.8820, grad_fn=<DivBackward0>) [0.08007313507497958]
ep 15 tensor(1.8487, grad_fn=<DivBackward0>) [0.07237977205924956]
ep 20 tensor(1.8319, grad_fn=<DivBackward0>) [0.06542558123199924]
ep 25 tensor(1.8175, grad_fn=<DivBackward0>) [0.059139543518331866]
ep 30 tensor(1.8050, grad_fn=<DivBackward0>) [0.053457463299478813]
ep 35 tensor(1.7948, grad_fn=<DivBackward0>) [0.048321312820571644]
ep 40 tensor(1.7892, grad_fn=<DivBackward0>) [0.04367863958719317]
ep 45 tensor(1.7837, grad_fn=<DivBackward0>) [0.03948203069879567]
ep 50 tensor(1.7784, grad_fn=<DivBackward0>) [0.03568862864853744]
ep 55 tensor(1.7723, grad_fn=<DivBackward0>) [0.03225969364468526]


In [20]:
dev = torch.tensor(br.data[1])
print(split_loss(dev), len(dev))

idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 100).data[0].tolist()
    print(br.decode(o))

num_batches 89500 1790
total loss tensor(3164.0281) tensor(1.7676)
None 89500

Tne, to I?
ROMEO:
The Go, shall will laDit shumsand part, may, this I, 
as men, Wellievenrates Xinte


In [17]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(2):
    o = model.generate(idx, 1000).data[0].tolist()
    print(br.decode(o))


LOd like so,
But admother? but in s I shable, not so theire them lingood s outct  yover more 
Fildsir, illower a men, thy servance
Well and in Anne toom your night unter
Nayer
For your ever I wear all name 
GRUMIO:
Hest 
PRINCE:
First the Slack theirtue your named, goo 
POLIXCALUS:
Shall boy spenthereforet go poor:
Dut aime
M dange 
More for yiely be dow: I sambend were fare too trust oneed kinsmiled of fathin that propose vain  Geed by with wifer Tybalt force speak he disitter else
ward:
ISABELLK:
For the noble me of you face Post theice, whis he furquest is o with the held: of  Where 
AUS:
Even tonglying sonst grave unp as tyes bere her say shumble ward! this mine well maince, and ale in thatess I reveds
mean:
And his you she him then, let it all backeelse againnefirectsers fiendent of beam Malul so  Breatents a greit,
My my Hrading your broth
or home  colds!
GLOUCESTER:
Let pitessengind, genture:
No most the they where yetroush honour ree must I hate hound the more dower gleft you 

In [22]:
e_epochs = 60

In [23]:
for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(1.7653, grad_fn=<DivBackward0>) [0.029160207983827797]
ep 2 tensor(1.7666, grad_fn=<DivBackward0>) [0.028005463747668217]
ep 4 tensor(1.7636, grad_fn=<DivBackward0>) [0.026896447383260556]
ep 6 tensor(1.7595, grad_fn=<DivBackward0>) [0.025831348066883437]
ep 8 tensor(1.7598, grad_fn=<DivBackward0>) [0.024808426683434852]
ep 10 tensor(1.7575, grad_fn=<DivBackward0>) [0.023826012986770832]
ep 12 tensor(1.7574, grad_fn=<DivBackward0>) [0.022882502872494704]
ep 14 tensor(1.7533, grad_fn=<DivBackward0>) [0.02197635575874391]
ep 16 tensor(1.7510, grad_fn=<DivBackward0>) [0.021106092070697653]
ep 18 tensor(1.7510, grad_fn=<DivBackward0>) [0.020270290824698025]
ep 20 tensor(1.7536, grad_fn=<DivBackward0>) [0.019467587308039984]
ep 22 tensor(1.7524, grad_fn=<DivBackward0>) [0.0186966708506416]
ep 24 tensor(1.7520, grad_fn=<DivBackward0>) [0.017956282684956193]
ep 26 tensor(1.7512, grad_fn=<DivBackward0>) [0.017245213890631928]
ep 28 tensor(1.7484, grad_fn=<DivBackward0>) [0.01656230

In [24]:
torch.save(model, "joined_head_shake_checkpoint")

In [25]:
dev = torch.tensor(br.data[1])
print(split_loss(dev), len(dev))

idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 100).data[0].tolist()
    print(br.decode(o))


num_batches 89500 1790
total loss tensor(3102.5913) tensor(1.7333)
None 89500

Bution a s from outhrefriending,
Good incle shooding put,
Towoman,
To 
York, gone are with swords th


In [26]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 1000).data[0].tolist()
    print(br.decode(o))



Hank fear hither:
Ands quotters long be noble in that day, why lord maince, whith world
Mestle to thee 
And 
And 
What word my mothers, appeon our my my gone,
How dearns, as brighbornly son groon it, and seelman 
DROMEO:
There, and Horture mementle acth everying him and are husbanishally!
LUCENTIO:
No flace nove s 
To him, that hame to who not disgmil and know tell was man hearts mand vaint, well buy die the Morriefi,
Unhal?
SPERDONE:
Off you vower mensting thou Romeost morture we crroneather
Show shood my 
VOLUE OF YORK:
Mordon a a bettereour hastingman, d sidlence bread full as holdit the miscatches away hour the commannius in, Buck is them offecteosed to light your 
PAULINA:
All pies for of me, jois, by requen?
FATHASTINGS:
You by dricks ther
Show pleat 
Unhed, thanks talk remost no, we 
LADY CAPULET:
Strants befact some you desserved by shall face?
Sould death Lance! she 
Corians
Your you mallanators,
A the close, a noblands
Tell with very not fear,
I would suchal this boldius so 

In [27]:
for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(1.7341, grad_fn=<DivBackward0>) [0.008676711527143824]
ep 2 tensor(1.7351, grad_fn=<DivBackward0>) [0.008333113750668928]
ep 4 tensor(1.7354, grad_fn=<DivBackward0>) [0.008003122446142439]
ep 6 tensor(1.7313, grad_fn=<DivBackward0>) [0.007686198797275197]
ep 8 tensor(1.7313, grad_fn=<DivBackward0>) [0.007381825324903099]
ep 10 tensor(1.7318, grad_fn=<DivBackward0>) [0.007089505042036937]
ep 12 tensor(1.7306, grad_fn=<DivBackward0>) [0.006808760642372274]
ep 14 tensor(1.7320, grad_fn=<DivBackward0>) [0.006539133720934331]
ep 16 tensor(1.7302, grad_fn=<DivBackward0>) [0.006280184025585331]
ep 18 tensor(1.7282, grad_fn=<DivBackward0>) [0.006031488738172152]
ep 20 tensor(1.7297, grad_fn=<DivBackward0>) [0.005792641784140534]
ep 22 tensor(1.7281, grad_fn=<DivBackward0>) [0.005563253169488569]
ep 24 tensor(1.7274, grad_fn=<DivBackward0>) [0.005342948343976821]
ep 26 tensor(1.7271, grad_fn=<DivBackward0>) [0.005131367589555339]
ep 28 tensor(1.7274, grad_fn=<DivBackward0>) [0.00492

In [28]:
dev = torch.tensor(br.data[1])
print(split_loss(dev), len(dev))

num_batches 89500 1790
total loss tensor(3084.2842) tensor(1.7231)
None 89500


In [29]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 1000).data[0].tolist()
    print(br.decode(o))


But be 
SOMERSET:
Like is those no  cut dence:
But his dest sin, and but Dence be him kings 
Percepupon 
If yet mean:
Anfry with upplain Lord and here one justion
In take you all free, sir forwarl so
first:
I must  hits with oneen Berfolzen:
Why bore 
HORTENSIO:
Fere off I but 
CORIOLANUS:
Maner s, yet was in him, gentlementy be roor 
Pelosomes A
sure unhoners naves town, yet himseless and  which hence,
A Tucripe her to theress true with, I am my you masters evenged swelves had a heart, ange have ricking 
PROSS:
Secoran
Did for sit flament must of get, wife, and Pholes of an unbound set no mistic, a beceing, lasul tter puts in I cI:
Awby than power thy yethor retters, stor the one, sir, and thummade will  
NORTENSIO:
You coung then I,
Is chargour at the due mine beggainal meth: the that mousand my Romes prining?
With child now myself us, lord,
TRANIO:
Persely deep to boy then his this what stanger thrights:
It they: Rangers? should by sent offerer canne thy broyals,
Aftus, sirectides?