In [3]:

# module import
import os
import sys
from datetime import datetime
import pickle

import numpy as np
import PIL.Image
import torch
import torchvision
                
sys.path.append('../cnn_preferred')
from utils import normalise_img, clip_extreme_pixel, save_video, normalise_vid, get_cnn_features, img_deprocess, get_target_feature_shape
from activation_maximization import generate_preferred

In [4]:
# network load
net = torchvision.models.alexnet(pretrained=True)

In [5]:
# image mean and std for pre/de-process image for input network
img_mean=np.array([0.485, 0.456, 0.406],dtype=np.float),
img_std = np.array([0.229,0.224,0.225])

# if the model input is for 0-1 range, norm = 255, elif 0-255, norm = 1
norm = 255

In [6]:
#save_dir
save_dir = '../result'
save_folder = 'jupyter_demo_torch_simpleCNN_conv'
save_folder = save_folder + '_' + datetime.now().strftime('%Y%m%dT%H%M%S')
save_path = os.path.join(save_dir,save_folder)
os.makedirs(save_path, exist_ok=True)

In [7]:
# initial image for the optimization
h, w = 224,224
initial_input = np.random.randint(0, 256, (h,w,3))

In [8]:
net

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Dropout(p=0.5)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace)
    (3): Dropout(p=0.5)
    (4): Linear(in_features=4096, out_feature

In [12]:
target_layer = "features[8]"
target_layer = "features[10]"

In [13]:
#target layer setting
exec_str_list = [target_layer]
## obtain target feature shape
# transform input shape for torch avairable shape
initial_torch_input = torch.Tensor(initial_input.transpose(2,0, 1)[np.newaxis])
# obtain target layer activation 
feat_shape = get_target_feature_shape(net, initial_torch_input, exec_str_list)

In [14]:
# options
opts = {
    'img_mean': img_mean, # img_mean to preprocessing input image (the default is [0.485, 0.456, 0.406]) 
    'img_std': img_std,   # img_std to preprocessing input image  (the default is [0.229,0.224,0.225]) 
    'norm': norm,         # if the model input is for 0-1 range, norm = 255, elif 0-255, norm = 1 (defalt is 255)
    
    'iter_n': 200, # the total number of iterations for gradient descend (defalt is 200)

    'disp_every': 1, # display the information on the terminal for every n iterations (default is 1)

    'save_intermediate': True, # save the intermediate or not (default is None)
    'save_intermediate_every': 10, # save the intermediate for every n iterations (default is 10)
    'save_intermediate_path': save_path, # the path to save the intermediate (default is None)

    'lr_start': 1., # learning rate (default is 1.)
    'lr_end': 1.,   # we can change learning rate linearly setteing these two parameters 

    'momentum_start': 0.001, # gradient with momentum (default is 0.001)
    'momentum_end': 0.001,   # we can change momentum linearly setteing these two parameters too 

    'decay_start': 0.001, # pixel decay for each iteration (default is 0.001)
    'decay_end': 0.001,   # we can also change pixel decay linealy  

    'image_blur': True, # Use image smoothing or not (default is True)
    'sigma_start': 2.5, # the size of the gaussian filter for image smoothing (default is 2.5)
    'sigma_end': 0.5,   

    'image_jitter': True, # use image jittering during optimization (default is True)
    'jitter_size': 4,     # the size of jitter (default is 32)
    
    'use_p_norm_reg': False, # use p_norm regularization (default is False)
    'p': 2,

    'use_TV_norm_reg': False, # use total variance norm (default is False)
    'TVbeta1': 1,             # the order of  spatial domain
    'TVbeta2':1.2,            # the order temporal domain (for video input)

    'clip_small_norm': True,   # clip or not the pixels with extreme high or low value (default True) 
    'clip_small_norm_every': 1,
    'n_pct_start': 5,
    'n_pct_end': 5,

    'clip_small_contribution': True, # clip or not the poxels with smal contribution norm of RGB channels
    'clip_small_contribution_every': 1,
    'c_pct_start': 5,
    'c_pct_end':5,
    
    'input_size': (224,224,3),
    'initial_input': initial_input, # the initial image for the optimization (setting to None will use random noise as initial image)
    }



In [15]:
channel_list = [14,56]

In [16]:
for channel in channel_list:
    #
    print('')
    print('channel='+str(channel))
    print('')

    # target units
    feat_size = feat_shape
    y_index = int(feat_size[2]/2) # the unit in the center of feature map
    x_index = int(feat_size[3]/2) # the unit in the center of feature map
    feature_mask = np.zeros(feat_size)
    feature_mask[0,channel,y_index,x_index] = 1
    
    
    # activation maximization
    preferred_stim = generate_preferred(net, exec_str_list, feature_mask=feature_mask, **opts)
    # save the results
    save_name = 'preferred_img' + '_layer_' + str(target_layer) + '_channel_' + str(channel) + '.npy'
    np.save(os.path.join(save_path,save_name), preferred_stim)

    save_name = 'preferred_img' + '_layer_' + str(target_layer) + '_channel_' + str(channel) + '.jpg'
    # To better display the image, clip pixels with extreme values (0.02% of
    # pixels with extreme low values and 0.02% of the pixels with extreme high
    # values). And then normalise the image by mapping the pixel value to be
    # within [0,255].
    PIL.Image.fromarray(normalise_img(clip_extreme_pixel(preferred_stim, pct=0.04))).save(
                    os.path.join(save_path, save_name))


channel=14

iter=1; mean(abs(feat))=4.70402;
iter=2; mean(abs(feat))=16.5821;
iter=3; mean(abs(feat))=18.0405;
iter=4; mean(abs(feat))=14.6219;
iter=5; mean(abs(feat))=1.6333;
iter=6; mean(abs(feat))=29.4826;
iter=7; mean(abs(feat))=56.1883;
iter=8; mean(abs(feat))=8.13925;
iter=9; mean(abs(feat))=25.4075;
iter=10; mean(abs(feat))=23.2243;
iter=11; mean(abs(feat))=18.4643;
iter=12; mean(abs(feat))=47.6797;
iter=13; mean(abs(feat))=52.1476;
iter=14; mean(abs(feat))=81.0078;
iter=15; mean(abs(feat))=87.2539;
iter=16; mean(abs(feat))=66.9968;
iter=17; mean(abs(feat))=130.678;
iter=18; mean(abs(feat))=60.7477;
iter=19; mean(abs(feat))=100.525;
iter=20; mean(abs(feat))=141.018;
iter=21; mean(abs(feat))=101.869;
iter=22; mean(abs(feat))=108.077;
iter=23; mean(abs(feat))=144.94;
iter=24; mean(abs(feat))=89.3926;
iter=25; mean(abs(feat))=175.798;
iter=26; mean(abs(feat))=174.166;
iter=27; mean(abs(feat))=180.347;
iter=28; mean(abs(feat))=170.286;
iter=29; mean(abs(feat))=187.793;
iter=30; mea

iter=40; mean(abs(feat))=143.824;
iter=41; mean(abs(feat))=145.082;
iter=42; mean(abs(feat))=169.278;
iter=43; mean(abs(feat))=148.422;
iter=44; mean(abs(feat))=134.879;
iter=45; mean(abs(feat))=167.532;
iter=46; mean(abs(feat))=170.534;
iter=47; mean(abs(feat))=171.83;
iter=48; mean(abs(feat))=202.019;
iter=49; mean(abs(feat))=199.236;
iter=50; mean(abs(feat))=234.775;
iter=51; mean(abs(feat))=197.186;
iter=52; mean(abs(feat))=204.855;
iter=53; mean(abs(feat))=219.935;
iter=54; mean(abs(feat))=236.22;
iter=55; mean(abs(feat))=304.613;
iter=56; mean(abs(feat))=249.135;
iter=57; mean(abs(feat))=240.361;
iter=58; mean(abs(feat))=250.466;
iter=59; mean(abs(feat))=316.2;
iter=60; mean(abs(feat))=336.536;
iter=61; mean(abs(feat))=331.273;
iter=62; mean(abs(feat))=348.515;
iter=63; mean(abs(feat))=276.856;
iter=64; mean(abs(feat))=313.127;
iter=65; mean(abs(feat))=343.839;
iter=66; mean(abs(feat))=389.519;
iter=67; mean(abs(feat))=308.752;
iter=68; mean(abs(feat))=412.901;
iter=69; mean(abs(

In [25]:
target = [channel, y_index, x_index]

In [26]:
target


[56, 6, 6]

In [31]:
feature_mask[0, target[0]]

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])