### What if we increase the corpus

let's start reading more books

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [25]:
%run bookreader.py

In [29]:
br = BookReader(False, r'[^\d+a-zA-Z \n?!:,]')
br.read("tiny_shakespeare.txt")
vocab_size = br.vocab_size
vocab_size

p [^\d+a-zA-Z \n?!:,]


59

### Get the batch with both x and y unseparated

In [4]:
def get_batch(data, batch_length=5, batch_size=5):
    # generate a small batch of data of inputs x and targets y
    ix = torch.randint(len(data) - batch_length, (batch_size,))
    b = torch.stack([data[i:i+batch_length] for i in ix])
    return b

In [31]:
train = torch.tensor(br.data[0])

### Create an attention head

In [6]:
class Head(nn.Module):

    def __init__(self, c, head_size, content_length):
        super().__init__()
        self.key = nn.Linear(c, head_size, bias=False)
        self.query = nn.Linear(c, head_size, bias=False)
        self.value = nn.Linear(c, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(content_length, content_length)))

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   
        q = self.query(x)
        
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        
        v = self.value(x) # (B,T,hs)
        out = wei @ v
        return out

In [7]:
from collections import OrderedDict

class FeedForward(nn.Module):
    def __init__(self, fan_in, multiplier = 4):
        super().__init__()

        layers = OrderedDict([
            ("l_in", nn.Linear(fan_in, multiplier * fan_in)),
            ("relu", nn.ReLU()),
            ("l_out", nn.Linear(multiplier * fan_in, fan_in)),
        ])
        self.net = nn.Sequential(
            layers
        )

        initial = layers['l_in']
        nn.init.kaiming_normal_(initial.weight, nonlinearity="relu")
        layers['l_in'].weight.data = initial.weight.data * 3/5
        if initial.bias is not None:
            nn.init.constant_(initial.bias, 0)

        final = layers['l_out']
        layers['l_out'].weight.data = final.weight.data * .2
        if final.bias is not None:
            nn.init.constant_(final.bias, 0)

    def forward(self, x):
        return self.net(x)

In [8]:
class FFAttention(nn.Module):

    def __init__(self, embed_size, head_size, content_length):
        super().__init__()
        
        self.vocab_embed = nn.Embedding(vocab_size, embed_size)
        self.positional_embed = nn.Embedding(content_length, embed_size)
        self.attention = Head(embed_size, head_size, content_length)
        self.ff = FeedForward(head_size)
        self.decode = nn.Linear(head_size, vocab_size)
        self.content_length = content_length

    def forward(self, idx, targets=None):
        #idx B,T
        B, T = idx.shape

        idx_e = self.vocab_embed(idx)
        # note tr is always the same - so the learning here is information passed back to the positional_embed from loss
        tr = torch.arange(T)
        pos_e = self.positional_embed(tr)

        x = idx_e + pos_e
        x = self.attention(x)
        
        x = self.ff(x)
        
        logits = self.decode(x)

        if targets is None:
            loss = None
        else:
            loss = F.cross_entropy(logits.view(B*T, -1), targets.resize(B*T)) # loss function

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.content_length:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [9]:
import torch.optim as optim

In [15]:
class MultiHead(nn.Module):

    def __init__(self, num_heads, head_size, embed_size, content_length):
        super().__init__()
        self.heads = nn.ModuleList([Head(embed_size, head_size, content_length) for _ in range(num_heads)])

    def forward(self, x):
        out = torch.cat( [head(x) for head in self.heads], dim = -1 )
        return out
    

In [11]:
class FFMultiHeadAttention(nn.Module):

    def __init__(self, embed_size, content_length, num_heads, head_size, multiplier=4):
        super().__init__()
        
        self.vocab_embed = nn.Embedding(vocab_size, embed_size)
        self.positional_embed = nn.Embedding(content_length, embed_size)
        self.mutli_attention = MultiHead(num_heads, head_size, embed_size, content_length)
        self.lna = nn.LayerNorm(embed_size)
        self.ff = FeedForward(embed_size, multiplier)
        self.lnff = nn.LayerNorm(embed_size)
        self.decode = nn.Linear(embed_size, vocab_size)
        self.content_length = content_length

    def forward(self, idx, targets=None):
        #idx B,T
        B, T = idx.shape

        idx_e = self.vocab_embed(idx)
        # note tr is always the same - so the learning here is information passed back to the positional_embed from loss
        tr = torch.arange(T)
        pos_e = self.positional_embed(tr)

        x = idx_e + pos_e
        x = self.mutli_attention(x)
        # print("multi ball out", x.shape)
        x = self.lna(x)
        x = self.ff(x)
        x = self.lnff(x)
        # print("feed forward out", x.shape)
        logits = self.decode(x)
        # print("decode out", x.shape)
        # print("targets", targets.shape)
        # return None, None

        if targets is None:
            loss = None
        else:
            targets = targets.reshape(B*T)
            loss = F.cross_entropy(logits.view(B*T, -1), targets) #.resize(B*T)) # loss function

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.content_length:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [12]:
def get_val_batch(data, batch_length=5, batch_size=5, i=0):
    # generate a small batch of data of inputs x and targets y
    if i == 0:
        ix = torch.randint(len(data) - batch_length, (batch_size,))
    else:
        ix = torch.arange(1, 5) + 1 + i

    b = torch.stack([data[i:i+batch_length] for i in ix])
    return b

