In [2]:
# this is where I'll build a transformer based neural network for producing toy story like text
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
class SelfAttention(nn.Module):
    """
    Masked Self Attention
    """
    def __init__(self, input_dim, head_dim, max_seq_len):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(input_dim, head_dim)
        self.key = nn.Linear(input_dim, head_dim)
        self.value = nn.Linear(input_dim, head_dim)

        self.register_buffer('mask', torch.tril(torch.ones(max_seq_len, max_seq_len)))
    
    def forward(self, x):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # (batch_size, seq_len, seq_len)
        similarity_scores = q @ k.transpose(1, 2)
        # mask out future tokens to float('-inf')
        seq_len = similarity_scores.shape[1]
        similarity_scores = similarity_scores.masked_fill(self.mask[:seq_len, :seq_len] == 0, float('-inf'))
        attention_weights = F.softmax(similarity_scores, dim=1)
        output = attention_weights @ v # (batch_size, seq_len, value_dim)
        return output

class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, head_dim, max_seq_len, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.self_attention_heads = nn.ModuleList(
            [SelfAttention(input_dim, head_dim, max_seq_len) for _ in range(num_heads)]
        )
    
    def forward(self, x):
        # (batch_size, seq_len, query_dim)
        # multi head self attention
        output = torch.cat([head(x) for head in self.self_attention_heads], dim=-1)
        return output

class TransformerBlock(nn.Module):
    def __init__(self, input_dim, head_dim, num_heads, ff_dim):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(input_dim, head_dim, 256, num_heads)
        self.norm1 = nn.LayerNorm(input_dim)
        self.ff = nn.Sequential(
            nn.Linear(input_dim, ff_dim * 4),
            nn.ReLU(),
            nn.Linear(ff_dim * 4, input_dim)
        )
        self.norm2 = nn.LayerNorm(input_dim)
        
    def forward(self, x):
        x = self.attention(x)
        x = x + self.norm1(x)
        x = self.ff(x)
        x = x + self.norm2(x)
        return x

In [4]:
# load toy story script as text into a list
text: list[str] = []
with open(f"data/toy story.txt") as f:
    text = list(f.read())
    
len(text)

168003

In [5]:
# print first 100 characters
print("".join(text[:100]))

FADE IN:

INT. ANDY'S BEDROOM

A row of moving boxes lie on the floor of the room.  They
are drawn u


In [6]:
# look up table encoder decoder from letters in words to numbers
def make_encoder_decoder(words: list[str]) -> tuple[dict[str, int], dict[int, str]]:
    letters = sorted(set("".join(words)))
    encoder = {letter: i for i, letter in enumerate(letters)}
    decoder = {i: letter for i, letter in enumerate(letters)}
    return encoder, decoder

In [7]:
# partition the dataset into sequences
def make_sequences(x: torch.Tensor, sequence_size: int) -> torch.Tensor:
    sequences = []
    for i in range(0, x.shape[0] - sequence_size):
        sequences.append(x[i:i+sequence_size])
    return torch.stack(sequences)

In [60]:
# partition the dataset into mini batches randomly and return the batches
def make_batches(x: torch.Tensor, batch_size):
    n = x.shape[0]
    indices = torch.randperm(n)
    x = x[indices]
    for i in range(0, n, batch_size):
        xs = x[i:i+batch_size]
        # batches.append((x[i:i+batch_size, :-1], x[i:i+batch_size, -1]))
        yield x[i:i+batch_size, :-1], x[i:i+batch_size, 1:]

In [61]:
encoder, decoder = make_encoder_decoder(text)
sequence_length = 128
# create the dataset
all_letters = torch.tensor([encoder[letter] for letter in text], dtype=torch.long)
sequences = make_sequences(all_letters, sequence_length + 1).to("mps")

