In [None]:
from VisionEngine.utils.config import process_config
from VisionEngine.utils import factory
from VisionEngine.utils.plotting import plot_img_attributions

from VisionEngine.utils.eval import embed_images

from VisionEngine.extensions.feature_attribution import calculate_gradients

import os
from dotenv import load_dotenv
from pathlib import Path

import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
# some plotting params
plt.rcParams['pdf.use14corefonts'] = True

In [None]:
GPU = 0  # make sure this is set to 0 if you have a single GPU

In [None]:
env_path = Path('../') / '.env'
load_dotenv(dotenv_path=env_path)

In [None]:
# DHRL-Trained Guppies (Original Data)
checkpoint_path = os.path.join(
    os.getenv("VISIONENGINE_HOME"),
    "checkpoints/guppies_DHRL_model.hdf5"
)
config_file = os.path.join(
    os.getenv("VISIONENGINE_HOME"),
    "VisionEngine/configs/guppies_DHRL_config.json"
)

In [None]:
config = process_config(config_file)

# need to change a few config values
config.data_loader.shuffle = False
config.data_loader.use_generated = False
config.data_loader.use_real = True

In [None]:
data_loader = factory.create(
            "VisionEngine.data_loaders."+config.data_loader.name
            )(config)

In [None]:
with tf.device(f'/device:GPU:{GPU}'):
    model = factory.create(
                "VisionEngine.models."+config.model.name
                )(config)
    model.load(checkpoint_path)
    model.trainable = False

In [None]:
with tf.device(f'/device:GPU:{GPU}'):
    # encode the samples
    z = tf.convert_to_tensor(embed_images(data_loader.get_test_data(), model))

In [None]:
hierarchical_level = 0
encoding_axis = 3
sample_id = 20

In [None]:
with tf.device(f'/device:GPU:{GPU}'):
    attributions = calculate_gradients(z, model, sample_id, encoding_axis, hierarchical_level)

In [None]:
_ = plot_img_attributions(
    image=attributions[1],
    recon_img=attributions[2],
    attribution_mask=attributions[0],
    H=hierarchical_level,
    z_i=encoding_axis,
    cmap=plt.cm.jet,
    overlay_alpha=.5)