In [None]:
import warnings
from typing import Tuple, Dict
import numpy as np
import torch
import torchvision
from captum.attr import Lime, GradientShap, InputXGradient, IntegratedGradients
import matplotlib.pyplot as plt
import seaborn as sns
import quantus
from quantus import FaithfulnessCorrelation, RelativeInputStability, Sparseness
sns.set()

# Suppress warnings
warnings.filterwarnings('ignore')

# Define device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

quantus.AVAILABLE_XAI_METHODS_CAPTUM


In [None]:
# Dataset preparation
BATCH_SIZE = 256
NUM_WORKERS = 0 # Increase for faster data loading
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_set = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=transform, download=True)
test_set = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=NUM_WORKERS)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=200, pin_memory=True, num_workers=NUM_WORKERS)

# Display a batch of images
def display_images(images: torch.Tensor, labels: torch.Tensor, num_images: int = 7):
    fig, axes = plt.subplots(1, num_images, figsize=(num_images * 3, 3))
    for i in range(num_images):
        axes[i].imshow(images[i].cpu().squeeze(), cmap="gray")
        axes[i].set_title(f"Class: {labels[i].item()}")
        axes[i].axis("off")
    plt.show()

display_images(*next(iter(train_loader)))



In [None]:
import torch
import torch.nn as nn

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # First Convolutional Block
        self.conv_1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)  # 28x28 -> 28x28
        self.bn_1 = nn.BatchNorm2d(16)
        self.relu_1 = nn.ReLU()
        self.pool_1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 28x28 -> 14x14

        # Second Convolutional Block
        self.conv_2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)  # 14x14 -> 14x14
        self.bn_2 = nn.BatchNorm2d(32)
        self.relu_2 = nn.ReLU()
        self.pool_2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 14x14 -> 7x7

        # Fully Connected Layers
        self.fc_1 = nn.Linear(32 * 7 * 7, 128)  # Adjusted from 256 → 128
        self.relu_3 = nn.ReLU()
        self.dropout_1 = nn.Dropout(0.4)  # Dropout for regularization
        self.fc_2 = nn.Linear(128, 64)  # Adjusted from 120 → 64
        self.relu_4 = nn.ReLU()
        self.dropout_2 = nn.Dropout(0.4)  # Dropout for regularization
        self.fc_3 = nn.Linear(64, 10)  # Output layer for 10 classes

    def forward(self, x):
        # Convolutional Layers
        x = self.pool_1(self.relu_1(self.bn_1(self.conv_1(x))))
        x = self.pool_2(self.relu_2(self.bn_2(self.conv_2(x))))

        # Flatten for Fully Connected Layers
        x = x.view(x.size(0), -1)  # Flatten
        x = self.relu_3(self.dropout_1(self.fc_1(x)))
        x = self.relu_4(self.dropout_2(self.fc_2(x)))
        x = self.fc_3(x)
        return x


In [None]:
# Model training
def train_model(
    model: torch.nn.Module,
    train_loader: torch.utils.data.DataLoader,
    test_loader: torch.utils.data.DataLoader,
    device: torch.device,
    epochs: int = 10,
    lr: float = 0.005,
    weight_decay: float = 1e-4
) -> torch.nn.Module:
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    model.to(device)

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        # Evaluation
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}, Test Accuracy: {accuracy:.2f}%")

    return model



