In [1]:
import sentencepiece
import torchtext
torchtext.disable_torchtext_deprecation_warning()
import pandas as pd
import numpy as np

%load_ext autoreload
%autoreload 2

In [2]:
from dataset import BHW2Dataset, BHW2Allin1Dataset
from torch.utils.data import DataLoader



def create_dataset(split : str, path_to_data="../data"):
    de = "{}/{}.de-en.de".format(path_to_data, split)
    de_dataset = BHW2Dataset(de)
    if split == "test1":
        return de_dataset

    en = "{}/{}.de-en.en".format(path_to_data, split)
    en_dataset = BHW2Dataset(en)
    return BHW2Allin1Dataset(de_dataset, en_dataset)


def create_dataloaders(path_to_data="../data"):
    train_set = create_dataset("train")
    val_set = create_dataset("val")
    test_set = create_dataset("test1")
    train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
    return train_loader, val_loader, test_loader

In [3]:
from typing import Union
import torch
from tqdm import tqdm

def train_epoch(model, loader, criterion, optimizer, device : Union[torch.device, str] ="cpu"):
    model.train()
    model.to(device)
    for de, de_lengths, en, en_lenghts in tqdm(loader):
        de_tokens = de[:, :de_lengths.max()].to(device)
        en_tokens = en[:, :en_lenghts.max()].to(device)
        # print(de_tokens.shape, en_tokens.shape)
        optimizer.zero_grad()
        logits = model(de_tokens, en_tokens[:, :-1])
        loss = criterion(logits.permute(0, 2, 1), en_tokens[:, 1:])
        loss.backward()
        optimizer.step()

    return loss


@torch.no_grad()
def validate_epoch(model, loader, criterion, device : Union[torch.device, str] ="cpu"):
    model.eval()
    model.to(device)
    for de, de_lengths, en, en_lenghts in tqdm(loader):
        de_tokens = de[:, :de_lengths.max()].to(device)
        en_tokens = en[:, :en_lenghts.max()].to(device)
        logits = model(de_tokens, en_tokens[:, :-1])
        loss = criterion(logits.permute(0, 2, 1), en_tokens[:, 1:])

    return loss

def train(model, train_loader, val_loader, optimizer, scheduler, criterion, n_epochs : int = 1, device : Union[torch.device, str] = "cpu"):
    for i in range(1, n_epochs + 1):
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device=device)
        val_loss = validate_epoch(model, val_loader, criterion, device=device)
        scheduler.step()
        print("Training epoch {} / {} : train_loss {}, val_loss {}".format(i, n_epochs, train_loss, val_loss))

In [4]:
import torchtext
torchtext.disable_torchtext_deprecation_warning()
import torch.nn as nn
import torch
from rnn_model import BHW2RNNModel
import warnings
warnings.filterwarnings("ignore")


device = torch.device("mps")
train_loader, val_loader, test_loader = create_dataloaders()
model = BHW2RNNModel(train_loader.dataset.de, train_loader.dataset.en, hidden_dim=128, device=device)

criterion = nn.CrossEntropyLoss(ignore_index=train_loader.dataset.en.pad_token)
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)

# for _ in range(100):
train(model, train_loader, val_loader, optimizer, scheduler, criterion, n_epochs=1, device=device)


100%|██████████| 4/4 [00:33<00:00,  8.42s/it]
100%|██████████| 4/4 [00:23<00:00,  5.99s/it]


Training epoch 1 / 1 : train_loss 9.75350570678711, val_loss 9.74924087524414


In [5]:
# torch.save(model.state_dict(), "./rnn.pth")

In [6]:
# from rnn_model import BHW2RNNModel

# model = BHW2RNNModel(train_loader.dataset.de, train_loader.dataset.en, hidden_dim=128, device=device)
# model.load_state_dict(torch.load("./rnn.pth", map_location=device))
# model.to(device)


In [7]:

model.dec.dataset.bos_token

2

In [8]:
model.to(device)
model.enc.to(device)
model.dec.to(device)
def form_test_set_predictions():
    translations = []
    for i in tqdm(range(len(test_loader.dataset))):
        idx = model.inference(test_loader.dataset[i][0].to(device))
        translations.append(train_loader.dataset.en.idx2token(idx)[1:-1])
        if i > 128:
            break
    return translations


form_test_set_predictions()

100%|██████████| 128/128 [00:23<00:00,  5.37it/s]


[['remember',
  'strait',
  'transactions',
  'squatters',
  'transactions',
  'i.',
  'vice',
  'photograph',
  'farms',
  'others',
  'fools',
  'descendants',
  'serif',
  'special',
  'basking',
  'branches',
  'canadian',
  'techie',
  'sovereignty',
  'killer',
  'surrogate',
  'encountering',
  'ignores',
  'profiles',
  'noor',
  'wakes',
  'gunfire',
  'irs',
  'garden',
  'requirement',
  'spacing',
  'liberties',
  'per',
  'twilight',
  'remembrance',
  '1997',
  'philharmonic',
  'tours',
  'receivers',
  'darwin',
  'homage',
  'singer',
  'unacceptable',
  'decommissioned',
  'sets',
  'peer-reviewed',
  'out',
  'continuity',
  'safer',
  'proud',
  'plumbing',
  'vanilla',
  'sink',
  'demand',
  'bulbs',
  'prior',
  'dark',
  'jimmy',
  'plumbing',
  'ones',
  'scanning',
  'block',
  'licenses',
  'on',
  'narrower',
  'awkward',
  'campfires',
  'trailer',
  'midwest',
  'collapsing',
  '130',
  'rupert',
  'weaknesses',
  'valentine',
  'tougher',
  'roofs',
  'tr

In [9]:
batch = next(iter(train_loader))



In [None]:
# batch.to(device)
de_tokens, de_lenghts = batch[0].to(device), batch[1].to(device)
en_tokens, en_lenghts = batch[2].to(device), batch[3].to(device)
de_tokens = de_tokens[:, :de_lenghts.max()]

embeds = model.enc.embeddings(de_tokens[:, :-1])
model.enc.rnn(embeds)[1].shape

model.dec()

TypeError: BHW2RNNDecoder.forward() missing 2 required positional arguments: 'encoder_hidden' and 'target_idx'