In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from PROJECT_DNFFA.ANALYSES import selectivity
from PROJECT_DNFFA.HELPERS import plotting, paths, nnutils

import os

import torch
from torchvision import datasets
import torchlens as tl

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
import pandas as pd

In [None]:
model_name = 'alexnet-barlow-twins'
floc_set = 'vpnl-floc'
probe_set = 'classic-categ'

if floc_set == 'vpnl-floc':
    act_domains = ['faces','bodies','scenes','characters']
    pie_domains = ['faces','bodies','objects','scenes','characters']
elif floc_set == 'classic-categ':
    act_domains = ['1-Faces','2-Bodies','3-Scenes','4-Words']
    pie_domains = ['1-Faces','2-Bodies','5-Objects','3-Scenes','4-Words']

probe_set_dir = paths.imageset_dir()
figure_savedir = f'{paths.figure_savedir()}/Figure1-Categ-Selective-Units'
os.makedirs(figure_savedir, exist_ok=True)

In [None]:
model, transforms, _ = nnutils.load_model(model_name)

probe_dataset = datasets.ImageFolder(root = f'{probe_set_dir}/{probe_set}', transform = transforms)

In [None]:
print(probe_dataset.classes)

In [None]:
# data loader object is required for passing images through the network - choose batch size and num workers here
data_loader = torch.utils.data.DataLoader(
    dataset=probe_dataset,
    batch_size=len(probe_dataset),
    num_workers=12,
    shuffle=False,
    pin_memory=False
)

image_tensors, _ = next(iter(data_loader))

model_history = tl.get_model_activations(model, image_tensors, which_layers='all')

In [None]:
indices = dict()
props = dict()
masks = dict()
tvals = dict()

layer_list = nnutils.get_layer_group(model_name)
layer_list_fmt = nnutils.alexnet_layer_str_format(layer_list)

for domain in pie_domains:
    
    selective_unit_dict = selectivity.get_model_selective_units(model_name, f'{floc_set}-{domain}', 
                                                                overwrite = False,
                                                                verbose = False)
    
    props[domain] = dict()
    indices[domain] = dict()
    masks[domain] = dict()
    tvals[domain] = dict()
    
    for layer in layer_list:
        props[domain][layer] = selective_unit_dict[layer]['prop_selective']
        indices[domain][layer] = selective_unit_dict[layer]['selective_idx']
        masks[domain][layer] = selective_unit_dict[layer]['lesioning_mask']
        tvals[domain][layer] = selective_unit_dict[layer]['mean_tvals_unranked']
        

In [None]:
if 'barlow-twins' in model_name:
    act_plot_layer = 'linear_1_28'
elif 'alexnet-supervised' in model_name:
    act_plot_layer = 'linear_1_19'
    
domain_masks = dict()
all_sel_idx = []
sorted_acts = []
nsel = []

plot_acts = model_history[act_plot_layer].tensor_contents.detach().numpy()
dims = plot_acts.shape

if len(dims) > 2:
    plot_acts = plot_acts.reshape(plot_acts.shape[0],np.prod(plot_acts.shape[1:]))
    
plot_acts = plot_acts[:400]

print(plot_acts.shape)

for domain in act_domains:

    # use the domain-selective subset
    mask = masks[domain][act_plot_layer].astype(bool) == False

    print(len(np.argwhere(mask)))
    
    all_sel_idx.append(mask)
    nsel.append(np.sum(mask))
    
    sorted_acts.append(plot_acts[:400,mask].T)
    
all_sel_idx = np.sum(np.vstack(all_sel_idx),axis=0) > 0
non_sel_idx = np.logical_not(all_sel_idx)

sorted_acts.append(plot_acts[:,non_sel_idx].T)
sorted_acts = stats.zscore(np.vstack(sorted_acts),axis=1)

In [None]:
plt.rcParams.update({'font.size': 24})

plt.figure(figsize=(12,14))
plt.imshow(sorted_acts,aspect='auto',clim=(-1.5,1.5),cmap='magma')
plt.colorbar()

x = 0
for ns in nsel:
    plt.plot(np.arange(400), np.ones((400,)) * x+ns, 'cyan', linewidth=3)
    x+=ns
    
