# Poetry Notebook

In this notebook we will be implementing GPT to generate text based on the work of Edgar Allan Poe.

In [1]:
# Installing dependencies
!pip install tiktoken
!pip install torch

# Downloading dataset from the GitHub
!wget https://raw.githubusercontent.com/kocenko/Poetry-Synthesis/dev/data/poe_data.txt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tiktoken
  Downloading tiktoken-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m21.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tiktoken
Successfully installed tiktoken-0.4.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
--2023-05-16 19:52:05--  https://raw.githubusercontent.com/kocenko/Poetry-Synthesis/dev/data/poe_data.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1930488 (1.8M) [text/plain]
Saving to: ‘poe_data.txt’


2023-05-16 19:52:05 (47.9 MB/s) - ‘poe_data.tx

In [2]:
# Essential imports
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

In [3]:
# Testing if GPU is available
if torch.cuda.is_available():
  device = "cuda"
else:
  device = "cpu"

CUDA_LAUNCH_BLOCKING=1

In [4]:
# Dataset class definition
### (Option) We can use different data to train it on
### (Option) What if the context affects not the following
###          but the one after the following token? (bigger offset)

class PoeDataset(Dataset):
    valid_split_params = ["train", "valid"]

    def __init__(self, text: str, split: str, split_ratio: float, context_length: int, tokenizer, offset: int = 1):
        ''' Poe Dataset constructor

        Args:
            str:
                file_path: Path to the file containing dataset
                splt: String indicating what type of data this dataset contains
            float:
                split_ratio: Value between (0, 1] of what should be the ratio
                             between training and validation set
            int:
                context_length: Length of the context
                offset: An offset between the end of the context and the target
        '''

        assert split in PoeDataset.valid_split_params, f"{split} is the wrong split type"
        assert split_ratio <= 1 and split_ratio > 0, f"Split ratio value should be from range (0, 1]"
        assert len(text) > 0, f"Dataset file should not be empty"
        assert context_length < len(text), f"Context length should not be more than {len(text) - 1}"

        self.text = text
        self.offset = offset
        self.context_length = context_length
        self.tokenizer = tokenizer
        self.data = torch.tensor(self.tokenizer.encode(self.text), dtype=torch.int32, device=device)

        split_idx = int(len(self.data) * split_ratio)
        if split == "train":
            self.data = self.data[:split_idx]
        else:
            self.data = self.data[split_idx:]

    def __len__(self):
        ''' Returns the size of the dataset
        
        Returns:
            Number of possible shifts in the dataset for choosing the context chunk
        '''
        return len(self.data) - self.context_length - self.context_length + 1
    
    def __getitem__(self, index):
        ''' Returns an item of given index

        Params:
            index: Which item should be returned
        
        Returns:
            Sample of given index
        '''
        
        x = self.data[index: index + self.context_length]
        y = self.data[index + self.offset: index + self.context_length + self.offset]

        return x, y



In [5]:
# Defined tokenizer class
import torch
from typing import List


class Tokenizer:
    ''' Class for character-wise tokenization'''

    def __init__(self, text: str):
        assert len(text) > 0, "Text used for creating tokenizer cannot be empty"

        self.text = text
        self.symbols = sorted(list(set(self.text)))
        self.vocab_size = len(self.symbols)
        self.stoi = { ch:i for i, ch in enumerate(self.symbols)}
        self.itos = { i:ch for i, ch in enumerate(self.symbols)}

    def encode(self, text: str) -> List:
        ''' Encodes string to list of ints '''

        return [self.stoi[ch] for ch in text]
    
    def decode(self, tokens: List) -> str:
        ''' Decodes list of ints to string '''
        
        return ''.join([self.itos[token] for token in tokens])


In [37]:
# Simple Decoder Class definition
### (Option) Different split, test data?
from typing import Tuple
import torch.nn.functional as F

class SingleAttentionHead(nn.Module):

    def __init__(self, head_size, config):
        super().__init__()
        self.n_embed = config["n_embed"]
        self.context_length = config["context_length"]

        self.query = nn.Linear(self.n_embed, head_size, bias = False, device=device)
        self.key = nn.Linear(self.n_embed, head_size, bias = False, device=device)
        self.value = nn.Linear(self.n_embed, head_size, bias = False, device=device)
        self.triangle_matrix = torch.tril(torch.ones(self.context_length, self.context_length, device=device))

    def forward(self, x):
        B, T, C = x.shape
        keys = self.key(x)
        query = self.query(x)

        affinities = query @ keys.transpose(-2, -1)  # Dot product, with transposition of T and C
        affinities *= C**(-.5)  # Normalization, to prevent softmax for skewing
        affinities = affinities.masked_fill(self.triangle_matrix[:T, :T] == 0, float('-inf'))
        affinities = F.softmax(affinities, dim=-1)

        v = self.value(x)
        return affinities @ v

