In [None]:
import time
import torch.nn.functional as F
import torch
from lstm import MultiLayerLSTM, estimate_loss, TextLoader, generate

In [None]:
text = open('/kaggle/input/shakespeare-text/input.txt', 'r', encoding='utf-8').read()
loader = TextLoader(text)

# reduce dims a little if you want faster iteration
n_embd = 384   # try 384 or 512
n_hidden = 384
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 64
block_size = 128

model = MultiLayerLSTM(loader.vocab_size, input_embd=n_embd, hidden_embd=n_hidden, layers=2, dropout=0.3)
model.to(device)
print(f'{sum(p.numel() for p in model.parameters()) / 1e6:.3f} MIL params')

optim = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)
max_iters = 20000
patience = 20
best_val_loss = float('inf')
epochs_no_improve = 0

In [None]:
for i in range(max_iters):
    x, y = loader.get_batch('train', batch_size=batch_size, block_size=block_size, device=device)
    B,T = x.shape

    logits = model(x)
    loss = F.cross_entropy(logits.view(B*T, -1), y.view(B*T))

    optim.zero_grad()
    loss.backward()
    # gradient clipping
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optim.step()

    if i % 200 == 0:
        with torch.no_grad():
            x_val, y_val = loader.get_batch('val', batch_size=512, block_size=block_size, device=device)
            logits_val = model(x_val)
            val_loss = F.cross_entropy(logits_val.view(-1, loader.vocab_size), y_val.view(-1))
        if val_loss < best_val_loss - 1e-4:
            best_val_loss = val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        print(f'Iter {i} | train loss = {loss.item():.4f} | val loss = {val_loss.item():.4f}')
        # if i % 1000 == 0:
        #     generate(model, loader.stoi, loader.itos, block_size=block_size, prompt="ROMEO:", device=device, max_new_tokens=200)

    if epochs_no_improve > patience:
        print(f'Early stopping at iter {i}. Best val loss = {best_val_loss:.4f}')
        break

In [None]:
generate(model, loader.stoi, loader.itos, block_size=block_size, device='cuda', max_new_tokens=500)