In [1]:
import torch
import os

mask_folder = 'circuit_masks/alexnet_sparse/force/'



masks = []

for f in os.listdir(mask_folder):
    if '.pt' in f:
        masks.append(torch.load(mask_folder+f))


In [2]:
target_folder = 'target_activations/alexnet_sparse/imagenet_2/'

for m in masks:
    feature = m['layer']+':'+str(m['unit'])
    if not os.path.exists(target_folder+feature+'.pt'):
        torch.save(m['full_target_activations'],target_folder+feature+'.pt')


In [3]:
import numpy as np
from scipy.stats import spearmanr, pearsonr


def get_circuit_accuracy(mask_dict,metric = 'spearman'):
    target = mask_dict['full_target_activations'][mask_dict['layer']+':'+str(mask_dict['unit'])].flatten().numpy()
    output = mask_dict['pruned_target_activations'][mask_dict['layer']+':'+str(mask_dict['unit'])].flatten().numpy()
    if metric == 'spearman':
        out = spearmanr(output,target).correlation
        if out is np.nan:
            out = 0.
        return out
    elif metric == 'pearson':
        out = pearsonr(output,target)[0]
        if out is np.nan:
            out = 0.
        return out
    elif metric == 'avg_diff':
        return np.mean(np.abs(output - target))
    else:
        print('unknown metric %s, options ["spearman","pearson","avg_diff"]'%metric)
    

In [4]:
for i in range(len(masks)):
    print(masks[i]['layer'])
    print(masks[i]['structure'])
    print(masks[i]['method'])
    print(masks[i]['keep_ratio'])
    print(get_circuit_accuracy(masks[i],metric = 'pearson'))
    print('\n')

features_10
kernels
FORCE
0.2
0.7904956176007432


features_6
kernels
FORCE
0.1
0.8679699668849037


features_10
kernels
FORCE
0.001
0.182701910434121


features_8
kernels
FORCE
0.05
0.45673073395947905


features_3
kernels
FORCE
0.01
0.7623504227871774


features_8
kernels
FORCE
0.5
0.9999862468094265


features_8
kernels
FORCE
0.05
0.7535761722736042


features_3
kernels
FORCE
0.2
0.7196407410070141


features_8
kernels
FORCE
0.001
-0.2736830821034252


features_6
kernels
FORCE
0.1
0.6537393402424412


features_3
kernels
FORCE
0.001
0.06329375768732035


features_3
kernels
FORCE
0.2
0.9524578242749078


features_8
kernels
FORCE
0.005
-0.14299631817770853


features_10
kernels
FORCE
0.5
0.9999965799525401


features_3
kernels
FORCE
0.01
-0.24548852152282996


features_8
kernels
FORCE
0.1
0.7957460912094758


features_3
kernels
FORCE
0.2
0.6287352210296906


features_6
kernels
FORCE
0.5
0.9999985326649299


features_8
kernels
FORCE
0.5
0.999993599840705


features_8
kernels
FORCE
0.05




0.3210682630781655


features_6
kernels
FORCE
0.001
0.04895669874389751


features_10
kernels
FORCE
0.2
0.7762001005587295


features_10
kernels
FORCE
0.01
-0.09306644876653655


features_6
kernels
FORCE
0.5
0.9984069480359752


features_3
kernels
FORCE
0.005
0.12125788121884622


features_8
kernels
FORCE
0.01
0.47502049148410774


features_6
kernels
FORCE
0.1
0.7065540713843944


features_3
kernels
FORCE
0.2
0.8555936018683197


features_6
kernels
FORCE
0.5
0.9999998565078394


features_6
kernels
FORCE
0.01
-0.025664949374791024


features_6
kernels
FORCE
0.1
0.8080387062593645


features_10
kernels
FORCE
0.5
0.9998946352545348


features_8
kernels
FORCE
0.1
0.8926973818211792


features_3
kernels
FORCE
0.05
0.4214004542995777


features_10
kernels
FORCE
0.2
0.813367483699262


features_8
kernels
FORCE
0.01
0.32396437800959377


features_3
kernels
FORCE
0.5
0.9930439899131205


features_8
kernels
FORCE
0.001
0.2529027337574855


features_10
kernels
FORCE
0.2
0.7493848865841042


featu

-0.0933969338491892


features_8
kernels
FORCE
0.01
0.2070382090793795


features_10
kernels
FORCE
0.01
0.4122269015516997


features_6
kernels
FORCE
0.01
0.11938932911426645


features_6
kernels
FORCE
0.005
0.023468257953038654


features_3
kernels
FORCE
0.01
0.08832034196418857


features_6
kernels
FORCE
0.005
0.40006042910455514


features_8
kernels
FORCE
0.005
0.1599467957323781


features_6
kernels
FORCE
0.5
0.9999466315188545


features_3
kernels
FORCE
0.001
0.0


features_6
kernels
FORCE
0.2
0.998370860091506


features_8
kernels
FORCE
0.1
0.7031901913858118


features_8
kernels
FORCE
0.05
0.6362857886906327


features_6
kernels
FORCE
0.005
0.49464694628474226


features_10
kernels
FORCE
0.5
0.9999960414340999


features_3
kernels
FORCE
0.005
0.0


features_3
kernels
FORCE
0.05
0.7250385401394863


features_6
kernels
FORCE
0.01
0.18598328135450504


features_8
kernels
FORCE
0.5
0.9999906044815912


features_10
kernels
FORCE
0.05
0.3923925861689199


features_6
kernels
FORCE
0.1


-0.034887155990520655


features_3
kernels
FORCE
0.05
0.7184749480002753


features_3
kernels
FORCE
0.005
0.30920178986392766


features_8
kernels
FORCE
0.01
0.31476789029903857


features_3
kernels
FORCE
0.001
0.0


features_8
kernels
FORCE
0.05
0.10504476789166696


features_6
kernels
FORCE
0.01
0.23129219572395288


