In [11]:
import os
import time

model_checkpoint = "TRL_2.9398_TSL_3.2146_EMB_768_LYR_5_HDS_16_CTX_128_LR_0.0001.pth"
size = os.stat(model_checkpoint).st_size

while True:
  time.sleep(60)
  new_size = os.stat(model_checkpoint).st_size
  if new_size == size:
    break
  size = new_size

In [12]:
import platform

if not os.path.exists("./datasets.zip"):
    print("upload datasets")
elif platform.system() == "Windows":
    os.system("git clone https://github.com/n1teshy/poet & move poet/tokenizer . & move poet/core . & rd /s /q poet")
    os.system("powershell Expand-Archive -Path ./datasets -DestinationPath .")
else:
    os.system("git clone https://github.com/n1teshy/poet && mv poet/tokenizer . && mv poet/core . && rm -rf poet")
    os.system("unzip ./datasets -d .")

In [13]:
import torch

from google.colab import drive
import torch.nn.functional as F
from torch.optim import AdamW
from core.tokenizers.regex import get_tokenizer
from core.utils import get_param_count
from core.config import device
from core.models import Generator

In [14]:
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [15]:
!mkdir drive/MyDrive/poet_params -p

In [16]:
BLOCK_SIZE = 128
BATCH_SIZE = 128
EPOCHS = 10
EMBEDDING_SIZE = 768
LAYERS = 5
HEADS = 16
TRAIN_FILE = "datasets/train.txt"
TEST_FILE = "datasets/test.txt"

In [17]:
tokenizer = get_tokenizer("poems.txt", 1024, "tokenizer/en", True)
train_data = tokenizer.encode(open(TRAIN_FILE, encoding="utf-8").read())
test_data = tokenizer.encode(open(TEST_FILE, encoding="utf-8").read())

In [18]:
def get_batch(split):
    data = train_data if split == "train" else test_data
    idxs = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE, ))
    X = [data[idx: idx + BLOCK_SIZE] for idx in idxs]
    Y = [data[idx + 1: idx + 1 + BLOCK_SIZE] for idx in idxs]
    return torch.tensor(X, device=device), torch.tensor(Y, device=device)

In [19]:
model = Generator(tokenizer.size, EMBEDDING_SIZE, BLOCK_SIZE, LAYERS, HEADS).to(device)
model.load_state_dict(torch.load(model_checkpoint))
print("%.4f mn parameters" % (get_param_count(model) / 1e6, ))

37.0255 mn parameters


In [20]:
@torch.no_grad()
def get_test_loss():
    model.eval()
    inp, tgt = get_batch("test")
    logits = model(inp)
    B, T, C = logits.shape
    logits, tgt = logits.reshape(B*T, C), tgt.reshape(B*T)
    loss = F.cross_entropy(logits, tgt)
    model.train()
    return loss


In [21]:
def batch_generator(split):
    while True:
        yield get_batch(split)

In [22]:
LEARNING_RATE = 0.0001
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

In [23]:
mean_train_loss, mean_test_loss = 2.95, 0
batch_to_epoch = len(train_data) / BATCH_SIZE
last_saved_train_loss = 2.9398

In [24]:
def save_model(folder):
    model_id = "TRL_%.4f_TSL_%.4f_EMB_%d_LYR_%d_HDS_%d_CTX_%d_LR_%.4f" % (
        mean_train_loss,
        mean_test_loss,
        EMBEDDING_SIZE,
        LAYERS,
        HEADS,
        BLOCK_SIZE,
        LEARNING_RATE,
    )
    torch.save(model.state_dict(), os.path.join(folder, f"{model_id}.pth"))


