In [1]:
with open('data/makemore/baby_names.txt') as f:
    names = f.read().split()
names[:5]

['Mary', 'Annie', 'Mattie', 'Ruby', 'Willie']

In [2]:
longest_name = len(max(names, key=len))

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
vocab = set(''.join(names))
vocab = sorted(vocab)
vocab = ['<s>', '</s>', '<pad>', '<unk>'] + vocab

char2idx = {char: idx for idx, char in enumerate(vocab)}
idx2char = {idx: char for idx, char in enumerate(vocab)}

print(char2idx)
print(idx2char)
print(len(vocab))

def encode_string(s: str, include_start=True, include_end=True, include_pad=True) -> torch.Tensor:
    result = [char2idx.get(char, char2idx['<unk>']) for char in s]

    if include_start:
        result = [char2idx['<s>']] + result

    if include_end:
        result = result + [char2idx['</s>']]

    result = torch.Tensor(result).long()

    if include_pad:
        result = F.pad(result, (0, longest_name - len(s)), value=char2idx['<pad>'])

    return result

def decode_string(indices: torch.Tensor) -> str:
    indices = indices.tolist()
    return ''.join(idx2char[idx] for idx in indices).replace('<pad>', '').replace('<s>', '').replace('</s>', '')

def encode_batch(strings: list[str]) -> torch.Tensor:
    return torch.stack([encode_string(s) for s in strings])

def decode_batch(batch):
    return [decode_string(indices) for indices in batch]

encoded_corpus = encode_batch(names)
encoded_corpus.shape

{'<s>': 0, '</s>': 1, '<pad>': 2, '<unk>': 3, 'A': 4, 'B': 5, 'C': 6, 'D': 7, 'E': 8, 'F': 9, 'G': 10, 'H': 11, 'I': 12, 'J': 13, 'K': 14, 'L': 15, 'M': 16, 'N': 17, 'O': 18, 'P': 19, 'Q': 20, 'R': 21, 'S': 22, 'T': 23, 'U': 24, 'V': 25, 'W': 26, 'X': 27, 'Y': 28, 'Z': 29, 'a': 30, 'b': 31, 'c': 32, 'd': 33, 'e': 34, 'f': 35, 'g': 36, 'h': 37, 'i': 38, 'j': 39, 'k': 40, 'l': 41, 'm': 42, 'n': 43, 'o': 44, 'p': 45, 'q': 46, 'r': 47, 's': 48, 't': 49, 'u': 50, 'v': 51, 'w': 52, 'x': 53, 'y': 54, 'z': 55}
{0: '<s>', 1: '</s>', 2: '<pad>', 3: '<unk>', 4: 'A', 5: 'B', 6: 'C', 7: 'D', 8: 'E', 9: 'F', 10: 'G', 11: 'H', 12: 'I', 13: 'J', 14: 'K', 15: 'L', 16: 'M', 17: 'N', 18: 'O', 19: 'P', 20: 'Q', 21: 'R', 22: 'S', 23: 'T', 24: 'U', 25: 'V', 26: 'W', 27: 'X', 28: 'Y', 29: 'Z', 30: 'a', 31: 'b', 32: 'c', 33: 'd', 34: 'e', 35: 'f', 36: 'g', 37: 'h', 38: 'i', 39: 'j', 40: 'k', 41: 'l', 42: 'm', 43: 'n', 44: 'o', 45: 'p', 46: 'q', 47: 'r', 48: 's', 49: 't', 50: 'u', 51: 'v', 52: 'w', 53: 'x', 54

torch.Size([890627, 15])

In [5]:
# Save Vocab
import json
with open('projects/3-makemore/vocab.json', 'w') as f:
    json.dump(vocab, f)

# Save Encoded Corpus
torch.save(encoded_corpus, 'projects/3-makemore/encoded_corpus.pt')

In [6]:
# Load Vocab
import json
with open('projects/3-makemore/vocab.json') as f:
    vocab = json.load(f)

char2idx = {char: idx for idx, char in enumerate(vocab)}
idx2char = {idx: char for idx, char in enumerate(vocab)}

# Load Encoded Corpus
encoded_corpus = torch.load('projects/3-makemore/encoded_corpus.pt')

In [7]:
train_data = encoded_corpus[:int(0.8 * len(encoded_corpus))]
test_data = encoded_corpus[int(0.8 * len(encoded_corpus)):]

train_loader = torch.utils.data.DataLoader(train_data, batch_size=256, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=256, shuffle=False)

In [8]:
class RNN(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int,
        hidden_dim: int,
        num_layers: int,
        dropout: float,
        nonlinearity: str = "relu",
    ):
        super(RNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            nonlinearity=nonlinearity,
            batch_first=True,
        )
        self.linear = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embedding(x)
        x, hidden = self.rnn(x, hidden)
        x = self.linear(x)
        return x, hidden

    def generate(self, context: torch.Tensor, hidden=None, length=100, temperature=1.0):
        self.eval()
        for _ in range(length):
            x, hidden = self(context, hidden)
            x = x[:, -1, :] / temperature
            x = torch.softmax(x, dim=-1)
            x = torch.multinomial(x, num_samples=1)
            context = torch.cat([context, x], dim=-1)
        self.train()
        return context
    

model = RNN(
    vocab_size=256,
    embedding_dim=128,
    hidden_dim=64,
    num_layers=4,
    dropout=0.1,
)

print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

Model has 86,784 parameters


In [9]:
import tqdm.notebook as tqdm

In [10]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total_loss = 0
    for batch in loader:
        output, _ = model(batch)
        loss = criterion(output[:, :-1].reshape(-1, 256), batch[:, 1:].reshape(-1))
        total_loss += loss.item()
    return total_loss / len(loader)

evaluate(model, test_loader)

5.489738841166441

In [11]:
from ema_pytorch import EMA
ema = EMA(model, beta=0.99)

In [17]:
# Training loop
pbar = tqdm.tqdm(range(20))
for epoch in pbar:
    test_loss = evaluate(model, test_loader)
    model.train()
    for seq in tqdm.tqdm(train_loader, leave=True, desc=f"Batches for epoch {epoch}"):
        x = seq[:, :-1]
        y = seq[:, 1:]

        optimizer.zero_grad()
        output, _ = model(x)
        loss = criterion(output.reshape(-1, 256), y.reshape(-1))
        loss.backward()
        optimizer.step()

        # ema.update()

        pbar.set_description(f"Epoch {epoch}")
        pbar.set_postfix_str(f"Loss: {loss.item():.4f}, Test Loss: {test_loss:.4f}")

  0%|          | 0/20 [00:00<?, ?it/s]

Batches for epoch 0:   0%|          | 0/2784 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [26]:
# Generate a new name
context = encode_string('', include_end=False, include_pad=False).repeat(5, 1)

decode_batch(model.generate(context, length=10))


['Ste', 'Tus', 'An', 'Ala', 'Sha']

In [14]:
# Save Model
torch.save(model.state_dict(), 'projects/3-makemore/model.pt')