In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from nltk import word_tokenize
from collections import Counter
import re

In [2]:
class TextDataset(Dataset):
    def __init__(self, text, sequence_length, min_freq=0):
        self.tokens = word_tokenize(text)
        token_counts = Counter(self.tokens)
        
        # Build vocabulary and word-to-index and index-to-word mappings
        self.vocab = ['<unk>'] + sorted([token for token, count in token_counts.items() if count >= min_freq], key=token_counts.get, reverse=True)
        self.token2idx = {w: idx for idx, w in enumerate(self.vocab)}
        self.idx2token = {idx: w for idx, w in enumerate(self.vocab)}
        self.vocab_size = len(self.vocab)
        
        # Convert tokens sequence to indices sequence
        self.indices = [self.token2idx.get(token, self.unk) for token in self.tokens]
        
        # Prepare sequences
        self.sequence_length = sequence_length
        self.data = []
        for i in range(len(self.indices) - sequence_length):
            input_seq = self.indices[i:i+sequence_length]
            target_seq = self.indices[i+1:i+sequence_length+1]
            self.data.append((input_seq, target_seq))

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return torch.tensor(self.data[index][0]), torch.tensor(self.data[index][1])

    @property
    def unk(self):
        return self.token2idx['<unk>']

In [3]:
sequence_length = 5
batch_size = 32
text = open('time_machine.txt', 'r').read()
text = re.sub("[^A-Za-z]+", " ", text).lower()

dataset = TextDataset(text, sequence_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [4]:
class NextWordPredictor(nn.Module):
    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256):
        super(NextWordPredictor, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, x, hidden):
        # Convert word indices to embeddings
        x = self.embedding(x)
        # Pass through LSTM
        out, hidden = self.lstm(x, hidden)
        # Pass through fully connected layer
        out = self.fc(out)
        return out, hidden
    
    def init_hidden(self, batch_size):
        # Initialize hidden state and cell state with zeros for LSTM
        h0 = torch.zeros(1, batch_size, hidden_dim)  # Hidden state
        c0 = torch.zeros(1, batch_size, hidden_dim)  # Cell state
        return (h0, c0)  # Return as tuple

In [5]:
# Hyperparameters
embedding_dim = 128
hidden_dim = 256
num_epochs = 10
learning_rate = 0.001

# Initialize the model, loss function, and optimizer
model = NextWordPredictor(vocab_size=dataset.vocab_size, embedding_dim=embedding_dim, hidden_dim=hidden_dim)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    for i, (input_seq, target_seq) in enumerate(dataloader):
        hidden = model.init_hidden(input_seq.size(0))  # Initialize hidden state for each epoch
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        output, hidden = model(input_seq, hidden)
        
        # Reshape output and target for loss calculation
        output = output.view(-1, dataset.vocab_size)
        target_seq = target_seq.view(-1)
        
        # Calculate loss and backpropagate
        loss = criterion(output, target_seq)
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0 or i == len(dataloader):
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}')

Epoch [1/10], Step [1/1029], Loss: 8.4325
Epoch [1/10], Step [101/1029], Loss: 6.5295
Epoch [1/10], Step [201/1029], Loss: 6.0164
Epoch [1/10], Step [301/1029], Loss: 5.7161
Epoch [1/10], Step [401/1029], Loss: 5.6383
Epoch [1/10], Step [501/1029], Loss: 5.2015
Epoch [1/10], Step [601/1029], Loss: 5.3731
Epoch [1/10], Step [701/1029], Loss: 5.1442
Epoch [1/10], Step [801/1029], Loss: 5.2658
Epoch [1/10], Step [901/1029], Loss: 4.9388
Epoch [1/10], Step [1001/1029], Loss: 4.2433
Epoch [2/10], Step [1/1029], Loss: 4.4742
Epoch [2/10], Step [101/1029], Loss: 4.2591
Epoch [2/10], Step [201/1029], Loss: 4.1947
Epoch [2/10], Step [301/1029], Loss: 3.9466
Epoch [2/10], Step [401/1029], Loss: 3.9278
Epoch [2/10], Step [501/1029], Loss: 3.9600
Epoch [2/10], Step [601/1029], Loss: 3.8016
Epoch [2/10], Step [701/1029], Loss: 3.8289
Epoch [2/10], Step [801/1029], Loss: 3.6471
Epoch [2/10], Step [901/1029], Loss: 3.5144
Epoch [2/10], Step [1001/1029], Loss: 3.3520
Epoch [3/10], Step [1/1029], Loss:

In [10]:
def predict_next_word(model, start_text, num_words=5):
    model.eval()
    
    # Prepare input sequence
    words = start_text.lower().split()
    input_seq = [dataset.token2idx[word] for word in words]
    input_seq = torch.tensor(input_seq).unsqueeze(0)
    
    # Initialize hidden state
    hidden = model.init_hidden(1)
    
    # Predict words
    predicted_words = words.copy()
    for _ in range(num_words):
        output, hidden = model(input_seq, hidden)
        
        # Get the last time-step's output
        last_word_logits = output[:, -1, :]
        _, predicted_idx = torch.max(last_word_logits, dim=1)
        
        # Convert predicted index to word and add to predicted words
        predicted_word = dataset.idx2token[predicted_idx.item()]
        predicted_words.append(predicted_word)
        
        # Prepare the next input sequence with the predicted word
        input_seq = torch.cat([input_seq[:, 1:], predicted_idx.unsqueeze(0)], dim=1)
    
    return ' '.join(predicted_words)

In [11]:
DEFAULT_LENGTH = 15
context = input("Input initial context (only words separated by space): ")
length = input(f"Enter length of generated text (default: {DEFAULT_LENGTH}): ")
if len(length) == 0:
    length = DEFAULT_LENGTH

print("Output:\n")
print(predict_next_word(model, context, num_words=DEFAULT_LENGTH))

Input initial context (only words separated by space):  this is
Enter length of generated text (default: 15):  20


Output:

this is so extensively overlooked continued the little table then he turned lighting his pipe puffing to
