# Next Word Prediction using N-grams

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from collections import Counter
import re
import requests
import random
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

class RobustVocabulary:
    def __init__(self, texts, min_freq=2, max_vocab_size=10000):
        all_tokens = []
        for text in texts:
            all_tokens.extend(self.tokenize(text))
        
        self.token_counts = Counter(all_tokens)
        common_tokens = [token for token, count in self.token_counts.most_common() 
                        if count >= min_freq][:max_vocab_size-2]
        
        self.vocab = ['<pad>', '<unk>'] + common_tokens
        self.token_to_idx = {token: idx for idx, token in enumerate(self.vocab)}
        self.idx_to_token = {idx: token for token, idx in self.token_to_idx.items()}
        
        print(f"Vocabulary: {len(self.vocab)} tokens")
        print(f"Top words: {common_tokens[:10]}")
    
    def tokenize(self, text):
        text = text.lower().strip()
        text = re.sub(r"'", "'", text)
        tokens = re.findall(r"\b\w+(?:'\w+)?\b", text)
        return tokens
    
    def get_itos(self):
        return self.vocab
    
    def __len__(self):
        return len(self.vocab)
    
    def encode(self, tokens):
        return [self.token_to_idx.get(token, 1) for token in tokens]

class NGramDataset(Dataset):
    def __init__(self, texts, vocab, context_size=3):
        self.vocab = vocab
        self.context_size = context_size
        self.data = self.create_ngram_data(texts)
    
    def create_ngram_data(self, texts):
        all_data = []
        
        for text in texts:
            tokens = self.vocab.tokenize(text)
            if len(tokens) < self.context_size + 1:
                continue
                
            token_indices = self.vocab.encode(tokens)
            
            for i in range(self.context_size, len(token_indices)):
                context = token_indices[i-self.context_size:i]
                target = token_indices[i]
                all_data.append((torch.tensor(context), torch.tensor(target)))
        
        return all_data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

class ImprovedNGramModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=128, context_size=3, hidden_dim=256, dropout=0.3):
        super(ImprovedNGramModel, self).__init__()
        self.context_size = context_size
        self.embedding_dim = embedding_dim
        
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        
        self.fc1 = nn.Linear(context_size * embedding_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout1 = nn.Dropout(dropout)
        
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.dropout2 = nn.Dropout(dropout)
        
        self.fc3 = nn.Linear(hidden_dim, vocab_size)
        
        self.relu = nn.ReLU()
    
    def forward(self, inputs):
        embeds = self.embeddings(inputs)
        context_vector = embeds.view(embeds.size(0), -1)
        
        x = self.relu(self.bn1(self.fc1(context_vector)))
        x = self.dropout1(x)
        
        x = self.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        
        output = self.fc3(x)
        return output

def load_sample_data():
    try:
        print("Downloading data...")
        url = "https://www.gutenberg.org/files/74/74-0.txt"
        response = requests.get(url, timeout=10)
        if response.status_code == 200:
            text = response.text
            sentences = [s.strip() for s in text.split('.') if len(s.strip()) > 20]
            print(f"Downloaded {len(sentences)} sentences")
            return sentences[:2000]
        else:
            raise Exception(f"HTTP {response.status_code}")
    except Exception as e:
        print(f"Download failed: {e}")
        return []

def train_model_with_validation(model, train_loader, val_loader, epochs=50, lr=0.001):
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.5)
    
    train_losses = []
    val_losses = []
    
    best_val_loss = float('inf')
    patience = 10
    patience_counter = 0
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for contexts, targets in train_loader:
            optimizer.zero_grad()
            outputs = model(contexts)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for contexts, targets in val_loader:
                outputs = model(contexts)
                loss = criterion(outputs, targets)
                val_loss += loss.item()
        
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        scheduler.step()
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), 'best_ngram_model.pth')
        else:
            patience_counter += 1
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}/{epochs} - Train: {train_loss:.4f}, Val: {val_loss:.4f}')
        
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    model.load_state_dict(torch.load('best_ngram_model.pth'))
    
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.show()
    
    return train_losses, val_losses

def evaluate_model(model, test_loader, vocab):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    
    total_loss = 0
    correct_predictions = 0
    total_predictions = 0
    
    with torch.no_grad():
        for contexts, targets in test_loader:
            outputs = model(contexts)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
            
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == targets).sum().item()
            total_predictions += targets.size(0)
    
    avg_loss = total_loss / len(test_loader)
    accuracy = correct_predictions / total_predictions
    perplexity = torch.exp(torch.tensor(avg_loss))
    
    print(f"Test Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}, Perplexity: {perplexity:.4f}")
    
    return avg_loss, accuracy, perplexity

