# Explainable DistilBERT for Phishing Email Detection

This notebook demonstrates how to build and explain a transformer-based model (DistilBERT) for phishing email classification with comprehensive explainability analysis using LIME and SHAP.

## Overview
- **Model**: DistilBERT for sequence classification
- **Dataset**: Phishing vs legitimate emails  
- **Explainability**: LIME and SHAP for model interpretability
- **Evaluation**: Comprehensive metrics and visualizations

## 1. Import Required Libraries

In [17]:
# Core libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# PyTorch and transformers
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    DistilBertTokenizer, 
    DistilBertForSequenceClassification,
    TrainingArguments, 
    Trainer,
    pipeline
)

# Sklearn for metrics and preprocessing
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, 
    confusion_matrix, classification_report,
    precision_recall_curve, average_precision_score
)
from sklearn.model_selection import train_test_split

# Explainability libraries
try:
    import lime
    from lime.lime_text import LimeTextExplainer
    print("LIME imported successfully")
except ImportError:
    print("LIME not installed. Install with: pip install lime")
    
try:
    import shap
    print("SHAP imported successfully")
except ImportError:
    print("SHAP not installed. Install with: pip install shap")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("All libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

LIME not installed. Install with: pip install lime
SHAP not installed. Install with: pip install shap
All libraries imported successfully!
PyTorch version: 2.9.1+cpu
CUDA available: False
Using device: cpu


## 2. Load and Prepare Data

In [6]:
# Load the preprocessed data
try:
    train_df = pd.read_csv("data/preprocessing/train.csv")
    val_df = pd.read_csv("data/preprocessing/val.csv")
    test_df = pd.read_csv("data/preprocessing/test.csv")
    print("‚úì Data loaded successfully from existing splits")
except FileNotFoundError:
    print("Preprocessed data not found. Loading raw data...")
    # Add code here to load and preprocess raw data if needed

# Combine subject and body text
def combine_text(row):
    subject = str(row['subject']) if pd.notna(row['subject']) else ""
    body = str(row['body']) if pd.notna(row['body']) else ""
    return f"{subject} {body}".strip()

# Apply text combination to all datasets
for df in [train_df, val_df, test_df]:
    df['text'] = df.apply(combine_text, axis=1)

# Data overview
print("\nüìä Dataset Overview:")
print(f"Train samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")
print(f"Test samples: {len(test_df)}")
print(f"\nClass distribution in training set:")
print(train_df['label'].value_counts().rename({0: 'Legitimate', 1: 'Phishing'}))

# Text length analysis
train_df['text_length'] = train_df['text'].str.len()
print(f"\nText length statistics:")
print(f"Mean: {train_df['text_length'].mean():.0f} characters")
print(f"Median: {train_df['text_length'].median():.0f} characters")
print(f"Max: {train_df['text_length'].max()} characters")

# Show sample texts
print("\nüìù Sample legitimate email:")
legit_sample = train_df[train_df['label'] == 0]['text'].iloc[0]
print(f"'{legit_sample[:200]}...'")

print("\nüé£ Sample phishing email:")
phish_sample = train_df[train_df['label'] == 1]['text'].iloc[0]
print(f"'{phish_sample[:200]}...'")

# Prepare data for model training
X_train = train_df['text'].tolist()
y_train = train_df['label'].tolist()
X_val = val_df['text'].tolist()
y_val = val_df['label'].tolist()
X_test = test_df['text'].tolist()
y_test = test_df['label'].tolist()

print(f"\n‚úì Data prepared for model training!")

‚úì Data loaded successfully from existing splits

üìä Dataset Overview:
Train samples: 57325
Validation samples: 12284
Test samples: 12284

Class distribution in training set:
label
Phishing      29841
Legitimate    27484
Name: count, dtype: int64

Text length statistics:
Mean: 1738 characters
Median: 733 characters
Max: 4546559 characters

üìù Sample legitimate email:
'Re: [Python-3000] Types and classes On Wed, Apr 2, 2008 at 5:51 PM, Guido van Rossum wrote: > I have no idea what you are saying here (and I did s/since/sense/ :-). Another lesson to me, that I should...'

üé£ Sample phishing email:
'Engaging RX Offers Superior Medical Reductions http://smallworldtho.spaces.live.com/default.aspx environment. In other In a way that makes you Linda F., New York...'

‚úì Data prepared for model training!

üìä Dataset Overview:
Train samples: 57325
Validation samples: 12284
Test samples: 12284

Class distribution in training set:
label
Phishing      29841
Legitimate    27484
Name: coun

## 3. Text Preprocessing for Transformers

In [7]:
# Initialize DistilBERT tokenizer
model_name = "distilbert-base-uncased"
tokenizer = DistilBertTokenizer.from_pretrained(model_name)

# Define maximum sequence length (DistilBERT max is 512)
MAX_LENGTH = 512

class EmailDataset(Dataset):
    """Custom dataset class for email classification"""
    
    def __init__(self, texts, labels, tokenizer, max_length=MAX_LENGTH):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        # Tokenize the text
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Create datasets
print("üîÑ Creating datasets...")
train_dataset = EmailDataset(X_train, y_train, tokenizer)
val_dataset = EmailDataset(X_val, y_val, tokenizer)
test_dataset = EmailDataset(X_test, y_test, tokenizer)

print(f"‚úì Train dataset: {len(train_dataset)} samples")
print(f"‚úì Validation dataset: {len(val_dataset)} samples") 
print(f"‚úì Test dataset: {len(test_dataset)} samples")

# Test tokenization on a sample
sample_text = X_train[0][:100]  # First 100 chars
sample_encoding = tokenizer(
    sample_text,
    truncation=True,
    padding='max_length', 
    max_length=50,  # Smaller for display
    return_tensors='pt'
)

print(f"\nüìù Tokenization example:")
print(f"Original text: '{sample_text}'")
print(f"Token IDs shape: {sample_encoding['input_ids'].shape}")
print(f"First 10 token IDs: {sample_encoding['input_ids'][0][:10].tolist()}")
print(f"Decoded tokens: {tokenizer.convert_ids_to_tokens(sample_encoding['input_ids'][0][:10])}")

# Analyze text lengths after tokenization
sample_lengths = []
for text in X_train[:1000]:  # Sample first 1000 for speed
    tokens = tokenizer(text, truncation=False, return_tensors='pt')
    sample_lengths.append(tokens['input_ids'].shape[1])

print(f"\nüìä Token length statistics (sample):")
print(f"Mean tokens: {np.mean(sample_lengths):.1f}")
print(f"95th percentile: {np.percentile(sample_lengths, 95):.0f}")
print(f"Texts truncated at {MAX_LENGTH}: {sum(1 for x in sample_lengths if x > MAX_LENGTH)} ({sum(1 for x in sample_lengths if x > MAX_LENGTH)/len(sample_lengths)*100:.1f}%)")

Token indices sequence length is longer than the specified maximum sequence length for this model (1671 > 512). Running this sequence through the model will result in indexing errors


üîÑ Creating datasets...
‚úì Train dataset: 57325 samples
‚úì Validation dataset: 12284 samples
‚úì Test dataset: 12284 samples

üìù Tokenization example:
Original text: 'Engaging RX Offers Superior Medical Reductions http://smallworldtho.spaces.live.com/default.aspx env'
Token IDs shape: torch.Size([1, 50])
First 10 token IDs: [101, 11973, 1054, 2595, 4107, 6020, 2966, 25006, 8299, 1024]
Decoded tokens: ['[CLS]', 'engaging', 'r', '##x', 'offers', 'superior', 'medical', 'reductions', 'http', ':']

üìä Token length statistics (sample):
Mean tokens: 432.2
95th percentile: 1228
Texts truncated at 512: 226 (22.6%)

üìä Token length statistics (sample):
Mean tokens: 432.2
95th percentile: 1228
Texts truncated at 512: 226 (22.6%)


In [None]:
# Check accelerate version for Trainer compatibility
try:
    import accelerate
    print(f"üîç Current accelerate version: {accelerate.__version__}")
    
    # Try importing Trainer to see if it works
    from transformers import Trainer
    print("‚úì Trainer imported successfully!")
    
except ImportError as e:
    if "accelerate" in str(e):
        print(f"‚ùå Accelerate import error: {str(e)}")
        print("üîß Trying to fix accelerate installation...")
        
        # Try to install/upgrade accelerate
        import subprocess
        import sys
        
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "accelerate>=0.26.0"])
            print("‚úì Accelerate upgraded successfully!")
            
            # Try importing again
            import accelerate
            from transformers import Trainer
            print(f"‚úì Now using accelerate version: {accelerate.__version__}")
            
        except Exception as install_error:
            print(f"‚ùå Failed to fix accelerate: {str(install_error)}")
            print("Manual fix needed: pip install --upgrade 'accelerate>=0.26.0'")
    else:
        print(f"‚ùå Other import error: {str(e)}")

