In [None]:
# Imports
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

from utils import (
    load_pairs, prepare_vocab, word2tensor,
    SOS_token, EOS_token, device
)


In [None]:
# Load Dataset
data_path = "../dakshina_dataset_v1.0/hi/lexicons/hi.transliteration.train.tsv"

pairs = load_pairs(data_path)
src_vocab, tgt_vocab = prepare_vocab(pairs)

print("Latin vocab size:", src_vocab.n_chars)
print("Native vocab size:", tgt_vocab.n_chars)
print("Example pair:", pairs[0])


In [None]:
# Attention Module
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.rand(hidden_size))

    def forward(self, hidden, encoder_outputs):
        # hidden: (1, batch, hidden_size)
        # encoder_outputs: (seq_len, hidden_size)
        seq_len = encoder_outputs.size(0)
        hidden = hidden.squeeze(0).repeat(seq_len, 1)  # (seq_len, hidden_size)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), 1)))
        energy = energy @ self.v  # (seq_len,)
        return F.softmax(energy, dim=0).unsqueeze(0)  # (1, seq_len)


In [None]:
# Encoder
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, cell_type="LSTM"):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.LSTM(hidden_size, hidden_size) if cell_type == "LSTM" else nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output, hidden = self.rnn(embedded, hidden)
        return output, hidden

    def init_hidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)


In [None]:
# Attention Decoder
class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=30, cell_type="LSTM"):
        super().__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.max_length = max_length

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.attention = BahdanauAttention(hidden_size)
        self.dropout = nn.Dropout(dropout_p)

        if cell_type == "LSTM":
            self.rnn = nn.LSTM(hidden_size * 2, hidden_size)
        else:
            self.rnn = nn.GRU(hidden_size * 2, hidden_size)

        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)

        attn_weights = self.attention(hidden[0] if isinstance(hidden, tuple) else hidden,
                                      encoder_outputs.squeeze(1))
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))  # (1,1,hidden)

        rnn_input = torch.cat((embedded, context), 2)
        output, hidden = self.rnn(rnn_input, hidden)

        output = self.out(output[0])
        return F.log_softmax(output, dim=1), hidden, attn_weights


In [None]:
# Training Step with Attention
teacher_forcing_ratio = 0.5

def train_step(input_tensor, target_tensor, encoder, decoder,
               encoder_optimizer, decoder_optimizer, criterion,
               max_length=30):
    encoder_hidden = encoder.init_hidden()
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)

    encoder_outputs = torch.zeros(max_length, 1, encoder.hidden_size, device=device)

    # Encoder forward
    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]

    decoder_input = torch.tensor([[SOS_token]], device=device)
    decoder_hidden = encoder_hidden
    loss = 0

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    if use_teacher_forcing:
        for di in range(target_length):
            decoder_output, decoder_hidden, _ = decoder(decoder_input, decoder_hidden, encoder_outputs)
            loss += criterion(decoder_output, target_tensor[di])
            decoder_input = target_tensor[di]
    else:
        for di in range(target_length):
            decoder_output, decoder_hidden, _ = decoder(decoder_input, decoder_hidden, encoder_outputs)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()
            loss += criterion(decoder_output, target_tensor[di])
            if decoder_input.item() == EOS_token:
                break

    loss.backward()
    encoder_optimizer.step()
    decoder_optimizer.step()
    return loss.item() / target_length


In [None]:
# Training Loop
def train_iters(pairs, encoder, decoder, src_vocab, tgt_vocab,
                n_iters=1000, learning_rate=0.01, print_every=100):
    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
    criterion = nn.NLLLoss()

    for it in range(1, n_iters + 1):
        src, tgt = random.choice(pairs)
        input_tensor = word2tensor(src_vocab, src)
        target_tensor = word2tensor(tgt_vocab, tgt)

        loss = train_step(input_tensor, target_tensor, encoder, decoder,
                          encoder_optimizer, decoder_optimizer, criterion)

        if it % print_every == 0:
            print(f"Iter {it}, Loss {loss:.4f}")


In [None]:
# Evaluation with Attention
def evaluate(encoder, decoder, word, src_vocab, tgt_vocab, max_length=30):
    with torch.no_grad():
        input_tensor = word2tensor(src_vocab, word)
        input_length = input_tensor.size()[0]
        encoder_hidden = encoder.init_hidden()

        encoder_outputs = torch.zeros(max_length, 1, encoder.hidden_size, device=device)
        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
            encoder_outputs[ei] = encoder_output[0, 0]

        decoder_input = torch.tensor([[SOS_token]], device=device)
        decoder_hidden = encoder_hidden

        decoded_chars, attentions = [], torch.zeros(max_length, max_length)
        for di in range(max_length):
            decoder_output, decoder_hidden, attn_weights = decoder(decoder_input, decoder_hidden, encoder_outputs)
            attentions[di, :attn_weights.size(-1)] = attn_weights.squeeze(0)
            topv, topi = decoder_output.data.topk(1)
            if topi.item() == EOS_token:
                break
            decoded_chars.append(tgt_vocab.index2char[topi.item()])
            decoder_input = topi.squeeze().detach()

        return ''.join(decoded_chars), attentions[:len(decoded_chars), :input_length]


In [None]:
# Attention Heatmap Utility
def show_attention(input_word, output_word, attentions):
    fig = plt.figure(figsize=(6,6))
    ax = fig.add_subplot(111)
    cax = ax.matshow(attentions.numpy(), cmap='viridis')
    fig.colorbar(cax)

    ax.set_xticklabels([''] + list(input_word) + ['EOS'], rotation=90)
    ax.set_yticklabels([''] + list(output_word))

    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.show()


In [None]:
# Run Training + Test Predictions
hidden_size = 256
encoder = EncoderRNN(src_vocab.n_chars, hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size, tgt_vocab.n_chars, dropout_p=0.1).to(device)

train_iters(pairs, encoder, decoder, src_vocab, tgt_vocab, n_iters=2000, print_every=200)

# Show some predictions with attention maps
for word, tgt in random.sample(pairs, 3):
    pred, attn = evaluate(encoder, decoder, word, src_vocab, tgt_vocab)
    print(f"{word} -> {pred} (target: {tgt})")
    show_attention(word, pred, attn)
