In [1]:
import torch
from torchvision.models import alexnet
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import warnings
import os
import pdb
import numpy as np
import scipy.io
import matplotlib.pyplot as plt
import itertools
import pickle
import warnings
from scipy.optimize import fmin
from scipy.stats import norm
import random
import pdb
import matplotlib.colors as mcolors
import timm
import itertools
import pandas as pd
import dill
from scipy.optimize import curve_fit
from collections import OrderedDict
import cornet
from cornet.cornet_z import CORnet_Z

warnings.filterwarnings("ignore")



In [13]:
def get_model(model_name):
    activations = []

    def get_activation(name):
        def hook(model, input, output):
            activations.append(output)
        return hook

    layers_to_extract = ['V1.output', 'V2.output', 'V4.output', 'IT.output', 'decoder.output']

    # load from checkpoint
    checkpoint_path = 'Zepochs/epoch_' + str(model_name) + '.pth.tar'
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    
    # load weights in - need to remove "module here" because of gpu parallelization (i think)
    state_dict = checkpoint['state_dict']
    new_state_dict = {key.replace("module.", ""): value for key, value in state_dict.items()}

    # Recreate the AlexNet model and load weights
    model = CORnet_Z()
    model.load_state_dict(new_state_dict)
    model.eval()

    # Retrieve additional information
    epoch = model_name
    
    with open('Zepochs/results.pkl', 'rb') as f:
        acc = pickle.load(f)
        
    INacc = acc[model_name*5005]['val']['top1']
    
    # Register hooks to capture the outputs of the selected layers
    for layer_name in layers_to_extract:
        layer = dict([*model.named_modules()])[layer_name]
        layer.register_forward_hook(get_activation(layer_name))
    return model, activations, layers_to_extract, INacc


def preprocess_image(img_path, input_size):
    preprocess = transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = Image.open(img_path)
    image = preprocess(image)
    image = image.unsqueeze(0)  # Create a mini-batch as expected by the model
    return image

def load_image(img_path, input_size):
    preprocess = transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
    ])
    image = Image.open(img_path)
    image = preprocess(image)
    image = image.unsqueeze(0)  # Create a mini-batch as expected by the model
    return image

# define loss function to compute scaling values
def compute_loss(params, all_combinations, biggerDiff):
    psi = params[:11]
    sigma = params[11]
    
    # make a copy here of all the combinations
    all_combinations_copy = np.copy(all_combinations)
    all_combinations_copy = all_combinations_copy/10
    
    # replace with the psi values
    for interp_val in range(11):
        all_combinations_copy[all_combinations == interp_val] = psi[interp_val]

    # calculate the differences and probabilitys
    diffs = np.abs(all_combinations_copy[:, 0] - all_combinations_copy[:, 1]) - np.abs(all_combinations_copy[:, 2] - all_combinations_copy[:, 3])
    total_prob = 0
    for response_num in range(len(diffs)):
        if biggerDiff[response_num] == 1:
            prob_response = -np.log(norm.cdf(diffs[response_num], 0, sigma))
            total_prob += prob_response
        elif biggerDiff[response_num] == 2:
            prob_response = -np.log(1-norm.cdf(diffs[response_num], 0, sigma))
            total_prob += prob_response
    return total_prob

# Generate all pairings of numbers from 0 to 10 without repeats
pairs = list(itertools.combinations(range(11), 2))

# Now generate all combinations of two pairs and convert to a matrix
all_combinations = np.array([a + b for a, b in itertools.combinations(pairs, 2)])

In [21]:
#which epochs do you want to look at? define here as range
model_names = range(13, 29)

In [None]:
# if you want to wipe everything and start over, uncomment these out.
#Psis = {};
#mldsSigmas = {};
#allSpecs = {};
#allINacc = {};

