In [1]:
'''
Visualization of the filters of a CNN, via gradient ascent in input space.
This script can run on CPU in a few minutes.
This script is meant for use only and is not discussed in great details as it is beyond the scope of the class.
'''
# using tf.2.1 in colab
# %tensorflow_version 2.x

# from __future__ import print_function

import numpy as np
import time
import tensorflow as tf
from tensorflow.keras.preprocessing.image import save_img
from tensorflow.keras import backend as K
# from google.colab import drive
# stopping eager execution mode to be allowed to use the gradient function
tf.compat.v1.disable_eager_execution()
# sanity check for tf version
print(tf.__version__)

###################################################################################################
# dimensions of the generated pictures for each filter.
###################################################################################################   
img_width = 150
img_height = 150

###################################################################################################
# the name of the layer we want to visualize 
###################################################################################################   
layer_name='conv2d_12'

###################################################################################################
# util function to convert a tensor into a valid image
###################################################################################################   
def deprocess_image(x):
    
    # normalize tensor: center on 0., ensure std is 0.1
    x -= x.mean()
    x /= (x.std() + K.epsilon())
    x *= 0.1

    # clip to [0, 1]
    x += 0.5
    x = np.clip(x, 0, 1)

    # convert to RGB array
    x *= 255
    if K.image_data_format() == 'channels_first':
        x = x.transpose((1, 2, 0))
    x = np.clip(x, 0, 255).astype('uint8')
    return x

###################################################################################################
# Load the model
###################################################################################################   
from tensorflow.keras.models import load_model
model=load_model('./dogs_cats_birds_model_case01_30epoch')

###################################################################################################
# This is the placeholder for the input images
###################################################################################################   
input_img = model.input

###################################################################################################
# Get the symbolic outputs of each "key" layer (we gave them unique names).
###################################################################################################   
layer_dict = dict([(layer.name, layer) for layer in model.layers[1:]])

###################################################################################################
# Utility function to normalize a tensor by its L2 norm
###################################################################################################   
def normalize(x): 
    return x / (K.sqrt(K.mean(K.square(x))) + K.epsilon())

kept_filters = []

###################################################################################################
# Scan through some number of filters...
###################################################################################################   

for filter_index in range(64):

    print('Processing filter %d' % filter_index)
    start_time = time.time()

    # we build a loss function that maximizes the activation
    # of the nth filter of the layer considered
    layer_output = layer_dict[layer_name].output
    if K.image_data_format() == 'channels_first':
        loss = K.mean(layer_output[:, filter_index, :, :])
    else:
        loss = K.mean(layer_output[:, :, :, filter_index])

    # we compute the gradient of the input picture wrt this loss
    grads = K.gradients(loss, input_img)[0]

    # normalization trick: we normalize the gradient
    grads = normalize(grads)

    # this function returns the loss and grads given the input picture
    iterate = K.function([input_img], [loss, grads])

    # step size for gradient ascent
    step = 1.

    # we start from a gray image with some random noise
    if K.image_data_format() == 'channels_first':
        input_img_data = np.random.random((1, 3, img_width, img_height))
    else:
        input_img_data = np.random.random((1, img_width, img_height, 3))
    input_img_data = (input_img_data - 0.5) * 20 + 128

    # we run gradient ascent for 20 steps
    for i in range(100):
        loss_value, grads_value = iterate([input_img_data])
        input_img_data += grads_value * step

        print('Current loss value:', loss_value)

    # decode the resulting input image
    if True:
        
    #if loss_value > 0
        img = deprocess_image(input_img_data[0])
        kept_filters.append((img, loss_value))
    end_time = time.time()
    print('Filter %d processed in %ds' % (filter_index, end_time - start_time))

###################################################################################################
# we will stich the best n^2 filters on a n x n grid.
###################################################################################################
n = 5

###################################################################################################
# the filters that have the highest loss are assumed to be more intuitive
# we will only keep the top n filters.
###################################################################################################   
kept_filters.sort(key=lambda x: x[1], reverse=True)
kept_filters = kept_filters[:n * n]

###################################################################################################
# build a black picture with enough space for
# our n x n filters of size 128 x 128, with a 5px margin in between
###################################################################################################   
margin = 5
width = n * img_width + (n - 1) * margin
height = n * img_height + (n - 1) * margin
stitched_filters = np.zeros((width, height, 3))

