# Computing heatmaps for saliency methods

In [None]:
import tensorflow as tf
from tensorflow import keras
import matplotlib
import matplotlib.pyplot as plt
plt.rc('image', cmap='Purples')

import numpy as np
from keras.utils import np_utils
from keras.backend.tensorflow_backend import set_session, clear_session

from scripts.analyzers import run_interpretation_methods
from scripts.models import create_model_llr, train_model

import pickle as pkl

import warnings

import os

In [None]:
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

In [None]:
print(tf.__version__)

In [None]:
print(keras.__version__)

In [None]:
print(np.__version__)

In [None]:
print(matplotlib.__version__)

In [None]:
warnings.filterwarnings("ignore") 

In [None]:
result_dir = '../results/saliency_methods'

## File path

In [None]:
file_path = '../data/data_vary_signal_exact_2021-04-27-21-29-44_pattern_type_5.pkl'

## Load data

In [None]:
def test_saved_data(data_path):
    with open(data_path, 'rb') as f:
        data = pkl.load(f)
        
    return data

In [None]:
keys = ['0.00_0.50_0.50', '0.02_0.49_0.49', '0.04_0.48_0.48', '0.06_0.47_0.47', '0.08_0.46_0.46']
# keys = ['0.00_0.50_0.50', '0.04_0.48_0.48', '0.08_0.46_0.46', '0.12_0.44_0.44', '0.16_0.42_0.42']

In [None]:
methods_params = [('gradient', {}), ('deep_taylor', {}), ('lrp.z', {}), ('lrp.alpha_beta', {'alpha' : 2, 'beta' : 1}), ('pattern.net', {}), ('pattern.attribution', {}), ('input_t_gradient', {})]
methods = [method[0] for method in methods_params]
print(methods)

In [None]:
params = {'input_dim' : 64, 'output_dim' : 2, 'regularizer' : None, 'epochs' : 200, 'runs' : 100, 'save_data' : True}

In [None]:
data = test_saved_data(data_path=file_path)

## Define helper functions

In [None]:
def generate_empty_results_dict():
    return {'results': dict(),
        'method_names': list()}

In [None]:
def dump_results(output_dir : str, results: dict, suffix: str) -> None: 
    output_path = os.path.join(output_dir, f'results_{suffix}.pkl')
    print(f'Output path: {output_path}')
    with open(output_path, 'wb') as f: 
        pkl.dump(results, f)

## 100 runs for all five parameter combinations

In [None]:
results = generate_empty_results_dict()
results['method_names'] = methods

In [None]:
acc_dict = dict()
for weights, data_list in data.items():
    print(f'Weight: {weights}')
    
    results_per_weight = list()
    acc_per_weight = list()
    val_acc_per_weight = list()
    
    for data_run in data_list:
        clear_session()
        
        output = dict()
        data_train = data_run['train']
        data_val = data_run['val']
        
        X_train = data_train['x']
        y_train_bin = data_train['y']
        y_train = np_utils.to_categorical(y_train_bin, num_classes = 2)
        
        X_val = data_val['x']
        y_val_bin = data_val['y']
        y_val = np_utils.to_categorical(y_val_bin, num_classes = 2)

        model = create_model_llr(output_dim = params['output_dim'], activation = 'softmax', regularizer = params['regularizer'], input_dim = params['input_dim'])
        model_trained, acc, val_acc = train_model(model, X_train, y_train, X_val, y_val, epochs = params['epochs'], verbose = False)
        model_weights = model_trained.get_weights()
        
        heatmaps = run_interpretation_methods(model_trained, methods = methods_params, data = X_val, X_train_blob = X_train, normalize = False)
        
        output['model'] = model_weights # TODO write function to load model + weights 
        output['explanations'] = heatmaps
        
        results_per_weight += [output]
        acc_per_weight += [acc[-1]]
        val_acc_per_weight += [val_acc[-1]]
        
    results['results'][weights] = results_per_weight
    acc_dict[weights] = {'acc' : acc_per_weight, 'val_acc' : val_acc_per_weight}


In [None]:
print(len(acc_dict['0.00_0.50_0.50']['acc']))


In [None]:
for key in keys:
    print(f'Final accuracy for {key}: {np.mean(acc_dict[key]["val_acc"]):.2f}')


In [None]:
def extract_pattern_type(data_path: str) -> str:
    return data_path.split('.')[2].split('pattern_type_')[-1]

In [None]:
if params['save_data']:
    pattern_type = f'pattern_type_{extract_pattern_type(data_path=file_path)}'
    dump_results(output_dir = result_dir, results = results, suffix = f'heatmapping_methods_{pattern_type}')
    dump_results(output_dir = result_dir, results = acc_dict, suffix = f'accuracies_{pattern_type}')


