In [None]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from tqdm import tqdm
import string

from dataset import ELMoPretrainDataset, PadCollate
from model import ELMoPretrainModel

In [None]:
dataset = ELMoPretrainDataset("wikitext-2/wiki.train.tokens", seq_len=100)
loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=PadCollate())

In [None]:
model = ELMoPretrainModel(len(dataset.char2idx), 128, len(dataset.word2idx)).to("mps")
crit = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), 3e-6)
EPOCHS = 100

In [None]:
for e in range(EPOCHS):
    loop = tqdm(loader, total=len(loader), position=0)
    loop.set_description(f"Epoch : [{e}/{EPOCHS}]")
    for src, tgt in loop:
        src, tgt = src.to("mps"), tgt.to("mps")
        opt.zero_grad()
        yhat = model(src)
        loss = crit(yhat.view(-1, yhat.shape[-1]), tgt.view(-1))
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        opt.step()
        print(loss.item())
        loop.set_postfix(loss = loss.item())