### What if we increase the corpus

let's start reading more books

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
%run bookreader.py

In [3]:
alice = BookReader("Through the looking glass.txt", "a tangled.txt", "Hunting of the snark.txt", "alice.txt")
vocab_size = alice.vocab_size
vocab_size

41

### Get the batch with both x and y unseparated

In [4]:
def get_batch(data, batch_length=5, batch_size=5):
    # generate a small batch of data of inputs x and targets y
    ix = torch.randint(len(data) - batch_length, (batch_size,))
    b = torch.stack([data[i:i+batch_length] for i in ix])
    return b

In [5]:
train = torch.tensor(alice.data[0])

### Create an attention head

In [6]:
class Head(nn.Module):

    def __init__(self, c, head_size, content_length):
        super().__init__()
        self.key = nn.Linear(c, head_size, bias=False)
        self.query = nn.Linear(c, head_size, bias=False)
        self.value = nn.Linear(c, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(content_length, content_length)))

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   
        q = self.query(x)
        
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        
        v = self.value(x) # (B,T,hs)
        out = wei @ v
        return out

In [7]:
from collections import OrderedDict

class FeedForward(nn.Module):
    def __init__(self, fan_in, multiplier = 4):
        super().__init__()

        layers = OrderedDict([
            ("l_in", nn.Linear(fan_in, multiplier * fan_in)),
            ("relu", nn.ReLU()),
            ("l_out", nn.Linear(multiplier * fan_in, fan_in)),
        ])
        self.net = nn.Sequential(
            layers
        )

        initial = layers['l_in']
        nn.init.kaiming_normal_(initial.weight, nonlinearity="relu")
        layers['l_in'].weight.data = initial.weight.data * 3/5
        if initial.bias is not None:
            nn.init.constant_(initial.bias, 0)

        final = layers['l_out']
        layers['l_out'].weight.data = final.weight.data * .2
        if final.bias is not None:
            nn.init.constant_(final.bias, 0)

    def forward(self, x):
        return self.net(x)

In [8]:
import torch.optim as optim