In [13]:
import math
@torch.no_grad()
def split_loss(split):
    split_len = len(split)
    total_loss = 0
    batch_size = 50
    num_batches = math.floor(split_len / batch_size)
    print("num_batches", split_len, num_batches)
    
    model.eval()

    for i in range(num_batches):

        t_b = get_val_batch(split, context_length+1, batch_size, i*batch_size)
        
        x = t_b[:, 0: context_length]
        y = t_b[:, context_length: context_length+1]
        
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, batch_loss = model(x, y)
        
        total_loss = total_loss + batch_loss
    
    print("total loss", total_loss, total_loss / num_batches)


### some examples from the past

of the great artist we are creating

In [73]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 100).data[0].tolist()
    print(names.decode(o))


 said  

at othe never exane yokle as lou mout ? 
 her appp aged 


trabs and at  che you of the a c


In [78]:
dev = torch.tensor(alice.data[1])
print(split_loss(dev), len(dev))

idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 100).data[0].tolist()
    print(names.decode(o))

num_batches 11075 221
total loss tensor(426.6404) tensor(1.9305)
None 11075

so
was alice bot by ran to lact the saout be he the don the knen  wen sail 


 you  alice and she cr


In [80]:
print(split_loss(dev), len(dev))

idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 100).data[0].tolist()
    print(names.decode(o))

num_batches 11075 221
total loss tensor(371.4741) tensor(1.6809)
None 11075

you
nice    and there much you d i bles of eyes  the
pig   belcome    nother look       
set  if you


In [30]:
epochs = 120
training_runs = 800
batch_size = 96
context_length = 12
learning_rate = .1
embedding_dimensions = 32
num_heads = 4
head_size = embedding_dimensions // num_heads

print(head_size)
# our embedding_dimensions are still 'small' so we mutliply the size our our feed forward network to make up
multiplier = 8
model = FFMultiHeadAttention(embedding_dimensions, context_length, num_heads, head_size, multiplier)
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

print(sum(p.numel() for p in model.parameters()), ' parameters')

lmbda = lambda epoch: 0.98

m_scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)

for ep in range(epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 5 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

8
24091  parameters
ep 0 tensor(2.4101, grad_fn=<DivBackward0>) [0.098]
ep 5 tensor(1.8489, grad_fn=<DivBackward0>) [0.0885842380864]
ep 10 tensor(1.7938, grad_fn=<DivBackward0>) [0.08007313507497958]
ep 15 tensor(1.7672, grad_fn=<DivBackward0>) [0.07237977205924956]
ep 20 tensor(1.7509, grad_fn=<DivBackward0>) [0.06542558123199924]
ep 25 tensor(1.7381, grad_fn=<DivBackward0>) [0.059139543518331866]
ep 30 tensor(1.7276, grad_fn=<DivBackward0>) [0.053457463299478813]
ep 35 tensor(1.7215, grad_fn=<DivBackward0>) [0.048321312820571644]
ep 40 tensor(1.7164, grad_fn=<DivBackward0>) [0.04367863958719317]
ep 45 tensor(1.7115, grad_fn=<DivBackward0>) [0.03948203069879567]
ep 50 tensor(1.7047, grad_fn=<DivBackward0>) [0.03568862864853744]
ep 55 tensor(1.7024, grad_fn=<DivBackward0>) [0.03225969364468526]
ep 60 tensor(1.6998, grad_fn=<DivBackward0>) [0.029160207983827797]
ep 65 tensor(1.6977, grad_fn=<DivBackward0>) [0.026358518435595345]
ep 70 tensor(1.6921, grad_fn=<DivBackward0>) [0.023826012

In [17]:
dev = torch.tensor(br.data[1])
print(split_loss(dev), len(dev))

idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 100).data[0].tolist()
    print(alice.decode(o))

num_batches 93300 1866
total loss tensor(3133.3733) tensor(1.6792)
None 93300

madamher must and your vick  
then ttale 
thou st  but weshome s prinute reglingbroke mard time  pro


In [18]:
torch.save(model, "all_shake_checkpoint")

In [32]:
e_epochs = 20

In [33]:
for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(3.0056, grad_fn=<DivBackward0>) [0.008676711527143824]
ep 2 tensor(2.3823, grad_fn=<DivBackward0>) [0.008333113750668928]
ep 4 tensor(2.2501, grad_fn=<DivBackward0>) [0.008003122446142439]
ep 6 tensor(2.1740, grad_fn=<DivBackward0>) [0.007686198797275197]
ep 8 tensor(2.1291, grad_fn=<DivBackward0>) [0.007381825324903099]
ep 10 tensor(2.0973, grad_fn=<DivBackward0>) [0.007089505042036937]
ep 12 tensor(2.0723, grad_fn=<DivBackward0>) [0.006808760642372274]
ep 14 tensor(2.0516, grad_fn=<DivBackward0>) [0.006539133720934331]
ep 16 tensor(2.0372, grad_fn=<DivBackward0>) [0.006280184025585331]
ep 18 tensor(2.0247, grad_fn=<DivBackward0>) [0.006031488738172152]


In [36]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 1090).data[0].tolist()
    print(br.decode(o))


