In [48]:
import sentencepiece
import torchtext
torchtext.disable_torchtext_deprecation_warning()
import pandas as pd
import numpy as np

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [49]:
from dataset import BHW2Dataset, BHW2Allin1Dataset
from torch.utils.data import DataLoader



def create_dataset(split : str, path_to_data="../data"):
    de = "{}/{}.de-en.de".format(path_to_data, split)
    de_dataset = BHW2Dataset(de)
    if split == "test1":
        return de_dataset

    en = "{}/{}.de-en.en".format(path_to_data, split)
    en_dataset = BHW2Dataset(en)
    return BHW2Allin1Dataset(de_dataset, en_dataset)


def create_dataloaders(path_to_data="../data"):
    train_set = create_dataset("train")
    val_set = create_dataset("val")
    test_set = create_dataset("test1")
    train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
    return train_loader, val_loader, test_loader

In [51]:
from typing import Union
import torch
from tqdm import tqdm

def train_epoch(model, loader, criterion, optimizer, device : Union[torch.device, str] ="cpu"):
    model.train()
    model.to(device)
    for de, de_lengths, en, en_lenghts in tqdm(loader):
        de_tokens = de[:, :de_lengths.max()].to(device)
        en_tokens = en[:, :en_lenghts.max()].to(device)
        # print(de_tokens.shape, en_tokens.shape)
        optimizer.zero_grad()
        logits = model(de_tokens, en_tokens[:, :-1])
        loss = criterion(logits.permute(0, 2, 1), en_tokens[:, 1:])
        loss.backward()
        optimizer.step()

    return loss


@torch.no_grad()
def validate_epoch(model, loader, criterion, device : Union[torch.device, str] ="cpu"):
    model.eval()
    model.to(device)
    for de, de_lengths, en, en_lenghts in tqdm(loader):
        de_tokens = de[:, :de_lengths.max()].to(device)
        en_tokens = en[:, :en_lenghts.max()].to(device)
        logits = model(de_tokens, en_tokens[:, :-1])
        loss = criterion(logits.permute(0, 2, 1), en_tokens[:, 1:])

    return loss

def train(model, train_loader, val_loader, optimizer, scheduler, criterion, n_epochs : int = 1, device : Union[torch.device, str] = "cpu"):
    for i in range(1, n_epochs + 1):
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device=device)
        val_loss = validate_epoch(model, val_loader, criterion, device=device)
        scheduler.step()
        print("Training epoch {} / {} : train_loss {}, val_loss {}".format(i, n_epochs, train_loss, val_loss))

In [None]:
import torch.nn as nn
import torch
from rnn_model import BHW2RNNModel

train_loader, val_loader, test_loader = create_dataloaders()
model = BHW2RNNModel(train_loader.dataset.de, train_loader.dataset.en, hidden_dim=128)

criterion = nn.CrossEntropyLoss(ignore_index=train_loader.dataset.en.pad_token)
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
device = torch.device("mps")

# for _ in range(100):
train(model, train_loader, val_loader, optimizer, scheduler, criterion, n_epochs=1, device=device)


  0%|          | 1/6123 [00:11<19:06:50, 11.24s/it]

: 