In [9]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import re
from collections import Counter
import streamlit as st
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import requests
from typing import List, Tuple, Dict
import random

class TextProcessor:
    def __init__(self, min_word_freq: int = 3):
        self.word2idx = {}
        self.idx2word = {}
        self.min_word_freq = min_word_freq
        
    def clean_text(self, text: str) -> str:
        """Clean the input text by removing special characters except periods."""
        text = text.lower()
        text = re.sub('[^a-zA-Z0-9 \.]', '', text)
        return text
    
    def build_vocabulary(self, text: str) -> None:
        """Build vocabulary from cleaned text."""
        # Clean and split text
        cleaned_text = self.clean_text(text)
        words = cleaned_text.split()
        
        # Count word frequencies
        word_freq = Counter(words)
        
        # Create vocabulary (only include words that appear at least min_word_freq times)
        vocab_words = [word for word, freq in word_freq.items() if freq >= self.min_word_freq]
        
        # Create word to index mappings
        self.word2idx = {word: idx for idx, word in enumerate(vocab_words)}
        self.idx2word = {idx: word for word, idx in self.word2idx.items()}
        
    def create_sequences(self, text: str, context_length: int) -> Tuple[List[List[int]], List[int]]:
        """Create input-output sequences for training."""
        cleaned_text = self.clean_text(text)
        words = cleaned_text.split()
        
        X, y = [], []
        for i in range(len(words) - context_length):
            context = words[i:i+context_length]
            target = words[i+context_length]
            
            # Skip if any word is not in vocabulary
            if all(word in self.word2idx for word in context) and target in self.word2idx:
                X.append([self.word2idx[word] for word in context])
                y.append(self.word2idx[target])
                
        return X, y

class WordPredictionDataset(Dataset):
    def __init__(self, X: List[List[int]], y: List[int]):
        self.X = torch.LongTensor(X)
        self.y = torch.LongTensor(y)
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

class WordPredictionModel(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int, context_length: int, hidden_dim: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.fc1 = nn.Linear(embedding_dim * context_length, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, vocab_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        embedded = self.embedding(x)  # Shape: (batch_size, context_length, embedding_dim)
        embedded = embedded.view(embedded.shape[0], -1)  # Flatten the embeddings
        x = self.relu(self.fc1(embedded))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x
    
    def get_embeddings(self):
        """Return the learned word embeddings."""
        return self.embedding.weight.detach().numpy()

def train_model(model: nn.Module, train_loader: DataLoader, 
                num_epochs: int, learning_rate: float, device: str) -> List[float]:
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    losses = []
    
    model.to(device)
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        
        for batch_X, batch_y in train_loader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        
        if (epoch + 1) % 1 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
            
    return losses

def predict_next_words(model: nn.Module, processor: TextProcessor, 
                      input_text: str, k: int, context_length: int, device: str) -> List[str]:
    """Predict the next k words given an input text."""
    model.eval()
    words = processor.clean_text(input_text).split()
    
    predictions = []
    for _ in range(k):
        # Take the last context_length words
        context = words[-context_length:]
        
        # Convert to indices, handling OOV words
        context_indices = []
        for word in context:
            if word in processor.word2idx:
                context_indices.append(processor.word2idx[word])
            else:
                # For OOV words, randomly select a word from vocabulary
                context_indices.append(random.choice(list(processor.word2idx.values())))
        
        # Convert to tensor and get prediction
        context_tensor = torch.LongTensor([context_indices]).to(device)
        with torch.no_grad():
            output = model(context_tensor)
            pred_idx = torch.argmax(output, dim=1).item()
            pred_word = processor.idx2word[pred_idx]
        
        predictions.append(pred_word)
        words.append(pred_word)
        
    return predictions

def visualize_embeddings(embeddings: np.ndarray, words: List[str], n_components: int = 2):
    """Visualize word embeddings using t-SNE or scatter plot."""
    if embeddings.shape[1] > 2:
        tsne = TSNE(n_components=n_components, random_state=42)
        embeddings_2d = tsne.fit_transform(embeddings)
    else:
        embeddings_2d = embeddings
    
    plt.figure(figsize=(12, 8))
    plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], alpha=0.5)
    
    for i, word in enumerate(words):
        plt.annotate(word, (embeddings_2d[i, 0], embeddings_2d[i, 1]))
    
    plt.title('Word Embeddings Visualization')
    plt.xlabel('Dimension 1')
    plt.ylabel('Dimension 2')
    return plt

