<a href="https://colab.research.google.com/github/k3larra/IKR/blob/master/ikr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# "When Can I Trust It?" Contextualising Explainability Methods for Classifiers
In experiment 3 for the paper with the above title we compare internal knowledge representations for a number of, on ImageNet-1k pretrained models. For the comparison the model agnostic XAI method Occlusion is used. By doing this we can, from some perspective, compare internal representations for the models. 

We can then draw the conclusion that the neural networks puts emphasis on different areas in the images. This is not any surprise but it poses questions when we are to select models for our explanations, what model is most trustworthy, and from what perspective. 

Analogous in a human setting, how to chose the best expert?



# Setup

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Test set, models and label 

In [2]:
#Testset
%%capture
from zipfile import ZipFile
! git clone https://github.com/k3larra/IKR
with ZipFile('/content/IKR/testset/testset.zip', 'r') as archive:
  archive.extractall('/content/testset')

In [3]:
#ImageNet1k labels 
%%capture
!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]
num_classes = len(categories)

def label_to_idx(label):
  return categories.index(label)

def idx_to_label(idx):
  return categories[idx]

In [None]:
#Load models
import torch
print(torch.__version__)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from torchvision.models import resnet101, ResNet101_Weights
model_resnet101 = resnet101(weights=ResNet101_Weights.IMAGENET1K_V2)
model_resnet101.eval()
model_resnet101.name = "ResNet101"
model_resnet101 = model_resnet101.to(device)

from torchvision.models import resnet152, ResNet152_Weights
model_resnet152 = resnet152(weights=ResNet152_Weights.IMAGENET1K_V2)
model_resnet152.eval()
model_resnet152.name = "ResNet152"
model_resnet152 = model_resnet152.to(device)

from torchvision.models import googlenet, GoogLeNet_Weights
model_googlenet = googlenet(weights=GoogLeNet_Weights.IMAGENET1K_V1)
model_googlenet.eval()
model_googlenet.name = "GoogLeNet"
model_googlenet = model_googlenet.to(device)

from torchvision.models import inception_v3, Inception_V3_Weights
model_inception_v3 = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1)
model_inception_v3.eval()
model_inception_v3.name = "Inception_V3"
model_inception_v3 = model_inception_v3.to(device)

from torchvision.models import efficientnet_v2_s,EfficientNet_V2_S_Weights
model_efficientnet_v2_s = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
model_efficientnet_v2_s.eval()
model_efficientnet_v2_s.name = "Efficientnet_V2_s"
model_efficientnet_v2_s = model_efficientnet_v2_s.to(device)

from torchvision.models import regnet_y_8gf,RegNet_Y_8GF_Weights
model_regnet_y_8gf = regnet_y_8gf(weights=RegNet_Y_8GF_Weights.IMAGENET1K_V2)
model_regnet_y_8gf.eval()
model_regnet_y_8gf.name = "RegNet_Y_8GF"
model_regnet_y_8gf = model_regnet_y_8gf.to(device)

from torchvision.models import swin_t,Swin_T_Weights
model_swin_t = swin_t(weights=Swin_T_Weights.IMAGENET1K_V1)
model_swin_t.eval()
model_swin_t.name = "Swin_T_Weights"
model_swin_t = model_swin_t.to(device)

from torchvision.models import convnext_tiny,ConvNeXt_Tiny_Weights
model_convnext_tiny = convnext_tiny(weights=ConvNeXt_Tiny_Weights.IMAGENET1K_V1)
model_convnext_tiny.eval()
model_convnext_tiny.name = "ConvNeXt_Tiny"
model_convnext_tiny = model_convnext_tiny.to(device)

experiment_models = [model_resnet101,
                     model_resnet152,
                     model_googlenet,
                     model_inception_v3,
                     model_efficientnet_v2_s,
                     model_regnet_y_8gf,
                     model_swin_t,
                     model_convnext_tiny]
