In [None]:
import warnings
warnings.filterwarnings('ignore') # disable warnings relateds to versions of tf
import numpy as np
import dianna
import onnx
from onnx_tf.backend import prepare
import matplotlib.pyplot as plt
from pathlib import Path
import onnxruntime
from scipy.special import softmax
import os

from matplotlib import cm
from sklearn.preprocessing import StandardScaler
# import random
import torch
from scipy.stats import wasserstein_distance, kendalltau, pearsonr
from functools import partial
from torch.nn.functional import mse_loss
import seaborn as sns

In [2]:
def run_model(data, name):
    fname           = str(Path('models', name))
    sess            = onnxruntime.InferenceSession(fname)
    input_name      = sess.get_inputs()[0].name
    output_name     = sess.get_outputs()[0].name

    onnx_input      = {input_name: data}
    pred_onnx       = sess.run([output_name], onnx_input)

    return softmax(pred_onnx[0], axis=1)

def prepare_data(data_path):
    data = dict(np.load(data_path))
    if 'color' in data_path.split('/')[-1]:
        data['images']          = np.array(data['images'], dtype='float32') / 255.
    elif 'rotation' in data_path.split('/')[-1]:
        data['images']          = np.array(data['images'], dtype='float32') / 255.
    elif 'roundedness' in data_path.split('/')[-1]:
        data['images']          = np.array(data['images'], dtype='float32')
    data['images']          = data['images'][:100]
    data['labels']          = np.array(data['labels'], dtype=int)
    data['labels']          = data['labels'][:100]
    data['images'] = data['images'].reshape(-1, 1, 64, 64)
    return data

def pca(x):
    X_train                 = np.array(x)
    sc                      = StandardScaler()
    X_train_std             = sc.fit_transform(X_train)
    cov_mat                 = np.cov(X_train_std.T)
    eigen_vals, eigen_vecs  = np.linalg.eig(cov_mat)

    # Make a list of (eigenvalue, eigenvector) tuples
    eigen_pairs = [(np.abs(eigen_vals[i]), eigen_vecs[:, i]) for i in range(len(eigen_vals))]

    # Sort the (eigenvalue, eigenvector) tuples from high to low
    eigen_pairs.sort(key=lambda k: k[0], reverse=True)

    w = np.hstack((eigen_pairs[0][1][:, np.newaxis], eigen_pairs[1][1][:, np.newaxis]))
    X_train_std[0].dot(w)

    X_train_pca = X_train_std.dot(w)

    return X_train_pca

def minmax(values):
    return (values -values.min()) / ((values.max() - values.min())+10e-12)

# fill each pixel with SHAP values 
def fill_segmentation(values, segmentation):
    out = np.zeros(segmentation.shape)
    for i in range(len(values)):
        out[segmentation == i] = values[i]
    return out

In [None]:
# TODO: Change this variable according to the path to the repository
repo_path = "/home/your_username/dianna-exploration"

