#### Set path

In [None]:
import sys
sys.path.append('C:/Users/matth/Documents/ETHZ/01_DS/02_HS23/02_DeepLearning/03_Project/00_Testbed_DL/scaling_mlps_mirror')

#### Import libraries

In [None]:
# Basic
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable
from torchvision import transforms
from PIL import Image
import time

# Attibution maps
from captum.attr import IntegratedGradients, Saliency, InputXGradient, GuidedBackprop, NoiseTunnel, LRP, DeepLift

# MLP model
from data_utils.data_stats import *
from models.networks import *
from utils.download import *

# CNN model
from torchvision.models import resnet50, ResNet50_Weights

#### Load models

In [None]:
## Load MLP
dataset = 'imagenet'                 # One of cifar10, cifar100, stl10, imagenet or imagenet21
architecture = 'B_12-Wi_1024'        #'B_6-Wi_512'         #'B_12-Wi_1024'  'B-12_Wi-1024_res_64_imagenet_epochs_50'   
resolution = 64                      # Resolution of fine-tuned model (64 for all models we provide)
num_classes = CLASS_DICT[dataset]
checkpoint = 'in21k_imagenet'        # This means you want the network pre-trained on ImageNet21k and finetuned on CIFAR10
model_mlp = get_model(architecture=architecture, resolution=resolution, num_classes=num_classes,
                  checkpoint=checkpoint, load_device='cpu', dropout=False)
model_mlp.eval()

In [None]:
## Load CNN
model_cnn = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
for param in model_cnn.parameters():
    param.requires_grad = False
model_cnn.eval()


#### Define plot generating functions

In [None]:
def create_attribution_map(model, model_type, input_batch, target_y, algorithm, smooth_bool=False, ax=plt):  

    # Compute attributions
    if smooth_bool==False:
        attr_algo = algorithm(model)
        attributions = attr_algo.attribute(inputs=input_batch, target=target_y)
    elif smooth_bool==True:
        attr_algo = algorithm(model)
        nt = NoiseTunnel(attr_algo)
        stdev = float(np.std(input_batch.detach().numpy()))/3
        attributions = nt.attribute(input_batch, target=target_y, nt_type='smoothgrad', stdevs=stdev, nt_samples=50) # smoothgrad, 'smoothgrad_sq', 'vargrad'

    # Remove batch and channel dimensions
    if model_type == "CNN": 
        attribution_array = attributions.squeeze(0).permute(1, 2, 0).detach().numpy()
        attribution_array_abs = np.absolute(attribution_array)
        attribution_array_abs_max = np.max(attribution_array_abs, axis=2)
    elif model_type == "MLP":
        attribution_array = attributions.detach().numpy().reshape([1, 3, resolution, resolution]).squeeze(0)
        attribution_array_abs = np.absolute(attribution_array)
        attribution_array_abs_max = np.max(attribution_array_abs, axis=0)

    # Display the original size image
    ax.imshow(attribution_array_abs_max, cmap='gray', vmin=np.percentile(attribution_array_abs_max, 1), vmax = np.percentile(attribution_array_abs_max, 99))
    
    # Hide X and Y axes label marks
    ax.xaxis.set_tick_params(labelbottom=False)
    ax.yaxis.set_tick_params(labelleft=False)

    # Hide X and Y axes tick marks
    ax.set_xticks([])
    ax.set_yticks([])

    # Remove frame
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)

In [None]:
def plot_original_image(original_image, ax=plt):

    # Plot image
    ax.imshow(input_image)
    
    # Hide X and Y axes label marks
    ax.xaxis.set_tick_params(labelbottom=False)
    ax.yaxis.set_tick_params(labelleft=False)

    # Hide X and Y axes tick marks
    ax.set_xticks([])
    ax.set_yticks([])

    # Remove frame
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)

#### Load input

In [None]:
# Specify path
#path = "C:\\Users\\matth\Documents\\ETHZ\\01_DS\\02_HS23\\02_DeepLearning\\03_Project\\00_Working\\model-vs-human\\datasets\\colour\\dnn\\session-1\\"
path = './source_images/'


## Selection short
filename_arr_short = np.array([
    "0088_cl_dnn_cr_bear_40_n02133161_2202.png", # 295	American black bear
    "0110_cl_dnn_cr_elephant_40_n02504458_1600.png", # 386 african elephant
    "0509_cl_dnn_cr_chair_40_n04099969_4543.png", # 765 rocking chair
    "0079_cl_dnn_cr_clock_40_n04548280_77.png", # 892 wall clock
    ])
