In [1]:
from transformer import Transformer
from regex_tokenizer import RegexTokenizer
import os
import pickle
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from tqdm import tqdm
import numpy as np

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
torch.cuda.device_count()

2

In [4]:
texts = []
directory = "books"

for fn in os.listdir(directory):
    with open(f"{directory}/{fn}", "r") as f:
        texts.append(f.read())

text = "<|endoftext|>".join(texts)

In [5]:
tokenizer = RegexTokenizer()

In [7]:
vocab_size = 256

assert vocab_size >= 256
n_merges = vocab_size - 256

tokenizer.train(text, n_merges)

0it [00:00, ?it/s]


In [6]:
#with open("tokenizer_256.pkl", "wb") as f:
#    pickle.dump(tokenizer, f)

In [6]:
with open("tokenizer_1024.pkl", "rb") as f:
    tokenizer = pickle.load(f)

In [8]:
#tokens = tokenizer.encode(text, allow_special="all")

In [8]:
#with open("train_tokens_512.pkl", "rb") as f:
#    tokens = pickle.load(f)

In [11]:
with open("train_tokens_256.pkl", "wb") as f:
    pickle.dump(tokens, f)

In [12]:
sample_text = "Hi, I am Damian and this is 🔥"
sample_tokens = tokenizer.encode(sample_text, allow_special="all")

100%|██████████████████████████████████████████| 1/1 [00:00<00:00, 23967.45it/s]


In [13]:
sample_tokens

[72,
 105,
 44,
 32,
 73,
 32,
 97,
 109,
 32,
 68,
 97,
 109,
 105,
 97,
 110,
 32,
 97,
 110,
 100,
 32,
 116,
 104,
 105,
 115,
 32,
 105,
 115,
 32,
 240,
 159,
 148,
 165]

In [14]:
print(tokenizer.decode(sample_tokens))

Hi, I am Damian and this is 🔥


In [15]:
class TokenDataset(Dataset):
    def __init__(self, tokens, context_size):
        
        self.context_size = context_size
        self.tokens = torch.tensor(tokens)

    def __len__(self):
        return len(self.tokens) - self.context_size - 1

    def __getitem__(self, idx):
        x = self.tokens[idx:(idx + self.context_size)]
        y = self.tokens[(idx + 1):(idx + self.context_size + 1)]
        
        return x, y

In [7]:
context_size = 64

In [16]:
token_dataset = TokenDataset(tokens, context_size = context_size)

In [17]:
batch_size = 256

train_set, val_set, test_set = torch.utils.data.random_split(token_dataset, [0.9, 0.05, 0.05])

train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
val_loader = DataLoader(val_set, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_set, shuffle=True, batch_size=batch_size)

In [18]:
# taken from: https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [19]:
def train(transformer, train_loader, val_loader, n_epochs,
          optimizer=None,
          lr_scheduler=None,
          early_stopper=None,
          metrics_per_epoch=10
         ):
    
    transformer = transformer.to(device)
    
    if optimizer is None:
        optimizer = torch.optim.Adam(transformer.parameters(), lr=3e-4)
        print("Using default optimizer")
        
    if early_stopper is None:
        early_stopper = EarlyStopper(patience=3, min_delta=1e-2)
        print("Using default early stopper")
        
    if lr_scheduler is None:
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                                  factor=0.3, patience=3, min_lr=1e-5,
                                                                  threshold=1e-3
                                                                 )
        print("Using default LR scheduler")
    
    # label smoothing deactivated for now
    criterion_train = nn.CrossEntropyLoss(label_smoothing=0.0)
    criterion_test = nn.CrossEntropyLoss()
    
    train_losses_over_epochs = []
    val_losses_over_epochs = []
    
    metrics_every = len(train_loader) // metrics_per_epoch
    in_between_epochs = []
    in_between_metrics = []
    
    for epoch_idx in range(n_epochs):
        
        train_losses_this_batch = []
        transformer.train()
        
        with tqdm(train_loader, desc=f"Epoch {epoch_idx + 1}/{n_epochs}", unit="batch") as tepoch:
            for batch_idx, (batch_x, batch_y) in enumerate(tepoch):

                # to GPU
                batch_x = batch_x.to(device)
                batch_y = batch_y.to(device)

                logits = transformer(batch_x)

                logits = logits.transpose(1, 2)

                loss = criterion_train(logits, batch_y)

                train_losses_this_batch.append(loss.item())

                if (batch_idx + 1) % metrics_every == 0:
                    in_between_loss = np.mean(np.array(train_losses_this_batch[-metrics_every:]))
                    in_between_metrics.append(in_between_loss)
                    in_between_epochs.append(epoch_idx + (batch_idx / len(train_loader)))

                    tepoch.set_postfix(avg_loss=in_between_loss)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
        train_loss_this_epoch = np.mean(np.array(train_losses_this_batch))
        train_losses_over_epochs.append(train_loss_this_epoch)
        
        # for early stopping
        val_losses_this_batch = []
        
        transformer.eval()
        
        with torch.no_grad():
            for batch_idx, (batch_x, batch_y) in enumerate(val_loader):

                # to GPU
                batch_x = batch_x.to(device)
                batch_y = batch_y.to(device)

                logits = transformer(batch_x)
                
                logits = logits.transpose(1, 2)

                loss = criterion_test(logits, batch_y)

                val_losses_this_batch.append(loss.item())
        
        val_loss_this_epoch = np.mean(np.array(val_losses_this_batch))
        val_losses_over_epochs.append(val_loss_this_epoch)
        print(f"{epoch_idx}. avg. train loss = {train_loss_this_epoch}, avg. val loss = {val_loss_this_epoch}")
        
        should_stop = early_stopper.early_stop(val_loss_this_epoch)
        lr_scheduler.step(val_loss_this_epoch)
        
        if should_stop:
            print(f"stopping early (val. loss did not decrease for {early_stopper.patience})")
            break
        
    return train_losses_over_epochs, in_between_epochs, in_between_metrics, val_losses_over_epochs