In [9]:
class MultiHead(nn.Module):
    def __init__(self, num_heads, head_size, embed_size, content_length):
        super().__init__()
        self.heads = nn.ModuleList([Head(embed_size, head_size, content_length) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, embed_size)
        self.dropout = nn.Dropout(.1)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [10]:
class AttentionBlock(nn.Module):
    def __init__(self, num_heads, head_size, embed_size, content_length, ff_mul):
        super().__init__()
        self.multihead = MultiHead(num_heads, head_size, embed_size, content_length)
        self.n1 = nn.LayerNorm(embed_size)
        self.ff = FeedForward(embed_size, ff_mul)
        self.n1 = nn.LayerNorm(embed_size)

    def forward(self, x):
        x_a = x.detach()
        x_b = x.detach()
        x = self.multihead(x)
        x = x + x_a
        x = self.n1(x)
        x = self.ff(x)
        x = x + x_b
        x = self.n1(x)

        return x

In [11]:
class FFMultiHeadAttention(nn.Module):

    def __init__(self, embed_size, content_length, num_heads, head_size, multiplier=4):
        super().__init__()
        
        self.vocab_embed = nn.Embedding(vocab_size, embed_size)
        self.positional_embed = nn.Embedding(content_length, embed_size)
        # self.mutli_attention = MultiHead(num_heads, head_size, embed_size, content_length)
        # self.lna = nn.LayerNorm(embed_size)
        # self.ff = FeedForward(embed_size, multiplier)
        # self.lnff = nn.LayerNorm(embed_size)
        self.atta = AttentionBlock(num_heads, head_size, embed_size, content_length, multiplier)
        self.attb = AttentionBlock(num_heads, head_size, embed_size, content_length, multiplier)
        self.decode = nn.Linear(embed_size, vocab_size)
        self.content_length = content_length

    def forward(self, idx, targets=None):
        #idx B,T
        B, T = idx.shape

        idx_e = self.vocab_embed(idx)
        # note tr is always the same - so the learning here is information passed back to the positional_embed from loss
        tr = torch.arange(T)
        pos_e = self.positional_embed(tr)

        x = idx_e + pos_e
        
        x = self.atta(x)
        x = self.attb(x)
        logits = self.decode(x)
        # print("decode out", x.shape)
        # print("targets", targets.shape)
        # return None, None

        if targets is None:
            loss = None
        else:
            targets = targets.reshape(B*T)
            loss = F.cross_entropy(logits.view(B*T, -1), targets) #.resize(B*T)) # loss function

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.content_length:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [12]:
def get_val_batch(data, batch_length=5, batch_size=5, i=0):
    # generate a small batch of data of inputs x and targets y
    if i == 0:
        ix = torch.randint(len(data) - batch_length, (batch_size,))
    else:
        ix = torch.arange(1, 5) + 1 + i

    b = torch.stack([data[i:i+batch_length] for i in ix])
    return b

In [13]:
import math
@torch.no_grad()
def split_loss(split):
    split_len = len(split)
    total_loss = 0
    batch_size = 50
    num_batches = math.floor(split_len / batch_size)
    print("num_batches", split_len, num_batches)
    
    model.eval()

    for i in range(num_batches):

        t_b = get_val_batch(split, context_length+1, batch_size, i*batch_size)
        
        x = t_b[:, 0: context_length]
        y = t_b[:, context_length: context_length+1]
        
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, batch_loss = model(x, y)
        
        total_loss = total_loss + batch_loss
    
    print("total loss", total_loss, total_loss / num_batches)


In [15]:
epochs = 60
training_runs = 800
batch_size = 96
context_length = 12
learning_rate = .1
embedding_dimensions = 32
num_heads = 4
head_size = embedding_dimensions // num_heads

print(head_size)
# our embedding_dimensions are still 'small' so we mutliply the size our our feed forward network to make up
multiplier = 4
model = FFMultiHeadAttention(embedding_dimensions, context_length, num_heads, head_size, multiplier)
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

print(sum(p.numel() for p in model.parameters()), ' parameters')

lmbda = lambda epoch: 0.98

m_scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)

for ep in range(epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

8
28137  parameters
ep 0 tensor(2.3184, grad_fn=<DivBackward0>) [0.098]
ep 2 tensor(1.9948, grad_fn=<DivBackward0>) [0.0941192]
ep 4 tensor(1.9217, grad_fn=<DivBackward0>) [0.09039207968]
ep 6 tensor(1.8863, grad_fn=<DivBackward0>) [0.086812553324672]
ep 8 tensor(1.8645, grad_fn=<DivBackward0>) [0.08337477621301498]
ep 10 tensor(1.8534, grad_fn=<DivBackward0>) [0.08007313507497958]
ep 12 tensor(1.8390, grad_fn=<DivBackward0>) [0.07690223892601039]
ep 14 tensor(1.8308, grad_fn=<DivBackward0>) [0.07385691026454037]
ep 16 tensor(1.8219, grad_fn=<DivBackward0>) [0.07093217661806457]
ep 18 tensor(1.8187, grad_fn=<DivBackward0>) [0.06812326242398921]
ep 20 tensor(1.8101, grad_fn=<DivBackward0>) [0.06542558123199924]
ep 22 tensor(1.8028, grad_fn=<DivBackward0>) [0.06283472821521206]
ep 24 tensor(1.7994, grad_fn=<DivBackward0>) [0.06034647297788966]
ep 26 tensor(1.7952, grad_fn=<DivBackward0>) [0.05795675264796523]
ep 28 tensor(1.7911, grad_fn=<DivBackward0>) [0.055661665243105805]
ep 30 tenso

In [17]:
dev = torch.tensor(alice.data[1])

In [18]:
print(split_loss(dev), len(dev))

idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 100).data[0].tolist()
    print(alice.decode(o))

