# Complete Evaluation and XAI for Visual Emotion Recognition

This notebook contains comprehensive evaluation and explainable AI functionality for visual emotion recognition.

## Components Included:
1. **Model Evaluation** - Comprehensive metrics and analysis
2. **GradCAM** - Gradient-based visual explanations
3. **LIME** - Local interpretable model-agnostic explanations
4. **SHAP** - SHapley Additive exPlanations
5. **Visualization Tools** - Various plotting and analysis functions


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import cv2
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. GradCAM Implementation

In [None]:
class GradCAM:
    """
    Gradient-weighted Class Activation Mapping (GradCAM) for visual explanations.
    """
    
    def __init__(self, model, target_layer_name):
        """
        Initialize GradCAM.
        
        Args:
            model: PyTorch model
            target_layer_name (str): Name of target layer for gradients
        """
        self.model = model
        self.target_layer_name = target_layer_name
        self.gradients = None
        self.activations = None
        
        # Register hooks
        self._register_hooks()
    
    def _register_hooks(self):
        """Register forward and backward hooks."""
        def forward_hook(module, input, output):
            self.activations = output.detach()
        
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()
        
        # Find target layer
        target_layer = None
        for name, module in self.model.named_modules():
            if name == self.target_layer_name:
                target_layer = module
                break
        
        if target_layer is None:
            raise ValueError(f"Layer '{self.target_layer_name}' not found in model")
        
        # Register hooks
        target_layer.register_forward_hook(forward_hook)
        target_layer.register_backward_hook(backward_hook)
    
    def generate_cam(self, input_tensor, class_idx=None):
        """
        Generate GradCAM heatmap.
        
        Args:
            input_tensor (torch.Tensor): Input tensor
            class_idx (int): Class index for explanation (None for predicted class)
            
        Returns:
            np.array: GradCAM heatmap
        """
        self.model.eval()
        
        # Forward pass
        output = self.model(input_tensor)
        
        # Get class index
        if class_idx is None:
            class_idx = output.argmax(dim=1).item()
        
        # Backward pass for target class
        self.model.zero_grad()
        class_score = output[0, class_idx]
        class_score.backward()
        
        # Generate CAM
        gradients = self.gradients[0]  # Remove batch dimension
        activations = self.activations[0]  # Remove batch dimension
        
        # Pool gradients over spatial dimensions
        weights = torch.mean(gradients, dim=(1, 2))
        
        # Weighted combination of activation maps
        cam = torch.zeros(activations.shape[1:], dtype=torch.float32)
        for i, w in enumerate(weights):
            cam += w * activations[i]
        
        # ReLU to keep only positive influences
        cam = F.relu(cam)
        
        # Normalize
        if cam.max() > 0:
            cam = cam / cam.max()
        
        return cam.cpu().numpy()
    
    def visualize_cam(self, input_tensor, original_image, class_idx=None, 
                     alpha=0.4, colormap=cv2.COLORMAP_JET):
        """
        Visualize GradCAM overlay on original image.
        
        Args:
            input_tensor (torch.Tensor): Input tensor
            original_image (PIL.Image or np.array): Original image
            class_idx (int): Class index
            alpha (float): Overlay transparency
            colormap: OpenCV colormap
            
        Returns:
            np.array: Visualization image
        """
        # Generate CAM
        cam = self.generate_cam(input_tensor, class_idx)
        
        # Convert original image to numpy if needed
        if isinstance(original_image, Image.Image):
            original_image = np.array(original_image)
        
        # Resize CAM to match image size
        height, width = original_image.shape[:2]
        cam_resized = cv2.resize(cam, (width, height))
        
        # Apply colormap
        heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), colormap)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        
        # Ensure original image is RGB
        if len(original_image.shape) == 2:  # Grayscale
            original_image = cv2.cvtColor(original_image, cv2.COLOR_GRAY2RGB)
        
        # Overlay heatmap on original image
        overlay = alpha * heatmap + (1 - alpha) * original_image
        
        return overlay.astype(np.uint8)


