In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from utils import models
from torchvision.transforms import Normalize
from captum.attr import Saliency

In [None]:
device = torch.device('mps')

# Create models
generator = models.GeneratorMNIST().to(device)
discriminator = models.DiscriminatorMNIST().to(device)

# Load weights (model.load_state_dict(torch.load(PATH)))
generator.load_state_dict(torch.load('results/mnist/weights/gen_epoch_4.pth', map_location=torch.device(device)))
discriminator.load_state_dict(torch.load('results/mnist/weights/disc_epoch_4.pth', map_location=torch.device(device)))

In [None]:
norm = Normalize((0.5,), (0.5,))

In [None]:
# Generate a fake sample
noise = torch.randn(1, 100, 1, 1, device=device)

fake = generator(noise)

img = fake.cpu().detach().numpy()
img = np.squeeze(img)

plt.imshow(img, cmap='gray')

In [None]:
# Saliency
saliency = Saliency(discriminator)
explanation = saliency.attribute(fake)
explanation = norm(explanation)

grads_img = explanation.squeeze().cpu().detach().numpy()
plt.imshow(grads_img)

In [None]:
print(torch.max(fake))
print(torch.max(explanation))

In [None]:
mse = torch.nn.MSELoss(reduction='none')

In [None]:
mul = fake * -explanation
# mse(fake, explanation)

mul_img = mul.squeeze().cpu().detach().numpy()
im = plt.imshow(mul_img)
plt.colorbar(im)