# Also check transformers version
import transformers
print(f"üîç Transformers version: {transformers.__version__}")

# Try importing all required components
try:
    from transformers import TrainingArguments, Trainer
    print("‚úì All training components imported successfully!")
    trainer_available = True
except ImportError as e:
    print(f"‚ùå Training components import failed: {str(e)}")
    trainer_available = False

if trainer_available:
    print("üöÄ Ready to proceed with training!")
else:
    print("‚ö†Ô∏è Training may not work. Please restart kernel and try again.")

üîç Checking accelerate version: 1.12.0
‚úì accelerate version 1.12.0 is compatible with Trainer
üîç Transformers version: 4.57.1
‚úì All versions compatible for training


## 4. Create DistilBERT Pipeline

In [None]:
# Load pre-trained DistilBERT model for sequence classification
print("üîÑ Loading DistilBERT model...")
model = DistilBertForSequenceClassification.from_pretrained(
    model_name,
    num_labels=2,  # Binary classification: phishing vs legitimate
    output_attentions=True,  # Enable attention weights for explainability
    output_hidden_states=False
)

# Move model to device
model.to(device)

print(f"‚úì Model loaded: {model.__class__.__name__}")
print(f"‚úì Number of parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"‚úì Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Try to use Trainer, fallback to manual training if it fails
try:
    from transformers import TrainingArguments, Trainer
    
    # Define training arguments
    training_args = TrainingArguments(
        output_dir="./models/distilbert-phishing",
        num_train_epochs=3,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=32,
        warmup_steps=100,
        weight_decay=0.01,
        logging_dir="./logs",
        logging_steps=50,
        eval_strategy="steps",
        eval_steps=200,
        save_strategy="steps",
        save_steps=400,  # Changed from 500 to 400 (multiple of 200)
        load_best_model_at_end=True,
        metric_for_best_model="eval_f1",
        greater_is_better=True,
        seed=42,
        dataloader_pin_memory=False,  # May help with memory issues
        remove_unused_columns=False
    )

    # Define metrics computation function
    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)
        
        precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
        accuracy = accuracy_score(labels, predictions)
        
        return {
            'accuracy': accuracy,
            'f1': f1,
            'precision': precision,
            'recall': recall
        }

    # Create trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )

    print("‚úì Trainer initialized successfully!")
    print(f"\nüìä Training configuration:")
    print(f"  - Epochs: {training_args.num_train_epochs}")
    print(f"  - Batch size: {training_args.per_device_train_batch_size}")
    print(f"  - Learning rate: {training_args.learning_rate}")
    print(f"  - Warmup steps: {training_args.warmup_steps}")
    print(f"  - Weight decay: {training_args.weight_decay}")
    print(f"  - Evaluation steps: {training_args.eval_steps}")
    print(f"  - Save steps: {training_args.save_steps}")
    
    use_trainer = True