def get_gradcam_layer_name(model):
    """
    Automatically detect the best layer for GradCAM based on model architecture.
    
    Args:
        model: PyTorch model
        
    Returns:
        str: Layer name for GradCAM
    """
    model_name = model.__class__.__name__.lower()
    
    if 'vgg' in model_name:
        return 'features.30'  # Last conv layer in VGG
    elif 'resnet' in model_name or 'improved' in model_name:
        return 'backbone.layer4'  # Last residual block
    elif 'baseline' in model_name:
        return 'features.12'  # Last conv layer in baseline
    else:
        # Try to find the last convolutional layer
        conv_layers = []
        for name, module in model.named_modules():
            if isinstance(module, nn.Conv2d):
                conv_layers.append(name)
        
        if conv_layers:
            return conv_layers[-1]
        else:
            raise ValueError("Could not find suitable layer for GradCAM")


print("GradCAM implementation ready!")

## 2. LIME Implementation

In [None]:
try:
    from lime import lime_image
    from lime.wrappers.scikit_image import SegmentationAlgorithm
    LIME_AVAILABLE = True
    print("LIME available for explanations")
except ImportError:
    LIME_AVAILABLE = False
    print("LIME not available. Install with: pip install lime")


if LIME_AVAILABLE:
    class LIMEExplainer:
        """
        LIME (Local Interpretable Model-agnostic Explanations) for image classification.
        """
        
        def __init__(self, model, transform, label_map=None, device='cpu'):
            """
            Initialize LIME explainer.
            
            Args:
                model: PyTorch model
                transform: Image preprocessing transform
                label_map (dict): Mapping from class names to indices
                device (str): Device to use
            """
            self.model = model
            self.transform = transform
            self.label_map = label_map or {}
            self.device = device
            
            # Create LIME explainer
            self.explainer = lime_image.LimeImageExplainer()
        
        def predict_fn(self, images):
            """
            Prediction function for LIME.
            
            Args:
                images (np.array): Batch of images
                
            Returns:
                np.array: Prediction probabilities
            """
            self.model.eval()
            predictions = []
            
            with torch.no_grad():
                for image in images:
                    # Convert to PIL Image
                    if image.dtype != np.uint8:
                        image = (image * 255).astype(np.uint8)
                    
                    pil_image = Image.fromarray(image)
                    
                    # Apply transforms
                    input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
                    
                    # Get prediction
                    output = self.model(input_tensor)
                    probs = F.softmax(output, dim=1)[0].cpu().numpy()
                    predictions.append(probs)
            
            return np.array(predictions)
        
        def explain_image(self, image, top_labels=None, num_features=5, 
                         num_samples=1000, segmentation_fn=None):
            """
            Explain image prediction using LIME.
            
            Args:
                image (PIL.Image or np.array): Image to explain
                top_labels (list): Labels to explain (None for top predictions)
                num_features (int): Number of features to include in explanation
                num_samples (int): Number of samples for LIME
                segmentation_fn: Segmentation function
                
            Returns:
                LIME explanation object
            """
            # Convert image to numpy array
            if isinstance(image, Image.Image):
                image_array = np.array(image)
            else:
                image_array = image
            
            # Ensure RGB format
            if len(image_array.shape) == 2:  # Grayscale
                image_array = np.stack([image_array] * 3, axis=-1)
            
            # Get segmentation function
            if segmentation_fn is None:
                segmentation_fn = SegmentationAlgorithm('quickshift', 
                                                       kernel_size=4, 
                                                       max_dist=200, 
                                                       ratio=0.2)
            
            # Generate explanation
            explanation = self.explainer.explain_instance(
                image_array,
                self.predict_fn,
                top_labels=top_labels,
                hide_color=0,
                num_samples=num_samples,
                segmentation_fn=segmentation_fn
            )
            
            return explanation
        
        def visualize_explanation(self, explanation, label_idx, positive_only=True, 
                                hide_rest=False, num_features=5):
            """
            Visualize LIME explanation.
            
            Args:
                explanation: LIME explanation object
                label_idx (int): Label index to visualize
                positive_only (bool): Show only positive features
                hide_rest (bool): Hide non-relevant regions
                num_features (int): Number of features to show
                
            Returns:
                tuple: (explanation_image, mask)
            """
            image, mask = explanation.get_image_and_mask(
                label_idx, 
                positive_only=positive_only,
                num_features=num_features,
                hide_rest=hide_rest
            )
            
            return image, mask
    
    print("LIME implementation ready!")

