In [None]:
from tqdm import tqdm

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

In [None]:
path = '../data/shakespeare/input.txt'
with open(path, 'r') as file:
    data = file.read()

In [None]:
token2idx = {token: idx for idx, token in enumerate(set(data))}
idx2token = {idx: token for token, idx in token2idx.items()}

start_token = '<START>'
end_token = '<END>'
token2idx[start_token] = len(token2idx)
idx2token[len(token2idx)] = start_token
token2idx[end_token] = len(token2idx)
idx2token[len(token2idx)] = end_token

data = [start_token] + list(data) + [end_token]
X = [token2idx[token] for token in data]

In [None]:
class TokenDataset(Dataset):
    def __init__(self, sequences, vocab_size):
        self.sequences = sequences
        self.vocab_size = vocab_size
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence = torch.tensor(self.sequences[idx], dtype=torch.long)
        return sequence[:-1], sequence[1:]


class TokenPredictionModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers):
        super(TokenPredictionModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x, hidden=None):
        x = self.embedding(x)
        out, hidden = self.lstm(x, hidden)
        out = self.fc(out)
        return out, hidden

vocab_size = len(token2idx)
embed_dim = 128
hidden_dim = 256
num_layers = 2
batch_size = 32
epochs = 10
device = torch.device("mps")

T = 1024
sequences = [X[idx:idx+T] for idx in range(len(X)-T)]
dataset = TokenDataset(sequences, vocab_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = TokenPredictionModel(vocab_size, embed_dim, hidden_dim, num_layers).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

model.train()
for epoch in range(epochs):
    total_loss = 0
    progress_bar = tqdm(dataloader)
    for inputs, targets in progress_bar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs, _ = model(inputs)
        loss = criterion(outputs.view(-1, model.fc.out_features), targets.view(-1))
        loss.backward()
        optimizer.step()
        current_loss = loss.item()
        progress_bar.set_postfix(loss=current_loss)
        total_loss += current_loss
    torch.save(model.state_dict(), f"model_{epoch+1}.pt")
    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")