class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.heads_num = config["att_head_num"]
        self.n_embed = config["n_embed"]

        head_size = self.n_embed // self.heads_num
        self.attention_heads = nn.ModuleList([SingleAttentionHead(head_size, config) for _ in range(self.heads_num)])
        self.heads_projection = nn.Linear(self.n_embed, self.n_embed, device=device)
        self.feed_forward = nn.Sequential(nn.Linear(self.n_embed, 4 * self.n_embed, device=device),
                                          nn.ReLU(),
                                          nn.Linear(4 * self.n_embed, self.n_embed, device=device))
        # '4 times n_embed' comes from the paper 'Attention is all you need' (as the whole transformer)
        self.layer_normalization = nn.LayerNorm(self.n_embed, device=device)
        self.layer_normalization2 = nn.LayerNorm(self.n_embed, device=device)

    def forward(self, x):
        x = self.layer_normalization(x)
        x = x + self.heads_projection(torch.cat([att(x) for att in self.attention_heads], dim=-1))
        x = self.layer_normalization2(x)
        x = x + self.feed_forward(x)
        return x


class OnlyDecoder(nn.Module):
    def __init__(self, config: dict):
        super().__init__()

        self.vocab_size = config["vocab_size"]
        self.n_embed = config["n_embed"]
        self.context_length = config["context_length"]
        self.head_num = config["att_head_num"]
        self.blocks_num = config["blocks_num"]

        self.token_embedding_table = nn.Embedding(self.vocab_size, self.n_embed, device=device)
        self.pos_embedding_table = nn.Embedding(self.context_length, self.n_embed, device=device)
        self.lin = nn.Linear(self.n_embed, self.vocab_size, device=device)
        self.transformer_blocks = nn.Sequential(*[TransformerBlock(config) for _ in range(self.blocks_num)],
                                                nn.LayerNorm(self.n_embed, device=device))
        self.layer_normalization = nn.LayerNorm(self.n_embed, device=device)

    def forward(self, token_idx: int, targets=None):
        B, T = token_idx.shape
        token_embedding = self.token_embedding_table(token_idx)
        pos_embedding = self.pos_embedding_table(torch.arange(T, device=device))
        x = token_embedding + pos_embedding
        x = self.transformer_blocks(x)
        x = self.layer_normalization(x)
        logits = self.lin(x)

        if targets is None:
          loss = None
        else:
          B, T, C = logits.shape
          logits = logits.view(B*T, C)
          targets = targets.view(B*T)
          targets = targets.type(torch.LongTensor).to(device)
          loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate_new_text(self, idx, sym_limit: int) -> torch.Tensor:
        output = []
        for _ in range(sym_limit):
          idx = idx if idx.size(1) <= self.context_length else idx[:, -self.context_length:]
          logits, loss = self(idx)
          logits = logits[:, -1, :]
          probabilities = F.softmax(logits, dim=-1)
          idx_next = torch.multinomial(probabilities, num_samples=1) # Take best
          output.append(idx_next)
          idx = torch.cat((idx, idx_next), dim=1)
        return output


In [17]:
import torch


@torch.no_grad()
def calc_loss(model, iterations, batch_size, train_set, val_set):
    ''' Used to evalute model by averaging on many iterations
    
    Args:
        model: Evaluated model
        iterations: Number of iterations to average through
        batch_size: Batch size
        train_set: Training dataset
        val_set: Validation dataset

    Returns:
        Dictionary with averaged losses for 'train' nad 'valid'
    '''

    split_type = ["train", "valid"]
    outcome_losses = {}
    model.eval()
    for t, split in enumerate([train_set, val_set]):
        loader = DataLoader(split, batch_size = batch_size, shuffle=True, drop_last=True)
        loader = iter(loader)
        losses = torch.zeros(iterations)
        for i in range(iterations):
            x, y = loader.__next__()
            _, loss = model(x, y)
            losses[i] = loss.item()
        outcome_losses[split_type[t]] = losses.mean()
    model.train()
    return outcome_losses