experiment_weights= [ResNet101_Weights.IMAGENET1K_V2,
                     ResNet152_Weights.IMAGENET1K_V2,
                     GoogLeNet_Weights.IMAGENET1K_V1,
                     Inception_V3_Weights.IMAGENET1K_V1,
                     EfficientNet_V2_S_Weights.IMAGENET1K_V1,
                     RegNet_Y_8GF_Weights.IMAGENET1K_V2,
                     Swin_T_Weights.IMAGENET1K_V1,
                     ConvNeXt_Tiny_Weights.IMAGENET1K_V1]

In [5]:
#Data transformation and inference
from torchvision import transforms
from torchvision.io import read_image
import torch.nn.functional as F

eval_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])

transform_normalize = transforms.Normalize( 
     mean=[0.485, 0.456, 0.406],
     std=[0.229, 0.224, 0.225]
 )

def transform_eval_data(img_path, eval_transform = None):
  image = Image.open(img_path).convert('RGB')
  if eval_transform:
      image = eval_transform(image)
      image = transform_normalize(image) 
  image = image.float()
  return image

def norm_image(image):
  data_min = np.min(image, axis=(1,2), keepdims=True)
  data_max = np.max(image, axis=(1,2), keepdims=True)
  image = (image - data_min) / (data_max - data_min)
  return image

def show_image(image, title):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.title(title)
    plt.axis("off")
    plt.pause(0.001)  # pause a bit so that plots are updated

def eval_model(experiment_set, test_model, print_eval=False):
  test_model.eval()   # Set model to evaluate mode
  for experiment_sample in experiment_set:
    with torch.no_grad():
      input_img = experiment_sample.unsqueeze(0).to(device)
      prediction = test_model(input_img).squeeze(0).softmax(0)
      values,indices=prediction.topk(5)
      if print_eval:
        show_image(norm_image(input_img.squeeze().permute(1,2,0).cpu().numpy()), test_model.name)
        for i in range(0,5):
          print("Label",indices[i].item(),": class ",  idx_to_label(indices[i])," probability :", str(int(np.round(values[i].item(),3)*100)),"%")

def get_all_files(experiment_path):
  loaded_files = []
  for f in sorted(os.listdir(experiment_path)):
    if f.endswith('.PNG') or f.endswith('.png') or f.endswith('.jpg') or f.endswith('.JPG'):
      loaded_files.append(f)
  return loaded_files

def load_experiment_data(experiment_path, test_model, weights, use_predifined_tranformation=False, plot_data=False, evaluate_model=False, print_evaluation=False):
  experiment_set = []
  eval_dir = get_all_files(experiment_path)
  eval_size = len(eval_dir)
  for i in range(eval_size):
    if use_predifined_tranformation:
      transforms = weights.transforms()
      experiment_set.append(transforms(read_image(experiment_path + eval_dir[i])))
    else:
      experiment_set.append(transform_eval_data(experiment_path + eval_dir[i],eval_transform))
  if plot_data:
    figure = plt.figure(figsize=(8, 8))
    cols, rows = 3, 1
    for i in range(1, cols * rows + 1):
        img = experiment_set[i-1]
        figure.add_subplot(rows, cols, i)
        plt.title(img.shape)
        plt.axis("off")
        plt.imshow(norm_image(img.permute(1,2,0).numpy()))
    plt.show()
  if evaluate_model:
    eval_model(experiment_set, test_model, print_evaluation)
  return experiment_set

# XAI Experiments

## Occlusion
Using the implementation from [captum.ai](captum.ai)

In [6]:
%%capture
! pip install captum 
from captum.attr import Occlusion
from captum.attr import LayerAttribution
from captum.attr import visualization as viz
def calculate_occlusion(experiment_model, target_idx, input_img, top_candidates, top_probs, save_path=""):
  #experiment_model = experiment_model.to(device)
  input_img = input_img.to(device)
  occlusion = Occlusion(experiment_model)
  attributions = occlusion.attribute(input_img,
                                    sliding_window_shapes=(1, 32, 32),
                                    strides=(1, 32, 32),
                                    target=target_idx,
                                    baselines = 0)
  input_img = input_img.squeeze()
  result = viz.visualize_image_attr(attributions[0].cpu().permute(1,2,0).detach().numpy(),
                              input_img.cpu().permute(1,2,0).detach().numpy(), 
                              method="blended_heat_map",
                              sign="all",
                              fig_size=(6,6))
  prob=0
  for idx, candate in enumerate(top_candidates): #Find acc for target
    if candate.item()== target_idx:
      prob=str(int(np.round(top_probs[idx].item(),2)*100))
  print("Top candidate:",str(idx_to_label(top_candidates[0].item())), "with prob:",str(int(np.round(top_probs[0].item(),2)*100))," model:",experiment_model.name)    
  print("occlusion with target: ",str(idx_to_label(target_idx)), "prob: ", prob,"%")
  result[0].savefig(save_path,bbox_inches='tight', pad_inches = 0)
  return attributions