for model_name in ["pytorch", "original"]:
    for explainer in ["rise", "lime"]: # shap
        for task_name in ["color", "rotation", "roundedness"]:
            OUT_PATH = f"{repo_path}/lorentz_workshop/explainer_results/{model_name}_{explainer}_{task_name}_heatmaps.npz"
            if os.path.exists(OUT_PATH):
                print(f"{OUT_PATH} already exists!")
                continue
            
            print("Generating", OUT_PATH)

            # load dataset
            if task_name == "color":
                DATA_PATH           = f'{repo_path}/example_data/dataset_preparation/geometric_shapes/test_colors.npz'
            elif task_name == "rotation":
                DATA_PATH           = f'{repo_path}/example_data/dataset_preparation/geometric_shapes/test_rotation.npz'
            elif task_name == "roundedness":
                DATA_PATH           = f'{repo_path}/lorentz_workshop/dataset_roundedness.npz'
            else:
                print("Wrong task name, aborting: ", task_name)
                continue
            
            data                = prepare_data(DATA_PATH)

            if model_name == "original":
                MODEL_PATH          = Path(f'{repo_path}/lorentz_workshop', "geometric_shapes_model.onnx")
            elif model_name == "pytorch":
                MODEL_PATH          = Path(f'{repo_path}/lorentz_workshop', "pytorch_model.onnx")
            else:
                print("Wrong model name, aborting: ", model_name)
                continue

            # Predict with onnx model
            onnx_model          = onnx.load(MODEL_PATH)
            pred_onnx           = run_model(data['images'], MODEL_PATH)
            pred_ids            = pred_onnx.argmax(axis=1)

            print("Accuracy:", sum(pred_ids == data['labels']) / len(data['labels']), "\n")


            # GENERATE HEATMAPS AND SAVE
            hms = []
            for i_instance in range(len(pred_ids)):
                # select instance for testing
                test_sample     = data['images'][i_instance].copy().astype(np.float32)

                if explainer == "lime":
                    relevances      = dianna.explain_image(MODEL_PATH, test_sample,
                                                        method="LIME", labels=[pred_ids[i_instance]], nsamples=2000,
                                                        n_masks=1000, feature_res=12, p_keep=0.7,
                                                        axis_labels=('channels','height','width'))
                    grayscale_hm   = relevances[0]
                    
                elif explainer == "rise":
                    relevances = dianna.explain_image(MODEL_PATH, test_sample,
                                                    method="RISE", labels=[pred_ids[i_instance]], n_masks=10000,
                                                    feature_res=8, p_keep=0.8,
                                                    axis_labels=('channels', 'height', 'width'))
                    grayscale_hm   = relevances[0]
                
                elif explainer == "shap":
                    shap_values, segments_slic = dianna.explain_image(MODEL_PATH, test_sample,
                                                                    method="KernelSHAP", labels=[0], nsamples=2000,
                                                                    n_segments=300, sigma=0.2,
                                                                    axis_labels=('channels', 'height', 'width'))
                    grayscale_hm = shap_values # TODO: This is not correct!

                else:
                    print("Wrong explainer: ", explainer)
                    break

                grayscale_hm   = (grayscale_hm - grayscale_hm.min())/(grayscale_hm.max() - grayscale_hm.min() + 10e-7)
                hms.append(grayscale_hm.tolist())
                
                """
                fig, ax         = plt.subplots(1,3)
                ax[0].imshow(relevances[0],cmap='jet')
                ax[1].imshow(utils.img_to_array(test_sample[0])/255.,cmap='gray')
                ax[2].imshow(utils.img_to_array(test_sample[0]) / 255., cmap='gray')
                ax[2].imshow(relevances[0], cmap='jet', alpha=0.4)
                plt.title(str(pred_ids[i_instance])+'_'+str(pred_onnx[i_instance,pred_ids[i_instance]]))
                plt.show()
                """

            hms = np.array(hms)
            np.savez(OUT_PATH, heatmaps = hms,
                    color=data['color'], 
                    rotation=data['rotation'],
                    roundedness=data['roundedness'],
                    data_path = DATA_PATH)


In [None]:
statistic_results = []

# Now that we have all the heatmaps, we can create all the plots
for model_name in ["pytorch", "original"]:
    for explainer in ["rise", "lime"]: # shap
        for task_name in ["color", "rotation", "roundedness"]:
            OUT_PATH = f"{repo_path}/lorentz_workshop/explainer_results/{model_name}_{explainer}_{task_name}_heatmaps.npz"
            print("Working on", OUT_PATH)

            # load dataset
            if task_name == "color":
                DATA_PATH           = f'{repo_path}/example_data/dataset_preparation/geometric_shapes/test_colors.npz'
            elif task_name == "rotation":
                DATA_PATH           = f'{repo_path}/example_data/dataset_preparation/geometric_shapes/test_rotation.npz'
            elif task_name == "roundedness":
                DATA_PATH           = f'{repo_path}/lorentz_workshop/dataset_roundedness.npz'
            else:
                print("Nope")
                continue
            
            data                = prepare_data(DATA_PATH)

            # Load the heatmaps
            output = np.load(OUT_PATH, allow_pickle=True)

            # DIMENSION REDUCTION PLOTS. Taken from dim_red_plots.py
            images = data['images']
            heatmaps = output['heatmaps']

            if len(images.shape) == 4:
                embedding_img       = pca(images.reshape(images.shape[0], images.shape[2] * images.shape[3]))
            elif len(images.shape) == 3:
                embedding_img       = pca(images.reshape(images.shape[0], images.shape[1] * images.shape[2]))

            embedding_xai       = pca(heatmaps.reshape(heatmaps.shape[0], heatmaps.shape[1] * heatmaps.shape[2]))

            fig, ax                 = plt.subplots(1, 2)
            viridis                 = cm.get_cmap('viridis', len(embedding_img))

            for i in range(len(embedding_img)):
                ax[0].scatter(embedding_img[i, 0], embedding_img[i, 1], c=viridis(i), alpha=0.5)
                ax[1].scatter(embedding_xai[i, 0], embedding_xai[i, 1], c=viridis(i), alpha=0.5)
            fig.savefig(f"{repo_path}/lorentz_workshop/explainer_results/{model_name}_{explainer}_{task_name}_dim_red_plots.png")
            plt.close(fig)

            print("Finished dimension reduction plots.")


            # QI'S PLOTS. Taken from gradcam/read.ipynb
            x_values = data[task_name][:100]
            heatmaps = torch.tensor(output['heatmaps'])

            for distance_function in ["wasserstein", "mse_loss"]:
                for transform_function in ["minmax", "softmax"]:

                    if distance_function == "wasserstein":
                        metric = wasserstein_distance
                    elif distance_function == "mse_loss":
                        metric = mse_loss

                    if transform_function == "minmax":
                        transform = minmax
                    elif transform_function == "softmax":
                        transform = torch.nn.functional.softmax

                    changes = np.zeros(len(heatmaps)-1)
                    changes_original = np.zeros(len(output['heatmaps'])-1)
                    with torch.no_grad():
                        for i in range(len(output['heatmaps'])-1): 
                            adj_change = metric(
                                transform(heatmaps[i+1, ...].flatten()), 
                                transform(heatmaps[i, ...].flatten()))
                            changes[i] = adj_change
                            changes_original[i] = metric(transform(heatmaps[i+1, ...].flatten()),
                                                        transform(heatmaps[0, ...].flatten()))
                            
                    sns.lineplot(x=x_values[1:], y=changes)
                    plt.xlabel("Consecutive changes")
                    plt.ylabel(f"DF = {distance_function}, TF = {transform_function}")
                    plt.savefig(f"{repo_path}/lorentz_workshop/explainer_results/{model_name}_{explainer}_{task_name}_consecutive_changes_{distance_function}_{transform_function}.png")
                    plt.close()

                    sns.lineplot(x=x_values[1:], y=changes_original)
                    plt.xlabel("Changes wrt to the original input")
                    plt.ylabel(f"DF = {distance_function}, TF = {transform_function}")
                    plt.savefig(f"{repo_path}/lorentz_workshop/explainer_results/{model_name}_{explainer}_{task_name}_original_changes_{distance_function}_{transform_function}.png")
                    plt.close()

                    res_kendall = kendalltau(x_values[1:][~np.isnan(changes_original)], changes_original[~np.isnan(changes_original)])
                    res_pearson = pearsonr(x_values[1:][~np.isnan(changes_original)], changes_original[~np.isnan(changes_original)])
                    statistic_results.append({"model": model_name, "explainer": explainer, "change": task_name, "df": distance_function,
                                              "tf": transform_function, "kendall_pval": res_kendall.pvalue, "kendall_val": res_kendall.statistic,
                                              "pearson_pval": res_pearson.pvalue, "pearson_val": res_pearson.statistic})
                    
            print("Finished Qi's plots.")
            

