In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
import math
import torch.nn.functional as F

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
def to_tokens(text, c_to_i):
    
    res = []
    
    for c in text:
        res.append(c_to_i[c])
        
    return res

In [4]:
def from_tokens(tokens, i_to_c):
    
    res = []
    
    for t in tokens:
        res.append(i_to_c[t])
        
    return "".join(res)

In [6]:
class TokenDataset(Dataset):
    def __init__(self, tokens, context_size):
        
        self.context_size = context_size
        self.tokens = torch.tensor(tokens)

    def __len__(self):
        return len(self.tokens) - self.context_size - 1

    def __getitem__(self, idx):
        x = self.tokens[idx:(idx + self.context_size)]
        y = self.tokens[(idx + 1):(idx + self.context_size + 1)]
        
        return x, y

In [7]:
def get_dataset_and_dicts(filename, context_size):
    
    with open(filename, "r") as f:
        text = f.read()
    
    chars = list(sorted(set(text)))
    
    print(f"context size = {context_size}")
    print(f"#symbols = {len(chars)}")
    
    c_to_i = {c:i for i, c in enumerate(chars)}
    i_to_c = {i:c for i, c in enumerate(chars)}
    
    tokens = to_tokens(text, c_to_i)
    
    print(f"#tokens in text = {len(tokens)}")
    
    return TokenDataset(tokens, context_size), c_to_i, i_to_c

In [8]:
context_size = 32
token_dataset, c_to_i, i_to_c = get_dataset_and_dicts("romeo_and_juliet.txt", context_size=context_size)

context size = 32
#symbols = 70
#tokens in text = 142455


In [9]:
train_set, val_set, test_set = torch.utils.data.random_split(token_dataset, [0.8, 0.1, 0.1])

In [10]:
batch_size = 128

train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
val_loader = DataLoader(val_set, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_set, shuffle=True, batch_size=batch_size)

In [20]:
class TransformerBlock(nn.Module):
    
    def __init__(self, d_model, n_heads, device):
        
        super().__init__()
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.device = device
        
        assert d_model % n_heads == 0, "n_heads must divide d_model"
        
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        
        # just matrix multiplications
        self.key_net = nn.Sequential(
            nn.Linear(d_model, d_model)
        )
        
        self.query_net = nn.Sequential(
            nn.Linear(d_model, d_model)
        )
        
        self.value_net = nn.Sequential(
            nn.Linear(d_model, d_model)
        )
    
    def forward(self, x):
        
        B, T, _ = x.shape
        
        # != self.key_net(x).view((B, self.n_heads, T, -1))
        # with the above, future leaks into past
        keys = self.key_net(x).view((B, T, self.n_heads, -1)).transpose(1, 2)
        queries = self.query_net(x).view((B, T, self.n_heads, -1)).transpose(1, 2)
        values = self.value_net(x).view((B, T, self.n_heads, -1)).transpose(1, 2)
        
        scaling_factor = 1.0 / math.sqrt(self.d_model / self.n_heads)
        attention_matrices = scaling_factor * torch.matmul(queries, keys.transpose(2, 3))
        
        neg_inf = -1e10
        
        # mask the future (upper triangle)
        mask = torch.tril(torch.ones(T, T)).to(self.device)
        mask = mask.masked_fill(mask == 0, -float("inf"))
                        
        # softmax per row
        activated_attention_matrices = F.softmax(attention_matrices + mask, dim=-1)
                
        # (B, head, T, dim_per_head)
        # d_model = head * dim_per_head
        att_output = torch.matmul(activated_attention_matrices, values)
        
        att_output = torch.transpose(att_output, 1, 2)
        
        # TODO: add layer norm here
        ffn_input = att_output.reshape((B, T, -1)) + x
        
        ffn_output = self.ffn(ffn_input)
        
        # TODO: add layer norm here
        return ffn_input + ffn_output

In [21]:
class Transformer(nn.Module):
    
    def __init__(self, n_symbols, context_length, d_model, n_heads, n_layers, device):
        
        super().__init__()
        
        self.n_symbols = n_symbols
        self.d_model = d_model
        self.context_length = context_length
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.device=device
        
        self.token_embedding = nn.Embedding(num_embeddings=n_symbols, embedding_dim=d_model)
        self.pos_embedding = nn.Embedding(num_embeddings=context_length, embedding_dim=d_model)
        
        # TODO: multiple transformer blocks
        tbs = [TransformerBlock(d_model = d_model, n_heads = n_heads, device=device) for _ in range(n_layers)]
        self.transformer_blocks = nn.Sequential(*tbs)
        
        self.to_logits = nn.Sequential(
            nn.Linear(d_model, n_symbols, device=device)
        )
        
    def forward(self, x):
        
        # batch, time
        B, T = x.shape
        
        embedded = self.token_embedding(x)
        #print(f"{embedded.shape}")
        
        positions = torch.arange(T).to(self.device)
        #print(f"{positions.shape}")

        embedded = embedded + self.pos_embedding(positions)
        
        after_transformer_layers = self.transformer_blocks(embedded)
                
        return self.to_logits(after_transformer_layers)
    
    def sample(self, prompt, n_tokens, c_to_i, i_to_c, beta = 1.0):
        self.eval()
        self.to(device)

        # Process the prompt to fit within the context length
        prompt = prompt[-self.context_length:]
        print(f"Prompt: {prompt}")
        
        prompt_tokens = [c_to_i[c] for c in prompt]
        print(f"Prompt tokens: {prompt_tokens}")

        context = torch.tensor(prompt_tokens, dtype=torch.long).unsqueeze(0).to(device)  # Ensure context is 2D: 1 x sequence_length

        output = []
        for _ in range(n_tokens):
            with torch.no_grad():
                logits = self(context)[:, -1, :] / beta  # Get logits for the last token position only
                probs = F.softmax(logits, dim=-1)
                last_sampled_token = torch.multinomial(probs, num_samples=1)
                
                output.append(last_sampled_token.item())
                context = torch.cat((context, last_sampled_token), dim=1)[:, -self.context_length:]  # Update context
                
                
        response = ''.join([i_to_c[t] for t in output])
        print(f"Response: {response}")
        return response

