In [None]:
import numpy as np
from sklearn.linear_model import Ridge, Lasso, lars_path
from skimage.segmentation import quickshift
from sklearn.metrics import pairwise_distances
from sklearn.utils import check_random_state
import torch
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import json
import os
from scipy import stats
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
import seaborn as sns


class LimeImageExplainer:
    def __init__(self, kernel_width=0.25):
        self.kernel_width = kernel_width

    def _generate_lars_path(self, weighted_data, weighted_labels):
        alphas, _, coef_path = lars_path(weighted_data, weighted_labels, method='lasso')
        feature_sets = []
        scores = []

        for coef in coef_path.T:
            nonzero_features = np.nonzero(coef)[0]
            feature_sets.append(nonzero_features)

            if len(nonzero_features) > 0:
                ridge = Ridge(alpha=0.01)
                ridge.fit(weighted_data[:, nonzero_features], weighted_labels)
                scores.append(ridge.score(weighted_data[:, nonzero_features], weighted_labels))
            else:
                scores.append(float('-inf'))

            if len(nonzero_features) >= 10:
                break

        return feature_sets, scores

    def _fit_lasso(self, data, labels, weights, num_features):
        weighted_data = data * np.sqrt(weights)[:, np.newaxis]
        weighted_labels = labels * np.sqrt(weights)

        feature_sets, scores = self._generate_lars_path(weighted_data, weighted_labels)

        best_score = float('-inf')
        best_features = None

        for features, score in zip(feature_sets, scores):
            if len(features) <= num_features and score > best_score:
                best_score = score
                best_features = features

        ridge = Ridge(alpha=0.01, fit_intercept=True)
        ridge.fit(weighted_data[:, best_features], weighted_labels)

        coef = np.zeros(data.shape[1])
        coef[best_features] = ridge.coef_

        return coef

    def explain_instance(self, image, classifier_fn, labels=(1,), num_samples=1000, num_features=10, batch_size=10):
        segments = quickshift(image, kernel_size=4, max_dist=200, ratio=0.2)
        num_segments = np.unique(segments).shape[0]

        random_state = check_random_state(None)
        perturbations = random_state.randint(0, 2, num_samples * num_segments)\
                       .reshape((num_samples, num_segments))

        perturbed_images = []
        for pert in perturbations:
            perturbed = image.copy()
            for segment_id in range(num_segments):
                if pert[segment_id] == 0:
                    perturbed[segments == segment_id] = (
                        np.mean(image[segments == segment_id], axis=0))
            perturbed_images.append(perturbed)

        predictions = []
        for i in range(0, len(perturbed_images), batch_size):
            batch = np.array(perturbed_images[i:i+batch_size])
            preds = classifier_fn(batch)
            predictions.extend(preds.numpy())
        predictions = np.array(predictions)

        distances = pairwise_distances(
            perturbations,
            perturbations[0].reshape(1, -1),
            metric='cosine'
        ).ravel()

        weights = np.sqrt(np.exp(-(distances ** 2) / self.kernel_width ** 2))

        explanations = {}
        for label in labels:
            label_score = predictions[:, label]
            feature_weights = self._fit_lasso(perturbations, label_score, weights, num_features)

            explanations[label] = {
                'segments': segments,
                'feature_weights': feature_weights,
                'predicted_value': predictions[0, label]
            }

        return explanations


# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model setup
model = models.resnet18(pretrained=True)
model = model.to(device)
model.eval()

# Define preprocessing
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# Load the ImageNet class index mapping
with open("imagenet_class_index.json") as f:
    class_idx = json.load(f)
idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
idx2synset = [class_idx[str(k)][0] for k in range(len(class_idx))]
id2label = {v[0]: v[1] for v in class_idx.values()}


def process_single_image(image_path):
    input_image = Image.open(image_path).convert('RGB')
    input_image = input_image.resize((224, 224), Image.BILINEAR)

    np_image = np.array(input_image)

    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)
    input_batch = input_batch.to(device)

    return np_image, input_batch


