In [None]:
import json

import os

from pathlib import Path

import kagglehub

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

In [None]:
input_directory = kagglehub.dataset_download("mruanova/shakespeare")
input_filepath = os.path.join(input_directory, "shakespeare.txt")
output_filepath = os.path.join("results", "shakespeare-lstm")

Path(output_filepath).mkdir(parents=True, exist_ok=True)

In [None]:
with open(input_filepath, 'r') as file:
    data = file.read()

In [None]:
token2idx = {token: idx for idx, token in enumerate(set(data))}
start_token = '<START>'
end_token = '<END>'
token2idx[start_token] = len(token2idx)
token2idx[end_token] = len(token2idx)
idx2token = {idx: token for token, idx in token2idx.items()}

data = [start_token] + list(data) + [end_token]
data = [token2idx[token] for token in data]

with open(f'{output_filepath}/token2idx.json', 'w') as file:
    json.dump(token2idx, file)

with open(f'{output_filepath}/idx2token.json', 'w') as file:
    json.dump(idx2token, file)

In [None]:
# Hyperparameters
vocab_size = len(token2idx)
embed_dim = 128
hidden_dim = 256
num_layers = 2
batch_size = 32
epochs = 3

device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

In [None]:
def sample_model(sequence, k=3):
    inputs = torch.tensor([token2idx[token] for token in sequence], dtype=torch.long).to(device)
    
    with torch.no_grad():
        logits, _ = model(inputs)
    
    topk_logits, topk_indices = torch.topk(logits, k, dim=-1)
    topk_probs = F.softmax(topk_logits, dim=-1)
    
    sampled_idxs = torch.multinomial(topk_probs, num_samples=1)
    next_token = topk_indices.gather(-1, sampled_idxs)
    next_idx = next_token.squeeze(-1).cpu().detach().tolist()[-1]
    new_token = idx2token[next_idx]
    
    return new_token


def sample_new_sequence(max_length):
    sequence = [start_token]
    for _ in range(max_length):
        new_token = sample_model(sequence)
        if new_token == end_token:
            break
        sequence.append(new_token)
    return sequence[1:]


class TokenDataset(Dataset):
    def __init__(self, sequences, vocab_size):
        self.sequences = sequences
        self.vocab_size = vocab_size
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence = torch.tensor(self.sequences[idx], dtype=torch.long)
        return sequence[:-1], sequence[1:]


class TokenPredictionModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers):
        super(TokenPredictionModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x, hidden=None):
        x = self.embedding(x)
        out, hidden = self.lstm(x, hidden)
        out = self.fc(out)
        return out, hidden


T = 1024
sequences = [data[idx:idx+T] for idx in range(len(data)-T)]
dataset = TokenDataset(sequences, vocab_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

test_length = 1000

model = TokenPredictionModel(vocab_size, embed_dim, hidden_dim, num_layers).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

model.train()
for epoch in range(epochs):
    total_loss = 0
    progress_bar = tqdm(dataloader)
    for inputs, targets in progress_bar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        logits, _ = model(inputs)
        loss = criterion(logits.view(-1, model.fc.out_features), targets.view(-1))
        loss.backward()
        optimizer.step()
        current_loss = loss.item()
        progress_bar.set_postfix(loss=current_loss)
        total_loss += current_loss
    torch.save(model.state_dict(), f"{output_filepath}/model_{epoch+1}.pt")
    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")
    test_sequence = sample_new_sequence(test_length)
    print(''.join(test_sequence))