###################################################################################################
# fill the picture with our saved filters 
###################################################################################################   
for i in range(n):
    for j in range(n):
        img, loss = kept_filters[i * n + j]
        stitched_filters[(img_width + margin) * i: (img_width + margin) * i + img_width,
                         (img_height + margin) * j: (img_height + margin) * j + img_height, :] = img

###################################################################################################
# save the result to drive
###################################################################################################   
save_img('./dogs_cats_birds_model_case01_30epoch_%dx%d.png' % (n, n), stitched_filters)
print('saved')

2.1.0
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Processing filter 0
Current loss value: -67.29149
Current loss value: -65.5718
Current loss value: -64.15102
Current loss value: -63.005642
Current loss value: -62.089314
Current loss value: -61.350346
Current loss value: -60.72833
Current loss value: -60.18909
Current loss value: -59.70511
Current loss value: -59.2678
Current loss value: -58.866867
Current loss value: -58.495068
Current loss value: -58.144352
Current loss value: -57.8135
Current loss value: -57.497753
Current loss value: -57.191254
Current loss value: -56.90367
Current loss value: -56.635303
Current loss value: -56.38055
Current loss value: -56.146538
Current loss value: -55.92504
Current loss value: -55.7126
Current loss value: -55.507866
Current loss value: -55.298985
Current loss value: -55.093327
Current loss value: -54.894325
Current loss value: -54.691914
Current loss value: -54.489124
Current loss value: -54.288586
Current lo

Current loss value: 14.373785
Current loss value: 18.501123
Current loss value: 22.926733
Current loss value: 27.601156
Current loss value: 32.491634
Current loss value: 37.56093
Current loss value: 42.778572
Current loss value: 48.124302
Current loss value: 53.581272
Current loss value: 59.135014
Current loss value: 64.77197
Current loss value: 70.481575
Current loss value: 76.25938
Current loss value: 82.10327
Current loss value: 88.010414
Current loss value: 93.980736
Current loss value: 100.01429
Current loss value: 106.10825
Current loss value: 112.26314
Current loss value: 118.47889
Current loss value: 124.75521
Current loss value: 131.08818
Current loss value: 137.47234
Current loss value: 143.90857
Current loss value: 150.3972
Current loss value: 156.93393
Current loss value: 163.5155
Current loss value: 170.13979
Current loss value: 176.80447
Current loss value: 183.50844
Current loss value: 190.24712
Current loss value: 197.0177
Current loss value: 203.82045
Current loss valu

Current loss value: -51.377247
Current loss value: -51.196247
Filter 5 processed in 0s
Processing filter 6
Current loss value: 54.56442
Current loss value: 57.013348
Current loss value: 58.89477
Current loss value: 60.40523
Current loss value: 61.727184
Current loss value: 62.950504
Current loss value: 64.117905
Current loss value: 65.248924
Current loss value: 66.35929
Current loss value: 67.45064
Current loss value: 68.53327
Current loss value: 69.60459
Current loss value: 70.65773
Current loss value: 71.71208
Current loss value: 72.76583
Current loss value: 73.81251
Current loss value: 74.86426
Current loss value: 75.91102
Current loss value: 76.96963
Current loss value: 78.038605
Current loss value: 79.10285
Current loss value: 80.17749
Current loss value: 81.26085
Current loss value: 82.34068
Current loss value: 83.43059
Current loss value: 84.512535
Current loss value: 85.59453
Current loss value: 86.668884
Current loss value: 87.735344
Current loss value: 88.7971
Current loss va

Current loss value: 349.1864
Current loss value: 353.15488
Current loss value: 357.12576
Current loss value: 361.09625
Current loss value: 365.06882
Current loss value: 369.04388
Current loss value: 373.01764
Current loss value: 376.9941
Current loss value: 380.97153
Current loss value: 384.94974
Current loss value: 388.92902
Current loss value: 392.91025
Filter 8 processed in 0s
Processing filter 9
Current loss value: -10.127985
Current loss value: -8.470048
Current loss value: -7.2050056
Current loss value: -6.2680445
Current loss value: -5.5261884
Current loss value: -4.912182
Current loss value: -4.392767
Current loss value: -3.9374325
Current loss value: -3.5263047
Current loss value: -3.1561468
Current loss value: -2.8060493
Current loss value: -2.4611328
Current loss value: -2.1149647
Current loss value: -1.779764
Current loss value: -1.4399165
Current loss value: -1.101472
Current loss value: -0.7733988
Current loss value: -0.43116856
Current loss value: -0.09920066
Current los