features_6
kernels
FORCE
0.001
-0.08156934558939734


features_3
kernels
FORCE
0.01
0.0


features_6
kernels
FORCE
0.01
0.60213918469601


features_10
kernels
FORCE
0.5
0.9999379834580204


features_8
kernels
FORCE
0.5
0.9999293576563634


features_10
kernels
FORCE
0.2
0.8018968720742004


features_10
kernels
FORCE
0.2
0.8106509689917465


features_3
kernels
FORCE
0.005
0.06961045386991258


features_3
kernels
FORCE
0.1
0.8887919117980769


features_6
kernels
FORCE
0.1
0.5761544298044515


features_6
kernels
FORCE
0.2
0.8555368587956175


features_3
kernels
FORCE
0.5
0.9687882564045747


features_3
kernels
FORCE
0.05
0.741456546443299


features_3
kernels
FORCE
0.5
0.967

0.9672833819755993


features_3
kernels
FORCE
0.1
0.5933503650873075


features_6
kernels
FORCE
0.1
0.5838355193108475


features_10
kernels
FORCE
0.01
0.13624742225130718


features_6
kernels
FORCE
0.2
0.9289779717076436


features_10
kernels
FORCE
0.05
0.24635720181181997


features_10
kernels
FORCE
0.005
-0.02019008114487635


features_3
kernels
FORCE
0.01
-0.008729066399888283


features_8
kernels
FORCE
0.05
0.5710473398125322


features_3
kernels
FORCE
0.005
0.4212752529252484


features_10
kernels
FORCE
0.01
0.13735982676663794


features_10
kernels
FORCE
0.5
0.9999568742230129


features_6
kernels
FORCE
0.01
0.19489364435664702


features_6
kernels
FORCE
0.001
0.0753837058572544


features_8
kernels
FORCE
0.1
0.7439927638181418


features_3
kernels
FORCE
0.005
0.12371833933480449


features_8
kernels
FORCE
0.1
0.8460741157993923


features_3
kernels
FORCE
0.01
0.20581648138478237


features_6
kernels
FORCE
0.1
0.8565356815062778


features_6
kernels
FORCE
0.01
0.2817671890974205

0.772690665958508


features_6
kernels
FORCE
0.2
0.8614157312324774


features_6
kernels
FORCE
0.05
0.7206967826701853


features_8
kernels
FORCE
0.01
0.010668971789019436


features_6
kernels
FORCE
0.2
0.9914286447732372


features_6
kernels
FORCE
0.005
0.22421455276727192


features_6
kernels
FORCE
0.2
0.9106575514869408


features_6
kernels
FORCE
0.05
0.21721473227784946


features_8
kernels
FORCE
0.005
0.4493765126312408


features_6
kernels
FORCE
0.01
0.6921164541468904


features_10
kernels
FORCE
0.1
0.6778019173991354


features_10
kernels
FORCE
0.005
0.09880873983358153


features_6
kernels
FORCE
0.1
0.5334030661345874


features_8
kernels
FORCE
0.005
0.19428989489681808


features_10
kernels
FORCE
0.1
0.45749386442713147


features_8
kernels
FORCE
0.01
0.16573121390906181


features_6
kernels
FORCE
0.05
0.7233142067996179


features_8
kernels
FORCE
0.2
0.9380760975767107


features_3
kernels
FORCE
0.001
0.0


features_6
kernels
FORCE
0.1
0.751743618348857


features_3
kernels


0.6699876488404081


features_10
kernels
FORCE
0.005
0.3378752545040201


features_3
kernels
FORCE
0.01
-0.04510538341518529


features_8
kernels
FORCE
0.1
0.7709053484128265


features_8
kernels
FORCE
0.2
0.9239730078067903


features_10
kernels
FORCE
0.5
0.9999845882003924


features_8
kernels
FORCE
0.01
0.14363527017236197


features_3
kernels
FORCE
0.1
0.2831840129625759


features_3
kernels
FORCE
0.05
0.8642526894714757


features_10
kernels
FORCE
0.1
0.5178788438711083


features_6
kernels
FORCE
0.001
0.04082258122287392


features_6
kernels
FORCE
0.5
0.9987856408325331


features_10
kernels
FORCE
0.001
0.12895269028497217


features_3
kernels
FORCE
0.005
-0.09211793117670354


features_6
kernels
FORCE
0.05
0.7288195213962668


features_6
kernels
FORCE
0.1
0.6772114883956468


features_10
kernels
FORCE
0.001
-0.04798004056547491


features_3
kernels
FORCE
0.005
0.040272039620573935


features_10
kernels
FORCE
0.001
-0.015525649963988423


features_3
kernels
FORCE
0.1
0.7971181250

0.7233055654951442


features_8
kernels
FORCE
0.5
0.9999917042836377


features_8
kernels
FORCE
0.5
0.9999909232453763


features_8
kernels
FORCE
0.05
0.44068279696917473


features_3
kernels
FORCE
0.1
0.9214843733844441


features_10
kernels
FORCE
0.01
-0.029660042764897664


features_6
kernels
FORCE
0.1
0.8719145332162088


features_6
kernels
FORCE
0.001
-0.03807253466703523


features_8
kernels
FORCE
0.1
0.672548598897828


features_6
kernels
FORCE
0.01
0.29956940076455163


features_3
kernels
FORCE
0.2
0.8561231316728752


features_10
kernels
FORCE
0.01
0.3011714115311854


features_10
kernels
FORCE
0.5
0.9997338393930493


features_6
kernels
FORCE
0.001
-0.02614248521369794


features_3
kernels
FORCE
0.01
0.11439319079570888


features_8
kernels
FORCE
0.05
0.42768837470668125


features_8
kernels
FORCE
0.5
0.9997604965051559


features_3
kernels
FORCE
0.001
0.03874754896526753


features_8
kernels
FORCE
0.2
0.8420191299432793


features_3
kernels
FORCE
0.1
-0.07788909491284346


