In [None]:
from xai import *
import matplotlib.pyplot as plt
import random
import numpy as np
import torch
import os
from tqdm import tqdm

In [None]:
latent_dim = 8

try:
    mnist = MNIST.load("mnist.model")
except FileNotFoundError:
    mnist = MNIST(
        latent_shape=(latent_dim,), 
        hidden_layers=2, 
        classifier_head_output_activation="Softmax",
        ) 
    
    mnist.fit_autoencoder(
        epochs=10_000,
        batch_size=256,
        loss_criterion="MSELoss",
        early_stop_cont=750,
        verbose=True,
        info="Mnist autoencoder train"
    )

    mnist.fit_classifier_head(
        epochs=10_000,
        batch_size=256,
        early_stop_cont=750,
        verbose=True,
        info="Mnist classifier-head train"
    )

    mnist.save("mnist.model")
    

In [None]:
mnist.autoencoder.train_history.figure("Mnist autoencoder train")

In [None]:
mnist.classifier_head.train_history.figure("Mnist classifier-head train")

In [None]:
save_folder = "./images/mnist-feature-space-explanation"
try:
    os.mkdir(save_folder)
except:
    pass

for algorithm in ["exact", "permutation", "deep", "kernel", "gradient"]:
    try:
        os.mkdir(f"{save_folder}/{algorithm}")
    except:
        pass

rows = 10
try:
    samples = Stream.load("images/samples", Tuple[torch.Tensor,torch.Tensor]).list()
except:
    samples: List[torch.Tensor,torch.Tensor] = []
    for i in range(rows):
        samples.append(mnist.get_sample(digit=i))

    Stream(samples).save("images/samples")

In [None]:
def mk_shap_latent_plot(algorithm: Explainers):
    fig = plt.figure(dpi=100, figsize=(20,28)) 
    fig.suptitle(f"Algorithm: {algorithm.capitalize()}Explainer", size=30)
    columns = latent_dim + 3

    def save_mnist(image: np.ndarray, path: str, **args) -> None:
        plt.imsave(path, image.repeat(16,0).repeat(16,1), **args)

    compute_times = []

    with tqdm(total=rows, desc=f"{algorithm.capitalize()}Explainer") as bar:
        for row in range(rows):
            slot = row*columns

            sample,label = samples[row]
            sample = sample.numpy(force=True)
            label = int(label.item())
            predict = mnist(sample)
            recon = predict.reconstruction().numpy(force=True)

            latent_explanation = predict.reconstruction.explain(algorithm, mnist(mnist.val_data).embedding()).invert()
            shap_values: np.ndarray = latent_explanation.shap_values
            norm = max([np.max(np.abs(shap_values))])

            compute_times.append(latent_explanation.compute_time.float())

            fig.add_subplot(rows,columns,slot+1)
            plt.imshow(sample, cmap="gray")
            plt.axis('off')
            plt.title(f"$x={{{label}}}$", size=25)
            save_mnist(sample, f"{save_folder}/{algorithm}/sample{row}.png", cmap="gray")
                
            fig.add_subplot(rows,columns,slot+2)
            plt.imshow(recon, cmap="gray")
            plt.axis('off')
            plt.title("$\hat{{x}}$", size=25)
            save_mnist(recon, f"{save_folder}/{algorithm}/reconstruction{row}.png", cmap="gray")

            for j,shap in enumerate(shap_values):
                fig.add_subplot(rows,columns,slot+j+3)
                im = np.zeros((28,28,3), dtype=np.float32)
                red = np.where(shap > 0, shap/norm, np.zeros_like(shap))
                blue = np.where(shap < 0, -shap/norm, np.zeros_like(shap))
                im[:,:,0] = red
                im[:,:,2] = blue
                plt.imshow(im)
                plt.axis('off')
                plt.title(f"$l_{j}$", size=25)
                save_mnist(im, f"{save_folder}/{algorithm}/latent{j}-explanation.png")
            
            bar.update()

            shap_sum = shap_values.sum(0)
            norm = np.max(np.abs(shap_sum))
            im = np.zeros((28,28,3), dtype=np.float32)
            red = np.where(shap_sum > 0, shap_sum/norm, np.zeros_like(shap_sum))
            blue = np.where(shap_sum < 0, -shap_sum/norm, np.zeros_like(shap_sum))
            im[:,:,0] = red
            im[:,:,2] = blue
            fig.add_subplot(rows,columns,slot+j+4)
            plt.imshow(im)
            plt.axis('off')
            plt.title(f"$\sum_{{i=0}}^{latent_dim}l_i$", size=25)
            save_mnist(im, f"{save_folder}/{algorithm}/latents-sum{row}.png")



    fig.tight_layout()
    fig.savefig(f"{save_folder}/{algorithm}/shap-latents.png", bbox_inches="tight")

    plt.close()
    plt.plot(range(len(compute_times)), compute_times, label="Time used")
    plt.plot(range(len(compute_times)), [sum(compute_times)/len(compute_times)]*len(compute_times), label="Average time")
    plt.legend()
    plt.ylabel("Seconds")
    plt.xlabel("Iteration")
    plt.title(f"Compute times: {algorithm}")
    plt.savefig(f"{save_folder}/{algorithm}/compute-times.png")
    plt.close()




algorithm: Explainers
for algorithm in ["exact", "permutation", "deep", "kernel", "gradient"]:
    mk_shap_latent_plot(algorithm)