## Set-up

In [None]:
import re
import requests
import torch
from torchvision import transforms
from torchvision.models import resnet50
from torchvision.utils import make_grid

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
resnet = resnet50(pretrained=True).to(device)

In [None]:
resnet.eval()
resnet

In [None]:
image_response = requests.get('https://upload.wikimedia.org/wikipedia/commons/5/56/White_shark.jpg',
                              stream=True)
label_response = requests.get('https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt')

In [None]:
image = transforms.functional_pil.Image.open(image_response.raw)
image = transforms.functional.resize(image, (image.size[1] // 2, image.size[0] // 2))
image_tensor = transforms.functional.to_tensor(image).to(device)
image

## Visualize Select Layer Activations (Outputs) of a Given Layer

#### Plot the output of every channel for the first layer/activation

In [None]:
with torch.no_grad():
    conv1_output = resnet.conv1(image_tensor.unsqueeze(0))
    bn1_output = resnet.bn1(conv1_output)
    first_activation_output = resnet.relu(bn1_output)

In [None]:
first_activation_output.shape

In [None]:
transforms.functional.to_pil_image(make_grid(first_activation_output.permute(1, 0, 2, 3)))

#### Plot the output for every channel in the last CNN layer/activation

In [None]:
with torch.no_grad():
    x = image_tensor.unsqueeze(0)
    for name, child in resnet.named_children():
        x = child(x)
        if name == 'layer4':
            break
x.shape

In [None]:
transforms.functional.to_pil_image(make_grid(x.permute(1, 0, 2, 3), nrow=56))

## Visualize the weights (of each filter)

In [None]:
resnet.conv1.weight.shape

In [None]:
transforms.functional.to_pil_image(make_grid(resnet.conv1.weight))

#### Now zoom in on the above filters

In [None]:
transforms.functional.to_pil_image(transforms.functional.resize(make_grid(resnet.conv1.weight), 56 * 8))

#### Now visualize **a small portion of** (3 / 512) the last weights

In [None]:
resnet.layer4[1].conv2.weight.shape

In [None]:
transforms.functional.to_pil_image(make_grid(resnet.layer4[1].conv2.weight[:, :3, :, :], nrow=32))

#### Now zoom in on the above image

In [None]:
transforms.functional.to_pil_image(transforms.functional.resize(make_grid(resnet.layer4[1].conv2.weight[:, :3, :, :], nrow=32), 128 * 3))

## Occlude part of the image sequentially to identify hotspots that are the reason for the most probable classification

In [None]:
image_tensor.shape

In [None]:
block_size = 20
heatmap = torch.clone(image_tensor)
count = 0
with torch.no_grad():
    initial_prediction = torch.softmax(resnet(image_tensor.unsqueeze(0)), dim=-1)
    prediction_index = torch.argmax(initial_prediction, dim=-1).item()
    prediction_probability = initial_prediction[0][prediction_index]
for i in range(0, image_tensor.shape[1] - block_size, 10):
    for j in range(0, image_tensor.shape[2] - block_size, 10):
        with torch.no_grad():
            blocked_image = torch.clone(image_tensor).to(device)
            blocked_image[:, i:i + block_size, j:j + block_size] = torch.zeros(3, block_size, block_size)
            prediction = torch.softmax(resnet(blocked_image.unsqueeze(0)), dim=-1)
            red = 1. if prediction[0][prediction_index] > prediction_probability + 0.03 else 0.
            blue = 1. if prediction[0][prediction_index] < prediction_probability - 0.03 else 0.
            heatmap[0, i:i + block_size, j:j + block_size] += red
            heatmap[2, i:i + block_size, j:j + block_size] += blue
heatmap = torch.clamp(heatmap, min=0, max=1)
heatmap

In [None]:
transforms.functional.to_pil_image(make_grid(heatmap))

## Making a classification prediction

In [None]:
with torch.no_grad():
    prediction = resnet(image_tensor.unsqueeze(0))  # have to unsqueeze to make a batch (of 1) out of image
    print('size of response for classification:', prediction.shape)

In [None]:
labels = eval(label_response.text)
len(labels), list(labels.items())[:5]

In [None]:
topk = torch.topk(torch.softmax(prediction, dim=-1), 5, dim=-1)
topk

In [None]:
print(f'Made the following prediction with a softmax probability of {topk.values[0][0]}:')
labels[topk.indices[0][0].item()]

## TO DO: Retrieving images that maximally activate a neuron

## TO DO: Embedding the codes with t-SNE