### What if we increase the corpus

let's start reading more books

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
%run bookreader.py

In [5]:
alice = BookReader("Through the looking glass.txt", "a tangled.txt", "Hunting of the snark.txt", "alice.txt")
vocab_size = alice.vocab_size
vocab_size

41

### Get the batch with both x and y unseparated

In [6]:
def get_batch(data, batch_length=5, batch_size=5):
    # generate a small batch of data of inputs x and targets y
    ix = torch.randint(len(data) - batch_length, (batch_size,))
    b = torch.stack([data[i:i+batch_length] for i in ix])
    return b

In [8]:
train = torch.tensor(alice.data[0])

### Create an attention head

In [9]:
class Head(nn.Module):

    def __init__(self, c, head_size, content_length):
        super().__init__()
        self.key = nn.Linear(c, head_size, bias=False)
        self.query = nn.Linear(c, head_size, bias=False)
        self.value = nn.Linear(c, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(content_length, content_length)))

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   
        q = self.query(x)
        
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        
        v = self.value(x) # (B,T,hs)
        out = wei @ v
        return out

In [10]:
from collections import OrderedDict

class FeedForward(nn.Module):
    def __init__(self, fan_in, multiplier = 4):
        super().__init__()

        layers = OrderedDict([
            ("l_in", nn.Linear(fan_in, multiplier * fan_in)),
            ("relu", nn.ReLU()),
            ("l_out", nn.Linear(multiplier * fan_in, fan_in)),
        ])
        self.net = nn.Sequential(
            layers
        )

        initial = layers['l_in']
        nn.init.kaiming_normal_(initial.weight, nonlinearity="relu")
        layers['l_in'].weight.data = initial.weight.data * 3/5
        if initial.bias is not None:
            nn.init.constant_(initial.bias, 0)

        final = layers['l_out']
        layers['l_out'].weight.data = final.weight.data * .2
        if final.bias is not None:
            nn.init.constant_(final.bias, 0)

    def forward(self, x):
        return self.net(x)

In [11]:
class FFAttention(nn.Module):

    def __init__(self, embed_size, head_size, content_length):
        super().__init__()
        
        self.vocab_embed = nn.Embedding(vocab_size, embed_size)
        self.positional_embed = nn.Embedding(content_length, embed_size)
        self.attention = Head(embed_size, head_size, content_length)
        self.ff = FeedForward(head_size)
        self.decode = nn.Linear(head_size, vocab_size)
        self.content_length = content_length

    def forward(self, idx, targets=None):
        #idx B,T
        B, T = idx.shape

        idx_e = self.vocab_embed(idx)
        # note tr is always the same - so the learning here is information passed back to the positional_embed from loss
        tr = torch.arange(T)
        pos_e = self.positional_embed(tr)

        x = idx_e + pos_e
        x = self.attention(x)
        
        x = self.ff(x)
        
        logits = self.decode(x)

        if targets is None:
            loss = None
        else:
            loss = F.cross_entropy(logits.view(B*T, -1), targets.resize(B*T)) # loss function

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.content_length:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [12]:
import torch.optim as optim

In [13]:
class MultiHead(nn.Module):

    def __init__(self, num_heads, head_size, embed_size, content_length):
        super().__init__()
        self.heads = nn.ModuleList([Head(embed_size, head_size, content_length) for _ in range(num_heads)])

    def forward(self, x):
        out = torch.cat( [head(x) for head in self.heads], dim = -1 )
        return out
    

In [14]:
class FFMultiHeadAttention(nn.Module):

    def __init__(self, embed_size, content_length, num_heads, head_size, multiplier=4):
        super().__init__()
        
        self.vocab_embed = nn.Embedding(vocab_size, embed_size)
        self.positional_embed = nn.Embedding(content_length, embed_size)
        self.mutli_attention = MultiHead(num_heads, head_size, embed_size, content_length)
        self.lna = nn.LayerNorm(embed_size)
        self.ff = FeedForward(embed_size, multiplier)
        self.lnff = nn.LayerNorm(embed_size)
        self.decode = nn.Linear(embed_size, vocab_size)
        self.content_length = content_length

    def forward(self, idx, targets=None):
        #idx B,T
        B, T = idx.shape

        idx_e = self.vocab_embed(idx)
        # note tr is always the same - so the learning here is information passed back to the positional_embed from loss
        tr = torch.arange(T)
        pos_e = self.positional_embed(tr)

        x = idx_e + pos_e
        x = self.mutli_attention(x)
        # print("multi ball out", x.shape)
        x = self.lna(x)
        x = self.ff(x)
        x = self.lnff(x)
        # print("feed forward out", x.shape)
        logits = self.decode(x)
        # print("decode out", x.shape)
        # print("targets", targets.shape)
        # return None, None

        if targets is None:
            loss = None
        else:
            targets = targets.reshape(B*T)
            loss = F.cross_entropy(logits.view(B*T, -1), targets) #.resize(B*T)) # loss function

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.content_length:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [15]:
def get_val_batch(data, batch_length=5, batch_size=5, i=0):
    # generate a small batch of data of inputs x and targets y
    if i == 0:
        ix = torch.randint(len(data) - batch_length, (batch_size,))
    else:
        ix = torch.arange(1, 5) + 1 + i

    b = torch.stack([data[i:i+batch_length] for i in ix])
    return b