# Experiments

In [7]:
def run_experiment(experiment_models, experiment_weights, experiment_path, experiment_name, use_predifined_tranformation=False, save_path="", debug=False):
  """Saves original and transformed images and makes predictions using models in experiment_models.
    Also builds a JSON file containing class probabilities for the 10 clsses with highest prediction score.  
    Parameters
    --------
    target_level: gives ...
    """
  path_to_experiment = save_path + experiment_name.replace(" ", "") + "/"
  print("Experiment name: ",experiment_name)
  print("Save Path: ", path_to_experiment)
  print("------------------")
  images=get_all_files(experiment_path)
  experiment_set = load_experiment_data(experiment_path,  experiment_models[0], experiment_weights[0],use_predifined_tranformation=False, plot_data=False, evaluate_model = False, print_evaluation=False) 
  experiment_json_data={}
  for idx, experiment_sample in enumerate(experiment_set):
      path_to_save = path_to_experiment+"image"+str(idx) + "/"
      if not os.path.exists(path_to_save):
        os.makedirs(path_to_save)
      image_info={}    
      image_path_transformed = path_to_save+"imagetransformed"+str(idx)+".PNG"
      image_path = path_to_save+"image"+str(idx)+".PNG"
      image = Image.open(experiment_path+images[idx])
      image.save(image_path)
      image.close
      im = norm_image(experiment_sample.cpu().permute(1,2,0).detach().numpy())
      plt.imsave(image_path_transformed,im)
      image_info["image_path"] = image_path
      image_info["image_path_transformed"] = image_path_transformed
      nbr_maps=len(experiment_models)
      saliency_map_aggregate=torch.zeros(1000,nbr_maps,147) #Holds the 1k Imagenet label_idx for five alternatives for the modes and 7x7x3 images flattended
      for model_index,model in enumerate(experiment_models):
        model_prediction = {}
        ##Here we need to change transform!!!!!
        experiment_sample = experiment_set[idx]
        experiment_sample = experiment_sample.unsqueeze(0)
        top_candidates, top_probs, jdata = process_input(experiment_sample, model, debug=debug) 
        image_info[model.name]=jdata
        saliency_json = create_saliency_maps(model,experiment_sample,top_candidates,top_probs,path_to_save,saliency_map_aggregate,model_index,nbr_maps=nbr_maps,debug=debug)
        image_info[model.name]["xai"]=saliency_json
      image_info["diff_mean_maps"] = calc_show_difference(saliency_map_aggregate,experiment_path,path_to_save,idx,debug=debug)
      experiment_json_data[idx]=image_info
      #check
      experiment_data = json.dumps(experiment_json_data[idx]) 
      with open(path_to_experiment + 'structure'+str(idx)+'.json', 'w') as outfile:
        outfile.write(experiment_data)
  if debug:
    print(json.dumps(experiment_json_data, indent=2))    
  experiment_data = json.dumps(experiment_json_data) 
  with open(path_to_experiment + 'structure.json', 'w') as outfile:
      outfile.write(experiment_data)

def process_input(input_img, experiment_model, debug=False): #This should be identival to above......
  input_img = input_img.to(device)
  experiment_model = experiment_model.to(device)
  output = experiment_model(input_img)
  probabilities = F.softmax(output[0], dim=0)
  top_prob, top_catid = torch.topk(probabilities, num_classes)
  jsonData={}
  for i in range(10):
      prediction={}
      prob=top_prob[i].item()
      prediction["probability"] = np.round(top_prob[i].item(),9)
      prediction["label"] = idx_to_label(top_catid[i])
      prediction["labelid"] = top_catid[i].item()
      jsonData[i]=prediction
  if debug:
    print(jsonData)
  return top_catid, top_prob, jsonData

