In [1]:
import json
import os
import pandas as pd
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm

from torch import Tensor
from torch import optim
from torch.utils.data import Dataset, DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

device

device(type='cuda')

In [2]:
# Constants

MAX_LENGTH = 71
SOS_TOKEN = 69
EOS_TOKEN = 70
TEACHER_FORCING_RATIO = 0.5
LEARNING_RATE = 0.01
HIDDEN_SIZE = 128

In [3]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size: int, hidden_size: int) -> None:
        super().__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> torch.Tensor:
        embedded = self.embedding(input).view(1, 1, -1)
        output, hidden = self.gru(embedded, hidden)
        return output, hidden

    def init_hidden(self) -> torch.Tensor:
        return torch.zeros(1, 1, self.hidden_size, device=device)

class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)
    
    def forward(self, input_tensor, hidden, encoder_outputs):
        embedded = self.embedding(input_tensor).view(1, 1, -1)
        embedded = self.dropout(embedded)

        attn_weights = F.softmax(self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))

        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)

        output = F.relu(output)
        output, hidden = self.gru(output, hidden)

        output = F.log_softmax(self.out(output[0]), dim=1)
        return output, hidden, attn_weights

    def init_hidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [4]:
def sentence_to_tensor(content, target_size = MAX_LENGTH) -> torch.Tensor:
    # Add padding to the end of the sentence, so that the length is equal to target_size
    tensor = torch.tensor(content, dtype=torch.long, device=device).view(-1, 1)

    if tensor.size()[0] < target_size:
        padding = torch.zeros(target_size - tensor.size()[0], 1, dtype=torch.int32, device=device)
        tensor = torch.cat((tensor, padding), dim=0)
        
    return tensor

In [5]:
class TranscriptionDataset(Dataset):
    def __init__(self, path_to_words: str = 'data/word-based/words.csv'):
        self.__words = pd.read_csv(path_to_words)

    def __len__(self):
        return self.__words.shape[0]

    def __getitem__(self, idx):
        if idx >= self.__len__():
            raise IndexError
        
        row = self.__words.iloc[idx]

        input_tensor = sentence_to_tensor(json.loads(row[2]))
        target_tensor = sentence_to_tensor(json.loads(row[3]))
        return input_tensor, target_tensor


In [6]:
class Trainer:
    def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN,
                 encoder_optimizer: optim.Optimizer, decoder_optimizer: optim.Optimizer,
                 max_length: int = MAX_LENGTH):
        self.__encoder = encoder.to(device)
        self.__decoder = decoder.to(device)
        self.__encoder_optimizer = encoder_optimizer
        self.__decoder_optimizer = decoder_optimizer
        self.__max_length = max_length
        self.__loss = 0

    def __init_train(self):
        encoder_hidden = self.__encoder.init_hidden()

        self.__encoder_optimizer.zero_grad()
        self.__decoder_optimizer.zero_grad()

        encoder_outputs = torch.zeros(self.__max_length, self.__encoder.hidden_size, device=device)

        self.__loss = 0

        return encoder_hidden, encoder_outputs
    
    def __encoder_train(self, encoder_outputs, input_tensor, encoder_hidden):
        input_length = input_tensor.size(0)

        for ei in range(input_length):
            encoder_output, encoder_hidden = self.__encoder(input_tensor[ei], encoder_hidden)
            encoder_outputs[ei] = encoder_output[0, 0]

        return encoder_outputs

    def __optimizers_step(self):
        self.__encoder_optimizer.step()
        self.__decoder_optimizer.step()

    def __decoder_train(self, decoder_input, decoder_hidden, encoder_outputs, target_tensor, criterion):
        target_length = target_tensor.size(0)
        use_teacher_forcing = True if random.random() < TEACHER_FORCING_RATIO else False

        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = self.__decoder(
                    decoder_input, decoder_hidden, encoder_outputs)

            if use_teacher_forcing:
                decoder_input = target_tensor[di]
            else:
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze().detach()
            
            self.__loss += criterion(decoder_output, target_tensor[di])

            if use_teacher_forcing and decoder_input.item() == EOS_TOKEN:
                break

    def train(self, input_tensor: Tensor, target_tensor: Tensor, 
              criterion: nn.Module, max_length: int = MAX_LENGTH) -> tuple[float, float]:        

        # Encoder training
        encoder_hidden, encoder_outputs = self.__init_train()
        encoder_outputs = self.__encoder_train(encoder_outputs, input_tensor, encoder_hidden)

        # Decoder training
        decoder_input = torch.tensor([[SOS_TOKEN]], device=device)
        decoder_hidden = encoder_hidden

        self.__decoder_train(decoder_input, decoder_hidden, encoder_outputs, target_tensor, criterion)
        
        # Optimizers step
        self.__loss.backward()
        self.__optimizers_step()

        return self.__loss.item() / target_tensor.size(0)