In [17]:
import math
@torch.no_grad()
def split_loss(split):
    split_len = len(split)
    total_loss = 0
    batch_size = 50
    num_batches = math.floor(split_len / batch_size)
    print("num_batches", split_len, num_batches)
    
    model.eval()

    for i in range(num_batches):

        t_b = get_val_batch(split, context_length+1, batch_size, i*batch_size)
        
        x = t_b[:, 0: context_length]
        y = t_b[:, context_length: context_length+1]
        
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, batch_loss = model(x, y)
        
        total_loss = total_loss + batch_loss
    
    print("total loss", total_loss, total_loss / num_batches)


### some examples from the past

of the great artist we are creating

In [73]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 100).data[0].tolist()
    print(names.decode(o))


 said  

at othe never exane yokle as lou mout ? 
 her appp aged 


trabs and at  che you of the a c


In [78]:
dev = torch.tensor(alice.data[1])
print(split_loss(dev), len(dev))

idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 100).data[0].tolist()
    print(names.decode(o))

num_batches 11075 221
total loss tensor(426.6404) tensor(1.9305)
None 11075

so
was alice bot by ran to lact the saout be he the don the knen  wen sail 


 you  alice and she cr


In [80]:
print(split_loss(dev), len(dev))

idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 100).data[0].tolist()
    print(names.decode(o))

num_batches 11075 221
total loss tensor(371.4741) tensor(1.6809)
None 11075

you
nice    and there much you d i bles of eyes  the
pig   belcome    nother look       
set  if you


In [18]:
epochs = 60
training_runs = 800
batch_size = 96
context_length = 12
learning_rate = .1
embedding_dimensions = 32
num_heads = 4
head_size = embedding_dimensions // num_heads

print(head_size)
# our embedding_dimensions are still 'small' so we mutliply the size our our feed forward network to make up
multiplier = 8
model = FFMultiHeadAttention(embedding_dimensions, context_length, num_heads, head_size, multiplier)
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

print(sum(p.numel() for p in model.parameters()), ' parameters')

lmbda = lambda epoch: 0.98

m_scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)

for ep in range(epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 5 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

8
22921  parameters
ep 0 tensor(2.3885, grad_fn=<DivBackward0>) [0.098]
ep 5 tensor(1.7651, grad_fn=<DivBackward0>) [0.0885842380864]
ep 10 tensor(1.7127, grad_fn=<DivBackward0>) [0.08007313507497958]
ep 15 tensor(1.6870, grad_fn=<DivBackward0>) [0.07237977205924956]
ep 20 tensor(1.6727, grad_fn=<DivBackward0>) [0.06542558123199924]
ep 25 tensor(1.6596, grad_fn=<DivBackward0>) [0.059139543518331866]
ep 30 tensor(1.6479, grad_fn=<DivBackward0>) [0.053457463299478813]
ep 35 tensor(1.6444, grad_fn=<DivBackward0>) [0.048321312820571644]
ep 40 tensor(1.6353, grad_fn=<DivBackward0>) [0.04367863958719317]
ep 45 tensor(1.6316, grad_fn=<DivBackward0>) [0.03948203069879567]
ep 50 tensor(1.6282, grad_fn=<DivBackward0>) [0.03568862864853744]
ep 55 tensor(1.6226, grad_fn=<DivBackward0>) [0.03225969364468526]


In [20]:
print(split_loss(dev), len(dev))

idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(1):
    o = model.generate(idx, 100).data[0].tolist()
    print(alice.decode(o))

num_batches 53750 1075
total loss tensor(1739.9026) tensor(1.6185)
None 53750

glover nevered at form  electronic word all certaining props under myes  know!  this look soundation


## oops

I left in the gutenberg end bits, alice doesn't know about electroic word...

In [22]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(2):
    o = model.generate(idx, 300).data[0].tolist()
    print(alice.decode(o))


 do of could shapter hatterrupe centeaphs the blaking a me as
overnot can 
     
 you appensequare permistank all infirst  are as  x  been? 
when incheet  way in the being 

 youd 
 found such! said a said never
 if it with the paporty!  are crossen this very to make the
ages three resting was much 

not the pabbecard you? 

 not
5 e  to 3   ghosen sollowl rading  

 thour say   displace not alice the protectly  or
the had project gusion!  and 
 litellowl she the moment itselfided by the curage the hand as holder evoicertain  brotectuods quiet!  they he lot to resker  betting is to be of and 9  