def interactive_predict_improved(model, vocab, context_size=3):
    print(f"Interactive prediction (context: {context_size}, vocab: {len(vocab)})")
    print("Commands: 'sample', 'quit', or enter text")
    
    while True:
        user_input = input("Enter text: ").strip()
        
        if user_input.lower() == 'quit':
            break
        
        if user_input.lower() == 'sample':
            print(', '.join(vocab.get_itos()[2:32]))
            continue
        
        if not user_input:
            continue
        
        try:
            tokens = vocab.tokenize(user_input)
            token_indices = vocab.encode(tokens)
            
            if len(token_indices) < context_size:
                padding_needed = context_size - len(token_indices)
                token_indices = [0] * padding_needed + token_indices
            
            context = token_indices[-context_size:]
            context_tensor = torch.tensor([context])
            
            model.eval()
            with torch.no_grad():
                output = model(context_tensor)
                probabilities = F.softmax(output, dim=1).squeeze()
                top_probs, top_indices = probabilities.topk(3)
                
                for i, (prob, idx) in enumerate(zip(top_probs, top_indices), 1):
                    word = vocab.get_itos()[idx.item()]
                    print(f"{i}. {word} ({prob.item()*100:.1f}%)")
                
        except Exception as e:
            print(f"Error: {e}")

if __name__ == "__main__":
    CONTEXT_SIZE = 3
    EMBEDDING_DIM = 128
    HIDDEN_DIM = 256
    BATCH_SIZE = 64
    EPOCHS = 50
    
    texts = load_sample_data()
    if not texts:
        print("No data available")
        exit()
    
    train_texts, temp_texts = train_test_split(texts, test_size=0.4, random_state=42)
    val_texts, test_texts = train_test_split(temp_texts, test_size=0.5, random_state=42)
    
    print(f"Train: {len(train_texts)}, Val: {len(val_texts)}, Test: {len(test_texts)}")
    
    vocab = RobustVocabulary(train_texts, min_freq=2, max_vocab_size=5000)
    
    train_dataset = NGramDataset(train_texts, vocab, CONTEXT_SIZE)
    val_dataset = NGramDataset(val_texts, vocab, CONTEXT_SIZE)
    test_dataset = NGramDataset(test_texts, vocab, CONTEXT_SIZE)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    model = ImprovedNGramModel(
        vocab_size=len(vocab),
        embedding_dim=EMBEDDING_DIM,
        context_size=CONTEXT_SIZE,
        hidden_dim=HIDDEN_DIM
    )
    
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    train_losses, val_losses = train_model_with_validation(
        model, train_loader, val_loader, EPOCHS
    )
    
    evaluate_model(model, test_loader, vocab)
    interactive_predict_improved(model, vocab, CONTEXT_SIZE)

In [None]:
# Simple interactive prediction cell - run after training on Tom Sawyer dataset
import torch
import torch.nn.functional as F

def predict_next_word():
    print("N-gram Next Word Prediction")
    print("Type 'quit' to exit")
    
    while True:
        user_input = input("Enter text: ").strip()
        
        if user_input.lower() == 'quit':
            break
        
        if not user_input:
            continue
        
        try:
            # Tokenize input
            tokens = vocab.tokenize(user_input)
            token_indices = vocab.encode(tokens)
            
            # Pad if needed
            if len(token_indices) < CONTEXT_SIZE:
                padding_needed = CONTEXT_SIZE - len(token_indices)
                token_indices = [0] * padding_needed + token_indices
            
            # Get context
            context = token_indices[-CONTEXT_SIZE:]
            context_tensor = torch.tensor([context])
            
            # Predict
            model.eval()
            with torch.no_grad():
                output = model(context_tensor)
                probabilities = F.softmax(output, dim=1).squeeze()
                top_probs, top_indices = probabilities.topk(3)
                
                print(f"Predictions for '{user_input}':")
                for i, (prob, idx) in enumerate(zip(top_probs, top_indices), 1):
                    word = vocab.get_itos()[idx.item()]
                    confidence = prob.item() * 100
                    print(f"  {i}. {word} ({confidence:.1f}%)")
                print()
                
        except Exception as e:
            print(f"Error: {e}")

# Start prediction
predict_next_word()