# 📧 Step 4: Real Email Classification with LoRA

## Week 7-8: Practical LoRA Application

Now we'll build the **real project** required by the curriculum: a practical email/ticket classifier using LoRA fine-tuning.

### 🎯 What You'll Learn:
1. **Real dataset handling** - Working with actual email data
2. **Data preprocessing** - Cleaning and preparing text for training
3. **Dataset augmentation** - Creating more training data
4. **Business problem solving** - Email classification is used in real companies
5. **Production considerations** - How to make this work in the real world

### 🏢 Business Context:
Email classification is crucial for:
- **Customer Support**: Routing tickets to right teams
- **Sales**: Prioritizing leads and opportunities  
- **Security**: Detecting spam and phishing
- **Organization**: Auto-categorizing internal communications

This is a **real problem** that AI engineers solve every day!

In [None]:
# Essential imports for our email classification project
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
import re
import random
from typing import List, Dict, Tuple
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

print("📧 Email Classification Project Setup Complete!")
print(f"Using PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 🗂️ Part 1: Creating Realistic Email Dataset

Since we need a **real email classification dataset**, we'll create one that mimics actual business emails.

### 🧠 Learning Objective:
- **Data generation for ML**: Often you need to create or augment datasets
- **Business domain understanding**: What makes emails different categories?
- **Text preprocessing**: How to clean and prepare text data

In [None]:
class EmailDatasetGenerator:
    """
    Generates realistic email dataset for classification training
    
    This teaches us:
    1. How to structure business problems
    2. What makes good training data
    3. Domain knowledge importance
    """
    
    def __init__(self):
        # Email categories we'll classify (common business use case)
        self.categories = {
            'urgent': 0,      # High priority - needs immediate attention
            'support': 1,     # Customer support requests
            'sales': 2,       # Sales inquiries and leads
            'spam': 3,        # Spam/promotional emails
            'normal': 4       # Regular business communication
        }
        
        # Templates for each category (this mimics real business emails)
        self.email_templates = {
            'urgent': [
                "URGENT: System is down and affecting all customers. Please respond immediately.",
                "CRITICAL: Security breach detected. Need immediate action from the team.",
                "EMERGENCY: Client meeting in 1 hour, presentation file corrupted. Help needed now!",
                "URGENT RESPONSE NEEDED: Major bug in production causing revenue loss.",
                "CRITICAL ISSUE: Payment system not working, customers cannot complete purchases.",
                "IMMEDIATE ACTION REQUIRED: Server outage affecting 1000+ users.",
                "URGENT: CEO wants status update on project by end of day.",
                "EMERGENCY: Database corruption, need to restore from backup immediately."
            ],
            'support': [
                "Hi, I'm having trouble logging into my account. Can you help me reset my password?",
                "Hello, the software keeps crashing when I try to export data. What should I do?",
                "I can't find the feature you mentioned in the tutorial. Could you guide me?",
                "The mobile app is not syncing with my desktop. How can I fix this?",
                "I accidentally deleted important files. Is there a way to recover them?",
                "The payment didn't go through but money was deducted. Please check my account.",
                "I need help setting up the integration with our existing system.",
                "The dashboard is showing incorrect data. Can someone look into this?"
            ],
            'sales': [
                "I'm interested in your enterprise plan. Can you send me pricing information?",
                "We're looking for a solution for our team of 50 people. What do you recommend?",
                "Could we schedule a demo to see how your product fits our needs?",
                "I saw your product at the conference. Can we discuss a potential partnership?",
                "Our company is expanding and we need a scalable solution. Let's talk.",
                "I'm evaluating different vendors. What makes your product unique?",
                "We have a budget approved for this quarter. When can we start implementation?",
                "Can you provide a quote for 200 user licenses with premium support?"
            ],
            'spam': [
                "Congratulations! You've won $1,000,000! Click here to claim your prize now!",
                "AMAZING OFFER: Buy one get one free! Limited time only! Act now!",
                "Your account will be suspended unless you verify immediately. Click this link.",
                "Hot singles in your area want to meet you! Join now for free!",
                "Make money from home! No experience needed! Start earning today!",
                "FINAL NOTICE: Your warranty is about to expire. Renew now for 50% off!",
                "Free iPhone 14! You've been selected! Claim within 24 hours!",
                "URGENT: Your PayPal account has been compromised. Verify now!"
            ],
            'normal': [
                "Thanks for the meeting yesterday. Here are the notes we discussed.",
                "The project timeline looks good. Let's proceed with the next phase.",
                "I've reviewed the documents and have a few questions. Can we chat tomorrow?",
                "The team meeting is scheduled for Friday at 2 PM in conference room B.",
                "Please find attached the monthly report for your review.",
                "I'll be out of office next week. John will handle any urgent matters.",
                "The client approved the proposal. We can start implementation next month.",
                "Can you send me the latest version of the design mockups?"
            ]
        }
        
        # Subject line patterns (important for email classification)
        self.subject_patterns = {
            'urgent': ['URGENT:', 'CRITICAL:', 'EMERGENCY:', 'IMMEDIATE:', 'ASAP:'],
            'support': ['Help:', 'Issue:', 'Problem:', 'Question:', 'Trouble:'],
            'sales': ['Inquiry:', 'Demo:', 'Partnership:', 'Quote:', 'Pricing:'],
            'spam': ['FREE!', 'WIN!', 'AMAZING!', 'CONGRATULATIONS!', 'LIMITED TIME!'],
            'normal': ['Re:', 'Meeting:', 'Update:', 'FYI:', 'Report:']
        }
    
    def generate_email(self, category: str) -> Dict[str, str]:
        """
        Generate a single email for the given category
        
        This teaches us how to create realistic training data
        """
        # Choose random template and subject pattern
        body = random.choice(self.email_templates[category])
        subject_prefix = random.choice(self.subject_patterns[category])
        
        # Add some variation to make it more realistic
        variations = [
            lambda text: text + " Please let me know if you need any clarification.",
            lambda text: "Hi team, " + text.lower(),
            lambda text: text + " Thanks for your help!",
            lambda text: text + " Best regards, " + random.choice(["John", "Sarah", "Mike", "Emma"]),
            lambda text: text  # No variation
        ]
        
        # Apply random variation
        if random.random() > 0.3:  # 70% chance of variation
            body = random.choice(variations)(body)
        
        # Create subject line
        subject = f"{subject_prefix} {body.split('.')[0][:50]}..."
        
        return {
            'subject': subject,
            'body': body,
            'category': category,
            'label': self.categories[category]
        }
    
    def generate_dataset(self, samples_per_category: int = 200) -> pd.DataFrame:
        """
        Generate a complete dataset
        
        This teaches us about balanced datasets and data distribution
        """
        print(f"🏗️  Generating realistic email dataset...")
        print(f"   Categories: {list(self.categories.keys())}")
        print(f"   Samples per category: {samples_per_category}")
        
        emails = []
        
        for category in self.categories.keys():
            for _ in range(samples_per_category):
                email = self.generate_email(category)
                emails.append(email)
        
        # Shuffle the dataset
        random.shuffle(emails)
        
        df = pd.DataFrame(emails)
        
        print(f"✅ Dataset generated successfully!")
        print(f"   Total samples: {len(df)}")
        print(f"   Class distribution:")
        
        class_counts = df['category'].value_counts()
        for category, count in class_counts.items():
            print(f"      {category}: {count} samples")
        
        return df

# Generate our dataset
generator = EmailDatasetGenerator()
email_df = generator.generate_dataset(samples_per_category=150)  # 750 total samples

# Display sample emails
print("\n📧 Sample Emails:")
print("=" * 50)

for category in ['urgent', 'support', 'sales', 'spam', 'normal']:
    sample = email_df[email_df['category'] == category].iloc[0]
    print(f"\n🏷️  {category.upper()}:")
    print(f"   Subject: {sample['subject'][:60]}...")
    print(f"   Body: {sample['body'][:80]}...")

## 🔍 Part 2: Data Analysis and Preprocessing

### 🧠 Learning Objective:
**Understanding your data is crucial for ML success!** We'll analyze patterns, clean text, and prepare for training.

This teaches us:
- **Exploratory Data Analysis (EDA)**: Understanding data patterns
- **Text preprocessing**: Essential for NLP tasks
- **Feature engineering**: Creating better inputs for models

In [None]:
def analyze_email_dataset(df: pd.DataFrame):
    """
    Comprehensive analysis of our email dataset
    
    This teaches us how to understand our data before training
    """
    print("📊 Email Dataset Analysis")
    print("=" * 40)
    
    # Basic statistics
    print(f"Dataset shape: {df.shape}")
    print(f"Categories: {df['category'].unique()}")
    print(f"Missing values: {df.isnull().sum().sum()}")
    
    # Text length analysis
    df['subject_length'] = df['subject'].str.len()
    df['body_length'] = df['body'].str.len()
    df['total_length'] = df['subject_length'] + df['body_length']
    
    # Create visualizations
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # 1. Class distribution
    category_counts = df['category'].value_counts()
    axes[0,0].bar(category_counts.index, category_counts.values, color='skyblue')
    axes[0,0].set_title('Email Category Distribution')
    axes[0,0].set_xlabel('Category')
    axes[0,0].set_ylabel('Count')
    axes[0,0].tick_params(axis='x', rotation=45)
    
    # 2. Text length distribution by category
    for category in df['category'].unique():
        category_data = df[df['category'] == category]['total_length']
        axes[0,1].hist(category_data, alpha=0.7, label=category, bins=20)
    
    axes[0,1].set_title('Text Length Distribution by Category')
    axes[0,1].set_xlabel('Total Text Length (characters)')
    axes[0,1].set_ylabel('Frequency')
    axes[0,1].legend()
    
    # 3. Average text lengths by category
    avg_lengths = df.groupby('category')['total_length'].mean().sort_values(ascending=False)
    axes[1,0].bar(avg_lengths.index, avg_lengths.values, color='lightcoral')
    axes[1,0].set_title('Average Text Length by Category')
    axes[1,0].set_xlabel('Category')
    axes[1,0].set_ylabel('Average Length (characters)')
    axes[1,0].tick_params(axis='x', rotation=45)
    
    # 4. Word count analysis
    df['word_count'] = df['body'].str.split().str.len()
    word_stats = df.groupby('category')['word_count'].agg(['mean', 'std']).round(2)
    
    x_pos = np.arange(len(word_stats.index))
    axes[1,1].bar(x_pos, word_stats['mean'], yerr=word_stats['std'], 
                  capsize=5, color='lightgreen', alpha=0.8)
    axes[1,1].set_title('Average Word Count by Category')
    axes[1,1].set_xlabel('Category')
    axes[1,1].set_ylabel('Average Word Count')
    axes[1,1].set_xticks(x_pos)
    axes[1,1].set_xticklabels(word_stats.index, rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed statistics
    print("\n📈 Detailed Statistics:")
    print(f"Average text lengths by category:")
    for category, length in avg_lengths.items():
        print(f"   {category}: {length:.1f} characters")
    
    print(f"\nWord count statistics:")
    for category in word_stats.index:
        mean_words = word_stats.loc[category, 'mean']
        std_words = word_stats.loc[category, 'std']
        print(f"   {category}: {mean_words:.1f} ± {std_words:.1f} words")
    
    return df

# Analyze our dataset
email_df = analyze_email_dataset(email_df)

In [None]:
class EmailTextPreprocessor:
    """
    Preprocesses email text for better model performance
    
    This teaches us essential text preprocessing techniques:
    1. Cleaning and normalization
    2. Feature engineering for emails
    3. Preparing text for transformer models
    """
    
    def __init__(self):
        # Common email patterns to clean
        self.email_patterns = {
            'email_addresses': r'\S+@\S+\.\S+',
            'urls': r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+',
            'phone_numbers': r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b',
            'excessive_whitespace': r'\s+',
            'special_chars': r'[^a-zA-Z0-9\s.,!?;:]'
        }
    
    def clean_text(self, text: str) -> str:
        """
        Clean and normalize email text
        
        This is crucial for good model performance!
        """
        if pd.isna(text):
            return ""
        
        # Convert to lowercase
        text = text.lower()
        
        # Remove email addresses (replace with token)
        text = re.sub(self.email_patterns['email_addresses'], '[EMAIL]', text)
        
        # Remove URLs (replace with token)
        text = re.sub(self.email_patterns['urls'], '[URL]', text)
        
        # Remove phone numbers (replace with token)
        text = re.sub(self.email_patterns['phone_numbers'], '[PHONE]', text)
        
        # Clean excessive whitespace
        text = re.sub(self.email_patterns['excessive_whitespace'], ' ', text)
        
        # Remove excessive punctuation
        text = re.sub(r'[!]{2,}', '!', text)
        text = re.sub(r'[?]{2,}', '?', text)
        
        # Strip whitespace
        text = text.strip()
        
        return text
    
    def create_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Create additional features that help with classification
        
        This teaches us feature engineering for NLP
        """
        print("🔧 Creating additional features...")
        
        # Clean the text
        df['clean_subject'] = df['subject'].apply(self.clean_text)
        df['clean_body'] = df['body'].apply(self.clean_text)
        
        # Combine subject and body (important for email classification)
        df['combined_text'] = df['clean_subject'] + ' ' + df['clean_body']
        
        # Feature engineering - these help models understand email patterns
        df['has_urgent_words'] = df['combined_text'].str.contains(
            r'urgent|emergency|critical|asap|immediate', case=False
        ).astype(int)
        
        df['has_question_marks'] = df['combined_text'].str.count('\?')
        df['has_exclamation'] = df['combined_text'].str.count('!')
        df['has_caps'] = df['subject'].str.isupper().astype(int)
        
        # Count spam indicators
        spam_words = ['free', 'win', 'prize', 'click', 'offer', 'limited time']
        df['spam_word_count'] = df['combined_text'].apply(
            lambda x: sum(1 for word in spam_words if word in x.lower())
        )
        
        print("✅ Feature engineering complete!")
        print(f"   Created features: has_urgent_words, has_question_marks, has_exclamation, has_caps, spam_word_count")
        
        return df
    
    def prepare_for_training(self, df: pd.DataFrame, test_size: float = 0.2) -> Tuple:
        """
        Prepare the dataset for training
        
        This teaches us proper train/test splitting and data preparation
        """
        print(f"🎯 Preparing dataset for training...")
        
        # Create features
        df = self.create_features(df)
        
        # Features for training (we'll use the combined text)
        X = df['combined_text'].values
        y = df['label'].values
        
        # Split the dataset
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_size, random_state=42, stratify=y
        )
        
        print(f"✅ Dataset split complete!")
        print(f"   Training samples: {len(X_train)}")
        print(f"   Testing samples: {len(X_test)}")
        print(f"   Classes: {len(np.unique(y))}")
        
        # Print class distribution
        train_dist = pd.Series(y_train).value_counts().sort_index()
        test_dist = pd.Series(y_test).value_counts().sort_index()
        
        print(f"\n   Training distribution:")
        for label, count in train_dist.items():
            category = list(generator.categories.keys())[list(generator.categories.values()).index(label)]
            print(f"      {category}: {count} samples")
        
        return X_train, X_test, y_train, y_test, df

# Preprocess our data
preprocessor = EmailTextPreprocessor()
X_train, X_test, y_train, y_test, processed_df = preprocessor.prepare_for_training(email_df)

# Show sample processed data
print("\n📧 Sample Processed Emails:")
print("=" * 50)

for i in range(3):
    original_text = email_df.iloc[i]['subject'] + ' ' + email_df.iloc[i]['body']
    processed_text = processed_df.iloc[i]['combined_text']
    category = processed_df.iloc[i]['category']
    
    print(f"\n🏷️  {category.upper()}:")
    print(f"   Original: {original_text[:100]}...")
    print(f"   Processed: {processed_text[:100]}...")
    print(f"   Features: urgent={processed_df.iloc[i]['has_urgent_words']}, "
          f"caps={processed_df.iloc[i]['has_caps']}, "
          f"spam_words={processed_df.iloc[i]['spam_word_count']}")

## 🏗️ Part 3: LoRA Email Classifier Architecture

### 🧠 Learning Objective:
Now we'll build our **complete email classifier** using LoRA. This teaches us:

- **Architecture design**: How to structure ML systems
- **LoRA integration**: Practical application of our previous learning
- **Custom datasets**: How to create PyTorch datasets
- **End-to-end pipeline**: From raw text to predictions

In [None]:
# Import our LoRA implementation (we'll use the advanced version from Step 3)
import sys
import os

# We'll recreate the necessary LoRA classes here for this notebook
class AdvancedLoRALayer(nn.Module):
    """
    Our proven LoRA implementation from Step 3
    """
    
    def __init__(
        self, 
        original_layer: nn.Linear,
        rank: int = 4,
        alpha: float = 32.0,
        dropout: float = 0.1,
        init_lora_weights: bool = True
    ):
        super().__init__()
        
        self.original_layer = original_layer
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        
        # Get dimensions
        self.in_features = original_layer.in_features
        self.out_features = original_layer.out_features
        
        # LoRA parameters
        self.lora_A = nn.Parameter(torch.empty(rank, self.in_features))
        self.lora_B = nn.Parameter(torch.empty(self.out_features, rank))
        
        # Dropout
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        
        # Initialize weights
        if init_lora_weights:
            self.reset_lora_parameters()
        
        # Freeze original parameters
        for param in self.original_layer.parameters():
            param.requires_grad = False
    
    def reset_lora_parameters(self):
        import math
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        result = self.original_layer(x)
        
        # LoRA forward pass
        lora_output = F.linear(x, self.lora_A)
        lora_output = self.dropout(lora_output)
        lora_output = F.linear(lora_output, self.lora_B.T)
        
        result += lora_output * self.scaling
        return result
    
    def get_lora_parameters(self):
        return [self.lora_A, self.lora_B]

class EmailDataset(Dataset):
    """
    Custom PyTorch dataset for email classification
    
    This teaches us how to create custom datasets for specific tasks
    """
    
    def __init__(self, texts: List[str], labels: List[int], tokenizer, max_length: int = 128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        # Tokenize the text
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

class LoRAEmailClassifier(nn.Module):
    """
    Complete email classifier with LoRA fine-tuning
    
    This is our main model that combines everything we've learned
    """
    
    def __init__(
        self, 
        model_name: str = 'distilbert-base-uncased',
        num_classes: int = 5,
        lora_rank: int = 8,
        lora_alpha: float = 16.0,
        lora_dropout: float = 0.1
    ):
        super().__init__()
        
        print(f"🏗️  Building LoRA Email Classifier...")
        print(f"   Base model: {model_name}")
        print(f"   Classes: {num_classes}")
        print(f"   LoRA rank: {lora_rank}, alpha: {lora_alpha}")
        
        # Load base model
        self.backbone = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # Apply LoRA to attention layers
        self.lora_layers = {}
        self._apply_lora(lora_rank, lora_alpha, lora_dropout)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.backbone.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
        
        # Initialize classifier weights
        for layer in self.classifier:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                nn.init.zeros_(layer.bias)
        
        self._print_parameter_stats()
    
    def _apply_lora(self, rank: int, alpha: float, dropout: float):
        """
        Apply LoRA to attention layers
        """
        target_modules = ['query', 'key', 'value']  # Focus on attention
        replaced_count = 0
        
        for name, module in self.backbone.named_modules():
            if any(target in name for target in target_modules) and isinstance(module, nn.Linear):
                # Replace with LoRA version
                lora_layer = AdvancedLoRALayer(
                    module, rank=rank, alpha=alpha, dropout=dropout
                )
                
                # Set the module in the model
                parent_module = self.backbone
                module_parts = name.split('.')
                
                for part in module_parts[:-1]:
                    parent_module = getattr(parent_module, part)
                
                setattr(parent_module, module_parts[-1], lora_layer)
                self.lora_layers[name] = lora_layer
                replaced_count += 1
        
        print(f"   ✅ Applied LoRA to {replaced_count} layers")
    
    def _print_parameter_stats(self):
        """
        Print parameter statistics
        """
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        
        print(f"\n📊 Model Statistics:")
        print(f"   Total parameters: {total_params:,}")
        print(f"   Trainable parameters: {trainable_params:,}")
        print(f"   Trainable percentage: {100 * trainable_params / total_params:.2f}%")
        print(f"   Memory reduction: {total_params / trainable_params:.1f}x")
    
    def forward(self, input_ids, attention_mask, labels=None):
        """
        Forward pass through the classifier
        """
        # Get backbone outputs
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        # Use [CLS] token for classification
        cls_output = outputs.last_hidden_state[:, 0]  # [CLS] token
        
        # Classification
        logits = self.classifier(cls_output)
        
        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits, labels)
        
        return {
            'logits': logits,
            'loss': loss,
            'predictions': torch.argmax(logits, dim=-1)
        }
    
    def get_trainable_parameters(self):
        """
        Get all trainable parameters for the optimizer
        """
        return [p for p in self.parameters() if p.requires_grad]

# Create our email classifier
model = LoRAEmailClassifier(
    model_name='distilbert-base-uncased',
    num_classes=5,
    lora_rank=8,
    lora_alpha=16.0
)

print("\n✅ LoRA Email Classifier built successfully!")

## 🎯 Part 4: Training Pipeline Setup

### 🧠 Learning Objective:
Now we'll create a **professional training pipeline** that you'd use in real projects. This teaches us:

- **Training loop design**: How to structure training for production
- **Monitoring and logging**: Track progress and debug issues
- **Validation strategy**: Ensure our model generalizes well
- **Best practices**: Professional ML engineering techniques

In [None]:
# Create datasets and dataloaders
print("🔄 Creating training and validation datasets...")

# Create datasets
train_dataset = EmailDataset(
    texts=X_train.tolist(),
    labels=y_train.tolist(),
    tokenizer=model.tokenizer,
    max_length=128
)

test_dataset = EmailDataset(
    texts=X_test.tolist(),
    labels=y_test.tolist(),
    tokenizer=model.tokenizer,
    max_length=128
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=16,  # Small batch size for memory efficiency
    shuffle=True,
    num_workers=0  # Set to 0 for Windows compatibility
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,  # Larger batch for evaluation
    shuffle=False,
    num_workers=0
)

print(f"✅ Datasets created:")
print(f"   Training batches: {len(train_loader)}")
print(f"   Testing batches: {len(test_loader)}")
print(f"   Training batch size: 16")
print(f"   Testing batch size: 32")

# Test a batch to make sure everything works
sample_batch = next(iter(train_loader))
print(f"\n🧪 Sample batch shapes:")
print(f"   Input IDs: {sample_batch['input_ids'].shape}")
print(f"   Attention mask: {sample_batch['attention_mask'].shape}")
print(f"   Labels: {sample_batch['labels'].shape}")

# Test forward pass
model.eval()
with torch.no_grad():
    outputs = model(
        input_ids=sample_batch['input_ids'],
        attention_mask=sample_batch['attention_mask'],
        labels=sample_batch['labels']
    )

print(f"\n✅ Forward pass test successful:")
print(f"   Output logits shape: {outputs['logits'].shape}")
print(f"   Loss: {outputs['loss'].item():.4f}")
print(f"   Predictions shape: {outputs['predictions'].shape}")