#loop through all model epochs (model names)
for model_name in model_names:
    print("epoch:" + str(model_name))
    activationLoss = 0
    fade = 0
    
    # Load the model, initialize the activation hooks, and get specs/imagenet accuracy
    model, activations, layers_to_extract, INacc = get_model(model_name)
    allINacc[model_name] = INacc
    
    # Name the textures, the interpolation values, and create empty dictionaries to store things
    image1Names = ['acorns', 'grass', 'lemons', 'pebbles', 'petals', 'bees', 'iceCream', 'corn', 'guacamole', 'rubies', 'blueberries']
    image2Names = ['redwood', 'leaves', 'bananas', 'granite', 'buttercream', 'pineapple', 'gooseFeathers', 'balloons', 'brainCoral', 'cherries', 'beads']
    interpValues = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    interpNames = [f"{string1}_{string2}" for string1, string2 in zip(image1Names, image2Names)]
    allActivations = {}; allImages = {}; imageNames = []; imageIndex = 0;
    
    # Actually go through and extract all activations
    for interpName in interpNames:
        for interpValue in interpValues:
            img_path = os.path.join('out', interpName, '10x10', 
                    f"{image1Names[imageIndex]}_{interpValue}_{image2Names[imageIndex]}_10x10_pool4_smp1.png")
            imageName = f"{interpName}_{interpValue}"
            # Different name if doing activation loss images
            if activationLoss == 1:
                img_path = 'out_activations_bal/' + image1Names[imageIndex] + '_' + str(interpValue) + '_' + image2Names[imageIndex] + '_1x1' + '_pool4_smp1.png'
            if fade == 1:
                img_path = 'fades/' + image1Names[imageIndex] + '_' + str(interpValue) + '_' + image2Names[imageIndex] + '_pool4_smp1.png'
            #print(img_path)
            x = preprocess_image(img_path, 224)
            unprocessed_image = load_image(img_path, 224)
            
            # Reset activations
            activations.clear()
    
            # Forward pass through the model
            with torch.no_grad():
                model(x)
    
            # Save outputs
            allActivations[imageName] = [activation.numpy() for activation in activations]
            allImages[imageName] = unprocessed_image.squeeze(0).permute(1, 2, 0).numpy()  # Save the original image
            imageNames.append(imageName)
            
            # Add softmax as last layer and classification as final
            #allActivations[imageName].append(scipy.special.softmax(allActivations[imageName][len(layers_to_extract)-1]))            
            #softMaxProbs = scipy.special.softmax(allActivations[imageName][len(layers_to_extract)-1]);
            
        # Move on to next image
        imageIndex += 1


    ## now do the mlds
    # pick which layers you want to look at
    layers = range(0,len(layers_to_extract))
    
    # init a figure
    fig = plt.figure(figsize=(15, 4*len(layers)))
    sub = 0;

    #create empty dicts to store psi values for each model
    Psis[model_name] = {}
    mldsSigmas[model_name] = {}

    #iterate through layers, doing mlds
    for layer in layers:
    
        # iterate through interpolations
        for interpPair in range(len(interpNames)):
            
            # make an empty list to keep track
            biggerDiff = [];
    
            # simulate the 2AFC decisions based on cosine dissimilarity between embeddings
            for pairings in all_combinations.astype(int):
                # add which interp number this is by adding amount of pairs you've iterated * num interp values
                pairings = tuple(x + interpPair*len(interpValues) for x in pairings)
                # get the index from all activations to the pairs
                im1 = pairings[0]; im2 = pairings[1]; im3 = pairings[2]; im4 = pairings[3];
                # calcuate the distances
                dist12 = scipy.spatial.distance.cosine(allActivations[imageNames[im1]][layer].flatten(), allActivations[imageNames[im2]][layer].flatten())
                dist34 = scipy.spatial.distance.cosine(allActivations[imageNames[im3]][layer].flatten(), allActivations[imageNames[im4]][layer].flatten())
                if dist12 > dist34:
                    biggerDiff.append(1)
                elif dist34 > dist12:
                    biggerDiff.append(2)
                else:
                    biggerDiff.append(0)
            
            
            # set up initial params
            psi = np.arange(0, 11)/10 #np.random.rand(11)#
            sigma = 0.2
            initial_params = np.concatenate((psi, [sigma]))
            
            # search for params
            optimal_params = fmin(compute_loss, initial_params, args=(all_combinations, biggerDiff));
            psi = optimal_params[:11]
            psi = psi - np.min(psi)
            psi = psi / np.max(psi)
    
            # save the fit parameters (psi and sigma values)
            Psis[model_name][(layer, interpPair)] = psi;
            mldsSigmas[model_name][(layer, interpPair)] = optimal_params[11];
            
            # plot it
            plt.subplot(len(layers),len(image1Names),interpPair+1+sub*len(image1Names))
            
            # iterate over interp_values
            for interp_value in range(11):
                im = interpPair * len(interpValues) + interp_value
                plt.imshow(allImages[imageNames[im]], extent=[(interp_value ) / 10 - 0.05, (interp_value) / 10 + 0.05, psi[interp_value] - 0.05, psi[interp_value] + 0.05])
           
            # set axis limits
            plt.xlim([-0.05, 1.05])
            plt.ylim([-0.05, 1.05])
            plt.plot([0, 1], [0, 1], 'r')
    
        sub += 1        
    # show the plot
    fig.supylabel('Perceptual distance value')
    fig.supxlabel('Synthesized interpolation value')
    

