In [19]:
import torch
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torch.utils.data import DataLoader, Dataset

import torchtext
torchtext.disable_torchtext_deprecation_warning()
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence  # Import pad_sequence function

In [5]:
# Load the WikiText-2 dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
dataset

Generating test split: 100%|██████████| 4358/4358 [00:00<00:00, 75024.94 examples/s]
Generating train split: 100%|██████████| 36718/36718 [00:00<00:00, 703916.88 examples/s]
Generating validation split: 100%|██████████| 3760/3760 [00:00<00:00, 442944.14 examples/s]


DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})

In [20]:
import re

# Tokenizer function
def tokenize(text):
    # Simple tokenizer that splits on non-alphabetic characters
    return re.findall(r'\b\w+\b', text.lower())

# Build vocabulary
counter = Counter()
for line in dataset['train']['text']:
    counter.update(tokenize(line))

# Define vocabulary size and special tokens
vocab_size = 10000
special_tokens = ['<unk>', '<pad>', '<bos>', '<eos>']
vocab = {word: idx + len(special_tokens) for idx, (word, _) in enumerate(counter.most_common(vocab_size - len(special_tokens)))}
for idx, token in enumerate(special_tokens):
    vocab[token] = idx

# Inverse vocabulary for decoding
inv_vocab = {idx: word for word, idx in vocab.items()}

# Encode function
def encode(text):
    tokens = tokenize(text)
    return [vocab.get(token, vocab['<unk>']) for token in tokens]

# Add special tokens to each sentence
def add_special_tokens(encoded_text):
    return [vocab['<bos>']] + encoded_text + [vocab['<eos>']]

# Prepare dataset
def prepare_dataset(split):
    encoded_texts = [add_special_tokens(encode(line)) for line in dataset[split]['text']]
    return encoded_texts

train_data = prepare_dataset('train')
valid_data = prepare_dataset('validation')
test_data = prepare_dataset('test')

# Data collate function for DataLoader
def collate_fn(batch):
    batch = [torch.tensor(item) for item in batch]
    batch = pad_sequence(batch, batch_first=True, padding_value=vocab['<pad>'])
    inputs = batch[:, :-1]
    targets = batch[:, 1:]
    return inputs, targets

# Create DataLoaders
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_data, batch_size=batch_size, collate_fn=collate_fn)
test_loader = DataLoader(test_data, batch_size=batch_size, collate_fn=collate_fn)

In [9]:
import torch.nn as nn

class RNNLanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, dropout=0.5):
        super(RNNLanguageModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=vocab['<pad>'])
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, hidden):
        x = self.embedding(x)
        x = self.dropout(x)
        output, hidden = self.lstm(x, hidden)
        output = self.dropout(output)
        output = self.fc(output)
        return output, hidden

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        return (weight.new_zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size),
                weight.new_zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size))


In [16]:
# Hyperparameters
embed_size = 128      # Size of word embeddings
hidden_size = 256     # Number of features in the hidden state of the RNN
num_layers = 2        # Number of recurrent layers (e.g., LSTM layers)
num_epochs = 10       # Number of training epochs
learning_rate = 0.001 # Learning rate for the optimizer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Initialize the model
model = RNNLanguageModel(
    vocab_size=len(vocab),   # Size of the vocabulary
    embed_size=embed_size,   # Embedding size
    hidden_size=hidden_size, # Hidden state size
    num_layers=num_layers,   # Number of LSTM layers
    dropout=0.5              # Dropout rate
).to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=vocab['<pad>']) # Ignore padding in loss calculation
optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Adam optimizer


In [35]:
import wandb
from tqdm import tqdm


# Initialize WandB
wandb.init(
    project="rnn-language-model",
    config={
        "embed_size": embed_size,
        "hidden_size": hidden_size,
        "num_layers": num_layers,
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "num_epochs": num_epochs,
    }
)
import torch
from tqdm import tqdm
import wandb

