In [1]:
from pipeline import Pipeline
from lang_pair import LangPair

from models.encoder import Encoder
from models.decoder import Decoder
from models.attn import Attn

from coach import Coach
import torch.optim as optim
import torch.nn as nn
import torch

import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
vi_vi_vocab, vi_en_vocab = Pipeline.load("vi_vi_train_10_chars_10k_vocab").data, Pipeline.load("vi_en_train_10_chars_10k_vocab").data
vi_vi_idxs, vi_en_idxs = Pipeline.load("vi_idx_10chars_filter").data

In [4]:
vi_en_pair = LangPair(vi_vi_idxs, vi_vi_vocab.eos_idx, vi_en_idxs, vi_en_vocab.eos_idx, device = device)

In [5]:
hidden_size = 256
batch_size = 32
learning_rate = .0001

In [6]:
enc_params = {
    "input_vocab_size": vi_vi_vocab.size,
    "hidden_size": hidden_size,
    "n_layers": 2,
    "dropout": .3
}

In [7]:
dec_params = {
    "target_vocab_size": vi_en_vocab.size,
    "hidden_size": hidden_size,
    "n_layers": 2,
    "dropout": .3,
    "batch_size": batch_size
}

In [8]:
attn_params = {
    "hidden_size": hidden_size,
    "method": "dot"
}

In [9]:
attn = Attn(**attn_params).to(device)
encoder = Encoder(**enc_params).to(device)
decoder = Decoder(**dec_params, attn = attn).to(device)

In [10]:
enc_optimizer = optim.SGD(encoder.parameters(), lr = learning_rate)
dec_optimizer = optim.SGD(decoder.parameters(), lr = learning_rate)
loss_fn = nn.NLLLoss()

In [11]:
coach_params = {
    "lang_pair": vi_en_pair, 
    "encoder": encoder, 
    "enc_optimizer": enc_optimizer, 
    "decoder": decoder, 
    "dec_optimizer": dec_optimizer, 
    "loss_fn": loss_fn
}

coach = Coach(**coach_params)

In [None]:
training_params = {
    "learning_rate": learning_rate,
    "iterations": 10000,
    "print_interval": 1000,
    "batch_size": batch_size
}

In [None]:
losses = coach.train(**training_params)

Training Iterations:   4%|▍         | 14/312 [00:17<05:57,  1.20s/batch]

In [None]:
plt.plot(pd.Series(losses).rolling(100).mean())
plt.show()