In [None]:
import os
import numpy as np
import pandas as pd

import torch
from torchvision.datasets import MNIST
from captum.attr import GradientShap

from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from lfxai.explanations.features import attribute_auxiliary
from lfxai.models.images import ClassifierMnist, EncoderMnist
import copy

from matplotlib import pyplot as plt

In [None]:
# Basic setup
torch.random.manual_seed(123)
batch_size = 128
device = 'cpu'

# Model Args
image_height = 28
dim_latent = 4

In [None]:
# Data loading
data_dir = "data/mnist"
shared_transform = transforms.Compose([transforms.ToTensor()])

train_dataset = MNIST(data_dir, train=True, download=True, transform=shared_transform
                                           )
test_dataset = MNIST(data_dir, train=False, download=True, transform=shared_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False
)

In [None]:
#### Model loading ######
# Specification
dim_latent = 4
name = "TestClassifier"
classifier_state_dict_path = os.path.join('..', 'TrainedModels', 'MNIST', 'Classifier_run0.pt') # TODO Need to loop over runs

# Load
encoder = EncoderMnist(dim_latent)

classifier = ClassifierMnist(encoder, dim_latent, name)
classifier.load_state_dict(torch.load(classifier_state_dict_path), strict=True)

encoder = copy.deepcopy(classifier.encoder) # Necessary? Should check whether the load just modifies the original

print("Classifier Loaded")

In [None]:
# Instantiate GradShap For Encoder and Full Model
gradshap_encoder = GradientShap(encoder)
gradshap_full_model = GradientShap(classifier)


In [None]:
baseline_image = torch.zeros((1, 1, 28, 28), device=device)

In [None]:
encoder_attributions = attribute_auxiliary(
                        encoder, test_loader, device, gradshap_encoder, baseline_image
                    )

# Note that this is the correct thing to do here because the classifier outputs probabilities so we are taking a soft sum
pipeline_attributions = attribute_auxiliary(
    classifier, test_loader, device, gradshap_full_model, baseline_image
)



In [None]:
# Cast each one to absolute value, since we're not interested in the direction on the hidden space
encoder_attributions = np.abs(encoder_attributions)
pipeline_attributions = np.abs(pipeline_attributions)

# Normalise each one to have variance 1
encoder_attributions = encoder_attributions/np.std(encoder_attributions)
pipeline_attributions = pipeline_attributions/np.std(pipeline_attributions)


In [None]:
fig, ax = plt.subplots(1, 2, sharey='row', figsize=[15, 6])
ax[0].imshow(encoder_attributions.mean(axis=0).squeeze(), cmap='gray_r')
ax[0].set_title('Encoder Saliency Map')

ax[1].imshow(pipeline_attributions.mean(axis=0).squeeze(), cmap='gray_r')
ax[1].set_title('Classifier Saliency Map')

In [None]:
# As above, but shows more detail so that we can understand what's going on better
fig, ax = plt.subplots(figsize=[12, 12])
ax.imshow(encoder_attributions.mean(axis=0).squeeze(), cmap='gray_r')
for (i, j), pixel_value in np.ndenumerate(np.mean(encoder_attributions, axis=0).squeeze().round(2)):
    ax.text(i, j, pixel_value, c='red', ha='center',va='center')

In [None]:
# Compute the pearson correlation between the maps
# TODO we should also exclude constant/nearly constant pixels?

# Note there are a few ways we could possibly do this, e.g. taking the mean of the feature values first
# However, I think that unnecessarily discards structure that we might otherwise want to keep.
pearsons_rho = np.corrcoef(encoder_attributions.flatten(),
                           pipeline_attributions.flatten())
pearsons_rho

In [None]:
# Alternatively, we may want to consider the R**2 value. Assuming a linear relationship + intercept, this is just the
# Square of the pearson coefficient.
pearsons_rho**2

In [None]:
# Finally we can look at what happens when we first aggregate across the images
pearsons_rho_agg = np.corrcoef(encoder_attributions.mean(axis=0).flatten(), pipeline_attributions.mean(axis=0).flatten())
pearsons_rho_agg

In [None]:
dots = (encoder_attributions*pipeline_attributions).mean(axis=0).squeeze()
dots = dots[np.abs(dots) > 0]

plt.hist(dots.flatten())

In [None]:
plt.imshow((encoder_attributions*pipeline_attributions).mean(axis=0).squeeze(), cmap='gray_r')
plt.colorbar()

We have something of a puzzle - why do the correlations skyrocket when we first aggregate across the dimensions?

To solve this, let's examine what happens in the case of a _single_ image.

In [None]:
i = 2

X, y = test_dataset[i]
X = X.squeeze()
a_enc_i = encoder_attributions[i].squeeze()
a_full_i = pipeline_attributions[i].squeeze()

#  Compute similarity in this image
(np.corrcoef(a_enc_i.flatten(), a_full_i.flatten()))

In [None]:
# Remove the black pixels since these are trivially the same
mask = (X == 0)
m_enc = a_enc_i[~mask]
m_full = a_full_i[~mask]

In [None]:
np.corrcoef(m_full.flatten(), m_enc.flatten())**2

In [None]:
from scipy.stats import spearmanr

spearmanr(m_full.flatten(), m_enc.flatten())

