In [17]:
import torch
import torch.nn as nn

import onmt
import onmt.io
import onmt.modules

In [18]:
vocab = dict(torch.load("out/data.vocab.pt"))
src_padding = vocab["src"].stoi[onmt.io.PAD_WORD]
tgt_padding = vocab["tgt"].stoi[onmt.io.PAD_WORD]

In [19]:
emb_size = 10
rnn_size = 6
# Specify the core model.
encoder_embeddings = onmt.modules.Embeddings(emb_size, len(vocab["src"]),
                                             word_padding_idx=src_padding)

encoder = onmt.modules.RNNEncoder(hidden_size=rnn_size, num_layers=1,
                                 rnn_type="LSTM", bidirectional=True,
                                 embeddings=encoder_embeddings)

decoder_embeddings = onmt.modules.Embeddings(emb_size, len(vocab["tgt"]),
                                             word_padding_idx=tgt_padding)
decoder = onmt.modules.InputFeedRNNDecoder(hidden_size=rnn_size, num_layers=1,
                                           bidirectional_encoder=True,
                                           rnn_type="LSTM", embeddings=decoder_embeddings)
model = onmt.modules.NMTModel(encoder, decoder)

# Specify the tgt word generator and loss computation module
model.generator = nn.Sequential(
            nn.Linear(rnn_size, len(vocab["tgt"])),
            nn.LogSoftmax())
loss = onmt.Loss.NMTLossCompute(model.generator, vocab["tgt"])

In [20]:
optim = onmt.Optim(method="sgd", lr=1, max_grad_norm=2)
# type(model.named_parameters())
# model.parameters()
optim.set_parameters(model.named_parameters())

# print model.generator.parameters()
# model.named_parameters()

In [21]:
data = torch.load("out/data.train.1.pt")
valid_data = torch.load("out/data.valid.1.pt")
data.load_fields(vocab)
valid_data.load_fields(vocab)
data.examples = data.examples[:100]

In [22]:
train_iter = onmt.io.OrderedIterator(
                dataset=data, batch_size=10,
                device=-1,
                repeat=False)
valid_iter = onmt.io.OrderedIterator(
                dataset=valid_data, batch_size=10,
                device=-1,
                train=False)

In [25]:
trainer = onmt.Trainer(model, train_iter, valid_iter, loss, loss, optim)

def report_func(*args):
    stats = args[-1]
    stats.output(args[0], args[1], 10, 0)
    return stats

for epoch in range(2):
    trainer.train(train_iter,epoch, report_func)
    val_stats = trainer.validate()

    print("Validation")
    val_stats.output(epoch, 11, 10, 0)
    trainer.epoch_step(val_stats.ppl(), epoch)

AttributeError: 'OrderedIterator' object has no attribute 'get_cur_dataset'