# Visualizing Convolutional Layers in a Trained VGG Network

### What have the various feature maps in a CNN been trianed to look for?

## Description

---

## Methods

---

### Imports

In [None]:
import torch
import torch.optim as optim
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

### Defining the Model

In [None]:

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

vgg = torchvision.models.vgg16(pretrained=True).to(device)

for param in vgg.parameters():
    param.requires_grad = False

vgg.activation = {}
def get_activation(name):
    def hook(model, input, output):
        vgg.activation[name] = output.squeeze()
    return hook
# mapping between layers indexed in model and layer names
layers = {'conv1_1': 0,
          'conv2_1': 5, 
          'conv3_1': 10, 
          'conv4_1': 17,
          'conv5_1': 24,}

vgg.features[layers['conv1_1']].register_forward_hook(get_activation('Conv1_1'))
vgg.features[layers['conv2_1']].register_forward_hook(get_activation('Conv2_1'))
vgg.features[layers['conv3_1']].register_forward_hook(get_activation('Conv3_1'))
vgg.features[layers['conv4_1']].register_forward_hook(get_activation('Conv4_1'))
vgg.features[layers['conv5_1']].register_forward_hook(get_activation('Conv5_1'))

### Helpers

In [None]:
def load_image(img_path):
    image = Image.open(img_path).convert('RGB')
    in_transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.485, 0.456, 0.406), 
                                             (0.229, 0.224, 0.225))])
    image = in_transform(image).unsqueeze(0)
    
    return image

In [None]:
def im_convert(tensor):
    """ Display a tensor as an image. """
    
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    return image

### Putting it all together

In [None]:
def show(layer_to_visualize='Conv1_1',
        filers_to_show=(2,2),
        resolution=100,
        steps=200, lr=0.01,
        shift=0, dist=(-1., 1.),
        random_state=None):

  random = np.random.RandomState(random_state) if random_state is not None else np.random
  fig, axs = plt.subplots(filers_to_show[0],
                          filers_to_show[1],
                          figsize=(filers_to_show[1]*2,filers_to_show[0]*2+0.3),
                          constrained_layout=True)
  if hasattr(axs, '__len__') == False:
    axs = np.array([axs])
  axs = axs.reshape(filers_to_show[0], filers_to_show[1])

  for x in range(filers_to_show[0]):
    for y in range(filers_to_show[1]):
      map_number = (x * (filers_to_show[1])) + y + shift

      target_image = torch.from_numpy(random.uniform(*dist,
                                      size=(3,resolution,resolution))).unsqueeze(0)
      target = target_image.clone().type(torch.FloatTensor).to(device).requires_grad_(True)

      optimizer = optim.Adam([target], lr=lr)

      vgg.eval()

      for ii in range(1, steps+1):
        vgg.forward(target)
        output = vgg.activation[layer_to_visualize][map_number]

        out = output.detach()
        expected = np.empty((*list(out.size()),))
        expected.fill(float(torch.max(out)))
        expected = expected**2
        expected = torch.from_numpy(expected).to(device).requires_grad_(False)

        loss = torch.mean((output - expected)**2)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

      axs[x][y].imshow(im_convert(target))
      axs[x][y].axis('off')
      axs[x][y].annotate(map_number,
            xy=(0, 0), color='white',
            fontweight='bold',
            verticalalignment='top',
            bbox=dict(boxstyle="round", fc="black"))
  
  fig.suptitle(layer_to_visualize, color='white', fontweight='bold')
  fig.patch.set_facecolor('black')
  plt.show()


### Testing it out

In [None]:
show('Conv1_1', (2,5), resolution=100, shift=20, random_state=100, dist=(0., 2.))
show('Conv2_1', (2,5), resolution=110, shift=20, random_state=100, dist=(0., 1.))
show('Conv3_1', (2,5), resolution=120, shift=20, random_state=100, dist=(-0.1, 0.2))
show('Conv4_1', (2,5), resolution=140, shift=20, random_state=100, dist=(-0.5, 0.5))
show('Conv5_1', (2,5), resolution=160, shift=25, random_state=100, dist=(-0.2, 0.2))

## Results

---