In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from PROJECT_DNFFA.ANALYSES import selectivity
from PROJECT_DNFFA.HELPERS import plotting, paths, nnutils

import torch
from torchvision import datasets
import torchlens as tl

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats

In [None]:
model_name = 'alexnet-barlow-twins'
floc_set = 'vpnl-floc'
probe_set = 'classic-5categ'

if floc_set == 'vpnl-floc':
    act_domains = ['faces','bodies','scenes','characters']
    pie_domains = ['faces','bodies','objects','scenes','characters']

probe_set_dir = paths.probe_imageset_dir()

In [None]:
model, transforms, _ = nnutils.load_model(model_name)

probe_dataset = datasets.ImageFolder(root = f'{probe_set_dir}/{probe_set}', transform = transforms)

In [None]:
print(probe_dataset.classes)

In [None]:
# data loader object is required for passing images through the network - choose batch size and num workers here
data_loader = torch.utils.data.DataLoader(
    dataset=probe_dataset,
    batch_size=len(probe_dataset),
    num_workers=12,
    shuffle=False,
    pin_memory=False
)

image_tensors, _ = next(iter(data_loader))

model_history = tl.get_model_activations(model, image_tensors, which_layers='all')

In [None]:
indices = dict()
props = dict()
masks = dict()

layer_list = nnutils.get_layer_group(model_name)

for domain in pie_domains:
    
    selective_unit_dict = selectivity.get_model_selective_units(model_name, f'{floc_set}-{domain}', verbose = False)
    
    props[domain] = dict()
    indices[domain] = dict()
    masks[domain] = dict()
    
    for layer in layer_list:
        props[domain][layer] = selective_unit_dict[layer]['prop_selective']
        indices[domain][layer] = selective_unit_dict[layer]['selective_idx']
        masks[domain][layer] = selective_unit_dict[layer]['lesioning_mask']
        

In [None]:
act_plot_layer = 'linear_1_28'
domain_masks = dict()
all_sel_idx = []
sorted_acts = []
nsel = []

plot_acts = model_history[act_plot_layer].tensor_contents.detach().numpy()
dims = plot_acts.shape

if len(dims) > 2:
    plot_acts = plot_acts.reshape(plot_acts.shape[0],np.prod(plot_acts.shape[1:]))
    
print(plot_acts.shape)

for domain in domains:

    # use the domain-selective subset
    mask = masks[domain][act_plot_layer].astype(bool) == False

    print(len(np.argwhere(mask)))
    
    all_sel_idx.append(mask)
    nsel.append(np.sum(mask))
    
    sorted_acts.append(plot_acts[:,mask].T)
    
all_sel_idx = np.sum(np.vstack(all_sel_idx),axis=0) > 0
non_sel_idx = np.logical_not(all_sel_idx)

sorted_acts.append(plot_acts[:,non_sel_idx].T)
sorted_acts = stats.zscore(np.vstack(sorted_acts),axis=1)

In [None]:
plt.rcParams.update({'font.size': 24})

plt.figure(figsize=(12,18))
plt.imshow(sorted_acts,aspect='auto',clim=(-1.5,1.5),cmap='magma')
plt.colorbar()

x = 0
for ns in nsel:
    plt.plot(np.arange(400), np.ones((400,)) * x+ns, 'k')
    x+=ns
    
plt.xticks(np.arange(0,480,80));

In [None]:
print(layer_list)

layer_list_fmt = nnutils.alexnet_layer_str_format(layer_list)

print(layer_list_fmt)

In [None]:
pie_plot_layers = ['conv2d_1_8', 'conv2d_2_12', 'conv2d_3_16', 'conv2d_4_19', 'conv2d_5_22', 'linear_1_28']
pie_plot_layers_fmt = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc6']
              
if floc_set == 'vpnl-floc':
    colors = np.flip(['tomato','orange','dodgerblue','limegreen','purple','darkgray'])
    
for lay, layer in enumerate(pie_plot_layers):
    
    sel_props = [props[domain][layer] for domain in pie_domains]
    sel_props.append(1 - np.sum(sel_props))
    sel_props = np.flip(sel_props)
    assert(np.isclose(np.sum(sel_props), np.array([1])))
    print(layer, sel_props)
    
    explodes = np.flip([0.05,0.05,0.05,0.05,0.05,0])

    plt.figure(figsize=(6,6))
    plt.pie(sel_props,colors=colors,explode=explodes,startangle=70);
    plt.title(pie_plot_layers_fmt[lay])

