In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from jsputils import classes, feature_extractor
import numpy as np
import matplotlib.pyplot as plt
import copy
import torch
import gc
import scipy.stats as stats
import pandas as pd
from IPython.core.debugger import set_trace
from fastprogress import progress_bar
from scipy.spatial.distance import pdist, squareform
import seaborn as sns


In [None]:
model_name = 'alexnet-barlow-twins'
floc_imageset_name = 'vpnl-floc'
probe_imageset_name = 'classic-categ'
figure_savedir = f'{os.getcwd()}/figure_outputs/Figure1-Categ-Selective-Units'

In [None]:
DNN = classes.DNNModel(model_name)

floc = classes.ImageSet(floc_imageset_name, transforms = DNN.transforms)
probe = classes.ImageSet(probe_imageset_name, transforms = DNN.transforms)

In [None]:
DNN.find_selective_units(floc_imageset_name, overwrite = False, verbose = False,
                        FDR_p = 0.05)

In [None]:
for domain in ['faces','scenes','bodies','characters']:
    all_tvals = []
    floc_dict = copy.deepcopy(DNN.selective_units[domain])
    for layer in list(floc_dict.keys()):
        mask = floc_dict[layer]['mask']
        tvals = floc_dict[layer]['tval'][mask]
        
        if len(tvals) > 0:
            all_tvals.append(tvals)
    
    all_tvals = np.concatenate(all_tvals)
            
    domain_mean_tval = np.mean(all_tvals)
    domain_std_tval = np.std(all_tvals)
    
    print(domain, domain_mean_tval, domain_std_tval)

In [None]:
DNN.get_floc_features(probe, field = 'probe_features', device = 'cuda:0', invert = False)


# Figure 1

In [None]:
pie_domains = ['faces','bodies','objects','scenes','characters'] # remove scrambled
act_domains = ['faces','bodies','scenes','characters']
print(pie_domains)

In [None]:
indices = dict()
props = dict()
masks = dict()
tvals = dict()

for domain in pie_domains:
    
    selective_unit_dict = DNN.selective_units[domain]
    
    props[domain] = dict()
    indices[domain] = dict()
    masks[domain] = dict()
    tvals[domain] = dict()
    
    for layer in DNN.layer_names_fmt:
        masks[domain][layer] = selective_unit_dict[layer]['mask']
        props[domain][layer] = np.mean(masks[domain][layer])
        indices[domain][layer] = np.squeeze(np.argwhere(masks[domain][layer]))
        tvals[domain][layer] = selective_unit_dict[layer]['tval']
        

In [None]:
act_plot_layer = 'fc6'
    
domain_masks = dict()
all_sel_idx = []
sorted_acts = []
nsel = []

plot_acts = DNN.probe_features[act_plot_layer]
dims = plot_acts.shape

if len(dims) > 2:
    plot_acts = plot_acts.reshape(plot_acts.shape[0],np.prod(plot_acts.shape[1:]))
    
plot_acts = plot_acts[:400]

print(plot_acts.shape)

for domain in act_domains:

    # use the domain-selective subset
    mask = masks[domain][act_plot_layer].astype(bool)

    print(len(np.argwhere(mask)))
    
    all_sel_idx.append(mask)
    nsel.append(np.sum(mask))
    
    sorted_acts.append(plot_acts[:400,mask].T)
    
all_sel_idx = np.sum(np.vstack(all_sel_idx),axis=0) > 0
non_sel_idx = np.logical_not(all_sel_idx)

sorted_acts.append(plot_acts[:,non_sel_idx].T)
sorted_acts = stats.zscore(np.vstack(sorted_acts),axis=1)

In [None]:
#plt.rcParams.update({'font.size': 24})

plt.figure(figsize=(12,14))
plt.imshow(sorted_acts,aspect='auto',clim=(-1.5,1.5),cmap='magma')
plt.colorbar()

x = 0
for ns in nsel:
    plt.plot(np.arange(400), np.ones((400,)) * x+ns, 'cyan', linewidth=3)
    x+=ns
    
plt.xticks(np.arange(0,480,80));

plt.savefig(f'{figure_savedir}/{model_name}_{act_plot_layer}-heatmap.tiff')

In [None]:
pie_plot_layers = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc6', 'fc7']
              
colors = ['darkgray', 'purple', 'limegreen', 'orange', 'dodgerblue', 'tomato']
          #np.flip(probe.domain_colorsl

for lay, layer in enumerate(pie_plot_layers):
    
    sel_props = [props[domain][layer] for domain in pie_domains]
    sel_props.append(1 - np.sum(sel_props))
    sel_props = np.flip(sel_props)
    assert(np.isclose(np.sum(sel_props), np.array([1])))
    print(np.flip(pie_domains))
    print(layer, sel_props)
    
    explodes = np.flip([0.05,0.05,0.05,0.05,0.05,0])

    plt.figure(figsize=(12,12))
    plt.pie(sel_props,colors=colors,explode=explodes,startangle=70);
    #plt.title(pie_plot_layers_fmt[lay],fontsize=48)
    plt.savefig(f'{figure_savedir}/{model_name}_{layer}-selectivity-pie.tiff')
    #plt.savefig(f'{figure_savedir}/{model_name}_{act_plot_layer}-heatmap.tiff')

