In [1]:
import os
import torch
import random
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from torch.distributions import Categorical

In [2]:
hidden_size = 512   
batch_size = 50    
step_len = 200      
num_layers = 3      
lr = 0.001          
num_steps = 1     
gen_seq_len = 30    
load_chk = False    
save_path = "navi_rnn_model.pt"
device = 'cuda'

In [None]:
data_path = "cleaned_navi_post.txt"
corpus = open(data_path, 'r').read()
# corpus = load_all_text_files_in_folder(data_path)
words = sorted(list(set(corpus.split())))
data_size, vocab_size = len(corpus.split()), len(words)

In [4]:
class TextDataset(Dataset):
    def __init__(self, text, vocab, seq_length=200):
        self.text = text.split()
        self.vocab = vocab
        self.vocab_size = len(vocab)
        self.seq_length = seq_length
        self.word_to_idx = {word: idx for idx, word in enumerate(vocab)}
        self.idx_to_word = {idx: word for word, idx in self.word_to_idx.items()}

        self.encoded_text = [self.word_to_idx[word] for word in self.text if word in self.word_to_idx]

    def __len__(self):
        return len(self.encoded_text) - self.seq_length

    def __getitem__(self, idx):
        input_seq = self.encoded_text[idx:idx+self.seq_length]
        target_seq = self.encoded_text[idx+1:idx+self.seq_length+1]
        return torch.tensor(input_seq, dtype=torch.long), torch.tensor(target_seq, dtype=torch.long)

In [5]:
dataset = TextDataset(corpus, words, seq_length=step_len)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [8]:
class RNN(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embedding(x)
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out.reshape(-1, self.hidden_size))
        return out, hidden

In [None]:

rnn = RNN(vocab_size, hidden_size, num_layers).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=lr)


In [10]:
def get_training_batch_indicies(index_list, batch_size):
    input_batch_indicies = torch.tensor(np.array(random.sample(index_list, batch_size)))

    target_batch_indicies = input_batch_indicies + 1
    return input_batch_indicies, target_batch_indicies

In [11]:
def generate_text(model, seed_word, vocab, idx_to_word, gen_seq_len=20):
    model.eval()
    input_seq = [vocab[seed_word]] if seed_word in vocab else [random.randint(0, len(vocab)-1)]
    input_tensor = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0).to(device)
    hidden = None
    generated_words = [seed_word]

    for _ in range(gen_seq_len):
        with torch.no_grad():
            output, hidden = model(input_tensor, hidden)
            probs = torch.nn.functional.softmax(output[-1], dim=0)
            next_word_idx = torch.multinomial(probs, 1).item()
            generated_words.append(idx_to_word[next_word_idx])

            input_tensor = torch.tensor([[next_word_idx]], dtype=torch.long).to(device)

    return ' '.join(generated_words)

In [12]:
for epoch in range(num_steps):
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        output, _ = rnn(inputs)
        loss = criterion(output, targets.view(-1))
        loss.backward()
        optimizer.step()
        print(f"Loss: {loss.item():.4f}")

        seed_word = random.choice(words)
    print(f"Generated Text (Seed: {seed_word}):")
    print(generate_text(rnn, seed_word, {w: i for i, w in enumerate(words)}, {i: w for i, w in enumerate(words)}))
    print(f"Epoch [{epoch+1}/{num_steps}], Loss: {loss.item():.4f}")

# Save Model
torch.save(rnn.state_dict(), save_path)
print("Model saved!")


Loss: 10.5406
Loss: 10.5193
Loss: 10.4752
Loss: 10.2426
Loss: 9.5104
Loss: 9.0281
Loss: 8.6260
Loss: 8.2464
Loss: 8.2634
Loss: 8.0557
Loss: 8.1175
Loss: 8.1149
Loss: 8.0545
Loss: 8.0854
Loss: 8.1970
Loss: 8.2706
Loss: 8.1841
Loss: 8.1475
Loss: 8.2246
Loss: 8.2462
Loss: 8.2242
Loss: 8.1104
Loss: 8.2655
Loss: 8.1084
Loss: 8.0326
Loss: 8.4123
Loss: 8.1978
Loss: 8.2939
Loss: 8.0425
Loss: 8.0828
Loss: 8.1366
Loss: 8.2199
Loss: 8.2529
Loss: 7.9644
Loss: 8.1424
Loss: 8.1513
Loss: 8.1913
Loss: 7.7950
Loss: 8.0841
Loss: 8.1554
Loss: 8.0145
Loss: 7.9594
Loss: 8.0051
Loss: 7.9759
Loss: 8.1299
Loss: 7.8562
Loss: 8.2214
Loss: 7.9686
Loss: 8.0307
Loss: 8.0987
Loss: 7.9558
Loss: 8.0633
Loss: 7.8542
Loss: 7.9705
Loss: 8.0838
Loss: 7.9777
Loss: 7.9191
Loss: 7.9815
Loss: 7.9359
Loss: 7.9729
Loss: 8.0934
Loss: 7.9747
Loss: 7.8917
Loss: 7.9853
Loss: 8.2696
Loss: 7.9478
Loss: 8.0246
Loss: 7.9929
Loss: 8.0585
Loss: 8.0545
Loss: 8.0127
Loss: 7.9931
Loss: 7.9598
Loss: 7.8702
Loss: 8.0207
Loss: 7.9418
Loss: 7.