In [1]:
import random
import polars as pl

df = pl.read_csv("../data/all_states.csv").select(["Sex", "Name"]).unique()

names = df["Name"].to_list()
gender = df["Sex"].to_list()

gender = list(map(lambda x: str(0) if x == "M" else str(1), gender))

random.shuffle(names)

train_names = names[: int(len(names) * 0.8)]
test_names = names[int(len(names) * 0.8) :]

train_gender = gender[: int(len(gender) * 0.8)]
test_gender = gender[int(len(gender) * 0.8) :]

assert len(train_names) == len(
    train_gender
), "Number of training names do not match number of training genders."
assert len(test_names) == len(
    test_gender
), "Number of testing names do not match number of testing genders."

with open("../data/train_names.txt", "w") as f:
    f.write(
        "\n".join(
            list(map(lambda x: "".join(x), zip(train_gender, train_names)))
        ).strip()
    )

with open("../data/test_names.txt", "w") as f:
    f.write(
        "\n".join(list(map(lambda x: "".join(x), zip(test_gender, test_names)))).strip()
    )

print(
    f"Training with {len(train_names):,} names, with the longest name being {max(len(name) for name in train_names)} characters long."
)
print(
    f"Testing with {len(test_names):,} names, with the longest name being {max(len(name) for name in test_names)} characters long."
)

Training with 27,460 names, with the longest name being 15 characters long.
Testing with 6,865 names, with the longest name being 15 characters long.


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchinfo import summary
from tqdm import tqdm

In [3]:
train_names = open("../data/train_names.txt").read().split()
test_names = open("../data/test_names.txt").read().split()

In [4]:
special_tokens = ["<pad>", "<sos>", "<eos>", "<unk>", "0", "1"]
tokens = special_tokens + list("abcdefghijklmnopqrstuvwxyz")
char_to_idx = {char: idx for idx, char in enumerate(tokens)}
idx_to_char = {idx: char for idx, char in enumerate(tokens)}


def encode(
    name: str, add_special_tokens: bool = True, pad_to_length: int = 32
) -> list[int]:
    name = name.lower()
    encoded_name = [char_to_idx[char] for char in name]
    if add_special_tokens:
        encoded_name = [char_to_idx["<sos>"]] + encoded_name + [char_to_idx["<eos>"]]
    if pad_to_length:
        encoded_name += [char_to_idx["<pad>"]] * (pad_to_length - len(encoded_name))
    return encoded_name


def decode(encoded_name: list[int], strip_special_tokens: bool = True) -> str:
    if strip_special_tokens:
        encoded_name = [
            idx
            for idx in encoded_name
            if idx
            not in [char_to_idx["<sos>"], char_to_idx["<eos>"], char_to_idx["<pad>"]]
        ]
    return "".join([idx_to_char[idx] for idx in encoded_name])


def encode_batch(
    names: list[str], add_special_tokens: bool = True, pad_to_length: int = 32
) -> torch.Tensor:
    return torch.tensor(
        [encode(name, add_special_tokens, pad_to_length) for name in names]
    )


def decode_batch(
    encoded_names: torch.Tensor, strip_special_tokens: bool = True
) -> list[str]:
    return [
        decode(encoded_name.tolist(), strip_special_tokens)
        for encoded_name in encoded_names
    ]

In [5]:
encoded_train = encode_batch(train_names, add_special_tokens=True, pad_to_length=24)
encoded_test = encode_batch(test_names, add_special_tokens=True, pad_to_length=24)
train_dataset = TensorDataset(encoded_train)
test_dataset = TensorDataset(encoded_test)
train_loader = DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=4, persistent_workers=True
)
test_loader = DataLoader(
    test_dataset, batch_size=128, shuffle=False, num_workers=4, persistent_workers=True
)

In [6]:
from name_generator import Model, ModelConfig


model = Model(
    ModelConfig(
        vocab_size=len(tokens),
        embedding_dim=48,
        num_layers=6,
        max_length=24,  # not padding to nearest 32 because max length of names is 17 - bump this for `theoretically` better performance
        q_heads=12,
        kv_heads=4,
        m=4,
        tie_weights=False,
    )
)
optimizer = optim.AdamW(model.parameters(), lr=4e-4)
device = torch.device("mps")
model.to(device)
summary(
    model,
    dtypes=[torch.long],
    device=device,
    input_data=(
        next(iter(train_loader))[0].to(device),
        (next(iter(train_loader))[0] == 0).to(device),
    ),
)

