In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from jsputils import classes, nnutils, feature_extractor
import numpy as np
import torch
import matplotlib.pyplot as plt
import scipy.stats as stats
import gc
import os
from os.path import exists
import json
import copy
from scipy.spatial.distance import squareform, pdist
import seaborn as sns
from fastprogress import progress_bar
from IPython.core.debugger import set_trace

torch.cuda.empty_cache()

In [None]:
model_name = 'alexnet-barlow-twins'
floc_imageset_name = 'vpnl-floc'

figure_savedir = f'{os.getcwd()}/figure_outputs/Figure2-Lesioning'

In [None]:
DNN = classes.DNNModel(model_name)

DNN.append_readout_layer(readout_from = 'relu7')
DNN.load_readout_weights(description = 'mdl-alexnet-barlow-twins_from-relu7_mlr-0.05_ilr-0.001_eps-10_sparse-pos-True_l1p-1e-05_l1n-1e-05',
    device = 'cpu')
#DNN.load_readout_weights(description = 'mdl-alexnet-barlow-twins_from-relu7_mlr-0.05_ilr-0.001_eps-10_sparse-pos-False',
#                        device = 'cpu')
                                         

In [None]:
weights = DNN.readout_model.readout.weight.detach().numpy()

plt.imshow(weights,aspect='auto',cmap='RdBu_r',clim=(-0.01,0.01))
plt.colorbar()

In [None]:
eps = 0.001
np.mean(np.abs(weights) < eps)

In [None]:
print(np.mean(np.abs(weights)))
print(np.std(np.abs(weights)))


In [None]:
DNN.find_selective_units(floc_imageset_name, overwrite = False, verbose = False,
                        FDR_p = 0.05)

In [None]:
DNN.model = DNN.readout_model
DNN.layer_names_fmt, _ = feature_extractor.get_pretty_layer_names(DNN.readout_model)


In [None]:
LSN = classes.LesionModel(DNN, 'cuda:0')

In [None]:
LSN.model.return_acts = False
LSN.model.masks['apply'] = False

In [None]:
LSN.get_imagenet_accs(topk=5)

In [None]:
prelesion_accs = LSN.imagenet_accs
print(np.mean(prelesion_accs))
print(np.std(prelesion_accs))

In [None]:
LSN.get_selective_unit_acts(layers = DNN.layer_names_fmt[:-1])

In [None]:
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()

In [None]:

results = dict()
results['acc'] = prelesion_accs

for domain in progress_bar(['faces','scenes','bodies','characters',
                            'objects','scrambled']):

    results[domain] = dict()
    
    LSN.model.return_acts = False

    LSN.apply_channelized_lesions(domain, 
                                  method = 'relus')

    LSN.model.masks['apply'] = True
    
    LSN.get_imagenet_accs(topk=5)

    results[domain]['lsn_acc'] = LSN.imagenet_accs
    
  

In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:
def scatter_corr(x,y,linecols = ['r'], bigks = [None], e1 = None, e2 = 0.5):
    np.random.seed(0)
    if e1 is None:
        ep1 = np.zeros((len(x),))
    else:
        ep1 = np.random.normal(0,e1,len(x))
    ep2 = np.random.normal(0,e2,len(y))
    sizes = np.array([150] * len(y))
    colors = ['darkgray'] * len(y)
    if bigks[0]: 
        for bk, bigk in enumerate(bigks):
            if bk == 0 and len(bigks) == 1:
                idx = np.argsort(y)
            elif bk == 0 and len(bigks) == 2:
                idx = np.argsort(x)
            elif bk == 1:
                idx = np.argsort(y)
            
            sizes[idx[:bigk]] = 400
            for k in range(bigk):
                colors[idx[:bigk][k]] = linecols[bk]
    colors = np.array(colors)
            
    if len(linecols) == 2:
        linecol_ = 'k'
    else:
        linecol_ = linecols[0]
        
    plt.scatter(x+ep1,y+ep2,sizes,c=colors)
    plt.plot(np.unique(x+ep1), np.poly1d(np.polyfit(x+ep1, y+ep2, 1))(np.unique(x+ep1)),color=linecol_,linewidth=20)
    # if len(choose_layer) > 0:
    #     plt.title(f'read-out effect: \nlayer {max_layer}: r = {round(np.corrcoef(x,y)[1,0],3)}',
    #          fontsize=20)
    # else:
    print(f'r = {round(np.corrcoef(x,y)[1,0],3)}')
    #plt.title(f'r = {round(np.corrcoef(x,y)[1,0],3)}',
    #          fontsize=20)
    #plt.xlabel('mean categ. activation',fontsize=14)
    #plt.ylabel('drop in acc',fontsize=14)

