In [None]:
# Bahdanau Attention NMT (Bahdanau et al., 2014)
# Implementation based on the paper: https://arxiv.org/pdf/1409.0473

import torch
from torch import nn
import torch.nn.functional as F


In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size=32, hidden_size=64):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)

    def forward(self, src):
        embedded = self.embedding(src)
        outputs, (h, c) = self.lstm(embedded)
        return outputs, (h, c)

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Linear(hidden_size, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        T = encoder_outputs.size(1)
        hidden = hidden.unsqueeze(1).repeat(1, T, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        scores = self.v(energy).squeeze(-1)
        weights = F.softmax(scores, dim=1)
        context = torch.bmm(weights.unsqueeze(1), encoder_outputs).squeeze(1)
        return context, weights

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size=32, hidden_size=64):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTMCell(embed_size + hidden_size, hidden_size)
        self.attention = Attention(hidden_size)
        self.fc = nn.Linear(hidden_size * 2, vocab_size)

    def forward(self, trg, hidden, cell, encoder_outputs):
        batch_size, T = trg.size()
        outputs = []
        for t in range(T):
            input_emb = self.embedding(trg[:, t])
            context, _ = self.attention(hidden, encoder_outputs)
            hidden, cell = self.lstm(torch.cat((input_emb, context), dim=1), (hidden, cell))
            logits = self.fc(torch.cat((hidden, context), dim=1))
            outputs.append(logits)
        outputs = torch.stack(outputs, dim=1)
        return outputs

class Seq2Seq(nn.Module):
    def __init__(self, vocab_size, embed_size=32, hidden_size=64):
        super().__init__()
        self.encoder = Encoder(vocab_size, embed_size, hidden_size)
        self.decoder = Decoder(vocab_size, embed_size, hidden_size)

    def forward(self, src, trg):
        enc_out, (h, c) = self.encoder(src)
        outputs = self.decoder(trg, h.squeeze(0), c.squeeze(0), enc_out)
        return outputs


In [None]:
# --------------------------
# Synthetic dataset helpers
# --------------------------
VOCAB_SIZE = 11  # digits 0-9 plus <eos>
EOS = 10

def gen_batch(batch_sz, seq_len=5):
    inputs = torch.randint(0, 10, (batch_sz, seq_len))
    targets = torch.flip(inputs, dims=[1])
    targets = torch.cat([targets, torch.full((batch_sz,1), EOS, dtype=torch.long)], dim=1)
    return inputs.to(DEVICE), targets.to(DEVICE)


In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Seq2Seq(VOCAB_SIZE).to(DEVICE)
optim = torch.optim.Adam(model.parameters(), lr=1e-3)

steps = 200
batch_sz = 128
seq_len = 6

for step in range(1, steps + 1):
    src, trg = gen_batch(batch_sz, seq_len)
    optim.zero_grad()
    output = model(src, trg[:, :-1])
    loss = F.cross_entropy(output.view(-1, VOCAB_SIZE), trg[:,1:].reshape(-1))
    loss.backward()
    optim.step()

    if step % 20 == 0:
        with torch.no_grad():
            model.eval()
            src, trg = gen_batch(batch_sz, seq_len)
            out = model(src, trg[:, :-1])
            preds = out.argmax(-1)
            acc = (preds == trg[:,1:]).float().mean().item()
            print(f'step {step:4d} | loss {loss.item():.3f} | val acc {acc*100:5.1f}%')
            model.train()


In [None]:
# Example usage
src, trg = gen_batch(1, seq_len)
with torch.no_grad():
    out = model(src, trg[:, :-1])
    pred = out.argmax(-1)
print('input:', src.squeeze(0).tolist())
print('target:', trg.squeeze(0).tolist())
print('pred  :', pred.squeeze(0).tolist())