num_batches 53850 1077
total loss tensor(1831.9103) tensor(1.7009)
None 53850

who on
them which winh 
i d one is about  milight cover was ideven answer assuitely   do s sbit to
g


In [None]:
# allow mor loops without restarting
e_epochs = 40

In [19]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 1000).data[0].tolist()
    print(alice.decode(o))



nightly
wand go  and gavter
 
 it which? 
 you
 of good not  no bin in the two the thought welled
the
worden them capty be subject  n the caweanged air 
      they offution  and one naw you lioust that alicell of the tarm you can for make 
 irpenc it shall getting minum meet is?eace to felld they down  nogicill
 father dired a brobperus dinah way content readed bather plased in in
   e   that  and he dore it galed  and eying up 
1  a basch 
 bustry alice thought bown    j  lappine a make ebnded  pokes it will many
the qanswer 
alice and go the other gstut  is how to rear ttwice  only to tenct 

 with of man  hoped  and exgle too hand in
susalic herales above soat  you so maname to becony! 
do ilnted   of
she
joorewist 
and to try lest 
she mapter on a do said 
     the rimarled  asons lande bonet not a nest at is makes you seemed to the equenly washer wordertabelted   as  my excely? 
 i shouldn  t is blaid  andn land did  if once 
  bybout  with upony arral  gived
puzzled 
 she s law

In [20]:
e_epochs = 60

In [16]:
# allow mor loops without restarting
e_epochs = 10

In [21]:
for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(1.6990, grad_fn=<DivBackward0>) [0.029160207983827797]
ep 2 tensor(1.6864, grad_fn=<DivBackward0>) [0.028005463747668217]
ep 4 tensor(1.6820, grad_fn=<DivBackward0>) [0.026896447383260556]
ep 6 tensor(1.6754, grad_fn=<DivBackward0>) [0.025831348066883437]
ep 8 tensor(1.6748, grad_fn=<DivBackward0>) [0.024808426683434852]
ep 10 tensor(1.6710, grad_fn=<DivBackward0>) [0.023826012986770832]
ep 12 tensor(1.6707, grad_fn=<DivBackward0>) [0.022882502872494704]
ep 14 tensor(1.6672, grad_fn=<DivBackward0>) [0.02197635575874391]
ep 16 tensor(1.6631, grad_fn=<DivBackward0>) [0.021106092070697653]
ep 18 tensor(1.6622, grad_fn=<DivBackward0>) [0.020270290824698025]
ep 20 tensor(1.6596, grad_fn=<DivBackward0>) [0.019467587308039984]
ep 22 tensor(1.6557, grad_fn=<DivBackward0>) [0.0186966708506416]
ep 24 tensor(1.6553, grad_fn=<DivBackward0>) [0.017956282684956193]
ep 26 tensor(1.6542, grad_fn=<DivBackward0>) [0.017245213890631928]
ep 28 tensor(1.6530, grad_fn=<DivBackward0>) [0.01656230

In [24]:
print(split_loss(dev), len(dev))

idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(2):
    o = model.generate(idx, 300).data[0].tolist()
    print(alice.decode(o))

num_batches 53850 1077
total loss tensor(1756.2078) tensor(1.6306)
None 53850

up   he sport  to was beave unical! i never while no came! one
aplessinning exvking  said to conden words assual deven look the poes to out
peoped the in were pust try eadight   the with extenberg becal suppre! she was you an if wilddeo uspeak day on  looks   the more was with a replied
    herd wit

defore offendents  and  you re to herself  the
set! 
 it is cready musts there
one a with sat to the best assumes pair  meaning  grows?  said t him  a reparake is to was little roose
one a twicerles  the
prisper speak done it 
more 
that it was large two  we
answer heage
      z saw arm    i trass a


In [23]:
torch.save(model, "all_alice_checkpoint")