# Imports

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline

In [None]:
import json, pickle, array, datetime, pickle, os

from statistics import mean, stdev

import tensorflow as tf
from tensorflow.keras.models import load_model
from dataset.dataset_padchest import *

from deap import base, creator, tools

from collections import Counter, defaultdict
import seaborn as sns

import numpy as np

import matplotlib

from XAI.evaluation import *
from XAI.pareto_front import *
from XAI.image_utils import *
from XAI.genetic_algorithm import *

In [None]:
lat_dim = 500

@tf.function()
def sampling(args):
    z_mean, z_log_sigma = args
    epsilon = tf.random.normal(shape=(tf.shape(z_mean)[0], lat_dim), mean=0., stddev=0.1)
    return z_mean + tf.math.exp(z_log_sigma) * epsilon

# Data and model load

In [None]:
with open('config.json', 'r') as f:
    config = json.load(f)

In [None]:
data = Dataset(config)

with open("dataset.pkl", 'rb') as f:
    data = pickle.load(f)

In [None]:
LABELS = sorted(config["padchest"]["label_names"], key=config["padchest"]["label_names"].get)

# Load models

In [None]:
MODEL_VER = 'v2_MOCVAE'
MODEL_PATH = '/results_padchest_' + MODEL_VER

In [None]:
autoencoder = load_model(MODEL_PATH + '/models/e_best_autoencoder.h5', custom_objects={'sampling': sampling})
encoder = load_model(MODEL_PATH + '/models/e_best_encoder.h5', custom_objects={'sampling': sampling})
decoder = load_model(MODEL_PATH + '/models/e_best_decoder.h5', custom_objects={'sampling': sampling})
classifier = load_model(MODEL_PATH + '/models/e_best_classifier.h5', custom_objects={'sampling': sampling})

# Optimizing all possibilities

In [None]:
n_tests = 10
out_base_path = 'XAI_evaluation_undersampling/'

full_dict = dict()

# Iterate over base labels
for label_base in LABELS:
    
    os.makedirs(out_base_path + 'Input_Imgs/' + label_base, exist_ok=True)
    
    full_dict[label_base] = dict()
    
    for label_obj in LABELS:
        full_dict[label_base][label_obj] = dict()
        full_dict[label_base][label_obj]['means'] = dict()
        full_dict[label_base][label_obj]['stds'] = dict()
        for label_means in LABELS:
            full_dict[label_base][label_obj]['means'][label_means] = []
        for label_stds in LABELS:
            full_dict[label_base][label_obj]['stds'][label_stds] = []
        full_dict[label_base][label_obj]['histories'] = []
        full_dict[label_base][label_obj]['population'] = []
        full_dict[label_base][label_obj]['img_ids'] = []
    
    img_id = 0
    
    # Iterate n_test times to reduce variability
    for i in range(n_tests):
        enc=0
        while enc==0:
            # Find base case from a specific label
            if LABELS[np.argmax(data.y_train[img_id])] == label_base:

                # Find base case with confidence > 80%
                input_img = data.X_train[img_id].reshape(1,224,224,1)
                latent_code = encoder.predict(input_img)
                latent_code = np.copy(latent_code[2])
                classification = classifier.predict(latent_code)
                conf = classification[0][LABELS.index(label_base)]

                if conf > 0.8:
                    enc=1
                else:
                    img_id+=1

            else:
                img_id+=1

        # Save Image
        img = data.X_train[img_id].reshape(224,224)
        matplotlib.image.imsave(out_base_path + 'Input_Imgs/' + label_base + '/'
                    + label_base + '_' + str(i) + '.png', img, cmap='gray')
        
        # Iterate over objective labels
        for label_obj in LABELS:
            if label_base != label_obj:
                print("\n\nLabel base:", label_base, '- Label objective:', label_obj)
                print("Iteration", i, "img_id:", img_id)
                print(datetime.now())

                input_img = data.X_train[img_id].reshape(1,224,224,1)
                latent_code = encoder.predict(input_img)
                # Use sampled z as latent space
                latent_code = latent_code[2]

                creator, toolbox = deap_configuration(latent_code)

                # Genetic optimization
                pop = toolbox.population(n=80)

                final_set, history = spea2(base_ind=latent_code, classifier=classifier,
                                           label_obj=label_obj, lab_list=LABELS,
                                           pop=pop, toolbox=toolbox, num_gens=250,
                                           sel_factor_pop=80, sel_factor_arch=40,
                                           mut_prob=MUTPB, mutrevpb=MUTREVPB, indrevpb=INDREVPB)
                
                # Purge final population keeping only individuals that missclasify
                class_ch = get_class_changes(latent_code, classifier, final_set, LABELS)
                final_set_purged = [final_set[i] for i in range(len(class_ch)) if class_ch[i] == 1]
                
                if len(final_set_purged)>0:
                    # Calculate change in classification
                    changes_dict = get_conf_changes_dict(data, encoder, classifier, label_obj, LABELS, final_set_purged)

                    means = {key: np.mean(value['confidence_ch']) for key, value in changes_dict.items()}
                    stds = {key: np.std(value['confidence_ch']) for key, value in changes_dict.items()}

                    for label in means.keys():
                        full_dict[label_base][label_obj]['means'][label].append(means[label])

                    for label in stds.keys():
                        full_dict[label_base][label_obj]['stds'][label].append(stds[label])
                    
                full_dict[label_base][label_obj]['histories'].append(history)
                full_dict[label_base][label_obj]['population'].append(final_set)
                full_dict[label_base][label_obj]['img_ids'].append(img_id)
                
                class_chs = get_class_changes_obj(latent_code, classifier, final_set, label_obj, LABELS)
                fitnesses = evaluate_pop(final_set, latent_code, classifier, label_obj, LABELS)
                fitnesses_purged = [list(fitnesses[j][1:]) for j in range(len(class_chs)) if class_chs[j] == 1]
                final_set_purged = [final_set[j] for j in range(len(class_chs)) if class_chs[j] == 1]
                
                save_path = out_base_path + 'MOCVAE/min_num/' + label_base + '-' + label_obj + '_' + str(i)
                idx_min_num = min(range(len(fitnesses_purged)), key=lambda i: fitnesses_purged[i][1])
                min_num_ind = final_set_purged[idx_min_num]
                plot_ind_changes(data, encoder, decoder, classifier,
                                 img_id, min_num_ind, LABELS, save_path)
                
                save_path = out_base_path + 'MOCVAE/min_mag/' + label_base + '-' + label_obj + '_' + str(i)
                idx_min_mag = min(range(len(fitnesses_purged)), key=lambda i: fitnesses_purged[i][0])
                min_mag_ind = final_set_purged[idx_min_mag]
                plot_ind_changes(data, encoder, decoder, classifier,
                                 img_id, min_mag_ind, LABELS, save_path)
                
        img_id+=1



Label base: cardiomegaly - Label objective: aortic elongation
Iteration 0 img_id: 7
2024-05-08 07:51:53.346143


In [None]:
with open("full_dict_undersampling_" + MODEL_VER + ".pkl", 'wb') as f:
        pickle.dump(full_dict, f)