# Write statistic_results to a html table
import pandas as pd
df = pd.DataFrame(statistic_results)
df.index = df.apply(lambda row: f"{row['model']}_{row['explainer']}_{row['change']}_{row['df']}_{row['tf']}", axis=1)
df = df[['kendall_pval', 'kendall_val', 'pearson_pval', 'pearson_val']]
df.to_html(f"{repo_path}/lorentz_workshop/explainer_results/original_changes_statistics.html")

In [68]:
# For exporting the pytorch model to onnx. No need to run this if you already have pytorch_model.onnx

import torch.nn as nn
import torch.nn.functional as F

class ShapesNet(nn.Module):
    def __init__(self, kernels=[8, 16], dropout = 0.2, classes=2):
        '''
        Two layer CNN model with max pooling.
        '''
        super(ShapesNet, self).__init__()
        self.kernels = kernels
        # 1st layer
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, kernels[0], kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout()
        )
        # 2nd layer
        self.layer2 = nn.Sequential(
            nn.Conv2d(kernels[0], kernels[1], kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout()
        )
        self.fc1 = nn.Linear(16 * 16 * kernels[-1], kernels[-1]) # pixel 64 / maxpooling 2 * 2 = 16
        self.fc2 = nn.Linear(kernels[-1], classes)

    def forward(self, x, mode='train'):
        x = self.layer1(x)
        x = self.layer2(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)

        if mode == 'train':
            return F.log_softmax(x, dim=1)
        else:
            return F.softmax(x, dim=1)

import os
model               = ShapesNet().to('cpu')
MODEL_PATH = f'{repo_path}/lorentz_workshop/gradcam/retrain_geometric_shapes_model.pt'
model.load_state_dict(torch.load(os.path.join(MODEL_PATH))['model_state_dict'])

# set the model to inference mode 
model.eval() 

# Let's create a dummy input tensor  
dummy_input = torch.randn(1, 1,64,64, requires_grad=True)  

# Export the model   
torch.onnx.export(model,         # model being run 
         dummy_input,       # model input (or a tuple for multiple inputs) 
         f"{repo_path}/lorentz_workshop/pytorch_model.onnx",       # where to save the model  
         export_params=True,  # store the trained parameter weights inside the model file 
         input_names = ['modelInput'],   # the model's input names 
         output_names = ['modelOutput'], # the model's output names 
         dynamic_axes={'modelInput' : {0 : 'batch_size'},    # variable length axes 
                                'modelOutput' : {0 : 'batch_size'}}) 