In [67]:
embedding_dim = 64
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.token_embedding = nn.Embedding(len(encoder), embedding_dim)
        self.position_embedding = nn.Embedding(sequence_length, embedding_dim)
        num_heads = 4
        self.transformer_blocks = nn.Sequential(
            TransformerBlock(
                embedding_dim*2, 
                head_dim=(embedding_dim*2)//num_heads, 
                ff_dim=embedding_dim, 
                num_heads=num_heads), 
        )
        self.linear = nn.Linear(embedding_dim*2, len(encoder))
    
    def forward(self, x):
        B, T = x.shape

        token_emb = self.token_embedding(x) # (B, T, C)
        
        position_emb = self.position_embedding(torch.arange(T, device="mps")) # (T, C)
        position_emb = position_emb.unsqueeze(0).expand(B, -1, -1)  

        x = torch.cat([token_emb, position_emb], dim=-1) # (B, T, 2*C)
        x = self.transformer_blocks(x)
        x = self.linear(x)
        return x

In [68]:
# set up the model
model = Model().to("mps")
# set up the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [69]:
losses = []

In [74]:
# training loop
max_iters = 25
batch_size = 32
num_batches = sequences.shape[0] // batch_size
model.train().to("mps")
for i in range(max_iters):
    for batch_num, (x, y) in enumerate(make_batches(sequences, batch_size)):
        logits = model(x)
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        y = y.flatten()
        loss = F.cross_entropy(logits, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        if batch_num % 100 == 0:
            print(f"epoch: {i}, batch: [{batch_num:>4d}/{num_batches:>4d}], loss: {loss.item():.4f}")
    
    print(f"epoch: {i}, loss: {loss.item()}")

# save checkpoint of model (optimizer and model state dict)
torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss
}, "checkpoint.pt")

epoch: 0, batch: [   0/5246], loss: 2.3328
epoch: 0, batch: [ 100/5246], loss: 1.9948
epoch: 0, batch: [ 200/5246], loss: 1.8148
epoch: 0, batch: [ 300/5246], loss: 1.7014
epoch: 0, batch: [ 400/5246], loss: 1.2627
epoch: 0, batch: [ 500/5246], loss: 0.8450
epoch: 0, batch: [ 600/5246], loss: 0.5405
epoch: 0, batch: [ 700/5246], loss: 0.3890
epoch: 0, batch: [ 800/5246], loss: 0.3514
epoch: 0, batch: [ 900/5246], loss: 0.2569
epoch: 0, batch: [1000/5246], loss: 0.2119
epoch: 0, batch: [1100/5246], loss: 0.1632
epoch: 0, batch: [1200/5246], loss: 0.1411
epoch: 0, batch: [1300/5246], loss: 0.1233
epoch: 0, batch: [1400/5246], loss: 0.1139
epoch: 0, batch: [1500/5246], loss: 0.0968
epoch: 0, batch: [1600/5246], loss: 0.0924
epoch: 0, batch: [1700/5246], loss: 0.0870
epoch: 0, batch: [1800/5246], loss: 0.0714
epoch: 0, batch: [1900/5246], loss: 0.0666
epoch: 0, batch: [2000/5246], loss: 0.0692
epoch: 0, batch: [2100/5246], loss: 0.0564
epoch: 0, batch: [2200/5246], loss: 0.0477
epoch: 0, b

KeyboardInterrupt: 

In [83]:
@torch.no_grad()
def generate_text(model, start_text, max_len=200):
    model.eval()
    text = start_text
    for i in range(max_len):
        x = torch.tensor([encoder[letter] for letter in text[-sequence_length:]], dtype=torch.long).to("mps")
        x = x.unsqueeze(0)
        logits = model(x)
        logits = logits[:, -1, :]
        # sample from the distribution
        probs = F.softmax(logits, dim=-1)
        letter = torch.multinomial(probs, 1).squeeze(0)[-1]
        text += decoder[letter.item()]

    return text

In [84]:
generate_text(model, "A row of moving boxes lie on the floor of the room. They are drawn")

'A row of moving boxes lie on the floor of the room. \nThey are drawnaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa,"TISSSSRS\nof hyalce PCA,\n  DES\nchacks rure haynhe rally Doe rup the dre rome re     SSYISSS PIElly widdles thr        HEyaly 6yy, is mimsind dey '