## In this notebook, we will load a CNN model called the [VGG16](https://arxiv.org/abs/1409.1556) model, with pre-trained weights on the [ImageNet](http://www.image-net.org) dataset. We will use *activation maximization* to visualize the features that this model has learnt. Before starting this exercise, please go to: <br/> Runtime -> Change runtime type <br/> and make sure that the notebook is set to use *Python 3* and *GPU*.

### Step 1 Install and import all dependencies.

In [0]:
import warnings
warnings.filterwarnings('ignore')
!pip install Keras-Applications
!pip install --quiet --force-reinstall git+https://github.com/raghakot/keras-vis.git -U
!pip install --quiet --force-reinstall scipy==1.2

In [0]:
from keras.applications.vgg16 import VGG16
from vis.visualization import visualize_activation
from vis.utils import utils
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, Activation, Flatten
from keras import activations
import matplotlib.pyplot as plt

### Step 2 First, let's understand the ImageNet dataset. Below, we are downloading descriptions for each of the 1000 classes of images in ImageNet. The descriptions can also be viewed [here.](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a)

In [0]:
import pickle
from urllib.request import urlopen
classidx_to_description_dict = pickle.load(urlopen('https://gist.githubusercontent.com/yrevar/6135f1bd8dcf2e0cc683/raw/d133d61a09d7e5a3b36b8c111a8dd5c4b5d560ee/imagenet1000_clsid_to_human.pkl'))

### Step 3 Download and load the pre-trained model. For this exercise, we are using the [VGG16](https://arxiv.org/abs/1409.1556) architecture with weights pre-trained on [ImageNet](http://www.image-net.org) dataset. Note that in the last layer (predictions dense layer), the model output is classified into 1000 classes.

In [0]:
model = VGG16()
model.summary()

### Step 4 Use activation maximization to visualize the images that maximize output of each filter in the prediction layer. Prediction layer is the layer in the model with 1000 filters, each representing a separate class of dataset.
### Question: What features has the model learnt for class *Volcano* and *Baseball*, separately? How about other classes? Share with your table the interesting visualizations you get!

In [0]:
layer_name = "predictions" # The prediction layer. Refer model.summary() to get names of all layers.

In [0]:
def get_layer_index_from_layer_name(layer_name):
  for idx, layer in enumerate(model.layers):
    if layer.name == layer_name:
        return idx

layer_index = get_layer_index_from_layer_name(layer_name)

In [0]:
# Modify activation of last layer to linear
model.layers[layer_index].activation = activations.linear
model = utils.apply_modifications(model)

In [0]:
filter_idx = 980 # Index of filter you want to visualize. In this case 63 is Volcano. For the predictions layer, this corresponds to class indexes seen in classidx_to_description_dict

In [0]:
fig=plt.figure(figsize=(20, 20))
print("Visualizing image for class index " + str(filter_idx) + ": " + classidx_to_description_dict[filter_idx])
img = visualize_activation(model, layer_index, filter_indices=[filter_idx], max_iter=100, tv_weight=1., lp_norm_weight=0.)
ax2 = fig.add_subplot(2, 2, 1)
plt.imshow(img)

### Step 5 Try modifying *layer_name* and *filter_idx* above to explore the features learned by different layers.
### Question: What features has model learnt for layer block1_conv1, block2_conv1, block3_conv1, block4_conv1 and block5_conv1, separately? Experiment with other layers yourself: From shallow to deeper layers, how are the features learnt evolving?

In [0]:
layer_name = "block1_conv1" # Name of the layer whose filters you want to visualize. Refer model.summary() to get names of all layers.
layer_index = get_layer_index_from_layer_name(layer_name)
model.layers[layer_index].activation = activations.linear
model = utils.apply_modifications(model)

columns = 5
rows = 2
fig = plt.figure(figsize=(20, 20))

for i in range(0,10):
  img = visualize_activation(model, layer_index, filter_indices=[i], max_iter=100, tv_weight=1., lp_norm_weight=0.)
  fig.add_subplot(rows, columns, i+1)
  plt.imshow(img)
plt.show()