epoch:13
Optimization terminated successfully.
         Current function value: 207.509517
         Iterations: 740
         Function evaluations: 1042
Optimization terminated successfully.
         Current function value: 221.431437
         Iterations: 660
         Function evaluations: 935
Optimization terminated successfully.
         Current function value: 217.738424
         Iterations: 780
         Function evaluations: 1108
Optimization terminated successfully.
         Current function value: 341.832491
         Iterations: 953
         Function evaluations: 1329
Optimization terminated successfully.
         Current function value: 153.280415
         Iterations: 835
         Function evaluations: 1178
Optimization terminated successfully.
         Current function value: 150.646214
         Iterations: 667
         Function evaluations: 944
Optimization terminated successfully.
         Current function value: 213.318185
         Iterations: 683
         Function evaluation

Optimization terminated successfully.
         Current function value: 153.280415
         Iterations: 835
         Function evaluations: 1178
Optimization terminated successfully.
         Current function value: 150.646214
         Iterations: 667
         Function evaluations: 944
Optimization terminated successfully.
         Current function value: 213.318185
         Iterations: 683
         Function evaluations: 964
Optimization terminated successfully.
         Current function value: 301.475839
         Iterations: 525
         Function evaluations: 755
Optimization terminated successfully.
         Current function value: 221.280803
         Iterations: 649
         Function evaluations: 926
Optimization terminated successfully.
         Current function value: 206.246197
         Iterations: 588
         Function evaluations: 858
Optimization terminated successfully.
         Current function value: 160.509933
         Iterations: 847
         Function evaluations: 1166
Opti

Optimization terminated successfully.
         Current function value: 305.373081
         Iterations: 748
         Function evaluations: 1056
Optimization terminated successfully.
         Current function value: 219.297014
         Iterations: 625
         Function evaluations: 886
Optimization terminated successfully.
         Current function value: 206.246197
         Iterations: 588
         Function evaluations: 858
Optimization terminated successfully.
         Current function value: 162.182210
         Iterations: 877
         Function evaluations: 1207
Optimization terminated successfully.
         Current function value: 132.745166
         Iterations: 586
         Function evaluations: 846
Optimization terminated successfully.
         Current function value: 131.391782
         Iterations: 721
         Function evaluations: 1022
Optimization terminated successfully.
         Current function value: 122.419805
         Iterations: 649
         Function evaluations: 911
Opt

In [8]:
who

CORnet_Z	 Image	 OrderedDict	 activationLoss	 alexnet	 all_combinations	 compute_loss	 cornet	 curve_fit	 
dill	 fade	 fmin	 get_model	 itertools	 load_image	 mcolors	 model_name	 model_names	 
models	 nn	 norm	 np	 os	 pairs	 pd	 pdb	 pickle	 


In [None]:
#dill.load_session('1119alexnet.db')
dill.dump_session('1126cornet.db')