## About

This notebook trains a recurrent neural network using a supervised dataset.

## Dataset

In this section, we load the dataset and pre-process questions and answers that will be use to train the network.

Define some useful functions:

In [43]:
def add_final_sentence_character(sentence, final_character="|"):
    return f"{sentence}{final_character}"


def sentence2int(sentence, character2int):
    return [character2int[character] for character in sentence]


def pad_sentence(sentence, size, pad_character=" "):
    return str(sentence).ljust(size, pad_character)

In [45]:
import pandas as pd


dataset = pd.read_csv("../dataset/supervised.csv", header=0)
questions = [add_final_sentence_character(x) for x in dataset["question"]]
answers = [add_final_sentence_character(x) for x in dataset["answer"]]

unique_characters = set("".join(questions + answers))
unique_characters_length = len(unique_characters)

character2int = {character: i for i, character in enumerate(unique_characters)}
int2character = {i: character for i, character in enumerate(unique_characters)}

In [46]:
longer_question_length = len(max(questions, key=len))
longer_answer_length = len(max(answers, key=len))

questions = [pad_sentence(x, longer_question_length) for x in questions]
answers = [pad_sentence(x, longer_answer_length) for x in answers]

## Network

This section defines the network architecture that will be used.

In [None]:
import torch
import torch.nn as nn


class RecurrentNetwork:
    def __init__(self, input_size, hidden_dim_rnn, n_layers_rnn, output_size):
        super(RecurrentNetwork, self).__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.hidden_dim_rnn = hidden_dim_rnn
        self.n_layers_rnn = n_layers_rnn

        self.rnn = nn.RNN(
            input_size,
            hidden_dim_rnn,
            n_layers_rnn,
            batch_first=True,
            nonlinearity="relu"
        )
        self.fc = nn.Linear(hidden_dim_rnn, output_size)

    def forward(self, x, hidden_state):
        