In [None]:
plot_layer = 'relu6'
rs = []
colors = ['red','dodgerblue','limegreen','purple']
#colors = ['tomato', [0.196, 0.804, 0.196], 'dodgerblue', 'purple']

for d, domain in enumerate(['faces','bodies','scenes','characters']):#,'objects','scrambled']:
    
    x = copy.deepcopy(LSN.selective_unit_acts[domain][plot_layer])
    y = -100 * copy.deepcopy((results['acc'] - results[domain]['lsn_acc']))# / results['acc']
    
    notnan = np.logical_and(np.logical_not(np.isnan(x)),
                           np.logical_not(np.isnan(y)))
    notinf = np.logical_and(np.logical_not(np.isinf(x)),
                           np.logical_not(np.isinf(y)))
    
    valid = np.logical_and(notnan, notinf)
    
    if np.sum(valid) > 750:
        r = stats.pearsonr(x[valid],y[valid])[0]
        
        
        rs.append(r)
        plt.figure(figsize=(24,24))
        scatter_corr(x[valid],y[valid],linecols = [colors[d]],bigks=[10])
        plt.xticks(fontsize=75)
        plt.yticks(fontsize=75)
        #plt.xlim([0,1])
        plt.savefig(f"{figure_savedir}/scatter_corr_{domain}.tiff")
        plt.close()
       
    else:
        print('skipping',domain,layer,np.sum(valid))
        
print(np.mean(rs), np.std(rs))

In [None]:
plot_layers = DNN.layer_names_fmt[:-1]
plt.figure(figsize=(16,9))
ft = 24

color_dict = {'faces':'tomato',
          'bodies':'dodgerblue',
          'objects':'orange',
          'scenes':'limegreen',
          'characters':'purple',
          'scrambled':'navy'}

layer_list = []

for layer in plot_layers:
    if not 'flatten' in layer:
        layer_list.append(layer)

for domain in ['faces','scenes','bodies','characters','objects']:
    
    rs = []
    y = -100 * (results['acc'] - results[domain]['lsn_acc'])# / results['acc']
    
    for plot_layer in layer_list:
    
        x = LSN.selective_unit_acts[domain][plot_layer]
        
    
        notnan = np.logical_and(np.logical_not(np.isnan(x)),
                               np.logical_not(np.isnan(y)))
        notinf = np.logical_and(np.logical_not(np.isinf(x)),
                               np.logical_not(np.isinf(y)))

        valid = np.logical_and(notnan, notinf)
    
        if np.sum(valid) > 750:
            r = stats.pearsonr(x[valid],y[valid])[0]
            rs.append(r)
        else:
            #print(domain, layer, np.sum(valid))
            rs.append(np.nan)
            
    if domain == 'characters':
        label = 'words'
    else:
        label = domain
            
    plt.plot(rs,label=label,color=color_dict[domain],linewidth=5);
        
#plt.title(f'Correlation between activation and cost profiles',fontsize=ft)
#plt.ylabel('Pearson r',fontsize=ft)
plt.xticks(np.arange(len(layer_list)),np.array(layer_list),rotation=90,fontsize=ft);
    #plt.title(f'proportion of domain-selective units by layer (FDR_p = {FDR_p})\nmodel: {model_name}\nfloc set: {floc_imageset}')
plt.grid('on')
# get rid of the frame
for spine in plt.gca().spines.values():
    spine.set_visible(False)
plt.ylim([-1,0.2])
plt.plot(np.arange(len(layer_list)), np.zeros((len(layer_list),)), color='k',linewidth=3)
plt.yticks(fontsize=ft)
plt.legend(fontsize=ft-3)
plt.tight_layout()
plt.savefig(f"{figure_savedir}/readout_effect_summary.tiff")

    

In [None]:
costs = []
cost_corrs = []

for domain in ['faces','bodies','scenes','characters','objects']:
    
    costs.append(results['acc'] - results[domain]['lsn_acc'])
    
costs = np.vstack(costs)

cost_corrs = pdist(costs,'correlation')

print(np.mean(1 - cost_corrs))
print(np.std(1 - cost_corrs))

