In [None]:
import os; os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # filter out info and warning messages
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import tensorflow as tf

In [None]:
def deprocess_img(processed_img):
    '''
    Takes a preprocessed image used by VGG-16 and returns the corresponding original image. This is done
    by adding the mean pixel values, reversing the color channel back to RGB and clipping the values. 
    Arguments: 
        processed_img:tensor
            Preprocessed image in shape(1,224,224)
    Returns: 
        img:tensor
            Original image in tf.uint8 format with shape(224,224,3).
    '''
    imagenet_means = [103.939, 116.779, 123.68]
    means = tf.reshape(tf.constant(imagenet_means), [1, 1, 3])
    img = processed_img + means
    img = tf.reverse(img, axis=[-1])
    img = tf.clip_by_value(img, 0, 255)
    img = tf.cast(img, tf.uint8)

    return img

In [None]:
def show_example(img, sal_map):
    '''
    Display the image with its saliency map overlayed and both separately.
    Three subplots are generated using matplotlib.
    Arguments: 
        img:tensor
            Preprocessed image.
        sal_map:tensor
            Saliency map corresponding to the input image.
    '''
    fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(8, 4))

    ax[0].imshow(deprocess_img(img))
    ax[0].axis('off')

    ax[1].imshow(sal_map, cmap='plasma', vmin=0, vmax=1) # plasma
    ax[1].axis('off')

    ax[2].imshow(deprocess_img(img))
    ax[2].imshow(sal_map, cmap='plasma', vmin=0, vmax=1, alpha=0.7, interpolation='bilinear') # plasma
    ax[2].axis('off')
    
    plt.tight_layout()
    plt.show();

In [None]:
# Load the data
salicon_train_ds_path = f'./SALICON/tfds_salicon/train2014'
salicon_train_ds = tf.data.Dataset.load(salicon_train_ds_path, compression='GZIP')

# Show examples for 3 images of the train dataset
for img, _, label in salicon_train_ds.take(3):
    show_example(img, label)

In [None]:
# Load the data
salicon_val_ds_path = f'./SALICON/tfds_salicon/val2014'
salicon_val_ds = tf.data.Dataset.load(salicon_val_ds_path, compression='GZIP')

# Show examples for 3 images of the validation dataset
for img, _, sal_map in salicon_val_ds.take(3):
    show_example(img, sal_map)

In [None]:
# Load the data
capgaze1_ds_path = f'./capgaze1/tfds_capgaze1'
capgaze1_ds = tf.data.Dataset.load(capgaze1_ds_path, compression='GZIP')

# Show examples for 3 images of the capgaze dataset, which we use to test the ability of generalization
for img, _, label in capgaze1_ds.take(3):
    show_example(img, label)