# Text Gen with LSTM

In [2]:
import sys
import numpy as np
import torch 
import torch.nn as nn

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# load data as ASCII
# and lower all
filename = "data/wonderland.txt"
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = raw_text.lower()

In [5]:
# Summarize the dataset
# set to remove duplicates
# sorted to sort the characters
# preapare a dictionary for mapping characters to integers
chars = sorted(list(set(raw_text)))
chars_to_int = dict((c, i) for i, c in enumerate(chars))
int_to_chars = dict((i, c) for i, c in enumerate(chars))
n_chars = len(raw_text)
n_vocab = len(chars)
print("Total Characters: ", n_chars)
print("Total Vocab: ", n_vocab)

Total Characters:  144512
Total Vocab:  45


In [6]:
# preapre data_length to gen the next character
seq_length = 50
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):
    seq_in = raw_text[i:i + seq_length]
    seq_out = raw_text[i + seq_length]
    dataX.append([chars_to_int[char] for char in seq_in])
    dataY.append(chars_to_int[seq_out])

n_patterns = len(dataX)
print("Total Patterns: ", n_patterns)

Total Patterns:  144462


In [7]:
# change the shape of the data to format for LSTM
# [samples, time steps, features]
# normalize the data
X = np.reshape(dataX, (n_patterns, seq_length, 1))
X = X / float(n_vocab)
y = dataY

In [8]:
class TextModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.lstm = nn.LSTM(1, 256, batch_first=True)
        self.classifier = nn.Sequential(
            nn.Linear(256, 256),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(256, n_vocab),
        )
    def forward(self, x):
        output, (h, c) = self.lstm(x)
        # h is the hidden state of the LSTM
        # c is the cell state of the LSTM
        x = self.classifier(h.squeeze(0))
        return x

In [9]:
x = torch.randn(2, 50, 1).to(device)
net = TextModel().to(device)


In [10]:
from tqdm import tqdm
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader

In [11]:
# It is designed to store and provide training data for a sequence-based task, such as language modeling or text generation
class TextDataset(Dataset):
    def __init__(self, data, next_chars):
        super().__init__()

        self.data = data
        self.next_chars = next_chars

    def __getitem__(self, index):
        return torch.tensor(self.data[index], dtype=torch.float32), self.next_chars[index]

    def __len__(self):
        return len(self.data)

In [12]:
text_dataset = TextDataset(X, y)
text_loader = DataLoader(
    dataset=text_dataset,
    shuffle=True,
    batch_size=32,
    num_workers=0
)


In [13]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

In [None]:
from tqdm import tqdm


num_epochs = 500
best_loss = 99999999
for epoch in range(num_epochs):
    train_tqdm = tqdm(enumerate(text_loader), total=len(text_loader))
    total_loss = 0
    total_correct = 0
    total_samples = 0
    for i, data in train_tqdm:
        # Separete input and output
        inputs, labels = data
        # feed forward

        # move data to device
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = net(inputs)

        # loss calculation
        loss = loss_fn(outputs, labels)

        # reset gradient
        optimizer.zero_grad()
        # calculate gradient
        loss.backward()

        # update weight
        optimizer.step()

        # calculate total loss
        total_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        total_correct += (predicted == labels).sum().item()
        total_samples += labels.size(0)
        # update progress and show loss
        train_tqdm.set_description(f"Epoch {epoch}: Total loss: {total_loss/(i + 1)}, Accuracy: {total_correct / total_samples:.4f}")

    train_loss = total_loss / len(text_loader)

    if train_loss <= best_loss:
        print(f"Save best model with loss = {train_loss}")
        best_loss = train_loss
        torch.save(net.state_dict(), f"weights/best_char_gen.pth")

In [78]:
# Assuming the best weights are saved in a file 'best_model.pth'
model = TextModel()
model.load_state_dict(torch.load('weights/best_char_gen.pth'))
model.eval()  # Set model to evaluation mode

def predict_next_char(input_str, model, chars_to_int, int_to_chars, n_vocab):
    # Convert input string to a list of integers
    input_ints = [chars_to_int[char] for char in input_str]
    
    # Prepare the input tensor, adding batch dimension and sequence length
    input_tensor = torch.tensor(input_ints).unsqueeze(0).unsqueeze(-1).float()
    input_tensor = input_tensor / float(n_vocab)
    # Pass the input through the model to get the predicted index
    with torch.no_grad():  # No need to calculate gradients during inference
        output = model(input_tensor)
    
    # Get the predicted character's index (taking the argmax to get the most likely class)
    predicted_index = torch.argmax(output, dim=-1).item()
    # Convert predicted index back to the corresponding character
    predicted_char = int_to_chars[predicted_index]
    return predicted_char

# Example of predicting the next character after 'hello'
input_str = 'very soon she felt on her eyes tere '
predicted_char = predict_next_char(input_str, model, chars_to_int, int_to_chars, n_vocab)
print(f"The next predicted character after '{input_str}' is: '{predicted_char}'")

The next predicted character after 'very soon she felt on her eyes tere ' is: 'w'


  model.load_state_dict(torch.load('weights/best_char_gen.pth'))
