In [1]:
import sentencepiece
import torchtext
torchtext.disable_torchtext_deprecation_warning()
import pandas as pd
import numpy as np

%load_ext autoreload
%autoreload 2



In [3]:
from dataset import BHW2Dataset, BHW2Allin1Dataset
from torch.utils.data import DataLoader



def create_dataset(split : str, path_to_data="../data"):
    de = "{}/{}.de-en.de".format(path_to_data, split)
    de_dataset = BHW2Dataset(de)
    if split == "test1":
        return de_dataset

    en = "{}/{}.de-en.en".format(path_to_data, split)
    en_dataset = BHW2Dataset(en)
    return BHW2Allin1Dataset(de_dataset, en_dataset)


def create_dataloaders(path_to_data="../data"):
    train_set = create_dataset("train")
    val_set = create_dataset("val")
    test_set = create_dataset("test1")
    train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
    return train_loader, val_loader, test_loader

In [4]:
from typing import Union
import torch
from tqdm import tqdm

def train_epoch(model, loader, criterion, optimizer, device : Union[torch.device, str] ="cpu"):
    model.train()
    model.to(device)
    for de, de_lengths, en, en_lenghts in tqdm(loader):
        de_tokens = de[:, :de_lengths.max()].to(device)
        en_tokens = en[:, :en_lenghts.max()].to(device)
        # print(de_tokens.shape, en_tokens.shape)
        optimizer.zero_grad()
        logits = model(de_tokens, en_tokens[:, :-1])
        loss = criterion(logits.permute(0, 2, 1), en_tokens[:, 1:])
        loss.backward()
        optimizer.step()

    return loss


@torch.no_grad()
def validate_epoch(model, loader, criterion, device : Union[torch.device, str] ="cpu"):
    model.eval()
    model.to(device)
    for de, de_lengths, en, en_lenghts in tqdm(loader):
        de_tokens = de[:, :de_lengths.max()].to(device)
        en_tokens = en[:, :en_lenghts.max()].to(device)
        logits = model(de_tokens, en_tokens[:, :-1])
        loss = criterion(logits.permute(0, 2, 1), en_tokens[:, 1:])

    return loss

def train(model, train_loader, val_loader, optimizer, scheduler, criterion, n_epochs : int = 1, device : Union[torch.device, str] = "cpu"):
    for i in range(1, n_epochs + 1):
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device=device)
        val_loss = validate_epoch(model, val_loader, criterion, device=device)
        scheduler.step()
        print("Training epoch {} / {} : train_loss {}, val_loss {}".format(i, n_epochs, train_loss, val_loss))

In [None]:
import torchtext
torchtext.disable_torchtext_deprecation_warning()
import torch.nn as nn
import torch
# from rnn_model import BHW2AttnRNNModel as BHW2RNNModel
from rnn_model import BHW2RNNModel
import warnings
warnings.filterwarnings("ignore")


device = torch.device("mps")
train_loader, val_loader, test_loader = create_dataloaders()
model = BHW2RNNModel(train_loader.dataset.de, train_loader.dataset.en, hidden_dim=512, device=device)

criterion = nn.CrossEntropyLoss(ignore_index=train_loader.dataset.en.pad_token)
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)




In [6]:

# train(model, train_loader, val_loader, optimizer, scheduler, criterion, n_epochs=1, device=device)

In [7]:
# torch.save(model.state_dict(), "./rnn.pth")

In [8]:
# from rnn_model import BHW2RNNModel

# model = BHW2RNNModel(train_loader.dataset.de, train_loader.dataset.en, hidden_dim=128, device=device)
# model.load_state_dict(torch.load("./rnn.pth", map_location=device))
# model.to(device)


In [22]:
model.to(device)


def form_test_set_predictions():
    translations = []
    for i in tqdm(range(len(test_loader.dataset))):
        idx = model.inference(test_loader.dataset[i][0].to(device))
        translations.append(train_loader.dataset.en.idx2token(idx)[1:-1])
        if i > 128:
            break
    return translations


form_test_set_predictions()

100%|██████████| 128/128 [00:47<00:00,  2.68it/s]


[['and',
  'the',
  'the',
  'the',
  'the',
  'the',
  'the',
  'the',
  'the',
  'the',
  'the',
  'the',
  'the',
  'the',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'the',
  'the',
  'the',
  ',',
  'the',
  'th

In [9]:
# batch = next(iter(train_loader))

In [24]:
from rnn_torch import EncoderRNN, AttnDecoderRNN

HIDDEN_SIZE = 128
encoder = EncoderRNN(train_loader.dataset.de.vocab_size, HIDDEN_SIZE).to(device)
decoder = AttnDecoderRNN(HIDDEN_SIZE, train_loader.dataset.en.vocab_size, device=device).to(device)
# criterion = nn.CrossEntropyLoss(ignore_index=1)



def train_epoch(dataloader, encoder, decoder, encoder_optimizer,
          decoder_optimizer, criterion, device):

    total_loss = 0
    for data in tqdm(dataloader):
        input_tensor, input_lengths, target_tensor, target_lengths = data
        input_tensor = input_tensor[:, :input_lengths.max()].to(device)
        target_tensor = target_tensor[:, :target_lengths.max()].to(device)

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor[:, 1:-1])


        # print(decoder_outputs.size())
        # print(target_tensor.size())
        if isinstance(criterion, nn.NLLLoss):
            loss = criterion(
                decoder_outputs.view(-1, decoder_outputs.size(-1)),
                target_tensor.view(-1)
            )
        else:
            loss = criterion(
                decoder_outputs.permute(0, 2, 1),
                target_tensor[:, 1:]
            )
        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

@torch.no_grad
def validate_epoch(dataloader, encoder, decoder, criterion, device):
    total_loss = 0
    for data in tqdm(dataloader):
        input_tensor, input_lengths, target_tensor, target_lengths = data
        input_tensor = input_tensor[:, :input_lengths.max()].to(device)
        target_tensor = target_tensor[:, :target_lengths.max()].to(device)


        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor[:, 1:-1])


        # print(decoder_outputs.size())
        # print(target_tensor.size())
        if isinstance(criterion, nn.NLLLoss):
            loss = criterion(
                decoder_outputs.view(-1, decoder_outputs.size(-1)),
                target_tensor.view(-1)
            )
        else:
            loss = criterion(
                decoder_outputs.permute(0, 2, 1),
                target_tensor[:, 1:]
            )


        total_loss += loss.item()

    return total_loss / len(dataloader)


def train(encoder, decoder, encoder_optimizer, decoder_optimizer, train_loader, val_loader, scheduler, criterion, n_epochs : int = 1, device : Union[torch.device, str] = "cpu"):
    for i in range(1, n_epochs + 1):
        encoder_optimizer = torch.optim.Adam(encoder.parameters())
        decoder_optimizer = torch.optim.Adam(decoder.parameters())
        train_loss = train_epoch(train_loader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, device=device)
        val_loss = validate_epoch(val_loader, encoder, decoder, criterion, device)
        # scheduler.step()
        print("Training epoch {} / {} : train_loss {}, val_loss {}".format(i, n_epochs, train_loss, val_loss))

In [25]:
torchtext.disable_torchtext_deprecation_warning()

loader = DataLoader(train_loader.dataset, batch_size=2, shuffle=False)

train(encoder, decoder, None, None, train_loader, val_loader, None, criterion, device=device)

100%|██████████| 4/4 [00:37<00:00,  9.37s/it]
100%|██████████| 4/4 [00:24<00:00,  6.07s/it]

Training epoch 1 / 1 : train_loss 10.727508544921875, val_loss 10.586227655410767