# Function to train the model with logging and dynamic hidden state initialization.
# The function now accepts a `start_epoch` parameter.
def train_model(model, train_loader, valid_loader, criterion, optimizer, num_epochs, start_epoch=1, log_interval=100):
    model.train()  # Set the model to training mode

    for epoch in range(start_epoch, start_epoch + num_epochs):
        total_loss = 0
        num_batches = len(train_loader)
        
        # Initialize hidden state for the first batch; will update it dynamically later.
        hidden = None

        # Use tqdm for progress bar
        pbar = tqdm(enumerate(train_loader), total=num_batches, desc=f"Epoch {epoch}")
        for batch_idx, (inputs, targets) in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            current_batch_size = inputs.size(0)
            
            # Initialize or update hidden state dynamically based on current batch size.
            if hidden is None or hidden[0].size(1) != current_batch_size:
                hidden = model.init_hidden(current_batch_size)
            else:
                # Detach hidden state to prevent backpropagating through the entire history
                hidden = tuple([h.detach() for h in hidden])
            
            optimizer.zero_grad()  # Clear gradients
            
            # Forward pass
            output, hidden = model(inputs, hidden)
            loss = criterion(output.view(-1, len(vocab)), targets.view(-1))
            loss.backward()  # Backward pass
            
            # Clip gradients to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()  # Update parameters
            
            total_loss += loss.item()
            
            # Logging every `log_interval` steps
            if (batch_idx + 1) % log_interval == 0:
                avg_loss = total_loss / log_interval
                wandb.log({"Training Loss": avg_loss, "Epoch": epoch, "Batch": batch_idx})
                pbar.set_postfix(loss=f"{avg_loss:.4f}")
                total_loss = 0  # Reset loss tracker

        # Validate after each epoch
        val_loss = evaluate_model(model, valid_loader, criterion)
        wandb.log({"Validation Loss": val_loss, "Epoch": epoch})
        print(f"Epoch {epoch}: Validation Loss: {val_loss:.4f}")

        #Save a checkpoint after 10 epochs
        if epoch % 10 == 0:
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'validation_loss': val_loss
            }
            torch.save(checkpoint, f'checkpoint_epoch{epoch}.pth')
            print(f"Checkpoint saved at epoch {epoch}")

# Function to evaluate the model remains unchanged.
def evaluate_model(model, data_loader, criterion):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0
    num_batches = len(data_loader)
    hidden = None

    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            current_batch_size = inputs.size(0)
            if hidden is None or hidden[0].size(1) != current_batch_size:
                hidden = model.init_hidden(current_batch_size)
            else:
                hidden = tuple([h.detach() for h in hidden])
            output, hidden = model(inputs, hidden)
            loss = criterion(output.view(-1, len(vocab)), targets.view(-1))
            total_loss += loss.item()
    model.train()  # Switch back to training mode
    return total_loss / num_batches



In [27]:
# Start training (from beginning - DO NOT RUN THIS IF RESUMING TRAINING)
train_model(model, train_loader, valid_loader, criterion, optimizer, num_epochs=num_epochs, start_epoch=1)

Epoch 1: 100%|██████████| 1148/1148 [00:41<00:00, 27.87it/s, loss=5.8464]


Epoch 1: Validation Loss: 5.5871


Epoch 2: 100%|██████████| 1148/1148 [00:40<00:00, 28.45it/s, loss=5.7392]


Epoch 2: Validation Loss: 5.4788


Epoch 3: 100%|██████████| 1148/1148 [00:40<00:00, 28.34it/s, loss=5.6471]


Epoch 3: Validation Loss: 5.3878


Epoch 4: 100%|██████████| 1148/1148 [00:40<00:00, 28.12it/s, loss=5.5780]


Epoch 4: Validation Loss: 5.3182


Epoch 5: 100%|██████████| 1148/1148 [00:40<00:00, 28.57it/s, loss=5.4934]


Epoch 5: Validation Loss: 5.2707


Epoch 6: 100%|██████████| 1148/1148 [00:40<00:00, 28.37it/s, loss=5.4207]


Epoch 6: Validation Loss: 5.2227


Epoch 7: 100%|██████████| 1148/1148 [00:40<00:00, 28.20it/s, loss=5.4009]


Epoch 7: Validation Loss: 5.1870


Epoch 8: 100%|██████████| 1148/1148 [00:40<00:00, 28.29it/s, loss=5.3519]


Epoch 8: Validation Loss: 5.1531


Epoch 9: 100%|██████████| 1148/1148 [00:40<00:00, 28.29it/s, loss=5.3212]


Epoch 9: Validation Loss: 5.1305


Epoch 10: 100%|██████████| 1148/1148 [00:40<00:00, 28.41it/s, loss=5.2686]


Epoch 10: Validation Loss: 5.1090