plt.xticks(np.arange(0,480,80));

plt.savefig(f'{figure_savedir}/{model_name}_{act_plot_layer}-heatmap.tiff')

In [None]:
pie_plot_layers = ['conv2d_1_8', 'conv2d_2_12', 'conv2d_3_16', 'conv2d_4_19', 'conv2d_5_22', 'linear_1_28']
pie_plot_layers_fmt = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc6']
              
if floc_set == 'vpnl-floc':
    colors = np.flip(['tomato','orange','dodgerblue','limegreen','purple','darkgray'])
    
for lay, layer in enumerate(pie_plot_layers):
    
    sel_props = [props[domain][layer] for domain in pie_domains]
    sel_props.append(1 - np.sum(sel_props))
    sel_props = np.flip(sel_props)
    assert(np.isclose(np.sum(sel_props), np.array([1])))
    print(layer, sel_props)
    
    explodes = np.flip([0.05,0.05,0.05,0.05,0.05,0])

    plt.figure(figsize=(12,12))
    plt.pie(sel_props,colors=colors,explode=explodes,startangle=70);
    #plt.title(pie_plot_layers_fmt[lay],fontsize=48)
    plt.savefig(f'{figure_savedir}/{model_name}_{pie_plot_layers_fmt[lay]}-selectivity-pie.tiff')



In [None]:
# comparing the indices of selective units across localizer sets

In [None]:
floc_set = 'classic-categ'

if floc_set == 'vpnl-floc':
    act_domains = ['faces','bodies','scenes','characters']
    pie_domains = ['faces','bodies','objects','scenes','characters']
elif floc_set == 'classic-categ':
    act_domains = ['1-Faces','2-Bodies','3-Scenes','4-Words']
    pie_domains = ['1-Faces','2-Bodies','5-Objects','3-Scenes','4-Words']



In [None]:
control_indices = dict()
control_props = dict()
control_masks = dict()
control_tvals = dict()

layer_list = nnutils.get_layer_group(model_name)

for domain in pie_domains:
    
    control_selective_unit_dict = selectivity.get_model_selective_units(model_name, f'{floc_set}-{domain}',
                                                                        verbose = False, overwrite=False)
    
    control_props[domain] = dict()
    control_indices[domain] = dict()
    control_masks[domain] = dict()
    control_tvals[domain] = dict()
    
    for layer in layer_list:
        control_props[domain][layer] = control_selective_unit_dict[layer]['prop_selective']
        control_indices[domain][layer] = control_selective_unit_dict[layer]['selective_idx']
        control_masks[domain][layer] = control_selective_unit_dict[layer]['lesioning_mask']
        control_tvals[domain][layer] = control_selective_unit_dict[layer]['mean_tvals_unranked']
        

In [None]:
domains = [('faces','1-Faces'),
           ('bodies','2-Bodies'),
           ('scenes','3-Scenes'),
           ('characters','4-Words'),
           ('objects','5-Objects')]

unit_summary = dict()
IoU_summary = dict()
total_summary = dict()
tval_summary = dict()