In [None]:
for batch_no, (inp, tgt) in enumerate(batch_generator("train"), start=1):
    optimizer.zero_grad()
    logits = model(inp)
    B, T, C = logits.shape
    logits, tgt = logits.reshape(B * T, C), tgt.reshape(B * T)
    train_loss = F.cross_entropy(logits, tgt)
    train_loss.backward()
    optimizer.step()
    test_loss = get_test_loss()
    train_loss, test_loss = train_loss.item(), test_loss.item()
    mean_train_loss = (mean_train_loss or train_loss) * 0.9975 + train_loss * 0.0025
    mean_test_loss = (mean_test_loss or test_loss) * 0.9975 + test_loss * 0.0025
    print(
        "%d:%d -> (%.4f | %.4f), (%.4f | %.4f)"
        % (
            batch_no // batch_to_epoch + 1,
            batch_no % batch_to_epoch,
            train_loss,
            mean_train_loss,
            test_loss,
            mean_test_loss,
        )
    )
    if last_saved_train_loss - mean_train_loss >= 0.02:
      save_model("drive/MyDrive/poet_params")
      print(f"saved model at train loss {mean_train_loss}")
      last_saved_train_loss = mean_train_loss


1:1 -> (2.9211 | 2.9499), (3.2631 | 3.2631)
1:2 -> (2.9634 | 2.9500), (3.1450 | 3.2628)
1:3 -> (2.9435 | 2.9499), (3.2444 | 3.2627)
1:4 -> (2.9683 | 2.9500), (3.2153 | 3.2626)
1:5 -> (2.9657 | 2.9500), (3.1942 | 3.2624)
1:6 -> (2.9575 | 2.9500), (3.2780 | 3.2625)
1:7 -> (2.9300 | 2.9500), (3.2387 | 3.2624)
1:8 -> (2.9301 | 2.9499), (3.1663 | 3.2622)
1:9 -> (2.9663 | 2.9500), (3.2526 | 3.2621)
1:10 -> (2.9655 | 2.9500), (3.2225 | 3.2620)
1:11 -> (2.9503 | 2.9500), (3.1842 | 3.2618)
1:12 -> (2.9246 | 2.9500), (3.1845 | 3.2617)
1:13 -> (2.8518 | 2.9497), (3.1133 | 3.2613)
1:14 -> (2.9970 | 2.9498), (3.1614 | 3.2610)
1:15 -> (2.9131 | 2.9497), (3.2741 | 3.2611)
1:16 -> (2.9601 | 2.9498), (3.2871 | 3.2611)
1:17 -> (2.9142 | 2.9497), (3.2707 | 3.2612)
1:18 -> (2.8950 | 2.9495), (3.1659 | 3.2609)
1:19 -> (2.9310 | 2.9495), (3.2435 | 3.2609)
1:20 -> (2.8515 | 2.9493), (3.2222 | 3.2608)
1:21 -> (2.9836 | 2.9493), (3.2677 | 3.2608)
1:22 -> (2.9798 | 2.9494), (3.3015 | 3.2609)
1:23 -> (2.9418 | 2

In [None]:
def generate(text=" ", max_len=200):
  context = torch.tensor([tokenizer.encode(text)], device=device)
  output = []
  for _ in range(max_len):
    logits = model(context)
    probs = F.softmax(logits, dim=-1)
    probs = probs[:, -1:, :].view(-1, tokenizer.size)
    next_token = torch.multinomial(probs, num_samples=1)
    output.append(next_token.item())
    context = torch.cat((context, next_token), dim=1)[:, -BLOCK_SIZE:]
  return output
print(tokenizer.decode(generate()))

veldt and broke, growing
A little lease of deadlier words.
Or a some unknown feeling of beauty
That comes, indeed in many ways,
With it has flown away.
I have written its first song and long will never cease to run,
We make the day for hours and time, we will never forget and we will live each other that,
With racism doth live to another day.
Our mind Heavenly Spirits in the air
Of the lost and unspecial joys of God,
Of the deepest woods of Memory! Yea, shall we see
The harlots of the earth, porters of Wind,
Let fall beneath the earth and he