f

In [31]:
masks[100]['batch_size']

200

In [5]:
import pandas as pd
import os

def gen_accuracies_df_from_masks_folder(folder_path):
    
    big_list = []
    columns = ['model_name','method','keep_ratio','T','layer','unit','id','structure','rank_field','data_path','spearman','pearson','avg_diff']
    fs = os.listdir(folder_path)
    for f in fs:

        mask = torch.load(folder_path+'/'+f)
        
        if 'structure' not in mask.keys():
            print(f)
            continue
        
        

        model_name = folder_path.split('/')[-1]
        if 'method' in mask.keys():
            method = mask['method']
        else:
            method = 'FORCE'
        
        if 'T' in mask.keys():
            T = mask['T']
        else:
            T = 1
            
        spearman = get_circuit_accuracy(mask,metric = 'spearman')
        pearson = get_circuit_accuracy(mask,metric = 'pearson')
        avg_diff = get_circuit_accuracy(mask,metric = 'avg_diff')
        
        f_id = mask['layer']+':'+str(mask['unit'])
        
        big_list.append([model_name,method,mask['keep_ratio'],T,mask['layer'],mask['unit'],f_id,mask['structure'],mask['rank_field'],mask['data_path'],spearman,pearson,avg_diff])
            
    df = pd.DataFrame(big_list,columns=columns)
    return df

In [33]:
acc_df = gen_accuracies_df_from_masks_folder(mask_folder)


An input array is constant; the correlation coefficent is not defined.



In [34]:
acc_df

Unnamed: 0,model_name,method,keep_ratio,T,layer,unit,id,structure,rank_field,data_path,spearman,pearson,avg_diff
0,,FORCE,0.200,1,features_10,13,features_10:13,kernels,image,image_data/imagenet_2/,0.747674,0.790496,1.869828
1,,FORCE,0.100,1,features_6,1,features_6:1,kernels,image,image_data/imagenet_2/,0.847219,0.867970,2.267001
2,,FORCE,0.001,8,features_10,13,features_10:13,kernels,image,image_data/imagenet_2/,0.186479,0.182702,2.526336
3,,FORCE,0.050,1,features_8,2,features_8:2,kernels,image,image_data/imagenet_2/,0.433226,0.456731,3.065464
4,,FORCE,0.010,8,features_3,4,features_3:4,kernels,image,image_data/imagenet_2/,0.758324,0.762350,5.562233
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1115,,FORCE,0.500,8,features_6,19,features_6:19,kernels,image,image_data/imagenet_2/,1.000000,1.000000,0.000869
1116,,FORCE,0.050,8,features_10,3,features_10:3,kernels,image,image_data/imagenet_2/,0.496394,0.509706,1.897411
1117,,FORCE,0.200,8,features_3,13,features_3:13,kernels,image,image_data/imagenet_2/,0.885753,0.953293,1.973063
1118,,FORCE,0.005,8,features_6,7,features_6:7,kernels,image,image_data/imagenet_2/,0.091203,0.111557,5.156613


In [35]:
#compare t8 and t1

t8_df = acc_df.loc[acc_df['T']==8]
t1_df = acc_df.loc[acc_df['T']==1]

print('total')
print(np.mean(t8_df.pearson))
print(np.mean(t1_df.pearson))

for r in [.5,.2,.1,.05,.01,.005,.001]:
    print('r: %s'%r)
    print(np.mean(t8_df.loc[t8_df['keep_ratio']==r].pearson))
    print(np.mean(t1_df.loc[t1_df['keep_ratio']==r].pearson))


total
0.5134276231992592
0.4151002603420854
r: 0.5
0.9920266083318298
0.9873781184693513
r: 0.2
0.8734611923457141
0.808966083474607
r: 0.1
0.7003036756975304
0.541398506104701
r: 0.05
0.477032520900881
0.36879792701723196
r: 0.01
0.2925920997174039
0.12447133276550311
r: 0.005
0.17665923933557087
0.05294488449967873
r: 0.001
0.08191802606588365
0.02174497006352468


In [None]:


print('total')
print(np.mean(df.correlations))


for r in [.5,.2,.1,.05,.01,.005,.001]:
    print('r: %s'%r)
    print(np.mean(df.loc[df['keep_ratios']==r].correlations))



In [None]:
import plotly.express as px

fig = px.scatter(acc_df, x="keep_ratio", y="pearson", color="id",
                 log_x=True,custom_data=['id','structure','method','T'])
fig.update_traces(
    hovertemplate="<br>".join([
        "id: %{customdata[0]}",
        "structure: %{customdata[1]}",
        "method: %{customdata[2]}",
        "T: %{customdata[3]}"
    ])
)
fig.show()

In [None]:
import plotly.express as px

fig = px.scatter(acc_df, x="keep_ratio", y="pearson", color="structure",
                 log_x=True))
fig.show()

In [None]:

import matplotlib.pyplot as plt
   

acc_df.loc[acc_df['structure']=='kernels'].plot(x ='keep_ratio', y='pearson', kind = 'scatter', logx=True)
plt.show()

In [None]:
acc_df.loc[acc_df['structure']=='kernels'].sort_values(by=['id'])

In [None]:
m2 = []

for m in masks:
    if m['layer'] == 'features_6' and m['unit'] ==0 and m['T'] == 1:
        m2.append(m)

In [None]:
m3 = []

for m in masks:
    if m['layer'] == 'features_6' and m['unit'] ==0:
        m3.append(m)

In [None]:
for m in m2:
    print(str(m['keep_ratio'])+ ' '+str(m['structure'])+' '+str(m['T']))
    

for m in m2:    
    print(m['pruned_target_activations']['features_6:0'][0,0,0])

In [None]:
print(m2[0]['mask'])

In [6]:
def get_mask_ratios_by_layer(mask):
    out = []
    for l in range(len(mask)):
        arr = mask[l]
        out.append(float((arr == 1.).sum())/float(arr.flatten().size()[0]))
    return out
        


