In [None]:
from keras.models import load_model
import numpy as np

model = load_model('whale.flukes.4250_classes.weights.best.hdf5')
train_classes = np.load('train_classes.npy')

num_classes = len(train_classes) 

In [None]:
model.summary()

In [None]:
from vis.utils import utils
from keras import activations

# Alternatively we can specify this as -1 since it corresponds to the last layer.
layer_idx = utils.find_layer_idx(model, 'dense_3')

# Swap softmax with linear
model.layers[layer_idx].activation = activations.linear
model = utils.apply_modifications(model)

In [None]:
from vis.visualization import visualize_activation
from matplotlib import pyplot as plt
from matplotlib import font_manager as font_manager
import random

font_manager._rebuild()

%matplotlib inline

plt.rcParams['figure.figsize'] = (18, 6)

whale_index = random.randint(0, num_classes)

print(f'Whale id: {train_classes[whale_index]}')
img = visualize_activation(model, layer_idx, filter_indices=whale_index)

plt.imshow(img)

In [None]:
img = visualize_activation(model, layer_idx, filter_indices=whale_index, max_iter=500, verbose=False)
plt.imshow(img)

In [None]:
from vis.input_modifiers import Jitter

# Jitter 16 pixels along all dimensions to during the optimization process.
img = visualize_activation(model, layer_idx, filter_indices=whale_index, max_iter=100, input_modifiers=[Jitter(16)])
plt.imshow(img)

In [None]:
import numpy as np
from vis.input_modifiers import Jitter

categories = np.random.permutation(num_classes)[:1]

vis_images = []
image_modifiers = [Jitter(16)]
for idx in categories:
    whale_id = train_classes[idx]    
    
    img = visualize_activation(model, layer_idx, filter_indices=idx, max_iter=100, input_modifiers=image_modifiers)
    
    # Reverse lookup index to imagenet label and overlay it on the image.
    img = utils.draw_text(img, whale_id, font="VeraMono.ttf")
    vis_images.append(img)

# Generate stitched images with 5 cols (so it will have 3 rows).
# plt.rcParams['figure.figsize'] = (50, 50)
stitched = utils.stitch_images(vis_images, cols=5)
plt.axis('off')
plt.imshow(stitched)
plt.show()

# Visualizing Conv filters

In [None]:
from vis.visualization import get_num_filters

# The name of the layer we want to visualize
# You can see this in the model definition.
layer_name = 'block4_conv3'
layer_idx = utils.find_layer_idx(model, layer_name)

# Visualize all filters in this layer.
filters = np.arange(get_num_filters(model.layers[layer_idx]))

# Generate input image for each filter.
vis_images = []
for idx in filters:
    img = visualize_activation(model, layer_idx, filter_indices=idx, input_modifiers=[Jitter(16)])
    
    # Utility to overlay text on image.
    img = utils.draw_text(img, 'Filter {}'.format(idx), font="VeraMono.ttf")    
    vis_images.append(img)

# Generate stitched image palette with 8 cols.
stitched = utils.stitch_images(vis_images, cols=8)    
plt.axis('off')
plt.imshow(stitched)
plt.title(layer_name)
plt.show()