In [7]:
def train_loop(encoder: EncoderRNN, decoder: DecoderRNN, dataset: TranscriptionDataset,
               epochs: int, print_every: int = 100):
    encoder_optimizer = optim.SGD(encoder.parameters(), lr=LEARNING_RATE)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=LEARNING_RATE)
    criterion = nn.NLLLoss()


    trainer = Trainer(encoder, decoder, encoder_optimizer, decoder_optimizer)
    
    for epoch in range(epochs):
        total_loss = 0

        # for iteration, (input_tensor, target_tensor) in tqdm.tqdm(enumerate(dataset)):
        for iteration in tqdm.tqdm(range(2000)):
            input_tensor, target_tensor = dataset[iteration]
        # for iteration, (input_tensor, target_tensor) in enumerate(dataset):
            loss = trainer.train(input_tensor, target_tensor, criterion)
            total_loss += loss

            # if iteration % print_every == 0:
            #     print(f'Epoch: {epoch} Iteration: {iteration + 1} loss: {loss}')
            #     print(f'Average loss: {total_loss / (iteration + 1)}')
        
        print(f'Average loss: {total_loss / (iteration + 1)}')

In [8]:
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
    with torch.no_grad():
        input_tensor = tensorFromSentence(input_lang, sentence)
        input_length = input_tensor.size()[0]
        encoder_hidden = encoder.initHidden()

        encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei],
                                                     encoder_hidden)
            encoder_outputs[ei] += encoder_output[0, 0]

        decoder_input = torch.tensor([[SOS_token]], device=device)  # SOS

        decoder_hidden = encoder_hidden

        decoded_words = []
        decoder_attentions = torch.zeros(max_length, max_length)

        for di in range(max_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            decoder_attentions[di] = decoder_attention.data
            topv, topi = decoder_output.data.topk(1)
            if topi.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            else:
                decoded_words.append(output_lang.index2word[topi.item()])

            decoder_input = topi.squeeze().detach()

        return decoded_words, decoder_attentions[:di + 1]

In [9]:
encoder = EncoderRNN(MAX_LENGTH, HIDDEN_SIZE).to(device)
decoder = DecoderRNN(HIDDEN_SIZE, MAX_LENGTH).to(device)

dataset = TranscriptionDataset()

In [11]:
train_loop(encoder, decoder, dataset, epochs=200, print_every=10000)

100%|██████████| 2000/2000 [02:17<00:00, 14.59it/s]


Average loss: 2.3034651519647786


100%|██████████| 2000/2000 [02:17<00:00, 14.57it/s]


Average loss: 2.5225870444136627


100%|██████████| 2000/2000 [02:17<00:00, 14.53it/s]


Average loss: 2.395883613720743


100%|██████████| 2000/2000 [02:17<00:00, 14.51it/s]


Average loss: 2.525039814868443


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.360507140851355


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 2.3857642246635873


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 2.5427475322770356


100%|██████████| 2000/2000 [02:18<00:00, 14.47it/s]


Average loss: 2.414130589088921


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 2.4994731715565037


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 2.4759080479447237


100%|██████████| 2000/2000 [02:18<00:00, 14.47it/s]


Average loss: 2.428460753434141


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.447542114506304


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.331624752964773


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.342333669422382


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.185845666579795


100%|██████████| 2000/2000 [02:18<00:00, 14.47it/s]


Average loss: 2.372403208155029


100%|██████████| 2000/2000 [02:18<00:00, 14.44it/s]


Average loss: 2.2453117839316263


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.3601454335031318


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.2504905972044256


100%|██████████| 2000/2000 [02:18<00:00, 14.47it/s]


Average loss: 2.2111590361057885


100%|██████████| 2000/2000 [02:18<00:00, 14.46it/s]


Average loss: 2.332442412732349


100%|██████████| 2000/2000 [02:17<00:00, 14.56it/s]


Average loss: 2.169294179479841


100%|██████████| 2000/2000 [02:17<00:00, 14.54it/s]


Average loss: 2.364306756127049


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 2.2464165217674954


100%|██████████| 2000/2000 [02:18<00:00, 14.47it/s]


Average loss: 2.3070976033412234


100%|██████████| 2000/2000 [02:18<00:00, 14.45it/s]


Average loss: 2.229591859787286


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.3021015557168236


100%|██████████| 2000/2000 [02:18<00:00, 14.45it/s]


Average loss: 2.2714569153752135


100%|██████████| 2000/2000 [02:18<00:00, 14.47it/s]


Average loss: 2.1390557742958354


100%|██████████| 2000/2000 [02:17<00:00, 14.51it/s]


Average loss: 2.2900261293464976


100%|██████████| 2000/2000 [02:17<00:00, 14.52it/s]


Average loss: 2.22864781000917


100%|██████████| 2000/2000 [02:18<00:00, 14.46it/s]


Average loss: 2.1468035140641994


100%|██████████| 2000/2000 [02:18<00:00, 14.45it/s]


Average loss: 2.2580036767469336


100%|██████████| 2000/2000 [02:18<00:00, 14.47it/s]


Average loss: 2.2575847739568866


100%|██████████| 2000/2000 [02:18<00:00, 14.44it/s]


Average loss: 2.234526935382626


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.191696726422913


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 2.241962191158622


100%|██████████| 2000/2000 [02:17<00:00, 14.51it/s]


Average loss: 2.2071105791380687


100%|██████████| 2000/2000 [02:18<00:00, 14.47it/s]


Average loss: 2.2981806048305886


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.102837747476466


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.286817719822199


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.1632748151026946


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.12891276791734


100%|██████████| 2000/2000 [02:17<00:00, 14.49it/s]


Average loss: 2.175682536581872


100%|██████████| 2000/2000 [02:17<00:00, 14.52it/s]


Average loss: 2.0254274581687515


100%|██████████| 2000/2000 [02:17<00:00, 14.49it/s]


Average loss: 2.147542449340012


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 2.103730111713144


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.1342302768532635


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.1234525407133016


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.0563322863276547


100%|██████████| 2000/2000 [02:18<00:00, 14.47it/s]


Average loss: 2.0708138600134522


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 1.9741104874073614


100%|██████████| 2000/2000 [02:18<00:00, 14.45it/s]


Average loss: 2.159168681567823


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.1136469262015694


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.111386799624271


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 2.080693475612453


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 2.0140849658838462


100%|██████████| 2000/2000 [02:18<00:00, 14.47it/s]


Average loss: 2.0869280303431252


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.1310571175393918


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.0960877215560068


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.0609171990005075


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 2.030805646231476


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 1.960446231351771


100%|██████████| 2000/2000 [02:17<00:00, 14.53it/s]


Average loss: 2.246261997196039


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.09893117423796


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.1006103181771874


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 2.1877556748155116


100%|██████████| 2000/2000 [02:17<00:00, 14.51it/s]


Average loss: 2.1188813183576256


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.1653450904127576


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 2.0761296023214335


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.25216533282105


100%|██████████| 2000/2000 [02:17<00:00, 14.51it/s]


Average loss: 2.119652530243699


100%|██████████| 2000/2000 [02:17<00:00, 14.51it/s]


Average loss: 2.213450346567263


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.1868404250312823


100%|██████████| 2000/2000 [02:17<00:00, 14.53it/s]


Average loss: 2.1811582371617715


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 1.9972631519378083


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.0760989374644274


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.1664603945805996


100%|██████████| 2000/2000 [02:17<00:00, 14.51it/s]


Average loss: 2.2119356135247394


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.196759476817828


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.2101608385468854


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.3165672098482206


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.148386419672363


100%|██████████| 2000/2000 [02:17<00:00, 14.52it/s]


Average loss: 2.228221517576299


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 2.2337344464013227


100%|██████████| 2000/2000 [02:17<00:00, 14.54it/s]


Average loss: 2.1307644507414882


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.3716138290855193


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.3066759843389772


100%|██████████| 2000/2000 [02:19<00:00, 14.37it/s]


Average loss: 2.3402107523125655


100%|██████████| 2000/2000 [02:20<00:00, 14.28it/s]


Average loss: 2.2097336168557833


100%|██████████| 2000/2000 [02:20<00:00, 14.19it/s]


Average loss: 2.193838867812089


100%|██████████| 2000/2000 [02:19<00:00, 14.33it/s]


Average loss: 2.328278877144132


100%|██████████| 2000/2000 [02:18<00:00, 14.46it/s]


Average loss: 2.2843104172122297


100%|██████████| 2000/2000 [02:18<00:00, 14.45it/s]


Average loss: 2.1773357960741286


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 2.29584971044097


100%|██████████| 2000/2000 [02:17<00:00, 14.51it/s]


Average loss: 2.220146917087927


100%|██████████| 2000/2000 [02:17<00:00, 14.49it/s]


Average loss: 2.127188187515231


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.251349239168034


100%|██████████| 2000/2000 [02:18<00:00, 14.47it/s]


Average loss: 2.0837854612310136


100%|██████████| 2000/2000 [02:18<00:00, 14.48it/s]


Average loss: 2.0978735001154383


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 2.1102954262276747


100%|██████████| 2000/2000 [02:17<00:00, 14.51it/s]


Average loss: 2.158436981516824


100%|██████████| 2000/2000 [02:18<00:00, 14.45it/s]


Average loss: 2.184556053154907


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 1.999648026278319


100%|██████████| 2000/2000 [02:17<00:00, 14.50it/s]


Average loss: 2.147014912944445


100%|██████████| 2000/2000 [02:18<00:00, 14.49it/s]


Average loss: 2.1478041164404917


100%|██████████| 2000/2000 [02:18<00:00, 14.42it/s]


Average loss: 2.0700904237115907


 24%|██▎       | 471/2000 [00:32<01:45, 14.50it/s]


KeyboardInterrupt: 

In [None]:
torch.save(encoder.state_dict(), 'encoder.pt')
torch.save(decoder.state_dict(), 'decoder.pt')