In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from collections import Counter
from tqdm import tqdm
import numpy as np
import math
import os
import re
import pandas as pd
from typing import List, Tuple, Optional
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
def clean_text(text: str) -> str:
    """Clean and normalize text."""
    text = str(text).lower()
    text = re.sub(r"won't", "will not", text)
    text = re.sub(r"can\'t", "can not", text)
    text = re.sub(r"n\'t", " not", text)
    text = re.sub(r"\'re", " are", text)
    text = re.sub(r"\'s", " is", text)
    text = re.sub(r"\'d", " would", text)
    text = re.sub(r"\'ll", " will", text)
    text = re.sub(r"\'t", " not", text)
    text = re.sub(r"\'ve", " have", text)
    text = re.sub(r"\'m", " am", text)
    text = re.sub(r"[^a-zA-Z\s]", " ", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

In [3]:
def tokenize(text: str) -> List[str]:
    """Simple tokenizer that splits on spaces."""
    return clean_text(text).split()

In [4]:
class IMDbDataset(Dataset):
    def __init__(self, texts: List[str], labels: List[int], vocab: dict, max_length: int):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.max_length = max_length
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        tokens = tokenize(text)
        token_ids = [self.vocab.get(token, self.vocab['<unk>']) for token in tokens]
        
        # Truncate or pad sequence
        if len(token_ids) > self.max_length:
            token_ids = token_ids[:self.max_length]
        else:
            padding = [self.vocab['<pad>']] * (self.max_length - len(token_ids))
            token_ids.extend(padding)
        
        return torch.tensor(token_ids), torch.tensor(label)

In [5]:
def build_vocab(texts: List[str], min_freq: int = 5) -> dict:
    """Build vocabulary from texts."""
    counter = Counter()
    
    for text in tqdm(texts, desc="Building vocabulary"):
        tokens = tokenize(text)
        counter.update(tokens)
    
    vocab = {'<pad>': 0, '<unk>': 1}
    idx = len(vocab)
    
    for token, count in counter.most_common():
        if count >= min_freq:
            vocab[token] = idx
            idx += 1
    
    return vocab

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        batch_size = x.size(0)
        
        # Linear transformations
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Split into heads
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention = F.softmax(scores, dim=-1)
        attention = self.dropout(attention)
        
        # Apply attention to values
        context = torch.matmul(attention, V)
        
        # Concatenate heads
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        # Final linear layer
        output = self.W_o(context)
        
        return output, attention

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, dropout):
        super().__init__()
        
        # Embedding layers
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout)
        
        # Stack of encoder layers
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self, x, mask=None):
        # x shape: (batch_size, seq_len)
        
        # Embed and add positional encoding
        x = self.embedding(x)  # (batch_size, seq_len, d_model)
        x = self.positional_encoding(x)
        
        # Store attention weights from all layers
        attention_weights = []
        
        # Pass through encoder layers
        for layer in self.layers:
            x, attn = layer(x, mask)
            attention_weights.append(attn)
        
        x = self.norm(x)
        
        # Return output and attention weights
        return x, attention_weights

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Self attention block
        attn_output, attention_weights = self.self_attn(x, mask=mask)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)
        
        # Feed forward block
        ff_output = self.feed_forward(x)
        x = x + self.dropout2(ff_output)
        x = self.norm2(x)
        
        return x, attention_weights

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        x = x + self.pe[:x.size(0)].transpose(0, 1)
        return self.dropout(x)

class SentimentTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_heads=8, num_layers=6, 
                 d_ff=2048, dropout=0.1):
        super().__init__()
        self.encoder = TransformerEncoder(
            vocab_size, d_model, num_heads, num_layers, d_ff, dropout
        )
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, 2)
        )
        
    def forward(self, x, mask=None):
        # Encode the input sequence
        encoded, attention_weights = self.encoder(x, mask)
        
        # Global average pooling
        encoded = encoded.mean(dim=1)  # Average over sequence length
        
        # Classify
        output = self.classifier(encoded)
        
        return output, attention_weights

