In [15]:
import os
import platform

if not os.path.exists("./datasets.zip"):
    print("upload datasets")
elif platform.system() == "Windows":
    os.system("git clone https://github.com/n1teshy/poet & move poet/tokenizer . & move poet/core . & rd /s /q poet")
    os.system("powershell Expand-Archive -Path ./datasets -DestinationPath .")
else:
    os.system("git clone https://github.com/n1teshy/poet && mv poet/tokenizer . && mv poet/core . && rm -rf poet")
    os.system("unzip ./datasets -d .")

In [16]:
import torch

from google.colab import drive
import torch.nn.functional as F
from torch.optim import AdamW
from core.tokenizers.regex import get_tokenizer
from core.utils import get_param_count
from core.config import device
from core.models import Generator

In [None]:
drive.mount("/content/drive")

In [18]:
!mkdir drive/MyDrive/poet_params -p

In [19]:
BLOCK_SIZE = 128
BATCH_SIZE = 128
EPOCHS = 10
EMBEDDING_SIZE = 768
LAYERS = 5
HEADS = 16
TRAIN_FILE = "datasets/train.txt"
TEST_FILE = "datasets/test.txt"

In [20]:
tokenizer = get_tokenizer("poems.txt", 1024, "tokenizer/en", True)
train_data = tokenizer.encode(open(TRAIN_FILE, encoding="utf-8").read())
test_data = tokenizer.encode(open(TEST_FILE, encoding="utf-8").read())

In [21]:
def get_batch(split):
    data = train_data if split == "train" else test_data
    idxs = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE, ))
    X = [data[idx: idx + BLOCK_SIZE] for idx in idxs]
    Y = [data[idx + 1: idx + 1 + BLOCK_SIZE] for idx in idxs]
    return torch.tensor(X, device=device), torch.tensor(Y, device=device)

In [24]:
model_checkpoint = "TRL_2.9398_TSL_3.2146_EMB_768_LYR_5_HDS_16_CTX_128_LR_0.0001.pth"
model = Generator(tokenizer.size, EMBEDDING_SIZE, BLOCK_SIZE, LAYERS, HEADS).to(device)
model.load_state_dict(torch.load("TRL_2.9198_TSL_3.2067_EMB_768_LYR_5_HDS_16_CTX_128_LR_0.0001.pth"))
print("%.4f mn parameters" % (get_param_count(model) / 1e6, ))

37.0255 mn parameters


In [25]:
@torch.no_grad()
def get_test_loss():
    model.eval()
    inp, tgt = get_batch("test")
    logits = model(inp)
    B, T, C = logits.shape
    logits, tgt = logits.reshape(B*T, C), tgt.reshape(B*T)
    loss = F.cross_entropy(logits, tgt)
    model.train()
    return loss


In [26]:
def batch_generator(split):
    while True:
        yield get_batch(split)

In [27]:
LEARNING_RATE = 0.0001
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

In [28]:
mean_train_loss, mean_test_loss = 2.95, 0
batch_to_epoch = len(train_data) / BATCH_SIZE
last_saved_train_loss = 2.9198

In [29]:
def save_model(folder):
    model_id = "TRL_%.4f_TSL_%.4f_EMB_%d_LYR_%d_HDS_%d_CTX_%d_LR_%.4f" % (
        mean_train_loss,
        mean_test_loss,
        EMBEDDING_SIZE,
        LAYERS,
        HEADS,
        BLOCK_SIZE,
        LEARNING_RATE,
    )
    torch.save(model.state_dict(), os.path.join(folder, f"{model_id}.pth"))


In [None]:
for batch_no, (inp, tgt) in enumerate(batch_generator("train"), start=1):
    optimizer.zero_grad()
    logits = model(inp)
    B, T, C = logits.shape
    logits, tgt = logits.reshape(B * T, C), tgt.reshape(B * T)
    train_loss = F.cross_entropy(logits, tgt)
    train_loss.backward()
    optimizer.step()
    test_loss = get_test_loss()
    train_loss, test_loss = train_loss.item(), test_loss.item()
    mean_train_loss = (mean_train_loss or train_loss) * 0.9975 + train_loss * 0.0025
    mean_test_loss = (mean_test_loss or test_loss) * 0.9975 + test_loss * 0.0025
    print(
        "%d:%d -> (%.4f | %.4f), (%.4f | %.4f)"
        % (
            batch_no // batch_to_epoch + 1,
            batch_no % batch_to_epoch,
            train_loss,
            mean_train_loss,
            test_loss,
            mean_test_loss,
        )
    )
    if last_saved_train_loss - mean_train_loss >= 0.02:
      save_model("drive/MyDrive/poet_params")
      print(f"saved model at train loss {mean_train_loss}")
      last_saved_train_loss = mean_train_loss


In [None]:
def generate(text=" ", max_len=2000):
  context = torch.tensor([tokenizer.encode(text)], device=device)
  output = []
  for _ in range(max_len):
    logits = model(context)
    probs = F.softmax(logits, dim=-1)
    probs = probs[:, -1:, :].view(-1, tokenizer.size)
    next_token = torch.multinomial(probs, num_samples=1)
    output.append(next_token.item())
    context = torch.cat((context, next_token), dim=1)[:, -BLOCK_SIZE:]
  return output
print(tokenizer.decode(generate()))