### Missing something again

what's the difference between multi head and a bigger head?

again the attention is linear unit - if we just make it wider isn't that the same as having mulitple smaller heads?

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
%run bookreader.py

In [3]:
br = BookReader(False, r'[^\d+a-zA-Z \n?!:,]')
br.read("tiny_shakespeare.txt")
vocab_size = br.vocab_size
vocab_size

59

### Get the batch with both x and y unseparated

In [4]:
def get_batch(data, batch_length=5, batch_size=5):
    # generate a small batch of data of inputs x and targets y
    ix = torch.randint(len(data) - batch_length, (batch_size,))
    b = torch.stack([data[i:i+batch_length] for i in ix])
    return b

In [5]:
train = torch.tensor(br.data[0])

### Create an attention head

In [6]:
class Head(nn.Module):

    def __init__(self, c, head_size, content_length):
        super().__init__()
        self.key = nn.Linear(c, head_size, bias=False)
        self.query = nn.Linear(c, head_size, bias=False)
        self.value = nn.Linear(c, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(content_length, content_length)))

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   
        q = self.query(x)
        
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        
        v = self.value(x) # (B,T,hs)
        out = wei @ v
        return out

In [7]:
from collections import OrderedDict

class FeedForward(nn.Module):
    def __init__(self, fan_in, multiplier = 4):
        super().__init__()

        layers = OrderedDict([
            ("l_in", nn.Linear(fan_in, multiplier * fan_in)),
            ("relu", nn.ReLU()),
            ("l_out", nn.Linear(multiplier * fan_in, fan_in)),
        ])
        self.net = nn.Sequential(
            layers
        )

        initial = layers['l_in']
        nn.init.kaiming_normal_(initial.weight, nonlinearity="relu")
        layers['l_in'].weight.data = initial.weight.data * 3/5
        if initial.bias is not None:
            nn.init.constant_(initial.bias, 0)

        final = layers['l_out']
        layers['l_out'].weight.data = final.weight.data * .2
        if final.bias is not None:
            nn.init.constant_(final.bias, 0)

    def forward(self, x):
        return self.net(x)

In [8]:
import torch.optim as optim

In [9]:
# class MultiHead(nn.Module):

#     def __init__(self, num_heads, head_size, embed_size, content_length):
#         super().__init__()
#         self.heads = nn.ModuleList([Head(embed_size, head_size, content_length) for _ in range(num_heads)])

#     def forward(self, x):
#         out = torch.cat( [head(x) for head in self.heads], dim = -1 )
#         return out

In [10]:
class MultiHeadJoin(nn.Module):
    
    def __init__(self, c, num_heads, head_size, content_length):
        super().__init__()
        self.query = nn.Linear(c, head_size * num_heads, bias=False)
        self.key = nn.Linear(c, head_size * num_heads, bias=False)
        self.value = nn.Linear(c, head_size * num_heads, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(content_length, content_length)))

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   
        q = self.query(x)
        
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        
        v = self.value(x) # (B,T,hs)
        out = wei @ v
        return out

In [11]:
class FFMultiHeadAttention(nn.Module):

    def __init__(self, embed_size, content_length, num_heads, head_size, multiplier=4):
        super().__init__()
        
        self.vocab_embed = nn.Embedding(vocab_size, embed_size)
        self.positional_embed = nn.Embedding(content_length, embed_size)
        self.mutli_attention = MultiHeadJoin(embed_size, num_heads, head_size, content_length)
        self.lna = nn.LayerNorm(embed_size)
        self.ff = FeedForward(embed_size, multiplier)
        self.lnff = nn.LayerNorm(embed_size)
        self.decode = nn.Linear(embed_size, vocab_size)
        self.content_length = content_length

    def forward(self, idx, targets=None):
        #idx B,T
        B, T = idx.shape

        idx_e = self.vocab_embed(idx)
        # note tr is always the same - so the learning here is information passed back to the positional_embed from loss
        tr = torch.arange(T)
        pos_e = self.positional_embed(tr)

        x = idx_e + pos_e
        x = self.mutli_attention(x)
        # print("multi ball out", x.shape)
        x = self.lna(x)
        x = self.ff(x)
        x = self.lnff(x)
        # print("feed forward out", x.shape)
        logits = self.decode(x)
        # print("decode out", x.shape)
        # print("targets", targets.shape)
        # return None, None

        if targets is None:
            loss = None
        else:
            targets = targets.reshape(B*T)
            loss = F.cross_entropy(logits.view(B*T, -1), targets) #.resize(B*T)) # loss function

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.content_length:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [12]:
def get_val_batch(data, batch_length=5, batch_size=5, i=0):
    # generate a small batch of data of inputs x and targets y
    if i == 0:
        ix = torch.randint(len(data) - batch_length, (batch_size,))
    else:
        ix = torch.arange(1, 5) + 1 + i

    b = torch.stack([data[i:i+batch_length] for i in ix])
    return b