Layer (type:depth-idx)                        Output Shape              Param #
Model                                         [128, 24, 32]             1,152
├─Embedding: 1-1                              [128, 24, 48]             1,536
├─ModuleList: 1-2                             --                        --
│    └─TransformerBlock: 2-1                  [128, 24, 48]             --
│    │    └─LayerNorm: 3-1                    [128, 24, 48]             96
│    │    └─CausalSelfAttention: 3-2          [128, 24, 48]             6,272
│    │    └─LayerNorm: 3-3                    [128, 24, 48]             96
│    │    └─MLP: 3-4                          [128, 24, 48]             18,672
│    └─TransformerBlock: 2-2                  [128, 24, 48]             --
│    │    └─LayerNorm: 3-5                    [128, 24, 48]             96
│    │    └─CausalSelfAttention: 3-6          [128, 24, 48]             6,272
│    │    └─LayerNorm: 3-7                    [128, 24, 48]             96
│   

In [7]:
test_loss = 0
for epoch in range(10):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=True)
    for (batch,) in pbar:
        batch = batch.to(device)

        inputs = batch[:, :-1]
        targets = batch[:, 1:]

        optimizer.zero_grad()
        output = model(inputs, key_padding_mask=(inputs == char_to_idx["<pad>"]))
        loss = nn.functional.cross_entropy(
            output.reshape(-1, len(tokens)), targets.reshape(-1)
        )
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item(), test_loss=test_loss)

    with torch.no_grad():
        model.eval()
        test_loss = 0
        for (batch,) in test_loader:
            batch = batch.to(device)

            inputs = batch[:, :-1]
            targets = batch[:, 1:]

            output = model(inputs, key_padding_mask=(inputs == char_to_idx["<pad>"]))
            loss = nn.functional.cross_entropy(
                output.reshape(-1, len(tokens)), targets.reshape(-1)
            )
            test_loss += loss.item()
        test_loss /= len(test_loader)

Epoch 1: 100%|██████████| 215/215 [00:09<00:00, 23.66it/s, loss=0.809, test_loss=0]
Epoch 2: 100%|██████████| 215/215 [00:08<00:00, 26.17it/s, loss=0.776, test_loss=0.773]
Epoch 3: 100%|██████████| 215/215 [00:08<00:00, 26.41it/s, loss=0.724, test_loss=0.749]
Epoch 4: 100%|██████████| 215/215 [00:08<00:00, 26.33it/s, loss=0.725, test_loss=0.736]
Epoch 5: 100%|██████████| 215/215 [00:08<00:00, 26.48it/s, loss=0.737, test_loss=0.724]
Epoch 6: 100%|██████████| 215/215 [00:08<00:00, 26.47it/s, loss=0.703, test_loss=0.716]
Epoch 7: 100%|██████████| 215/215 [00:08<00:00, 26.25it/s, loss=0.723, test_loss=0.708]
Epoch 8: 100%|██████████| 215/215 [00:08<00:00, 25.83it/s, loss=0.754, test_loss=0.702]
Epoch 9: 100%|██████████| 215/215 [00:08<00:00, 24.79it/s, loss=0.689, test_loss=0.698]
Epoch 10: 100%|██████████| 215/215 [00:08<00:00, 24.91it/s, loss=0.679, test_loss=0.693]


In [8]:
@torch.no_grad()
def generate_names(n=16):
    model.eval()
    genders = torch.cat(
        [
            torch.tensor([[char_to_idx["0"]]]).repeat(n // 2, 1),
            torch.tensor([[char_to_idx["1"]]]).repeat(n // 2, 1),
        ],
        dim=0,
    )
    start_token = torch.tensor([[char_to_idx["<sos>"]]]).repeat(n, 1)
    start_token = torch.cat([start_token, genders], dim=1).to(device)

    generated = start_token
    for _ in range(20):
        output = model(generated) / 0.6

        token = torch.multinomial(F.softmax(output[:, -1], dim=1), 1)

        generated = torch.cat([generated, token], dim=1)

        if token.all() == char_to_idx["<pad>"]:
            break

    return decode_batch(generated, strip_special_tokens=True)


names = generate_names()
names = [name[1:].capitalize() for name in names]
print("Male names:", *names[:8], sep="\n    ", end="\n\n")
print("Female names:", *names[8:], sep="\n    ")

Male names:
    Karin
    Lei
    Elizi
    Kanic
    Peank
    Elena
    Aliya
    Anasl

Female names:
    Landa
    Berne
    Samar
    Harmo
    Amani
    Harmo
    Taden
    Shant


In [9]:
torch.save(model.state_dict(), "../model/model.pth")