In [None]:
%load_ext lab_black
%matplotlib inline

<h1> Trained model Manipulation </h1>
In this lab we are going to see what we can do with a pre-trained classifier model (other than classify images) and hopefully get a better idea of what is going on inside our models!<br>
First we will try and visulise what our traied network is "looking" at when it makes a classification <br>
To do this we are going to take a pre-trained model from pytorch's "Model Zoo", VGG19 in this case, and backprop the gradients from a single output to the input image and visulise the magnitudes of the gradients <br>
Next we will look what happens when we change our input image with these gradients
<img src="https://glassboxmedicine.files.wordpress.com/2019/06/greater-swiss-mountain-dog.jpeg" width="800" align="center">

[CAM](https://glassboxmedicine.com/2019/06/11/cnn-heat-maps-class-activation-mapping-cam/)

In [None]:
import torch
import torch.nn as nn
import torch.utils.data
import torchvision.datasets as Datasets
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torch.nn.functional as F
import torchvision.utils as vutils
import torchvision.models as models

from IPython.display import clear_output
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import copy
import os

In [None]:
# Set device to GPU if avaliable
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.cuda.set_device(device)

The VGG19 model we will be using was trained using the ImageNet challenge dataset, let's load a file of the class names

In [None]:
# Load a list of the 1000 ImageNet classes from the ImageNet challenge
# https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a
image_net_classes = np.loadtxt("Imagenet_classes.csv", dtype=str, delimiter=", ")

In [None]:
image_net_classes[:10]

Load our test image to experiment with

In [None]:
# Load our test image
test_img = Image.open("Pupper.jpg").convert("RGB")
# Transform the PIL image to a tensor and normalize using the means and std used to train the VGG16 model
transform = T.Compose(
    [
        T.Resize(512),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
# Make sure you add on the batch dimension
test_img1 = transform(test_img).unsqueeze(0).to(device)

A few helper functions

In [None]:
# This Function will allow us to scale an images pixel values to a value between 0 and 1
def normalize_img(img):
    mins = img.min(0, keepdims=True).min(1, keepdims=True)
    maxs = img.max(0, keepdims=True).max(1, keepdims=True)
    return (img - mins) / (maxs - mins)


# This clip function forces the input to be within the range to be within the max and min of an image
# normalised with the given mean and std (from an initial range of 0-1)
def clip(image_tensor):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    for c in range(3):
        m, s = mean[c], std[c]
        # clip the input to be within the min and max values
        image_tensor[0, c] = torch.clamp(image_tensor[0, c], -m / s, (1 - m) / s)
    return image_tensor

# Create a pretrained VGG19 Mode

In [None]:
# Create a VGG19 from the pytorch "models" module and download the pre trained weights
# https://pytorch.org/docs/stable/torchvision/models.html
# These models have be trained on the ImageNet challenge dataset (1.3 million images, 1000 classes) to a reasonably high accuracy
vgg_net = models.vgg19(pretrained=True).to(device)
# We're not training it so put it in eval mode
vgg_net.eval()

In [None]:
# Lets see how many Parameter's our Model has!
num_params = 0
for param in vgg_net.parameters():
    num_params += param.flatten().shape[0]
print(f"This model has approximately {num_params} Parameters!")

Visulise the shape of the output

In [None]:
vgg_net(test_img1).shape

<h3> Visulise test image!</h3>

In [None]:
plt.figure(figsize=(10, 10))
np_img = test_img1[0].cpu().numpy().transpose((1, 2, 0))
image_norm = normalize_img(np_img)
plt.imshow(image_norm)

<h3>  What does VGG19 think our test image is?</h3>

In [None]:
# Get the index of the max ouput of the network
idx = vgg_net(test_img1).argmax(1).item()
# Use this to index the class list to get the clas name
print(f"This image is class {idx} which is a {image_net_classes[idx]}")

<h3> What are you looking at???</h3>
Now that we know what it thinks it is, we can try to work out what part of the image VGG19 has used to make it's decision by simply looking at the gradients

In [None]:
# make a copy of our test image and use it to create an autograd variable so that we can capture the gradients
image = copy.deepcopy(test_img1)
image.requires_grad = True
# Get the index of the max ouput of the network
output = vgg_net(image)
# Backpropagate the gradients from the max output to the input image
# In this way we are calculating the how the different input pixels of our image affect the output
# You can actually backprop from anywhere in your network!
# NOTE we can only backpropagate from a single value
output[0, idx].backward()

In [None]:
# TODO: Fix this cell
# Copy the gradients and flatten into a 2D tensor by taking the max along the channels
grad_values, _ = image.grad.detach()[0].cpu().abs().max(0)
# Downsample then upsample as a quick and dirty way of generating a heatmap
grad_scale = F.avg_pool2d(grad_values.unsqueeze(0), 10).unsqueeze(0)
grad_scale = (
    F.upsample_bilinear(grad_scale, size=(grad_values.shape[0], grad_values.shape[1]))
    .squeeze(0)
    .squeeze(0)
)

<h3> Visulise </h3>
This method is a crude way of visulising what the network is paying attention to, brighter areas correspond to higher gradients <br>
If you are interested, checkout this implementation of Class Activation Mapping (CAM) for a better method <br> 

[Class Activation Mapping](http://snappishproductions.com/blog/2018/01/03/class-activation-mapping-in-pytorch.html)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(grad_scale)

<h2> Generating Art </h2>
Once we get the gradients of our image with respect to the ouput, what can we do with them? <br>
These gradients tell us how to change the input image to INCREASE a given output (or wherever you backproped from) <br>
So what happens when we use these gradients to change our image?


<img src="https://b2h3x3f6.stackpathcdn.com/assets/landing/img/gallery/4.jpg" width="800" align="center">

[Deep Dream Generator](https://deepdreamgenerator.com/)

In [None]:
class_indx = 33
print("This class is a", image_net_classes[class_indx])

Lets Select an ImageNet class, we don't have to backpropagate from the real class, infact we can backprop from any feature anywhere in our network!

Like before lets make another copy of the image

In [None]:
image2 = copy.deepcopy(test_img1)
image2.requires_grad = True

In [None]:
# Define a Learning or "update" rate
# try dropping learning rate in order to let the network
# think that it is a turtle, without any noticeable change
# lr = 0.05 # still see some scales
lr = 1.5  # imperceptible

We will now backprop from the class activation indexed by the class we chose eariler <br>
Using the gradients collected we will take a "step" in the direction of the gradient by adding the gradient to our image <br>
As a result we will be enhancing any features of our image that look like they belong to our chosen class

In [None]:
for _ in range(100):
    # Forward pass of network
    out = vgg_net(image2)
    # you don't really need to zero the gradients of the network as we don't use them
    vgg_net.zero_grad()
    # Backprop from chosen class activation
    out[0, class_indx].backward()
    # update the image with the scaled gradients
    # do gradient ascent
    image2.data += lr * image2.grad.data
    # clip the image to keep the pixel values within the origional range
    image2.data = clip(image2.data)
    # we should techinically zero the gradients of the image so they don't accumulate over multiple iteration
    # but in practice for this application it does not make much a difference
    image2.grad.data.zero_()

Now that we've updated our image, what does VGG19 think our image is?

In [None]:
indx = vgg_net(image2).argmax(1).item()
print("This image is now a", image_net_classes[indx])

Let's visualize our altered image!!!

In [None]:
plt.figure(figsize=(10, 10))
np_img = image2[0].detach().cpu().numpy().transpose((1, 2, 0))
image_norm = normalize_img(np_img)
plt.imshow(image_norm)

<h3>Using Multiple Scales</h3>
As we can see in our altered image the updates have mainly changed fine details of the image, this is because many of the layers of the network only opperate on small regions of the image, it is only the final layers' "receptive field" that encompasses the whole image. If we want to make large scale changes to our image (aka modify the general shape of objects in the image) we need more layers of our network to "view" larger regions of the image. We can do this by simply downsampling our input image, however this means the resolution of the ouput image will be low. Instead we can perform some steps of the gradient ascent on the low-res image and then upsample the modified image and again perform gradient ascent. By doing this we can make start by making large scale changes to the image and then make finer and finer changes. We can perform this usampling multiple times. 

In [None]:
# Define a Learning or "update" rate
lr = 0.01

In [None]:
image3 = copy.deepcopy(test_img1)
# downsample the image by a factor of 8
image3 = F.avg_pool2d(image3, 8)
image3.requires_grad = True

for _ in range(6):
    for _ in range(10):
        # Forward pass of network
        out = vgg_net(image3)
        vgg_net.zero_grad()
        # Backprop from chosen class activation
        out[0, class_indx].backward()
        # update the image with the scaled gradients
        # do gradient ascent
        image3.data += lr * image3.grad.data
        # clip the image to keep the pixel values within the origional range
        image3.data = clip(image3.data)
        image3.grad.data.zero_()

    with torch.no_grad():
        image3 = F.upsample_bilinear(image3, scale_factor=1.25)
    image3.requires_grad = True

In [None]:
plt.figure(figsize=(10, 10))
np_img = image3[0].detach().cpu().numpy().transpose((1, 2, 0))
image_norm = normalize_img(np_img)
plt.imshow(image_norm)

<h3> Slicing our network</h3>
This VGG19 implementation is mainly made up of two nn.sequential blocks <br>Lets only take one of them, the initial "features" block, and from it only take some of the initial layers

In [None]:
vgg_net

In [None]:
# copy the features block, and take all layers from the first to the 10th last layer
# take the feature extractor
features_net = vgg_net.features[:]
features_net

In [None]:
lr = 10

In [None]:
image4 = copy.deepcopy(test_img1)
image4 = F.avg_pool2d(image4, 8)
image4.requires_grad = True

In [None]:
# what do the feature maps at this layer look like?
features_net(image4).shape

Lets update our image by backproping from the mean of a single feature map (channel) of the last layer of our "features" block <br>
What would happen if we only backproped from only one feature in this layer? 

In [None]:
channel_id = 3
for _ in range(6):
    for _ in range(50):
        # Forward pass of network
        out = features_net(image4)
        features_net.zero_grad()
        # Backprop from chosen class activation
        out[0, channel_id].mean().backward()
        # update the image with the scaled gradients
        image4.data += lr * image4.grad.data
        # clip the image to keep the pixel values within the origional range
        image4.data = clip(image4.data)
        image4.grad.data.zero_()
    with torch.no_grad():
        image4 = F.upsample_bilinear(image4, scale_factor=1.3)
    image4.requires_grad = True

In [None]:
indx = vgg_net(image4).argmax(1).item()
print("This image is now a", image_net_classes[indx])

By backpropagating from an earlier layer in our network we exagerating "lower-level" features of our image. 

In [None]:
plt.figure(figsize=(10, 10))
np_img = image4[0].detach().cpu().numpy().transpose((1, 2, 0))
image_norm = normalize_img(np_img)
plt.imshow(image_norm)

<h3>Using target features</h3>
Instead of maximising random features of our image, let instead get the features of a "target" image at some layer of our network and make the features of our source image match them 

In [None]:
# Load our target image
test_img2 = Image.open("pattern.jpg").convert("RGB")
test_img2 = transform(test_img2).unsqueeze(0).to(device)

In [None]:
plt.figure(figsize=(10, 10))
np_img = test_img2[0].detach().cpu().numpy().transpose((1, 2, 0))
image_norm = normalize_img(np_img)
plt.imshow(image_norm)

In [None]:
indx = vgg_net(test_img2).argmax(1).item()
print("This image is now a", image_net_classes[indx])

In [None]:
# Make another copy of the features block, and take all layers from the first to the 20th last layer
features_net_sliced = vgg_net.features[:-1]
features_net_sliced

In [None]:
# get the mean value for each of the channels for our target image
target_features = features_net_sliced(test_img2)

In [None]:
target_features.shape

In [None]:
# mean of feature maps for every channel
target_features.mean(dim=[2, 3]).shape

In [None]:
lr = 100

In [None]:
image5 = copy.deepcopy(test_img1)
image5 = F.avg_pool2d(image5, 8)
image5.requires_grad = True

In [None]:
for _ in range(6):
    for _ in range(50):
        # Forward pass of network
        out = features_net_sliced(image5)
        features_net_sliced.zero_grad()
        # Update our source image so the mean features at this layer match the target image
        current = out.mean(dim=[2, 3])
        target = target_features.mean(dim=[2, 3])
        (current - target).pow(2).mean().backward(retain_graph=True)

        # update the image with the scaled gradients
        image5.data -= lr * image5.grad.data
        # clip the image to keep the pixel values within the origional range
        image5.data = clip(image5.data)
        image5.grad.data.zero_()

    with torch.no_grad():
        image5 = F.upsample_bilinear(image5, scale_factor=1.3)

    image5.requires_grad = True

In [None]:
plt.figure(figsize=(10, 10))
np_img = image5[0].detach().cpu().numpy().transpose((1, 2, 0))
image_norm = normalize_img(np_img)
plt.imshow(image_norm)