In [1]:
import torch
import torch.nn as nn
import random
import matplotlib.pyplot as plt

In [10]:
vocab_size = 256

In [11]:
x_ = list(map(ord, "hello"))
y_ = list(map(ord, "hola"))
x = torch.LongTensor(x_)
y = torch.LongTensor(y_)

In [17]:
class Seq2Seq(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(Seq2Seq, self).__init__()
        self.n_layers = 1
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.encoder = nn.GRU(hidden_size, hidden_size)
        self.decoder = nn.GRU(hidden_size, hidden_size)
        self.project = nn.Linear(hidden_size, vocab_size)

    def forward(self, inputs, targets):
        initial_state = self._init_state()
        embedding = self.embedding(inputs).unsqueeze(1)
        encoder_output, encoder_state = self.encoder(embedding, initial_state)

        decoder_state = encoder_state
        decoder_input = torch.LongTensor([0])
        
        outputs = []
        for i in range(targets.size()[0]):
            decoder_input = self.embedding(decoder_input).unsqueeze(1)
            decoder_output, decoder_state = self.decoder(
                decoder_input, decoder_state)
            
            projection = self.project(decoder_output)
            outputs.append(projection)
            
            decoder_input = torch.LongTensor([targets[i]])

        outputs = torch.stack(outputs).squeeze()
        return outputs

    def _init_state(self, batch_size=1):
        weight = next(self.parameters()).data
        return weight.new(self.n_layers, batch_size, self.hidden_size).zero_()

In [18]:
seq2seq = Seq2Seq(vocab_size, 16)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(seq2seq.parameters(), lr=1e-3)

In [20]:
log = []
for i in range(1000):
    prediction = seq2seq(x, y)
    loss = criterion(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    loss_val = loss.data
    log.append(loss_val)
    
    if i % 100 == 0:
        print(i, loss_val.item())
        _, top1 = prediction.data.topk(1, 1)
        print([chr(c) for c in top1.squeeze().numpy().tolist()])

0 5.450019359588623
['%', '\x1d', '\x1d', 'ù']
100 2.0540502071380615
['h', 'h', 'h', 'a']
200 0.8488560914993286
['h', 'o', 'l', 'a']
300 0.41322243213653564
['h', 'o', 'l', 'a']
400 0.25716862082481384
['h', 'o', 'l', 'a']
500 0.18181373178958893
['h', 'o', 'l', 'a']
600 0.13711613416671753
['h', 'o', 'l', 'a']
700 0.10718270391225815
['h', 'o', 'l', 'a']
800 0.08460129052400589
['h', 'o', 'l', 'a']
900 0.06411759555339813
['h', 'o', 'l', 'a']
