In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%config InlineBackend.figure_format='svg'

In [2]:
def make_batch():
    input_batch = [np.eye(n_class)[[word_dict[n] for n in sentences[0].split()]]]
    output_batch = [np.eye(n_class)[[word_dict[n] for n in sentences[1].split()]]]
    target_batch = [[word_dict[n] for n in sentences[2].split()]]
    return torch.FloatTensor(input_batch), torch.FloatTensor(output_batch), torch.LongTensor(target_batch)

In [3]:
class Attention(nn.Module):
    def __init__(self, n_class, n_hidden):
        super(Attention, self).__init__()
        self.enc_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
        self.dec_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
        
        self.attn = nn.Linear(n_hidden, n_hidden)
        self.out = nn.Linear(n_hidden * 2, n_class)
        
    def forward(self, enc_inputs, hidden, dec_inputs):
        enc_inputs = enc_inputs.transpose(0, 1)
        dec_inputs = dec_inputs.transpose(0, 1)
        
        enc_inputs, enc_hidden = self.enc_cell(enc_inputs, hidden)
        
        trained_attn = []
        hidden = enc_hidden
        n_step = len(dec_inputs)
        model = torch.empty([n_step, 1, n_class])
        
        for i in range(n_step):
            dec_output, hidden = self.dec_cell(dec_inputs[i].unsqueeze(0), hidden)
            atten_weights = self.get_att_weight(dec_output, enc_outputs)
            trained_attn.append(attn_weights.squeeze().data.numpy())
            
            context = attn_weights.bmm(enc_outputs.transpose(0, 1))
            dec_output = dec_output.squeeze(0)
            context = context.squeeze(1)
            model[i] = self.out(torch.cat((dec_output, context), 1))
        
        return model.transpose(0, 1).squeeze(0), trained_attn
    
    def get_att_weight(self, dec_output, enc_outputs):
        n_step = len(enc_outputs)
        attn_scores = torch.zeros(n_step)
        
        for i in range(n_step):
            attn_scores[i] = self.get_att_score(dec_output, enc_outputs[i])
            
        return F.softmax(attn_scores).view(1, 1, -1)
    
    def get_att_score(self, dec_output, enc_output):
        score = self.attn(enc_output)
        return torch.dot(dec_output.view(-1), score.view(-1))

In [None]:
n_step = 5
n_hidden = 128

sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E']

word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {w: i for i,w in enumerate(word_list)}
number_dict = {i: w for i,w in enumerate(word_list) }
n_class = len(word_dict)

hidden = torch.zeros(1, 1, n_hidden)

model = Attention(n_class, n_hidden)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

input_batch, output_batch, target_batch = make_batch()

for epoch in range(2000):
    optimizer.zero_grad()
    output, __ = model(input_batch, hidden, output_batch)
    