In [22]:
n_symbols = len(c_to_i.keys())

transformer = Transformer(n_symbols, context_size, d_model = 512, n_heads = 8, n_layers=4, device=device)

In [23]:
# taken from: https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [24]:
def train(transformer, train_loader, val_loader, n_epochs,
          optimizer=None,
          lr_scheduler=None,
          early_stopper=None, 
         ):
    
    transformer = transformer.to(device)
    
    if optimizer is None:
        optimizer = torch.optim.Adam(transformer.parameters(), lr=3e-4)
        print("Using default optimizer")
        
    if early_stopper is None:
        early_stopper = EarlyStopper(patience=3, min_delta=1e-2)
        print("Using default early stopper")
        
    if lr_scheduler is None:
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                                  factor=0.3, patience=3, min_lr=1e-5,
                                                                  threshold=1e-3
                                                                 )
        print("Using default LR scheduler")
        
    criterion = nn.CrossEntropyLoss()
    
    train_losses_over_epochs = []
    val_losses_over_epochs = []
    
    for epoch_idx in range(n_epochs):
        
        train_losses_this_batch = []
        transformer.train()
        
        for batch_idx, (batch_x, batch_y) in enumerate(train_loader):
            
            # to GPU
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            
            logits = transformer(batch_x)
            
            logits = logits.transpose(1, 2)
            
            loss = criterion(logits, batch_y)
            
            train_losses_this_batch.append(loss.item())
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        train_loss_this_epoch = np.mean(np.array(train_losses_this_batch))
        train_losses_over_epochs.append(train_loss_this_epoch)
        
        # for early stopping
        val_losses_this_batch = []
        
        transformer.eval()
        
        with torch.no_grad():
            for batch_idx, (batch_x, batch_y) in enumerate(val_loader):

                # to GPU
                batch_x = batch_x.to(device)
                batch_y = batch_y.to(device)

                logits = transformer(batch_x)
                
                logits = logits.transpose(1, 2)

                loss = criterion(logits, batch_y)

                val_losses_this_batch.append(loss.item())
        
        val_loss_this_epoch = np.mean(np.array(val_losses_this_batch))
        val_losses_over_epochs.append(val_loss_this_epoch)
        print(f"{epoch_idx}. avg. train loss = {train_loss_this_epoch}, avg. val loss = {val_loss_this_epoch}")
        
        should_stop = early_stopper.early_stop(val_loss_this_epoch)
        lr_scheduler.step(val_loss_this_epoch)
        
        if should_stop:
            print(f"stopping early (val. loss did not decrease for {early_stopper.patience})")
            break
        
    return train_losses_over_epochs, val_losses_over_epochs

In [25]:
optimizer = torch.optim.Adam(transformer.parameters(), lr=1e-4)

In [30]:
train_losses, val_losses = train(transformer, train_loader, val_loader, n_epochs=5, optimizer=optimizer)

Using default early stopper
Using default LR scheduler
0. avg. train loss = 0.6658949276681165, avg. val loss = 0.6295139496879918
1. avg. train loss = 0.5508448999337475, avg. val loss = 0.5450061815125602
2. avg. train loss = 0.48529759942481815, avg. val loss = 0.498303263847317
3. avg. train loss = 0.44516120916263824, avg. val loss = 0.46687306969293524
4. avg. train loss = 0.4209693576617942, avg. val loss = 0.4513879711074488


In [33]:
transformer.sample(prompt="You are", n_tokens = 600, c_to_i=c_to_i, i_to_c=i_to_c, beta=0.75)

Prompt: You are
Prompt tokens: [33, 52, 58, 1, 38, 55, 42]
Response:  a joyful bride.
I wonder at this haste, that I must wed
Ere he that shot so trim
Whinishe heaven bless thee. Hark you, sir.

ROMEO.
What wilt thou tell her, Nurse? Thou dost not mark while. God Benvolio and Romeo.

BENVOLIO.
Here comes the furious Tybalt back again.

 [_Exit._]

ROMEO.
[_To Juliet._] If I profane with my unworthiest hand
This toolboys requestainers of this neighbour-stained steel,—
We’ll warrant thee, Nurse. Commend me to thy lady and my wife!

JULIET.
That same I am done. For thou hast sold one what he best friend I how not?

JULIET.
My ear new in my house of Montague’s.

PA


' a joyful bride.\nI wonder at this haste, that I must wed\nEre he that shot so trim\nWhinishe heaven bless thee. Hark you, sir.\n\nROMEO.\nWhat wilt thou tell her, Nurse? Thou dost not mark while. God Benvolio and Romeo.\n\nBENVOLIO.\nHere comes the furious Tybalt back again.\n\n [_Exit._]\n\nROMEO.\n[_To Juliet._] If I profane with my unworthiest hand\nThis toolboys requestainers of this neighbour-stained steel,—\nWe’ll warrant thee, Nurse. Commend me to thy lady and my wife!\n\nJULIET.\nThat same I am done. For thou hast sold one what he best friend I how not?\n\nJULIET.\nMy ear new in my house of Montague’s.\n\nPA'