# ExAI - Explainable Corgi (Cardigan) Separator 🐶

We use [Contrastive GradCAM](https://xai-blog.netlify.app/docs/groups/contrastive-grad-cam-consistency/#contrastive-grad-cam-consistency-loss)
and [Layerwise Relevance Propagation](https://github.com/kaifishr/PyTorchRelevancePropagation) to explain the difference between Corgis and Cardigans - two breeds that are often difficult to distinguish visually even for dog experts.

Key visual differences include tail length (Cardigans have longer tails), ear size (Cardigans have larger ears), and body length (Cardigans typically have longer bodies). Our XAI techniques aim to determine if these are indeed the features our model focuses on when making classifications.

- We leverage [Stanford ImageNet Dog Dataset](http://vision.stanford.edu/aditya86/ImageNetDogs/) for fine-tuning [ResNet50](https://pytorch.org/hub/pytorch_vision_resnet/#model-description).
- Target breeds: [Pembroke Welsh Corgi](https://de.wikipedia.org/wiki/Welsh_Corgi_Pembroke) | [Cardigan Welsh Corgi](https://de.wikipedia.org/wiki/Welsh_Corgi_Cardigan)

## The Process:
1. Load the dataset and split it into training (80%) and validation (20%) sets.
2. Fine-tune a pre-trained ResNet50 model using transfer learning (freezing early layers).
3. Evaluate model performance through accuracy metrics and confusion matrix.
4. Apply two XAI techniques to visualize decision factors:
   - **GradCAM**: Highlights regions that most influenced the class prediction
   - **LRP**: Provides pixel-level relevance scores for the entire image
5. Compare both techniques to understand if the model focuses on breed-specific anatomical features.
6. Analyze whether our model's reasoning aligns with established breed characteristics.

## 1. Data/Dependency Loading and Transformation

This section covers the foundational setup for our Corgi classification pipeline, including:
- Import of essential libraries and dependencies
- Data acquisition from Stanford Dogs Dataset
- Custom dataset class implementation for Corgi breeds
- Data augmentation and preprocessing transformations
- Creation of training and validation data loaders

### Import of all necessary packages/libraries 

We import PyTorch and related libraries for deep learning, along with NumPy, Matplotlib, and other data processing tools. These packages enable us to build our classification pipeline, handle image data, train our CNN model, and visualize the XAI results for comparing Pembroke and Cardigan Welsh Corgis.

In [None]:
# Setup and Imports
import os
import time #time measurement and delays 
import copy 
import numpy as np
import matplotlib.pyplot as plt # for plotting
import seaborn as sns # statistical data visualization
from tqdm import tqdm # progress bars for loops
from PIL import Image # image loading
import cv2 #image processing 

# PyTorch imports
import torch
import torch.nn as nn # neural network layers
import torch.optim as optim # optimization algorithms
from torch.optim import lr_scheduler # learning rate scheduling
from torch.utils.data import Dataset, DataLoader, random_split # dataset handling
import torchvision # for image transformations
from torchvision import transforms, models # pre-trained models
from torchvision.datasets.utils import download_url, extract_archive # downloading datasets

# For evaluation metrics
from sklearn.metrics import confusion_matrix, classification_report # for evaluation

### Setting device for GPU acceleration

We configure PyTorch to utilize available GPU resources through CUDA, significantly accelerating the training process and matrix operations. If no GPU is available, the code automatically falls back to CPU processing, ensuring compatibility across different hardware environments.

In [7]:
# Setup device for training
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Mounting Google Drive...
/
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
-= Done =-


### Downloading Dataset if not already in directory
[Stanford Dogs Dataset](http://vision.stanford.edu/aditya86/ImageNetDogs/):

> The Stanford Dogs dataset contains images of 120 breeds of dogs from around the world. This dataset has been built using images and annotation from ImageNet for the task of fine-grained image categorization.


In [2]:
def download_and_extract_dataset(download_dir, extract_dir):
    os.makedirs(download_dir, exist_ok=True)
    os.makedirs(extract_dir, exist_ok=True)
    
    # Download the dataset
    dataset_url = "http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar"
    filename = os.path.basename(dataset_url)
    filepath = os.path.join(download_dir, filename)
    
    if not os.path.exists(filepath):
        print(f"Downloading {filename}...")
        download_url(dataset_url, download_dir)
    else:
        print(f"File {filename} already exists in {download_dir}")
    
    # Extract the dataset
    if not os.path.exists(os.path.join(extract_dir, "Images")):
        print(f"Extracting {filename} to {extract_dir}...")
        extract_archive(filepath, extract_dir)
    else:
        print(f"Dataset already extracted to {extract_dir}")

Downloaded images.tar to /content/drive/MyDrive/xAI-Corgis/images.tar


### Corgi Dataset Class 

In [None]:

class CorgiDataset(Dataset):
    def __init__(self, dataset_root, transform=None):
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.class_names = ['Pembroke', 'Cardigan']
        
        images_dir = os.path.join(dataset_root, "Images")
        if not os.path.exists(images_dir):
            raise FileNotFoundError(f"Images directory not found at {images_dir}")
            
        all_breeds = os.listdir(images_dir)
        
        pembroke_dir = None
        cardigan_dir = None
        
        for breed in all_breeds:
            if "Pembroke" in breed:
                pembroke_dir = os.path.join(images_dir, breed)
            elif "Cardigan" in breed:
                cardigan_dir = os.path.join(images_dir, breed)
        
        if not pembroke_dir or not cardigan_dir:
            raise ValueError("Could not find Pembroke or Cardigan directories")
        
        print(f"Pembroke directory: {pembroke_dir}")
        print(f"Cardigan directory: {cardigan_dir}")
        
        for img_name in os.listdir(pembroke_dir):
            if img_name.endswith(('.jpg', '.jpeg', '.png')):
                self.image_paths.append(os.path.join(pembroke_dir, img_name))
                self.labels.append(0)  # Pembroke
        
        for img_name in os.listdir(cardigan_dir):
            if img_name.endswith(('.jpg', '.jpeg', '.png')):
                self.image_paths.append(os.path.join(cardigan_dir, img_name))
                self.labels.append(1)  # Cardigan
        
        print(f"Total number of images: {len(self.image_paths)}")
        print(f"Pembroke images: {sum(1 for label in self.labels if label == 0)}")
        print(f"Cardigan images: {sum(1 for label in self.labels if label == 1)}")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            label = self.labels[idx]
            
            if self.transform:
                image = self.transform(image)
                
            return image, label
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a blank image and the same label
            blank_image = torch.zeros((3, 224, 224)) if self.transform else Image.new('RGB', (224, 224), (0, 0, 0))
            return blank_image, self.labels[idx]

### Class for transforming Subset

When splitting data into training and validation sets, we get Subset objects that don't directly support transformations. 

TransformedSubset enables us to apply different data strategies to training data versus validation data, which is essential to avoid data leakage.

In [8]:
class TransformedSubset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
    
    def __getitem__(self, idx):
        # Correctly handle the subset indexing
        image, label = self.subset.dataset[self.subset.indices[idx]]
        if self.transform:
            image = self.transform(image)
        return image, label
    
    def __len__(self):
        return len(self.subset)

images.tar successfully extracted to: '/content/dogs'.


### Data Preparation and Loaders

This section creates our data pipeline for efficient model training. It has different functions:
1. **Data Transformations**: Defines separate strategies for training.
2. **Dataset Splitting**: Creates an 80/20 train/validation split.
3. **Optimized Loading**: Configures DataLoaders with batch processing, etc.

This approach ensures our model trains on varied examples while being evaluated on consistent, unmodified validation data.

In [None]:
def prepare_dataloaders(dataset_root, batch_size=32, num_workers=2):
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    
    dataset = CorgiDataset(dataset_root, transform=None)  # No transform yet
    
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    
    torch.manual_seed(42)
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    train_dataset_transformed = TransformedSubset(train_dataset, data_transforms['train'])
    val_dataset_transformed = TransformedSubset(val_dataset, data_transforms['val'])
    
    train_loader = DataLoader(
        train_dataset_transformed, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    val_loader = DataLoader(
        val_dataset_transformed, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    print(f"Training set size: {len(train_dataset)} images")
    print(f"Validation set size: {len(val_dataset)} images")
    print(f"Training batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    
    return train_loader, val_loader, dataset.class_names

## 2. Model Definition

This section implements the transfer learning approach using ResNet50:

1. **Pre-trained Model Loading**: We load ResNet50 with weights pre-trained on ImageNet
2. **Layer Freezing Strategy**: We employ a strategic freezing pattern.
3. **Custom Classification Head**: The original fully-connected layer is replaced with ReLU (Rectified Linear Unit)

This approach dramatically reduces training time and required data while maintaining high accuracy.

### Model Setup

In [None]:

def setup_model(num_classes=2):
    model = models.resnet50(pretrained=True)
    
    for param in model.parameters():
        param.requires_grad = False
    
    for param in model.layer4.parameters():
        param.requires_grad = True
    
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(num_features, 512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, num_classes)
    )
    
    model = model.to(device)
    
    print(f"ResNet50 model configured for {num_classes} classes")
    print(f"Trainable parameters in model.layer4: {sum(p.numel() for p in model.layer4.parameters() if p.requires_grad)}")
    print(f"Trainable parameters in model.fc: {sum(p.numel() for p in model.fc.parameters() if p.requires_grad)}")
    
    return model

### Training Functions

This section contains the core training pipeline components that manage the model training process. 

#### Training Function: Epoch

This function handles a single training iteration through the entire dataset. For each batch of images:
1. It transfers data to GPU/CPU and zeroes gradients
2. Performs predictions 
3. Calculates loss using cross-entropy (loss function)
4. Executes backward propagation to compute gradients
5. Updates model weights via optimizer


In [None]:

def train_epoch(model, dataloader, criterion, optimizer):
    model.train()  
    running_loss = 0.0
    running_corrects = 0
    
    for inputs, labels in tqdm(dataloader, desc="Training"):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
    
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_corrects.double() / len(dataloader.dataset)
    
    print(f'Training Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
    
    return epoch_loss, epoch_acc.item()


#### Training Function: Validate Epoch

Evaluates model performance on validation data without updating weights.

In [None]:

def validate_epoch(model, dataloader, criterion):
    model.eval()  # Set model to evaluate mode
    running_loss = 0.0
    running_corrects = 0
    
    # Iterate over data
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Validation"):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Forward
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
    
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_corrects.double() / len(dataloader.dataset)
    
    print(f'Validation Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
    
    return epoch_loss, epoch_acc.item()

### Training the Model

This orchestration function manages the entire training cycle. It implements:

1. **Setting Up the Learning Process**: We use CrossEntropy and configure different learning speeds for different parts of the network
2. **Adaptive Learning Rate**: our scheduler reduces learning rates when improvement slows down
3. **Training Loop**: The main training cycle runs through our data multiple times (epochs), tracking both how well we're memorizing training data and how well we generalize to new images 
4. **Smart Quitting**: We stop training when validation accuracy doesn't improve for 5 consecutive epochs - preventing the model from just memorizing training examples
5. **Saving the Best Version**: Rather than keeping the final model, we save a copy of the model weights whenever it achieves a new high score on validation data

This approach maximizes efficiency by preventing overfitting while ensuring the model reaches optimal performance for our Corgi classification task.

In [None]:

def train_model(model, train_loader, val_loader, num_epochs=15, patience=5):
    criterion = nn.CrossEntropyLoss()
    
    optimizer = optim.Adam([
        {'params': model.layer4.parameters(), 'lr': 1e-4},
        {'params': model.fc.parameters(), 'lr': 1e-3}
    ], weight_decay=1e-5)
    
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)
    
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    no_improve_epochs = 0
    
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 40)
        
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        
        val_loss, val_acc = validate_epoch(model, val_loader, criterion)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        scheduler.step(val_loss)
        
        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1
        
        print(f'Best val Acc: {best_acc:.4f}')
        
        if no_improve_epochs >= patience:
            print(f'Early stopping after {epoch+1} epochs without improvement')
            break
        
        print()
    
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')
    
    model.load_state_dict(best_model_wts)
    return model, history

### Evaluating Model

This function tests how well our model recognizes different Corgis:

1. It runs validation images through the model and compares predictions to actual labels
2. Creates a colorful grid (confusion matrix) showing correct guesses vs. mistakes
3. Calculates accuracy scores for each breed
4. Returns all results for analysis

Basically, it's like giving our model a final exam and creating a detailed report card!

In [None]:

def evaluate_model(model, dataloader, class_names):
    model.eval()
    y_true = []
    y_pred = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
    
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names,
                yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.show()
    
    report = classification_report(y_true, y_pred, target_names=class_names)
    print("\nClassification Report:")
    print(report)
    
    return y_true, y_pred, report

### Plotting the History

This function visualizes two key metrics:
1. **Loss trends** - Shows how the error decreases during training
2. **Accuracy trends** - Shows how prediction accuracy improves

Both metrics are plotted for training and validation data, helping identify when the model starts overfitting (when validation metrics worsen while training metrics continue to improve).

In [None]:
def plot_training_history(history):
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.show()


### Saving the Model

In [None]:

def save_model(model, save_path, class_names, optimizer=None, epoch=None, history=None):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    torch.save(model, save_path.replace('.pth', '_full.pth'))
    print(f"Complete model saved to: {save_path.replace('.pth', '_full.pth')}")
    
    torch.save(model.state_dict(), save_path.replace('.pth', '_weights.pth'))
    print(f"Model weights saved to: {save_path.replace('.pth', '_weights.pth')}")
    
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'classes': class_names,
    }
    
    if optimizer:
        checkpoint['optimizer_state_dict'] = optimizer.state_dict()
    
    if epoch is not None:
        checkpoint['epoch'] = epoch
        
    if history:
        checkpoint['history'] = history
    
    torch.save(checkpoint, save_path)
    print(f"Checkpoint saved to: {save_path}")
    
    try:
        dummy_input = torch.randn(1, 3, 224, 224).to(device)
        torch.onnx.export(
            model,
            dummy_input,
            save_path.replace('.pth', '.onnx'),
            export_params=True,
            opset_version=11,
            do_constant_folding=True,
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
        )
        print(f"ONNX model saved to: {save_path.replace('.pth', '.onnx')}")
    except Exception as e:
        print(f"Error exporting to ONNX format: {e}")

### Loading the Model

In [None]:

def load_model(load_path, model=None):
    """
    Loads a model from a file
    
    Args:
        load_path: Path to load the model from
        model: Model to load weights into (optional)
        
    Returns:
        model: The loaded model
        checkpoint: The loaded checkpoint
    """
    try:
        checkpoint = torch.load(load_path, map_location=device)
        
        if model is None:
            # Try loading the full model
            if load_path.endswith('_full.pth'):
                model = torch.load(load_path, map_location=device)
                print(f"Full model loaded from: {load_path}")
                return model, None
            
            # Otherwise create a new model
            model = setup_model()
        
        # Load state dict if it exists
        if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)
        
        print(f"Model weights loaded from: {load_path}")
        return model, checkpoint
    except Exception as e:
        print(f"Error loading model: {e}")
        return None, None


## 3. Model Execution

In [None]:

def main():
    download_dir = "/content/drive/MyDrive/xAI-Corgis" # TODO: Change to local directory
    extract_dir = "/content/dogs" # TODO: Change to local directory 
    
    download_and_extract_dataset(download_dir, extract_dir)
    
    train_loader, val_loader, class_names = prepare_dataloaders(extract_dir, batch_size=32)
    
    model = setup_model(num_classes=len(class_names))
    
    os.makedirs(download_dir, exist_ok=True)
    model_path = os.path.join(download_dir, 'resnet50_corgi_classifier.pth')
    if os.path.exists(model_path):
        print("Loading pre-trained model...")
        model, checkpoint = load_model(model_path)
    else:
        print("Training a new model...")
        model, history = train_model(model, train_loader, val_loader, num_epochs=15)
        
        plot_training_history(history)
        
        y_true, y_pred, report = evaluate_model(model, val_loader, class_names)
        
        save_model(
            model, 
            model_path,
            class_names=class_names,
            history=history
        )

    print("\n" + "="*50)
    print("Applying XAI Methods for Model Interpretability")
    print("="*50)
    
    print("\nGenerating GradCAM visualizations...")
    visualize_gradcam(model, val_loader, class_names, num_images=5)
    
    print("\nGenerating Layer-wise Relevance Propagation visualizations...")
    visualize_lrp(model, val_loader, class_names, num_images=5)
    
    print("\nComparing GradCAM and LRP methods...")
    compare_xai_methods(model, val_loader, class_names, num_images=3)
    
    print("\nXAI visualization complete. All results saved as PNG files.")

if __name__ == "__main__":
    main()

## 4. xAI Methods 

Talk about the xAI methods

### GradCAM

In [None]:

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.hooks = []
        self.gradients = None
        self.activations = None
        self.register_hooks()
        self.model.eval()
        
    def register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output.detach()
            
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()
            
        forward_handle = self.target_layer.register_forward_hook(forward_hook)
        backward_handle = self.target_layer.register_full_backward_hook(backward_hook)
        
        self.hooks = [forward_handle, backward_handle]
        
    def remove_hooks(self):
        """Removes all registered hooks"""
        for hook in self.hooks:
            hook.remove()
            
    def __call__(self, input_tensor, target_class=None):
        input_tensor = input_tensor.to(device)
        
        self.model.zero_grad()
        
        output = self.model(input_tensor)
        
        if target_class is None:
            target_class = torch.argmax(output, dim=1).item()
        
        one_hot = torch.zeros_like(output)
        one_hot[0, target_class] = 1
        
        output.backward(gradient=one_hot, retain_graph=True)
        
        pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
        
        for i in range(pooled_gradients.shape[0]):
            self.activations[:, i, :, :] *= pooled_gradients[i]
        
        cam = torch.mean(self.activations, dim=1).squeeze()
        
        cam = torch.maximum(cam, torch.tensor(0.0).to(device))
        
        if torch.max(cam) > 0:
            cam = cam / torch.max(cam)
        
        cam = cam.cpu().numpy()
        
        return cam

def apply_gradcam(model, img_tensor, img_np, target_class=None, layer_name='layer4'):
    target_layer = model.layer4
    
    grad_cam = GradCAM(model, target_layer)
    
    cam = grad_cam(img_tensor, target_class)
    
    cam_resized = cv2.resize(cam, (img_np.shape[1], img_np.shape[0]))
    
    heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
    
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    
    alpha = 0.4
    visualization = heatmap * alpha + img_np * (1 - alpha)
    visualization = np.uint8(visualization)
    
    grad_cam.remove_hooks()
    
    return visualization, cam

def visualize_gradcam(model, dataloader, class_names, num_images=5):
    model.eval()
    
    images, labels = next(iter(dataloader))
    images = images[:num_images]
    labels = labels[:num_images]
    
    fig, axes = plt.subplots(num_images, 3, figsize=(15, 5 * num_images))
    
    for i, (image, label) in enumerate(zip(images, labels)):
        img_np = image.cpu().numpy().transpose(1, 2, 0)
        img_np = np.clip(img_np * np.array([0.229, 0.224, 0.225]) + 
                        np.array([0.485, 0.456, 0.406]), 0, 1)
        
        input_tensor = image.unsqueeze(0).to(device)
        
        with torch.no_grad():
            output = model(input_tensor)
            _, pred = torch.max(output, 1)
            prob = torch.nn.functional.softmax(output, dim=1)
        
        true_cam, _ = apply_gradcam(model, input_tensor, img_np, label.item())
        
        pred_cam, _ = apply_gradcam(model, input_tensor, img_np, pred.item())
        
        axes[i, 0].imshow(img_np)
        axes[i, 0].set_title(f"True: {class_names[label]}\nPred: {class_names[pred]} ({prob[0][pred.item()]:.2f})")
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(true_cam)
        axes[i, 1].set_title(f"GradCAM for {class_names[label]}")
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(pred_cam)
        axes[i, 2].set_title(f"GradCAM for {class_names[pred]}")
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig('gradcam_visualizations.png')
    plt.show()

### Layer-wise Relevance Propagation (LRP) Implementation

In [None]:

class LRP:
    def __init__(self, model, epsilon=1e-9):
        self.model = model
        self.epsilon = epsilon
        self.model.eval()
        
    def _clone_module(self, module, memo=None):
        if memo is None:
            memo = {}
        if id(module) in memo:
            return memo[id(module)]
        
        clone = copy.deepcopy(module)
        memo[id(module)] = clone
        
        return clone
    
    def _register_hooks(self, module, activations, relevances):
        forward_hooks = []
        backward_hooks = []
        
        def forward_hook(m, input, output):
            activations[id(m)] = output.detach()
            
        def backward_hook(m, grad_in, grad_out):
            if id(m) in activations:
                with torch.no_grad():
                    a = activations[id(m)]
                    if isinstance(m, nn.Conv2d):
                        if m.stride == (1, 1) and m.padding == (1, 1):
                            w = m.weight
                            w_pos = torch.clamp(w, min=0)
                            z = torch.nn.functional.conv2d(a, w_pos, bias=None, 
                                                          stride=m.stride, padding=m.padding)
                            s = (grad_out[0] / (z + self.epsilon)).data
                            c = torch.nn.functional.conv_transpose2d(s, w_pos, 
                                                                    stride=m.stride, padding=m.padding)
                            relevances[id(m)] = (a * c).data
                        else:
                            relevances[id(m)] = (a * grad_out[0]).data
                    elif isinstance(m, nn.Linear):
                        w = m.weight
                        w_pos = torch.clamp(w, min=0)
                        z = torch.matmul(a, w_pos.t())
                        s = (grad_out[0] / (z + self.epsilon)).data
                        c = torch.matmul(s, w_pos)
                        relevances[id(m)] = (a * c).data
                    else:
                        relevances[id(m)] = (a * grad_out[0]).data
        
        if isinstance(module, (nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU)):
            forward_hooks.append(module.register_forward_hook(forward_hook))
            backward_hooks.append(module.register_full_backward_hook(backward_hook))
        
        for child in module.children():
            f_hooks, b_hooks = self._register_hooks(child, activations, relevances)
            forward_hooks.extend(f_hooks)
            backward_hooks.extend(b_hooks)
            
        return forward_hooks, backward_hooks
    
    def __call__(self, input_tensor, target_class=None):
        input_tensor = input_tensor.clone().detach().to(device)
        input_tensor.requires_grad = True
        
        activations = {}
        relevances = {}
        
        forward_hooks, backward_hooks = self._register_hooks(self.model, activations, relevances)
        
        try:
            output = self.model(input_tensor)
            
            if target_class is None:
                target_class = torch.argmax(output, dim=1).item()
            
            one_hot = torch.zeros_like(output)
            one_hot[0, target_class] = 1.0
            
            self.model.zero_grad()
            output.backward(gradient=one_hot, retain_graph=True)
            
            input_gradient = input_tensor.grad.data
            
            first_layer_id = None
            for module in self.model.modules():
                if isinstance(module, nn.Conv2d):
                    first_layer_id = id(module)
                    break
            
            if first_layer_id in relevances:
                relevance_map = relevances[first_layer_id]
            else:
                relevance_map = input_gradient
                
            relevance_map = relevance_map.sum(dim=1).squeeze()
            
            relevance_map = torch.abs(relevance_map)
            if torch.max(relevance_map) > 0:
                relevance_map = relevance_map / torch.max(relevance_map)
            
            return relevance_map.cpu().numpy()
            
        finally:
            for hook in forward_hooks + backward_hooks:
                hook.remove()

def apply_lrp(model, img_tensor, img_np, target_class=None):
    lrp = LRP(model)
    
    relevance_map = lrp(img_tensor, target_class)
    
    relevance_resized = cv2.resize(relevance_map, (img_np.shape[1], img_np.shape[0]))
    
    heatmap = cv2.applyColorMap(np.uint8(255 * relevance_resized), cv2.COLORMAP_JET)
    
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    
    alpha = 0.4
    visualization = heatmap * alpha + img_np * (1 - alpha)
    visualization = np.uint8(visualization)
    
    return visualization, relevance_map

def visualize_lrp(model, dataloader, class_names, num_images=5):
    model.eval()
    
    images, labels = next(iter(dataloader))
    images = images[:num_images]
    labels = labels[:num_images]
    
    fig, axes = plt.subplots(num_images, 3, figsize=(15, 5 * num_images))
    
    for i, (image, label) in enumerate(zip(images, labels)):
        img_np = image.cpu().numpy().transpose(1, 2, 0)
        img_np = np.clip(img_np * np.array([0.229, 0.224, 0.225]) + 
                        np.array([0.485, 0.456, 0.406]), 0, 1)
        
        input_tensor = image.unsqueeze(0).to(device)
        
        with torch.no_grad():
            output = model(input_tensor)
            _, pred = torch.max(output, 1)
            prob = torch.nn.functional.softmax(output, dim=1)
        
        true_lrp, _ = apply_lrp(model, input_tensor, img_np, label.item())
        
        pred_lrp, _ = apply_lrp(model, input_tensor, img_np, pred.item())
        
        axes[i, 0].imshow(img_np)
        axes[i, 0].set_title(f"True: {class_names[label]}\nPred: {class_names[pred]} ({prob[0][pred.item()]:.2f})")
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(true_lrp)
        axes[i, 1].set_title(f"LRP for {class_names[label]}")
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(pred_lrp)
        axes[i, 2].set_title(f"LRP for {class_names[pred]}")
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig('lrp_visualizations.png')
    plt.show()

### Methods Comparison: GradCAM vs LRP

In [None]:

def compare_xai_methods(model, dataloader, class_names, num_images=3):
    model.eval()
    
    images, labels = next(iter(dataloader))
    images = images[:num_images]
    labels = labels[:num_images]
    
    fig, axes = plt.subplots(num_images, 3, figsize=(15, 5 * num_images))
    
    for i, (image, label) in enumerate(zip(images, labels)):
        img_np = image.cpu().numpy().transpose(1, 2, 0)
        img_np = np.clip(img_np * np.array([0.229, 0.224, 0.225]) + 
                        np.array([0.485, 0.456, 0.406]), 0, 1)
        
        input_tensor = image.unsqueeze(0).to(device)
        
        with torch.no_grad():
            output = model(input_tensor)
            _, pred = torch.max(output, 1)
            prob = torch.nn.functional.softmax(output, dim=1)
        
        gradcam_vis, _ = apply_gradcam(model, input_tensor, img_np, pred.item())
        
        lrp_vis, _ = apply_lrp(model, input_tensor, img_np, pred.item())
        
        axes[i, 0].imshow(img_np)
        axes[i, 0].set_title(f"Original\nTrue: {class_names[label]}\nPred: {class_names[pred]} ({prob[0][pred.item()]:.2f})")
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(gradcam_vis)
        axes[i, 1].set_title("GradCAM")
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(lrp_vis)
        axes[i, 2].set_title("Layer-wise Relevance Propagation")
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig('xai_comparison.png')
    plt.show()