for d in range(len(domains)):
    
    unit_summary[domains[d][0]] = dict()
    unit_summary[domains[d][0]]['layer'] = []
    
    IoU_summary[domains[d][0]] = dict()
    IoU_summary[domains[d][0]]['layer'] = []
    
    total_summary[domains[d][0]] = dict()
    total_summary[domains[d][0]]['layer'] = []
    
    tval_summary[domains[d][0]] = dict()
    tval_summary[domains[d][0]]['layer'] = []
    
    unit_summary[domains[d][0]]['prop_uniqueA'] = []
    unit_summary[domains[d][0]]['prop_overlap'] = []
    unit_summary[domains[d][0]]['prop_uniqueB'] = []
    unit_summary[domains[d][0]]['prop_not_sel'] = []
    IoU_summary[domains[d][0]]['IoU'] = []
    total_summary[domains[d][0]]['total_prop_selective'] = []
    tval_summary[domains[d][0]]['tval_r'] = []
    
    for lay, layer in enumerate(layer_list):
        
        unit_summary[domains[d][0]]['layer'].append(layer_list_fmt[lay])
        IoU_summary[domains[d][0]]['layer'].append(layer_list_fmt[lay])
        total_summary[domains[d][0]]['layer'].append(layer_list_fmt[lay])
        tval_summary[domains[d][0]]['layer'].append(layer_list_fmt[lay])

        idxAll = np.arange(len(masks[domains[d][0]][layer]))
        idxA = np.argwhere(masks[domains[d][0]][layer] == 0)
        idxB = np.argwhere(control_masks[domains[d][1]][layer] == 0)
        tvalA = tvals[domains[d][0]][layer]
        tvalB = control_tvals[domains[d][1]][layer]
        
        valid = np.logical_and(np.logical_not(np.isnan(tvalA)), np.logical_not(np.isnan(tvalB)))
                
        if len(idxA) == 0 and len(idxB) == 0:
            unit_summary[domains[d][0]]['prop_uniqueA'].append(np.nan)
            unit_summary[domains[d][0]]['prop_uniqueB'].append(np.nan)
            unit_summary[domains[d][0]]['prop_overlap'].append(np.nan)
            unit_summary[domains[d][0]]['prop_not_sel'].append(np.nan)
            IoU_summary[domains[d][0]]['IoU'].append(np.nan)
            total_summary[domains[d][0]]['total_prop_selective'].append([np.nan, np.nan])
            tval_summary[domains[d][0]]['tval_r'].append(np.nan)
        else:
            idxA_only = np.setdiff1d(idxA, idxB)
            idxB_only = np.setdiff1d(idxB, idxA)
            idxBoth = np.intersect1d(idxA, idxB)
            idxEither = np.unique(np.concatenate((idxA, idxB)))
            idxNone = np.setdiff1d(idxAll, idxEither)
            IoU = len(idxBoth) / len(idxEither)
            try:
                ranking_r = stats.pearsonr(tvalA[valid], tvalB[valid])[0]
            except:
                ranking_r = np.nan
            #print(domains[d], layer, len(idxAll), len(idxBoth), len(idxNone))

            assert(len(idxA_only) + len(idxB_only) + len(idxBoth) + len(idxNone) == len(idxAll))
            
            unit_summary[domains[d][0]]['prop_uniqueA'].append(len(idxA_only) / len(idxAll))
            unit_summary[domains[d][0]]['prop_uniqueB'].append(len(idxB_only) / len(idxAll))
            unit_summary[domains[d][0]]['prop_overlap'].append(len(idxBoth) / len(idxAll))
            unit_summary[domains[d][0]]['prop_not_sel'].append(len(idxNone) / len(idxAll))
            IoU_summary[domains[d][0]]['IoU'].append(IoU)
            total_summary[domains[d][0]]['total_prop_selective'].append([len(idxA) / len(idxAll),
                                                                         len(idxB) / len(idxAll)])
            tval_summary[domains[d][0]]['tval_r'].append(ranking_r)

In [None]:
plt.scatter(tvals['faces']['maxpool2d_3_25'],
            control_tvals['1-Faces']['maxpool2d_3_25'],0.1)

In [None]:
bar_colors = dict()
bar_colors['faces'] = (['tomato','firebrick','pink','whitesmoke'], 'firebrick')
bar_colors['bodies'] = (['dodgerblue','darkblue','lightblue','whitesmoke'], 'darkblue')
bar_colors['scenes'] = (['limegreen','darkgreen','lightgreen','whitesmoke'], 'darkgreen')
bar_colors['characters'] = (['darkviolet','indigo','orchid','whitesmoke'], 'indigo')
bar_colors['objects'] = (['yellow','darkorange','navajowhite','whitesmoke'], 'darkorange')

