In [None]:
import os
import sys
#from typing import Union, Any, Optional, Callable, Tuple

ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)

import numpy as np
import pandas as pd
import proplot as plot
from scipy.stats import pearsonr as linear_correlation
import tensorflow as tf
#import eagerpy as ep
#import torch
#from torch import nn, optim
#from torch.nn import functional as F

from DeepSparseCoding.tf1x.utils.logger import Logger as tfLogger
import DeepSparseCoding.tf1x.analysis.analysis_picker as ap
#from DeepSparseCoding.tf1x.data.dataset import Dataset
import DeepSparseCoding.tf1x.data.data_selector as ds
import DeepSparseCoding.tf1x.utils.data_processing as tfdp
import DeepSparseCoding.tf1x.params.param_picker as pp
import DeepSparseCoding.tf1x.models.model_picker as mp

#from DeepSparseCoding.utils.file_utils import Logger
#import DeepSparseCoding.utils.dataset_utils as dataset_utils
#import DeepSparseCoding.utils.loaders as loaders
import DeepSparseCoding.utils.plot_functions as pf

#import foolbox
#from foolbox import PyTorchModel
#from foolbox.attacks.projected_gradient_descent import LinfProjectedGradientDescentAttack
#from foolbox.types import Bounds
#from foolbox.models.base import Model
#from foolbox.attacks.base import T
#from foolbox.criteria import Misclassification
#from foolbox.attacks.base import raise_if_kwargs
#from foolbox.attacks.base import get_criterion

rand_seed = 123
rand_state = np.random.RandomState(rand_seed)

### Load DeepSparseCoding analyzer

In [None]:
class params(object):
  def __init__(self):
    self.device = "/gpu:0"
    self.analysis_dataset = "test"
    self.save_info = "analysis_" + self.analysis_dataset
    self.overwrite_analysis_log = False
    self.do_class_adversaries = False
    self.do_run_analysis = False
    self.do_evals = False
    self.do_basis_analysis = False
    self.do_inference = False
    self.do_atas = False 
    self.do_recon_adversaries = False
    self.do_neuron_visualization = False
    self.do_full_recon = False
    self.do_orientation_analysis = False 
    self.do_group_recons = False
    
    self.data_dir = os.path.join(ROOT_DIR, 'Datasets')
    self.data_type = 'vanhateren'
    self.vectorize_data = True
    self.rescale_data = False
    self.standardize_data = False
    self.contrast_normalize = False
    self.whiten_data = True
    self.whiten_method = "FT"
    self.whiten_batch_size = 2
    self.extract_patches = True
    self.num_patches = 1e4
    self.patch_edge_size = 16
    self.overlapping_patches = True
    self.randomize_patches = True
    self.patch_variance_threshold = 0.0
    self.lpf_data = False # whitening automatically includes lpf
    self.lpf_cutoff = 0.7
    self.batch_size = 100
    self.random_seed = rand_seed

In [None]:
analysis_params = params()
analysis_params.projects_dir = os.path.expanduser("~")+"/Work/Projects/"

model_names = ['lca_768_vh']#, 'sae_768_vh', 'rica_768_vh']
model_types = ['LCA']#, 'SAE', 'ICA']
analyzers = []
for model_type, model_name in zip(model_types, model_names):
    analysis_params.model_name = model_name
    analysis_params.version = '0.0'
    analysis_params.model_dir = analysis_params.projects_dir+analysis_params.model_name
    model_log_file = (analysis_params.model_dir+"/logfiles/"+analysis_params.model_name
      +"_v"+analysis_params.version+".log")
    #model_logger = tfLogger(model_log_file, overwrite=False)
    #model_log_text = model_logger.load_file()
    #model_params = model_logger.read_params(model_log_text)[-1]
    analysis_params.model_type = model_type
    analyzer = ap.get_analyzer(analysis_params.model_type)
    analysis_params.save_info = "analysis_selectivity"
    analyzer.setup(analysis_params)
    analyzer.model_type = model_type
    analyzer.setup_model(analyzer.model_params)
    analyzers.append(analyzer)