except ImportError as e:
    print(f"‚ùå Trainer import failed: {str(e)}")
    print("üîÑ Setting up manual training instead...")
    
    # Manual training setup
    from torch.optim import AdamW
    from torch.utils.data import DataLoader
    
    # Create data loaders
    train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    
    # Setup optimizer
    optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
    
    print("‚úì Manual training setup complete!")
    print("üìä Training will use manual PyTorch training loop")
    
    use_trainer = False
    
    # Store training config for manual training
    manual_training_config = {
        'num_epochs': 3,
        'batch_size': 16,
        'learning_rate': 5e-5,
        'weight_decay': 0.01
    }

print(f"\nüéØ Training method: {'Hugging Face Trainer' if use_trainer else 'Manual PyTorch'}")

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


üîÑ Loading DistilBERT model...
‚úì Model loaded: DistilBertForSequenceClassification
‚úì Number of parameters: 66,955,010
‚úì Trainable parameters: 66,955,010


ImportError: Using the `Trainer` with `PyTorch` requires `accelerate>=0.26.0`: Please run `pip install transformers[torch]` or `pip install 'accelerate>=0.26.0'`

## 5. Train the Model

In [None]:
# Check if pre-trained model exists
model_save_path = Path("./models/distilbert-phishing")
if model_save_path.exists():
    print("üîÑ Loading pre-trained model...")
    model = DistilBertForSequenceClassification.from_pretrained(
        model_save_path,
        output_attentions=True
    ).to(device)
    tokenizer = DistilBertTokenizer.from_pretrained(model_save_path)
    print("‚úì Pre-trained model loaded successfully!")