In [None]:
plt.figure(figsize=(5*4,4*4))
sns.heatmap(1 - squareform(cost_corrs), annot = True, cmap = 'RdBu_r', vmin = -1, vmax = 1, annot_kws={"size": 40})
plt.xticks([])
plt.yticks([]);
plt.savefig(f"{figure_savedir}/cost_corr_domain_summary.png")


In [None]:
domains = ['faces','bodies','scenes','characters','objects']

dom_pairs = [[0,2],
             [0,3]]

for dom_pair in dom_pairs:
    dA, dB = dom_pair

    plt.figure(figsize=(24,24))

    scatter_corr(-100*copy.deepcopy(costs[dA]),
                 -100*copy.deepcopy(costs[dB]), e1 = 0.5, e2 = 0.5, bigks=[10,10],
                                     linecols=[colors[dA],
                                               colors[dB]])
    plt.xticks(fontsize=75)
    plt.yticks(fontsize=75)
    #plt.xlim([-0.33,0.84])
    #plt.ylim([-0.33,0.84])
    #plt.axis('square');
    
    plt.savefig(f"{figure_savedir}/dissociation_{domains[dom_pair[0]]}-{domains[dom_pair[1]]}.tiff")

    plt.show()

In [None]:

results_cv = dict()

# unlesioned case
LSN.model.masks['apply'] = False
LSN.model.return_acts = False
LSN.get_imagenet_accs(topk = 5, cv = True)

results_cv['acc_splitA'] =  LSN.imagenet_accs[:,0]
results_cv['acc_splitB'] =  LSN.imagenet_accs[:,1]


In [None]:
np.mean(results_cv['acc_splitB'])

In [None]:
# Load the JSON file containing the ImageNet class index
with open(f'{os.getcwd()}/imagenet_class_labels.json', 'r') as f:
    class_index = json.load(f)

# Create a list to store the category labels
categories = []

# Iterate over the class index dictionary and extract the labels
for idx in range(len(class_index)):
    categories.append(class_index[idx])
categories = np.array(categories)

In [None]:

for domain in progress_bar(['faces','scenes','bodies','characters']):

    results_cv[domain] = dict()
    
    LSN.model.return_acts = False

    LSN.apply_channelized_lesions(domain, 
                                  method = 'relus')

    LSN.model.masks['apply'] = True

    LSN.get_imagenet_accs(topk = 5, cv = True)

    results_cv[domain]['lsn_acc_splitA'] = LSN.imagenet_accs[:,0]
    results_cv[domain]['lsn_acc_splitB'] = LSN.imagenet_accs[:,1]
    

In [None]:
domains = ['faces','bodies','scenes','characters']
c=1
ks = [5,10,25,50,75,100]
plt.figure(figsize=(12,len(ks)*5))

prop = False