target_y_arr_short = np.array([295, 385, 765, 892]) 


## Selection long
filename_arr_long = np.array([
    "0088_cl_dnn_cr_bear_40_n02133161_2202.png", # 295	American black bear
    "0528_cl_dnn_cr_bear_40_n02133161_1362.png", # 297 sloth bear
    "0110_cl_dnn_cr_elephant_40_n02504458_1600.png", # 386 african elephant
    "0133_cl_dnn_cr_elephant_40_n02504013_4892.png", # 385 indian elephant
    "0447_cl_dnn_cr_chair_40_n03376595_1302.png", # 559 folding chair
    "0509_cl_dnn_cr_chair_40_n04099969_4543.png", # 765 rocking chair
    "0279_cl_dnn_cr_clock_40_n04548280_24041.png", # 892 wall clock
    "0079_cl_dnn_cr_clock_40_n04548280_77.png", # 892 wall clock
    ])
target_y_arr_selection = np.array([295, 297, 386, 385, 559, 765, 892, 892]) 

In [None]:
## Select selection(s) to generate attribution maps for
filename_arr_arr = np.array([filename_arr_short])
target_y_arr_arr = np.array([target_y_arr_short])
path = path

#### Generate attribution maps for specified selection

In [None]:
for (filename_arr, target_y_arr) in zip(filename_arr_arr, target_y_arr_arr):

    # Initialize plot
    imax = filename_arr.shape[0]
    fig, axs = plt.subplots(imax,5,figsize=(2*5,2*imax))
    plt.subplots_adjust(wspace=0.1, hspace=0.1)

    # Create subplots
    for iter, (filename, target_y) in enumerate(zip(filename_arr[:imax], target_y_arr[:imax])):
            
        target_y = int(target_y)
        print("Iterator:", iter)
        print("Filename:", filename)
        print("Target Y:", target_y)
        
        ##--------------------------------------------

        # Load image
        filepath = path + filename
        input_image = Image.open(filepath)

        ##--------------------------------------------

        ## Prepare input for MLP
        preprocess_mlp = transforms.Compose([
            transforms.ToTensor(),
            transforms.ConvertImageDtype(torch.float32),
            transforms.Resize(size=(resolution, resolution), antialias=True),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        input_batch_mlp = preprocess_mlp(input_image).unsqueeze(0)
        input_batch_mlp_reshaped = torch.reshape(input_batch_mlp, (input_batch_mlp.shape[0], -1))
        input_batch_mlp_reshaped = Variable(input_batch_mlp_reshaped, requires_grad=True)

        ##--------------------------------------------
        
        ## Prepare input for CNN
        preprocess_cnn = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        input_batch_cnn = preprocess_cnn(input_image).unsqueeze(0)
        
        ##--------------------------------------------

        # Create saliency maps
        plot_original_image(original_image=input_image, ax=axs[iter,0])
        create_attribution_map(model=model_mlp, model_type="MLP", input_batch=input_batch_mlp_reshaped, target_y=target_y, algorithm=Saliency, smooth_bool=True, ax=axs[iter,1])
        create_attribution_map(model=model_mlp, model_type="MLP", input_batch=input_batch_mlp_reshaped, target_y=target_y, algorithm=InputXGradient, smooth_bool=True, ax=axs[iter,2])
        create_attribution_map(model=model_cnn, model_type="CNN", input_batch=input_batch_cnn, target_y=target_y, algorithm=Saliency, smooth_bool=True, ax=axs[iter,3])
        create_attribution_map(model=model_cnn, model_type="CNN", input_batch=input_batch_cnn, target_y=target_y, algorithm=InputXGradient, smooth_bool=True, ax=axs[iter,4])

        if iter==0:
            # Set titles
            fontsize = 10
            axs[iter,1].set_title('Gradient: \n B-12/Wi-1024+DA', fontsize=fontsize)
            axs[iter,2].set_title('Input * Gradient: \n B-12/Wi-1024+DA', fontsize=fontsize)
            axs[iter,3].set_title('Gradient: \n ResNet-50', fontsize=fontsize)
            axs[iter,4].set_title('Input * Gradient: \n ResNet-50', fontsize=fontsize)

    # Specify the folder path where you want to save the plot
    output_folder = './output_images'
    # Ensure the output folder exists or create it if not
    os.makedirs(output_folder, exist_ok=True)
    # Save the plot to the specified folder
    timestr = time.strftime("%Y%m%d-%H%M%S")
    output_filename = timestr + '.png'
    output_filepath = os.path.join(output_folder, output_filename)
    plt.savefig(output_filepath)

