In [None]:
device = "mps"

### Create and format Dataset

In [None]:
import random
import itertools

In [None]:
with open('blog/3-names/data/names.txt') as f:
    names = f.read().split()

random.shuffle(names)

train_test = (0.8, 0.2)

train_names = names[:int(len(names) * train_test[0])]
test_names = names[int(len(names) * train_test[0]):]

### Create tokenizer

In [None]:
special_tokens = ['<s>', '</s>', '<unk>', '<pad>']
vocab = special_tokens + sorted(set(''.join(names)))
char2idx = {c: i for i, c in enumerate(vocab)}
idx2char = {i: c for i, c in enumerate(vocab)}

# encode names
train_names_encoded = [[char2idx[c] for c in name] for name in train_names]
test_names_encoded = [[char2idx[c] for c in name] for name in test_names]

# add start and end tokens
train_names_encoded = [[char2idx['<s>']] + name + [char2idx['</s>']] for name in train_names_encoded]
test_names_encoded = [[char2idx['<s>']] + name + [char2idx['</s>']] for name in test_names_encoded]

# Flatten dataset
train_names_encoded = list(itertools.chain(*train_names_encoded))
test_names_encoded = list(itertools.chain(*test_names_encoded))

print('Vocab size:', len(vocab))
print('Vocabulary:', vocab)

### Helper functions

In [None]:
import torch
import torch.nn as nn

train_encoded_corpus = torch.tensor(train_names_encoded).flatten().to(device)
test_encoded_corpus = torch.tensor(test_names_encoded).flatten().to(device)

def encode(text: str) -> torch.Tensor:
    return torch.tensor([char2idx[c] for c in text], device=device)

def decode(tensor: torch.Tensor) -> str:
    return ''.join([idx2char[i] for i in tensor])

def encode_batch(batch: list) -> torch.Tensor:
    return torch.tensor([encode(text) for text in batch], device=device)

def decode_batch(batch: torch.Tensor) -> list:
    return [decode(tensor) for tensor in batch]

In [None]:
def get_batch(batch_size: int, seq_len: int, train: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
    if train:
        data = train_encoded_corpus
        
    else:
        data = test_encoded_corpus

    idx = torch.randint(0, len(data) - seq_len, (batch_size,))
    x = torch.stack([data[i:i+seq_len] for i in idx])
    y = torch.stack([data[i+1:i+1+seq_len] for i in idx])
    return x, y

get_batch(2, 5)

### Define Model

In [None]:
@torch.no_grad()
def test(model):
    x, y = get_batch(256, 64)
    
    out, _ = model(x)
    loss = nn.functional.cross_entropy(out.flatten(0, 1), y.flatten())
    return loss.item()

In [None]:
class RNN(nn.Module):
    def __init__(
        self, vocab_size: int, hidden_size: int, num_layers: int, dropout: float
    ):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size, num_layers, dropout=dropout, nonlinearity='relu', batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, h=None):
        if h is None:
            h = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size, device=x.device)

        x = self.emb(x)
        x, h = self.rnn(x, h)
        x = self.fc(x)
        return x, h

    @torch.no_grad()
    def generate(self, max_tokens: int, start_seq: str = '') -> str:
        h = None
        x = torch.cat([torch.tensor([0], device=device), encode(start_seq)]).long().unsqueeze(0)
        out = [xi.item() for xi in x.flatten()]
        for _ in range(max_tokens):
            x, h = self.forward(x, h)
            x = x[:, -1].argmax(-1).unsqueeze(0)
            
            out.append(x.item())
            if x.item() == char2idx["</s>"]:
                break
        return decode(out).replace('<s>', '').replace('</s>', '')


model = RNN(len(vocab), hidden_size=128, num_layers=6, dropout=0).to(device)
print(model.generate(128, start_seq='A'))
print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")

### Train

In [None]:
# Training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()


In [None]:
import tqdm
test_loss = 0

pbar = tqdm.trange(10_000, desc='Training', unit='step')
for step in pbar:
    x, y = get_batch(256, 64)
    out, _ = model(x)
    loss = criterion(out.flatten(0, 1), y.flatten())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    pbar.set_postfix(loss=loss.item(), test_loss=test_loss)
    
    if step % 100 == 0:
        test_loss = test(model)

In [None]:
for i in range(10):
    print(model.generate(128, start_seq='').replace('<s>', '').replace('</s>', '').strip())

### LSTM

In [None]:
class LSTM(nn.Module):
    def __init__(
        self, vocab_size: int, hidden_size: int, num_layers: int, dropout: float
    ):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.LSTM(
            hidden_size, hidden_size, num_layers, dropout=dropout, batch_first=True
        )
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, h=None):
        if h is None:
            h = torch.zeros(
                self.rnn.num_layers, x.size(0), self.rnn.hidden_size, device=x.device
            )

        x = self.emb(x)
        print(x.shape, h.shape)
        x, h = self.rnn(x, h)
        x = self.fc(x)
        return x, h

    @torch.no_grad()
    def generate(self, max_tokens: int, start_seq: str = "") -> str:
        h = None
        x = (
            torch.cat([torch.tensor([0], device=device), encode(start_seq)])
            .long()
            .unsqueeze(0)
        )
        print(x.shape)
        out = [xi.item() for xi in x.flatten()]
        for _ in range(max_tokens):
            x, h = self.forward(x, h)
            x = x.argmax(-1)
            x = x[:, -1].unsqueeze(0)
            out.append(x.item())
            if x.item() == char2idx["</s>"]:
                break
        return decode(out).replace("<s>", "").replace("</s>", "")


model = LSTM(len(vocab), hidden_size=128, num_layers=8, dropout=0).to(device)
print(model.generate(128, start_seq="A"))
print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")

In [None]:
test(model)

In [None]:
test_loss = 0

pbar = tqdm.trange(1_000, desc='Training', unit='step')
for step in pbar:
    x, y = get_batch(256, 64)
    out, _ = model(x)
    loss = criterion(out.flatten(0, 1), y.flatten())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    pbar.set_postfix(loss=loss.item(), test_loss=test_loss)
    
    if step % 100 == 0:
        test_loss = test(model)

In [None]:
class GRU(nn.Module):
    def __init__(
        self, vocab_size: int, hidden_size: int, num_layers: int, dropout: float
    ):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.GRU(hidden_size, hidden_size, num_layers, dropout=dropout, nonlinearity='relu', batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, h=None):
        if h is None:
            h = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size, device=x.device)

        x = self.emb(x)
        x, h = self.rnn(x, h)
        x = self.fc(x)
        return x, h

    @torch.no_grad()
    def generate(self, max_tokens: int, start_seq: str = '') -> str:
        h = None
        x = torch.cat([torch.tensor([0], device=device), encode(start_seq)]).long().unsqueeze(0)
        out = [xi.item() for xi in x.flatten()]
        for _ in range(max_tokens):
            x, h = self.forward(x, h)
            x = x.argmax(-1)
            x = x[:, -1].unsqueeze(0)
            out.append(x.item())
            if x.item() == char2idx["</s>"]:
                break
        return decode(out).replace('<s>', '').replace('</s>', '')


model = GRU(len(vocab), hidden_size=128, num_layers=8, dropout=0).to(device)
print(model.generate(128, start_seq=''))
print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")