# Reproduce Figures and Tables

In [None]:
import warnings
from reproduce import reproduce_results

# -- Catching pandas slice warnings
with warnings.catch_warnings(record=True):
    reproduce_results()

# Reproduce XAI Figures

In [None]:
import os.path
import json
import matplotlib.pyplot as plt
import numpy as np
import cv2
import yaml
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from glob import glob
from PIL import Image
# -- internal
from utils import get_model, get_image_data_list, get_augmentations, XAI_model_wrapper, MAPPINGS 

# -- Path  collection
XAI_CONFIG = 'config/xai_config.yaml'
LOCAL_CHECKPOINTS =  'res/checkpoints/'
RESULT_PATH = 'plots/xai/'

In [None]:
with open(XAI_CONFIG, "r") as file:
    xai_data_config = yaml.load(file, Loader=yaml.FullLoader)
    
checkpoint_uuid_list = xai_data_config['checkpoints'] # all on downstream kather for 1-Layer-MLP

for checkpoint_uuid in checkpoint_uuid_list:    
    # NOTE: Determined checkpoint fetch replaced for docker container
    checkpoint_path = Path(LOCAL_CHECKPOINTS, checkpoint_uuid)
    with open(Path(checkpoint_path, 'metadata.json'), "r") as metadata_file:
        mf_tmp = json.load(metadata_file)
        checkpoint_configuration = mf_tmp['experiment_config']
        checkpoint_hparam = mf_tmp['hparams']
    
    # -- get augementation that was used for testing process in the downstream task
    augmentations = get_augmentations(checkpoint_hparam)
    
    # -- load checkpoint state dict and init model
    state_dict = t.load(Path(checkpoint_path, 'state_dict.pth'), map_location="cpu")['models_state_dict']
    model = get_model(checkpoint_uuid, checkpoint_hparam, state_dict)
    # model.load_state_dict(state_dict[0])

    try:
        model_name = checkpoint_hparam['model_name']
    except:
        model_name = checkpoint_hparam['encoder']
    
    # -- getting test images to explain
    xai_data_list, classes, dataset_name = get_image_data_list(xai_config=xai_data_config, dataset=checkpoint_hparam['dataset'])
    
    # do XAI 
    for i,image_path in enumerate(xai_data_list):
        test_image = Image.open(image_path)
        test_img_tensor = augmentations(test_image).unsqueeze(0)
        test_tensor = t.vstack([test_img_tensor, t.ones((1,3,96,96))])
        
        # -- GradCam
        xai_model = XAI_model_wrapper(base_model=model, model_name=model_name).eval()
        out = xai_model(test_tensor)[0]
        pred = np.argmax(out.detach().cpu().numpy())
        out[pred].backward()  # getting features for predicted class
        fma = xai_model.fma[0].detach().cpu().numpy()  
        fmg = xai_model.fmg[0][0].detach().cpu().numpy()   
        fmg_weights = xai_model.get_fmg_weights()[0].detach().numpy()
        
        Grad_CAM_map = np.zeros((fma.shape[-1],fma.shape[-2]))
        for i in range(fmg.shape[0]):
            Grad_CAM_map += fmg_weights[i, 0, 0] * fma[i, :, :]
        
        if (checkpoint_hparam['dataset']=='patchcamelyon'):
            Grad_CAM_map = cv2.resize(Grad_CAM_map, dsize=(96,96), interpolation=cv2.INTER_LINEAR )
        else:
            Grad_CAM_map = cv2.resize(Grad_CAM_map, dsize=(224,224), interpolation=cv2.INTER_LINEAR )
        
        # -- Guided Backprop
        xai_model = XAI_model_wrapper(base_model=model, model_name=model_name).eval()
        test_tensor.requires_grad = True
        def backward_hook(module, grad_in, grad_out):
                if isinstance(module, nn.ReLU):
                    return (F.relu(grad_in[0]),)

        for module in xai_model.named_modules():
            module[1].register_backward_hook(backward_hook)

        out = xai_model(test_tensor)[0]
        pred = np.argmax(out.detach().cpu().numpy())

        test_img_gradient = t.autograd.grad(out[pred],test_tensor, allow_unused=True)[0][0]    
        np_test_img_gradient = abs(test_img_gradient.detach().cpu().numpy().transpose(1,2,0))
        saliency_map = np.zeros_like(np_test_img_gradient[:,:,0])

        for i in range(np_test_img_gradient.shape[0]):
            for j in range(np_test_img_gradient.shape[1]):
                saliency_map[i,j] = max(np_test_img_gradient[i,j,0],np_test_img_gradient[i,j,1],np_test_img_gradient[i,j,2])

        if (checkpoint_hparam['dataset']=='patchcamelyon'):
            Guided_Backprop_map = cv2.resize(saliency_map,dsize=(96,96),interpolation=cv2.INTER_LINEAR)
        else:
            Guided_Backprop_map = cv2.resize(saliency_map,dsize=(224,224),interpolation=cv2.INTER_LINEAR)

        # -- Create plots
        fig = plt.figure(figsize=(18,4))
        # -- Plain test image
        plt.subplot(151)
        plt.title('Original Image:'+ image_path.stem, size=12)
        plt.xlabel('Prediction: '+classes[pred],size=14)
        plt.imshow(test_image)
        # -- GradCam heatmap overlay
        plt.subplot(152)
        plt.title('Grad CAM map', size=14)
        plt.imshow(test_image, alpha=0.5)
        plt.imshow(Grad_CAM_map, cmap="jet",alpha=0.5)

        # -- Thresholded image by GradCam map
        norm_GGC = (Grad_CAM_map-Grad_CAM_map.min())/np.ptp(Grad_CAM_map)
        strength_add = (norm_GGC > (0.6 * norm_GGC.max())) * 1.0
        weaken_sub = (norm_GGC < (1.0 * norm_GGC.mean())) * 1.0
        norm_GGC_sa = np.add(norm_GGC, strength_add)
        norm_GGC_sa_ws = np.subtract(norm_GGC_sa, weaken_sub)
        alphas_ = np.clip(norm_GGC_sa_ws, .0, 1.0)
        alphas_2 = np.expand_dims(alphas_, axis=2)
        test_image_trans = np.concatenate((np.array(test_image)/255,alphas_2),axis=2)
        plt.subplot(153)
        plt.title('Grad CAM Transparency', size=14)
        plt.imshow(test_image_trans)
        
        # -- Guided-Backprop map
        plt.subplot(154)
        plt.title('Guided Backprop map', size=14)
        plt.imshow(test_image, alpha=0.5)
        plt.imshow(Guided_Backprop_map, cmap="jet", alpha=0.5)

        # -- Guide-GradCam map
        Guided_Grad_CAM_map = Grad_CAM_map * Guided_Backprop_map
        plt.subplot(155)
        plt.title('Guided Grad CAM map', size=14)
        plt.imshow(test_image, alpha=0.6)
        plt.imshow(Guided_Grad_CAM_map, cmap="jet",alpha=0.6)

        # -- Name and persist
        if checkpoint_hparam['checkpoint_uuid']:
            general_label = checkpoint_hparam['method'] + '_' + MAPPINGS['encoder_dataset'][checkpoint_hparam['checkpoint_uuid']] + '_'
        else:
            general_label = 'supervised_'
        general_label += MAPPINGS['dataset'][checkpoint_hparam['dataset']] + '_' + checkpoint_hparam['pred_head_structure']

        fig.suptitle(general_label, size=16)

        image_folder = Path(RESULT_PATH, dataset_name, image_path.stem, MAPPINGS['freeze'][checkpoint_hparam['freeze_encoder']])
        if not os.path.exists(image_folder):
            os.makedirs(image_folder)

        fig.savefig(Path(image_folder).joinpath(general_label + '.png'), dpi=100)
        plt.close()
        