In [7]:
def get_mask_ratio_all_layers(mask):
    total = 0
    ones = 0
    for l in range(len(mask)-2):
        arr = mask[l]
        ones += float((arr == 1.).sum())
        total+= float(arr.flatten().size()[0])
    return ones/total
        

In [None]:
for mask in m2:
    print(mask['keep_ratio'])
    print(mask['method'])
    print(mask['structure'])
    print(get_mask_ratios_by_layer(mask['mask']))
    print(get_mask_ratio_all_layers(mask['mask']))
    print(mask['mask'][0].sum()+mask['mask'][1].sum()+mask['mask'][2].sum())


In [None]:
for mask in m3:
    print(mask['keep_ratio'])
    print(mask['method'])
    print(mask['structure'])
    print(get_mask_ratios_by_layer(mask['mask']))
    print(get_mask_ratio_all_layers(mask['mask']))
    print(mask['mask'][0].sum()+mask['mask'][1].sum()+mask['mask'][2].sum())
    


In [None]:
for m in m3:
    print(m['mask'])

## cummulative sparsity versus sparsity

### generate MASK

In [8]:
import torch
from circuit_pruner.force import *
from circuit_pruner.custom_exceptions import TargetReached
import time
import os
from circuit_pruner.utils import update_sys_path
import torch.utils.data as data
import torchvision.datasets as datasets
from circuit_pruner.data_loading import rank_image_data
from circuit_pruner.dissected_Conv2d import *
from copy import deepcopy

def mask_from_sparsity(rank_list, k):

    all_scores = torch.cat([torch.flatten(x) for x in rank_list])
    norm_factor = torch.sum(all_scores)
    all_scores.div_(norm_factor)

    threshold, _ = torch.topk(all_scores, k, sorted=True)
    acceptable_score = threshold[-1]

    mask = []

    for g in rank_list:
        mask.append(((g / norm_factor) >= acceptable_score).float())
        
    return mask

def mask_from_cum_salience(rank_list, cum_sal):

    all_scores = torch.cat([torch.flatten(x) for x in rank_list])
    norm_factor = torch.sum(all_scores)
    all_scores.div_(norm_factor)

    
    all_scores_sorted = torch.sort(all_scores, descending=True).values
    
    cum_total = 0.
    for i in range(len(all_scores_sorted)):
        cum_total += all_scores_sorted[i]
        if cum_total > cum_sal:
            print(i)
            threshold = all_scores_sorted[i]
            print(threshold)
            break
            

    mask = []

    for g in rank_list:
        mask.append(((g / norm_factor) >= threshold).float())
        
    return mask,i




In [None]:

cum_sparsities = [.99,.95]
for i in range(18):
    cum_sparsities.append(round(cum_sparsities[-1]-.05,2))

## setup

In [None]:
layer = 'features_3'
unit = 11

feature_name = layer+':'+str(unit)
method = 'actxgrad'
structure = 'edges'
device = 'cuda:0'
batch_size = 200


#get ranks

ranks_folder = 'circuit_ranks/alexnet_sparse/actgrad/'

for f in os.listdir(ranks_folder):
    if feature_name in f:
        layer_ranks = torch.load(ranks_folder+f)
        break
        
rank_list = []

for l in range(len(layer_ranks['ranks'][structure][method])):
    print(layer_ranks['ranks'][structure][method][l][0])
    rank_list.append(torch.tensor(layer_ranks['ranks'][structure][method][l][1]))

    
#params

config = layer_ranks['config']

if '/' in config:
    config_root_path = ('/').join(config.split('/')[:-1])
    update_sys_path(config_root_path)
config_module = config.split('/')[-1].replace('.py','')
params = __import__(config_module)



#target_activations

target_activations = torch.load('target_activations/alexnet_sparse/imagenet_2/'+feature_name+'.pt')
    

    
#model

model = params.model

feature_target = {layer:[unit]}


pruned_model = deepcopy(model)
pruned_model = pruned_model.to(device)

setup_net_for_circuit_prune(pruned_model, feature_targets=feature_target, rank_field = 'image',save_target_activations=True)

pruned_model = pruned_model.to(device)

reset_masks_in_net(pruned_model)


#dataloader
kwargs = {'num_workers': params.num_workers, 'pin_memory': True, 'sampler':None} if 'cuda' in device else {}
dataloader = data.DataLoader(rank_image_data(params.data_path,
                                            params.preprocess,
                                            label_file_path = params.label_file_path,class_folders=True),
                                            batch_size=batch_size,
                                            shuffle=False,
                                            **kwargs)



#total params


total_params = 0
for l in pruned_model.modules():
    if isinstance(l, nn.Conv2d):
        if not l.last_layer:  #all params potentially relevant
            if structure in ['kernels','edges']:
                total_params += int(l.weight.shape[0]*l.weight.shape[1])
            else:
                total_params += int(l.weight.shape[0])

        else: #only weights leading into feature targets are relevant
            if structure in ['kernels','edges']:
                total_params += int(len(l.feature_targets_indices)*l.weight.shape[1])
            else:
                total_params += len(l.feature_targets_indices)

            break



### set up mask

In [None]:
#set up cum sals




cum_sals = [.99,.95]
for i in range(18):
    cum_sals.append(round(cum_sals[-1]-.05,2))

    
