# Shakespeare

## Stetup

### Load Dataset

In [None]:
with open("blog/5-shakespeare/data/train.txt", "r") as f:
    train_corpus = f.read()

with open("blog/5-shakespeare/data/test.txt", "r") as f:
    test_corpus = f.read()

### Create Tokenizer

In [None]:
from tokenizers import Tokenizer, models, decoders, trainers, tools, pre_tokenizers

tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
tokenizer.decoder = decoders.ByteLevel()
trainer = trainers.BpeTrainer(special_tokens=["[PAD]", "[SOS]", "[EOS]", "[MASK]", "[UNK]"], vocab_size=4096)

In [None]:
tokenizer.train(["blog/5-shakespeare/data/train.txt"], trainer=trainer)
tokenizer.get_vocab_size()

In [None]:
viz = tools.EncodingVisualizer(tokenizer)
viz(train_corpus[:512])

#### Tokenize Dataset

In [None]:
train_encoded_corpus = tokenizer.encode(train_corpus).ids
val_encoded_corpus = tokenizer.encode(test_corpus).ids

## Train

### LSTM

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
import matplotlib.pyplot as plt

device = "mps"

In [None]:
seq_len = 256
batch_size = 128

# Reshape the encoded corpus into a tensor
train_tensor = torch.tensor([train_encoded_corpus], dtype=torch.long).to(device)
val_tensor = torch.tensor([val_encoded_corpus], dtype=torch.long).to(device)

# Make sure its multple of seq_len
train_tensor = train_tensor[:, :train_tensor.size(1) // seq_len * seq_len].view(-1, seq_len)
val_tensor = val_tensor[:, :val_tensor.size(1) // seq_len * seq_len].view(-1, seq_len)
print(train_tensor.size(), val_tensor.size())

# Create a dataset
train_dataset = torch.utils.data.TensorDataset(train_tensor)
val_dataset = torch.utils.data.TensorDataset(val_tensor)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
class LSTM(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers, dropout, batch_first=True):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers,dropout=dropout, batch_first=batch_first)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden: tuple[torch.Tensor, torch.Tensor] | None = None):
        if hidden is None:
            hidden = (
                torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size, device=x.device),
                torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size, device=x.device),
            )

        x = self.embedding(x)
        x, hidden = self.rnn(x, hidden)
        x = self.fc(x)

        return x, hidden

    @torch.no_grad()
    def generate(
        self,
        start_seq: str,
        max_len: int = 128,
        hidden: tuple[torch.Tensor, torch.Tensor] | None = None,
    ) -> str:
        self.eval()
        start_seq = tokenizer.encode(start_seq).ids

        if hidden is None:
            hidden = (
                torch.randn(self.rnn.num_layers, 1, self.rnn.hidden_size, device=device),
                torch.randn(self.rnn.num_layers, 1, self.rnn.hidden_size, device=device)
            )
        
        x = torch.tensor(start_seq, dtype=torch.long, device=device).view(1, -1)
        output = x.flatten().tolist()
        for _ in range(max_len):
            x, hidden = self(x, hidden)
            if x.shape[1] > 1:
                x = x[:, -1:]

            x = x.softmax(dim=-1).argmax(dim=-1)
            if x.item() == tokenizer.token_to_id("[EOS]"):
                break
            output = output + x.flatten().tolist()
        self.train()
        
        return tokenizer.decode(output)


hidden_size = 512
num_layers = 4
dropout = 0.0

model = LSTM(tokenizer.get_vocab_size(), hidden_size, num_layers, dropout).to(device)
num_train_steps = 0

print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")
# model.generate()
model.generate("The Project")

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
criterion = nn.CrossEntropyLoss()

In [10]:
train_loss = 0
for epoch in range(30):
    model.train()
    pbar = tqdm.tqdm(train_loader, leave=False, desc=f"Epoch {epoch + 1}")
    for x, in pbar:
        optimizer.zero_grad()
        x = x.to(device)
        inputs = x[:, :-1]
        targets = x[:, 1:]

        output, _ = model(inputs)
        output = output.view(-1, output.size(-1))
        targets = targets.flatten()

        loss = criterion(output, targets)
        loss.backward()
        optimizer.step()

        train_loss = 0.9 * train_loss + 0.1 * loss.item()

        pbar.set_postfix(loss=loss.item(), roll_loss=train_loss)

                                               

KeyboardInterrupt: 

In [None]:
print(model.generate("The Project"))