print(np.sum(sel_props[1:]))

# Supplementary Figure 1

In [None]:
model_names = ['alexnet-barlow-twins',
               'alexnet-ipcl',
               'alexnet-supervised',
               'alexnet-barlow-twins-random']

domains = ['faces','bodies','scenes','characters','objects']

colors = {'faces':'tomato',
          'bodies':'dodgerblue',
          'objects':'orange',
          'scenes':'limegreen',
          'characters':'purple'}

layer_list = ['conv1',
             'groupnorm1',
             'relu1',
             'maxpool1',
             'conv2',
             'groupnorm2',
             'relu2',
             'maxpool2',
             'conv3',
             'groupnorm3',
             'relu3',
             'conv4',
             'groupnorm4',
             'relu4',
             'conv5',
             'groupnorm5',
             'relu5',
             'maxpool5',
             'fc6',
             'batchnorm6',
             'relu6',
             'fc7',
             'batchnorm7',
             'relu7',
             'fc8',
             'norm8']


In [None]:

ft = 24

for m, model_name in enumerate(model_names):
    
    plt.figure(figsize=(16,9))
        
    DNN = classes.DNNModel(model_name)

    DNN.find_selective_units(floc_imageset_name, overwrite = False, verbose = False,
                            FDR_p = 0.05)
    
    layers = list(DNN.selective_units[domain].keys())
    
    for domain in domains:
        layer_labels = []
        domain_props = []

        for layer in layers:
            if 'flatten' not in layer and 'dropout' not in layer:
            
                domain_props.append(np.mean(DNN.selective_units[domain][layer]['mask']))
                layer_labels.append(layer)

        if domain == 'characters':
            label = 'words'
        else:
            label = domain
            
        plt.plot(domain_props,label=label,color=colors[domain],linewidth=4);
        
    #plt.title(model_name,fontsize=ft)
    #.ylabel('proportion units selective',fontsize=ft)
    plt.xticks(np.arange(len(layer_labels)),np.array(layer_labels),rotation=90,fontsize=ft);
    #plt.title(f'proportion of domain-selective units by layer (FDR_p = {FDR_p})\nmodel: {model_name}\nfloc set: {floc_imageset}')
    plt.grid('on')
    # get rid of the frame
    for spine in plt.gca().spines.values():
        spine.set_visible(False)
    plt.ylim([0,0.4])
    plt.yticks(fontsize=ft)
    plt.legend(fontsize=ft,loc='upper left')
    plt.tight_layout()
    plt.savefig(f'{figure_savedir}/{model_name}_{floc_imageset_name}_summary.tiff')

    plt.show()


In [None]:
# face selective units in untrained model

In [None]:
model_names = ['alexnet-barlow-twins','alexnet-barlow-twins-random']
image_sets = ['vpnl-floc','classic-categ']

floc_info = dict()
#probe_sets = [floc, probe]

for model_name in model_names:
    
    floc_info[model_name] = dict()
    
    for image_set in image_sets:
        
        DNN = classes.DNNModel(model_name)
        DNN.find_selective_units(image_set, overwrite = False, verbose = False,
                                FDR_p = 0.05)
        
        floc_info[model_name][image_set] = DNN.selective_units


In [None]:
layer = 'fc6'

for model_name in model_names:
    prop_selective = []
    tvals = []
    for domain in ['faces','bodies','scenes','characters']:
        mask = floc_info[model_name]['vpnl-floc'][domain][layer]['mask']
        tval = floc_info[model_name]['vpnl-floc'][domain][layer]['tval']
        #print(model_name, domain, np.nanmin(tval[mask]))
        prop_selective.append(np.nanmean(mask))
        tvals.append(np.nanmean(tval[mask]))
    print(model_name, np.sum(prop_selective), np.nanmean(tvals))


In [None]:
layers = DNN.layer_names_fmt

for model_name in model_names:
    
    for image_set in image_sets:
        
        plt.figure(figsize=(16,9))
        
        for domain in ['faces','bodies','scenes','characters','objects']:
            
            layer_labels = []

            mean_tvals = []
            for layer in layers:
                
                if layer != 'flatten':

                    mask = floc_info[model_name]['vpnl-floc'][domain][layer]['mask']
                    tvals = floc_info[model_name][image_set][domain][layer]['tval'][mask]
                    mean_tvals.append(np.nanmean(tvals))
                    layer_labels.append(layer)

            plt.plot(mean_tvals,label=f'{domain}',color=colors[domain],linewidth=4);
        #plt.title(f'{model_name}',fontsize=ft)
        #plt.ylabel(f'mean tvals\n{image_set} probe set',fontsize=ft)
        plt.xticks(np.arange(len(layer_labels)),np.array(layer_labels),rotation=90,fontsize=ft);
        plt.grid('on')
        # get rid of the frame
        for spine in plt.gca().spines.values():
            spine.set_visible(False)
        plt.ylim([-8,17])
        plt.yticks(fontsize=ft)
        plt.plot(np.arange(len(layer_labels)), np.zeros((len(layer_labels),)), color='k',linewidth=6)
        #plt.legend(fontsize=ft)
        plt.tight_layout()
        plt.savefig(f'{figure_savedir}/{model_name}_{image_set}_tvalues.tiff')
        plt.show()
        

            
            
        
        
    