In [1]:
import datetime
import json
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import urllib.request
from torch.utils.data import Dataset, DataLoader

In [2]:
# setup memory in the GPU 1
#gpu_device = 1

In [3]:
with urllib.request.urlopen(
    "https://saturn-public-data.s3.us-east-2.amazonaws.com/examples/pytorch/seattle_pet_licenses_cleaned.json"
) as f:
    pet_names = json.loads(f.read().decode("utf-8"))

# Our list of characters, where * represents blank and + represents stop
characters = list("*+abcdefghijklmnopqrstuvwxyz-. ")
str_len = 8

In [4]:
def format_training_data(pet_names, device=None):
    def get_substrings(in_str):
        # add the stop character to the end of the name, then generate all the partial names
        in_str = in_str + "+"
        res = [in_str[0:j] for j in range(1, len(in_str) + 1)]
        return res

    pet_names_expanded = [get_substrings(name) for name in pet_names]
    pet_names_expanded = [item for sublist in pet_names_expanded for item in sublist]
    pet_names_characters = [list(name) for name in pet_names_expanded]
    pet_names_padded = [name[-(str_len + 1) :] for name in pet_names_characters]
    pet_names_padded = [
        list((str_len + 1 - len(characters)) * "*") + characters for characters in pet_names_padded
    ]
    pet_names_numeric = [[characters.index(char) for char in name] for name in pet_names_padded]

    # the final x and y data to use for training the model. Note that the x data needs to be one-hot encoded
    if device is None:
        y = torch.tensor([name[1:] for name in pet_names_numeric])
        x = torch.tensor([name[:-1] for name in pet_names_numeric])
    else:
        y = torch.tensor([name[1:] for name in pet_names_numeric], device=device)
        x = torch.tensor([name[:-1] for name in pet_names_numeric], device=device)
    x = torch.nn.functional.one_hot(x, num_classes=len(characters)).float()
    return x, y

In [5]:
class OurDataset(Dataset):
    def __init__(self, pet_names, device=None):
        self.x, self.y = format_training_data(pet_names, device)
        self.permute()

    def __getitem__(self, idx):
        idx = self.permutation[idx]
        return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)

    def permute(self):
        self.permutation = torch.randperm(len(self.x))

In [6]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.lstm = nn.LSTM(
            input_size=len(characters),
            hidden_size=self.lstm_size,
            num_layers=4,
            batch_first=True,
            dropout=0.1,
        )
        self.fc = nn.Linear(self.lstm_size, len(characters))

    def forward(self, x):
        output, state = self.lstm(x)
        logits = self.fc(output)
        return logits

In [12]:
def train():
    num_epochs = 8
    batch_size = 4096
    lr = 0.001
    device = torch.device(0)

    dataset = OurDataset(pet_names, device=device)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    model = Model()
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        dataset.permute()
        for i, (batch_x, batch_y) in enumerate(loader):
            optimizer.zero_grad()
            batch_y_pred = model(batch_x)

            loss = criterion(batch_y_pred.transpose(1, 2), batch_y)
            loss.backward()
            optimizer.step()
        #print(
        #    f"{datetime.datetime.now().isoformat()} - epoch {epoch} complete - loss {loss.item()}"
        #)
    return model

In [None]:
%timeit -n 5000 model = train()

In [9]:
def generate_name(model, characters, str_len):
    device = torch.device(0)
    in_progress_name = []
    next_letter = ""
    while not next_letter == "+" and len(in_progress_name) < 30:
        # prep the data to run in the model again
        in_progress_name_padded = in_progress_name[-str_len:]
        in_progress_name_padded = (
            list((str_len - len(in_progress_name_padded)) * "*") + in_progress_name_padded
        )
        in_progress_name_numeric = [characters.index(char) for char in in_progress_name_padded]
        in_progress_name_tensor = torch.tensor(in_progress_name_numeric, device=device)
        in_progress_name_tensor = torch.nn.functional.one_hot(
            in_progress_name_tensor, num_classes=len(characters)
        ).float()
        in_progress_name_tensor = torch.unsqueeze(in_progress_name_tensor, 0)

        # get the probabilities of each possible next character by running the model
        with torch.no_grad():
            next_letter_probabilities = model(in_progress_name_tensor)

        next_letter_probabilities = next_letter_probabilities[0, -1, :]
        next_letter_probabilities = (
            torch.nn.functional.softmax(next_letter_probabilities, dim=0).detach().cpu().numpy()
        )
        next_letter_probabilities = next_letter_probabilities[1:]
        next_letter_probabilities = [
            p / sum(next_letter_probabilities) for p in next_letter_probabilities
        ]

        # determine what the actual letter is
        next_letter = characters[
            np.random.choice(len(characters) - 1, p=next_letter_probabilities) + 1
        ]
        if next_letter != "+":
            # if the next character isn't stop add the latest generated character to the name and continue
            in_progress_name.append(next_letter)
    # turn the list of characters into a single string
    pet_name = "".join(in_progress_name).title()
    return pet_name

In [12]:
# Generate 50 names then filter out existing ones
generated_names = [generate_name(model, characters, str_len) for i in range(0, 50)]
generated_names = [name for name in generated_names if name not in pet_names]
print(generated_names)

['Mlyo', 'Stelsy', 'Telky', 'Mohli L Mepkeo', 'Sotpie', 'Sceoweis', 'Lakisa Sinybin', 'Bevi', 'Gumo', 'Sharlef', 'Fulbe', 'Roclee', 'Pibysha', 'Maeley', 'Vooil', 'Hoatya', 'Socan', 'Telo Bsalela', 'Dilos', 'Gannl', 'Baize', 'Chiviy', 'Begta', 'Macdlorlte', 'Ezcie', 'Feen', 'Cabiletg', 'Brikankran', 'Matsie', 'Maiopas', 'Meldee', 'Boto', 'Jufi', 'Mlova Jos', 'Tonva', 'Ruma', 'Oni', 'Tacgyslanri', 'Gabla', 'Neiga', 'Rarre', 'Asihea Deruly', 'Shoxo', 'Pepny', 'Dhoy', 'Geop', 'Gabbee', 'Sama', 'Yona', 'Geegka']
