For the purpose of creating a playground-like teaching and learning environment, all the code used here is stored in a ipython or jupyter notebook. import_ipynb is used to import classes and functions from other notebooks. Since these imports will run all the cells in the notebook you are importing from, once you have finished playing with a certain module or cell in a notebook, comment out the cells you would rather not have executed when the notebook is imported. 

In [1]:
import math, copy

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

import import_ipynb
from notebooks.Data_mover import csv2datatools, Options, num_batches
from notebooks.Encoder import Encoder 
from notebooks.Decoder import Decoder 
from notebooks.Trainer import trainer, CosineWithRestarts

%load_ext autoreload
%autoreload 2

importing Jupyter notebook from /media/carson/New Volume/Chloe/chloebot/notebooks/Data_mover.ipynb
importing Jupyter notebook from /media/carson/New Volume/Chloe/chloebot/notebooks/Encoder.ipynb
importing Jupyter notebook from /media/carson/New Volume/Chloe/chloebot/notebooks/Decoder.ipynb
importing Jupyter notebook from /media/carson/New Volume/Chloe/chloebot/notebooks/Trainer.ipynb


In [2]:
opt = Options(batchsize = 4)
data_iter, infield, outfield, opt = csv2datatools('saved/translation_pairs.csv','en', opt)

In [3]:
class Transformer(nn.Module):
    def __init__(self, in_vocab_size, out_vocab_size, emb_dim, n_layers, heads, dropout):
        super().__init__()
        self.encoder = Encoder(in_vocab_size, emb_dim, n_layers, heads, dropout)
        self.decoder = Decoder(out_vocab_size, emb_dim, n_layers, heads, dropout)
        self.out = nn.Linear(emb_dim, out_vocab_size)
    def forward(self, src_seq, trg_seq, src_mask, trg_mask):
        e_output = self.encoder(src_seq, src_mask)
        d_output = self.decoder(trg_seq, e_output, src_mask, trg_mask)
        output = self.out(d_output)
        return output

In [4]:
emb_dim, n_layers, heads, dropout = 64, 2, 8, 0.1 
opt.save_path = 'saved/weights/model_weights'
model = Transformer(len(infield.vocab), len(outfield.vocab), emb_dim, n_layers, heads, dropout)
if opt.device != -1:
    model = model.cuda()
model.load_state_dict(torch.load(opt.save_path))

<All keys matched successfully>

In [5]:
opt.lr = 0.001 #0.0001
opt.epochs = 20 
optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.98), eps=1e-9)
scheduler = CosineWithRestarts(optimizer, T_max=num_batches(data_iter))

In [6]:
model = trainer(model, data_iter, opt, optimizer, scheduler)

0m: epoch 0 loss = 0.004
0m: epoch 1 loss = 0.006
0m: epoch 2 loss = 0.004
0m: epoch 3 loss = 0.005
0m: epoch 4 loss = 0.003
0m: epoch 5 loss = 0.003
0m: epoch 6 loss = 0.003
0m: epoch 7 loss = 0.003
0m: epoch 8 loss = 0.003
0m: epoch 9 loss = 0.003
0m: epoch 10 loss = 0.003
0m: epoch 11 loss = 0.002
0m: epoch 12 loss = 0.002
0m: epoch 13 loss = 0.002
0m: epoch 14 loss = 0.002
0m: epoch 15 loss = 0.002
0m: epoch 16 loss = 0.002
0m: epoch 17 loss = 0.002
0m: epoch 18 loss = 0.001
0m: epoch 19 loss = 0.001


Transformer(
  (encoder): Encoder(
    (embed): Embedder(
      (embed): Embedding(19, 64)
    )
    (pe): PositionalEncoder(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): ModuleList(
      (0): EncoderLayer(
        (norm_1): Norm()
        (dropout_1): Dropout(p=0.1, inplace=False)
        (attn): MultiHeadAttention(
          (q_linear): Linear(in_features=64, out_features=64, bias=True)
          (v_linear): Linear(in_features=64, out_features=64, bias=True)
          (k_linear): Linear(in_features=64, out_features=64, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (out): Linear(in_features=64, out_features=64, bias=True)
        )
        (norm_2): Norm()
        (ff): FeedForward(
          (linear_1): Linear(in_features=64, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear_2): Linear(in_features=2048, out_features=64, bias=True)
        )
        (dropout_2): Dropout(p=0.1, inplace=Fal