In [1]:
import torch
import torch.nn as nn
import logging
from tensorboardX import SummaryWriter
import onmt.translate
from metrics import BleuScore, CorrectAnswersScore
from trainer import SimMLETrainer

Import requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.
  from numba.decorators import jit as optional_jit
Import of 'jit' requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.
  from numba.decorators import jit as optional_jit


In [2]:
mode="v2"
writer = SummaryWriter(comment="-argmax")
fmt = "%(asctime)-15s %(levelname)s %(message)s"
logging.basicConfig(format=fmt, level=logging.INFO)
scenario_number = 8
vocab_fields = torch.load(f"processed/SimQA.vocab.pt")

src_text_field = vocab_fields["src"].base_field
src_vocab = src_text_field.vocab
src_padding = src_vocab.stoi[src_text_field.pad_token]

tgt_text_field = vocab_fields['tgt'].base_field
tgt_vocab = tgt_text_field.vocab
tgt_padding = tgt_vocab.stoi[tgt_text_field.pad_token]

emb_size = 200
rnn_size = 200
encoder_embeddings = onmt.modules.Embeddings(
    emb_size,
    len(src_vocab),
    word_padding_idx=src_padding,
    position_encoding=True,
    fix_word_vecs=True)
encoder_embeddings.load_pretrained_vectors(f"src.embeddings.pt")


encoder = onmt.encoders.TransformerEncoder(
    num_layers=6,
    d_model=emb_size,
    max_relative_positions=100,
    embeddings=encoder_embeddings,
    d_ff=1,
    dropout=0.1,
    attention_dropout=0.1,
    heads=8
)

decoder_embeddings = onmt.modules.Embeddings(emb_size,
                                             len(tgt_vocab),
                                             word_padding_idx=tgt_padding,
                                             position_encoding=True)
decoder = onmt.decoders.TransformerDecoder(
    num_layers=6,
    d_model=emb_size,
    embeddings=decoder_embeddings,
    max_relative_positions=100,
    aan_useffn=False,
    full_context_alignment=True,
    alignment_layer=1,
    alignment_heads=1,
    heads=8,
    d_ff=1,
    copy_attn=False,
    self_attn_type="scaled-dot",
    dropout=0.1,
    attention_dropout=0.1
)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = onmt.models.model.NMTModel(encoder, decoder)
model.to(device)
logging.getLogger().info(device)
# Specify the tgt word generator and loss computation module
model.generator = nn.Sequential(
    nn.Linear(emb_size, len(tgt_vocab)),
    nn.LogSoftmax(dim=-1)).to(device)

#checkpoint = torch.load("./models/model_step_5000.pt")
#model.load_state_dict(checkpoint["model"], strict=False)
#model.generator.load_state_dict(checkpoint["generator"], strict=False)

import onmt.utils.loss
loss = onmt.utils.loss.NMTLossCompute(
    criterion=nn.NLLLoss(ignore_index=tgt_padding, reduction="sum"),
    generator=model.generator)

2020-11-23 21:20:55,345 INFO cuda


In [3]:
import onmt.opts
lr = 0.001
torch_optimizer = torch.optim.Adam(model.parameters(), lr=lr)
optim = onmt.utils.optimizers.Optimizer(
torch_optimizer, learning_rate=lr, max_grad_norm=2)

In [4]:
from itertools import chain
train_data_file = f"processed/SimQA.train.0.pt"
valid_data_file = f"processed/SimQA.valid.0.pt"
train_iter = onmt.inputters.inputter.DatasetLazyIter(dataset_paths=[train_data_file],
                                                     fields=vocab_fields,
                                                     batch_size=32,
                                                     batch_size_multiple=1,
                                                     batch_size_fn=None,
                                                     device=device,
                                                     is_train=True,
                                                     repeat=True,
                                                     pool_factor=8192)

valid_iter = onmt.inputters.inputter.DatasetLazyIter(dataset_paths=[valid_data_file],
                                                     fields=vocab_fields,
                                                     batch_size=16,
                                                     batch_size_multiple=1,
                                                     batch_size_fn=None,
                                                     device=device,
                                                     is_train=False,
                                                     repeat=False,
                                                     pool_factor=8192)
report_manager = onmt.utils.ReportMgr(report_every=100, tensorboard_writer=writer)

src_reader = onmt.inputters.str2reader["text"]
tgt_reader = onmt.inputters.str2reader["text"]
scorer = onmt.translate.GNMTGlobalScorer(alpha=0.7,
                                         beta=0.,
                                         length_penalty="avg",
                                         coverage_penalty="none")
gpu = 0 if torch.cuda.is_available() else -1
translator = onmt.translate.Translator(model=model,
                                       fields=vocab_fields,
                                       src_reader=src_reader,
                                       tgt_reader=tgt_reader,
                                       global_scorer=scorer,
                                       gpu=gpu)
builder = onmt.translate.TranslationBuilder(data=torch.load(valid_data_file),
                                            fields=vocab_fields, has_tgt=True)

metrics = [BleuScore(), CorrectAnswersScore(tgt_vocab)]
model_saver = onmt.models.model_saver.ModelSaver("model", model, None, vocab_fields, optim)
trainer = SimMLETrainer(model=model,
                    model_saver=model_saver,
                    train_loss=loss,
                    valid_loss=loss,
                    optim=optim,
                    report_manager=report_manager,
                    metrics=metrics,
                    translator=translator,
                    translation_builder=builder)
stats, metrics = trainer.train(train_iter=train_iter,
              src_vocab=src_vocab,
              train_steps=10000,
              valid_iter=valid_iter,
              valid_steps=100,
              save_checkpoint_steps=1000)

2020-11-23 21:20:55,414 INFO Start training loop and validate every 100 steps...
2020-11-23 21:20:55,414 INFO Loading dataset from processed/SimQA.train.0.pt
2020-11-23 21:20:55,468 INFO number of examples: 3029
2020-11-23 21:21:31,893 INFO Loading dataset from processed/SimQA.train.0.pt
2020-11-23 21:21:31,952 INFO number of examples: 3029
2020-11-23 21:21:33,726 INFO Step 100/10000; acc:  63.02; ppl:  3.69; xent: 1.31; lr: 0.00100; 17316/3725 tok/s;     38 sec
2020-11-23 21:21:33,728 INFO Loading dataset from processed/SimQA.valid.0.pt
2020-11-23 21:21:33,741 INFO number of examples: 534
2020-11-23 21:22:29,467 INFO Validation metrics: Bleu: 0.5975,Correct answers 0.1097, State transitions score: 0.0793
2020-11-23 21:22:29,468 INFO Validation perplexity: 1.89107
2020-11-23 21:22:29,468 INFO Validation accuracy: 79.4884
2020-11-23 21:22:55,059 INFO Loading dataset from processed/SimQA.train.0.pt
2020-11-23 21:22:55,108 INFO number of examples: 3029
2020-11-23 21:22:57,420 INFO Step 20