In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [85]:
import torch
from torchtext import data, vocab

import time
import datetime
import logging
import sys

from utils import execute_and_time, get_device, itos
from preprocess import Batch, embedding_param
from model import Transformer
from optimize import get_default_optimizer, train

In [39]:
LOG_FILE = False
logger = logging.getLogger()
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
if LOG_FILE:
    file_handler = logging.FileHandler('log.out')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler) 
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger.setLevel(logging.INFO)

In [133]:
DATA_PATH = 'data/'
SAMPLE_DATA_PATH = f'{DATA_PATH}sample_data/'
PROCESSED_DATA_PATH = f'{DATA_PATH}processed_data/'

pre_trained_vector_type = 'glove.6B.200d' 
batch_size = 64
device = get_device()
stack_number = 6
heads_number = 8
fix_length = 80

2019-01-20 12:18:42,001 INFO GPU unavailable, using CPU.


In [134]:
%%time
tokenizer = data.get_tokenizer('spacy')
TEXT = data.Field(tokenize=tokenizer, lower=True, eos_token='_eos_', fix_length=fix_length)
trn_data_fields = [("source", TEXT), ("target", TEXT)]
trn, vld = data.TabularDataset.splits(path=f'{SAMPLE_DATA_PATH}',
                           train='train_ds.csv', 
                           validation='valid_ds.csv',
                           format='csv', 
                           skip_header=True, 
                           fields=trn_data_fields)

CPU times: user 1min 52s, sys: 1.44 s, total: 1min 54s
Wall time: 1min 54s


In [135]:
TEXT.build_vocab(trn, vectors=pre_trained_vector_type)
vocabulary = TEXT.vocab
vocab_size = len(vocabulary)

2019-01-20 12:21:49,080 INFO Loading vectors from .vector_cache/glove.6B.200d.txt.pt


In [136]:
train_iter, val_iter = data.BucketIterator.splits((trn, vld), 
                                                  batch_sizes=(batch_size, int(batch_size * 1.6)),
                                                  device=device, 
                                                  sort_key=lambda x: len(x.source),
                                                  shuffle=True, 
                                                  sort_within_batch=False, 
                                                  repeat=True)
train_iter = Batch(train_iter, "source", "target", vocabulary, device=device)
val_iter = Batch(val_iter, "source", "target", vocabulary, device=device)
# train_iter, val_iter = iter(train_iter_tuple), iter(val_iter_tuple)

In [137]:
batch = next(train_iter)
print(type(batch), len(batch))
print(batch[0].size(), batch[1].size(), batch[2].size(), batch[3].size())

sample_source = batch[0].transpose(1,0)[0]
sample_target = batch[1].transpose(1,0)[0]
sample_src_mask = batch[2].transpose(1,0)[0]
sample_tgt_mask = batch[3].transpose(1,0)[0]


print("source:\n%s \n\ncorresponding tensor:\n%s \n" %(itos(sample_source, vocabulary), sample_source))
print("target:\n%s \n\ncorresponding tensor:\n%s \n" %(itos(sample_target, vocabulary), sample_target))
print(sample_src_mask)
print(sample_tgt_mask)

<class 'tuple'> 4
torch.Size([64, 80]) torch.Size([64, 80]) torch.Size([64, 1, 1, 80]) torch.Size([64, 1, 80, 80])
source:
legendary fifa the a at olivier the riots washington greek jitters while china these egyptian signs a zimbabwe torrential passenger some large canadian heavily each the with u.s a los hermann stocks british the the the a an the scientists bill france spyker a some israel the following the it a more countdown second chinese david residents the insurgency the women former the voters 

corresponding tensor:
tensor([ 5558,  1902,     4,    10,    18, 11089,     4,  3123,   425,   702,
         7213,   298,    42,  1617,   823,   852,    10,   559,  4929,  1556,
          173,   869,   656,  2858,  1181,     4,    17,    40,    10,   739,
        15032,   146,   134,     4,     4,     4,    10,    25,     4,  1108,
          395,   197, 19383,    10,   173,   166,     4,   293,     4,    45,
           10,    63,  7914,   124,    95,  1091,  1078,     4,  4038,     4,
 

In [138]:
pre_trained_vector, embz_size, padding_idx = embedding_param(SAMPLE_DATA_PATH, TEXT, pre_trained_vector_type)

2019-01-20 12:22:08,664 INFO pre_trained_vector_mean = 0.002008178, pre_trained_vector_std = 0.43602833
2019-01-20 12:22:08,665 INFO Normalizing embeddings...
2019-01-20 12:22:08,810 INFO pre_trained_vector_mean = -1.2933846e-08, pre_trained_vector_std = 1.0000006


In [139]:
model = Transformer(
    embz_size,
    vocab_size,
    padding_idx,
    pre_trained_vector,
    stack_number,
    heads_number
)

In [140]:
optimizer = get_default_optimizer(model)
criterion = torch.nn.NLLLoss()
train(model, train_iter, 5000, optimizer, criterion, print_every=100)

2019-01-20 12:22:24,865 INFO Start traning
torch.Size([64, 80, 200]) torch.Size([64, 80, 200]) torch.Size([64, 80, 200])
torch.Size([64, 80, 200]) torch.Size([64, 80, 200]) torch.Size([64, 80, 200])
torch.Size([64, 80, 200]) torch.Size([64, 80, 200]) torch.Size([64, 80, 200])
torch.Size([64, 80, 200]) torch.Size([64, 80, 200]) torch.Size([64, 80, 200])
torch.Size([64, 80, 200]) torch.Size([64, 80, 200]) torch.Size([64, 80, 200])
torch.Size([64, 80, 200]) torch.Size([64, 80, 200]) torch.Size([64, 80, 200])
torch.Size([64, 80, 200]) torch.Size([64, 80, 200]) torch.Size([64, 80, 200])
torch.Size([64, 80, 200]) torch.Size([64, 80, 200]) torch.Size([64, 80, 200])
torch.Size([64, 80, 200]) torch.Size([64, 80, 200]) torch.Size([64, 80, 200])
torch.Size([64, 80, 200]) torch.Size([64, 80, 200]) torch.Size([64, 80, 200])
torch.Size([64, 80, 200]) torch.Size([64, 80, 200]) torch.Size([64, 80, 200])
torch.Size([64, 80, 200]) torch.Size([64, 80, 200]) torch.Size([64, 80, 200])
torch.Size([64, 80, 2

KeyboardInterrupt: 