In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Load the dataset
with open("Sonnets.txt", "r") as file:
    sonnets_text = file.read()

# Tokenize the sonnets at the character level
chars = sorted(list(set(sonnets_text)))
char_to_idx = {char: i for i, char in enumerate(chars)}
idx_to_char = {i: char for i, char in enumerate(chars)}
char_sequence = [char_to_idx[char] for char in sonnets_text]

# Define a character-based dataset
class CharDataset(Dataset):
    def __init__(self, sequence, sequence_length):
        self.sequence = sequence
        self.sequence_length = sequence_length

    def __len__(self):
        return len(self.sequence) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.sequence[index:index+self.sequence_length]),
            torch.tensor(self.sequence[index+1:index+self.sequence_length+1])
        )

sequence_length = 100
char_dataset = CharDataset(char_sequence, sequence_length)
char_dataloader = DataLoader(char_dataset, batch_size=32, shuffle=True)

# Define the LSTM model
class CharLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, dropout):
        super(CharLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, hidden):
        x = self.embedding(x)
        x, hidden = self.lstm(x, hidden)
        x = self.fc(self.dropout(x))
        return x, hidden

# Hyperparameters
char_vocab_size = len(chars)
embedding_dim = 64
hidden_dim = 128
output_dim = char_vocab_size
n_layers = 2
dropout = 0.5

char_model = CharLSTM(char_vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, dropout)
char_criterion = nn.CrossEntropyLoss()
char_optimizer = torch.optim.Adam(char_model.parameters(), lr=0.001)

# Train the model
n_epochs = 15
for epoch in range(n_epochs):
    total_loss = 0

    for inputs, targets in char_dataloader:
        batch_size = inputs.size(0)  # Corrected placement
        hidden = (torch.zeros(n_layers, batch_size, hidden_dim), torch.zeros(n_layers, batch_size, hidden_dim))

        char_optimizer.zero_grad()
        outputs, hidden = char_model(inputs, hidden)
        hidden = (hidden[0].detach(), hidden[1].detach())
        outputs = outputs.view(-1, outputs.shape[2])
        targets = targets.view(-1)
        loss = char_criterion(outputs, targets)
        loss.backward()
        char_optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{n_epochs}, Loss: {total_loss/len(char_dataloader)}")

# Save the trained model
torch.save(char_model.state_dict(), 'char_lstm_model2.pth')

# To generate text after training:
def generate_text(model, start_string, generate_length=100):
    model.eval()
    text_generated = []
    input_eval = torch.tensor([char_to_idx[s] for s in start_string]).unsqueeze(0)
    hidden = (torch.zeros(n_layers, 1, hidden_dim), torch.zeros(n_layers, 1, hidden_dim))

    for i in range(generate_length):
        outputs, hidden = model(input_eval, hidden)
        predicted_id = torch.argmax(outputs, dim=2)[-1, 0].item()
        input_eval = torch.tensor([[predicted_id]])
        text_generated.append(idx_to_char[predicted_id])

    return start_string + ''.join(text_generated)

# Generate a sonnet
new_sonnet = generate_text(char_model, start_string="Shall I compare thee to a summer's day?", generate_length=400)
print(new_sonnet)


Epoch 1/15, Loss: 1.9069862222444212
Epoch 2/15, Loss: 1.5635372025243353
Epoch 3/15, Loss: 1.4643457078349238
Epoch 4/15, Loss: 1.409766277367459
Epoch 5/15, Loss: 1.3736680722131103
Epoch 6/15, Loss: 1.3480362291030832
Epoch 7/15, Loss: 1.3283089159863881
Epoch 8/15, Loss: 1.3126506206735202
Epoch 9/15, Loss: 1.2991797537636425
Epoch 10/15, Loss: 1.28810113374173
Epoch 11/15, Loss: 1.2784171856541515
Epoch 12/15, Loss: 1.2704241371982796
Epoch 13/15, Loss: 1.2630448445276943
Epoch 14/15, Loss: 1.2565489055187393
Epoch 15/15, Loss: 1.250594989962021
Shall I compare thee to a summer's day?o
Where thou art this thou art thou art thou art still,
And to the stand the stars to these stands that thee still,
And therefore to thee, thou art thou art thou shalt store,
And then thou these stars to these stand to thee,
And thou art thou art those this to the stars
And that thou that thou art thine eyes of truth,
And then thou art the stars to these standed stay
The stars to these stands that