In [13]:
import math
@torch.no_grad()
def split_loss(split):
    split_len = len(split)
    total_loss = 0
    batch_size = 50
    num_batches = math.floor(split_len / batch_size)
    print("num_batches", split_len, num_batches)
    
    model.eval()

    for i in range(num_batches):

        t_b = get_val_batch(split, context_length+1, batch_size, i*batch_size)
        
        x = t_b[:, 0: context_length]
        y = t_b[:, context_length: context_length+1]
        
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, batch_loss = model(x, y)
        
        total_loss = total_loss + batch_loss
    
    print("total loss", total_loss, total_loss / num_batches)


In [14]:
epochs = 60
training_runs = 800
batch_size = 96
context_length = 24
learning_rate = .1
embedding_dimensions = 32
num_heads = 4
head_size = embedding_dimensions // num_heads

print(head_size)
# our embedding_dimensions are still 'small' so we mutliply the size our our feed forward network to make up
multiplier = 8
model = FFMultiHeadAttention(embedding_dimensions, context_length, num_heads, head_size, multiplier)
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

print(sum(p.numel() for p in model.parameters()), ' parameters')

lmbda = lambda epoch: 0.98

m_scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)

for ep in range(epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 5 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

8
24475  parameters
ep 0 tensor(2.6336, grad_fn=<DivBackward0>) [0.098]
ep 5 tensor(1.8936, grad_fn=<DivBackward0>) [0.0885842380864]
ep 10 tensor(1.8267, grad_fn=<DivBackward0>) [0.08007313507497958]
ep 15 tensor(1.7958, grad_fn=<DivBackward0>) [0.07237977205924956]
ep 20 tensor(1.7743, grad_fn=<DivBackward0>) [0.06542558123199924]
ep 25 tensor(1.7640, grad_fn=<DivBackward0>) [0.059139543518331866]
ep 30 tensor(1.7515, grad_fn=<DivBackward0>) [0.053457463299478813]
ep 35 tensor(1.7420, grad_fn=<DivBackward0>) [0.048321312820571644]
ep 40 tensor(1.7336, grad_fn=<DivBackward0>) [0.04367863958719317]
ep 45 tensor(1.7270, grad_fn=<DivBackward0>) [0.03948203069879567]
ep 50 tensor(1.7207, grad_fn=<DivBackward0>) [0.03568862864853744]
ep 55 tensor(1.7152, grad_fn=<DivBackward0>) [0.03225969364468526]


In [15]:
dev = torch.tensor(br.data[1])
print(split_loss(dev), len(dev))

idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 100).data[0].tolist()
    print(br.decode(o))

num_batches 87625 1752
total loss tensor(2999.8904) tensor(1.7123)
None 87625

Then tenought him man MARGARET:
Is life, sid in out lords  you morthet your goody she to my lenex 
A


In [16]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(2):
    o = model.generate(idx, 400).data[0].tolist()
    print(br.decode(o))


arewn a no time
e reture defore himselve us the wellse a should be he sand prisong to crea:
Nay, and Hery lie dost!
Your here are he, lord:
I well, where our vault! he cold cause, dongues the atter d dear am tis  d reparing wall out a us minder
thy right some,  lof will
You houth her be the s to 
Lecondon
In 
ANGELO:
Not come a is it  FatY ANNE:
Priture mall think stiment of ight Luke
see
Is lewin