for k in ks:

    domain_costs = dict()
    top_k_indices = dict()
    mean_drop = dict()
    sem_drop = dict()

    for lsn_domain in domains:

        mean_drop[lsn_domain] = dict()
        sem_drop[lsn_domain] = dict()
        
        for sp in ['A','B']:

            if prop:
                domain_costs[f'{lsn_domain}_split{sp}'] = (results_cv[f'acc_split{sp}'] - results_cv[lsn_domain][f'lsn_acc_split{sp}']) / results_cv[f'acc_split{sp}']
            else:
                domain_costs[f'{lsn_domain}_split{sp}'] = (results_cv[f'acc_split{sp}'] - results_cv[lsn_domain][f'lsn_acc_split{sp}'])

            domain_costs[f'{lsn_domain}_split{sp}'][np.isinf(domain_costs[f'{lsn_domain}_split{sp}'])] = 0
            domain_costs[f'{lsn_domain}_split{sp}'][np.isnan(domain_costs[f'{lsn_domain}_split{sp}'])] = 0
        
        cost_sort_idx = np.argsort(domain_costs[f'{lsn_domain}_splitA'])
        assert(np.sum(np.isnan(domain_costs[f'{lsn_domain}_splitA'])) == 0)

        top_k_indices[lsn_domain] = cost_sort_idx[-k:]

    for lsn_domain in domains:
        for probe_domain in domains:
            mean_drop[lsn_domain][probe_domain] = np.mean(domain_costs[f'{lsn_domain}_splitB'][top_k_indices[probe_domain]])
            sem_drop[lsn_domain][probe_domain] = np.std(domain_costs[f'{lsn_domain}_splitB'][top_k_indices[probe_domain]]) / np.sqrt(k)

            #print(lsn_domain, probe_domain, mean_drop[lsn_domain][probe_domain])

        if k == 10:
            vals = domain_costs[f'{lsn_domain}_splitB'][top_k_indices[lsn_domain]]
            print(lsn_domain, np.mean(vals), np.std(vals))
            
    # Create the x position of the bars
    x = np.arange(len(domains))

    # Create the bars
    bar_width = 0.18
    gap = 0.1
    colors = ['tomato', 'dodgerblue', 'limegreen', 'purple']#'navy']#, 'orange']

    plt.subplot(len(ks),1,c)
    for i, probe_domain in enumerate(domains):
        means = -100 * copy.deepcopy(np.array([mean_drop[lsn_domain][probe_domain] for lsn_domain in domains]))
        plt.bar(x + i*bar_width, means, width = bar_width, color = colors[i],
                yerr =  100*np.array([sem_drop[lsn_domain][probe_domain] for lsn_domain in domains]))


    # Add some text for labels, title
    #plt.title(f'Lesioning impact on most domain-relevant categories, k = {k}')
    #plt.xticks(x + bar_width*1.5, [f'{domain}\nlesions' for domain in domains])
    plt.hlines(0,0,3.5,'k',linewidth=0.5)
    plt.xticks([])
    plt.box('off')
    for spine in plt.gca().spines.values():
        spine.set_visible(False)
    plt.yticks(fontsize=25)
 
    plt.grid('on')
    
    if prop:
        #plt.ylabel('Mean change in accuracy (proportion)')
        plt.ylim([-110, 25])
    else:
        #plt.ylabel('Mean change in accuracy (absolute %)')
        plt.ylim([-70, 15])
    c+=1

# show the graph
plt.tight_layout()
plt.savefig(f"{figure_savedir}/dissociation_bars_k_summary.tiff")

plt.show()




In [None]:
k = 6

indices = dict()

domains = ['faces','bodies','scenes','characters']
# use half the data to get the special categs for each domain
for domain in domains:
    costs = results_cv['acc_splitA'] - results_cv[domain]['lsn_acc_splitA']
    indices[domain] = np.flip(np.argsort(costs)[-k:])
    print(domain, categories[indices[domain]], '\naccuracy drops:', costs[indices[domain]])
    print('\n')

In [None]:
ValLoader = classes.DataLoaderFFCV('val')

In [None]:
val_images = torch.Tensor(50000, 3, 224, 224)
val_targets = torch.Tensor(50000)
batch_size = ValLoader.batch_size

c = 0
for images, targets, _, _ in progress_bar(ValLoader.data_loader):
    val_images[c:c+batch_size] = images
    val_targets[c:c+batch_size] = targets
    c+=batch_size


In [None]:
probe_images = torch.Tensor(5000, 3, 224, 224)
probe_targets = torch.arange(1000)

c = 0
for i in range(1000):
    idx = np.squeeze(np.argwhere(val_targets == i))
    for j in range(5):
        probe_images[c+j] = val_images[idx[0+j]]
    c+=5
    
del val_images, val_targets


In [None]:
gc.collect()
torch.cuda.empty_cache()
gc.collect()

In [None]:
# Load the JSON file containing the ImageNet class index
with open(f'{os.getcwd()}/imagenet_class_labels.json', 'r') as f:
    class_index = json.load(f)

# Create a list to store the category labels
categories = []

# Iterate over the class index dictionary and extract the labels
for idx in range(len(class_index)):
    categories.append(class_index[idx])

In [None]:
# Define the mean and standard deviation values used for normalization
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])


In [None]:
for domain in ['faces','scenes','bodies','characters']:#'faces','scenes','bodies','characters']:

    costs = results_cv['acc_splitA'] - results_cv[domain]['lsn_acc_splitA']
    #costs = results['acc'] - results[domain]['lsn_acc']
    acts = LSN.selective_unit_acts[domain][plot_layer]

    rankings = np.flip(np.argsort(costs))

    for i in range(8):
        
        for j in range(6):
            
            plt.figure(figsize=(20,20))

            categ_idx = rankings[i]

            loss = costs[rankings[i]]#100 * costs[rankings[i]] / results['acc'][rankings[i]]

            img = probe_images[categ_idx*5+j].numpy().transpose(1,2,0)

            # Undo the normalization
            restored_image = img * std + mean

            # Clip the values to ensure they are within the valid range [0, 1]
            restored_image = np.clip(restored_image, 0, 1)

            plt.imshow(restored_image)
            plt.axis('off')
            #plt.title(f'{categories[categ_idx]} ({round(loss,2)})')
            plt.tight_layout()
            plt.savefig(f"{figure_savedir}/{domain}-impaired-{i}-{categories[categ_idx]}-{round(loss,2)}-{j}.tiff")
            plt.close()

