# 🦷 PAI Classification Explainable AI Analysis
## Multi-Layer Class Activation Mapping for Dental Radiograph Interpretation

**Author:** Gerald Torgersen  
**Date:** 2025  
**GitHub:** [github.com/geraldOslo/PAI-meets-AI](https://github.com/geraldOslo/PAI-meets-AI)  

**License**  
SPDX-License-Identifier: MIT  
Copyright (c) 2025 Gerald Torgersen

---

### 🎯 **Overview**

This notebook performs **Explainable AI (XAI) analysis** on fine-tuned deep learning models for **Periapical Index (PAI) classification** of dental radiographs. PAI scoring is a standardized method for assessing apical periodontitis in endodontic diagnosis, rating pathology on a scale of 1-5 based on radiographic findings.

### 🔬 **What We're Analyzing**

**PAI Classification Scale:**
- **PAI 1**: Normal periapical structures
- **PAI 2**: Small changes in bone structure  
- **PAI 3**: Changes in bone structure with mineral loss
- **PAI 4**: Periodontitis with well-defined radiolucent area
- **PAI 5**: Severe periodontitis with exacerbating features

### 🧠 **Explainable AI Approach**

This analysis employs **multi-layer Class Activation Mapping (CAM)** to understand *how* and *where* the AI model focuses when making PAI predictions:

#### **🔍 Multi-Layer Analysis**
- **Early Layers (Blocks 1-2)**: Low-level features like edges and textures
- **Mid Layers (Blocks 3-4)**: Intermediate patterns and structures  
- **Deep Layers (Blocks 5-6)**: High-level semantic features critical for classification
- **Combined Analysis**: Weighted fusion across all layers for comprehensive understanding

#### **🎨 CAM Methods Available**
- **🎯 GradCAM**: Standard gradient-based attention mapping
- **⚡ GradCAM++**: Enhanced gradient weighting for better localization
- **🔄 ScoreCAM**: Perturbation-based method independent of gradients
- **📊 LayerCAM**: Layer-wise gradient analysis for hierarchical understanding

### 🔬 **Methodological Rigor**
- **Separate CAM Generation**: Individual visualizations use predicted class CAMs (showing model reasoning), while population averages use true class CAMs (revealing class-specific patterns)
- **Quality Validation**: Cross-class difference metrics ensure meaningful pattern discrimination

### 📊 **Analysis Outputs**

#### **Individual Sample Visualizations**
- **🖼️ Multi-panel plots** showing original image, prediction probabilities, and layer-wise heatmaps
- **🎭 Class-specific attention maps** revealing how focus changes for different PAI scores
- **🔍 Difference maps** highlighting what multi-layer fusion adds over single-layer analysis

#### **Population-Level Insights**
- **📈 Average heatmaps** for each PAI class showing typical attention patterns
- **🧭 Anatomical alignment** with quadrant-aware image flipping for meaningful averaging
- **🎯 Optional circular masking** to focus on periapical regions

#### **Performance Metrics**
- **📊 Confusion matrices** and classification performance
- **📏 Quantitative metrics**: Accuracy, MAE, Quadratic Weighted Kappa (QWK)
- **💾 Comprehensive results** in CSV and JSON formats

### 🔬 **Scientific Value**

This analysis provides insights into:
- **🎯 Model reliability**: Does the AI focus on clinically relevant areas?
- **📚 Feature hierarchy**: How do different network layers contribute to decisions?
- **⚖️ Method comparison**: Which CAM technique provides most interpretable results?
- **🏥 Clinical validation**: Can model attention patterns support diagnostic confidence?

### 🚀 **Workflow Overview**

```mermaid
graph TD
    A[📁 Load Test Data] --> B[🤖 Load Trained Model]
    B --> C[🎯 Generate Multi-Layer CAMs]
    C --> D[🖼️ Create Individual Visualizations]
    C --> E[📊 Calculate Average Heatmaps]
    D --> F[📈 Evaluate Performance]
    E --> F
    F --> G[💾 Save Results & Metrics]
```

### 📋 **Prerequisites**

- ✅ Trained EfficientNet-B3 model checkpoint with PAI classification weights
- ✅ Test dataset with PAI labels and quadrant information  
- ✅ Required libraries: `timm`, `pytorch-grad-cam`, `torch`, `matplotlib`, etc.

### 🎛️ **Configuration**

All analysis parameters are controlled through `config_inference.py`:
- 📂 **Data paths**: Model checkpoint, test images, output directory
- 🎨 **CAM methods**: Which explainability techniques to apply
- 🔢 **Visualization settings**: Number of examples, masking parameters
- ⚙️ **Model parameters**: Automatically loaded from checkpoint

---

**🎓 Research Context**: This analysis supports evidence-based AI in endodontic diagnosis by providing interpretable insights into automated PAI scoring systems, contributing to trustworthy AI deployment in clinical dentistry.

---

## Setup and Configuration

In [None]:
# ==============================================================================
# Imports and Environment Setup
# ==============================================================================

# Import configuration and utilities
import config_inference as config
import inference_utils as utils

# Standard imports
import os
import sys
import traceback

# Setup custom package path if needed
if hasattr(config, 'CUSTOM_PACKAGE_PATH'):
    utils.setup_custom_package_path(config.CUSTOM_PACKAGE_PATH)

# Check library availability
utils.check_library_availability()

print("=== PAI XAI Analysis Setup Complete ===")


## Configuration Review
# ==============================================================================
# Configuration Validation and Review
# ==============================================================================

# Validate configuration
try:
    config.validate_config()
    print("✅ Configuration validation passed!")
except ValueError as e:
    print(f"❌ Configuration validation failed: {e}")
    sys.exit(1)

# Print configuration summary
config.print_config_summary()

# Ensure output directory exists
output_dir = config.ensure_output_directory()
print(f"📁 Output directory ready: {output_dir}")

## Data Loading and Preparation

In [None]:
# ==============================================================================
# Data Loading and Preparation
# ==============================================================================

print("\n=== Loading and Preparing Test Data ===")

try:
    # Load and prepare test data
    test_loader, test_data = utils.load_and_prepare_test_data(
        test_csv_file=config.TEST_CSV_FILE,
        test_root_dir=config.TEST_ROOT_DIR,
        mean=config.MEAN,
        std=config.STD,
        batch_size=config.BATCH_SIZE
    )
    
    print(f"✅ Test data loaded successfully!")
    print(f"   📊 Total samples: {len(test_data)}")
    print(f"   🔢 Batch size: {config.BATCH_SIZE}")
    
except Exception as e:
    print(f"❌ Error loading test data: {e}")
    traceback.print_exc()
    sys.exit(1)


# ==============================================================================
# Optional: Display Sample Images
# ==============================================================================

if config.SHOW_CLIPS:
    print(f"\n=== Displaying {config.CLIPS_TO_SHOW} Sample Images ===")
    utils.display_dataset_samples(
        test_data, 
        num_images=config.CLIPS_TO_SHOW, 
        images_per_row=config.CLIPS_PER_ROW, 
        title="Test Dataset Samples"
    )
else:
    print("\n📸 Sample image display is disabled in config (SHOW_CLIPS=False)")

## XAI Analysis Execution

In [None]:
# ==============================================================================
# Run XAI Analysis for Selected CAM Methods
# ==============================================================================

print(f"\n=== Running XAI Analysis ===")
print(f"🎯 Methods to run: {[config.GRAD_CAM_METHODS[i] for i in config.USE_METHODS]}")
print(f"🖼️  Individual visualizations: {config.NUM_EXAMPLES_TO_VISUALIZE}")
print(f"🎭 Mask radius: {config.APPLY_MASK_RADIUS_PIXELS}")

# Store results for each method
analysis_results = {}

# Loop through selected CAM methods
for i in config.USE_METHODS:
    if i < 0 or i >= len(config.GRAD_CAM_METHODS):
        print(f"⚠️  Warning: Invalid index {i} in USE_METHODS. Skipping.")
        continue

    selected_cam_method = config.GRAD_CAM_METHODS[i]

    # Check if pytorch-grad-cam is available for non-gradcam methods
    if selected_cam_method.lower() != 'gradcam' and not utils.PYTORCH_GRAD_CAM_AVAILABLE:
        print(f"⚠️  Skipping method '{selected_cam_method}' as pytorch-grad-cam is not available.")
        continue

    print(f"\n🚀 Running XAI analysis for method: {selected_cam_method.upper()}")
    print("=" * 60)

    # Run the main XAI analysis
    try:
        run_metrics = utils.run_xai_test(
            model_path=config.CHECKPOINT_PATH,
            test_loader=test_loader,
            mean=config.MEAN,
            std=config.STD,
            test_root_dir=config.TEST_ROOT_DIR,
            output_dir=config.OUTPUT_DIR,
            num_examples=config.NUM_EXAMPLES_TO_VISUALIZE,
            num_classes=config.NUM_CLASSES,
            gradcam_method=selected_cam_method,
            apply_mask_radius=config.APPLY_MASK_RADIUS_PIXELS,
            target_block_indices=config.TARGET_BLOCK_INDICES
        )

        if run_metrics:
            analysis_results[selected_cam_method] = run_metrics
            print(f"✅ XAI analysis complete for {selected_cam_method.upper()}")
            print(f"   📊 Accuracy: {run_metrics['accuracy']:.4f}")
            print(f"   📏 MAE: {run_metrics['mae']:.4f}")
            if run_metrics['quadratic_weighted_kappa'] is not None:
                print(f"   🎯 QWK: {run_metrics['quadratic_weighted_kappa']:.4f}")
            else:
                print(f"   🎯 QWK: N/A")
        else:
            print(f"❌ XAI analysis failed for {selected_cam_method.upper()}. See logs above.")

    except Exception as e:
        print(f"❌ Error during XAI analysis for {selected_cam_method}: {e}")
        traceback.print_exc()

print(f"\n🏁 All XAI analyses complete!")

## Results Summary

In [None]:
# ==============================================================================
# Results Summary and Comparison
# ==============================================================================

print(f"\n=== Final Results Summary ===")
print("=" * 50)

if analysis_results:
    print(f"🎉 Successfully completed analysis for {len(analysis_results)} method(s):")
    
    # Create summary table
    print(f"\n{'Method':<15} {'Accuracy':<10} {'MAE':<8} {'QWK':<8}")
    print("-" * 45)
    
    for method, metrics in analysis_results.items():
        accuracy = metrics['accuracy']
        mae = metrics['mae']
        qwk = metrics['quadratic_weighted_kappa']
        qwk_str = f"{qwk:.4f}" if qwk is not None else "N/A"
        print(f"{method.upper():<15} {accuracy:<10.4f} {mae:<8.4f} {qwk_str:<8}")
    
    print(f"\n📁 All outputs saved to: {config.OUTPUT_DIR}")
    print(f"   📊 Individual visualizations: {config.NUM_EXAMPLES_TO_VISUALIZE} per method")
    print(f"   📈 Average heatmaps: 5 PAI classes per method")
    print(f"   📋 Performance metrics: CSV, JSON, and confusion matrix plots")
    
else:
    print("❌ No successful analyses completed.")

print(f"\n🎊 Analysis workflow complete!")

## Additional Analysis Options

In [None]:
# ==============================================================================
# Optional: Additional Analysis and Exploration
# ==============================================================================

# You can add additional analysis here, such as:
# - Comparing results between methods
# - Loading and displaying specific visualizations
# - Statistical analysis of the results
# - Custom visualizations

# Example: Display method comparison if multiple methods were run
if len(analysis_results) > 1:
    print(f"\n=== Method Comparison ===")
    
    # Find best performing method by accuracy
    best_method = max(analysis_results.items(), key=lambda x: x[1]['accuracy'])
    print(f"🏆 Best accuracy: {best_method[0].upper()} ({best_method[1]['accuracy']:.4f})")
    
    # Find best performing method by QWK (if available)
    qwk_results = {k: v for k, v in analysis_results.items() if v['quadratic_weighted_kappa'] is not None}
    if qwk_results:
        best_qwk_method = max(qwk_results.items(), key=lambda x: x[1]['quadratic_weighted_kappa'])
        print(f"🎯 Best QWK: {best_qwk_method[0].upper()} ({best_qwk_method[1]['quadratic_weighted_kappa']:.4f})")

# Example: Show output directory structure
print(f"\n=== Output Directory Structure ===")
if os.path.exists(config.OUTPUT_DIR):
    for item in os.listdir(config.OUTPUT_DIR):
        item_path = os.path.join(config.OUTPUT_DIR, item)
        if os.path.isdir(item_path):
            print(f"📁 {item}/")
            # Count files in subdirectory
            try:
                file_count = len([f for f in os.listdir(item_path) if os.path.isfile(os.path.join(item_path, f))])
                print(f"   └─ {file_count} files")
            except:
                pass
        else:
            print(f"📄 {item}")

print(f"\n✨ Notebook execution complete! ✨")