In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from collections import Counter
import random

In [2]:
# Reading text file
with open('input_text.txt', 'r') as f:
    text = f.read()

In [3]:
# Preprocessing text
def preprocess(text):
    text = text.lower()
    words = text.split()
    word_counts = Counter(words)
    sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True)
    vocab_to_int = {word: ii for ii, word in enumerate(sorted_vocab, 1)}
    int_to_vocab = {ii: word for word, ii in vocab_to_int.items()}
    
    encoded = [vocab_to_int[word] for word in words]
    return encoded, vocab_to_int, int_to_vocab

In [4]:
encoded_text, vocab_to_int, int_to_vocab = preprocess(text)
vocab_size = len(vocab_to_int) + 1

In [5]:
# encoded: List or NumPy array of integers. This is the tokenized text.
# sequence_length: Integer. The length of each input sequence. 
# batch_size: Integer. The number of sequences that will be 
# processed together in parallel (in a batch).
# Example:
    #encoded = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    #sequence_length = 2
    #batch_size = 2
def create_batches(encoded, sequence_length, batch_size):
    total_length = len(encoded)
    n_batches = total_length // (batch_size * sequence_length)
    encoded = encoded[:n_batches * batch_size * sequence_length]
    input_data = np.array(encoded)
    target_data = np.roll(input_data, -1)
    
    inputs = input_data.reshape((batch_size, -1))
    targets = target_data.reshape((batch_size, -1))

    # inputs: NumPy array, shape (batch_size, num_batches * sequence_length).
    # Example: 
        # inputs = [[1, 2, 3, 4], [5, 6, 7, 8]]
        # targets = [[2, 3, 4, 5], [6, 7, 8, 1]]
    return inputs, targets


In [6]:
# Hyperparameters
sequence_length = 5
batch_size = 4

inputs, targets = create_batches(encoded_text, sequence_length, batch_size)


In [7]:
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, dropout=0.2):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x, hidden):
        x = self.embedding(x)
        out, hidden = self.lstm(x, hidden)
        # Batch size: out.size(0)
        # Sequence length: out.size(1)
        # Hidden size: out.size(2)
        out = self.fc(out.reshape(out.size(0) * out.size(1), out.size(2)))
        return out, hidden

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = (weight.new(self.num_layers, batch_size, self.hidden_size).zero_(),
                  weight.new(self.num_layers, batch_size, self.hidden_size).zero_())
        return hidden


In [8]:
# Hyperparameters for the model
vocab_size = len(vocab_to_int) + 1
embed_size = 256
hidden_size = 512
num_layers = 2
learning_rate = 0.001
epochs = 20


In [9]:
# Instantiate model, loss function and optimizer
model = LSTMModel(vocab_size, embed_size, hidden_size, num_layers)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


In [10]:
# Training loop
model.train()
for epoch in range(epochs):
    hidden = model.init_hidden(batch_size)
    
    for i in range(0, inputs.shape[1], sequence_length):
        # Prepare inputs and targets
        input_batch = torch.tensor(inputs[:, i:i + sequence_length], dtype=torch.long)
        target_batch = torch.tensor(targets[:, i:i + sequence_length], dtype=torch.long)
        
        # Detach hidden state to prevent backpropagating through entire history
        hidden = tuple([each.detach() for each in hidden])
        
        # Reset gradients
        optimizer.zero_grad()
        
        # Forward pass
        output, hidden = model(input_batch, hidden)
        loss = criterion(output, target_batch.view(-1))
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        if i % (sequence_length * 10) == 0:
            print(f'Epoch [{epoch + 1}/{epochs}], Step [{i}/{inputs.shape[1]}], Loss: {loss.item():.4f}')