In [30]:
vocab_size = 256
n_symbols = vocab_size + len(tokenizer.special_tokens)

transformer = Transformer(n_symbols, context_size, d_model = 256, n_heads = 8, n_layers=8, device=device)

In [31]:
# Calculate number of trainable parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Get the number of parameters
num_params = count_parameters(transformer)
print(f"Number of trainable parameters: {num_params}")

Number of trainable parameters: 5939969


In [9]:
transformer.load_model("model_1024_18e.pth")

Model loaded from model_1024_18e.pth


In [21]:
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    transformer = torch.nn.DataParallel(transformer)

Using 2 GPUs!


In [22]:
#transformer = transformer.module

In [22]:
optimizer = torch.optim.Adam(transformer.parameters(), lr=3e-4)

In [24]:
train_losses, in_between_epochs, in_between_metrics, val_losses =\
train(transformer, train_loader, val_loader, n_epochs=9, optimizer=optimizer, metrics_per_epoch=100)

Using default early stopper
Using default LR scheduler


Epoch 1/9: 100%|██████████| 8097/8097 [14:42<00:00,  9.18batch/s, avg_loss=1.11]


0. avg. train loss = 1.2537779447339532, avg. val loss = 1.044516016377343


Epoch 2/9: 100%|█████████████| 8097/8097 [14:40<00:00,  9.20batch/s, avg_loss=1]


1. avg. train loss = 1.0518828357783927, avg. val loss = 0.9038003379768795


Epoch 3/9: 100%|█████████| 8097/8097 [14:39<00:00,  9.21batch/s, avg_loss=0.938]


2. avg. train loss = 0.9674982047823015, avg. val loss = 0.8126448636584812


Epoch 4/9: 100%|█████████| 8097/8097 [14:40<00:00,  9.20batch/s, avg_loss=0.893]


3. avg. train loss = 0.9132949791676589, avg. val loss = 0.7446919899516635


Epoch 5/9: 100%|██████████| 8097/8097 [14:40<00:00,  9.19batch/s, avg_loss=0.86]


4. avg. train loss = 0.874725679880427, avg. val loss = 0.6991754105356005


Epoch 6/9: 100%|█████████| 8097/8097 [14:39<00:00,  9.21batch/s, avg_loss=0.834]


5. avg. train loss = 0.845420897426584, avg. val loss = 0.6621509642071194


Epoch 7/9: 100%|█████████| 8097/8097 [14:38<00:00,  9.22batch/s, avg_loss=0.814]


6. avg. train loss = 0.822179857449604, avg. val loss = 0.631395171350903


Epoch 8/9: 100%|█████████| 8097/8097 [14:37<00:00,  9.23batch/s, avg_loss=0.796]


7. avg. train loss = 0.8031615089218868, avg. val loss = 0.6052261923419104


Epoch 9/9: 100%|█████████| 8097/8097 [14:37<00:00,  9.23batch/s, avg_loss=0.781]


8. avg. train loss = 0.7871972238250436, avg. val loss = 0.5903909688525729


In [13]:
#transformer.module.save_model("model_256_9e.pth")

In [24]:
if isinstance(transformer, nn.DataParallel):
    print("removed wrapper")
    transformer = transformer.module


prompt = "While exploring the depths of the ocean in the Nautilus, Captain Nemo discovers an underwater cave filled with artifacts. Among these, he finds a detailed map that leads to a hidden location on land."
n_tokens = 128
n_samples = 3

responses = transformer.sample(prompt_tokens, n_tokens, n_samples, beta = 0.8)

Prompt tokens: [65, 727, 811, 292, 279, 564, 511, 568, 263, 306, 531, 121, 355, 429, 262, 359, 987, 856, 260, 308, 425, 280, 46, 387, 609, 1006, 331, 44, 429, 347, 479, 98, 650, 586, 258, 379, 319, 269, 666, 798, 932, 907, 817, 339, 829, 671, 258, 550, 314, 271, 401, 104, 46, 544, 967, 771, 274, 475, 348, 117, 538, 46]


In [25]:
def print_reponses(prompt, responses):
    print("Prompt:")
    print("```")
    print(f"{prompt}")
    print("```")

    print("")
    print("vocab size: 1024")
    print("")

    for i, response in enumerate(responses):

        response_tokens = [t.item() for t in response.cpu().detach()]
        response_text = tokenizer.decode(response_tokens)
        print("```")
        print(f"{prompt} {response_text}")
        print("```")

In [26]:
print_reponses(prompt, responses)

Prompt:
```
While exploring the depths of the ocean in the Nautilus, Captain Nemo discovers an underwater cave filled with artifacts. Among these, he finds a detailed map that leads to a hidden location on land.
```

vocab size: 1024

```
While exploring the depths of the ocean in the Nautilus, Captain Nemo discovers an underwater cave filled with artifacts. Among these, he finds a detailed map that leads to a hidden location on land. 
now, when there was one of the scent occurred which I was keeping on earth
and open, and I could see nothing but a strange dress which I took my bank
to easy open to me, and I found my smack ones and stopped at the dog.
There was a perfect rifle of cards, around the parlor, and out-of-wheel
whispered, and big old snakes, and spiders, and a
```
```
While exploring the depths of the ocean in the Nautilus, Captain Nemo discovers an underwater cave filled with artifacts. Among these, he finds a detailed map that leads to a hidden location on land. 
Captain N