In [1]:
import torch
import torch.nn.functional as F
import time

from in_torch.text_loader import CharDataset, TextLoader
from torch.utils.data import DataLoader
from in_torch.lstm import CharLSTM, generate

# ---------------------------------------------
# data prep
text = open('../gpt/input.txt', 'r').read()

batch_size = 16
block_size = 64

base_ds = CharDataset(text, block_size)
ds_train = base_ds.split('train')
ds_val = base_ds.split('val')

# pin_memory = torch.cuda.is_available()
loader_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4, drop_last=True)
loader_val = DataLoader(ds_val, batch_size=1000, shuffle=True, pin_memory=True, num_workers=4, drop_last=True)

# ------------------------------------------------------------
device = 'mps'
model = CharLSTM(ds_train.vocab_size, emb=256, hidden=256, layers=3)
print(f'{(sum(p.numel() for p in model.parameters()) / 1e6):.4f} Mil parameters')
model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)


1.6124 Mil parameters


In [None]:
max_iters = 20000
patience = 20
best_val_loss = float('inf')
epochs_no_improve = 0

for i in range(max_iters):
    # t0 = time.time()
    xb,yb = next(iter(loader_train))
    xb,yb = xb.to(device), yb.to(device)

    logits,_ = model(xb)
    loss = F.cross_entropy(logits.view(-1, ds_train.vocab_size), yb.view(-1))
        
    optimizer.zero_grad()

    loss.backward()

    optimizer.step()
    # t1 = time.time()
    # print(f'time for {i}th epoch: {(t1-t0):.2f} seconds')

    # validation
    with torch.no_grad():
        x_val, y_val = next(iter(loader_val))
        x_val, y_val = x_val.to(device), y_val.to(device)
        logits,_ = model(x_val)
        val_loss = F.cross_entropy(logits.view(-1, ds_train.vocab_size), y_val.view(-1))
    # early-stopping
    if val_loss < best_val_loss - 1e-4: # small delta to be considered
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(
            {
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'vocab': ds_train.vocab,
            },
            "char_lstm_best_model.pt"
        )
    else:
        epochs_no_improve += 1
    if epochs_no_improve > patience:
        print(f'Eearly stop @ epoch {i}. \nBest validation loss = {best_val_loss:.4f}')
        break
    
    if i % 10 == 0:
        print(f'Iter {i}, train loss = {loss.item():.4f} | val loss = {val_loss.item():.4f}') #| best_val_loss = {best_val_loss}')



time for 0th epoch: 2.50 seconds
Iter 0, train loss = 4.1883 | val loss = 4.1798
time for 1th epoch: 2.80 seconds
time for 2th epoch: 1.56 seconds
time for 3th epoch: 1.48 seconds
time for 4th epoch: 2.74 seconds
time for 5th epoch: 2.31 seconds
time for 6th epoch: 2.17 seconds
time for 7th epoch: 1.64 seconds
time for 8th epoch: 2.03 seconds
time for 9th epoch: 2.16 seconds
time for 10th epoch: 1.63 seconds
Iter 10, train loss = 3.9979 | val loss = 3.9543
Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/macbook/miniforge3/envs/torch/lib/python3.12/multiprocessing/__init__.py", line 16, in <module>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/macbook/miniforge3/envs/torch/lib/python3.12/multiprocessing/spawn.py", line 118, in spawn_main
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/macbook/miniforge3/envs/torch/lib/python3.12/multiprocessing/__init__.py", line 16, in <module>
    from . import resource_tracker
  File "/Users/macbook/miniforge3/envs/torch/lib/python3.12/multiprocessing/resource_tracker.py", line 37, in <module>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/macbook/miniforge3/envs/torch/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
    import _multiprocessing
KeyboardInterrupt
            from .

In [None]:
checkpoint = torch.load('char_lstm_best_model.pt', map_location=device)
print(checkpoint.keys())

In [None]:
print(checkpoint['vocab'])

vocab = checkpoint["vocab"]
stoi  = {c: i for i, c in enumerate(vocab)}
itos  = {i: c for c, i in stoi.items()}

In [None]:
# --- rebuild model -------------------------------
model = CharLSTM(len(vocab), emb=256, hidden=256,
                 layers=2, dropout=0.5).to(device)
model.load_state_dict(checkpoint["model_state"])
model.eval()

In [None]:
# --- rebuild optimizer (optional) ---------------------------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
optimizer.load_state_dict(checkpoint["optimizer_state"])

In [None]:
generate(model, stoi, itos, block_size=block_size, device='mps')