In [9]:
import torch
import tools
from tools import translate
from model import Encoder, Decoder
from data import TranslationDataset, SOS_token, EOS_token, MAX_LENGTH, collate
import numpy as np

In [2]:
device = torch.device('cpu')

In [6]:
data_dir = tools.select_data_dir()
trainset = TranslationDataset(data_dir, train=True)
testset = TranslationDataset(data_dir, train=False)

The data directory is ../data


In [7]:
encoder = Encoder(src_vocab_size=trainset.input_lang.n_words, n_blocks=3, n_features=256, n_heads=16, n_hidden=1024)
tools.load_model(encoder, '6_tr_encoder.pth', device)

decoder = Decoder(tgt_vocab_size=trainset.output_lang.n_words, n_blocks=3, n_features=256, n_heads=16, n_hidden=1024)
tools.load_model(decoder, '6_tr_decoder.pth', device)

Model loaded from 6_tr_encoder.pth.
Model loaded from 6_tr_decoder.pth.


In [13]:
print('Translate test data:')
print('-----------------------------')
for i in range(5):
    input_sentence, target_sentence = testset[np.random.choice(len(testset))]
    print('>', ' '.join(testset.input_lang.index2word[i.item()] for i in input_sentence))
    print('=', ' '.join(testset.output_lang.index2word[i.item()] for i in target_sentence))
    output_sentence = translate(encoder, decoder, input_sentence)
    print('<', ' '.join(testset.output_lang.index2word[i.item()] for i in output_sentence), '\n')

Translate test data:
-----------------------------
> il n est pas idiot . EOS
= he is no fool . EOS
< he s not stupid . EOS 

> elles sont toutes deux bonnes . EOS
= they are both good . EOS
< they are both good . EOS 

> tu es tout seul . EOS
= you re all alone . EOS
< you re all alone . EOS 

> ils font partie de nous . EOS
= they re part of us . EOS
< they re part of us . EOS 

> elle est tres fachee apres moi . EOS
= she is very annoyed with me . EOS
< she s very upset at me . EOS 

