In [18]:
import os
import torch
import torch.nn as nn
import torch.optim as optim

from torchtext.legacy.data import Field
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

import nltk
import numpy as np
import pandas as pd
from tqdm import tqdm

import models

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def save_checkpoint(model, optimizer, loss, epoch, path):
    checkpoint = {
        "epoch" : epoch,
        "loss" : loss,
        "model_state_dict" : model.state_dict(),
        "optimizer_state_dict" : optimizer.state_dict(),
    }

    torch.save(checkpoint, path)

def load_checkpoint(model, path):
    checkpoint = torch.load(path, map_location = device)

    model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device)

    epoch = checkpoint["epoch"]
    loss = checkpoint["loss"]

    return epoch, loss

## Train Definitions

In [4]:
MAX_EPOCH  = 10
BATCH_SIZE = 2
INITIAL_LR = 1e-3

CLIP = 1 # ??

## Data Definitions

In [21]:
DATAFRAME_PATH = './dataframes'

In [5]:
def en_tokenizer(text: str) -> list:
    return nltk.word_tokenize(text, language = 'english',)

def tr_tokenizer(text: str) -> list:
    return text.split()

In [22]:
df = pd.read_csv(os.path.join(DATAFRAME_PATH, 'wmt16.csv'))

train_df = df[df.split == 'train']
valid_df = df[df.split == 'validation']
valid_df = valid_df.reset_index(drop = True)

### Build Vocabulary

In [7]:
en_field = Field(tokenize = en_tokenizer, init_token='<sos>', eos_token='<eos>', lower = False)
tr_field  = Field(tokenize = tr_tokenizer, init_token='<sos>', eos_token='<eos>', lower = False)

In [8]:
en_train_preprocessed_text = train_df['en'].apply(lambda x: en_field.preprocess(x))
tr_train_preprocessed_text = train_df['tr'].apply(lambda x: tr_field.preprocess(x))

en_field.build_vocab(en_train_preprocessed_text, min_freq = 2)
tr_field.build_vocab(tr_train_preprocessed_text, min_freq = 2)

en_vocab = en_field.vocab
tr_vocab = tr_field.vocab

In [9]:
# define tags
# only taking values from tr since they are equal in en_vocab as well
PAD_IDX = tr_vocab['<pad>']
SOS_IDX = tr_vocab['<sos']
EOS_IDX = tr_vocab['<eos>>']

In [10]:
def get_corpora_dataset(en_text: str, tr_text: str, en_vocab, tr_vocab):
    data = []

    for i in range(len(en_text)):
        en_tensor = torch.tensor([en_vocab[token] for token in en_text[i]], dtype = torch.long)
        tr_tensor = torch.tensor([tr_vocab[token] for token in tr_text[i]], dtype = torch.long)
        data.append((en_tensor, tr_tensor))

    return data

In [11]:
def generate_batch(data_batch):
    en_batch, tr_batch = [], []

    for (en_item, tr_item) in data_batch:
        en_batch.append(torch.cat([torch.tensor([SOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
        tr_batch.append(torch.cat([torch.tensor([SOS_IDX]), tr_item, torch.tensor([EOS_IDX])], dim=0))

    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
    tr_batch = pad_sequence(tr_batch, padding_value=PAD_IDX)
    return en_batch, tr_batch

In [12]:
en_valid_preprocessed_text = valid_df['en'].apply(lambda x: en_field.preprocess(x))
tr_valid_preprocessed_text = valid_df['tr'].apply(lambda x: tr_field.preprocess(x))

train_dataset = get_corpora_dataset(en_train_preprocessed_text, tr_train_preprocessed_text, en_vocab, tr_vocab)
valid_dataset = get_corpora_dataset(en_valid_preprocessed_text, tr_valid_preprocessed_text, en_vocab, tr_vocab)

In [13]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)

## Model Definitions

In [14]:
INPUT_DIM = len(en_vocab)
OUTPUT_DIM = len(tr_vocab)

# model hyperparams
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2

ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

enc = models.Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
dec = models.Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)

model = models.Seq2Seq(enc, dec, device).to(device)

In [15]:
# loss function
TRG_PAD_IDX = tr_vocab.stoi[tr_field.pad_token]
loss_fn = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

optimizer = optim.Adam(model.parameters(), lr=INITIAL_LR)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score

### Train & Eval Loop

In [19]:
best_valid_loss = 999999
plateau_counter = 0

for epoch in range(MAX_EPOCH):
    train_looper = tqdm(enumerate(train_loader), total=len(train_loader), leave = False, position = 0)
    train_looper.set_description("Epoch [{:003}]".format(epoch + 1))

    epoch_train_loss = 0
    epoch_valid_loss = 0

    #train
    model.train()
    for i, (src, trg) in train_looper:

        src = src.to(device)
        trg = trg.to(device)

        optimizer.zero_grad()

        output = model(src, trg)
        #trg = [trg len, batch size]
        #output = [trg len, batch size, output dim]

        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)

        #trg = [(trg len - 1) * batch size]
        #output = [(trg len - 1) * batch size, output dim]

        loss = loss_fn(output, trg)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)

        optimizer.step()
        epoch_train_loss += loss.item()
        train_looper.set_postfix(loss=loss.detach().item())

    #evaluate
    model.eval()
    with torch.no_grad():
        for _, (src, trg) in enumerate(valid_loader):
            src = src.to(device)
            trg = trg.to(device)

            output = model(src, trg, 0) #turn off teacher forcing

            output = output[1:].view(-1, output.shape[-1])
            trg = trg[1:].view(-1)

            loss = loss_fn(output, trg)
            epoch_valid_loss += loss.item()

    epoch_train_loss = epoch_train_loss / len(train_loader)
    epoch_valid_loss = epoch_valid_loss / len(valid_loader)

    scheduler.step(epoch_valid_loss)

    print("Epoch: {}, TrainLoss: {:.2f}, ValidLoss : {:.2f}, lr: {}".format(epoch + 1, epoch_train_loss, epoch_valid_loss,  optimizer.param_groups[0]['lr']))    

    # checkpoint
    if epoch_valid_loss < best_valid_loss:
        plateau_counter = 0
        best_valid_loss = epoch_valid_loss
        checkpoint_path = 'model_checkpoints/' + "model_" + str(epoch) + ".pkl"
        print("Saving ", checkpoint_path)
        save_checkpoint(model, optimizer, epoch_valid_loss, (epoch + 1), checkpoint_path)
    else:
        plateau_counter += 1
        if plateau_counter > 5:
            print("Early stopping...")
            break



Epoch: 1, TrainLoss: 11.68, ValidLoss : 3.89, lr: 0.001
Saving  model_checkpoints/model_0.pkl