def classifier_fn(perturbed_images):
    processed_images = []

    for img in perturbed_images:
        try:
            if torch.is_tensor(img):
                img = img.cpu().numpy()

            if img.max() <= 1.0:
                img = (img * 255).astype(np.uint8)
            else:
                img = img.astype(np.uint8)

            pil_image = Image.fromarray(img)
            tensor = preprocess(pil_image)
            processed_images.append(tensor)

        except Exception as e:
            print(f"Error processing image in classifier_fn: {str(e)}")
            continue

    batch = torch.stack(processed_images)
    batch = batch.to(device)

    with torch.no_grad():
        outputs = model(batch)
        probs = torch.nn.functional.softmax(outputs, dim=1)

    return probs.cpu()


def visualize_explanation(image, explanation, label):
    segments = explanation['segments']
    feature_weights = explanation['feature_weights']

    mask = np.zeros(segments.shape, dtype=bool)

    for segment_id, weight in enumerate(feature_weights):
        if weight != 0:
            mask[segments == segment_id] = True

    visualization = np.full_like(image, 128)
    visualization[mask] = image[mask]

    return visualization


if __name__ == "__main__":
    imagenet_path = '/content/imagenet_samples'
    image_paths = [os.path.join(imagenet_path, f) for f in os.listdir(imagenet_path)
                  if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.gif'))]

    explainer = LimeImageExplainer()

    for img_path in image_paths:
        try:
            np_image, input_batch = process_single_image(img_path)

            with torch.no_grad():
                output = model(input_batch)

            _, predicted_idx = torch.max(output, 1)
            predicted_idx = predicted_idx.item()
            predicted_label = idx2label[predicted_idx]

            print(f"Processed {img_path}: Predicted as {predicted_label}")

            explanation = explainer.explain_instance(
                np_image,
                classifier_fn,
                labels=[predicted_idx],
                num_samples=50,
                num_features=10,
                batch_size=10
            )

            mask_viz = visualize_explanation(
                np_image,
                explanation[predicted_idx],
                predicted_idx
            )

            plt.figure(figsize=(12, 6))
            plt.subplot(1, 2, 1)
            plt.imshow(np_image)
            plt.title(f'Original: {predicted_label}')
            plt.axis('off')

            plt.subplot(1, 2, 2)
            plt.imshow(mask_viz)
            plt.title('LIME Explanation')
            plt.axis('off')

            plt.tight_layout()
            save_path = f'LIME_{os.path.basename(img_path)}'
            plt.savefig(save_path, bbox_inches='tight', dpi=150)
            plt.close()
        except Exception as e:
            print(f"Error processing {img_path}: {str(e)}")
            import traceback
            traceback.print_exc()
            continue

In [None]:
# Load the pre-trained ResNet18 model
model = models.resnet18(pretrained=True)
model.eval()  # Set model to evaluation mode

# Define the image preprocessing transformations
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# Load the ImageNet class index mapping
with open("imagenet_class_index.json") as f:
    class_idx = json.load(f)
idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
idx2synset = [class_idx[str(k)][0] for k in range(len(class_idx))]
id2label = {v[0]: v[1] for v in class_idx.values()}


class SmoothGrad:
    def __init__(self, model: torch.nn.Module,
                 device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model.to(device)
        self.device = device

    def generate_noise(self, image: torch.Tensor, noise_level: float):
        return torch.normal(
            mean=0,
            std=noise_level,
            size=image.shape,
            device=self.device
        )

    def get_gradients(self, image: torch.Tensor, target_class: int):
        image.requires_grad_()
        output = self.model(image)
        score = output[:, target_class]
        self.model.zero_grad()
        score.backward()

        return image.grad.data

    def __call__(self, image: torch.Tensor, n_samples: int = 50, noise_level: float = 0.1):
        if image.dim() == 3:
            image = image.unsqueeze(0)
        image = image.to(self.device)

        with torch.no_grad():
            output = self.model(image)
            target_class = output.argmax(dim=1).item()

        image_range = image.max() - image.min()
        noise_scale = noise_level * image_range

        accumulated_grads = torch.zeros_like(image)

        for _ in range(n_samples):
            noise = self.generate_noise(image, noise_scale)
            noisy_image = image + noise

            grads = self.get_gradients(noisy_image, target_class)
            accumulated_grads += grads

        smoothed_grads = accumulated_grads / n_samples

        return smoothed_grads, target_class

def visualize_sensitivity_map(sensitivity_map: torch.Tensor, original_image: Image, predicted_label: str):
    sensitivity = sensitivity_map.cpu().numpy()[0].transpose(1, 2, 0)
    sensitivity = np.abs(sensitivity)
    sensitivity = (sensitivity - sensitivity.min()) / (sensitivity.max() - sensitivity.min())
    sensitivity = sensitivity.max(axis=2)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    ax1.imshow(original_image)
    ax1.set_title('Original Image')
    ax1.axis('off')

    ax2.imshow(sensitivity, cmap='hot')
    ax2.set_title(f'SmoothGrad Sensitivity Map\nPredicted: {predicted_label}')
    ax2.axis('off')

    plt.tight_layout()
    return fig


smooth_grad = SmoothGrad(model)

imagenet_path = './imagenet_samples'

# List of image file paths
image_paths = os.listdir(imagenet_path)

for img_path in image_paths:
    full_path = os.path.join(imagenet_path, img_path)
    try:
        input_image = Image.open(full_path).convert('RGB')
        input_tensor = preprocess(input_image)

        sensitivity_map, predicted_idx = smooth_grad(
            input_tensor,
            n_samples=50,
            noise_level=0.1
        )

        predicted_label = idx2label[predicted_idx]

        fig = visualize_sensitivity_map(sensitivity_map, input_image, predicted_label)
        output_filename = f'smoothgrad_{os.path.splitext(img_path)[0]}.png'
        fig.savefig(output_filename)
        plt.close(fig)

        print(f"Processed {img_path}: Predicted as {predicted_label}")

    except Exception as e:
        print(f"Error processing {img_path}: {str(e)}")

In [None]:
class ExplanationCorrelation:
    def get_feature_ranking(explanation: Dict):
        weights = np.abs(explanation['feature_weights'])
        rankings = stats.rankdata(-weights)

        return rankings

    def calculate_correlations(rankings1: np.ndarray, rankings2: np.ndarray):
        kendall_tau, kendall_p = stats.kendalltau(rankings1, rankings2)
        spearman_rho, spearman_p = stats.spearmanr(rankings1, rankings2)

        return kendall_tau, kendall_p, spearman_rho, spearman_p

    def compare_explanations(explanation1: Dict,
                             explanation2: Dict,
                             method1_name: str,
                             method2_name: str):

        rankings1 = ExplanationCorrelation.get_feature_ranking(explanation1)
        rankings2 = ExplanationCorrelation.get_feature_ranking(explanation2)
        kendall_tau, kendall_p, spearman_rho, spearman_p = (
            ExplanationCorrelation.calculate_correlations(rankings1, rankings2)
        )

        return {
            'method1': method1_name,
            'method2': method2_name,
            'kendall_tau': kendall_tau,
            'kendall_p_value': kendall_p,
            'spearman_rho': spearman_rho,
            'spearman_p_value': spearman_p
        }


def visualize_correlation_matrices(kendall_matrix: np.ndarray, spearman_matrix: np.ndarray, method_names: List[str]):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    sns.heatmap(kendall_matrix,
                annot=True,
                cmap='RdBu',
                vmin=-1,
                vmax=1,
                xticklabels=method_names,
                yticklabels=method_names,
                ax=ax1)
    ax1.set_title('Kendall Tau Correlations')

    sns.heatmap(spearman_matrix,
                annot=True,
                cmap='RdBu',
                vmin=-1,
                vmax=1,
                xticklabels=method_names,
                yticklabels=method_names,
                ax=ax2)
    ax2.set_title('Spearman Rank Correlations')

    plt.tight_layout()
    return fig

def convert_smoothgrad_to_segments(smoothgrad_grads: torch.Tensor, segments: np.ndarray):

    grads = np.abs(smoothgrad_grads.cpu().numpy()[0])
    grads = grads.mean(axis=0)

    unique_segments = np.unique(segments)
    num_segments = len(unique_segments)

    feature_weights = np.zeros(num_segments)
    for i, segment_id in enumerate(unique_segments):
        segment_mask = (segments == segment_id)
        feature_weights[i] = grads[segment_mask].mean()

    return {
        'segments': segments,
        'feature_weights': feature_weights
    }

if __name__ == "__main__":

    image_path = "/content/imagenet_samples/mountain_bike.JPEG"
    np_image, input_batch = process_single_image(image_path)

    with torch.no_grad():
        output = model(input_batch)
        _, predicted_idx = torch.max(output, 1)
        predicted_idx = predicted_idx.item()

    lime_explainer = LimeImageExplainer()
    lime_explanation = lime_explainer.explain_instance(
        np_image,
        classifier_fn,
        labels=[predicted_idx],
        num_samples=100,
        num_features=10
    )
    segments = lime_explanation[predicted_idx]['segments']

    smooth_grad = SmoothGrad(model)
    smoothgrad_grads, _ = smooth_grad(
        input_batch,
        n_samples=50,
        noise_level=0.1
    )
    smoothgrad_explanation = convert_smoothgrad_to_segments(
        smoothgrad_grads,
        segments
    )

    explanations = {
        'LIME': lime_explanation[predicted_idx],
        'SmoothGrad': smoothgrad_explanation
    }


    kendall_matrix = np.ones((2, 2))
    spearman_matrix = np.ones((2, 2))

    results = ExplanationCorrelation.compare_explanations(
        explanations['LIME'],
        explanations['SmoothGrad'],
        'LIME',
        'SmoothGrad'
    )

    kendall_matrix[0, 1] = results['kendall_tau']
    kendall_matrix[1, 0] = results['kendall_tau']

    spearman_matrix[0, 1] = results['spearman_rho']
    spearman_matrix[1, 0] = results['spearman_rho']

    fig = visualize_correlation_matrices(
        kendall_matrix,
        spearman_matrix,
        list(explanations.keys())
    )
    plt.show()

    print("\nresults:")
    for key, value in results.items():
        if isinstance(value, float):
            print(f"{key}: {value:.4f}")
        else:
            print(f"{key}: {value}")

    def visualize_segmented_explanations(original_image: np.ndarray, lime_exp: Dict, smoothgrad_exp: Dict):

        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        axes[0].imshow(original_image)
        axes[0].set_title('Original Image')
        axes[0].axis('off')

        segments = lime_exp['segments']
        lime_mask = np.zeros_like(segments, dtype=float)
        for segment_id, weight in enumerate(lime_exp['feature_weights']):
            lime_mask[segments == segment_id] = abs(weight)
        axes[1].imshow(lime_mask, cmap='gray')
        axes[1].set_title('LIME Explanation')
        axes[1].axis('off')

        smoothgrad_mask = np.zeros_like(segments, dtype=float)
        for segment_id, weight in enumerate(smoothgrad_exp['feature_weights']):
            smoothgrad_mask[segments == segment_id] = abs(weight)
        axes[2].imshow(smoothgrad_mask, cmap='gray')
        axes[2].set_title('Segmented SmoothGrad')
        axes[2].axis('off')

        plt.tight_layout()
        return fig

    fig = visualize_segmented_explanations(
        np_image,
        explanations['LIME'],
        explanations['SmoothGrad']
    )
    plt.show()

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from typing import Optional, Tuple
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import os

class PGDExplanationAttack:
    def __init__(
        self,
        model: torch.nn.Module,
        num_steps: int = 5,
        step_size: float = 1/255,
        initial_explanation_weight: float = 1.0,
        initial_prediction_weight: float = 100.0,
        device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    ):
        self.model = model.to(device)
        self.num_steps = num_steps
        self.step_size = step_size
        self.explanation_weight = initial_explanation_weight
        self.prediction_weight = initial_prediction_weight
        self.device = device

    def get_smoothgrad_explanation(
        self,
        image: torch.Tensor,
        n_samples: int = 50,
        noise_level: float = 0.1,
        target_class: Optional[int] = None
    ):
        accumulated_grads = torch.zeros_like(image).to(self.device)

        image_range = image.max() - image.min()
        noise_scale = noise_level * image_range

        if target_class is None:
            with torch.no_grad():
                output = self.model(image)
                target_class = output.argmax(dim=1).item()

        for _ in range(n_samples):
            noise = torch.randn_like(image).to(self.device) * noise_scale
            noisy_image = image.clone() + noise
            noisy_image.requires_grad_(True)

            output = self.model(noisy_image)
            score = output[:, target_class]

            grad = torch.autograd.grad(score, noisy_image,
                                     create_graph=True,
                                     retain_graph=True)[0]

            accumulated_grads += grad

        return accumulated_grads / n_samples

    def prediction_loss(
        self,
        original_output: torch.Tensor,
        perturbed_output: torch.Tensor,
        target_class: int
    ):
        orig_probs = F.softmax(original_output, dim=1)
        pert_probs = F.softmax(perturbed_output, dim=1)

        target_prob_loss = F.mse_loss(pert_probs[:, target_class], orig_probs[:, target_class])

        ranking_loss = torch.relu(
            pert_probs.max(dim=1)[0] - pert_probs[:, target_class] + 0.1
        ).mean()

        logit_diff_loss = F.mse_loss(
            perturbed_output[:, target_class] - perturbed_output,
            original_output[:, target_class] - original_output
        )

        return target_prob_loss + ranking_loss + logit_diff_loss

    def explanation_loss(
        self,
        original_exp: torch.Tensor,
        perturbed_exp: torch.Tensor
    ):
        orig_norm = F.normalize(original_exp.view(original_exp.size(0), -1), dim=1)
        pert_norm = F.normalize(perturbed_exp.view(perturbed_exp.size(0), -1), dim=1)

        cos_sim = (orig_norm * pert_norm).sum(dim=1)
        cos_loss = -torch.mean((1 - cos_sim))

        mse_loss = F.mse_loss(orig_norm, pert_norm)

        return cos_loss + mse_loss

    def attack(
        self,
        image: torch.Tensor,
        epsilon: float,
    ):
        if image.dim() == 3:
            image = image.unsqueeze(0)
        image = image.to(self.device)

        with torch.no_grad():
            original_output = self.model(image)
            target_class = original_output.argmax(dim=1).item()

        original_exp = self.get_smoothgrad_explanation(image, target_class=target_class)

        adv_image = image.clone().detach()
        best_adv_image = adv_image.clone()
        best_exp_diff = -float('inf')

        pred_weight = self.prediction_weight
        exp_weight = self.explanation_weight

        for step in tqdm(range(self.num_steps), desc="PGD Attack"):
            adv_image.requires_grad_(True)

            current_exp = self.get_smoothgrad_explanation(adv_image, target_class=target_class)
            current_output = self.model(adv_image)
            current_pred = current_output.argmax(dim=1).item()

            pred_loss = self.prediction_loss(original_output, current_output, target_class)
            exp_loss = self.explanation_loss(original_exp, current_exp)

            if current_pred != target_class:
                pred_weight *= 2
            else:
                pred_weight = max(pred_weight * 0.9, self.prediction_weight)

            total_loss = (exp_weight * exp_loss) + (pred_weight * pred_loss)

            grad = torch.autograd.grad(total_loss, adv_image)[0]

            with torch.no_grad():
                adv_image = adv_image.detach() + self.step_size * grad.sign()
                delta = torch.clamp(adv_image - image, -epsilon, epsilon)
                adv_image = torch.clamp(image + delta, 0, 1).detach()

                current_output = self.model(adv_image)
                current_pred = current_output.argmax(dim=1).item()

                if current_pred == target_class:
                    exp_diff = -self.explanation_loss(original_exp, current_exp).item()
                    if exp_diff > best_exp_diff:
                        best_exp_diff = exp_diff
                        best_adv_image = adv_image.clone()

        return best_adv_image, target_class

def visualize_results(
    original_tensor: torch.Tensor,
    adv_image: torch.Tensor,
    original_exp: torch.Tensor,
    perturbed_exp: torch.Tensor,
    orig_class: str,
    orig_conf: float,
    adv_class: str,
    adv_conf: float,
    epsilon: float
):
    fig = plt.figure(figsize=(15, 5))

    plt.subplot(2, 2, 1)
    plt.imshow(torch.clamp(original_tensor.squeeze().permute(1,2,0).cpu(), 0, 1))
    plt.title(f'Original Image\n{orig_class} ({orig_conf:.1f}%)')
    plt.axis('off')

    plt.subplot(2, 2, 2)
    plt.imshow(torch.abs(original_exp.squeeze()).sum(dim=0).cpu().detach(), cmap='hot')
    plt.title('Original Explanation')
    plt.axis('off')

    plt.subplot(2, 2, 3)
    plt.imshow(torch.clamp(adv_image.squeeze().permute(1,2,0).cpu(), 0, 1))
    plt.title(f'Adversarial Image (ε={epsilon:.3f})\n{adv_class} ({adv_conf:.1f}%)')
    plt.axis('off')

    plt.subplot(2, 2, 4)
    plt.imshow(torch.abs(perturbed_exp.squeeze()).sum(dim=0).cpu().detach(), cmap='hot')
    plt.title('Perturbed Explanation')
    plt.axis('off')

    plt.tight_layout()
    return fig

def get_prediction(model, image_tensor):
    with torch.no_grad():
        output = model(image_tensor)
        probabilities = F.softmax(output, dim=1)
        pred_idx = output.argmax(dim=1).item()
        confidence = probabilities[0][pred_idx].item() * 100
        return idx2label[pred_idx], confidence, pred_idx

def apply_pgd_attack(image_path: str, model, preprocess):
    input_image = Image.open(image_path).convert('RGB')
    input_tensor = preprocess(input_image)

    attack = PGDExplanationAttack(
        model,
        num_steps=5,
        step_size=1/255,
        initial_explanation_weight=1.0,
        initial_prediction_weight=100.0
    )

    epsilons = [2/255, 4/255, 8/255]
    results = []
    predicted_idx = None

    for epsilon in epsilons:
        adv_image, predicted_idx = attack.attack(input_tensor, epsilon=epsilon)
        results.append((adv_image, predicted_idx))

    return input_tensor, results, idx2label[predicted_idx]

def main():
    images_to_attack = []
    for img_path in os.listdir(imagenet_path):
        if os.path.isfile(os.path.join(imagenet_path, img_path)):
            images_to_attack.append(img_path)

    for img_path in images_to_attack:
        print(f"\nProcessing {img_path}")
        image_path = os.path.join(imagenet_path, img_path)

        try:
            original_tensor, attack_results, _ = apply_pgd_attack(image_path, model, preprocess)

            orig_class, orig_conf, orig_idx = get_prediction(model, original_tensor.unsqueeze(0))
            print(f"\nOriginal Classification: {orig_class} ({orig_conf:.2f}% confidence)")

            attack = PGDExplanationAttack(model)

            for epsilon, (adv_image, _) in zip([2/255, 4/255, 8/255], attack_results):
                adv_class, adv_conf, adv_idx = get_prediction(model, adv_image)

                if adv_idx == orig_idx:
                    print(f"\nEpsilon = {epsilon:.4f}:")
                    print(f"Adversarial Classification: {adv_class} ({adv_conf:.2f}% confidence)")

                    original_tensor_batch = original_tensor.unsqueeze(0) if original_tensor.dim() == 3 else original_tensor
                    adv_image_batch = adv_image.unsqueeze(0) if adv_image.dim() == 3 else adv_image

                    original_exp = attack.get_smoothgrad_explanation(original_tensor_batch)
                    perturbed_exp = attack.get_smoothgrad_explanation(adv_image_batch)

                    fig = visualize_results(
                        original_tensor_batch, adv_image_batch,
                        original_exp, perturbed_exp,
                        orig_class, orig_conf,
                        adv_class, adv_conf,
                        epsilon
                    )

                    plt.savefig(f'pgd_attack_{img_path}_eps_{int(epsilon*255)}.png')
                    plt.close(fig)
                    exp_diff = F.mse_loss(original_exp, perturbed_exp).item()
                    print(f"Explanation difference (MSE): {exp_diff:.6f}")
                else:
                    print(f"\nEpsilon = {epsilon:.4f}: Skipped - prediction changed from {orig_class} to {adv_class}")

        except Exception as e:
            print(f"Error processing {img_path}: {str(e)}")
            continue

if __name__ == "__main__":
    main()