Epoch [1/20], Step [0/20], Loss: 3.6971
Epoch [2/20], Step [0/20], Loss: 3.3524
Epoch [3/20], Step [0/20], Loss: 3.1618
Epoch [4/20], Step [0/20], Loss: 3.0255
Epoch [5/20], Step [0/20], Loss: 2.7335
Epoch [6/20], Step [0/20], Loss: 2.3177
Epoch [7/20], Step [0/20], Loss: 1.8860
Epoch [8/20], Step [0/20], Loss: 1.4625
Epoch [9/20], Step [0/20], Loss: 1.1086
Epoch [10/20], Step [0/20], Loss: 0.7997
Epoch [11/20], Step [0/20], Loss: 0.5802
Epoch [12/20], Step [0/20], Loss: 0.4291
Epoch [13/20], Step [0/20], Loss: 0.3400
Epoch [14/20], Step [0/20], Loss: 0.2538
Epoch [15/20], Step [0/20], Loss: 0.2298
Epoch [16/20], Step [0/20], Loss: 0.1829
Epoch [17/20], Step [0/20], Loss: 0.1615
Epoch [18/20], Step [0/20], Loss: 0.1344
Epoch [19/20], Step [0/20], Loss: 0.1199
Epoch [20/20], Step [0/20], Loss: 0.1149


In [11]:
def predict(model, word, vocab_to_int, int_to_vocab, hidden=None, top_k=5):
    # Convert word to integer (token)
    x = np.array([[vocab_to_int[word]]])
    inputs = torch.tensor(x, dtype=torch.long)
    
    # Ensure hidden state is detached from the computational graph
    hidden = tuple([each.data for each in hidden])
    
    # Forward pass through the model
    out, hidden = model(inputs, hidden)
    
    # Apply softmax to get probabilities
    p = torch.nn.functional.softmax(out, dim=1).data
    p, top_tokens = p.topk(top_k)
    
    # Convert to numpy arrays for processing
    top_tokens = top_tokens.numpy().squeeze()
    p = p.numpy().squeeze()
    
    # Randomly select the next word based on probabilities
    next_word_token = np.random.choice(top_tokens, p=p/p.sum())
    
    # Return the predicted word and the hidden state
    return int_to_vocab[next_word_token], hidden


In [27]:
def sample(model, size, prime='the', top_k=5):
    model.eval()  # Set model to evaluation mode
    
    # Start with the prime words
    words = prime.lower().split()  # Split prime into words
    
    # Initialize the hidden state
    hidden = model.init_hidden(1)
    
    # Pass through the prime words to prime the model
    for word in words:
        if word in vocab_to_int:  # Ensure the word exists in the vocab
            predicted_word, hidden = \
            predict(model, word, vocab_to_int, int_to_vocab, hidden, top_k)
        else:
            print(f"Warning: Word '{word}' not found in vocabulary.")
            return ' '.join(words)  
            # Return the prime words if a word is missing from vocab
    
    # Add the first predicted word
    words.append(predicted_word)
    
    # Generate `size` number of additional words
    for _ in range(size):
        word, hidden = predict(model, words[-1], vocab_to_int, int_to_vocab, hidden, top_k)
        words.append(word)
    
    # Join words into a single string and return the generated text
    return ' '.join(words)

In [30]:
# Example of generating text after training
generated_text = sample(model, 200, prime='The bird', top_k=3)
print(generated_text)

the bird sang in the tree. the bird liked to play near the tree. the tree was and the bird the bird the bird the over the bird liked to play near the tree. the tree was tall and the bird the bird the bird the over the bird liked to play near the tree. the tree was tall and the bird the bird the bird the bird the trees. the bird the bird the liked to play near the the tree. the tree was tall and the bird the bird the bird the trees. the bird the bird the liked to play near the tree. the tree was tall and the bird the bird the bird the bird the bird liked to play near the tree. the tree was tall and the bird the bird the bird liked to play near the tree. the tree was tall and the bird the bird the bird the bird the over the bird liked to near the tree. the tree was tall and the bird the bird the bird the over the bird liked to play near the tree. the tree was tall and the bird the bird the bird the bird the trees. the