mod the not it dee!
Harck ene, pret, farmith thine or 
you athey,
VINGS:
To 
Rer
VING sund my
Helf 
Or 
SICIUS:
The noful be thern
Nibent wibly colave have sire wommand licrfle mere, as
Yout s inteavet is ster wiper quest has that wine me poan a milly blon this but hapirsidin yaglion kuZans well to sold suct  my frossonarstile sive wrer s thouckee: risee:
Thands the 
Faing it beent ess,
Of agive dead hist wark dof goded you  of thy 
be 
Anre
And will loved show linge hathat that by siter sow be Sitilly sas, you had seach I your beam eat gre 
Be tade ecied thy onot the let: thee doft, I thesonsurghter will
God I him lot thouldet deatitatles on have pands a me hist a my be flancelose 
Mwordes retawed all 
And wasuedts for far lo! I at wick
Tis all we fice 
And thee with for a now dend frok no prok and wormined jrotent,
That will,
RUCHIONDET:
I t?
BASTRELETH:
Of whis gofus!
ind magght he for and, me tanter an of Edrodse,
That mace dgir succh tatter is thuntlemate, anarrds
The pring dose 

In [37]:
e_epochs = 30

for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(2.0127, grad_fn=<DivBackward0>) [0.005792641784140534]
ep 2 tensor(2.0019, grad_fn=<DivBackward0>) [0.005563253169488569]
ep 4 tensor(1.9949, grad_fn=<DivBackward0>) [0.005342948343976821]
ep 6 tensor(1.9864, grad_fn=<DivBackward0>) [0.005131367589555339]
ep 8 tensor(1.9848, grad_fn=<DivBackward0>) [0.004928165433008947]
ep 10 tensor(1.9738, grad_fn=<DivBackward0>) [0.004733010081861793]
ep 12 tensor(1.9680, grad_fn=<DivBackward0>) [0.004545582882620065]
ep 14 tensor(1.9641, grad_fn=<DivBackward0>) [0.00436557780046831]
ep 16 tensor(1.9580, grad_fn=<DivBackward0>) [0.004192700919569765]
ep 18 tensor(1.9578, grad_fn=<DivBackward0>) [0.004026669963154802]
ep 20 tensor(1.9517, grad_fn=<DivBackward0>) [0.003867213832613871]
ep 22 tensor(1.9483, grad_fn=<DivBackward0>) [0.0037140721648423617]
ep 24 tensor(1.9467, grad_fn=<DivBackward0>) [0.003566994907114604]
ep 26 tensor(1.9417, grad_fn=<DivBackward0>) [0.0034257419087928656]
ep 28 tensor(1.9394, grad_fn=<DivBackward0>) [0.0032

In [43]:
e_epochs = 30

for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(1.9374, grad_fn=<DivBackward0>) [0.003159795261048163]
ep 2 tensor(1.9356, grad_fn=<DivBackward0>) [0.0030346673687106558]
ep 4 tensor(1.9344, grad_fn=<DivBackward0>) [0.0029144945409097134]
ep 6 tensor(1.9306, grad_fn=<DivBackward0>) [0.0027990805570896884]
ep 8 tensor(1.9295, grad_fn=<DivBackward0>) [0.0026882369670289366]
ep 10 tensor(1.9286, grad_fn=<DivBackward0>) [0.0025817827831345905]
ep 12 tensor(1.9253, grad_fn=<DivBackward0>) [0.0024795441849224603]
ep 14 tensor(1.9218, grad_fn=<DivBackward0>) [0.0023813542351995304]
ep 16 tensor(1.9195, grad_fn=<DivBackward0>) [0.002287052607485629]
ep 18 tensor(1.9199, grad_fn=<DivBackward0>) [0.002196485324229198]
ep 20 tensor(1.9155, grad_fn=<DivBackward0>) [0.0021095045053897217]
ep 22 tensor(1.9151, grad_fn=<DivBackward0>) [0.0020259681269762884]
ep 24 tensor(1.9161, grad_fn=<DivBackward0>) [0.0019457397891480272]
ep 26 tensor(1.9158, grad_fn=<DivBackward0>) [0.0018686884934977653]
ep 28 tensor(1.9146, grad_fn=<DivBackward0

In [44]:
dev = torch.tensor(br.data[1])
print(split_loss(dev), len(dev))

idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 100).data[0].tolist()
    print(br.decode(o))

num_batches 85675 1713
total loss tensor(3273.4973) tensor(1.9110)
None 85675

Seady hoit with bee us 
GLYIA:
The your suck it it he the from 
NORTHAS:
I deatchagreak his for husm


In [46]:
e_epochs = 30

for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(1.9123, grad_fn=<DivBackward0>) [0.0017236187673607055]
ep 2 tensor(1.9081, grad_fn=<DivBackward0>) [0.0016553634641732215]
ep 4 tensor(1.9130, grad_fn=<DivBackward0>) [0.0015898110709919619]
ep 6 tensor(1.9086, grad_fn=<DivBackward0>) [0.0015268545525806802]
ep 8 tensor(1.9089, grad_fn=<DivBackward0>) [0.0014663911122984852]
ep 10 tensor(1.9066, grad_fn=<DivBackward0>) [0.0014083220242514652]
ep 12 tensor(1.9080, grad_fn=<DivBackward0>) [0.0013525524720911072]
ep 14 tensor(1.9082, grad_fn=<DivBackward0>) [0.0012989913941962993]
ep 16 tensor(1.9033, grad_fn=<DivBackward0>) [0.0012475513349861256]
ep 18 tensor(1.9043, grad_fn=<DivBackward0>) [0.001198148302120675]
ep 20 tensor(1.9063, grad_fn=<DivBackward0>) [0.0011507016293566962]
ep 22 tensor(1.9032, grad_fn=<DivBackward0>) [0.0011051338448341708]
ep 24 tensor(1.9029, grad_fn=<DivBackward0>) [0.0010613705445787376]
ep 26 tensor(1.9022, grad_fn=<DivBackward0>) [0.0010193402710134195]
ep 28 tensor(1.9007, grad_fn=<DivBackwar

In [None]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 1000).data[0].tolist()
    print(br.decode(o))

In [47]:
dev = torch.tensor(br.data[1])
print(split_loss(dev), len(dev))

num_batches 85675 1713
total loss tensor(3257.0881) tensor(1.9014)
None 85675


In [45]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 1000).data[0].tolist()
    print(br.decode(o))


 Yorterve 
QUEEN LOUCESTER:
Sech and fip BOMEO:
Onds that  fords:
misere oute pt s your well twaypose at that cabuts noweend: you for do con he of you hath sir, ifort brower news fier my cont her d gue deacks chougneshis you make 
KING RICHARD IICIOLAND:
My, Lech Ithenge let Ares, all of thy to aboortence 
But oplearfe as pooth hith I hear queent corrahe in like hear hand tidede him love trand, that RICHARD III:
My bef qual evesse havitentest book that the our conged, which his sers the can he bhy or him yeard,
I thre coors:
To mer than ariemerss thanger this wheng which it be 
BRUTHASTINCE Youghs sect cary dona a fathinted to disphvent a not neon, be antio  do strince vas hand agodst byied ungs,
Kind I tidlk in ll 
And meand is and it thow me preath it thereadew with boy: I in nffe king the rauing
Aver, anurage with not, theread to he not hy,
Clanalt, drair fellier thes my fain rup at promeir warme may cro holooss:
For As tesestme, ing heads I woom to therch ce grace,  morth osom  wi