In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

---

## RNN model

![rnn name classification model](../assets/rnn-name-classification.png)

In [None]:
class RNN(nn.Module):
    # implement RNN from scratch rather than using nn.RNN
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input_tensor, hidden_tensor):
        combined = torch.cat((input_tensor, hidden_tensor), 1)

        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def init_hidden(self):
        return torch.zeros(1, self.hidden_size)

In [None]:
from utils import ALL_LETTERS, N_LETTERS
from utils import (
    load_data,
    letter_to_tensor,
    line_to_tensor,
    random_training_example,
)

category_lines, all_categories = load_data()
n_categories = len(all_categories)

print("n_categories:", n_categories)
print("n_letters:", N_LETTERS)
print("all_categories:", all_categories)

In [None]:
n_hidden = 128
rnn = RNN(N_LETTERS, n_hidden, n_categories)

In [None]:
# one step
input_tensor = letter_to_tensor("A")
hidden_tensor = rnn.init_hidden()

output, next_hidden = rnn(input_tensor, hidden_tensor)
print(f"{output.size()=}")
print(f"{next_hidden.size()=}")

In [None]:
# whole sequence/name
input_tensor = line_to_tensor("Albert")
hidden_tensor = rnn.init_hidden()
print(f"{input_tensor.size()=}")

output, next_hidden = rnn(input_tensor[0], hidden_tensor)
print(output.size())
print(next_hidden.size())

In [None]:
def category_from_output(output):
    category_idx = torch.argmax(output).item()
    return all_categories[category_idx]


print(category_from_output(output))

In [None]:
criterion = nn.NLLLoss()
learning_rate = 0.005
optimizer = torch.optim.SGD(rnn.parameters(), lr=learning_rate)

In [None]:
def train(line_tensor, category_tensor):
    hidden = rnn.init_hidden()

    for i in range(line_tensor.size()[0]):
        output, hidden = rnn(line_tensor[i], hidden)

    loss = criterion(output, category_tensor)

    optimizer.zero_grad() # zero the gradient
    loss.backward() # backpropagation (and calculate the gradients)
    optimizer.step() # update the weights (by the gradients calculated in the previous line)

    return output, loss.item()

In [None]:
from tqdm import tqdm
current_loss = 0
all_losses = []
plot_steps, print_steps = 1000, 5000
n_iters = 100000
correct_count = 0
incorrect_count = 0

for i in tqdm(range(n_iters)):
    category, line, category_tensor, line_tensor = random_training_example(
        category_lines, all_categories
    )

    output, loss = train(line_tensor, category_tensor)
    current_loss += loss
    guess = category_from_output(output)
    if guess == category:
        correct_count += 1
    else:
        incorrect_count += 1

    if (i + 1) % plot_steps == 0:
        all_losses.append(current_loss / plot_steps)
        current_loss = 0

    if (i + 1) % print_steps == 0:
        print(f"\n\n{correct_count=}; {incorrect_count=}\n")
        correct_count = 0
        incorrect_count = 0
        guess = category_from_output(output)
        correct = "CORRECT" if guess == category else f"WRONG ({category})"
        print(f"{i+1} {(i+1)/n_iters*100} {loss:.4f} {line} / {guess} {correct}")
        print("\n================================================================\n")

In [None]:
plt.figure()
plt.plot(all_losses)
plt.show()

In [None]:
def predict(input_line):
    print(f"\n> {input_line}")
    rnn.eval()
    with torch.no_grad():
        line_tensor = line_to_tensor(input_line)

        hidden = rnn.init_hidden()

        for i in range(line_tensor.size()[0]):
            output, hidden = rnn(line_tensor[i], hidden)

        guess = category_from_output(output)
        print(guess)

In [None]:
while True:
    sentence = input("Input:")
    if sentence == "quit":
        break

    predict(sentence)