In [1]:
from pipeline import Pipeline
from lang_pair import LangPair

from models.encoder import Encoder
from models.decoder import Decoder
from models.attn import Attn

from coach import Coach
from translator import Translator

import torch.optim as optim
import torch.nn as nn
import torch

import pandas as pd
import io

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# vi_vi_vocab, vi_en_vocab = Pipeline.load("vi_vi_train").data, Pipeline.load("vi_en_train").data
# vi_en_pair = LangPair(vi_vi_vocab, vi_en_vocab, device = device)
# with open("transforms/vi_en_lang_pair.pkl", "wb+") as f:
#     torch.save(vi_en_pair, f)

In [4]:
with open("transforms/vi_en_lang_pair.pkl", "rb+") as f:
    lang_pair = torch.load(f)
with open("vi_en_validation_lang_pair.pkl", "rb+") as f:
    valid_lang_pair = torch.load(f)

In [5]:
hidden_size = 100
batch_size = 20
learning_rate = .1
embed_size = 250

In [6]:
enc_params = {
    "input_vocab_size": lang_pair.lang1_vocab.size,
    "hidden_size": hidden_size,
    "n_layers": 1,
    "dropout": 0,
    "embed_size": embed_size
}

In [7]:
dec_params = {
    "target_vocab_size": lang_pair.lang2_vocab.size,
    "hidden_size": hidden_size,
    "n_layers": 1,
    "dropout": 0,
}

In [8]:
attn_params = {
    "hidden_size": hidden_size,
    "method": "dot"
}

In [9]:
attn = Attn(**attn_params).to(device)
encoder = Encoder(**enc_params).to(device)
decoder = Decoder(**dec_params, attn = attn).to(device)
decoder_attn = Decoder(**dec_params, attn = attn).to(device)

In [10]:
enc_optimizer = optim.SGD(encoder.parameters(), lr = learning_rate)
dec_optimizer = optim.SGD(decoder.parameters(), lr = learning_rate)
dec_attn_optimizer = optim.SGD(decoder_attn.parameters(), lr = learning_rate)
loss_fn = nn.NLLLoss()

In [11]:
coach_params = {
    "lang_pair": lang_pair, 
    "encoder": encoder, 
    "enc_optimizer": enc_optimizer, 
    "decoder": decoder, 
    "dec_optimizer": dec_optimizer, 
    "loss_fn": loss_fn,
    "device": device
}

coach_attn_params = {
    **coach_params,
    "dec_optimizer": dec_attn_optimizer,
    "decoder": decoder_attn
}

coach = Coach(**coach_params)
coach_attn = Coach(**coach_attn_params)

In [12]:
rand_training_params = {
    "learning_rate": learning_rate,
    "iterations": 10000,
    "print_interval": 1000,
    "batch_size": batch_size
}

epoch_training_params = {
    "num_epochs": 2,
    "print_interval": 5000,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "percent_of_data": 1
}

In [13]:
# losses = coach.train_random(**rand_training_params)
losses = coach.train_random(**rand_training_params)
with open("model_test.pkl", "wb") as f:
    torch.save(coach, f)

Fetching batches...



Training Iterations:  10%|█         | 51/500 [02:17<21:48,  2.91s/ batch]

Interval (1/10) average loss: -3.1894


Training Iterations:  20%|██        | 101/500 [04:27<17:39,  2.66s/ batch]

Interval (2/10) average loss: -9.1148


Training Iterations:  30%|███       | 151/500 [06:08<10:26,  1.79s/ batch]

Interval (3/10) average loss: -16.2682


Training Iterations:  40%|████      | 201/500 [08:00<14:49,  2.97s/ batch]

Interval (4/10) average loss: -22.9321


Training Iterations:  50%|█████     | 251/500 [09:49<10:01,  2.41s/ batch]

Interval (5/10) average loss: -28.3475


Training Iterations:  60%|██████    | 301/500 [11:17<07:32,  2.27s/ batch]

Interval (6/10) average loss: -35.8933


Training Iterations:  70%|███████   | 351/500 [13:26<07:12,  2.90s/ batch]

Interval (7/10) average loss: -41.3195


Training Iterations:  80%|████████  | 401/500 [15:44<04:54,  2.98s/ batch]

Interval (8/10) average loss: -49.2310


Training Iterations:  90%|█████████ | 451/500 [17:36<02:23,  2.94s/ batch]

Interval (9/10) average loss: -58.2209


Training Iterations: 100%|██████████| 500/500 [19:38<00:00,  1.95s/ batch]


In [14]:
# losses = coach.train_random(**rand_training_params)
attn_losses, attns = coach_attn.train_random(**rand_training_params)
with open("model_attn_test.pkl", "wb") as f:
    torch.save(coach_attn, f)

Fetching batches...



Training Iterations:  10%|█         | 51/500 [01:46<17:30,  2.34s/ batch]

Interval (1/10) average loss: -3.2547


Training Iterations:  20%|██        | 101/500 [03:55<26:40,  4.01s/ batch]

Interval (2/10) average loss: -9.0167


Training Iterations:  30%|███       | 151/500 [06:23<15:21,  2.64s/ batch]

Interval (3/10) average loss: -14.8629


Training Iterations:  40%|████      | 201/500 [07:58<09:26,  1.89s/ batch]

Interval (4/10) average loss: -23.3650


Training Iterations:  50%|█████     | 251/500 [10:01<06:56,  1.67s/ batch]

Interval (5/10) average loss: -28.8463


Training Iterations:  60%|██████    | 301/500 [11:44<05:45,  1.74s/ batch]

Interval (6/10) average loss: -36.5873


Training Iterations:  70%|███████   | 351/500 [13:32<06:57,  2.80s/ batch]

Interval (7/10) average loss: -39.3202


Training Iterations:  80%|████████  | 401/500 [15:37<05:01,  3.04s/ batch]

Interval (8/10) average loss: -45.5428


Training Iterations:  90%|█████████ | 451/500 [17:41<01:11,  1.46s/ batch]

Interval (9/10) average loss: -53.7073


Training Iterations: 100%|██████████| 500/500 [19:14<00:00,  1.31s/ batch]
