#### Set path

In [None]:
import sys
sys.path.append('C:/Users/matth/Documents/ETHZ/01_DS/02_HS23/02_DeepLearning/03_Project/00_Testbed_DL/scaling_mlps_mirror')


#### Import libraries

In [None]:
# Basic
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# Torch
import torch
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

# MLP model
from data_utils.data_stats import *
from models.networks import *
from utils.download import *

# CNN model
from torchvision.models import resnet50, ResNet50_Weights

# Others
from tqdm import tqdm
from data_utils.data_stats import *
import torchvision.transforms as T
import random
from scipy.ndimage import gaussian_filter1d
from PIL import Image
import time
import os


### Load Model

In [None]:
## Load standard mlp model
dataset = 'imagenet'                 # One of cifar10, cifar100, stl10, imagenet or imagenet21
architecture = 'B_12-Wi_1024'        #'B_6-Wi_512'         #'B_12-Wi_1024'  'B-12_Wi-1024_res_64_imagenet_epochs_50'   
resolution = 64                      # Resolution of fine-tuned model (64 for all models we provide)
num_classes = CLASS_DICT[dataset]
checkpoint = 'in21k_imagenet'        # This means you want the network pre-trained on ImageNet21k and finetuned on CIFAR10
model_mlp = get_model(architecture=architecture, resolution=resolution, num_classes=num_classes,
                  checkpoint=checkpoint, load_device='cpu', dropout=False)
model_mlp.eval()


In [None]:
## Load CNN
model_cnn = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
for param in model_cnn.parameters():
    param.requires_grad = False
#model_cnn.eval()

In [None]:
# Define normalization constants
MEAN = MEAN_DICT["imagenet"]/255
STD = STD_DICT["imagenet"]/255

#### Define plot generating functions

In [None]:
# Helper functions
def preprocess(img, size=224):
    transform = T.Compose([
        T.Scale(size),
        T.ToTensor(),
        T.Normalize(mean=MEAN.tolist(),
                    std=STD.tolist()),
        T.Lambda(lambda x: x[None]),
    ])
    return transform(img)

def deprocess(img, should_rescale=True):
    transform = T.Compose([
        T.Lambda(lambda x: x[0]),
        T.Normalize(mean=[0, 0, 0], std=(1.0 / STD).tolist()),
        T.Normalize(mean=(-MEAN).tolist(), std=[1, 1, 1]),
        T.Lambda(rescale) if should_rescale else T.Lambda(lambda x: x),
        T.ToPILImage(),
    ])
    return transform(img)

def rescale(x):
    low, high = x.min(), x.max()
    x_rescaled = (x - low) / (high - low)
    return x_rescaled
    
def blur_image(X, sigma=1):
    X_np = X.cpu().clone().numpy()
    X_np = gaussian_filter1d(X_np, sigma, axis=2)
    X_np = gaussian_filter1d(X_np, sigma, axis=3)
    X.copy_(torch.Tensor(X_np).type_as(X))
    return X

def jitter(X, ox, oy):
    """
    Helper function to randomly jitter an image.
    
    Inputs
    - X: PyTorch Tensor of shape (N, C, H, W)
    - ox, oy: Integers giving number of pixels to jitter along W and H axes
    
    Returns: A new PyTorch Tensor of shape (N, C, H, W)
    """
    #print(X.size())

    if ox != 0:
        left = X[:, :, :, :-ox]
        right = X[:, :, :, -ox:]
        X = torch.cat([right, left], dim=3)
    if oy != 0:
        top = X[:, :, :-oy]
        bottom = X[:, :, -oy:]
        X = torch.cat([bottom, top], dim=2)
    return X