Current loss value: 475.7258
Current loss value: 481.63797
Current loss value: 487.55652
Current loss value: 493.4804
Current loss value: 499.40698
Current loss value: 505.33884
Current loss value: 511.27466
Current loss value: 517.21545
Current loss value: 523.1598
Current loss value: 529.1085
Filter 11 processed in 0s
Processing filter 12
Current loss value: -50.309483
Current loss value: -48.06437
Current loss value: -46.295876
Current loss value: -45.005974
Current loss value: -44.102623
Current loss value: -43.471455
Current loss value: -42.979767
Current loss value: -42.59137
Current loss value: -42.259327
Current loss value: -41.96497
Current loss value: -41.70263
Current loss value: -41.469917
Current loss value: -41.23866
Current loss value: -41.02758
Current loss value: -40.82459
Current loss value: -40.633854
Current loss value: -40.442123
Current loss value: -40.263134
Current loss value: -40.069313
Current loss value: -39.89125
Current loss value: -39.70917
Current loss va

Current loss value: 37.390312
Current loss value: 39.28875
Current loss value: 40.83701
Current loss value: 42.195694
Current loss value: 43.450645
Current loss value: 44.628807
Current loss value: 45.77772
Current loss value: 46.891644
Current loss value: 47.988434
Current loss value: 49.06416
Current loss value: 50.140114
Current loss value: 51.217308
Current loss value: 52.28703
Current loss value: 53.355198
Current loss value: 54.442135
Current loss value: 55.52515
Current loss value: 56.628025
Current loss value: 57.731487
Current loss value: 58.847126
Current loss value: 59.972645
Current loss value: 61.105137
Current loss value: 62.249596
Current loss value: 63.40202
Current loss value: 64.56128
Current loss value: 65.72118
Current loss value: 66.88442
Current loss value: 68.04579
Current loss value: 69.21059
Current loss value: 70.35923
Current loss value: 71.49744
Current loss value: 72.64472
Current loss value: 73.75785
Current loss value: 74.87407
Current loss value: 75.9674

Current loss value: -24.094866
Current loss value: -23.857542
Current loss value: -23.636414
Current loss value: -23.40082
Current loss value: -23.178198
Current loss value: -22.950314
Filter 17 processed in 0s
Processing filter 18
Current loss value: -69.16905
Current loss value: -67.12868
Current loss value: -65.50125
Current loss value: -64.268135
Current loss value: -63.33438
Current loss value: -62.619286
Current loss value: -62.03743
Current loss value: -61.537277
Current loss value: -61.1102
Current loss value: -60.714436
Current loss value: -60.36274
Current loss value: -60.0371
Current loss value: -59.727074
Current loss value: -59.43422
Current loss value: -59.15165
Current loss value: -58.884953
Current loss value: -58.62958
Current loss value: -58.377117
Current loss value: -58.138588
Current loss value: -57.90017
Current loss value: -57.66912
Current loss value: -57.443882
Current loss value: -57.212395
Current loss value: -56.99951
Current loss value: -56.77057
Current lo

Current loss value: -72.633575
Current loss value: -70.100494
Current loss value: -68.06512
Current loss value: -66.53104
Current loss value: -65.45949
Current loss value: -64.71629
Current loss value: -64.18771
Current loss value: -63.781
Current loss value: -63.44838
Current loss value: -63.180725
Current loss value: -62.935734
Current loss value: -62.720776
Current loss value: -62.52984
Current loss value: -62.346214
Current loss value: -62.1784
Current loss value: -62.01681
Current loss value: -61.865887
Current loss value: -61.719883
Current loss value: -61.57589
Current loss value: -61.43956
Current loss value: -61.302433
Current loss value: -61.17501
Current loss value: -61.043083
Current loss value: -60.915733
Current loss value: -60.793777
Current loss value: -60.6637
Current loss value: -60.54249
Current loss value: -60.42365
Current loss value: -60.308014
Current loss value: -60.187176
Current loss value: -60.06433
Current loss value: -59.95431
Current loss value: -59.831406