And you place mady  then eye,
Majesty, lender our you clubry mercius, I will this stmon of First: sound on to that arre at winst 
KING RICHARD IRGILIA:
The dstain, whroor you 
The and well 
ISABEKETBRANDA:
Lo!
LEONTG:
Or tis to hapart my of a nexternce: piecius no this 
SEBASLY:
Nothine s nother 
CAPTISTA:
Should fries benead by timpardn think your mispect
And be no 
Ohan good like 
The clease, an


In [17]:
e_epochs = 60

In [18]:
for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(1.7135, grad_fn=<DivBackward0>) [0.029160207983827797]
ep 2 tensor(1.7106, grad_fn=<DivBackward0>) [0.028005463747668217]
ep 4 tensor(1.7107, grad_fn=<DivBackward0>) [0.026896447383260556]
ep 6 tensor(1.7079, grad_fn=<DivBackward0>) [0.025831348066883437]
ep 8 tensor(1.7035, grad_fn=<DivBackward0>) [0.024808426683434852]
ep 10 tensor(1.7046, grad_fn=<DivBackward0>) [0.023826012986770832]
ep 12 tensor(1.7025, grad_fn=<DivBackward0>) [0.022882502872494704]
ep 14 tensor(1.7007, grad_fn=<DivBackward0>) [0.02197635575874391]
ep 16 tensor(1.7025, grad_fn=<DivBackward0>) [0.021106092070697653]
ep 18 tensor(1.6976, grad_fn=<DivBackward0>) [0.020270290824698025]
ep 20 tensor(1.6965, grad_fn=<DivBackward0>) [0.019467587308039984]
ep 22 tensor(1.6975, grad_fn=<DivBackward0>) [0.0186966708506416]
ep 24 tensor(1.6948, grad_fn=<DivBackward0>) [0.017956282684956193]
ep 26 tensor(1.6952, grad_fn=<DivBackward0>) [0.017245213890631928]
ep 28 tensor(1.6916, grad_fn=<DivBackward0>) [0.01656230

In [19]:
torch.save(model, "joined_head_bigger_shake_checkpoint")

In [21]:
dev = torch.tensor(br.data[1])
print(split_loss(dev), len(dev))

idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 100).data[0].tolist()
    print(br.decode(o))


num_batches 87625 1752
total loss tensor(2945.1343) tensor(1.6810)
None 87625

Fropess they name bright Horth glow,
Swomannot your for set, I lame, sawfmustifter Esw are mean befo


In [20]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 400).data[0].tolist()
    print(br.decode(o))


Romes more hot with slave 
This pray had, nor vow know: yet genry sword Look! AkF:
Give your a quire thy by out 
Torseds truil 
POMPEY:
Thy greatch
EDWARD:
May
Did ward, that eigue or see with manishpeses  fongery being kich his endemit
nee, betw is our like, loving gall keep hurse are weign,
Ath let right would you 
LADY VI:
March  when the honour 
For You shall not and vill tongue genour blive!



In [22]:
for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(1.6822, grad_fn=<DivBackward0>) [0.008676711527143824]
ep 2 tensor(1.6810, grad_fn=<DivBackward0>) [0.008333113750668928]
ep 4 tensor(1.6811, grad_fn=<DivBackward0>) [0.008003122446142439]
ep 6 tensor(1.6820, grad_fn=<DivBackward0>) [0.007686198797275197]
ep 8 tensor(1.6807, grad_fn=<DivBackward0>) [0.007381825324903099]
ep 10 tensor(1.6791, grad_fn=<DivBackward0>) [0.007089505042036937]
ep 12 tensor(1.6764, grad_fn=<DivBackward0>) [0.006808760642372274]
ep 14 tensor(1.6790, grad_fn=<DivBackward0>) [0.006539133720934331]
ep 16 tensor(1.6788, grad_fn=<DivBackward0>) [0.006280184025585331]
ep 18 tensor(1.6781, grad_fn=<DivBackward0>) [0.006031488738172152]
ep 20 tensor(1.6772, grad_fn=<DivBackward0>) [0.005792641784140534]
ep 22 tensor(1.6779, grad_fn=<DivBackward0>) [0.005563253169488569]
ep 24 tensor(1.6768, grad_fn=<DivBackward0>) [0.005342948343976821]
ep 26 tensor(1.6777, grad_fn=<DivBackward0>) [0.005131367589555339]
ep 28 tensor(1.6766, grad_fn=<DivBackward0>) [0.00492

In [23]:
dev = torch.tensor(br.data[1])
print(split_loss(dev), len(dev))

num_batches 87625 1752
total loss tensor(2926.6079) tensor(1.6704)
None 87625


In [24]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 1000).data[0].tolist()
    print(br.decode(o))