pearsons = []
sparsities = []
for cum_sal in cum_sals:

    #setup mask
    mask,k = mask_from_cum_salience(rank_list, cum_sal)
    sparsities.append(float(k)/total_params)


    if structure is not 'weights':
        expanded_mask = expand_structured_mask(mask,pruned_model) #this weight mask will get applied to the network on the next iteration
    else:
        expanded_mask = mask

    for l in expanded_mask:
        l = l.to(device)

    #import pdb; pdb.set_trace()
    ###GET ACTIVATIONS FROM PRUNED MODEL
    #get feature outputs from pruned model
    if structure == 'filters':
        reset_masks_in_net(pruned_model)
        apply_filter_mask(pruned_model,mask) #different than masking weights, because it also masks biases
    else:
        apply_mask(pruned_model,expanded_mask) 
        
        
    #run model
    save_target_activations_in_net(pruned_model,save=True)

    iter_dataloader = iter(dataloader)
    iters = len(iter_dataloader)



    pruned_target_activations = {}

    for it in range(iters):
        #clear_feature_targets_from_net(pruned_model)

        # Grab a single batch from the training dataset
        inputs, targets = next(iter_dataloader)
        inputs = inputs.to(device)

        pruned_model.zero_grad()

        #Run model forward until all targets reached
        try:
            outputs = pruned_model.forward(inputs)
        except:
            #except:
            pass


        activations = get_saved_target_activations_from_net(pruned_model)
        for l in activations:
            activations[l] = activations[l].to('cpu')
            if l not in pruned_target_activations.keys():
                pruned_target_activations[l] = activations[l]
            else:
                pruned_target_activations[l] = torch.cat((pruned_target_activations[l],activations[l]),dim=0)

    #compare
    cor = pearsonr(target_activations[feature_name].flatten().numpy(),pruned_target_activations[feature_name].flatten().numpy())[0]
    if cor == np.nan:
        cor = 0.
    print(cor)
    
    pearsons.append(cor)

In [None]:
#setup mask
cum_sal = .4
mask,k = mask_from_cum_salience(rank_list, cum_sal)
#sparsities.append(float(k)/total_params)


if structure is not 'weights':
    expanded_mask = expand_structured_mask(mask,pruned_model) #this weight mask will get applied to the network on the next iteration
else:
    expanded_mask = mask

for l in expanded_mask:
    l = l.to(device)

#import pdb; pdb.set_trace()
###GET ACTIVATIONS FROM PRUNED MODEL
#get feature outputs from pruned model
if structure == 'filters':
    reset_masks_in_net(pruned_model)
    apply_filter_mask(pruned_model,mask) #different than masking weights, because it also masks biases
else:
    apply_mask(pruned_model,expanded_mask) 


#run model
save_target_activations_in_net(pruned_model,save=True)

iter_dataloader = iter(dataloader)
iters = len(iter_dataloader)



pruned_target_activations = {}

for it in range(iters):
    #clear_feature_targets_from_net(pruned_model)

    # Grab a single batch from the training dataset
    inputs, targets = next(iter_dataloader)
    inputs = inputs.to(device)

    pruned_model.zero_grad()

    #Run model forward until all targets reached
    try:
        outputs = pruned_model.forward(inputs)
    except:
        #except:
        pass


    activations = get_saved_target_activations_from_net(pruned_model)
    for l in activations:
        activations[l] = activations[l].to('cpu')
        if l not in pruned_target_activations.keys():
            pruned_target_activations[l] = activations[l]
        else:
            pruned_target_activations[l] = torch.cat((pruned_target_activations[l],activations[l]),dim=0)

#compare
cor = pearsonr(target_activations[feature_name].flatten().numpy(),pruned_target_activations[feature_name].flatten().numpy())[0]
if cor == np.nan:
    cor = 0.
print(cor)

In [None]:
import plotly.express as px

fig = px.scatter(x=target_activations[feature_name].flatten().numpy(), y=pearsons)
fig.update_xaxes(autorange="reversed")
fig.show()

In [None]:
import plotly.express as px

fig = px.scatter(x=cum_sals, y=pearsons)
fig.update_xaxes(autorange="reversed")
fig.show()

In [None]:
import plotly.express as px

fig = px.scatter(x=sparsities, y=pearsons)
fig.update_xaxes(autorange="reversed")
fig.show()

In [None]:
total_params

### run model

In [None]:
start = time.time()

save_target_activations_in_net(pruned_model,save=True)

iter_dataloader = iter(dataloader)
iters = len(iter_dataloader)



pruned_target_activations = {}

for it in range(iters):
    #clear_feature_targets_from_net(pruned_model)

    # Grab a single batch from the training dataset
    inputs, targets = next(iter_dataloader)
    inputs = inputs.to(device)

    pruned_model.zero_grad()

    #Run model forward until all targets reached
    try:
        outputs = pruned_model.forward(inputs)
    except:
        #except:
        pass


    activations = get_saved_target_activations_from_net(pruned_model)
    for l in activations:
        activations[l] = activations[l].to('cpu')
        if l not in pruned_target_activations.keys():
            pruned_target_activations[l] = activations[l]
        else:
            pruned_target_activations[l] = torch.cat((pruned_target_activations[l],activations[l]),dim=0)

        
print(time.time()-start)

### compare

In [None]:
pearsonr(target_activations[feature].flatten().numpy(),pruned_target_activations[feature].flatten().numpy())[0]

### Get feature responses

In [None]:
#params

config = layer_ranks['config']

if '/' in config:
    config_root_path = ('/').join(args.config.split('/')[:-1])
    update_sys_path(config_root_path)
config_module = args.config.split('/')[-1].replace('.py','')
params = __import__(config_module)




layer = args.layer
unit = args.unit
ratios = args.ratio
device= args.device
rank_field = args.rank_field
structure = args.structure


if args.data_path is None:
    data_path = params.data_path
else:
    data_path = args.data_path

if args.batch_size is None:
    batch_size = params.batch_size
else:
    batch_size = args.batch_size












#target_activations

target_activations = torch.load('target_activations/alexnet_sparse/imagenet_2/'+feature+'.pt')






#model

model = params.model



## plotting cumulative salience

In [19]:
import os
import torch
import pandas as pd
import numpy as np

