In [None]:
import torch
import torch.nn.functional as F
import time

from text_loader import CharDataset #, TextLoader
from torch.utils.data import DataLoader
from lstm import CharLSTM, generate

# ---------------------------------------------
# data prep
text = open('../data/pg_essays.txt', 'r').read()
# print(text[:500])

batch_size = 2
block_size = 8

base_ds = CharDataset(text, block_size)
ds_train = base_ds.split('train')
ds_val = base_ds.split('val')

# pin_memory = torch.cuda.is_available()
loader_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4, drop_last=True)
loader_val = DataLoader(ds_val, batch_size=1000, shuffle=True, pin_memory=True, num_workers=4, drop_last=True)

# ------------------------------------------------------------
device = 'mps'
model = CharLSTM(ds_train.vocab_size, emb=64, hidden=64, layers=2)
print(f'{(sum(p.numel() for p in model.parameters()) / 1e6):.4f} Mil parameters')
model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)


In [None]:
max_iters = 20000
patience = 20
best_val_loss = float('inf')
epochs_no_improve = 0

for i in range(max_iters):
    t0 = time.time()
    xb,yb = next(iter(loader_train))
    xb,yb = xb.to(device), yb.to(device)

    logits,_ = model(xb)
    loss = F.cross_entropy(logits.view(-1, ds_train.vocab_size), yb.view(-1))

    optimizer.zero_grad()

    loss.backward()

    optimizer.step()
    # t1 = time.time()
    # print(f'time for {i}th epoch: {(t1-t0):.2f} seconds')

    # validation
    with torch.no_grad():
        x_val, y_val = next(iter(loader_val))
        x_val, y_val = x_val.to(device), y_val.to(device)
        logits,_ = model(x_val)
        val_loss = F.cross_entropy(logits.view(-1, ds_train.vocab_size), y_val.view(-1))
    # early-stopping
    if val_loss < best_val_loss - 1e-4: # small delta to be considered
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(
            {
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'vocab': ds_train.vocab,
            },
            "models/char_lstm_pg.pt"
        )
    else:
        epochs_no_improve += 1
    if epochs_no_improve > patience:
        print(f'Eearly stop @ epoch {i}. \nBest validation loss = {best_val_loss:.4f}')
        break
    
    if i % 250 == 0:
        print(f'Iter {i}, train loss = {loss.item():.4f} | val loss = {val_loss.item():.4f}') #| best_val_loss = {best_val_loss}')
    t1 = time.time()
    print(f'time for {i}th epoch: {(t1-t0):.2f} seconds')

In [None]:
checkpoint = torch.load('models/char_lstm_best_model.pt', map_location=device)
print(checkpoint.keys())

In [None]:
print(checkpoint['vocab'])

vocab = checkpoint["vocab"]
stoi  = {c: i for i, c in enumerate(vocab)}
itos  = {i: c for c, i in stoi.items()}

In [None]:
# --- rebuild model -------------------------------
model = CharLSTM(len(vocab), emb=256, hidden=256,
                 layers=2, dropout=0.5).to(device)
model.load_state_dict(checkpoint["model_state"])
model.eval()

In [None]:
# --- rebuild optimizer (optional) ---------------------------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
optimizer.load_state_dict(checkpoint["optimizer_state"])

In [None]:
generate(model, ds_train.stoi, ds_train.itos, block_size=block_size, device='mps', max_new_tokens=500)