def train_model(model, train_set, valid_set, hyper_params: dict, device):
    ''' Trains the model

    Args:
        model: Model to train
        train_set: Training dataset
        valid_set: Validation dataset
        hyper_params: dict of hyperparameters
    '''

    lr = hyper_params["lr"]
    epochs = hyper_params["epochs"]
    batch_size = hyper_params["batch_size"]
    eval_each = hyper_params["eval_each"]
    eval_iterations = hyper_params["eval_iterations"]
    break_iter = hyper_params["break_iter"]

    train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    for e in range(epochs):
        for i, (x, y) in enumerate(train_dataloader):
            if i == break_iter:
                break

            if i % eval_each == 0:
                losses = calc_loss(model, eval_iterations, batch_size, train_set, valid_set)
                print(f"Epoch: {e} Step: {i}, train loss: {losses['train']:.4f}, val loss: {losses['valid']:.4f}")
            
            x, y = x.to(device), y.to(device)
            logits, loss = model(x, y)
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()


In [34]:
# Setting up the dataset parameters
### (Option 1) We can use different tokenizer, like SentencePiece
### (Option 2) We can build our own tokenizer, using huggingface library
import tiktoken



file_path = "poe_data.txt"

# Reading file, preparing tokenizer
with open(file_path, 'r', encoding="utf-8") as f:
            text = f.read()

# Setting up dataset
split_ratio = 0.85
context_length = 8
offset = 1  # I am wondering what would be the results for 2, for example
custom_tokenizer = False

if custom_tokenizer:
    tokenizer = Tokenizer(text)
    vocab_size = tokenizer.vocab_size
else:
    tokenizer = tiktoken.get_encoding("cl100k_base")
    vocab_size = tokenizer.n_vocab

# Setting up model (!!! head_size should be n_embed // head_num)
net_config = { "vocab_size": vocab_size,
               "n_embed": 32,
               "context_length": context_length,
               "att_head_num": 4,
               "blocks_num": 3}

symbols_limit = 50
model = OnlyDecoder(net_config)
model.to(device)

# Training parameters
hypers = {
    "lr": .3e-4,
    "epochs": 5,
    "batch_size": 32,
    "eval_each": 200,
    "eval_iterations": 200,
    "break_iter": 10000
}

# Training
train_set = PoeDataset(text, 'train', split_ratio, context_length, tokenizer, offset=offset)
val_set = PoeDataset(text, 'valid', split_ratio, context_length, tokenizer, offset=offset)

train_dataloader = DataLoader(train_set, batch_size=4, shuffle=True, drop_last=True)
train_model(model, train_set, val_set, hypers, device=device)

# Test it
starter = torch.zeros((1,1), dtype=torch.long, device=device)
print(tokenizer.decode(model.generate_new_text(starter, 200)[0].tolist()))

Epoch: 0 Step: 0, train loss: 11.6868, val loss: 11.6843
Epoch: 0 Step: 200, train loss: 11.3956, val loss: 11.3853
Epoch: 0 Step: 400, train loss: 11.0098, val loss: 11.0030
Epoch: 0 Step: 600, train loss: 10.6879, val loss: 10.6785
Epoch: 0 Step: 800, train loss: 10.3888, val loss: 10.3782
Epoch: 0 Step: 1000, train loss: 10.1026, val loss: 10.1024
Epoch: 0 Step: 1200, train loss: 9.8157, val loss: 9.8351
Epoch: 0 Step: 1400, train loss: 9.5640, val loss: 9.5704
Epoch: 0 Step: 1600, train loss: 9.3126, val loss: 9.3289
Epoch: 0 Step: 1800, train loss: 9.0762, val loss: 9.0991
Epoch: 0 Step: 2000, train loss: 8.8596, val loss: 8.8788
Epoch: 0 Step: 2200, train loss: 8.6590, val loss: 8.6775
Epoch: 0 Step: 2400, train loss: 8.4616, val loss: 8.4833
Epoch: 0 Step: 2600, train loss: 8.2840, val loss: 8.2980
Epoch: 0 Step: 2800, train loss: 8.1156, val loss: 8.1373
Epoch: 0 Step: 3000, train loss: 7.9585, val loss: 8.0008
Epoch: 0 Step: 3200, train loss: 7.8210, val loss: 7.8443
Epoch: 0 

In [None]:
# Test it
starter = torch.zeros((1,1), dtype=torch.long, device=device)
print(tokenizer.decode(model.generate_new_text(starter, 100000)[0].tolist()))

In [36]:
save_path = 'state'
torch.save(model.state_dict(), save_path)