In [None]:
import torch
from src.data.components.helsinki_dataset import Dataset, load_dataset, pad

In [None]:
from omegaconf import DictConfig, OmegaConf

config = dict()
config["train_set"] = "train_360"
config[
    "datadir"
] = "/Users/lukas/Desktop/projects/MIT/prosody/prosody/repositories/helsinki-prosody/data/"
config["fraction_of_train_data"] = 1
config["nclasses"] = 2
config["shuffle_sentences"] = True
config["sorted_batches"] = True
config["model"] = "gpt2"
config["log_values"] = False
config["invalid_set_to"] = False
config["mask_invalid_grads"] = True

config = OmegaConf.create(config)

In [None]:
# create splits
splits, tag_to_index, index_to_tag, vocab = load_dataset(config)

In [None]:
word_to_embid = None

train_dataset = Dataset(splits["train"], tag_to_index, config, word_to_embid)
eval_dataset = Dataset(splits["dev"], tag_to_index, config, word_to_embid)
test_dataset = Dataset(splits["test"], tag_to_index, config, word_to_embid)

In [None]:
dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=pad
)

len(dataloader)

In [None]:
from transformers import BertModel, BertTokenizer, GPT2Tokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

from src.models.helsinki_models import BertRegression
from src.data.components.helsinki_dataset import weighted_mse_loss

model = BertRegression("mps", config).to("mps")

device = "mps"

# criterion = weighted_mse_loss
criterion = torch.nn.MSELoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.00005)

In [None]:
# create queue with at most 20 elements
from collections import deque
from tqdm import tqdm

queue = deque(maxlen=50)

total_iterations = len(dataloader)
pbar = tqdm(
    total=total_iterations,
    desc="Loss: N/A",
    bar_format="{desc} |{bar}| {percentage:3.0f}% {r_bar}",
)

for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
    words, x, is_main_piece, tags, y, seqlens, values, _ = batch

    # REGRESSION
    # optimizer.zero_grad()
    # x = x.to(device)
    # values = values.to(device)
    # predictions, true = model(x, values)
    # loss = criterion(predictions.to(device), true.float().to(device))
    # loss.backward()
    # optimizer.step()

    optimizer.zero_grad()
    x = x.to(device)
    y = y.to(device)
    logits, y, _ = model(x, y)  # logits: (N, T, VOCAB), y: (N, T)
    logits = logits.view(-1, logits.shape[-1])  # (N*T, VOCAB)
    y = y.view(-1)  # (N*T,)
    loss = criterion(logits.to(device), y.to(device))
    loss.backward()
    optimizer.step()

    queue.append(loss.item())

    if (i + 1) % 50 == 0:
        print(f"Avg loss: {sum(queue) / len(queue)}")