In [None]:
with open('data/tiny-shakespeare/tiny-shakespeare.txt') as f:
    corpus = f.read()
corpus[:1000]

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
vocab = set(''.join(corpus))
vocab = sorted(vocab)
vocab = ['<pad>', '<unk>'] + vocab

char2idx = {char: idx for idx, char in enumerate(vocab)}
idx2char = {idx: char for idx, char in enumerate(vocab)}

print(char2idx)
print(idx2char)
print(len(vocab))


In [2]:
def encode_string(s: str) -> torch.Tensor:
    result = [char2idx.get(char, char2idx['<unk>']) for char in s]

    result = torch.Tensor(result).long()

    return result

def decode_string(indices: torch.Tensor) -> str:
    indices = indices.tolist()
    return ''.join(idx2char[idx] for idx in indices).replace('<pad>', '').replace('<s>', '').replace('</s>', '')

def encode_batch(strings: list[str]) -> torch.Tensor:
    return torch.stack([encode_string(s) for s in strings])

def decode_batch(batch):
    return [decode_string(indices) for indices in batch]

In [None]:
encoded_corpus = encode_batch([corpus]).squeeze(0)
encoded_corpus.shape

In [None]:
# Save Vocab
import json
with open('projects/4-shakespeare/vocab.json', 'w') as f:
    json.dump(vocab, f)

# Save Encoded Corpus
torch.save(encoded_corpus, 'projects/4-shakespeare/encoded_corpus.pt')

In [3]:
# Load Vocab
import json
with open('projects/4-shakespeare/vocab.json') as f:
    vocab = json.load(f)

char2idx = {char: idx for idx, char in enumerate(vocab)}
idx2char = {idx: char for idx, char in enumerate(vocab)}

# Load Encoded Corpus
encoded_corpus = torch.load('projects/4-shakespeare/encoded_corpus.pt')

In [8]:
train_data = encoded_corpus[:int(0.8 * len(encoded_corpus))]
train_data = train_data[:-(train_data.shape[0] % 256)].view(-1, 256)

test_data = encoded_corpus[int(0.8 * len(encoded_corpus)):]
test_data = test_data[:-(test_data.shape[0] % 256)].view(-1, 256)

print(train_data.shape, test_data.shape)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=False)

print(f"Training tokens: {train_data.numel():,}")
print(f"Testing tokens: {test_data.numel():,}")

torch.Size([3485, 256]) torch.Size([871, 256])
Training tokens: 892,160
Testing tokens: 222,976


In [9]:
x = next(iter(train_loader))
x.shape

torch.Size([64, 256])

In [11]:
class RNN(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int,
        hidden_dim: int,
        num_layers: int,
        dropout: float,
        nonlinearity: str = "relu",
    ):
        super(RNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            nonlinearity=nonlinearity,
            batch_first=True,
        )
        self.linear = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embedding(x)
        x, hidden = self.rnn(x, hidden)
        x = self.linear(x)
        return x, hidden

    def generate(self, context: torch.Tensor, hidden=None, length=100, temperature=1.0):
        self.eval()
        for _ in range(length):
            x, hidden = self(context, hidden)
            x = x[:, -1, :] / temperature
            x = torch.softmax(x, dim=-1)
            x = torch.multinomial(x, num_samples=1)
            context = torch.cat([context, x], dim=-1)
        self.train()
        return context
    

model = RNN(
    vocab_size=256,
    embedding_dim=128,
    hidden_dim=64,
    num_layers=8,
    dropout=0,
)

print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1000, 2000, 3000], gamma=0.1)

Model has 120,064 parameters


In [12]:
import tqdm.notebook as tqdm

In [13]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total_loss = 0
    for batch in loader:
        output, _ = model(batch)
        loss = criterion(output[:, :-1].reshape(-1, 256), batch[:, 1:].reshape(-1))
        total_loss += loss.item()
    return total_loss / len(loader)

evaluate(model, test_loader)

5.546577147075108

In [14]:
from ema_pytorch import EMA
ema = EMA(model, beta=0.99)

In [15]:
# Training loop
pbar = tqdm.tqdm(range(20))
for epoch in pbar:
    test_loss = evaluate(model, test_loader)
    model.train()
    for seq in tqdm.tqdm(train_loader, leave=True, desc=f"Batches for epoch {epoch}"):
        x = seq[:, :-1]
        y = seq[:, 1:]

        optimizer.zero_grad()
        output, _ = model(x)
        loss = criterion(output.reshape(-1, 256), y.reshape(-1))
        loss.backward()
        optimizer.step()

        ema.update()

        pbar.set_description(f"Epoch {epoch}")
        pbar.set_postfix_str(f"Loss: {loss.item():.4f}, Test Loss: {test_loss:.4f}")

  0%|          | 0/20 [00:00<?, ?it/s]

Batches for epoch 0:   0%|          | 0/55 [00:00<?, ?it/s]

Batches for epoch 1:   0%|          | 0/55 [00:00<?, ?it/s]

Batches for epoch 2:   0%|          | 0/55 [00:00<?, ?it/s]

Batches for epoch 3:   0%|          | 0/55 [00:00<?, ?it/s]

Batches for epoch 4:   0%|          | 0/55 [00:00<?, ?it/s]

Batches for epoch 5:   0%|          | 0/55 [00:00<?, ?it/s]

Batches for epoch 6:   0%|          | 0/55 [00:00<?, ?it/s]

Batches for epoch 7:   0%|          | 0/55 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [16]:
# Generate a new name
context = encode_string('First Citizen:').repeat(1, 1)

print(decode_batch(model.generate(context, length=256))[0])


First Citizen:
SMRENCAS E:
Whas of hreegtinds colin dirtersiy and
Pomoly,.

Oarse marougbor oparase ule one weulsoor: rark
Horl Cerl'nneryte
 hougoyecnodn mosholeld in o'ots.
Basaye
men ow!

Soved meedp Dolm,,? Hiluslence weur hosfen nemerirt.

LANBURTUITTI:
Hyl lw nevd


In [None]:
# Save Model
torch.save(model.state_dict(), 'projects/4-shakespeare/model.pt')