for domain in ['faces','bodies','scenes','characters','objects']:
    
    prop_df = pd.DataFrame(unit_summary[domain])
    IoU_df = pd.DataFrame(IoU_summary[domain])
    tval_df = pd.DataFrame(tval_summary[domain])
    
    plt.rcParams.update({'font.size': 28})

    prop_df.plot(kind='bar', stacked=True, color=bar_colors[domain][0], figsize=(24,10))
    descriptions = ['unique to vpnl-floc', 'overlapping', 'unique to classic-categ', 'non-selective']
    handles, labels = plt.gca().get_legend_handles_labels()
    order = [3,2,1,0]
    plt.legend([handles[idx] for idx in order],[descriptions[idx] for idx in order], loc='upper right') 
    plt.xticks(np.arange(len(layer_list_fmt)), np.array(layer_list_fmt), rotation=90);
    plt.title(f'comparing {domain}-selective unit indices between vpnl-floc and classic-categ')
    plt.ylabel('prop. units in layer')
    plt.savefig(f'{figure_savedir}/{model_name}_{domain}-floc-set-comparison.png',pad_inches=2)

    IoU_df.plot(kind='bar', color=bar_colors[domain][1], figsize=(24,10))
    plt.legend([f'IoU ({domain})'],loc='upper right')#, bbox_to_anchor=(1.25,0.5))
    plt.title(f'intersection over union of {domain}-selective unit indices')
    plt.xticks(np.arange(len(layer_list_fmt)), np.array(layer_list_fmt), rotation=90);
    plt.ylim([0, 1])
    plt.ylabel('intersection over union')
    plt.savefig(f'{figure_savedir}/{model_name}_{domain}-floc-set-IoU.png', pad_inches=2)
    
    tval_df.plot(kind='bar', color=bar_colors[domain][1], figsize=(24,10))
    plt.legend([f'pearson r ({domain})'],loc='upper right')#, bbox_to_anchor=(1.25,0.5))
    plt.title(f'pearson r between tvals of {domain}-selective comparisons')
    plt.xticks(np.arange(len(layer_list_fmt)), np.array(layer_list_fmt), rotation=90);
    plt.ylim([-0.3, 1])
    plt.plot(np.arange(len(layer_list_fmt)+1)-0.5, np.zeros((len(layer_list_fmt)+1,)),'k',linewidth=2)
    plt.ylabel('pearson r')
    plt.savefig(f'{figure_savedir}/{model_name}_{domain}-floc-set-tvals-pearsonr.png', pad_inches=2)

    

In [None]:
setA_totals = []
setB_totals = []

for lay, layer in enumerate(layer_list):
    
    lay_setA_props = []
    lay_setB_props = []
    
    for d in range(len(domains)):
        
        lay_setA_props.append(total_summary[domains[d][0]]['total_prop_selective'][lay][0])
        lay_setB_props.append(total_summary[domains[d][0]]['total_prop_selective'][lay][1])
        
    setA_totals.append(np.nansum(lay_setA_props))
    setB_totals.append(np.nansum(lay_setB_props))

bar_width=0.25
plt.figure(figsize=(24,10))
plt.bar(np.arange(len(layer_list))-bar_width/2, setA_totals, width=bar_width)
plt.bar(np.arange(len(layer_list))+bar_width/2, setB_totals, width=bar_width)
plt.ylim([0,1])
plt.legend(['vpnl-floc','classic-categ'])
plt.title(f'comparing total proportions of selective units between floc sets')
plt.xticks(np.arange(len(layer_list_fmt)), np.array(layer_list_fmt), rotation=90);
plt.ylabel('proportion units selective within layer')
plt.savefig(f'{figure_savedir}/{model_name}_floc-sets-total-prop-summary.png',pad_inches=2)


In [None]:
stats.pearsonr(setA_totals, setB_totals)[0]

In [None]:
plot_df_melted = pd.melt(plot_df, id_vars=['layer'], var_name='props', value_name='proportion layer units')
plot_df_melted

In [None]:
sns.barplot(x='year', y='count', hue='fruit', data=df_melted, palette='bright')

In [None]:
data = {'year': [2010, 2011, 2012, 2013, 2014],
        'apples': [5, 2, 6, 1, 10],
        'oranges': [10, 12, 8, 6, 3],
        'pears': [3, 8, 2, 5, 7]}
df = pd.DataFrame(data)
df

In [None]:
np.mean(maskA==0)

In [None]:
np.sum(control_masks['1-Faces']['relu_7_33'] != 1)

In [None]:
# 0 0 1 1 0 0 1 0
# 0 0 1 0 0 0 1 0
