In [1]:
%load_ext autoreload
%autoreload 2

from unsupervised_mt.dataset import Dataset
from unsupervised_mt.train import Trainer 
from unsupervised_mt.models import Embedding, Encoder, DecoderHat, Attention, Discriminator
from unsupervised_mt.batch_iterator import BatchIterator
from unsupervised_mt.utils import log_probs2indices, noise

from functools import partial
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm_notebook as tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
ds = Dataset(languages=('src', 'tgt'), 
             corp_paths=('../data/train.lc.norm.tok.en', '../data/train.lc.norm.tok.fr'), 
             emb_paths=('../data/wiki.multi.en.vec', '../data/wiki.multi.fr.vec'), 
             pairs_paths=('../data/src2tgt_20.npy', '../data/tgt2src_20.npy'), 
             max_length=20)

In [3]:
batch_iter = BatchIterator(ds)

In [4]:
hidden_size = 100
num_layers = 3

src_embedding = Embedding(ds.emb_matrix['src']).to(device)
tgt_embedding = Embedding(ds.emb_matrix['tgt']).to(device)

encoder_rnn = nn.GRU(input_size=src_embedding.embedding_dim, hidden_size=hidden_size, num_layers=num_layers, bidirectional=True)
decoder_rnn = nn.GRU(input_size=src_embedding.embedding_dim, hidden_size=hidden_size, num_layers=num_layers, bidirectional=True)
attention = Attention(src_embedding.embedding_dim, hidden_size, max_length=ds.max_length, bidirectional=True)
src_hat = DecoderHat(2*hidden_size, ds.vocabs['src'].size)
tgt_hat = DecoderHat(2*hidden_size, ds.vocabs['tgt'].size)
discriminator = Discriminator(2*hidden_size)

trainer = Trainer(partial(ds.translate_batch_word_by_word, l1='src', l2='tgt'), 
                  partial(ds.translate_batch_word_by_word, l1='tgt', l2='src'), 
                  src_embedding, tgt_embedding, encoder_rnn, decoder_rnn, attention, src_hat, tgt_hat, 
                  discriminator, 
                  ds.get_sos_index('src'), ds.get_sos_index('tgt'), 
                  ds.get_eos_index('src'), ds.get_eos_index('tgt'), 
                  ds.get_pad_index('src'), ds.get_pad_index('tgt'), 
                  device, lr_core=1e-3, lr_disc=1e-3)

In [5]:
src_embedding.load_state_dict(torch.load('./src_embedding'))
tgt_embedding.load_state_dict(torch.load('./tgt_embedding'))
encoder_rnn.load_state_dict(torch.load('./encoder_rnn'))
decoder_rnn.load_state_dict(torch.load('./decoder_rnn'))
attention.load_state_dict(torch.load('./attention'))
src_hat.load_state_dict(torch.load('./src_hat'))
tgt_hat.load_state_dict(torch.load('./tgt_hat'))
discriminator.load_state_dict(torch.load('./discriminator'))

In [None]:
batch_size = 50
num_steps = 50000

for i in tqdm(range(num_steps)):
    trainer.train_step(batch_iter.load_batch(batch_size), weights=(1, 1, 0.1))

HBox(children=(IntProgress(value=0, max=50000), HTML(value='')))

In [20]:
batch = batch_iter.load_batch(1, test=True)
ds.print_batch(batch['src'], 'src')
ds.print_batch(trainer.frozen_src2tgt(batch['src']), 'tgt')
ds.print_batch(
    log_probs2indices(trainer.src2tgt.evaluate(batch['src'].to(device), ds.get_sos_index('tgt'), ds.get_eos_index('tgt'), 20)), 'tgt'
)

['two', 'competitors', 'one', 'of', 'which', 'is', 'in', 'the', 'us', 'air', 'force', 'wrestling', '.', '<eos>']
['deux', 'competiteurs', 'deux', 'de', 'qui', 'est', 'dans', 'la', 'unis', 'air', 'force', 'wrestling', '.', '<eos>']
['deux', 'hommes', 'sont', 'des', 'la', '<unk>', 'de', 'l', '<unk>', 'de', 'de', 'de', '.', '<eos>']


In [10]:
torch.save(src_embedding.state_dict(), 'src_embedding')
torch.save(tgt_embedding.state_dict(), 'tgt_embedding')
torch.save(encoder_rnn.state_dict(), 'encoder_rnn')
torch.save(decoder_rnn.state_dict(), 'decoder_rnn')
torch.save(attention.state_dict(), 'attention')
torch.save(src_hat.state_dict(), 'src_hat')
torch.save(tgt_hat.state_dict(), 'tgt_hat')
torch.save(discriminator.state_dict(), 'discriminator')