else:
    class LIMEExplainer:
        def __init__(self, *args, **kwargs):
            raise ImportError("LIME not available. Install with: pip install lime")
    
    print("LIME placeholder created.")

## 3. SHAP Implementation

In [None]:
try:
    import shap
    SHAP_AVAILABLE = True
    print("SHAP available for explanations")
except ImportError:
    SHAP_AVAILABLE = False
    print("SHAP not available. Install with: pip install shap")


if SHAP_AVAILABLE:
    class SHAPExplainer:
        """
        SHAP (SHapley Additive exPlanations) for image classification.
        """
        
        def __init__(self, model, background_data, device='cpu'):
            """
            Initialize SHAP explainer.
            
            Args:
                model: PyTorch model
                background_data (torch.Tensor): Background dataset for SHAP
                device (str): Device to use
            """
            self.model = model
            self.device = device
            
            # Create prediction function
            def predict_fn(x):
                self.model.eval()
                with torch.no_grad():
                    x_tensor = torch.from_numpy(x).float().to(self.device)
                    outputs = self.model(x_tensor)
                    return F.softmax(outputs, dim=1).cpu().numpy()
            
            # Create SHAP explainer
            self.explainer = shap.DeepExplainer(predict_fn, background_data.to(device))
        
        def explain_image(self, image_tensor, class_idx=None):
            """
            Explain image prediction using SHAP.
            
            Args:
                image_tensor (torch.Tensor): Image tensor to explain
                class_idx (int): Class index to explain
                
            Returns:
                np.array: SHAP values
            """
            # Get SHAP values
            shap_values = self.explainer.shap_values(image_tensor.to(self.device))
            
            # Return values for specific class or all classes
            if class_idx is not None:
                return shap_values[class_idx]
            else:
                return shap_values
        
        def visualize_explanation(self, image_tensor, shap_values, class_names=None):
            """
            Visualize SHAP explanation.
            
            Args:
                image_tensor (torch.Tensor): Original image tensor
                shap_values: SHAP values
                class_names (list): Class names for labeling
            """
            # Convert tensor to numpy for visualization
            image_np = image_tensor.cpu().numpy()
            
            # SHAP image plot
            shap.image_plot(
                shap_values,
                image_np,
                labels=class_names,
                show=True
            )
    
    print("SHAP implementation ready!")

else:
    class SHAPExplainer:
        def __init__(self, *args, **kwargs):
            raise ImportError("SHAP not available. Install with: pip install shap")
    
    print("SHAP placeholder created.")

## 4. Comprehensive Evaluation Functions

