In [17]:
from google.colab import drive

drive.mount("/content/drive")

In [None]:
!git clone https://github.com/n1teshy/transformer > /dev/null
!mv transformer/* . && rmdir transformer
!mkdir -p drive/MyDrive/checkpoints/poet2
!ls drive/MyDrive/checkpoints/poet2

In [1]:
import os
import math

import torch
from torch.optim import AdamW

from core.data.generator import GeneratorDataset
from core.models import Generator
from core.utils.bpe import Tokenizer
from core.utils.configs import DecoderConfig, GeneratorDataConfig

In [2]:
# data conf
batch_size = 32
context = 512
train_cache = "drive/MyDrive/datasets/poems_cache_34k/train"
val_cache = "drive/MyDrive/datasets/poems_cache_34k/val"
sample_delimiter = ("# " * 39) + "#"

# model conf
no_blocks = 5
no_heads = 16
model_dim = 768
model_context = 512

# training conf
epochs = 10
checkpoints_dir = "drive/MyDrive/checkpoints/poet2"

learning_rate = 3e-3
min_lr = learning_rate / 10
total_samples = 54685
grad_accum_iters = 1
warmup_iters = int(total_samples / batch_size / grad_accum_iters * 0.2)
lr_decay_iters = int(total_samples / batch_size / grad_accum_iters * 0.95)

In [3]:
tokenizer = Tokenizer()
tokenizer.load("tokenizer/poet2_tokenizer.model")

train_data_conf = GeneratorDataConfig(
    batch_size=batch_size,
    pad_id=tokenizer.pad_id,
    cache_dir=train_cache,
    shuffle_shards=True,
    shuffle_samples=True
)
val_data_conf = GeneratorDataConfig(
    batch_size=batch_size,
    pad_id=tokenizer.pad_id,
    cache_dir=val_cache,
    shuffle_shards=True,
    shuffle_samples=True
)

train_dataset = GeneratorDataset(train_data_conf)
val_dataset = GeneratorDataset(val_data_conf)

In [None]:
model_conf = DecoderConfig(
    no_blocks=no_blocks,
    no_heads=no_heads,
    model_dim=model_dim,
    vocab_size=tokenizer.size,
    pad_id=tokenizer.pad_id,
    context=model_context,
    dropout=0.2,
    train_mode=True,
    sos_id=tokenizer.sos_id,
    eos_id=tokenizer.eos_id,
)
model = Generator(model_conf)
# model.load_state_dict(torch.load(os.path.join(checkpoints_dir, )))
model = model.to("cuda")
print(
    "model has %.2fmn parameters"
    % (sum(p.numel() for p in model.parameters()) / 1e6,)
)


def get_lr(it: int):
    if it < warmup_iters:
        return learning_rate * (it + 1) / (warmup_iters + 1)
    if it > lr_decay_iters:
        return min_lr
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (learning_rate - min_lr)


@torch.no_grad()
def get_val_loss() -> float:
    model.eval()
    batch = val_dataset.next_batch()
    if batch is None:
        val_dataset.reset()
        batch = val_dataset.next_batch()
    x, y = batch
    _, loss = model(x, y)
    return loss.item()


def save_model(t_loss: float, v_loss: float):
    name = "%.2f-%.2f-%.2f-%d-%d-%d-%d.pth" % (
        t_loss,
        v_loss,
        learning_rate,
        no_blocks,
        no_heads,
        model_dim,
        model_context,
    )
    torch.save(model.state_dict(), os.path.join(checkpoints_dir, name))

In [7]:
optimizer = AdamW(model.parameters(), lr=learning_rate)

In [5]:
loss_window, batches_trained = 128 / grad_accum_iters, 0
mt_loss, mv_loss = None, None

In [None]:
best_t_loss, best_v_loss = 4, 4
min_loss_improv = 0.2


for epoch in range(epochs):
    while batch := train_dataset.next_batch():
        model.train()
        x, y = batch
        _, loss = model(x, y)
        t_loss = loss.item()
        loss /= grad_accum_iters
        loss.backward()
        batches_trained += 1
        if batches_trained % grad_accum_iters == 0:
            optimizer.zero_grad()
            lr = get_lr(batches_trained / grad_accum_iters)
            for group in optimizer.param_groups:
                group["lr"] = lr
            optimizer.step()
            v_loss = get_val_loss()
            mt_loss = t_loss * (1/loss_window) + (mt_loss or t_loss) * (1 - 1/loss_window)
            mv_loss = v_loss * (1/loss_window) + (mv_loss or v_loss) * (1 - 1/loss_window)
            print(
                "%d-%d: %.2f -> %.2f, %.2f -> %.2f, lr: %.5f"
                % (epoch, batches_trained // grad_accum_iters, t_loss, mt_loss, v_loss, mv_loss, lr)
            )
            if (
                (batches_trained // grad_accum_iters) >= loss_window
                and best_t_loss - mt_loss >= min_loss_improv
                and mv_loss - mt_loss < min_loss_improv
            ):
                save_model(mt_loss, mv_loss)
                best_t_loss, best_v_loss = mt_loss, mv_loss
                print("saved with losses: %.2f, %.2f" % (mt_loss, mv_loss))
        else:
            print("# " * (batches_trained % grad_accum_iters), end="\r")
    train_dataset.reset()

In [None]:
for token in model.generate():
  print(tokenizer.decode([token]), end="")