def load_acc_df_from_folder(folder,keep_ratios = None):
    cum_sal_datas = []
    for f in os.listdir(folder):

        cum_sal_datas.append(torch.load(folder+f))
        

        
        
    
    #fix

    cum_sals = [.99,.95]
    for i in range(18):
        cum_sals.append(round(cum_sals[-1]-.05,2))


    for f in os.listdir(cumulative_folder):
        d = torch.load(cumulative_folder+f)
        if 'cum_sals' not in d.keys():
            d['cum_sals'] = cum_sals
            torch.save(d,cumulative_folder+f)    
            
            
    biglist = []
            
    columns = ['correlations','keep_ratios','cum_sals','method','layer','unit','feature_name','batch_size','data_path']


    ind = -1
    for d in cum_sal_datas:
        ind+=1
        feature_name = d['layer']+':'+str(d['unit'])
        if not (0 <= d['correlations'][0] <= 1):
            print(ind)
            continue
        for i in range(len(d['correlations'])):
            cor = d['correlations'][i]
            ratio = d['keep_ratios'][i]
            cum_sal = d['cum_sals'][i]
            biglist.append([cor,ratio,cum_sal,d['method'],d['layer'],
                 d['unit'],feature_name,d['batch_size'],d['data_path']])



    df = pd.DataFrame(biglist,columns=columns)


    
    #cleanup

    df.fillna(0, inplace=True)  
    df = df.replace([np.inf, -np.inf], 0) 
    
    if keep_ratios:
        df_out = pd.DataFrame([],columns=columns)
        for i in keep_ratios:
            df_out = df_out.append(df.loc[df['keep_ratios']==i]) 
    else:
        df_out = df
        
    return df_out
    
    

cumulative_folder = 'cum_salience_accuracies/alexnet_sparse/imagenet_2/actxgrad/'



cum_sal_datas = []
for f in os.listdir(cumulative_folder):
 
    cum_sal_datas.append(torch.load(cumulative_folder+f))

In [20]:
#fix

cum_sals = [.99,.95]
for i in range(18):
    cum_sals.append(round(cum_sals[-1]-.05,2))


for f in os.listdir(cumulative_folder):
    d = torch.load(cumulative_folder+f)
    if 'cum_sals' not in d.keys():
        d['cum_sals'] = cum_sals
        torch.save(d,cumulative_folder+f)

In [21]:
import pandas as pd

biglist = []

columns = ['correlations','keep_ratios','cum_sals','method','layer','unit','feature_name','batch_size','data_path']


ind = -1
for d in cum_sal_datas:
    ind+=1
    feature_name = d['layer']+':'+str(d['unit'])
    if not (0 <= d['correlations'][0] <= 1):
        print(ind)
        continue
    for i in range(len(d['correlations'])):
        cor = d['correlations'][i]
        ratio = d['keep_ratios'][i]
        cum_sal = d['cum_sals'][i]
        biglist.append([cor,ratio,cum_sal,d['method'],d['layer'],
             d['unit'],feature_name,d['batch_size'],d['data_path']])
    

            
df = pd.DataFrame(biglist,columns=columns)

16
33
37
46
75
101
144
157


In [22]:
#cleanup

df.fillna(0, inplace=True)  
df = df.replace([np.inf, -np.inf], 0) 
df

Unnamed: 0,correlations,keep_ratios,cum_sals,method,layer,unit,feature_name,batch_size,data_path
0,0.999999,0.500000,tensor(0.9995),actxgrad,features_6,4,features_6:4,200,/mnt/data/chris/dropbox/Research-Hamblin/Proje...
1,0.983820,0.200000,tensor(0.9484),actxgrad,features_6,4,features_6:4,200,/mnt/data/chris/dropbox/Research-Hamblin/Proje...
2,0.896475,0.100000,tensor(0.8449),actxgrad,features_6,4,features_6:4,200,/mnt/data/chris/dropbox/Research-Hamblin/Proje...
3,0.674877,0.050000,tensor(0.7279),actxgrad,features_6,4,features_6:4,200,/mnt/data/chris/dropbox/Research-Hamblin/Proje...
4,0.276994,0.010000,tensor(0.4995),actxgrad,features_6,4,features_6:4,200,/mnt/data/chris/dropbox/Research-Hamblin/Proje...
...,...,...,...,...,...,...,...,...,...
2047,0.233624,0.031250,0.25,actxgrad,features_3,0,features_3:0,200,/mnt/data/chris/dropbox/Research-Hamblin/Proje...
2048,0.000000,0.023438,0.2,actxgrad,features_3,0,features_3:0,200,/mnt/data/chris/dropbox/Research-Hamblin/Proje...
2049,0.000000,0.011719,0.15,actxgrad,features_3,0,features_3:0,200,/mnt/data/chris/dropbox/Research-Hamblin/Proje...
2050,0.000000,0.003906,0.1,actxgrad,features_3,0,features_3:0,200,/mnt/data/chris/dropbox/Research-Hamblin/Proje...


In [29]:
import plotly.express as px

fig = px.scatter(df, x="keep_ratios", y="correlations", color="feature_name",
                 log_x=True, custom_data=['feature_name','layer','cum_sals','keep_ratios'])
# fig.update_traces(
#     hovertemplate="<br>".join([
#         "feature_name: %{customdata[0]}",
#         "layer: %{customdata[1]}",
#         "cum_sals: %{customdata[2]}",
#         "keep_ratios: %{customdata[3]}"
#     ])
# )
fig.update_xaxes(autorange="reversed")

fig.show()

In [None]:
fig = px.scatter(df, x="keep_ratios", y="correlations", color="layer", log_x = True)
# fig.update_traces(
#     hovertemplate="<br>".join([
#         "feature_name: %{customdata[0]}",
#         "layer: %{customdata[1]}",
#         "cum_sals: %{customdata[2]}",
#         "keep_ratios: %{customdata[2]}"
#     ])
# )
fig.update_xaxes(autorange="reversed")

fig.show()

In [None]:
fig = px.scatter(df, x="keep_ratios", y="correlations", color="layer", log_x = True)
# fig.update_traces(
#     hovertemplate="<br>".join([
#         "feature_name: %{customdata[0]}",
#         "layer: %{customdata[1]}",
#         "cum_sals: %{customdata[2]}",
#         "keep_ratios: %{customdata[2]}"
#     ])
# )
fig.update_xaxes(autorange="reversed")

