In [None]:
import shap
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader

# Load your PyTorch model
model = torch.load("./transformer_mnist.pth", map_location=torch.device('cpu'))
model.eval()

# Load Fashion MNIST Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

test_dataset = datasets.FashionMNIST(root="./cnn-data", train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=5, shuffle=True)


background, labels = next(iter(test_loader))  # Shape: (5, 1, 28, 28)

test_samples = background[:5]  # Select first 5 images

explainer = shap.GradientExplainer(model, background)

shap_values = explainer.shap_values(test_samples)

test_samples_np = test_samples.squeeze().numpy()  # Shape: (5, 28, 28)
test_samples_np = np.expand_dims(test_samples_np, axis=-1)  # Shape: (5, 28, 28, 1)

# Plot SHAP Explanations
shap.image_plot(shap_values, test_samples_np)