In [8]:
class AttentionVisualizer:
    @staticmethod
    def plot_attention_heatmap(attention_weights: torch.Tensor, 
                             layer_idx: int, 
                             head_idx: int, 
                             tokens: Optional[List[str]] = None,
                             save_path: Optional[str] = None):
        """
        Plots attention heatmap for a specific layer and head.
        """
        # Get attention weights for specified layer and head
        attention = attention_weights[layer_idx][0, head_idx].detach().cpu().numpy()
        
        # Create figure
        plt.figure(figsize=(10, 8))
        
        # Plot heatmap
        if tokens:
            # Truncate attention matrix to actual tokens length
            attention = attention[:len(tokens), :len(tokens)]
            sns.heatmap(attention, cmap="viridis", xticklabels=tokens, yticklabels=tokens)
        else:
            sns.heatmap(attention, cmap="viridis")
            
        plt.title(f"Attention Weights - Layer {layer_idx + 1}, Head {head_idx + 1}")
        plt.xlabel("Key Positions")
        plt.ylabel("Query Positions")
        
        # Rotate x-axis labels if tokens are provided
        if tokens:
            plt.xticks(rotation=45, ha='right')
            plt.yticks(rotation=0)
        
        # Adjust layout to prevent label cutoff
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path)
            plt.close()
        else:
            plt.show()

    @staticmethod
    def visualize_all_attention_heads(attention_weights: List[torch.Tensor], 
                                    tokens: Optional[List[str]] = None,
                                    save_dir: Optional[str] = None):
        """
        Visualizes attention patterns for all layers and heads.
        """
        num_layers = len(attention_weights)
        num_heads = attention_weights[0].size(1)
        
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
        
        for layer in range(num_layers):
            for head in range(num_heads):
                save_path = None
                if save_dir:
                    save_path = os.path.join(save_dir, f'attention_layer{layer+1}_head{head+1}.png')
                AttentionVisualizer.plot_attention_heatmap(
                    attention_weights, layer, head, tokens, save_path
                )

In [9]:
def analyze_text(model: nn.Module, 
                text: str, 
                vocab: dict, 
                max_length: int, 
                device: torch.device,
                visualize: bool = True,
                save_dir: Optional[str] = None) -> Tuple[str, float, List[torch.Tensor]]:
    """
    Analyzes a text sample and optionally visualizes its attention patterns.
    
    Returns:
        Tuple containing (prediction label, confidence score, attention weights)
    """
    model.eval()
    
    # Tokenize and encode
    tokens = tokenize(text)
    token_ids = [vocab.get(token, vocab['<unk>']) for token in tokens]
    
    # Truncate or pad sequence
    if len(token_ids) > max_length:
        token_ids = token_ids[:max_length]
        tokens = tokens[:max_length]
    else:
        padding_length = max_length - len(token_ids)
        token_ids.extend([vocab['<pad>']] * padding_length)
    
    # Predict
    with torch.no_grad():
        sequence = torch.tensor(token_ids).unsqueeze(0).to(device)
        mask = (sequence != 0).unsqueeze(1).unsqueeze(2)
        
        outputs, attention_weights = model(sequence, mask)
        prediction = F.softmax(outputs, dim=1)
        
        # Get prediction details
        pred_class = torch.argmax(prediction).item()
        pred_prob = prediction[0][pred_class].item()
        sentiment = "Positive" if pred_class == 1 else "Negative"
        
        # Visualize if requested
        if visualize:
            print(f"\nAnalyzing text: {text}")
            print(f"Prediction: {sentiment} (confidence: {pred_prob:.2f})")
            print("\nGenerating attention visualizations...")
            
            AttentionVisualizer.visualize_all_attention_heads(
                attention_weights, 
                tokens,
                save_dir
            )
        
        return sentiment, pred_prob, attention_weights

In [10]:
def main():
    # Configuration
    MAX_LENGTH = 256
    BATCH_SIZE = 32
    D_MODEL = 256
    NUM_HEADS = 8
    NUM_LAYERS = 1
    D_FF = 1024
    DROPOUT = 0.1
    LEARNING_RATE = 1e-4
    NUM_EPOCHS = 1
    MIN_FREQ = 5
    TRAIN_RATIO = 0.8
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load and prepare data
    df = pd.read_csv("IMDB_Dataset.csv")
    texts = df['review'].tolist()
    labels = (df['sentiment'] == 'positive').astype(int).tolist()
    
    # Build vocabulary
    print("Building vocabulary...")
    vocab = build_vocab(texts, min_freq=MIN_FREQ)
    vocab_size = len(vocab)
    print(f"Vocabulary size: {vocab_size}")
    
    # Create and train model
    model = SentimentTransformer(
        vocab_size=vocab_size,
        d_model=D_MODEL,
        num_heads=NUM_HEADS,
        num_layers=NUM_LAYERS,
        d_ff=D_FF,
        dropout=DROPOUT
    ).to(device)
    
    # Example usage after training
    sample_text = "This movie was absolutely fantastic! The acting was superb and the story was engaging throughout."
    
    # Analyze text and visualize attention
    sentiment, confidence, attention_weights = analyze_text(
        model=model,
        text=sample_text,
        vocab=vocab,
        max_length=MAX_LENGTH,
        device=device,
        visualize=True,
        save_dir='attention_plots'
    )

if __name__ == "__main__":
    main()

Using device: cuda
Building vocabulary...


Building vocabulary: 100%|██████████████████████████████████████████████████████| 50000/50000 [00:05<00:00, 9081.17it/s]


Vocabulary size: 39412

Analyzing text: This movie was absolutely fantastic! The acting was superb and the story was engaging throughout.
Prediction: Positive (confidence: 0.58)

Generating attention visualizations...
