In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
from torchtext import data, vocab

import time
import datetime
import logging
import sys

from utils import execute_and_time, get_device, itos
from preprocess import Batch, embedding_param
from model import Transformer
from optimize import get_default_optimizer, train

In [3]:
LOG_FILE = False
logger = logging.getLogger()
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
if LOG_FILE:
    file_handler = logging.FileHandler('log.out')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler) 
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger.setLevel(logging.INFO)

In [4]:
DATA_PATH = 'data/'
SAMPLE_DATA_PATH = f'{DATA_PATH}sample_data/'
PROCESSED_DATA_PATH = f'{DATA_PATH}processed_data/'

pre_trained_vector_type = 'glove.6B.200d' 
batch_size = 64
device = get_device(3)
stack_number = 6
heads_number = 8
fix_length = 80

2019-01-20 18:17:11,301 INFO GPU available, using NO.3.


In [5]:
%%time
tokenizer = data.get_tokenizer('spacy')
TEXT = data.Field(tokenize=tokenizer, lower=True, eos_token='_eos_', fix_length=fix_length)
trn_data_fields = [("source", TEXT), ("target", TEXT)]
trn, vld = data.TabularDataset.splits(path=f'{SAMPLE_DATA_PATH}',
                           train='train_ds.csv', 
                           validation='valid_ds.csv',
                           format='csv', 
                           skip_header=True, 
                           fields=trn_data_fields)

CPU times: user 33.1 s, sys: 192 ms, total: 33.3 s
Wall time: 33.3 s


In [6]:
TEXT.build_vocab(trn, vectors=pre_trained_vector_type)
vocabulary = TEXT.vocab
vocab_size = len(vocabulary)

2019-01-20 18:18:12,043 INFO Loading vectors from .vector_cache/glove.6B.200d.txt.pt


In [14]:
train_iter, val_iter = data.BucketIterator.splits((trn, vld), 
                                                  batch_sizes=(batch_size, int(batch_size * 1.6)),
                                                  device=device, 
                                                  sort_key=lambda x: len(x.source),
                                                  shuffle=True, 
                                                  sort_within_batch=False, 
                                                  repeat=True)
train_iter = Batch(train_iter, "source", "target", device, vocabulary)
val_iter = Batch(val_iter, "source", "target", device, vocabulary)
# train_iter, val_iter = iter(train_iter_tuple), iter(val_iter_tuple)

In [15]:
batch = next(train_iter)
print(type(batch), len(batch))
print(batch[0].size(), batch[1].size(), batch[2].size(), batch[3].size())

sample_source = batch[0].transpose(1,0)[0]
sample_target = batch[1].transpose(1,0)[0]
sample_src_mask = batch[2].transpose(1,0)[0]
sample_tgt_mask = batch[3].transpose(1,0)[0]


print("source:\n%s \n\ncorresponding tensor:\n%s \n" %(itos(sample_source, vocabulary), sample_source))
print("target:\n%s \n\ncorresponding tensor:\n%s \n" %(itos(sample_target, vocabulary), sample_target))
print(sample_src_mask)
print(sample_tgt_mask)

<class 'tuple'> 4
torch.Size([64, 80]) torch.Size([64, 80]) torch.Size([64, 1, 1, 80]) torch.Size([64, 1, 80, 80])
source:
the the some germany juventus the the john canadian double federal president the un mark foreign barry new a air several president president the three congress presstek a the gary manfred un real one with the britain many the more the the the former thai zimbabwean on foreign police the brad the matt a anxiety turkey goodyear indonesia pope eu two government stock the 

corresponding tensor:
tensor([    4,     4,   171,   287,  3456,     4,     4,   374,   634,  1399,
          209,    34,     4,   184,   710,    94,  4968,    27,    10,   243,
          543,    34,    34,     4,    81,   430, 34450,    10,     4,  4816,
        23458,   184,   850,    74,    17,     4,   274,   511,     4,    62,
            4,     4,     4,    90,   442,  2302,    11,    94,    65,     4,
         5714,     4,  6649,    10,  5709,   421, 10580,   428,   968,   173,
           47,

In [16]:
pre_trained_vector, embz_size, padding_idx = embedding_param(SAMPLE_DATA_PATH, TEXT, pre_trained_vector_type)

2019-01-20 18:27:38,434 INFO pre_trained_vector_mean = 0.0019917963, pre_trained_vector_std = 0.43600857
2019-01-20 18:27:38,439 INFO Normalizing embeddings...
2019-01-20 18:27:38,637 INFO pre_trained_vector_mean = -1.1977131e-08, pre_trained_vector_std = 0.9999996


In [22]:
model = Transformer(
    embz_size,
    vocab_size,
    padding_idx,
    pre_trained_vector,
    stack_number,
    heads_number
)

In [25]:
optimizer = get_default_optimizer(model)
criterion = torch.nn.NLLLoss()
train(model.to(device), train_iter, 5000, optimizer, criterion, print_every=100)

2019-01-20 18:38:57,378 INFO Start training
2019-01-20 18:40:40,687 INFO Iteration: 100, loss: 11.423600196838379, estimated remaining time: 84 min 22 s
2019-01-20 18:42:23,178 INFO Iteration: 200, loss: 11.402462005615234, estimated remaining time: 81 min 59 s
2019-01-20 18:44:05,617 INFO Iteration: 300, loss: 11.349199295043945, estimated remaining time: 80 min 14 s
2019-01-20 18:45:47,634 INFO Iteration: 400, loss: 11.374692916870117, estimated remaining time: 78 min 12 s
2019-01-20 18:47:30,265 INFO Iteration: 500, loss: 11.389216423034668, estimated remaining time: 76 min 58 s
2019-01-20 18:49:12,842 INFO Iteration: 600, loss: 11.35831356048584, estimated remaining time: 75 min 13 s
2019-01-20 18:50:55,074 INFO Iteration: 700, loss: 11.372986793518066, estimated remaining time: 73 min 15 s
2019-01-20 18:52:37,488 INFO Iteration: 800, loss: 11.350546836853027, estimated remaining time: 71 min 41 s
2019-01-20 18:54:19,887 INFO Iteration: 900, loss: 11.36902904510498, estimated remai

In [27]:
torch.save(model.state_dict(), 'trans-1.pt')