In [None]:
# Model evaluation
def evaluate_model(
    model: torch.nn.Module,
    data_loader: torch.utils.data.DataLoader,
    device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
    model.eval()
    logits = []
    targets = []
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            logits.append(outputs)
            targets.append(labels)
    return torch.cat(logits), torch.cat(targets)



In [None]:
# Training and testing
model = LeNet()
trained_model = train_model(model, train_loader, test_loader, device)

# Evaluation
predictions, labels = evaluate_model(trained_model, test_loader, device)
test_accuracy = (predictions.argmax(dim=1) == labels).float().mean().item()
print(f"Final Test Accuracy: {test_accuracy * 100:.2f}%")



In [None]:
# Explanation methods
explanations = {
    "Lime": Lime(trained_model),
    "GradientShap": GradientShap(trained_model),
    "InputXGradient": InputXGradient(trained_model),
    "IntegratedGradients": IntegratedGradients(trained_model),
}

x_batch, y_batch = next(iter(test_loader))
x_batch, y_batch = x_batch.to(device), y_batch.to(device)



In [None]:
# Normalized attributions
attributions = {}
for name, method in explanations.items():
    if name == "GradientShap" or name == "IntegratedGradients":
        baseline = torch.zeros_like(x_batch)
        attributions[name] = quantus.normalise_by_negative(
            method.attribute(inputs=x_batch, baselines=baseline, target=y_batch).sum(dim=1).cpu().detach().numpy()
        )
    else:
        attributions[name] = quantus.normalise_by_negative(
            method.attribute(inputs=x_batch, target=y_batch).sum(dim=1).cpu().detach().numpy()
        )

# Visualization
for name, attr in attributions.items():
    print(f"{name} Attributions: {attr.shape}")


In [None]:
a_batch_lime = quantus.normalise_func.normalise_by_negative(Lime(model).attribute(inputs=x_batch, target=y_batch).sum(axis=1).detach().to(device).numpy())
a_batch_gradient = quantus.normalise_func.normalise_by_negative(GradientShap(model).attribute(inputs=x_batch, baselines=torch.zeros_like(x_batch), target=y_batch).sum(axis=1).detach().to(device).numpy())
a_batch_inputXgradient = quantus.normalise_func.normalise_by_negative(InputXGradient(model).attribute(inputs=x_batch, target=y_batch).sum(axis=1).detach().to(device).numpy())
a_batch_intgrad = quantus.normalise_func.normalise_by_negative(IntegratedGradients(model).attribute(inputs=x_batch, target=y_batch, baselines=torch.zeros_like(x_batch)).sum(axis=1).detach().to(device).numpy())


x_batch, y_batch = x_batch.to(device).numpy(), y_batch.to(device).numpy()

assert [isinstance(obj, np.ndarray) for obj in [x_batch, y_batch, a_batch_lime, a_batch_gradient, a_batch_inputXgradient, a_batch_intgrad]]

In [None]:
# List of attribution methods for dynamic processing
attribution_methods = list(attributions.keys())

# Plot attributions
nr_images = 4
fig, axes = plt.subplots(nrows=nr_images, ncols=len(attribution_methods) + 1, figsize=(15, 12))

for i in range(nr_images):
    # Input image
    axes[i, 0].imshow((np.reshape(x_batch[i], (28, 28)) * 255).astype(np.uint8), cmap="gray")
    axes[i, 0].title.set_text(f"FMNIST class {y_batch[i].item()}")
    axes[i, 0].axis("off")

    # Loop through attribution methods
    for j, method in enumerate(attribution_methods):
        axes[i, j + 1].imshow(attributions[method][i], cmap="seismic", vmin=-0.95, vmax=1)
        axes[i, j + 1].title.set_text(method)
        axes[i, j + 1].axis("off")

plt.tight_layout()
plt.show()


# Quantitative Analysis 

### Faithfulness Correlation
Measures how well the explanation aligns with the model's predictions.
https://quantus.readthedocs.io/en/latest/docs_api/quantus.metrics.faithfulness.faithfulness_correlation.html 


In [None]:
# Initialize the FaithfulnessCorrelation metric
faithfulness_metric = quantus.FaithfulnessCorrelation(
    nr_runs=50, # Number of runs to compute the faithfulness score (default 200 was really slow)
    subset_size=224,
    perturb_baseline="black",
    perturb_func=quantus.perturb_func.baseline_replacement_by_indices,
    similarity_func=quantus.similarity_func.correlation_pearson,
    disable_warnings=True,
    normalise=True,
    abs=True
)

# Dictionary to store faithfulness scores for each method
faithfulness_scores = {}

# Compute faithfulness scores for each attribution method
for method_name, method_attributions in attributions.items():
    # Ensure attributions are torch tensors
    method_attributions = torch.tensor(method_attributions, dtype=torch.float32, device=device)
    
    faithfulness_scores[method_name] = faithfulness_metric(
        model=model,
        x_batch=x_batch,
        y_batch=y_batch,
        a_batch=method_attributions,
        device=device,
        explain_func=quantus.explain,
        explain_func_kwargs={"method": method_name}
    )

# Display the faithfulness scores
faithfulness_scores


In this evaluation, higher values suggest a stronger correspondence between the explanation and the actual model behavior.

The results obtained are as follows:

- Lime: 0.0071
- GradientShap: 0.1066
- InputXGradient: 0.1190
- IntegratedGradients: 0.1195

These findings suggest that IntegratedGradients and InputXGradient provide explanations more closely aligned with the model’s predictions, while the explanations generated by Lime and GradientShap appear to be less faithful in capturing the model’s underlying decision-making process.

### Relative Input Stability
Evaluates the robustness of explanations under slight input perturbations. The lower, the better.
https://quantus.readthedocs.io/en/latest/docs_api/quantus.metrics.robustness.html  


In [None]:
# Initialize the RelativeInputStability metric
relative_input_stability_metric = quantus.RelativeInputStability(
    nr_samples=5, # Number of samples to compute the stability score (default 200 was really slow)
    perturb_func=quantus.perturb_func.uniform_noise,
    normalise=True,
    disable_warnings=True,
    abs=True
)

# Ensure tensors are on the correct device and dtype
predictions, labels = evaluate_model(model, test_loader, device)
predicted_labels = predictions.argmax(dim=1)

# Filter correctly classified samples
correct_indices = (predicted_labels == labels).nonzero(as_tuple=True)[0]
print(f"Number of correctly classified: {len(correct_indices)}")

correct_indices = correct_indices[correct_indices < len(x_batch)]

# Ensure correct indices are within bounds
correct_x_batch = x_batch[correct_indices]
correct_y_batch = y_batch[correct_indices]

# Dictionary to store stability scores
relative_input_stability_scores = {}

# Compute stability scores for each attribution method
for method_name, method_attributions in attributions.items():
    relative_input_stability_scores[method_name] = relative_input_stability_metric(
        model=model,
        x_batch=correct_x_batch,
        y_batch=correct_y_batch,
        a_batch=method_attributions[correct_indices],
        device=device,
        explain_func=quantus.explain,
        explain_func_kwargs={"method": method_name}
    )

# Calculate average scores
average_results = {
    method: np.nanmean(scores) for method, scores in relative_input_stability_scores.items()
}

# Display the results
average_results


Relative Input Stability metric reflects how resistant each attribution method is to slight perturbations in the input data. In this case, lower values indicate more robust explanations.

The average scores obtained are as follows:
- Lime: 14.0141
- GradientShap: 1.8824
- InputXGradient: 1.9936
- IntegratedGradients: 1.8695

These results suggest that Lime’s explanations are significantly less stable compared to the other methods. In contrast, IntegratedGradients, GradientShap, and InputXGradient produce explanations that remain more consistent and robust when inputs are slightly altered.

### Sparsity
Assesses the simplicity of explanations by checking how many features contribute significantly. 
https://quantus.readthedocs.io/en/latest/docs_api/quantus.metrics.complexity.sparseness.html 


In [None]:
# Initialize the Sparsity metric
sparsity_metric = quantus.Sparseness(normalise=True)

# Ensure tensors are on the correct device and dtype
predictions, labels = evaluate_model(model, test_loader, device)
predicted_labels = predictions.argmax(dim=1)

# Filter correctly classified samples
correct_indices = (predicted_labels == labels).nonzero(as_tuple=True)[0]
print(f"Number of correctly classified: {len(correct_indices)}")

correct_indices = correct_indices[correct_indices < len(x_batch)]

# Extract correctly classified samples
correct_x_batch = x_batch[correct_indices]
correct_y_batch = y_batch[correct_indices]

# Dictionary to store sparsity scores
sparsity_scores = {}

# Compute sparsity scores for each attribution method
for method_name, method_attributions in attributions.items():
    sparsity_scores[method_name] = sparsity_metric(
        model=model,
        x_batch=correct_x_batch,
        y_batch=correct_y_batch,
        a_batch=method_attributions[correct_indices],
        device=device,
        explain_func=quantus.explain,
        explain_func_kwargs={"method": method_name}
    )

# Calculate average sparsity results
average_sparsity_results = {
    method: np.nanmean(scores) for method, scores in sparsity_scores.items()
}

# Display the results
average_sparsity_results


Sparsity indicate how simple or concentrated the explanations are by measuring how many features have a significant impact. Higher values suggest that only a few features dominate the explanation.

The average sparsity scores are as follows:

- Lime: 0.9835
- GradientShap: 0.7650
- InputXGradient: 0.7574
- IntegratedGradients: 0.7642

These results suggest that Lime’s explanations are more feature-sparse (relying on fewer dominant features), while GradientShap, IntegratedGradients, and InputXGradient yield explanations that are comparatively more distributed across multiple features.

#### Is any method consistently the best across all metrics, or do trade-offs exist?
No single attribution method clearly outperforms the others across all three metrics. Instead, there are notable trade-offs:

- Faithfulness Correlation: IntegratedGradients and InputXGradient yield higher scores, indicating their explanations more closely align with the model’s predictions. Lime and GradientShap underperform in this regard.

- Relative Input Stability: IntegratedGradients, GradientShap, and InputXGradient produce more stable explanations. Lime, despite its high sparsity, is notably less stable.

- Sparsity: Lime’s explanations are the most feature-sparse, focusing on fewer dominant features. However, this simplicity comes at the cost of reduced faithfulness and stability.

In conclusion, IntegratedGradients and InputXGradient strike a strong balance between faithfulness and stability, while Lime provides highly sparse explanations but lags in other areas. This indicates that the “best” method depends on the particular requirements and priorities for explanation quality.

#### Does the quantitative evaluation align with the qualitative observations?
The quantitative results are largely consistent with the qualitative observations. Lime explanations, which the metrics indicate are highly sparse but less faithful and stable, visually appear to highlight only a handful of pixels, often failing to capture the overall shape of the object. On the other hand, methods like IntegratedGradients and InputXGradient, which scored higher in both faithfulness and stability, produce attributions that more closely follow the contours and features of the object. GradientShap’s results, though less sparse, provide more evenly distributed attributions that align better with the underlying structure than Lime does.

In essence, the metrics’ outcomes and the visual inspection both point towards the conclusion that methods yielding more consistent, broadly distributed attributions (e.g., IntegratedGradients and InputXGradient) produce explanations that qualitatively appear more coherent and faithful to the model’s reasoning, while Lime’s sparse yet unstable explanations resonate with its poorer performance on the quantitative evaluations.