In [None]:
with open('data/makemore/baby_names.txt') as f:
    names = f.read().split()
names[:5]

In [None]:
longest_name = len(max(names, key=len))

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
vocab = set(''.join(names))
vocab = sorted(vocab)
vocab = ['<s>', '</s>', '<pad>', '<unk>'] + vocab

char2idx = {char: idx for idx, char in enumerate(vocab)}
idx2char = {idx: char for idx, char in enumerate(vocab)}

print(char2idx)
print(idx2char)
print(len(vocab))

def encode_string(s: str, include_start=True, include_end=True, include_pad=True) -> torch.Tensor:
    result = [char2idx.get(char, char2idx['<unk>']) for char in s]

    if include_start:
        result = [char2idx['<s>']] + result

    if include_end:
        result = result + [char2idx['</s>']]

    result = torch.Tensor(result).long()

    if include_pad:
        result = F.pad(result, (0, longest_name - len(s)), value=char2idx['<pad>'])

    return result

def decode_string(indices: torch.Tensor) -> str:
    indices = indices.tolist()
    return ''.join(idx2char[idx] for idx in indices).replace('<pad>', '').replace('<s>', '').replace('</s>', '')

def encode_batch(strings: list[str]) -> torch.Tensor:
    return torch.stack([encode_string(s) for s in strings])

def decode_batch(batch):
    return [decode_string(indices) for indices in batch]

encoded_corpus = encode_batch(names)
encoded_corpus.shape

In [None]:
train_data = encoded_corpus[:int(0.8 * len(encoded_corpus))]
test_data = encoded_corpus[int(0.8 * len(encoded_corpus)):]

train_loader = torch.utils.data.DataLoader(train_data, batch_size=256, shuffle=True, num_workers=2, pin_memory=True, persistent_workers=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=256, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True)

In [None]:
device = "mps"

In [None]:
class RNN(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int,
        hidden_dim: int,
        num_layers: int,
        dropout: float,
        nonlinearity: str = "relu",
    ):
        super(RNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            nonlinearity=nonlinearity,
            batch_first=True,
        )
        self.linear = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embedding(x)
        x, hidden = self.rnn(x, hidden)
        x = self.linear(x)
        return x, hidden

    def generate(self, context: torch.Tensor, hidden=None, length=100, temperature=1.0):
        self.eval()
        for _ in range(length):
            x, hidden = self(context, hidden)
            x = x[:, -1, :] / temperature
            x = torch.softmax(x, dim=-1)
            x = torch.multinomial(x, num_samples=1)
            context = torch.cat([context, x], dim=-1)
        self.train()
        return context
    

model = RNN(
    vocab_size=256,
    embedding_dim=128,
    hidden_dim=64,
    num_layers=4,
    dropout=0.,
).to(device)

print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

In [None]:
import tqdm.notebook as tqdm

In [None]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        output, _ = model(batch)
        loss = criterion(output[:, :-1].reshape(-1, 256), batch[:, 1:].reshape(-1))
        total_loss += loss.item()
    return total_loss / len(loader)

evaluate(model, test_loader)

In [None]:
from ema_pytorch import EMA
ema = EMA(model, beta=0.99)

In [None]:
scheduler.step()
scheduler.get_last_lr()

In [None]:
# Training loop
pbar = tqdm.tqdm(range(20))
for epoch in pbar:
    test_loss = evaluate(model, test_loader)
    model.train()
    for seq in tqdm.tqdm(train_loader, leave=True, desc=f"Batches for epoch {epoch}"):
        seq = seq.to(device)
        x = seq[:, :-1]
        y = seq[:, 1:]

        optimizer.zero_grad()
        output, _ = model(x)
        loss = criterion(output.reshape(-1, 256), y.reshape(-1))
        loss.backward()
        optimizer.step()

        ema.update()
        
        pbar.set_description(f"Epoch {epoch}")
        pbar.set_postfix_str(f"Loss: {loss.item():.4f}, Test Loss: {test_loss:.4f}")
        
    scheduler.step()

In [None]:
# Generate a new name
context = encode_string('', include_end=False, include_pad=False).repeat(5, 1).to(device)

decode_batch(model.generate(context, length=10))


In [None]:
# Save Model
torch.save(model.state_dict(), 'projects/3-makemore/model.pt')