def create_saliency_maps(experiment_model,experiment_sample,top_candidates,top_probs,path_to_save,saliency_map_aggregate,model_index,nbr_maps=5,debug=False):
    """ Creates saliency maps for nbr_maps with highest class probability
    Parameters
    -------- 
    nbr_maps: the number of top_candidates saliency maps are created for
    """
    path_to_save = path_to_save+experiment_model.name+"/occlusion/"
    if not os.path.exists(path_to_save):
      os.makedirs(path_to_save)
    json_data={}
    json_data["XAI-method"] = "Occlusion"
    json_data["image_path"] = path_to_save
    json_data["settings"] = "sliding_window_shapes=(1, 32, 32),strides=(1, 32, 32)"
    json_data["code_ref"] = "https://captum.ai/api/occlusion.html"
    for i in range(0,nbr_maps):
      #json_data[i] = path_to_save+idx_to_label(top_candidates[i]).replace(" ", "_")+".PNG"
      json_data[i] = path_to_save+str(top_candidates[i].item())+".PNG"
      saliency_map =calculate_occlusion(experiment_model,
                            target_idx=top_candidates[i],
                            input_img=experiment_sample,
                            top_candidates=top_candidates, 
                            top_probs=top_probs,   
                            save_path=json_data[i])
      json_data["metrix_"+str(i)] = calculate_metrix_for_saliency_map(saliency_map,top_candidates[i])
      if debug:
       print("saliency_map.shape",saliency_map.shape)
       print(json.dumps(json_data, indent=2))
      add_saliency_row(saliency_map_aggregate,saliency_map,top_candidates[i],model_index,debug=debug)
    return json_data

def calculate_metrix_for_saliency_map(saliency_map,target_idx):
  json_metrix={}
  saliency_map = transforms.functional.resize(saliency_map,(7,7))
  saliency_map = torch.mean(saliency_map,1)
  saliency_map = torch.flatten(saliency_map)
  json_metrix["target_idx"]= str(target_idx.item())
  raw_string=""
  for i in range(0,49):
    if i>0:
      raw_string=raw_string+","+str(np.round((saliency_map[i]).item(),3))
    else:
      raw_string=str(np.round((saliency_map[i]).item(),3))
  saliency_map[saliency_map<0]=0
  json_metrix["mean"] = np.round(torch.mean(saliency_map).item(),3)
  json_metrix["max"] = np.round(torch.max(saliency_map).item(),3)
  json_metrix["min"] = np.round(torch.min(saliency_map).item(),3)
  mean_string=""
  for i in range(1,50):
    if i>1:
      mean_string=mean_string+","+str(np.round(torch.mean(torch.topk(saliency_map,i)[0]).item(),3))
    else:
      mean_string=str(np.round(torch.mean(torch.topk(saliency_map,i)[0]).item(),3))
  json_metrix["mean_values"]=mean_string
  json_metrix["raw_string"]=raw_string
  return json_metrix
    


##Similarity and dispersion measurement between models.

In [8]:
def add_saliency_row(saliency_map_aggregate,saliency_map,label_index,model_index,debug=False):
  saliency_map = transforms.functional.resize(saliency_map,(7,7))
  saliency_map = torch.flatten(saliency_map)
  saliency_map=saliency_map[None, :]
  saliency_map_aggregate[label_index,model_index]=saliency_map
  if debug:
    print("Adding saliency row with shape ",saliency_map," to index:",label_index,":",idx_to_label(int(label_index)))
    print("Min value is:",torch.min(saliency_map)," and max is ",torch.max(saliency_map))
  return saliency_map_aggregate