In [None]:
def comprehensive_model_analysis(model, test_loader, device, label_map=None, 
                                save_results=True, results_dir='results'):
    """
    Perform comprehensive model analysis and evaluation.
    
    Args:
        model: Trained PyTorch model
        test_loader: Test data loader
        device: Device to use
        label_map (dict): Label mapping
        save_results (bool): Whether to save results
        results_dir (str): Directory to save results
        
    Returns:
        dict: Comprehensive analysis results
    """
    import os
    from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
    from sklearn.preprocessing import label_binarize
    
    print("Starting comprehensive model analysis...")
    
    # Create results directory
    if save_results:
        os.makedirs(results_dir, exist_ok=True)
    
    model.eval()
    all_predictions = []
    all_targets = []
    all_probabilities = []
    prediction_confidence = []
    
    # Collect predictions
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            
            output = model(data)
            probabilities = F.softmax(output, dim=1)
            predictions = output.argmax(dim=1)
            
            # Calculate confidence (max probability)
            confidence = probabilities.max(dim=1)[0]
            
            all_predictions.extend(predictions.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())
            prediction_confidence.extend(confidence.cpu().numpy())
            
            if batch_idx % 50 == 0:
                print(f"Processed {batch_idx+1}/{len(test_loader)} batches")
    
    all_probabilities = np.array(all_probabilities)
    prediction_confidence = np.array(prediction_confidence)
    
    # Basic metrics
    from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
    
    accuracy = accuracy_score(all_targets, all_predictions)
    f1_macro = f1_score(all_targets, all_predictions, average='macro')
    f1_weighted = f1_score(all_targets, all_predictions, average='weighted')
    precision_macro = precision_score(all_targets, all_predictions, average='macro', zero_division=0)
    recall_macro = recall_score(all_targets, all_predictions, average='macro', zero_division=0)
    
    # Class names
    if label_map:
        class_names = [k for k, v in sorted(label_map.items(), key=lambda x: x[1])]
    else:
        unique_labels = sorted(list(set(all_targets)))
        class_names = [f'Class_{i}' for i in unique_labels]
    
    num_classes = len(class_names)
    
    # Detailed classification report
    report = classification_report(all_targets, all_predictions, 
                                 target_names=class_names, 
                                 output_dict=True, zero_division=0)
    
    # Confusion matrix
    cm = confusion_matrix(all_targets, all_predictions)
    
    # ROC curves for multi-class
    if num_classes > 2:
        # Binarize targets for multi-class ROC
        y_test_bin = label_binarize(all_targets, classes=range(num_classes))
        
        # Compute ROC curve and AUC for each class
        fpr = dict()
        tpr = dict()
        roc_auc = dict()
        
        for i in range(num_classes):
            fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], all_probabilities[:, i])
            roc_auc[i] = auc(fpr[i], tpr[i])
        
        # Compute micro-average ROC curve and AUC
        fpr["micro"], tpr["micro"], _ = roc_curve(y_test_bin.ravel(), all_probabilities.ravel())
        roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
    else:
        # Binary classification
        fpr, tpr, _ = roc_curve(all_targets, all_probabilities[:, 1])
        roc_auc = {0: auc(fpr, tpr)}
    
    # Confidence analysis
    confidence_stats = {
        'mean': np.mean(prediction_confidence),
        'std': np.std(prediction_confidence),
        'min': np.min(prediction_confidence),
        'max': np.max(prediction_confidence),
        'median': np.median(prediction_confidence)
    }
    
    # Per-class accuracy
    per_class_accuracy = {}
    for i, class_name in enumerate(class_names):
        class_mask = np.array(all_targets) == i
        if np.sum(class_mask) > 0:
            class_acc = np.mean(np.array(all_predictions)[class_mask] == i)
            per_class_accuracy[class_name] = class_acc
        else:
            per_class_accuracy[class_name] = 0.0
    
    # Compile results
    results = {
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'classification_report': report,
        'confusion_matrix': cm,
        'roc_auc': roc_auc,
        'fpr': fpr,
        'tpr': tpr,
        'confidence_stats': confidence_stats,
        'per_class_accuracy': per_class_accuracy,
        'predictions': all_predictions,
        'targets': all_targets,
        'probabilities': all_probabilities,
        'confidence': prediction_confidence,
        'class_names': class_names,
        'num_classes': num_classes
    }
    
    # Print summary
    print(f"\n{'='*60}")
    print(f"COMPREHENSIVE MODEL ANALYSIS RESULTS")
    print(f"{'='*60}")
    print(f"Test Samples: {len(all_targets)}")
    print(f"Number of Classes: {num_classes}")
    print(f"Classes: {class_names}")
    print(f"\nOVERALL METRICS:")
    print(f"  Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"  F1-Score (Macro): {f1_macro:.4f}")
    print(f"  F1-Score (Weighted): {f1_weighted:.4f}")
    print(f"  Precision (Macro): {precision_macro:.4f}")
    print(f"  Recall (Macro): {recall_macro:.4f}")
    
    print(f"\nCONFIDENCE ANALYSIS:")
    print(f"  Mean Confidence: {confidence_stats['mean']:.4f}")
    print(f"  Std Confidence: {confidence_stats['std']:.4f}")
    print(f"  Min Confidence: {confidence_stats['min']:.4f}")
    print(f"  Max Confidence: {confidence_stats['max']:.4f}")
    
    print(f"\nPER-CLASS ACCURACY:")
    for class_name, acc in per_class_accuracy.items():
        print(f"  {class_name}: {acc:.4f} ({acc*100:.2f}%)")
    
    if num_classes > 2:
        print(f"\nROC AUC SCORES:")
        for i, class_name in enumerate(class_names):
            if i in roc_auc:
                print(f"  {class_name}: {roc_auc[i]:.4f}")
        if "micro" in roc_auc:
            print(f"  Micro-average: {roc_auc['micro']:.4f}")
    
    # Save results if requested
    if save_results:
        import json
        
        # Save numerical results (excluding numpy arrays)
        save_dict = {
            'accuracy': accuracy,
            'f1_macro': f1_macro,
            'f1_weighted': f1_weighted,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'confidence_stats': confidence_stats,
            'per_class_accuracy': per_class_accuracy,
            'class_names': class_names,
            'num_classes': num_classes
        }
        
        with open(os.path.join(results_dir, 'analysis_results.json'), 'w') as f:
            json.dump(save_dict, f, indent=2)
        
        # Save arrays
        np.save(os.path.join(results_dir, 'predictions.npy'), all_predictions)
        np.save(os.path.join(results_dir, 'targets.npy'), all_targets)
        np.save(os.path.join(results_dir, 'probabilities.npy'), all_probabilities)
        np.save(os.path.join(results_dir, 'confidence.npy'), prediction_confidence)
        np.save(os.path.join(results_dir, 'confusion_matrix.npy'), cm)
        
        print(f"\nResults saved to {results_dir}/")
    
    return results


print("Comprehensive evaluation functions ready!")

## 5. Visualization Functions

In [None]:
def plot_comprehensive_results(results, figsize=(20, 12), save_path=None):
    """
    Plot comprehensive analysis results.
    
    Args:
        results (dict): Results from comprehensive_model_analysis
        figsize (tuple): Figure size
        save_path (str): Path to save plots
    """
    fig, axes = plt.subplots(2, 3, figsize=figsize)
    
    class_names = results['class_names']
    cm = results['confusion_matrix']
    confidence = results['confidence']
    per_class_acc = results['per_class_accuracy']
    
    # 1. Confusion Matrix
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    im1 = axes[0, 0].imshow(cm_normalized, interpolation='nearest', cmap=plt.cm.Blues)
    axes[0, 0].set_title('Normalized Confusion Matrix')
    axes[0, 0].set_xlabel('Predicted Label')
    axes[0, 0].set_ylabel('True Label')
    
    # Add text annotations
    thresh = cm_normalized.max() / 2.
    for i in range(cm_normalized.shape[0]):
        for j in range(cm_normalized.shape[1]):
            axes[0, 0].text(j, i, f'{cm_normalized[i, j]:.2f}',
                           ha="center", va="center",
                           color="white" if cm_normalized[i, j] > thresh else "black")
    
    axes[0, 0].set_xticks(range(len(class_names)))
    axes[0, 0].set_yticks(range(len(class_names)))
    axes[0, 0].set_xticklabels(class_names, rotation=45)
    axes[0, 0].set_yticklabels(class_names)
    
    # 2. Per-class Accuracy
    class_accs = [per_class_acc[name] for name in class_names]
    bars = axes[0, 1].bar(class_names, class_accs, color='skyblue', alpha=0.7)
    axes[0, 1].set_title('Per-Class Accuracy')
    axes[0, 1].set_xlabel('Class')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].set_ylim(0, 1)
    axes[0, 1].tick_params(axis='x', rotation=45)
    
    # Add value labels on bars
    for bar, acc in zip(bars, class_accs):
        axes[0, 1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                       f'{acc:.3f}', ha='center', va='bottom', fontsize=8)
    
    # 3. Confidence Distribution
    axes[0, 2].hist(confidence, bins=30, alpha=0.7, color='green', edgecolor='black')
    axes[0, 2].axvline(confidence.mean(), color='red', linestyle='--', 
                      label=f'Mean: {confidence.mean():.3f}')
    axes[0, 2].set_title('Prediction Confidence Distribution')
    axes[0, 2].set_xlabel('Confidence Score')
    axes[0, 2].set_ylabel('Frequency')
    axes[0, 2].legend()
    
    # 4. ROC Curves (if available)
    if 'fpr' in results and 'tpr' in results:
        fpr = results['fpr']
        tpr = results['tpr']
        roc_auc = results['roc_auc']
        
        if results['num_classes'] > 2:
            # Multi-class ROC
            for i, class_name in enumerate(class_names):
                if i in fpr and i in tpr:
                    axes[1, 0].plot(fpr[i], tpr[i], 
                                   label=f'{class_name} (AUC = {roc_auc[i]:.3f})')
            
            if 'micro' in fpr:
                axes[1, 0].plot(fpr['micro'], tpr['micro'], 'k--',
                               label=f'Micro-avg (AUC = {roc_auc["micro"]:.3f})')
        else:
            # Binary ROC
            axes[1, 0].plot(fpr, tpr, label=f'ROC (AUC = {roc_auc[0]:.3f})')
        
        axes[1, 0].plot([0, 1], [0, 1], 'k--', alpha=0.5)
        axes[1, 0].set_title('ROC Curves')
        axes[1, 0].set_xlabel('False Positive Rate')
        axes[1, 0].set_ylabel('True Positive Rate')
        axes[1, 0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    else:
        axes[1, 0].text(0.5, 0.5, 'ROC data not available', 
                       ha='center', va='center', transform=axes[1, 0].transAxes)
        axes[1, 0].set_title('ROC Curves')
    
    # 5. Classification Report Heatmap
    if 'classification_report' in results:
        report = results['classification_report']
        
        # Extract per-class metrics
        metrics = ['precision', 'recall', 'f1-score']
        metric_matrix = []
        
        for class_name in class_names:
            if class_name in report:
                row = [report[class_name][metric] for metric in metrics]
                metric_matrix.append(row)
            else:
                metric_matrix.append([0, 0, 0])
        
        metric_matrix = np.array(metric_matrix)
        
        im2 = axes[1, 1].imshow(metric_matrix, cmap='YlOrRd', aspect='auto', vmin=0, vmax=1)
        axes[1, 1].set_title('Per-Class Metrics Heatmap')
        axes[1, 1].set_xlabel('Metrics')
        axes[1, 1].set_ylabel('Classes')
        
        # Add text annotations
        for i in range(len(class_names)):
            for j, metric in enumerate(metrics):
                text = axes[1, 1].text(j, i, f'{metric_matrix[i, j]:.3f}',
                                      ha="center", va="center", color="black")
        
        axes[1, 1].set_xticks(range(len(metrics)))
        axes[1, 1].set_yticks(range(len(class_names)))
        axes[1, 1].set_xticklabels(metrics)
        axes[1, 1].set_yticklabels(class_names)
        
        # Add colorbar
        plt.colorbar(im2, ax=axes[1, 1], shrink=0.8)
    
    # 6. Model Performance Summary
    axes[1, 2].axis('off')
    summary_text = f"""MODEL PERFORMANCE SUMMARY
    
Overall Accuracy: {results['accuracy']:.4f}
F1-Score (Macro): {results['f1_macro']:.4f}
F1-Score (Weighted): {results['f1_weighted']:.4f}
Precision (Macro): {results['precision_macro']:.4f}
Recall (Macro): {results['recall_macro']:.4f}

Confidence Stats:
  Mean: {results['confidence_stats']['mean']:.4f}
  Std: {results['confidence_stats']['std']:.4f}
  Min: {results['confidence_stats']['min']:.4f}
  Max: {results['confidence_stats']['max']:.4f}

Classes: {len(class_names)}
Test Samples: {len(results['targets'])}
    """
    
    axes[1, 2].text(0.05, 0.95, summary_text, transform=axes[1, 2].transAxes,
                    fontsize=10, verticalalignment='top', fontfamily='monospace',
                    bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plots saved to {save_path}")
    
    plt.show()


def visualize_explainability_results(model, test_loader, device, num_samples=4, 
                                    methods=['gradcam'], save_dir=None):
    """
    Visualize explainability results for sample predictions.
    
    Args:
        model: Trained PyTorch model
        test_loader: Test data loader
        device: Device to use
        num_samples (int): Number of samples to visualize
        methods (list): Explainability methods to use
        save_dir (str): Directory to save visualizations
    """
    # Get sample data
    data_iter = iter(test_loader)
    images, labels = next(data_iter)
    
    # Select samples
    indices = np.random.choice(len(images), min(num_samples, len(images)), replace=False)
    
    if 'gradcam' in methods:
        # Initialize GradCAM
        try:
            target_layer = get_gradcam_layer_name(model)
            gradcam = GradCAM(model, target_layer)
            
            fig, axes = plt.subplots(2, num_samples, figsize=(4*num_samples, 8))
            if num_samples == 1:
                axes = axes.reshape(2, 1)
            
            for i, idx in enumerate(indices):
                image = images[idx:idx+1].to(device)
                true_label = labels[idx].item()
                
                # Get prediction
                with torch.no_grad():
                    output = model(image)
                    pred_label = output.argmax(dim=1).item()
                    confidence = F.softmax(output, dim=1).max().item()
                
                # Original image
                orig_img = images[idx].permute(1, 2, 0).cpu().numpy()
                if orig_img.shape[2] == 1:  # Grayscale
                    orig_img = orig_img.squeeze(2)
                    axes[0, i].imshow(orig_img, cmap='gray')
                else:  # RGB
                    # Denormalize for visualization
                    orig_img = np.clip(orig_img * 0.5 + 0.5, 0, 1)
                    axes[0, i].imshow(orig_img)
                
                axes[0, i].set_title(f'Original\nTrue: {true_label}, Pred: {pred_label}\nConf: {confidence:.3f}')
                axes[0, i].axis('off')
                
                # GradCAM visualization
                cam = gradcam.generate_cam(image)
                axes[1, i].imshow(cam, cmap='hot')
                axes[1, i].set_title(f'GradCAM Heatmap')
                axes[1, i].axis('off')
            
            plt.suptitle('GradCAM Visualizations', fontsize=16)
            plt.tight_layout()
            
            if save_dir:
                import os
                os.makedirs(save_dir, exist_ok=True)
                plt.savefig(os.path.join(save_dir, 'gradcam_results.png'), 
                           dpi=300, bbox_inches='tight')
            
            plt.show()
            
        except Exception as e:
            print(f"GradCAM visualization failed: {e}")
    
    print(f"Explainability visualization completed for {len(indices)} samples")


print("Visualization functions ready!")

## Summary

This notebook provides comprehensive evaluation and explainable AI functionality:

### Core Components:
1. **GradCAM**: Gradient-based visual explanations with automatic layer detection
2. **LIME**: Local interpretable model-agnostic explanations (if available)
3. **SHAP**: SHapley Additive exPlanations (if available)
4. **Comprehensive Evaluation**: Detailed metrics, ROC curves, confidence analysis
5. **Advanced Visualizations**: Multi-panel plots with confusion matrices, ROC curves, confidence distributions

### Key Features:
- **Automatic Layer Detection**: GradCAM automatically finds suitable layers
- **Multi-method Explanations**: Support for multiple XAI approaches
- **Comprehensive Metrics**: Accuracy, F1-scores, precision, recall, ROC-AUC
- **Confidence Analysis**: Prediction confidence statistics and distributions
- **Per-class Analysis**: Detailed per-class performance metrics
- **Rich Visualizations**: Professional-quality plots and heatmaps
- **Results Saving**: Automatic saving of results and visualizations

### Usage Examples:
```python
# Comprehensive analysis
results = comprehensive_model_analysis(model, test_loader, device, label_map)

# Plot results
plot_comprehensive_results(results, save_path='analysis_plots.png')

# GradCAM explanations
visualize_explainability_results(model, test_loader, device, 
                                methods=['gradcam'], num_samples=6)

# Individual explanations
gradcam = GradCAM(model, 'features.30')
cam = gradcam.generate_cam(image_tensor)
```

All functionality is self-contained within this notebook and doesn't require the src folder structure.