In [None]:
def create_activation_maximization(target_y, model, dtype, model_type, **kwargs):
    """
    Generate an image to maximize the score of target_y under a pretrained model.
    
    Inputs:
    - target_y: Integer in the range [0, 1000) giving the index of the class
    - model: A pretrained MLP that will be used to generate the image
    - dtype: Torch datatype to use for computations
    
    Keyword arguments:
    - l2_reg: Strength of L2 regularization on the image
    - learning_rate: How big of a step to take
    - num_iterations: How many iterations to use
    - blur_every: How often to blur the image as an implicit regularizer
    - max_jitter: How much to gjitter the image as an implicit regularizer
    - show_every: How often to show the intermediate result
    """

    assert (model_type in ["cnn", "mlp"]), "Error: specified model type not implemented"

    #--------------
    # Set the seed for random module
    seed = 1 #42#1  
    random.seed(seed)

    # Set the seed for numpy
    np.random.seed(seed)

    # Set the seed for PyTorch
    torch.manual_seed(seed)
    #--------------

    ## Default parameter set
    #model.type(dtype)
    #l2_reg = kwargs.pop('l2_reg', 1e-3)
    #learning_rate = kwargs.pop('learning_rate', 25)
    #num_iterations = kwargs.pop('num_iterations', 200) #100
    #blur_every = kwargs.pop('blur_every', 10)
    #max_jitter = kwargs.pop('max_jitter', 16)
    #show_every = kwargs.pop('show_every', 25)

    # Tuned parameter sets
    if model_type == "mlp":
        model.type(dtype)
        l2_reg = kwargs.pop('l2_reg', 1e-3)
        learning_rate = kwargs.pop('learning_rate', 20)
        num_iterations = kwargs.pop('num_iterations', 800) 
        blur_every = kwargs.pop('blur_every', 5)
        max_jitter = kwargs.pop('max_jitter', 16)
        show_every = kwargs.pop('show_every', 25)
    elif model_type == "cnn":
        model.type(dtype)
        l2_reg = kwargs.pop('l2_reg', 1e-3)
        learning_rate = kwargs.pop('learning_rate', 5)
        num_iterations = kwargs.pop('num_iterations', 800) 
        blur_every = kwargs.pop('blur_every', 40)
        max_jitter = kwargs.pop('max_jitter', 0)
        show_every = kwargs.pop('show_every', 25)

    # Randomly initialize the image as a PyTorch Tensor, and also wrap it in
    # a PyTorch Variable.    
    if model_type == "mlp":
        img = torch.randn(1, 3, 64, 64).mul_(1.0).type(dtype)
        img_reshaped = torch.reshape(img, (img.shape[0], -1))
        img_var = Variable(img_reshaped, requires_grad=True)
        original_shape = (img.shape[0], img.shape[1], 64, 64)
        img_restored = img_reshaped.view(original_shape)
    elif model_type == "cnn":
        img = torch.randn(1, 3, 224, 224).mul_(1.0).type(dtype)
        img_var = Variable(img, requires_grad=True)        

    # Perform activation maximization
    for t in range(num_iterations):
        
        # Randomly jitter the image a bit; this gives slightly nicer results
        ox, oy = random.randint(0, max_jitter), random.randint(0, max_jitter)
        if model_type == "mlp":
            img_restored.copy_(jitter(img_restored, ox, oy))
            img_reshaped = torch.reshape(img_restored, (img_restored.shape[0], -1))
        elif model_type == "cnn":
            img.copy_(jitter(img, ox, oy))

        # Perform forward and backward pass
        scores = model(img_var)
        scores[:, target_y].backward()
        
        # Compute the regularized gradient and make an update step
        if model_type == "mlp":
            grad = img_var.grad.data - 2 * l2_reg * img_reshaped
            img_reshaped += learning_rate * grad
            img_restored = img_reshaped.view(original_shape)
        elif model_type == "cnn":
            grad = img_var.grad.data - 2 * l2_reg * img
            img += learning_rate * grad
        img_var.grad.data.zero_()

        # Undo the random jitter
        if model_type == "mlp":
            img_restored.copy_(jitter(img_restored, -ox, -oy))
            img_reshaped = torch.reshape(img_restored, (img_restored.shape[0], -1))
        elif model_type == "cnn":
            img.copy_(jitter(img, -ox, -oy))

        # As regularizer, clamp the image and periodically blur the image
        for c in range(3):
            lo = float(-MEAN[c] / STD[c])
            hi = float((1.0 - MEAN[c]) / STD[c])
            if model_type == "mlp":
                img_restored[:, c].clamp_(min=lo, max=hi)
                img_reshaped = torch.reshape(img_restored, (img_restored.shape[0], -1))
            elif model_type == "cnn":
                img[:, c].clamp_(min=lo, max=hi)

        # As regularizer, periodically blur the image
        if t % blur_every == 0:
            if model_type == "mlp":
                img_restored = blur_image(img_restored, sigma=0.5)
                img_reshaped = torch.reshape(img_restored, (img_restored.shape[0], -1))
            elif model_type == "cnn":
                blur_image(img, sigma=0.5)     

        # Periodically show the image
        if t == 0 or (t + 1) % show_every == 0 or t == num_iterations - 1:
            if model_type == "mlp":
                plt.imshow(deprocess(img_restored.clone().cpu()))
            elif model_type == "cnn":
                plt.imshow(deprocess(img.clone().cpu()))
            plt.title('Iteration %d / %d' % (t + 1, num_iterations))
            plt.gcf().set_size_inches(4, 4)
            plt.axis('off')
            plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

            # Specify the folder path where you want to save the plot
            output_folder = './output_images'
            # Ensure the output folder exists or create it if not
            os.makedirs(output_folder, exist_ok=True)
            # Save the plot to the specified folder
            timestr = time.strftime("%Y%m%d-%H%M%S")
            output_filename = timestr + '.png'
            output_filepath = os.path.join(output_folder, output_filename)
            plt.savefig(output_filepath)
    
            plt.show()

    if model_type == "cnn":
        return deprocess(img.cpu())
    elif model_type == "mlp":
        return deprocess(img_restored.cpu())

#### Generate activation maximization images

In [None]:
# Define specific on figure
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

In [None]:
# Select model for activation maximization
model = model_cnn # model_mlp #
model_type = "cnn" # "mlp" # 

# Define target output
target_y_arr = np.array([
                         #404, # airliner (airplane)
                         294, # brown bear (bear)
                         #671, # mountain bike (bicycle)
                         #15, # robin (bird)
                         #814, # speedboat (boat)
                         #440, # beer bottle (bottle)
                         #817, # sports car (car)
                         #284, # siamese cat (cat)
                         559, # folding chair (chair)
                         892, # wall clock (clock)
                         #235, # german shepherd (dog)
                         386, # african elephant (elephant)
                         #508, # computer keyboard (keyboard)
                         #623, # cleaver (knife)
                         #766, # rotisserie (oven)
                         #867 # trailer truck (truck)
                         ])

In [None]:
dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to use GPU
model.type(dtype)

for target_y in target_y_arr:
    out = create_activation_maximization(target_y, model, dtype, model_type = model_type)