else:
    print("üöÄ Starting model training...")
    print("This may take 10-30 minutes depending on your hardware...")
    
    if use_trainer:
        # Use Hugging Face Trainer
        try:
            trainer.train()
            print("‚úì Training completed successfully!")
            
            # Save the best model
            trainer.save_model()
            tokenizer.save_pretrained(training_args.output_dir)
            print(f"‚úì Model saved to {training_args.output_dir}")
            
        except Exception as e:
            print(f"‚ùå Trainer failed: {str(e)}")
            print("üîÑ Switching to manual training...")
            use_trainer = False
    
    if not use_trainer:
        # Manual training loop
        print("üìö Starting manual training...")
        
        model.train()
        total_steps = len(train_dataloader) * manual_training_config['num_epochs']
        
        for epoch in range(manual_training_config['num_epochs']):
            print(f"\nüîÑ Epoch {epoch + 1}/{manual_training_config['num_epochs']}")
            total_loss = 0
            
            for batch_idx, batch in enumerate(train_dataloader):
                # Move batch to device
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                # Forward pass
                outputs = model(input_ids=input_ids, 
                              attention_mask=attention_mask, 
                              labels=labels)
                loss = outputs.loss
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                
                # Print progress
                if batch_idx % 50 == 0:
                    print(f"  Batch {batch_idx}/{len(train_dataloader)}, Loss: {loss.item():.4f}")
            
            avg_loss = total_loss / len(train_dataloader)
            print(f"  Average Loss: {avg_loss:.4f}")
            
            # Validation
            model.eval()
            val_loss = 0
            val_correct = 0
            val_total = 0
            
            with torch.no_grad():
                for batch in val_dataloader:
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    labels = batch['labels'].to(device)
                    
                    outputs = model(input_ids=input_ids, 
                                  attention_mask=attention_mask, 
                                  labels=labels)
                    val_loss += outputs.loss.item()
                    
                    predictions = torch.argmax(outputs.logits, dim=-1)
                    val_correct += (predictions == labels).sum().item()
                    val_total += labels.size(0)
            
            val_accuracy = val_correct / val_total
            print(f"  Validation Loss: {val_loss/len(val_dataloader):.4f}")
            print(f"  Validation Accuracy: {val_accuracy:.4f}")
            
            model.train()
        
        # Save manually trained model
        model_save_path.mkdir(parents=True, exist_ok=True)
        model.save_pretrained(model_save_path)
        tokenizer.save_pretrained(model_save_path)
        print(f"‚úì Manual training completed! Model saved to {model_save_path}")

# Load the best model for evaluation
model = DistilBertForSequenceClassification.from_pretrained(
    model_save_path,
    output_attentions=True
).to(device)

print("\nüéØ Model training summary:")
print(f"‚úì Model successfully trained and loaded")
print(f"‚úì Model location: {model_save_path}")
print(f"‚úì Ready for evaluation and explainability analysis")

## 6. Model Evaluation and Metrics

In [None]:
# Create evaluation pipeline
eval_pipeline = pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1,
    return_all_scores=True
)

# Function to predict with probabilities
def predict_text(text):
    """Predict class and probabilities for a given text"""
    results = eval_pipeline(text)
    probs = {result['label']: result['score'] for result in results}
    predicted_class = max(probs.keys(), key=lambda x: probs[x])
    return predicted_class, probs

# Evaluate on test set
print("üîÑ Evaluating model on test set...")
test_predictions = []
test_probabilities = []

# Batch prediction for efficiency
batch_size = 32
for i in range(0, len(X_test), batch_size):
    batch_texts = X_test[i:i + batch_size]
    batch_results = eval_pipeline(batch_texts)
    
    for result in batch_results:
        # Convert to binary classification format
        if isinstance(result, list):  # Multiple labels returned
            phishing_score = next(r['score'] for r in result if r['label'] == 'LABEL_1')
            prediction = 1 if phishing_score > 0.5 else 0
        else:  # Single label returned
            prediction = 1 if result['label'] == 'LABEL_1' else 0
            phishing_score = result['score'] if result['label'] == 'LABEL_1' else 1 - result['score']
        
        test_predictions.append(prediction)
        test_probabilities.append(phishing_score)

# Convert to numpy arrays
y_pred = np.array(test_predictions)
y_prob = np.array(test_probabilities)

# Calculate metrics
accuracy = accuracy_score(y_test, y_pred)
precision, recall, f1, _ = precision_recall_fscore_support(y_test, y_pred, average='weighted')
conf_matrix = confusion_matrix(y_test, y_pred)

# Calculate PR-AUC
pr_auc = average_precision_score(y_test, y_prob)

print("\nüìä Test Set Performance:")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-Score: {f1:.4f}")
print(f"PR-AUC: {pr_auc:.4f}")

# Detailed classification report
print(f"\nüìã Detailed Classification Report:")
print(classification_report(y_test, y_pred, target_names=['Legitimate', 'Phishing']))

# Confusion Matrix Visualization
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', 
           xticklabels=['Legitimate', 'Phishing'],
           yticklabels=['Legitimate', 'Phishing'])
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')

