In [1]:
import model as trf
from process import get_data
from train import *
from config import *

In [2]:
import torch
import torch.nn as nn
from torchtext import data, datasets

In [9]:
DEVICE_SET = None
BATCH_SIZE = 5

In [3]:
data_path = "dataset/"
file_name = "train."
save_path = "model_save/transformer_epoch"

In [4]:
SRC, TGT, trn = get_data(data_path + file_name)

In [11]:
pad_idx = TGT.vocab.stoi["<pad>"]

In [12]:
model = trf.Transformer(len(SRC.vocab), len(TGT.vocab), N = 6)
#model.cuda() # .to("cuda:0") : 
criterion = nn.NLLLoss(ignore_index = 1)
#criterion.cuda()
model_opt = torch.optim.Adam(model.parameters(), lr = .1, betas=(.9, .98), eps = 1e-9)

In [13]:
train_iter = data.Iterator(trn, batch_size = BATCH_SIZE,
                          device = None,
                          repeat = False)

In [15]:
def train_model(epochs, model, criterion, model_opt, train_iter, save_path, print_every = 100):
    
    model.train()
    start = time.time()
    total_loss = 0
    mean_tokens = 0
    
    for epoch in range(epochs):
        for i, batch in enumerate(train_iter):
            
            src = batch.en
            tgt = batch.de
            
            # equalize sequence length of batches, originated from torchtext
            diff_ = src.size(-1) - tgt.size(-1)
            bal_pad = torch.ones(BATCH_SIZE, abs(diff_), dtype = torch.long)

            if diff_ < 0:
                src = torch.cat((src, bal_pad), dim = 1)
            elif diff_ > 0:
                tgt = torch.cat((tgt, bal_pad), dim = 1)
                
            bat = Batch(src, tgt) # from train.Batch
            
            hidden = model.forward(bat.src, bat.trg, bat.src_mask, bat.trg_mask)
            preds = model.generator(hidden)

            model_opt.zero_grad()

            loss = criterion(preds.contiguous().view(-1, preds.size(-1)),
                            bat.trg_y.contiguous().view(-1))
            loss.backward()

            model_opt.step()

            total_loss += loss.data
            mean_tokens += bat.ntokens / BATCH_SIZE

            if i % print_every == 0:
                elapsed = time.time() - start
                print("Iteration Step: %d Loss per token: %f per Sec: %f #(tokens) : %d" %
                        (i, total_loss / mean_tokens , elapsed, mean_tokens / print_every))
                start = time.time()
                
                total_loss = 0
                mean_tokens = 0

        torch.save(model.state_dict(), save_path + str(i) + ".pt") # save check point
        print("Epoch Step : %d is done" %(epcoh))
        

In [16]:
train_model(1, model, criterion, model_opt, train_iter, save_path, print_every = 5)

Iteration Step: 0 Loss per token: 0.627188 per Sec: 4.965740 #(tokens) : 3
Iteration Step: 5 Loss per token: 5.101102 per Sec: 33.744342 #(tokens) : 14
Iteration Step: 10 Loss per token: 3.167434 per Sec: 31.484677 #(tokens) : 18
Iteration Step: 15 Loss per token: 2.032048 per Sec: 26.491583 #(tokens) : 14
Iteration Step: 20 Loss per token: 0.863235 per Sec: 30.420695 #(tokens) : 17
Iteration Step: 25 Loss per token: 0.729698 per Sec: 32.973776 #(tokens) : 19


KeyboardInterrupt: 