def calc_show_difference(saliency_map_aggregate,experiment_path,save_path,image_index,debug=False):
  json_data={}
  ##Check this needed or not
  experiment_set = load_experiment_data(experiment_path,  "dummy_name", experiment_weights[0],use_predifined_tranformation=False, plot_data=False, evaluate_model = False, print_evaluation=False) 
  #Only using this to get the image sp model name not important
  experiment_sample = experiment_set[image_index]
  input_img = experiment_sample.unsqueeze(0)
  input_img = input_img.to(device)
  input_img = input_img.squeeze()
  for index,value in enumerate(saliency_map_aggregate):
    if int(torch.sum(saliency_map_aggregate[index]).item())!=0:
      saliency_candidate = torch.zeros(1,147)
      for j,row in enumerate(saliency_map_aggregate[index]):
        if int(torch.sum(row).item())!=0:
          if debug:
            print("Adding saliency map for", index,":", idx_to_label(index))
          if int(torch.sum(saliency_candidate).item())==0:
            saliency_candidate=row[None, :]
          else:
            saliency_candidate = torch.cat((saliency_candidate,row[None, :]),0)
      if debug:
        print("final",saliency_candidate)
        print("final",saliency_candidate.shape)
        print("saliency_candidate.size(dim=0)",saliency_candidate.size(dim=0))
        print("label_index",index,":",idx_to_label(index))
      if saliency_candidate.size(dim=0)>1:  
        mean_saliency_map = torch.mean(saliency_candidate, 0) #Mean over all columns for all maps
        if debug:
          print("more than one map")
          print("saliency_candidate",saliency_candidate)
          print("saliency_candidate.shape",saliency_candidate.shape)
          print("mean_saliency_map for label_index:",index,":",idx_to_label(int(index)))
          print("mean_saliency_map",mean_saliency_map)
          print("mean_saliency_map shape",mean_saliency_map.shape)
        mean_image=torch.reshape(mean_saliency_map, (1,3,7,7))
        attributions = LayerAttribution.interpolate(mean_image, [224,224]) 
        result=viz.visualize_image_attr(attributions[0].cpu().permute(1,2,0).detach().numpy(),
                                    input_img.cpu().permute(1,2,0).detach().numpy(),  
                                    method="blended_heat_map",
                                    sign="positive",
                                    fig_size=(6,6))
        print("Mean saliency map for prediction ",idx_to_label(index),":",index)
        save_path_mean = save_path+'mean_image'+str(image_index)+'_candidate'+str(index)+'_occ.PNG'
        result[0].savefig(save_path_mean,bbox_inches='tight', pad_inches = 0)
        json_data["mean_image_path_candidate_"+str(index)]=save_path_mean
        mean_saliency_matrix = transforms.functional.resize(mean_image,(7,7)) #MM and think here of max numbers what to do....perhaps compare to individual for the models.... 
        if debug:
           print("mean_saliency_matrix: ",mean_saliency_matrix)
           print("mean_saliency_matrix.shape: ",mean_saliency_matrix.shape)
        mean_saliency_matrix=torch.mean(mean_saliency_matrix,1)
        mean_saliency_matrix_flatten = torch.flatten(mean_saliency_matrix)
        mean_saliency_matrix_flatten[mean_saliency_matrix_flatten<0]=0 #Remove neg attribs
        if debug:
           print("mean_saliency_matrix: ",mean_saliency_matrix_flatten)
           print("mean_saliency_matrix.shape: ",mean_saliency_matrix_flatten.shape)
        if debug:
          print("mean saliency map for:",idx_to_label(index)," with index:",index)
          print("mean:",np.round(torch.mean(mean_saliency_matrix_flatten).item(),3))
          print("mean 5:",np.round(torch.mean(torch.topk(mean_saliency_matrix_flatten,5)[0]).item(),3))
          print("mean 10:",np.round(torch.mean(torch.topk(mean_saliency_matrix_flatten,10)[0]).item(),3))
        json_data["mean_for_candidate_"+str(index)] = np.round(torch.mean(mean_saliency_matrix_flatten).item(),3)
        json_data["max_for_candidate_"+str(index)] = np.round(torch.max(mean_image).item(),3)
        json_data["min_for_candidate_"+str(index)] = np.round(torch.min(mean_image).item(),3)
        mean_string=""
        for i in range(1,50):
          if i>1:
            mean_string=mean_string+","+str(np.round(torch.mean(torch.topk(mean_saliency_matrix_flatten,i)[0]).item(),3))
          else:
            mean_string=str(np.round(torch.mean(torch.topk(mean_saliency_matrix_flatten,i)[0]).item(),3))
        json_data["mean_average_csv_for_candidate_"+str(index)] = mean_string
        std_saliency_map = torch.std(saliency_candidate, 0, unbiased=False) ###Should unbiased be True since it is a sample from different models in reality????
        std_image=torch.reshape(std_saliency_map, (1,3,7,7))
        attributions = LayerAttribution.interpolate(torch.negative(std_image), [224,224]) #Turn them around
        result=viz.visualize_image_attr(attributions[0].cpu().permute(1,2,0).detach().numpy(),
                                      input_img.cpu().permute(1,2,0).detach().numpy(),  
                                      method="blended_heat_map",
                                      sign="negative",
                                      fig_size=(6,6))
        print("Mean standard deviation saliency map for prediction ",idx_to_label(index),":",index)
        savepath_diff = save_path+'diff_image'+str(image_index)+'_candidate'+str(index)+'_occ.PNG'
        result[0].savefig(savepath_diff,bbox_inches='tight', pad_inches = 0)
        json_data["diff_image_path_candidate_"+str(index)]=savepath_diff
        std_saliency_matrix = transforms.functional.resize(std_image,(7,7)) ##??
        if debug:
          print("std_saliency_matrix: ",std_saliency_matrix)
          print("std_saliency_matrix.shape: ",std_saliency_matrix.shape)
        std_saliency_matrix=torch.mean(std_saliency_matrix,1)
        if debug:
          print("again ? after mean std_saliency_matrix: ",std_saliency_matrix)
          print("std_saliency_matrix.shape: ",std_saliency_matrix.shape)
        std_saliency_matrix_flatten = torch.flatten(std_saliency_matrix)
        if debug:
          print("mean std:",np.round(torch.mean(std_saliency_matrix_flatten).item(),3))
          print("mean std 5:",np.round(torch.mean(torch.topk(std_saliency_matrix_flatten,5)[0]).item(),3))
          print("mean std 10:",np.round(torch.mean(torch.topk(std_saliency_matrix_flatten,10)[0]).item(),3))
        json_data["mean_std_image_path_candidate_"+str(index)] =  np.round(torch.mean(std_saliency_matrix_flatten).item(),3)
        std_string=""
        for i in range(1,50):
          if i>1:
            std_string=std_string+","+str(np.round(torch.mean(torch.topk(std_saliency_matrix_flatten,i)[0]).item(),3))
          else:
            std_string=str(np.round(torch.mean(torch.topk(std_saliency_matrix_flatten,i)[0]).item(),3))
        json_data["std_average_csv_for_candidate_"+str(index)] = std_string
      else:
        if debug:
          print("Only one saliency map found for the label,",index,":",idx_to_label(index)," no mean or std can be calculated.")
      if debug:
        print(json.dumps(json_data, indent=2))
  return json_data


# Run experiments!

In [None]:
import json
experiment_path = "/content/testset/"
experiment_name= "version07"
save_path = 'trust_1/'
if not os.path.exists(save_path):
        os.makedirs(save_path)

run_experiment(experiment_models, experiment_weights, experiment_path, experiment_name, use_predifined_tranformation=False, save_path=save_path, debug=False)
json_data={}
json_data["description"] = "Five models"
json_data["image tranformation"] ="standard image transformation: resize256 centercrop 224 and imagenet transforms"
with open(save_path+experiment_name+"/description.json", "w") as outfile:
    outfile.write(json.dumps(json_data))

In [None]:
zip_name = "trust_1"
XAI_compress = "/"
import shutil
shutil.make_archive("/content/"+zip_name, 'zip', "/content/trust_1")

'/content/trust_1.zip'

In [11]:
#Cleaning if needed
! rm -r trust_1
! rm -r testset
! rm -r testshapes2
! rm imagenet_classes.txt

rm: cannot remove 'IKR': No such file or directory
rm: cannot remove 'testshapes2': No such file or directory
rm: cannot remove 'imagenet_classes.txt': No such file or directory
