In [None]:
import os
import sys
import pickle

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.functional as F
from torch.utils.data import Dataset, DataLoader

sys.path.append(os.path.join(sys.path[0], '../src'))
import data
import utils
import model
from model import EncoderRNN
from model import DecoderRNN

%load_ext autoreload
%autoreload 2

## Load and process training data

In [None]:
# Change to directory containing English to Tamil transliteration data.
DATA_DIR = "../data/aksharantar_sampled/tam/"
os.chdir(DATA_DIR)
os.listdir()

The class WordDataset reads and stores words in onehot representation.

In [None]:
# dataset = data.WordDataset() # By default reads 'tam_train.csv'.
# with open('dataset.pkl', 'wb') as file:
    # pickle.dump(dataset, file)

In [None]:
with open('dataset.pkl', 'rb') as file:
    dataset = pickle.load(file)

The above object `dataset[i]` returns i-th data point: a word-pair in one-hot represenation. 

In [None]:
x, y = dataset[0]
print(x.shape)
print(y.shape)

In [None]:
dataset.decode_eng_word(x)

In [None]:
dataset.decode_tam_word(y)

### Sampling data

In [None]:
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
test_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
next(iter(train_dataloader))

## Test Encoder

In [None]:
from model import EncoderRNN
encoder = EncoderRNN(cell_type='gru')
output, hidden = encoder.forward(x)

In [None]:
out = torch.sum(output)
out.backward()

In [None]:
for param in encoder.parameters():
    print(param.grad)

In [None]:
print(output.shape)
print(hidden.shape)

### Decoder

In [None]:
y[0].shape

In [None]:
import model
from model import DecoderRNN

decoder = DecoderRNN(cell_type='gru')
out, hidden = decoder.forward(y[0], hidden)

In [None]:
hidden.shape

In [None]:
print(out.shape)
print(hidden.shape)

## Loss

In [None]:
target = torch.zeros(128)
target[0] = 1
target = torch.tensor([torch.where(target == 1)[0]])
y_hat = out
criterion = nn.NLLLoss()
loss = criterion(y_hat, target)

In [None]:
for param in encoder.parameters():
    print(param.grad)
    break

## Test Training

Code prior to calling `train`

In [None]:
encoder = EncoderRNN(cell_type='gru')
decoder = DecoderRNN(cell_type='gru')
encoder, decoder, loss_seq = utils.train(dataset, encoder, decoder, 0.01, 2000)

In [None]:
plt.plot(loss_seq)

## Evaluation

In [None]:
x, y = next(iter(dataset))

In [None]:
encoder.initialize_hidden()

In [None]:
y_pred = utils.predict(x, encoder, decoder)

In [None]:
torch.sum(y_pred)
dataset.decode_tam_word(y_pred)
# dataset.decode_tam_word(y)