In [None]:
# randomized lesions


In [None]:
overwrite = False

savefn = f'{os.getcwd()}/analysis_outputs/2-Lesioning/randomized_lesion_results.npy'
n_iters = 10

plot_layers = DNN.layer_names_fmt[:-1]

colors = {'faces':'tomato',
          'bodies':'dodgerblue',
          'objects':'orange',
          'scenes':'limegreen',
          'characters':'purple',
          'scrambled':'navy'}

if exists(savefn) and not overwrite:
    
    results_rand = np.load(savefn,allow_pickle=True).item()
    
else:
    
    results_rand = dict()
    results_rand['acc'] = prelesion_accs
    for domain in ['faces','scenes','bodies','characters',
                                    'objects']:
        results_rand[domain] = dict()
        results_rand[domain]['lsn_acc'] = []
        results_rand[domain]['rs'] = []

    for i in progress_bar(range(n_iters)):

        LSN.randomize_selective_unit_indices()
        LSN.get_selective_unit_acts(layers = DNN.layer_names_fmt[:-1])

        torch.cuda.empty_cache()
        gc.collect()

        LSN.model.return_acts = False

        for domain in progress_bar(['faces','scenes','bodies','characters',
                                    'objects']):

            LSN.apply_channelized_lesions(domain, 
                                          method = 'relus')

            LSN.model.masks['apply'] = True

            LSN.get_imagenet_accs(topk=5)

            postlesion_accs = LSN.imagenet_accs

            results_rand[domain]['lsn_acc'].append(postlesion_accs)

            rs = []
            y = (results_rand['acc'] - postlesion_accs)

            for plot_layer in layer_list:

                x = LSN.selective_unit_acts[domain][plot_layer]

                notnan = np.logical_and(np.logical_not(np.isnan(x)),
                                       np.logical_not(np.isnan(y)))
                notinf = np.logical_and(np.logical_not(np.isinf(x)),
                                       np.logical_not(np.isinf(y)))

                valid = np.logical_and(notnan, notinf)

                if np.sum(valid) > 750:
                    r = stats.pearsonr(x[valid],y[valid])[0]
                    rs.append(r)
                else:
                    #print(domain, layer, np.sum(valid))
                    rs.append(np.nan)

            results_rand[domain]['rs'].append(rs)
            
    np.save(savefn, results_rand, allow_pickle=True)
            

In [None]:
plt.figure(figsize=(16,9))
ft = 24

for domain in ['faces','scenes','bodies','characters','objects']:
    rs = -1 * np.stack(results_rand[domain]['rs'],axis=1)
    rs_mean = np.mean(rs,axis=1)
    rs_sem = np.std(rs,axis=1) / np.sqrt(n_iters)
    
    if domain == 'characters':
        label = 'words'
    else:
        label = domain
            
    plt.plot(rs_mean,label=label,color=colors[domain],linewidth=3);
    plt.fill_between(np.arange(len(layer_list)), rs_mean - rs_sem, rs_mean + rs_sem, color=colors[domain],
                     alpha=0.3)
        
#plt.title(f'Correlation between activation and cost profiles',fontsize=ft)
#plt.ylabel('Pearson r',fontsize=ft)
plt.xticks(np.arange(len(layer_list)),np.array(layer_list),rotation=90,fontsize=ft);
    #plt.title(f'proportion of domain-selective units by layer (FDR_p = {FDR_p})\nmodel: {model_name}\nfloc set: {floc_imageset}')
plt.grid('on')
# get rid of the frame
for spine in plt.gca().spines.values():
    spine.set_visible(False)
plt.ylim([-1,0.2])
plt.plot(np.arange(len(layer_list)), np.zeros((len(layer_list),)), color='k',linewidth=3)
plt.yticks(fontsize=ft)
plt.legend(fontsize=ft-3)
plt.tight_layout()
plt.savefig(f"{figure_savedir}/readout_effect_summary_randomized.tiff")

    