Current loss value: 48.981667
Current loss value: 49.31854
Current loss value: 49.649693
Current loss value: 49.98605
Current loss value: 50.315968
Current loss value: 50.656094
Current loss value: 50.986134
Filter 23 processed in 0s
Processing filter 24
Current loss value: -18.840302
Current loss value: -17.475311
Current loss value: -16.098818
Current loss value: -14.6859045
Current loss value: -13.221677
Current loss value: -11.69892
Current loss value: -10.120727
Current loss value: -8.483553
Current loss value: -6.7854548
Current loss value: -5.025988
Current loss value: -3.2011094
Current loss value: -1.3099533
Current loss value: 0.65162027
Current loss value: 2.688021
Current loss value: 4.800401
Current loss value: 6.9916644
Current loss value: 9.268354
Current loss value: 11.63274
Current loss value: 14.0898695
Current loss value: 16.63431
Current loss value: 19.267761
Current loss value: 21.991596
Current loss value: 24.798243
Current loss value: 27.68845
Current loss value:

Current loss value: 404.0105
Current loss value: 408.9827
Current loss value: 413.9588
Current loss value: 418.93942
Current loss value: 423.92413
Current loss value: 428.91214
Filter 26 processed in 0s
Processing filter 27
Current loss value: -86.387024
Current loss value: -82.75461
Current loss value: -79.649025
Current loss value: -76.98122
Current loss value: -74.65257
Current loss value: -72.502686
Current loss value: -70.46439
Current loss value: -68.49911
Current loss value: -66.60694
Current loss value: -64.725746
Current loss value: -62.858578
Current loss value: -61.03021
Current loss value: -59.190754
Current loss value: -57.374294
Current loss value: -55.549194
Current loss value: -53.734207
Current loss value: -51.90892
Current loss value: -50.09502
Current loss value: -48.26259
Current loss value: -46.422188
Current loss value: -44.60924
Current loss value: -42.755924
Current loss value: -40.920296
Current loss value: -39.08435
Current loss value: -37.264378
Current loss 

Current loss value: 186.96864
Current loss value: 189.991
Current loss value: 193.01526
Current loss value: 196.04396
Current loss value: 199.0754
Filter 29 processed in 0s
Processing filter 30
Current loss value: -49.999058
Current loss value: -48.029682
Current loss value: -46.334026
Current loss value: -44.883152
Current loss value: -43.65139
Current loss value: -42.6009
Current loss value: -41.681656
Current loss value: -40.865982
Current loss value: -40.122837
Current loss value: -39.43724
Current loss value: -38.788284
Current loss value: -38.173424
Current loss value: -37.583504
Current loss value: -37.011326
Current loss value: -36.452263
Current loss value: -35.910236
Current loss value: -35.371758
Current loss value: -34.84548
Current loss value: -34.31646
Current loss value: -33.801643
Current loss value: -33.284397
Current loss value: -32.775764
Current loss value: -32.268616
Current loss value: -31.763956
Current loss value: -31.2635
Current loss value: -30.768904
Current 

ValueError: slice index 32 of dimension 3 out of bounds. for 'strided_slice_32' (op: 'StridedSlice') with input shapes: [?,72,72,32], [4], [4], [4] and with computed input tensors: input[1] = <0 0 0 32>, input[2] = <0 0 0 33>, input[3] = <1 1 1 1>.

In [6]:
layer_dict = dict([(layer.name, layer) for layer in model.layers[1:]])
print(layer_dict)

{'activation_17': <tensorflow.python.keras.saving.saved_model.load.Activation object at 0x000001713002B908>, 'max_pooling2d_11': <tensorflow.python.keras.saving.saved_model.load.MaxPooling2D object at 0x000001713002D288>, 'conv2d_12': <tensorflow.python.keras.saving.saved_model.load.Conv2D object at 0x0000017130031088>, 'activation_18': <tensorflow.python.keras.saving.saved_model.load.Activation object at 0x0000017130031848>, 'max_pooling2d_12': <tensorflow.python.keras.saving.saved_model.load.MaxPooling2D object at 0x00000171300302C8>, 'conv2d_13': <tensorflow.python.keras.saving.saved_model.load.Conv2D object at 0x0000017130030DC8>, 'activation_19': <tensorflow.python.keras.saving.saved_model.load.Activation object at 0x00000171300327C8>, 'max_pooling2d_13': <tensorflow.python.keras.saving.saved_model.load.MaxPooling2D object at 0x0000017130033188>, 'conv2d_14': <tensorflow.python.keras.saving.saved_model.load.Conv2D object at 0x0000017130033F88>, 'activation_20': <tensorflow.python.