In [1]:
import torch
from typing import Tuple, List, Optional

import unicodedata
import matplotlib.pyplot as plt

def unicode_to_ascii(name: str, letters: str) -> str:
    return "".join(
        c for c in unicodedata.normalize("NFKD", name)
        if not unicodedata.combining(c) and c in letters
    )

def name_to_tensor(name: str, letters: str, n_letters: int) -> torch.Tensor:
    name = unicode_to_ascii(name, letters)
    name_tensor = torch.zeros(len(name), 1, n_letters)
    for i, char in enumerate(name):
        name_tensor[i, 0, letters.index(char)] = 1
    return name_tensor

def output_category(output: torch.Tensor, categories: list) -> Tuple[str, int]:
    return categories[torch.argmax(output, dim=1).item()]

def plot_losses(losses: List[float], save_path: Optional[str] = None) -> None:
    plt.plot(losses, label="Loss")
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.title("Training Loss")
    plt.legend()
    if save_path:
        plt.savefig(save_path)
    plt.show()

In [2]:
import glob
from typing import List, Dict

from utils import unicode_to_ascii

def get_names(file: str, letters: str) -> List[str]:
    names = open(file, encoding="utf-8").read().strip().split("\n")
    return [unicode_to_ascii(name, letters) for name in names]

def get_data(path: str, letters: str) -> Dict[str, List[str]]:
    files = glob.glob(path)
    category_name_dict = {}
    for filename in files:
        category = filename.split("\\")[-1].split('.')[0]
        category_name_dict[category] = get_names(filename, letters)
    categories = list(category_name_dict.keys())
    n_letters = len(letters)
    n_categories = len(categories)
    return category_name_dict, categories, n_letters, n_categories

In [3]:
import string

path = "./data/names/*.txt"
letters = string.ascii_letters + " .,;'"

category_name_dict, categories, n_letters, n_categories = get_data(path, letters)
n_hidden = 128

In [4]:
import random

def random_choice(cat_list: list) -> str:
    return cat_list[random.randint(0, len(cat_list)-1)]

def random_sample():
    category = random_choice(categories)
    name = random_choice(category_name_dict[category])
    category_tensor = torch.tensor([categories.index(category)], dtype=torch.long)
    name_tensor = name_to_tensor(name, letters, n_letters)
    return name_tensor, category_tensor, name, category

In [5]:
def category_tensor(category):
    idx = categories.index(category)
    cat_tensor = torch.zeros(1, n_categories)
    cat_tensor[0][idx] = 1
    return cat_tensor

In [6]:
category_tensor("Irish")

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [7]:
category_tensor = torch.tensor([categories.index("Irish")], dtype=torch.long)
category_tensor

tensor([8])

In [8]:
def target_tensor(name):
    letter_indexes = [letters.find(name[i]) for i in range(1, len(name))]
    print(letter_indexes)
    letter_indexes.append(n_letters - 1) # EOS
    return torch.LongTensor(letter_indexes)

In [9]:
nt = name_to_tensor("ngt", letters, n_letters)
nt.shape

torch.Size([3, 1, 57])

In [10]:
nt

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0.]]])

In [11]:
tt = target_tensor("ngt")
tt

[6, 19]


tensor([ 6, 19, 56])

In [12]:
tt.shape

torch.Size([3])

In [13]:
tt.unsqueeze_(-1)

tensor([[ 6],
        [19],
        [56]])

In [14]:
tt.shape

torch.Size([3, 1])