Take wild ll men 
GBRUCHESS OF YORK:
Carvoul sir Naphg leavent and their 
I countris, an 
ISABELLIO:
Bey to a uncharget! coves
To fair that if smelt, for 
And and me:
I ll the most
If of was  thoused or othy your we fathe myself the cred
A fraring 
But, one joy till  he than and prays to her 
KING EDWARD IV:
And othen being elsentry with sender s be? a and young you guarl, and offection 
Call 
HORTENSIO:
He woman:
Ay, pyraithes
The bestself thee me in  What wertings if?
BRUTUS:
On wideed:
Who premieu busistard:
Wherefooted Mitus lights and prience?
Then, st o my me
coom die before s say down:
By I will well us place forcemember
Where have reld, the all the gerfore you the body s robles,
And here:
Prove forge condence coung be of draws casing  he enour I neemb such his shangelo and of thou faves in die that my line save I owe host in thite news her Grinks, vidd by one perate doieute 
Ort suppast 
CARLE:
A dreak, seeks and brother, to and play lord he anow that ague now, then how oice t

In [25]:
for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(1.6700, grad_fn=<DivBackward0>) [0.0025817827831345905]
ep 2 tensor(1.6721, grad_fn=<DivBackward0>) [0.0024795441849224603]
ep 4 tensor(1.6704, grad_fn=<DivBackward0>) [0.0023813542351995304]
ep 6 tensor(1.6720, grad_fn=<DivBackward0>) [0.002287052607485629]
ep 8 tensor(1.6694, grad_fn=<DivBackward0>) [0.002196485324229198]
ep 10 tensor(1.6723, grad_fn=<DivBackward0>) [0.0021095045053897217]
ep 12 tensor(1.6700, grad_fn=<DivBackward0>) [0.0020259681269762884]
ep 14 tensor(1.6692, grad_fn=<DivBackward0>) [0.0019457397891480272]
ep 16 tensor(1.6680, grad_fn=<DivBackward0>) [0.0018686884934977653]
ep 18 tensor(1.6688, grad_fn=<DivBackward0>) [0.0017946884291552537]
ep 20 tensor(1.6699, grad_fn=<DivBackward0>) [0.0017236187673607055]
ep 22 tensor(1.6681, grad_fn=<DivBackward0>) [0.0016553634641732215]
ep 24 tensor(1.6669, grad_fn=<DivBackward0>) [0.0015898110709919619]
ep 26 tensor(1.6691, grad_fn=<DivBackward0>) [0.0015268545525806802]
ep 28 tensor(1.6676, grad_fn=<DivBackward

In [26]:
print(split_loss(dev), len(dev))

num_batches 87625 1752
total loss tensor(2920.3748) tensor(1.6669)
None 87625


In [27]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 1000).data[0].tolist()
    print(br.decode(o))


KATHARINA:
Liends greation quiticals!
So the to and you  heaven his love afford?
DUKE VINCENTIO:
Dubject is servedy eyet
cank is he so? one commed what desem,
With fater for the train hunstina
not remphas, which one bear art my be against childlebly care,
Unknity, with moror fairit
cive e, matting the but do citil pregarly were a to mock of your fly s
The lord:
And nother he prays a on not thy bid and!
Pready 
The kneward Scullaight, to send trivice,
That noth thou could me did such with thank, those thus
yee best, war, for theecond let where rife:
The as wit
Cleep: burn my have this of  wanto as s do:
The daugh as look d sare,
Whose majester shall Dukely 
That entresses I carew is d Come, my mind even to shall hape speak end in shriquieters to more 
Now will Jeach d bester you to  more an you well:
Then only 
GREMIl have then my cronaren time to but am desing should,
And tental im, so crown
Fall of tisters, in his is must statorn genry mother may inder abeenst:
Say,
Or trow with what