In [None]:
data = ds.get_data(analysis_params)
data = analyzers[0].model.preprocess_dataset(data, analysis_params)
data = analyzers[0].model.reshape_dataset(data, analysis_params)

### VH data examples

In [None]:
num_imgs = 6
img_idx = np.random.randint(analysis_params.batch_size)
fig, axs = plot.subplots(ncols=6)
for inc_img in range(num_imgs):
    im = axs[inc_img].imshow(data['train'].images[img_idx+inc_img,...].reshape(16, 16), cmap='greys_r')
axs.format(suptitle=f'DSC van hateren example')
pf.clear_axes(axs)
plot.show()

In [None]:
weights = [np.squeeze(analyzer.eval_analysis(data['train'].images[0,...][None,...], ['lca/weights/w:0'], analyzer.analysis_params.save_info)['lca/weights/w:0']) for analyzer in analyzers]
lca_weights = weights[0]

In [None]:
activations = [np.squeeze(analyzer.compute_activations(data['train'].images[0:100,...],
    activation_operation=analyzer.model.get_encodings))
    for analyzer in analyzers]

In [None]:
  def compute_lambda_activations(images, model, weights, batch_size=None, activation_operation=None):
    """
    Computes the output code for a set of images.
    Outputs:
      evaluated activation_operation on the input images
    Inputs:
      images [np.ndarray] of shape (num_imgs, num_img_pixels)
      batch_size [int] how many inputs to use in a batch
      activation_operation [tf operation] that produces the output activation
        if None then it defaults to `model.get_encodings()`
    """
    if activation_operation is None:
        activation_operation = model.get_encodings
    images_shape = list(images.shape)
    num_images = images_shape[0]
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.compat.v1.Session(config=config, graph=model.graph) as sess:
        if batch_size is not None and batch_size < num_images:
            assert num_images % batch_size == 0, (
                "batch_size=%g must divide evenly into num_images=%g"%(batch_size, num_images))
            num_batches = int(np.ceil(num_images / batch_size))
            batch_image_shape = [batch_size] + images_shape[1:]
            sess.run(model.init_op, {model.input_placeholder:np.zeros(batch_image_shape)})
            activations = []
            for batch_idx in range(num_batches):
                im_batch_start_idx = int(batch_idx * batch_size)
                im_batch_end_idx = int(np.min([im_batch_start_idx + batch_size, num_images]))
                batch_images = images[im_batch_start_idx:im_batch_end_idx, ...]
                feed_dict = model.get_feed_dict(batch_images, is_test=True)
                feed_dict[model.weight_placeholder] = weights
                outputs = sess.run(activation_operation(), feed_dict)
                activations.append(outputs.copy())
            activations = np.stack(activations, axis=0)
            num_batches, batch_size, num_outputs = activations.shape
            activations = activations.reshape((num_batches*batch_size, num_outputs))
        else:
            feed_dict = model.get_feed_dict(images, is_test=True)
            feed_dict[model.weight_placeholder] = weights
            sess.run(model.init_op, feed_dict)
            activations = sess.run(activation_operation(), feed_dict)
    return activations

In [None]:
lamb_activation = lambda x : tf.identity(x) # linear
lambda_params = pp.get_params("lambda")
lambda_params.set_data_params("vanhateren")
lambda_params.batch_size = 100
lambda_params.data_shape = [lambda_params.patch_edge_size**2] # assumes vector inputs (i.e. not convoultional)
lambda_params.activation_function = lamb_activation
lambda_model = mp.get_model("lambda")
lambda_model.setup(lambda_params)

linear_activations = compute_lambda_activations(data['train'].images[0:100,...], lambda_model, lca_weights, batch_size=10)