In [36]:
# Load the previous checkpoint
checkpoint = torch.load('checkpoint_epoch10.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Set the start epoch to one after the checkpoint's epoch
start_epoch = checkpoint['epoch'] + 1

print(f"Resuming training from epoch {start_epoch}")

# Specify the additional number of epochs you want to train for
additional_epochs = 20

# Resume training using the same train_model function
train_model(model, train_loader, valid_loader, criterion, optimizer, num_epochs=additional_epochs, start_epoch=start_epoch)

Resuming training from epoch 11


Epoch 11: 100%|██████████| 1148/1148 [00:40<00:00, 28.31it/s, loss=5.2538]


Epoch 11: Validation Loss: 5.0909
Checkpoint saved at epoch 11


Epoch 12: 100%|██████████| 1148/1148 [00:40<00:00, 28.37it/s, loss=5.2273]


Epoch 12: Validation Loss: 5.0722
Checkpoint saved at epoch 12


Epoch 13: 100%|██████████| 1148/1148 [00:40<00:00, 28.11it/s, loss=5.2133]


Epoch 13: Validation Loss: 5.0566
Checkpoint saved at epoch 13


Epoch 14: 100%|██████████| 1148/1148 [00:40<00:00, 28.30it/s, loss=5.1836]


Epoch 14: Validation Loss: 5.0445
Checkpoint saved at epoch 14


Epoch 15: 100%|██████████| 1148/1148 [00:40<00:00, 28.25it/s, loss=5.1602]


Epoch 15: Validation Loss: 5.0293
Checkpoint saved at epoch 15


Epoch 16: 100%|██████████| 1148/1148 [00:40<00:00, 28.25it/s, loss=5.1571]


Epoch 16: Validation Loss: 5.0215
Checkpoint saved at epoch 16


Epoch 17: 100%|██████████| 1148/1148 [00:40<00:00, 28.33it/s, loss=5.1071]


Epoch 17: Validation Loss: 5.0110
Checkpoint saved at epoch 17


Epoch 18: 100%|██████████| 1148/1148 [00:40<00:00, 28.39it/s, loss=5.1229]


Epoch 18: Validation Loss: 5.0012
Checkpoint saved at epoch 18


Epoch 19: 100%|██████████| 1148/1148 [00:40<00:00, 28.25it/s, loss=5.0664]


Epoch 19: Validation Loss: 4.9959
Checkpoint saved at epoch 19


Epoch 20: 100%|██████████| 1148/1148 [00:40<00:00, 28.44it/s, loss=5.0547]


Epoch 20: Validation Loss: 4.9926
Checkpoint saved at epoch 20


Epoch 21: 100%|██████████| 1148/1148 [00:40<00:00, 28.34it/s, loss=5.0620]


Epoch 21: Validation Loss: 4.9795
Checkpoint saved at epoch 21


Epoch 22: 100%|██████████| 1148/1148 [00:40<00:00, 28.41it/s, loss=5.0521]


Epoch 22: Validation Loss: 4.9746
Checkpoint saved at epoch 22


Epoch 23: 100%|██████████| 1148/1148 [00:40<00:00, 28.25it/s, loss=5.0224]


Epoch 23: Validation Loss: 4.9711
Checkpoint saved at epoch 23


Epoch 24: 100%|██████████| 1148/1148 [00:40<00:00, 28.45it/s, loss=5.0168]


Epoch 24: Validation Loss: 4.9674
Checkpoint saved at epoch 24


Epoch 25: 100%|██████████| 1148/1148 [00:40<00:00, 28.25it/s, loss=5.0180]


Epoch 25: Validation Loss: 4.9589
Checkpoint saved at epoch 25


Epoch 26: 100%|██████████| 1148/1148 [00:40<00:00, 28.12it/s, loss=5.0203]


Epoch 26: Validation Loss: 4.9570
Checkpoint saved at epoch 26


Epoch 27: 100%|██████████| 1148/1148 [00:40<00:00, 28.21it/s, loss=5.0022]


Epoch 27: Validation Loss: 4.9502
Checkpoint saved at epoch 27


Epoch 28: 100%|██████████| 1148/1148 [00:40<00:00, 28.45it/s, loss=4.9729]


Epoch 28: Validation Loss: 4.9453
Checkpoint saved at epoch 28


Epoch 29: 100%|██████████| 1148/1148 [00:40<00:00, 28.25it/s, loss=4.9762]


Epoch 29: Validation Loss: 4.9406
Checkpoint saved at epoch 29


Epoch 30: 100%|██████████| 1148/1148 [00:40<00:00, 28.39it/s, loss=4.9610]


Epoch 30: Validation Loss: 4.9417
Checkpoint saved at epoch 30


In [37]:
wandb.finish()

0,1
Batch,▂▃▅▆▄▄▆▇▁▅▇▂▅▂▃▂█▃▄▆▁▂▄██▄▇▇▃▅▄▁▂▄▇▆▇▂▅▇
Epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
Training Loss,███▇▇▆▇▇▆▆▆▅▆▆▆▅▅▄▅▅▄▄▃▄▄▄▃▄▃▃▂▂▃▂▂▂▁▂▂▂
Validation Loss,█▇▆▆▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁

0,1
Batch,1099.0
Epoch,30.0
Training Loss,4.96097
Validation Loss,4.94166


In [38]:
# save trained model to disk
saved_model_name = 'rnn_language_model_30epochs.pth'
torch.save(model.state_dict(), saved_model_name)

In [39]:
# Load the saved state dictionary
model.load_state_dict(torch.load(saved_model_name, map_location=device))

# Set the model to evaluation mode
model.eval()

print("Model loaded successfully!")

Model loaded successfully!


In [85]:
def predict_next_word(model, input_text, vocab, inv_vocab, top_k=5):
    model.eval()  # Set the model to evaluation mode
    # Tokenize the input text
    tokens = tokenize(input_text)
    # Convert tokens to indices, using <unk> for unseen words
    input_ids = [vocab.get(token, vocab['<unk>']) for token in tokens]
    # Convert to tensor and add batch dimension (1, sequence_length)
    input_tensor = torch.tensor(input_ids).unsqueeze(0).to(device)
    
    # Initialize hidden state with batch size 1
    hidden = model.init_hidden(1)
    
    with torch.no_grad():
        # Forward pass: get output predictions and update hidden state
        output, hidden = model(input_tensor, hidden)
    
    # Get logits for the last token in the sequence
    logits = output[0, -1]  # Shape: [vocab_size]
    # Apply softmax to convert logits to probabilities
    probabilities = torch.softmax(logits, dim=0)
    # Get the indices of the top_k words with highest probability
    top_probs, top_indices = torch.topk(probabilities, top_k)
    # Map indices to words using the inverse vocabulary
    top_words = [inv_vocab[idx.item()] for idx in top_indices]
    
    return top_words

def generate_text_sample(model, prompt, vocab, inv_vocab, num_tokens, top_k=5, temperature=1.0):
    """
    Generate text by predicting and appending num_tokens to the prompt using sampling.
    """
    generated_text = prompt
    
    for _ in range(num_tokens):
        # Get predictions from the existing function
        predictions = predict_next_word(model, generated_text, vocab, inv_vocab, top_k=top_k)
        
        # Instead of choosing the top prediction, sample from top-k.
        # First, get the raw output logits from the model:
        tokens = tokenize(generated_text)
        input_ids = [vocab.get(token, vocab['<unk>']) for token in tokens]
        input_tensor = torch.tensor(input_ids).unsqueeze(0).to(device)
        hidden = model.init_hidden(1)
        with torch.no_grad():
            output, hidden = model(input_tensor, hidden)
        logits = output[0, -1] / temperature  # adjust logits by temperature
        probabilities = torch.softmax(logits, dim=0)
        
        # Get top-k probabilities and indices
        top_probs, top_indices = torch.topk(probabilities, top_k)
        
        # Sample from the top-k tokens according to their probabilities
        top_probs = top_probs / torch.sum(top_probs)  # normalize
        sampled_index = torch.multinomial(top_probs, 1).item()
        next_word_idx = top_indices[sampled_index].item()
        next_word = inv_vocab.get(next_word_idx, "<unk>")
        
        generated_text += " " + next_word
    
    return generated_text

def generate_text_naive(prompt,num_tokens=100):
    generated_text = prompt
    for i in range(num_tokens):
        predicted_words = predict_next_word(model, generated_text, vocab, inv_vocab, top_k=5)
        next_word = predicted_words[0]
        if next_word == "<unk>":
            next_word = predicted_words[2]
        
        generated_text += next_word + " "

    return generated_text

In [93]:
prompt = "I am a man. "
generated_text = generate_text_naive(prompt,50)
print("Generated text: ",generated_text)

Generated text:  I am a man. in the world s last day of the year the album was released in the united states and the song s release of the album s first album album in the united states and the song was released in the united states and the song s release of the album 


In [100]:
# Example usage:
prompt = "I am a man. I can be a very good "
predicted_words = predict_next_word(model, prompt, vocab, inv_vocab, top_k=5)
print("Prompt:", prompt)
print("Next word predictions:", predicted_words)

Prompt: I am a man. I can be a very good 
Next word predictions: ['<unk>', 'and', 'of', 'thing', 'to']


In [76]:
# Example usage:
prompt = "Hello, this is "
generated = generate_text_sample(model, prompt, vocab, inv_vocab, num_tokens=20, top_k=5, temperature=.5)
print("Generated Text:")
print(generated)

Generated Text:
Hello, this is  the most common <unk> in <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>


In [None]:
predict_next_word