In [None]:
fig, ax = plt.subplots(1, 3, sharey='row', figsize=[15, 6])
ax[0].imshow(a_enc_i, cmap='gray_r')
ax[0].set_title('Encoder Saliency Map')

ax[1].imshow(a_full_i, cmap='gray_r')
ax[1].set_title('Classifier Saliency Map')

ax[2].imshow(X, cmap='gray_r')
ax[2].set_title(f'Original Image - class {y}')

In [None]:
# Let's look for whether these are correlated acros s a _single_ image by visualising them as a scatter
plt.scatter(m_enc, m_full)

We have some satisfying answers:

1. Correlations appear much lower _within_ images than across images. Intuitively, if we take the mean across the images, the bits in the centre will all end up with some feature importance between 0 and 1. However, most explainability metrics are defined per prediction, and this can be important in e.g. regulatory contexts.

2. At least some of the correlation is driven by the fact that the black pixels have spuriously high correlations. Masking out these pixels makes a substantive difference.

TODO: we should check whether their pearson results change when masking out some pixels. We should also check whether their results are taken per image or somehow aggregating across images.

Now, we should make the above into a function that we can easily loop over a dataset, which should give us pearson correlations per image.

In [None]:
def image_pearson(i: int, dataset, encoder_attributions, full_attributions, mask_dead_pixels=True):
    X, y = dataset[i]
    X = X.squeeze()

    X = X.squeeze()
    a_enc_i = encoder_attributions[i].squeeze()
    a_full_i = full_attributions[i].squeeze()

    if mask_dead_pixels:
        mask = (X == 0)
        a_enc_i = a_enc_i[~mask]
        a_full_i = a_full_i[~mask]

    rho = np.corrcoef(a_enc_i.flatten(), a_full_i.flatten())[0, 1]

    return rho

rhos = []
for i in range(len(test_dataset)):
    rho = image_pearson(i, test_dataset, encoder_attributions, pipeline_attributions, mask_dead_pixels=True)
    rhos.append(rho)

rhos = np.array(rhos)


In [None]:
pd.Series(rhos).mean()**2

In [None]:
plt.hist(rhos, bins=50)

# Further extensions

Can we figure out what properties of the map make this happen? E.g. is it an issue of scale? To test this, we can look at properties of the decoder.

In [None]:
from lfxai.explanations.features import AuxiliaryFunction

In [None]:
decoder = copy.deepcopy(classifier.lin_output)

In [None]:
# Pass through once to get a baseline for the hidden
hiddens = []

for X,y in test_loader:
    hidden = encoder(X).detach().numpy()
    hiddens.append(hidden)


hiddens = np.concatenate(hiddens)
hidden_baseline = hiddens.mean(axis=0)
hidden_baseline = hidden_baseline[np.newaxis, ...]
hidden_baseline = torch.Tensor(hidden_baseline)

In [None]:
decoder_attributions = []
# hidden_baseline = torch.zeros(1, 4)

gradshap_decoder = GradientShap(decoder)

for X, y in test_loader:
    hidden = encoder(X)

    auxiliary_encoder = AuxiliaryFunction(decoder, hidden)
    gradshap_decoder.forward_func = auxiliary_encoder

    attr = gradshap_decoder.attribute(hidden, baselines=hidden_baseline)
    attr = attr.detach().cpu().numpy()

    decoder_attributions.append(attr)

decoder_attributions = np.concatenate(decoder_attributions)
decoder_attributions = np.abs(decoder_attributions)


In [None]:
df = pd.DataFrame(decoder_attributions)
df.describe()

### Now that we have obtained the decoder attributions, we need to hone in on a pixel that failed our test before and try and understand what went wrong.

Note: this ended up being a bit of a failure - it seems like the scale is indeed the key, but since the explanations are locally linear, this is a pretty trivial observation.

In [None]:

i = 0

X, y = test_dataset[i]
X = X.squeeze()
a_enc_i = encoder_attributions[i].squeeze()
a_full_i = pipeline_attributions[i].squeeze()
a_dec_i = decoder_attributions[i].squeeze()

#  Compute similarity in this image
(np.corrcoef(a_enc_i.flatten(), a_full_i.flatten()))

In [None]:
fig, ax = plt.subplots(1, 3, figsize=[15, 4])
ax[0].imshow(a_enc_i, cmap='gray')
ax[0].set_title('encoder')

ax[1].imshow(a_full_i, cmap='gray')
ax[1].set_title('full')

ax[2].imshow(a_full_i*a_enc_i, cmap='gray')
ax[2].set_title('corr')

k, l = 16, 17
for i in range(len(ax)):
    ax[i].text(k, l, 'h', c='red', ha='center',va='center')

In [None]:
(a_full_i*a_enc_i)[k, l]

In [None]:
mask = X > 0

In [None]:
all_corrs = pd.Series((a_full_i*a_enc_i)[mask])
all_corrs = all_corrs.sort_values()

In [None]:
(all_corrs > (a_full_i*a_enc_i)[k, l]).sum()/len(all_corrs)

In [None]:
a_full_i[k, l]

In [None]:
a_enc_i[k, l]

In [None]:
fig, ax = plt.subplots(2, 2)

for i in range(4):
    a = ax.flatten()[i]
    a.hist(df[i])
    a.set_title(i)

In [None]:
plt.imshow(X, cmap='gray')

In [None]:
a_dec_i