fig.show()

In [None]:
highend= df.loc[df['cum_sals']>.75]

pearsonr(highend['cum_sals'],highend['keep_ratios'])

In [None]:
pearsonr(highend['cum_sals'],highend['correlations'])

In [None]:
fig = px.scatter(df, x="cum_sals", y="keep_ratios", color="layer",log_y = True)
# fig.update_traces(
#     hovertemplate="<br>".join([
#         "feature_name: %{customdata[0]}",
#         "layer: %{customdata[1]}",
#         "cum_sals: %{customdata[2]}",
#         "keep_ratios: %{customdata[2]}"
#     ])
# )
fig.update_xaxes(autorange="reversed")
fig.update_yaxes(autorange="reversed")
fig.show()

In [None]:
import plotly.express as px

fig = px.scatter(highend, x="keep_ratios", y="correlations", color="feature_name",
                 log_x=True, custom_data=['feature_name','layer','cum_sals','keep_ratios'])
fig.update_traces(
    hovertemplate="<br>".join([
        "feature_name: %{customdata[0]}",
        "layer: %{customdata[1]}",
        "cum_sals: %{customdata[2]}",
        "keep_ratios: %{customdata[3]}"
    ])
)
fig.update_xaxes(autorange="reversed")

fig.show()

In [None]:
fig = px.scatter(df, x="cum_sals", y="correlations", color="feature_name")
# fig.update_traces(
#     hovertemplate="<br>".join([
#         "feature_name: %{customdata[0]}",
#         "layer: %{customdata[1]}",
#         "cum_sals: %{customdata[2]}",
#         "keep_ratios: %{customdata[2]}"
#     ])
# )
fig.update_xaxes(autorange="reversed")

fig.show()

In [None]:
fig = px.scatter(df, x="cum_sals", y="keep_ratios", color="layer",log_y = True)
# fig.update_traces(
#     hovertemplate="<br>".join([
#         "feature_name: %{customdata[0]}",
#         "layer: %{customdata[1]}",
#         "cum_sals: %{customdata[2]}",
#         "keep_ratios: %{customdata[2]}"
#     ])
# )
fig.update_xaxes(autorange="reversed")
fig.update_yaxes(autorange="reversed")
fig.show()

In [None]:
pearsonr(df['keep_ratios'],df['correlations'])

In [None]:
log_ratios = df['keep_ratios'].replace(np.inf, 10e10).replace(-np.inf, -10e10)

In [None]:
pearsonr(log_ratios,df['correlations'])

In [None]:
pearsonr(df['cum_sals'],df['correlations'])

In [None]:
np.log(df['keep_ratios'])

In [None]:
acts = torch.load('./circuit_ranks/alexnet_sparse/imagenet_2/actxgrad/alexnet_sparse_features_6:11_1636834866.3155613.pt')

In [None]:
df = load_acc_df_from_folder('cum_salience_accuracies/alexnet/imagenet_2/snip/')

### average plots

In [14]:
###Plotting all data
import plotly.graph_objs as go
from plotly.graph_objs import *



# fig = go.Figure([
#     go.Scatter(
#         x=x,
#         y=y,
#         line=dict(color='rgb(0,100,80)'),
#         mode='lines'
#     ),
#     go.Scatter(
#         x=x+x[::-1], # x, then x reversed
#         y=y_upper+y_lower[::-1], # upper, then lower reversed
#         fill='toself',
#         fillcolor='rgba(0,100,80,0.2)',
#         line=dict(color='rgba(255,255,255,0)'),
#         hoverinfo="skip",
#         showlegend=False
#     )
# ])


data = {'actxgrad':{},
        'snip':{}
       }


folders = {'actxgrad':'cum_salience_accuracies/alexnet/imagenet_2/actxgrad/',
           'snip':'cum_salience_accuracies/alexnet/imagenet_2/snip/'}



colors = {'actxgrad':[45, 55, 196],
          'snip':[201, 32, 32]}



for model in folders:
    df = load_acc_df_from_folder(folders[model],keep_ratios = [.5,.2,.1,.05,.01,.005,.001])
    data[model]['x'] = list(df['keep_ratios'])
    data[model]['y'] = list(df['correlations'])

print(len(data[model]['x']))

   
def bin_data(data_x,data_y,bin_size=20):
    xs = []
    ys = []
    stds = []
    zipped_data = zip(data_x, data_y)
    sorted_zipped_data = sorted(zipped_data)
    bin_i = 1
    curr_ys = []
    x_sum=0
    for d in sorted_zipped_data:
        bin_i+=1
        curr_ys.append(d[1])
        x_sum+=d[0]
        if bin_i == bin_size:
            xs.append(x_sum/bin_size)
            ys.append(np.mean(np.array(curr_ys)))
            stds.append(np.std(np.array(curr_ys)))
            bin_i = 1
            curr_ys = []
            x_sum=0
    return xs,ys,stds



binned_data = {}

for folder in folders:
    xs,ys,stds = bin_data(data[folder]['x'],data[folder]['y'],bin_size=80)
    binned_data[folder] = [xs,ys,stds]


    

    
fig = go.Figure()
for model in binned_data:
    color = 'rgb(%s, %s, %s)'%(colors[model][0],colors[model][1],colors[model][2])
    color_fill = 'rgba(%s, %s, %s, 0.3)'%(colors[model][0],colors[model][1],colors[model][2])
    
    fig.add_trace(
        go.Scatter(
        name=model,
        x=binned_data[model][0],
        y=binned_data[model][1],
        mode='lines',
        line=dict(color=color),
        )                 
    )
    fig.add_trace(
        go.Scatter(
        name='Upper Bound %s'%model,
        x=binned_data[model][0],
        y=np.array(binned_data[model][1])+np.array(binned_data[model][2]),
        mode='lines',
        marker=dict(color="#444"),
        line=dict(width=0),
        showlegend=False
    )          
    )
    fig.add_trace(
        go.Scatter(
        name='Lower Bound %s'%model,
        x=binned_data[model][0],
        y=np.array(binned_data[model][1])-np.array(binned_data[model][2]),
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        fillcolor=color_fill,
        fill='tonexty',
        showlegend=False
    )       
    )
    
    
