In [291]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
import random
import numpy as np
import time

random.seed()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [292]:
SOS_token = 0
EOS_token = 1

class Language:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {SOS_token: "<", EOS_token: ">"}
        self.n_chars = 2  # Count SOS and EOS

    def addWord(self, word):
        for char in word:
            self.addChar(char)

    def addChar(self, char):
        if char not in self.word2index:
            self.word2index[char] = self.n_chars
            self.word2count[char] = 1
            self.index2word[self.n_chars] = char
            self.n_chars += 1
        else:
            self.word2count[char] += 1

In [293]:
def input_output_data(lang: str, type:str):
    path = "./aksharantar_sampled/{}/{}_{}.csv".format(lang, lang, type)
    df = pd.read_csv(path, header=None)
    return df[0].to_numpy(), df[1].to_numpy()

In [294]:
def set_words(lang:str):
    input_lang, output_lang = Language('eng'), Language(lang)
    input_words, output_words = input_output_data(lang, 'train')
    word_pairs = [[input_words[i], output_words[i]] for i in range(len(input_words))]
    for word in input_words:
        input_lang.addWord(word)
    for word in output_words:
        output_lang.addWord(word)
    return input_lang, output_lang, word_pairs

In [295]:
input_lang, output_lang, pairs = set_words('tam')
print(random.choice(pairs))
print("Number of words in input language: ", len(pairs))
print("Number of characters in input language: ", input_lang.n_chars)
print("Number of characters in output language: ", output_lang.n_chars)

['puliyurai', 'புலியூரை']
Number of words in input language:  51200
Number of characters in input language:  28
Number of characters in output language:  48


In [296]:
random_pair = random.choice(pairs)
input_word = random_pair[0]
output_word = random_pair[1]

encoded_output = [output_lang.word2index[char] for char in output_word]
print("Encoded output: ", encoded_output)

decoded_output = [output_lang.index2word[i] for i in encoded_output]
print("Decoded output: ", decoded_output)

decoded_string = ''.join(decoded_output)
print("Decoded string: ", decoded_string)

Encoded output:  [2, 8, 5, 10, 7, 20, 5, 4, 17, 21, 13]
Decoded output:  ['த', 'ர', '்', 'ம', 'ச', 'ண', '்', 'ட', 'ி', 'க', 'ை']
Decoded string:  தர்மசண்டிகை


In [297]:
def cell(cell_type:str):
    if cell_type == 'LSTM':
        return nn.LSTM
    elif cell_type == 'GRU':
        return nn.GRU
    elif cell_type == 'RNN':
        return nn.RNN
    else:
        raise Exception("Invalid cell type")

In [298]:
EMBED_DIM = 64
INPUT_DIM = input_lang.n_chars
OUTPUT_DIM = output_lang.n_chars
HIDDEN_DIM = 256
CELL_TYPE = 'GRU'

MAX_LENGTH = 50

PRINT_EVERY = 1000
PLOT_EVERY = 100
LEARNING_RATE = 0.001

In [299]:
class EncoderRNN(nn.Module):
    def __init__(self):
        super(EncoderRNN, self).__init__()
        self.hidden_dim = HIDDEN_DIM
        self.embedding = nn.Embedding(INPUT_DIM, EMBED_DIM)

        # cell types = "RNN", "GRU", "LSTM"
        self.cell = cell(CELL_TYPE)(EMBED_DIM, HIDDEN_DIM)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output, hidden = self.cell(embedded, hidden)
        return output, hidden
    
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_dim, device=device)
    
class DecoderRNN(nn.Module):
    def __init__(self):
        super(DecoderRNN, self).__init__()
        self.hidden_dim = HIDDEN_DIM
        self.embedding = nn.Embedding(OUTPUT_DIM, EMBED_DIM)

        # cell types = "RNN", "GRU", "LSTM"
        self.cell = cell(CELL_TYPE)(EMBED_DIM, HIDDEN_DIM)
        self.out = nn.Linear(HIDDEN_DIM, OUTPUT_DIM)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        embed = self.embedding(input).view(1, 1, -1)
        active_embed = F.relu(embed)
        output, hidden = self.cell(active_embed, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden
    
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_dim, device=device)

In [300]:
def indexesFromWord(lang:Language, word:str):
    return [lang.word2index[char] for char in word]

def tensorFromWord(lang:Language, word:str):
    indexes = indexesFromWord(lang, word)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)

def tensorsFromPair(pair:list):
    input_tensor = tensorFromWord(input_lang, pair[0])
    target_tensor = tensorFromWord(output_lang, pair[1])
    return (input_tensor, target_tensor)

In [301]:
teacher_forcing_ratio = 0.5

