In [None]:
from VisionEngine.datasets import guppies
from VisionEngine.utils.config import process_config
from VisionEngine.utils import factory
import sys
import os
from PIL import Image
from itertools import product
from dotenv import load_dotenv
from pathlib import Path

import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
pwd

In [None]:
checkpoint_path = '/home/etheredge/Workspace/VisionEngine/checkpoints/guppy_periodic/2020-213-11/guppy_periodic.hdf5'

In [None]:
config_file = '/home/etheredge/Workspace/VisionEngine/VisionEngine/configs/guppy_periodic_config.json'
config = process_config(config_file)

In [None]:
env_path = Path('../') / '.env'
load_dotenv(dotenv_path=env_path)

In [None]:
model = factory.create(
            "VisionEngine.models."+config.model.name
            )(config)

In [None]:
model.load(checkpoint_path)

In [None]:
# config.data_loader.use_generated = True
# config.data_loader.use_real = False

In [None]:
data_loader = factory.create(
            "VisionEngine.data_loaders."+config.data_loader.name
            )(config)

In [None]:
def plot_im(img):
    if config.model.last_activation == 'tanh':
        img * 0.5 + 0.5
        return img
    else:
        return img

def plot_img_attributions(image,
                          attribution_mask,
                          H=0,
                          z_i=0,
                          cmap=None,
                          overlay_alpha=0.4):

    fig, axs = plt.subplots(nrows=1, ncols=3, squeeze=False, figsize=(12, 4))

    axs[0, 1].set_title('Original sample Output')
    axs[0, 1].imshow(image)
    axs[0, 1].axis('off')

    axs[0, 2].set_title(f'Attribution mask: {H}, {z_i}')
    axs[0, 2].imshow(attribution_mask, cmap=cmap)
    axs[0, 2].axis('off')

    axs[0, 3].set_title(f'Overlay: {H}, {z_i}')
    axs[0, 3].imshow(attribution_mask, cmap=cmap)
    axs[0, 3].imshow(image, alpha=overlay_alpha)
    axs[0, 3].axis('off')

    plt.tight_layout()
    return fig

def plot_overlay(image,
                 attribution_mask,
                 H=0,
                 z_i=0,
                 cmap=None,
                 overlay_alpha=0.4):

    fig, axs = plt.subplots(nrows=1, ncols=1, squeeze=False, figsize=(4, 4))
    axs[0, 0].set_title(f'Overlay: {H}, {z_i}')
    axs[0, 0].imshow(attribution_mask, cmap=cmap)
    axs[0, 0].imshow(image, alpha=overlay_alpha)
    axs[0, 0].axis('off')
    plt.tight_layout()
    return fig

In [None]:
def embed_images(x):
    outputs = [
        model.model.get_layer('normal_variational').output,
        model.model.get_layer('normal_variational_1').output,
        model.model.get_layer('normal_variational_2').output,
        model.model.get_layer('normal_variational_3').output
    ]
    encoder = tf.keras.Model(model.model.inputs, outputs)
    return encoder.predict(x)

def reconstruct_images(x):
    return model.model.predict(x)

In [None]:
# encode the samples
x = data_loader.get_train_data().shuffle(1000).take(1)
z = tf.convert_to_tensor(embed_images(x))

# get the original samples reconstruction
x_hat = reconstruct_images(x)

In [None]:
list(x)[0][0].shape

In [None]:
sample_id = 8  # < 16 we're only grabbing one batch at a time
plt.subplot(121)
plt.imshow(list(x)[0][0][sample_id])
plt.subplot(122)
plt.imshow(x_hat[sample_id])

In [None]:
hierarchical_level = 3
encoding_axis = 0

In [None]:
z = tf.convert_to_tensor(z)

In [None]:
model.trainable = False

In [None]:
def interpolate_latentvar(Z, H, z_i, alphas, zdim=10):
    mods = []
    for h in range(len(Z)):
        if h == H:
            z = Z[h]
            mod = tf.concat([
                tf.repeat(
                    [z], 10, axis=0)[:,-1:z_i],
                alphas[:, tf.newaxis],
                tf.repeat(
                    [z], 10, axis=0)[:, z_i:-1]], axis=1)
            mods.append(mod)
        else:
            z = Z[h]
            mod = tf.repeat([z], 10, axis=0)
            mods.append(mod)

    return mods

def compute_gradients(latent_vars):
    with tf.GradientTape() as tape:
        tape.watch(latent_vars)
        logits = model.decoder([latent_vars[0], latent_vars[1], latent_vars[2], latent_vars[3]])
        images = tf.nn.tanh(logits)
    return tape.gradient(images, logits)


def integral_approximation(gradients):
    # riemann_trapezoidal
    grads = (gradients[:-1] + gradients[1:]) / tf.constant(2.0)
    integrated_gradients = tf.math.reduce_mean(grads, axis=0)
    return integrated_gradients


# @tf.function
def integrated_gradients(encoding, H=0, z_i=0, m_steps=300, batch_size=10, lim=1.):

    # Generate traversal steps
    traversal_steps = tf.linspace(start=0.0, stop=lim, num=m_steps)
    

    # Accumulate gradients across batches
    integrated_gradients = 0.0

    # Batch traversals
    ds = tf.data.Dataset.from_tensor_slices(traversal_steps).batch(batch_size)

    for batch in ds:
        batch_interpolated_inputs = interpolate_latentvar(Z=encoding, H=H, z_i=z_i, alphas=batch)
        batch_gradients = compute_gradients(batch_interpolated_inputs)
        return batch_gradients
        integrated_gradients += integral_approximation(gradients=batch_gradients)
    
    return tf.abs(lim) * integrated_gradients

In [None]:
for att in attributions:
    try:
        print(att.shape)
    except AttributeError:
        print(att)

In [None]:
# compute atributions by integrating the gradients
attributions = integrated_gradients(z[:,sample_id,:], H=hierarchical_level, z_i=encoding_axis)

# visualize the attributions
attributions_mask = tf.reduce_sum(tf.math.abs(attributions), axis=-1)

_ = plot_img_attributions(image=x_hat[sample_id],
                          attribution_mask=attributions_mask,
                          H=hierarchical_level,
                          z_i=encoding_axis,
                          cmap=plt.cmap.vridis,
                          overlay_alpha=0.4)