def create_streamlit_app():
    st.title('Next Word Prediction App')
    
    # Model parameters
    context_length = st.slider('Context Length', 2, 10, 5)
    embedding_dim = st.select_slider('Embedding Dimension', options=[32, 64, 128, 256], value=64)
    hidden_dim = st.select_slider('Hidden Dimension', options=[512, 1024, 2048], value=1024)
    activation = st.selectbox('Activation Function', ['ReLU', 'Tanh', 'LeakyReLU'])
    
    # Input text
    input_text = st.text_area('Enter your text:', 'The quick brown fox jumps')
    num_words = st.slider('Number of words to predict', 1, 10, 3)
    
    if st.button('Predict'):
        # Load model (you would need to save and load your trained models)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Here you would load your trained model and processor
        # For demonstration, we'll just show the structure
        st.write(f'Predicting {num_words} words...')
        st.write('Model parameters:')
        st.write(f'Context length: {context_length}')
        st.write(f'Embedding dimension: {embedding_dim}')
        st.write(f'Hidden dimension: {hidden_dim}')
        st.write(f'Activation: {activation}')

if __name__ == '__main__':
    # For training
    # Load your dataset here
    with open("s.txt","r") as f:
        text = f.read()

    # Initialize processor and create sequences
    processor = TextProcessor(min_word_freq=3)
    processor.build_vocabulary(text)
    X, y = processor.create_sequences(text, context_length=5)
    
    # Create dataset and dataloader
    dataset = WordPredictionDataset(X, y)
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
    
    # Initialize and train model
    model = WordPredictionModel(
        vocab_size=len(processor.word2idx),
        embedding_dim=64,
        context_length=5,
        hidden_dim=1024
    )
    
    # Train model
    #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device('mps')
    losses = train_model(model, train_loader, num_epochs=400, learning_rate=0.001, device=device)
    
    # For Streamlit app
    #create_streamlit_app()

Epoch [1/400], Loss: 6.9463
Epoch [2/400], Loss: 6.7661
Epoch [3/400], Loss: 6.6305
Epoch [4/400], Loss: 6.5433
Epoch [5/400], Loss: 6.4614
Epoch [6/400], Loss: 6.4041
Epoch [7/400], Loss: 6.3500
Epoch [8/400], Loss: 6.3038
Epoch [9/400], Loss: 6.2798
Epoch [10/400], Loss: 6.2519
Epoch [11/400], Loss: 6.2342
Epoch [12/400], Loss: 6.2201
Epoch [13/400], Loss: 6.2054
Epoch [14/400], Loss: 6.1953
Epoch [15/400], Loss: 6.1769
Epoch [16/400], Loss: 6.1681
Epoch [17/400], Loss: 6.1672
Epoch [18/400], Loss: 6.1634
Epoch [19/400], Loss: 6.1573
Epoch [20/400], Loss: 6.1599
Epoch [21/400], Loss: 6.1540
Epoch [22/400], Loss: 6.1524
Epoch [23/400], Loss: 6.1512
Epoch [24/400], Loss: 6.1542
Epoch [25/400], Loss: 6.1537
Epoch [26/400], Loss: 6.1434
Epoch [27/400], Loss: 6.1412
Epoch [28/400], Loss: 6.1483
Epoch [29/400], Loss: 6.1372
Epoch [30/400], Loss: 6.1430
Epoch [31/400], Loss: 6.1463
Epoch [32/400], Loss: 6.1478
Epoch [33/400], Loss: 6.1328
Epoch [34/400], Loss: 6.1485
Epoch [35/400], Loss: 6

In [10]:
model.eval()

WordPredictionModel(
  (embedding): Embedding(13748, 64)
  (fc1): Linear(in_features=320, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=1024, bias=True)
  (fc3): Linear(in_features=1024, out_features=13748, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.2, inplace=False)
)

In [13]:
input_text = "hello there, how are you doing? "
no_of_words = 5
predictions = predict_next_words(model, processor, input_text, no_of_words, context_length=5, device=device)
predictions

['the', 'man', 'the', 'the', 'the']

In [16]:
torch.save(model.state_dict(), "model.pth")

In [17]:
model.state_dict()

OrderedDict([('embedding.weight',
              tensor([[-0.9370, -2.7978,  0.8693,  ...,  0.0991,  0.4725, -1.3914],
                      [-0.3325, -2.9091,  0.7640,  ..., -0.2837, -3.5142,  0.7057],
                      [ 1.8181, -0.3444,  1.3652,  ..., -1.3464, -0.3616, -1.0449],
                      ...,
                      [ 2.4203, -2.3470, -2.1992,  ...,  2.2182, -0.2923, -2.4632],
                      [ 1.2427,  0.4998,  1.6122,  ..., -0.6251,  1.2276, -1.2728],
                      [-1.3240, -2.5738, -1.7213,  ...,  1.2415,  1.8873, -2.7529]],
                     device='mps:0')),
             ('fc1.weight',
              tensor([[ 1.3234,  0.9590,  0.0095,  ..., -0.1819, -0.5921,  0.4217],
                      [-0.6543, -1.0630, -0.9033,  ..., -0.7070,  1.1529, -1.0007],
                      [-0.7353,  0.7766,  1.5406,  ...,  0.3213,  0.5183, -0.9535],
                      ...,
                      [-1.0038,  1.8119,  0.8362,  ..., -0.8274, -1.0179, -1.0997],
    