# Precision-Recall Curve
plt.subplot(1, 2, 2)
precision_curve, recall_curve, _ = precision_recall_curve(y_test, y_prob)
plt.plot(recall_curve, precision_curve, marker='.', label=f'PR-AUC = {pr_auc:.3f}')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# Sample predictions with confidence scores
print(f"\nüîç Sample Predictions:")
for i in range(min(5, len(X_test))):
    text = X_test[i][:100] + "..." if len(X_test[i]) > 100 else X_test[i]
    true_label = "Phishing" if y_test[i] == 1 else "Legitimate"
    pred_label = "Phishing" if y_pred[i] == 1 else "Legitimate"
    confidence = y_prob[i] if y_pred[i] == 1 else 1 - y_prob[i]
    
    status = "‚úì" if y_test[i] == y_pred[i] else "‚úó"
    print(f"{status} True: {true_label:10} | Pred: {pred_label:10} | Conf: {confidence:.3f}")
    print(f"  Text: '{text}'")
    print()

## 7. Model Explainability with LIME

In [None]:
try:
    # Create LIME explainer
    explainer = LimeTextExplainer(class_names=['Legitimate', 'Phishing'])

    # Wrapper function for LIME (needs to return probabilities)
    def predict_proba_wrapper(texts):
        """Wrapper function that returns probabilities for LIME"""
        probabilities = []
        
        for text in texts:
            try:
                results = eval_pipeline(text)
                if isinstance(results, list):
                    # Extract probabilities for both classes
                    probs = [0.0, 0.0]  # [legitimate, phishing]
                    for result in results:
                        if result['label'] == 'LABEL_0':  # Legitimate
                            probs[0] = result['score']
                        elif result['label'] == 'LABEL_1':  # Phishing
                            probs[1] = result['score']
                    probabilities.append(probs)
                else:
                    # Single result case
                    if results['label'] == 'LABEL_0':
                        probabilities.append([results['score'], 1 - results['score']])
                    else:
                        probabilities.append([1 - results['score'], results['score']])
            except Exception as e:
                print(f"Error processing text: {str(e)}")
                probabilities.append([0.5, 0.5])  # Default to uncertain
        
        return np.array(probabilities)

    # Select sample emails for explanation
    # Pick one phishing and one legitimate email
    phishing_idx = next(i for i, label in enumerate(y_test) if label == 1)
    legitimate_idx = next(i for i, label in enumerate(y_test) if label == 0)

    sample_emails = [
        (X_test[phishing_idx], "Phishing", phishing_idx),
        (X_test[legitimate_idx], "Legitimate", legitimate_idx)
    ]

    print("üîç LIME Analysis Results:")
    print("=" * 60)

    for email_text, email_type, idx in sample_emails:
        print(f"\nüìß Analyzing {email_type} Email (Index: {idx}):")
        print(f"Text preview: '{email_text[:200]}...'")
        
        # Generate explanation
        try:
            explanation = explainer.explain_instance(
                email_text,
                predict_proba_wrapper,
                num_features=20,
                num_samples=500
            )
            
            # Get the explanation for the predicted class
            predicted_class = explanation.predict_proba[1]  # Phishing probability
            actual_class = y_test[idx]
            
            print(f"Prediction: {predicted_class:.3f} (Phishing probability)")
            print(f"Actual: {'Phishing' if actual_class == 1 else 'Legitimate'}")
            
            # Show top influential words
            print(f"\nüîç Top words influencing prediction:")
            explanation_list = explanation.as_list()
            
            for word, importance in explanation_list[:10]:
                direction = "‚Üí Phishing" if importance > 0 else "‚Üí Legitimate"
                print(f"  '{word}': {importance:+.3f} {direction}")
            
            # Save explanation as HTML (optional)
            html_path = f"lime_explanation_{email_type.lower()}_{idx}.html"
            explanation.save_to_file(html_path)
            print(f"\nüìÑ Detailed explanation saved to: {html_path}")
            
            # Show explanation in notebook
            explanation.show_in_notebook(text=True)
            
        except Exception as e:
            print(f"‚ùå Error generating LIME explanation: {str(e)}")

except ImportError:
    print("‚ùå LIME not available. Please install with: pip install lime")
    print("Skipping LIME analysis...")
    
except Exception as e:
    print(f"‚ùå Error in LIME analysis: {str(e)}")
    print("This might be due to model compatibility issues.")

## 8. Model Explainability with SHAP

