## Set up

In [None]:
import os
base_dir = os.path.normpath(os.getcwd() + os.sep + os.pardir) 

In [None]:
# import requests # request img from web
# import shutil # save img locally
# from pathlib import Path
# from PIL import Image
# import numpy as np
# import math
# import matplotlib.pyplot as plt
import torch
import torchvision
# from torchvision import transforms

In [None]:
import sys
sys.path.append(f'{base_dir}/src/captum')
from captum import optim as optimviz

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Utilities

In [None]:
def visualize(model: torch.nn.Module, target: torch.nn.Module, neuron: int = -1, 
              neuron_obj: bool = True, lr: float = 0.025) -> None:
    image = optimviz.images.NaturalImage((224, 224)).to(device)
    transforms = [
        # Normalization for torchvision models
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        # Random translation, scaling, and rotation to help with visualization quality.
        # Padding is used to avoid transform artifacts and then it's cropped away.
        torch.nn.ReflectionPad2d(16),
        optimviz.transforms.RandomSpatialJitter(16),
        optimviz.transforms.RandomScale(scale=(1, 0.975, 1.025, 0.95, 1.05)),
        torchvision.transforms.RandomRotation(degrees=(-5,5)),
        optimviz.transforms.RandomSpatialJitter(8),
        optimviz.transforms.CenterCrop((224,224)),
    ]
    transforms = torch.nn.Sequential(*transforms)
    if isinstance(model, optimviz.models.InceptionV1):
        # For the GoogLeNet model, we don't need the initial normalization
        transforms = transforms[1:]
        print("transforms[1:]") ################## test
    else: ######################################## test
        print("transforms") ###################### test
    if neuron > -1:
        if neuron_obj is True: 
            loss_fn = optimviz.loss.NeuronActivation(target, neuron)
        else:
            loss_fn = optimviz.loss.ChannelActivation(target, neuron)
    else:
        loss_fn = optimviz.loss.DeepDream(target)
    obj = optimviz.InputOptimization(model, loss_fn, image, transforms)
    history = obj.optimize(optimviz.optimization.n_steps(128, show_progress=False), lr=lr)   

    print(f"""There are {len(history)} steps in the history.
    Initial loss is {history[0].item()}.
    Final loss is {history[-1].item()}.""")

    image().show()

## Load models

In [None]:
googlenet_torchhub = torch.hub.load('pytorch/vision:v0.10.0', 'googlenet', pretrained=True).to(device)
googlenet_torchhub.eval()

In [None]:
googlenet_torchvision = torchvision.models.googlenet(pretrained=True).to(device)
googlenet_torchvision.eval()

In [None]:
googlenet_captum = optimviz.models._image.inception_v1.googlenet(pretrained=True).to(device)
googlenet_captum.eval()

## Visualisations

In [None]:
neuron = 55
target_torchhub = googlenet_torchhub.inception4e.branch1.conv
target_torchvision = googlenet_torchvision.inception4e.branch1.conv
target_captum = googlenet_captum.mixed4e.conv_1x1

In [None]:
visualize(googlenet_torchhub, target_torchhub, neuron)

In [None]:
visualize(googlenet_torchvision, target_torchvision, neuron)

In [None]:
visualize(googlenet_captum, target_captum, neuron) # 55

In [None]:
visualize(googlenet_captum, target_captum, 42)

In [None]:
visualize(googlenet_captum, target_captum, 42)