Start with the imports  

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import string
import random


Then we define our TextRNN class, which inherits from the Torch class 'nn.Module'

In [None]:
class TextRNN(nn.Module):

    def __init__(self, vocab_size, hidden_size=128, num_layers=2):
        super(TextRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.vocab_size = vocab_size

        self.embedding = nn.Embedding(vocab_size, hidden_size)

        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)

        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None):
        embedded = self.embedding(x)

        if hidden is None:
            output, hidden = self.lstm(embedded)
        else:
            output, hidden = self.lstm(embedded, hidden)

        output = output.reshape(-1, self.hidden_size)
        output = self.fc(output)        

        return output, hidden 
    
    def init_hidden(self, batch_size, device):
        h0 = torch.zeros(self.num_layers, batch_size, sel.fhidden_size).to(device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        return (h0, c0)
    


Next, we define a class "TextDataset" to load in the data, and handle randomly sampling the input. 

In [None]:
class TextDataset:
    
    #seq_length is a hyperparameter that controls the memory context length, and
    #thus also affects efficiency. Longer sequences can result in more context and longer
    #patterns being learned, but requires more memory and compute time. 
    def __init__(self, text_file, seq_length=100):
        with open(text_file, 'r', encoding='utf-8') as f:
            self.text = f.read()

        self.chars = sorted(list(set(self.text)))
        self.vocab_size = len(self.chars)
        self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
        self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}

        self.seq_length = seq_length

        self.data = [self.char_to_idx[ch] for ch in self.text]

    # Here we create randomly sampled batches of sequences, and put them
    # into tensors. In the "target" tensor, the sequence is shifted by one character
    # from the corresponding input tensor sequence, so we can verify the prediction 
    # for the next character in the sequence.
    #     
    def get_batch(self, batch_size):
        start_indices = [random.randint(0, len(self.data) - self.seq_length - 1) 
                         for _ in range(batch_size)]
        
        inputs = []
        targets = []

        for start_idx in start_indices:
            input_seq = self.data[start_idx:(start_idx + self.seq_length)]
            target_seq = self.data[(start_idx + 1): (start_idx + self.seq_length + 1)]
            inputs.append(input_seq)
            targets.append(target_seq)

        return torch.tensor(inputs), torch.tensor(targets)


                                                        

Now we'll add the training function, and the generation function

In [11]:
def train_model(text_file, epochs=100, batch_size=32, learning_rate=0.001, 
                hidden_size=128, num_layers=2, seq_length=100):
    
    #device initialization - could add support if on Apple Silicon
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'Device initalialized: {device}')

    dataset = TextDataset(text_file, seq_length)
    print(f'Vocab size: {dataset.vocab_size}')
    print(f'Text length: {len(dataset.text)}')

    model = TextRNN(dataset.vocab_size, hidden_size, num_layers).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    print('Starting training...')

    for epoch in range(epochs):
        model.train()

        total_loss = 0.0

        num_batches = 20

        for batch in range(num_batches):
            inputs, targets = dataset.get_batch(batch_size)
            inputs, targets = inputs.to(device), targets.to(device)

            # clearing the gradients for each batch - if they accumulated, we'd end up
            # with huge update gradients, instability, and massive loss. 
            optimizer.zero_grad()

            # this function runs the model on the inputs, returning the predictions
            outputs, _ = model(inputs)

            targets = targets.reshape(-1)

            # calculating the loss (uses nn.CrossEntropyLoss)
            loss = criterion(outputs, targets)

            loss.backward()  # backpropagation to calculate gradients
            optimizer.step() # update the model parameters
            total_loss += loss.item()

        avg_loss = total_loss / num_batches

        #print the avg loss every 10 epochs
        if (epoch +1) % 10 == 0:
            print(f'Epoch [{epoch + 1}/{epochs}], Loss: {avg_loss:.4f}')
    
    print('Training complete.')

    return model, dataset

def generate_text(model, dataset, device, seed_text='', length=100, temperature=1.0):
    model.eval()

    if not seed_text:
        seed_text = random.choice(dataset.text)
    
    current_seq = [dataset.char_to_idx.get(ch, 0) for ch in seed_text]
    generated = seed_text

    with torch.no_grad():
        hidden = None

        for _ in range(length):

            if len(current_seq) > dataset.seq_length:
                current_seq = current_seq[-dataset.seq_length]
                
            input_tensor = torch.tensor([current_seq]).to(device)
            
            output, hidden = model(input_tensor, hidden)

            last_output = output[-1] / temperature
            probabilities = torch.softmax(last_output, dim=0)

            next_char_idx = torch.multinomial(probabilities, 1).item()
            next_char = dataset.idx_to_char[next_char_idx]

            generated += next_char

    return generated

        




Generate a sample file if we need it for testing

In [13]:

sample_training_text = """ Far out in the uncharted backwaters of the unfashionable end of the western spiral arm of the Galaxy lies a small unregarded yellow sun. Orbiting this at a distance of roughly ninety-two million miles 
                        is an utterly insignificant little blue-green planet whose ape-descended life forms are so amazingly primitive that they still think digital watches are a pretty neat idea.
                        This planet has—or rather had—a problem, which was this: most of the people living on it were unhappy for pretty much all of the time. Many solutions were suggested for this problem, 
                        but most of these were largely concerned with the movement of small green pieces of paper, which was odd because on the whole it wasn't the small green pieces of paper that were unhappy. Many were increasingly of the opinion that they'd all made a big mistake in coming down from the trees in the first place. And some said that even the trees had been a bad move, and that no one should ever have left the oceans.
In many of the more relaxed civilizations on the Outer Eastern Rim of the Galaxy, the Hitchhiker's Guide has already supplanted the great Encyclopaedia Galactica as the standard repository of all knowledge and wisdom, for though it has many omissions and contains much that is apocryphal, or at least wildly inaccurate, it scores over the older, more pedestrian work in two important respect""" * 20

with open('sample_text.txt', 'w', encoding='utf-8') as f:
    f.write(sample_training_text)


Train the model
(The sample file took about 2 minutes on an Intel i9 CPU)

In [14]:
model, dataset = train_model("sample_text.txt", epochs=50, batch_size=32)

Device initalialized: cpu
Vocab size: 42
Text length: 28760
Starting training...
Epoch [10/50], Loss: 1.1859
Epoch [20/50], Loss: 0.2273
Epoch [30/50], Loss: 0.1219
Epoch [40/50], Loss: 0.0928
Epoch [50/50], Loss: 0.0763
Training complete.


Generate some text

In [21]:
generated_text = generate_text(model, dataset, torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
                               seed_text='The meaning of life is', length=100, temperature=1)
print(f"Generated Text: {generated_text}")



Generated Text: The meaning of life isttttttttttttttttttttttttstttttttttpttttttpttttttttpttttttttptttttttttutptttttttttttttttttttttttttttt