In [None]:
try:
    import shap
    
    print("üîÑ Setting up SHAP explainer...")
    
    # Create a wrapper function for SHAP
    def model_predict(texts):
        """Wrapper for SHAP that returns probabilities"""
        if isinstance(texts, str):
            texts = [texts]
        
        probabilities = []
        for text in texts:
            results = eval_pipeline(text)
            if isinstance(results, list):
                # Get phishing probability
                phishing_prob = next((r['score'] for r in results if r['label'] == 'LABEL_1'), 0.5)
            else:
                phishing_prob = results['score'] if results['label'] == 'LABEL_1' else 1 - results['score']
            probabilities.append([1 - phishing_prob, phishing_prob])
        
        return np.array(probabilities)
    
    # Initialize SHAP explainer
    # Using a sample of training data as background dataset
    background_texts = X_train[:100]  # Use subset for efficiency
    
    explainer = shap.Explainer(model_predict, background_texts)
    
    # Select sample texts for SHAP analysis
    sample_texts = []
    sample_labels = []
    
    # Get one example from each class
    for label in [0, 1]:
        idx = next(i for i, l in enumerate(y_test) if l == label)
        sample_texts.append(X_test[idx][:500])  # Truncate for efficiency
        sample_labels.append(label)
    
    print(f"üìä Analyzing {len(sample_texts)} sample emails with SHAP...")
    
    # Compute SHAP values
    try:
        shap_values = explainer(sample_texts)
        
        print("‚úì SHAP values computed successfully!")
        
        # Display results for each sample
        for i, (text, true_label) in enumerate(zip(sample_texts, sample_labels)):
            print(f"\nüìß Sample {i+1} ({'Phishing' if true_label == 1 else 'Legitimate'}):")
            print(f"Text: '{text[:150]}...'")
            
            # Get SHAP values for phishing class (class 1)
            shap_vals = shap_values[i].values[:, 1]  # Phishing class
            
            # Get the words (features)
            words = shap_values[i].data
            
            # Find top positive and negative contributions
            word_importance = list(zip(words, shap_vals))
            word_importance.sort(key=lambda x: abs(x[1]), reverse=True)
            
            print(f"üîç Top contributing words:")
            for word, importance in word_importance[:10]:
                direction = "‚Üí Phishing" if importance > 0 else "‚Üí Legitimate"
                print(f"  '{word}': {importance:+.4f} {direction}")
        
        # Create visualizations
        plt.figure(figsize=(12, 8))
        
        # SHAP summary plot
        try:
            plt.subplot(2, 1, 1)
            shap.summary_plot(shap_values[:, :, 1], feature_names=sample_texts, show=False)
            plt.title("SHAP Summary Plot - Feature Importance for Phishing Classification")
            
            # SHAP waterfall plot for first sample
            plt.subplot(2, 1, 2)
            shap.waterfall_plot(shap_values[0, :, 1], show=False)
            plt.title(f"SHAP Waterfall Plot - Sample 1 ({'Phishing' if sample_labels[0] == 1 else 'Legitimate'})")
            
        except Exception as e:
            print(f"‚ö†Ô∏è  Could not create SHAP plots: {str(e)}")
            print("This might be due to visualization compatibility issues.")
        
        plt.tight_layout()
        plt.show()
        
    except Exception as e:
        print(f"‚ùå Error computing SHAP values: {str(e)}")
        print("SHAP analysis may require more memory or different configuration.")
        
        # Fallback: Simple feature importance analysis
        print("\nüîÑ Performing simplified feature analysis...")
        
        for i, text in enumerate(sample_texts):
            words = text.split()[:50]  # First 50 words
            print(f"\nüìß Sample {i+1} word analysis:")
            
            # Simple word-by-word importance
            base_pred = model_predict([text])[0][1]  # Phishing probability
            
            important_words = []
            for word in words:
                # Remove word and see change in prediction
                modified_text = text.replace(word, '')
                if modified_text != text:
                    modified_pred = model_predict([modified_text])[0][1]
                    importance = base_pred - modified_pred
                    important_words.append((word, importance))
            
            # Sort by importance
            important_words.sort(key=lambda x: abs(x[1]), reverse=True)
            
            print("üîç Most influential words:")
            for word, importance in important_words[:10]:
                direction = "‚Üí Phishing" if importance > 0 else "‚Üí Legitimate"
                print(f"  '{word}': {importance:+.4f} {direction}")

except ImportError:
    print("‚ùå SHAP not available. Please install with: pip install shap")
    print("Skipping SHAP analysis...")

except Exception as e:
    print(f"‚ùå Error in SHAP analysis: {str(e)}")
    print("SHAP analysis might require additional configuration or memory.")

## 9. Feature Importance Visualization

In [None]:
# Attention Visualization and Feature Analysis

def get_attention_weights(text, model, tokenizer, max_length=512):
    """Extract attention weights from the model"""
    # Tokenize input
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=max_length
    ).to(device)
    
    # Get model outputs including attention
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)
    
    # Extract attention weights
    attention = outputs.attentions  # Tuple of attention matrices
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
    
    return attention, tokens, outputs.logits