# fig = go.Figure([
#     go.Scatter(
#         name='worst subgraphs',
#         x=worst_xs,
#         y=worst_ys,
#         mode='lines',
#         line=dict(color='rgb(201, 32, 32)'),
#     ),
#     go.Scatter(
#         name='Upper Bound worst',
#         x=xs,
#         y=np.array(worst_ys)+np.array(worst_stds),
#         mode='lines',
#         marker=dict(color="#444"),
#         line=dict(width=0),
#         showlegend=False
#     ),
#     go.Scatter(
#         name='Lower Bound worst',
#         x=xs,
#         y=np.array(worst_ys)-np.array(worst_stds),
#         marker=dict(color="#444"),
#         line=dict(width=0),
#         mode='lines',
#         fillcolor='rgba(201, 32, 32, 0.3)',
#         fill='tonexty',
#         showlegend=False
#     ),
#     go.Scatter(
#         name='random subgraphs',
#         x=random_xs,
#         y=random_ys,
#         mode='lines',
#         line=dict(color='rgb(45, 55, 196)'),
#     ),
#     go.Scatter(
#         name='Upper Bound random',
#         x=xs,
#         y=np.array(random_ys)+np.array(random_stds),
#         mode='lines',
#         marker=dict(color="#444"),
#         line=dict(width=0),
#         showlegend=False
#     ),
#     go.Scatter(
#         name='Lower Bound random',
#         x=xs,
#         y=np.array(random_ys)-np.array(random_stds),
#         marker=dict(color="#444"),
#         line=dict(width=0),
#         mode='lines',
#         fillcolor='rgba(45, 55, 196, 0.3)',
#         fill='tonexty',
#         showlegend=False
#     ),

#     go.Scatter(
#         name='best subgraphs',
#         x=best_xs,
#         y=best_ys,
#         mode='lines',
#         line=dict(color='rgb(50, 173, 61)'),
#     ),
#     go.Scatter(
#         name='Upper Bound best',
#         x=best_xs,
#         y=np.array(best_ys)+np.array(best_stds),
#         mode='lines',
#         marker=dict(color="#444"),
#         line=dict(width=0),
#         showlegend=False
#     ),
#     go.Scatter(
#         name='Lower Bound best',
#         x=best_xs,
#         y=np.array(best_ys)-np.array(best_stds),
#         marker=dict(color="#444"),
#         line=dict(width=0),
#         mode='lines',
#         fillcolor='rgba(50, 173, 61, 0.3)',
#         fill='tonexty',
#         showlegend=False
#     ),


# ])
# fig.update_layout(
#     yaxis_title='Wind speed (m/s)',
#     title='Continuous, variable value error bars',
#     hovermode="x"
# )
# fig.show()


# fig = go.Figure([
#     go.Scatter(x = data['best']['x'],
#                y = data['best']['y'],
#                mode='markers',
#                name = 'best subgraph'),
#     go.Scatter(x = data['random']['x'],
#                y = data['random']['y'],
#                mode='markers',
#                name = 'random subgraph'),
#     go.Scatter(x = data['worst']['x'],
#                y = data['worst']['y'],
#                mode='markers',
#                name='worst subgraph'),   
# ])



layout = Layout(
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)'
)

fig.layout = layout
fig.update_xaxes(type="log",title_text='Size Ratio (log scale)',autorange="reversed")
fig.update_yaxes(title_text='Pearson R of Activations')
fig.update_layout(yaxis_range=[-.3,1],legend=dict(
    yanchor="top",
    y=0.99,
    xanchor="right",
    x=0.99),
    width=900,
    height=600)
fig

6
11
57
59
67
88
112
132
0
30
105
115
143
147
156
158
532


## check for mask 'collapse'

In [None]:
import torch
import os

mask_folder = 'circuit_masks/alexnet_sparse/force/'



masks = []

for f in os.listdir(mask_folder):
    if '.pt' in f:
        masks.append(torch.load(mask_folder+f))

In [None]:
from circuit_pruner.force import *
import time
import os
from circuit_pruner.utils import update_sys_path



##DATA LOADER###
import torch.utils.data as data
import torchvision.datasets as datasets
from circuit_pruner.data_loading import rank_image_data
from circuit_pruner.dissected_Conv2d import *
from copy import deepcopy


config = 'configs/alexnet_sparse_config.py'

if '/' in config:
    config_root_path = ('/').join(config.split('/')[:-1])
    update_sys_path(config_root_path)
config_module = config.split('/')[-1].replace('.py','')
params = __import__(config_module)




device= 'cuda:0'
model = deepcopy(params.model).to(device)

for m in masks:
    if m['keep_ratio'] == .001:
        break
        



T = 1
ratio = m['keep_ratio']


structure = m['structure']

data_path = m['data_path']


batch_size = 200



feature_target = {m['layer']:[m['unit']]}


kwargs = {'num_workers': params.num_workers, 'pin_memory': True, 'sampler':None} if 'cuda' in device else {}
dataloader = data.DataLoader(rank_image_data(data_path,
                                            params.preprocess,
                                            label_file_path = params.label_file_path,class_folders=True),
                                            batch_size=batch_size,
                                            shuffle=False,
                                            **kwargs)




ranks  = circuit_snip_rank(model, dataloader, feature_targets = feature_target, feature_targets_coefficients = None, full_dataset = True, device=device, criterion= None, setup_net=True,rank_field='image',mask=m['mask'])

In [None]:
x = list(range(10))