def train(input_tensor, 
          target_tensor, 
          encoder : EncoderRNN,
          decoder : DecoderRNN,
          encoder_optimizer : optim.Optimizer, 
          decoder_optimizer : optim.Optimizer,
          criterion,
          max_length=MAX_LENGTH):
    
    encoder_hidden = encoder.initHidden()

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)

    encoder_outputs = torch.zeros(max_length, encoder.hidden_dim, device=device)

    loss = 0

    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]

    decoder_input = torch.tensor([[SOS_token]], device=device)
    decoder_hidden = encoder_hidden

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    if use_teacher_forcing:
        for di in range(target_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            loss += criterion(decoder_output, target_tensor[di])

            decoder_input = target_tensor[di]
    else:
        for di in range(target_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            topv, topi = decoder_output.topk(1)

            decoder_input = topi.squeeze().detach()
            loss += criterion(decoder_output, target_tensor[di])

            if decoder_input.item() == EOS_token:
                break
    
    loss.backward()
    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length

In [302]:
training_pairs = [tensorsFromPair(pair) for pair in pairs]
print("Training pairs size: ", len(training_pairs))

def train_loop(encoder : EncoderRNN,
               decoder : DecoderRNN,
               n_iters : int = 5,
               print_every=PRINT_EVERY, 
               plot_every=PLOT_EVERY,
               learning_rate=LEARNING_RATE):
    
    start_time = time.time()
    plot_losses = []
    print_loss_total = 0
    plot_loss_total = 0

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)

    # random permutate the training pairs
    random.shuffle(training_pairs)
    criterion = nn.NLLLoss()

    for iter in range(1, len(training_pairs)+1):
        training_pair = training_pairs[iter-1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]

        loss = train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            current_time = time.time()
            print("Loss after {} iterations ({}s): {}".format(iter, current_time - start_time, print_loss_avg))

        if iter % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

    return plot_losses

Training pairs size:  51200


In [303]:
def evaluate(encoder : EncoderRNN,
             decoder : DecoderRNN,
             word : str,
             max_length=MAX_LENGTH):

    with torch.no_grad():
        input_tensor = tensorFromWord(input_lang, word)
        input_length = input_tensor.size()[0]
        encoder_hidden = encoder.initHidden()

        encoder_outputs = torch.zeros(max_length, encoder.hidden_dim, device=device)

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
            encoder_outputs[ei] += encoder_output[0, 0]

        decoder_input = torch.tensor([[SOS_token]], device=device)
        decoder_hidden = encoder_hidden

        decoded_word = ""

        for di in range(max_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            topv, topi = decoder_output.topk(1)

            if topi.item() == EOS_token:
                break
            else:
                decoded_word += output_lang.index2word[topi.item()]

            decoder_input = topi.squeeze().detach()

        return decoded_word

In [304]:
def evaluate_random(encoder:EncoderRNN, decoder:DecoderRNN, n=10):
    for i in range(n):
        pair = random.choice(pairs)
        print("Input: {}".format(pair[0]))
        print("Target: {}".format(pair[1]))
        output = evaluate(encoder, decoder, pair[0])
        print("Output: {}".format(output))
        print("")

In [305]:
encoder1 = EncoderRNN().to(device)
decoder1 = DecoderRNN().to(device)

EPOCHS = 10
for epoch in range(EPOCHS):
    print("Epoch {}".format(epoch))
    train_loop(encoder1, decoder1, print_every=5000)

Epoch 0
Loss after 5000 iterations (32.570534229278564s): 2.761272708293739
Loss after 10000 iterations (64.56866693496704s): 2.576452040883524
Loss after 15000 iterations (96.51885867118835s): 2.5351037977381115
Loss after 20000 iterations (128.96645998954773s): 2.5098527596083904
Loss after 25000 iterations (161.5207736492157s): 2.4745074851416615
Loss after 30000 iterations (194.34914422035217s): 2.454995440347155
Loss after 35000 iterations (227.49340748786926s): 2.430090384004403
Loss after 40000 iterations (264.2547187805176s): 2.3978343262346424
Loss after 45000 iterations (298.7650682926178s): 2.3538481898292436
Loss after 50000 iterations (333.62473917007446s): 2.316358998466944
Epoch 1
Loss after 5000 iterations (35.29433274269104s): 2.2542203995663725
Loss after 10000 iterations (70.68840432167053s): 2.1761930697732526
Loss after 15000 iterations (106.17613887786865s): 2.1479104773283786
Loss after 20000 iterations (142.5713927745819s): 2.0786312474793003
Loss after 25000 it

In [347]:
torch.save(encoder1.state_dict(), "./models/encoder_tam1.pt")
torch.save(decoder1.state_dict(), "./models/decoder_tam1.pt")