def visualize_attention(text, model, tokenizer, layer=5, head=0):
    """Visualize attention weights for a specific layer and head"""
    attention, tokens, logits = get_attention_weights(text, model, tokenizer)
    
    # Get attention for specified layer and head
    att_matrix = attention[layer][0][head].cpu().numpy()
    
    # Create visualization
    plt.figure(figsize=(12, 8))
    
    # Filter out special tokens for better visualization
    valid_tokens = [token for token in tokens if token not in ['[CLS]', '[SEP]', '[PAD]']]
    valid_indices = [i for i, token in enumerate(tokens) if token not in ['[CLS]', '[SEP]', '[PAD]']]
    
    if len(valid_indices) > 30:  # Limit for readability
        valid_indices = valid_indices[:30]
        valid_tokens = valid_tokens[:30]
    
    # Extract relevant attention matrix
    att_subset = att_matrix[np.ix_(valid_indices, valid_indices)]
    
    # Plot heatmap
    plt.subplot(2, 1, 1)
    sns.heatmap(att_subset, 
                xticklabels=valid_tokens,
                yticklabels=valid_tokens,
                cmap='Blues',
                cbar=True)
    plt.title(f'Attention Weights - Layer {layer}, Head {head}')
    plt.xlabel('Tokens (Attending To)')
    plt.ylabel('Tokens (Attending From)')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    
    # Plot prediction probabilities
    plt.subplot(2, 1, 2)
    probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
    classes = ['Legitimate', 'Phishing']
    colors = ['green' if probs[0] > probs[1] else 'red']
    
    bars = plt.bar(classes, probs, color=['lightblue', 'lightcoral'])
    plt.ylabel('Probability')
    plt.title('Model Prediction Probabilities')
    
    # Add value labels on bars
    for bar, prob in zip(bars, probs):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{prob:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    return att_matrix, tokens, probs

# Analyze attention for sample emails
print("üîç Attention Weight Analysis")
print("=" * 50)

# Select examples for attention analysis
examples = [
    (X_test[phishing_idx], "Phishing Email"),
    (X_test[legitimate_idx], "Legitimate Email")
]

for text, label in examples:
    print(f"\nüìß Analyzing: {label}")
    print(f"Text preview: '{text[:150]}...'")
    
    try:
        # Visualize attention
        att_matrix, tokens, probs = visualize_attention(text, model, tokenizer)
        plt.suptitle(f'Attention Analysis - {label}', fontsize=16, y=1.02)
        plt.show()
        
        # Find most attended tokens
        avg_attention = np.mean(att_matrix, axis=0)  # Average attention received by each token
        token_attention = list(zip(tokens, avg_attention))
        token_attention.sort(key=lambda x: x[1], reverse=True)
        
        print(f"üéØ Most attended tokens:")
        for token, attention in token_attention[:10]:
            if token not in ['[CLS]', '[SEP]', '[PAD]']:
                print(f"  '{token}': {attention:.4f}")
                
    except Exception as e:
        print(f"‚ùå Error in attention analysis: {str(e)}")

# Word frequency analysis in phishing vs legitimate emails
print("\nüìä Word Frequency Analysis")
print("=" * 40)

# Collect all words from each class
phishing_words = []
legitimate_words = []

for text, label in zip(X_train, y_train):
    words = text.lower().split()
    if label == 1:
        phishing_words.extend(words)
    else:
        legitimate_words.extend(words)

# Count word frequencies
from collections import Counter
phishing_counter = Counter(phishing_words)
legitimate_counter = Counter(legitimate_words)

# Find distinctive words
common_phishing = phishing_counter.most_common(20)
common_legitimate = legitimate_counter.most_common(20)

# Create visualization
plt.figure(figsize=(15, 6))

plt.subplot(1, 2, 1)
words, counts = zip(*common_phishing[:15])
plt.barh(range(len(words)), counts, color='red', alpha=0.7)
plt.yticks(range(len(words)), words)
plt.xlabel('Frequency')
plt.title('Most Common Words in Phishing Emails')
plt.gca().invert_yaxis()

plt.subplot(1, 2, 2)
words, counts = zip(*common_legitimate[:15])
plt.barh(range(len(words)), counts, color='green', alpha=0.7)
plt.yticks(range(len(words)), words)
plt.xlabel('Frequency')
plt.title('Most Common Words in Legitimate Emails')
plt.gca().invert_yaxis()

plt.tight_layout()
plt.show()

print("\nüîç Top distinctive phishing keywords:")
for word, count in common_phishing[:10]:
    if len(word) > 3:  # Filter out short words
        print(f"  '{word}': {count} occurrences")

print("\nüîç Top legitimate keywords:")
for word, count in common_legitimate[:10]:
    if len(word) > 3:
        print(f"  '{word}': {count} occurrences")

## 10. Save Model and Results

In [None]:
# Save model, tokenizer, and results
import json
from datetime import datetime

print("üíæ Saving model and results...")

# Create results directory
results_dir = Path("results")
results_dir.mkdir(exist_ok=True)

# Save model and tokenizer (if not already saved)
model_dir = Path("models/distilbert-phishing")
if not model_dir.exists():
    model_dir.mkdir(parents=True, exist_ok=True)
    model.save_pretrained(model_dir)
    tokenizer.save_pretrained(model_dir)
    print(f"‚úì Model saved to {model_dir}")
else:
    print(f"‚úì Model already exists at {model_dir}")

# Save evaluation results
results = {
    "model_info": {
        "model_name": model_name,
        "architecture": "DistilBERT for sequence classification",
        "num_parameters": sum(p.numel() for p in model.parameters()),
        "max_length": MAX_LENGTH,
        "training_date": datetime.now().isoformat()
    },
    "dataset_info": {
        "train_size": len(X_train),
        "val_size": len(X_val), 
        "test_size": len(X_test),
        "class_distribution": {
            "legitimate": int(sum(1 for x in y_train if x == 0)),
            "phishing": int(sum(1 for x in y_train if x == 1))
        }
    },
    "performance_metrics": {
        "test_accuracy": float(accuracy),
        "test_precision": float(precision),
        "test_recall": float(recall),
        "test_f1_score": float(f1),
        "test_pr_auc": float(pr_auc)
    },
    "training_config": {
        "epochs": training_args.num_train_epochs,
        "batch_size": training_args.per_device_train_batch_size,
        "learning_rate": training_args.learning_rate,
        "warmup_steps": training_args.warmup_steps,
        "weight_decay": training_args.weight_decay
    }
}

# Save results to JSON
results_file = results_dir / "distilbert_results.json"
with open(results_file, 'w') as f:
    json.dump(results, f, indent=2)

print(f"‚úì Results saved to {results_file}")

# Save predictions for further analysis
predictions_df = pd.DataFrame({
    'text': X_test,
    'true_label': y_test,
    'predicted_label': y_pred,
    'phishing_probability': y_prob,
    'correct_prediction': y_test == y_pred
})

predictions_file = results_dir / "distilbert_predictions.csv"
predictions_df.to_csv(predictions_file, index=False)
print(f"‚úì Predictions saved to {predictions_file}")

# Create a simple inference function for future use
def classify_email(text, model_path="models/distilbert-phishing"):
    """
    Classify a single email text using the trained DistilBERT model
    
    Args:
        text (str): Email text to classify
        model_path (str): Path to the saved model
    
    Returns:
        dict: Classification results with prediction and confidence
    """
    # Load model and tokenizer
    model = DistilBertForSequenceClassification.from_pretrained(model_path)
    tokenizer = DistilBertTokenizer.from_pretrained(model_path)
    
    # Create pipeline
    classifier = pipeline(
        "text-classification",
        model=model,
        tokenizer=tokenizer,
        return_all_scores=True
    )
    
    # Make prediction
    results = classifier(text)
    
    # Format results
    phishing_score = next(r['score'] for r in results if r['label'] == 'LABEL_1')
    prediction = "Phishing" if phishing_score > 0.5 else "Legitimate"
    confidence = phishing_score if prediction == "Phishing" else 1 - phishing_score
    
    return {
        "prediction": prediction,
        "confidence": confidence,
        "phishing_probability": phishing_score,
        "legitimate_probability": 1 - phishing_score
    }

# Save the inference function
import pickle
with open(results_dir / "classify_email_function.pkl", "wb") as f:
    pickle.dump(classify_email, f)

print(f"‚úì Inference function saved to {results_dir / 'classify_email_function.pkl'}")

# Summary
print("\nüéâ Analysis Complete!")
print("=" * 50)
print(f"üìä Model Performance Summary:")
print(f"  ‚Ä¢ Accuracy: {accuracy:.4f}")
print(f"  ‚Ä¢ F1-Score: {f1:.4f}")
print(f"  ‚Ä¢ PR-AUC: {pr_auc:.4f}")
print(f"\nüíæ Saved Files:")
print(f"  ‚Ä¢ Model: {model_dir}")
print(f"  ‚Ä¢ Results: {results_file}")
print(f"  ‚Ä¢ Predictions: {predictions_file}")
print(f"\nüîç Explainability Features:")
print(f"  ‚Ä¢ LIME explanations for individual predictions")
print(f"  ‚Ä¢ SHAP values for feature importance")
print(f"  ‚Ä¢ Attention weight visualizations")
print(f"  ‚Ä¢ Word frequency analysis")

# Example usage of the saved model
print(f"\nüß™ Testing saved model:")
sample_text = "Get 50% off Viagra! Click here now for amazing deals!"
result = classify_email(sample_text)
print(f"Sample text: '{sample_text}'")
print(f"Prediction: {result['prediction']} (confidence: {result['confidence']:.3f})")
print(f"\n‚úÖ DistilBERT explainable phishing detection model is ready for deployment!")