In [1]:
%load_ext autoreload
%autoreload 2

### Use quick start data . . .

In [2]:
import pickle
dist_df = pickle.load(open('../quick_start/dist_df.pkl','rb'))
compare_df = pickle.load(open('../quick_start/compare_df.pkl','rb'))

Tesla K40c with CUDA capability sm_35 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_61 sm_70 sm_75 compute_37.
If you want to use the Tesla K40c GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



### . . .  or use your generated data

In [2]:
import os
import torch
import pandas as pd
import numpy as np
from scipy.stats import spearmanr, pearsonr
import pickle

In [3]:
def get_circuit_accuracy(acts,target_acts,metric = 'pearson'):
    target = target_acts.flatten().numpy()
    output = acts.flatten().numpy()
    if metric == 'spearman':
        out = spearmanr(output,target).correlation
        if out is np.nan:
            out = 0.
        return out
    elif metric == 'pearson':
        out = pearsonr(output,target)[0]
        if out is np.nan:
            out = 0.
        return out
    elif metric == 'avg_diff':
        return np.mean(np.abs(output - target))
    elif metric == 'normed_diff':
        norm_factor = np.std(target)
        return np.mean(np.abs(output - target))/norm_factor
        
    else:
        print('unknown metric %s, options ["spearman","pearson","avg_diff","normed_diff"]'%metric)
    

In [9]:
target_activations = {}
extracted_data = []

failure_list = []
columns = ['model','sparsity', 'unit', 'layer', 'cum_sal','masked_spearman','mask_pearson','masked_avg_diff','masked_normed_diff','pruned_spearman','pruned_pearson','pruned_avg_diff','pruned_normed_diff', 'method', 'structure', 'batch_size', 'data_path', 'config', 'masked_target_activations', 'feature_name', 'total_collapse', 'total_params', 'masked_k', 'effective_k', 'effective_sparsity']
failures = 0
for m_name in ['alexnet','alexnet_sparse']:
    target_activations[m_name] = torch.load('target_activations/%s/imagenet_2/orig_activations.pt'%m_name)['activations']
    
    for method in ['magnitude','actxgrad','snip','FORCE']:
        if os.path.exists('extracted_circuits/%s/imagenet_2/%s'%(m_name,method)):
            for f in sorted(os.listdir('extracted_circuits/%s/imagenet_2/%s'%(m_name,method))):
                
                try:
                    data = torch.load('extracted_circuits/%s/imagenet_2/%s/%s'%(m_name,method,f))
                    data['model'] = m_name
                    if not data['total_collapse']:
                        #get accuracy metrics
                        for extract in ['masked','pruned']:
                            for met in ["spearman","pearson","avg_diff","normed_diff"]:
                                if extract == 'masked':
                                    model_out = data['masked_target_activations'][data['feature_name']]
                                else:
                                    model_out = data['pruned_target_activations'][data['layer']+':0']

                                target = target_activations[m_name][data['feature_name']]
                                data[extract+'_'+met] = get_circuit_accuracy(model_out,target, metric = met)
                        del data['pruned_model']
                        del data['masked_target_activations']
                        del data['pruned_target_activations']
                    else:
                        for extract in ['masked','pruned']:
                            for met in ["spearman","pearson","avg_diff","normed_diff"]:
                                data[extract+'_'+met] = np.nan
                    extracted_data.append(data)

                except:
                    failures+=1
                    print(m_name+'   '+method+'  '+f)
                    failure_list.append([m_name,method,f])

print('failures :%s'%str(failures))
    
    
#make data frame

columns = ['model','sparsity', 'unit', 'layer', 'cum_sal','masked_spearman','masked_pearson',
           'masked_avg_diff','masked_normed_diff','pruned_spearman','pruned_pearson','pruned_avg_diff',
           'pruned_normed_diff', 'method', 'structure', 'batch_size', 'data_path', 'config', 
        'feature_name', 'total_collapse', 'total_params', 'masked_k', 'effective_k', 'effective_sparsity']

big_list = []


for d in extracted_data:
    l = []
    for c in columns:
        try:
            l.append(d[c])
        except:
            if c == 'cum_sal':
                l.append('nan')
        
    big_list.append(l)
            
df = pd.DataFrame(big_list,columns=columns)

#cleanup

df.fillna(0, inplace=True)  
df = df.replace([np.inf, -np.inf], 0) 


df['masked_pearson_abs'] = df['masked_pearson'].apply(abs)
df['pruned_pearson_abs'] = df['pruned_pearson'].apply(abs)


#save
import pickle

pickle.dump(df,open('./extracted_circuits/circuit_data_df.pkl','wb'))

dist_df = df
compare_df = df


failures :0


### Plotting

#### Full Data Distribution

In [55]:
import plotly.express as px
import plotly.graph_objs as go

fig = go.Figure()




method = 'actxgrad'
y_col = 'masked_pearson_abs'


df_plot = dist_df.loc[dist_df['sparsity']<.6]
df_plot = df_plot.replace({'alexnet_sparse': 'sparse alexnet'})

fig = px.box(df_plot, x="sparsity", y=y_col, color="model")

    
fig.update_xaxes(autorange="reversed",title="Sparsity") 
fig.update_yaxes(title="|Pearson R|")

fig.update_layout({ 'showlegend':True,
                    'width':1200,
                    'plot_bgcolor':'rgba(255,255,255,1)',
                    'paper_bgcolor':'rgba(255,255,255,1)'
                    })
   
fig.update_layout(
    xaxis = dict(
        tickmode = 'array'
    )
)
fig.update_layout(xaxis_type='category')
                
                
fig.update_layout(legend=dict(
    yanchor="top",
    y=0.4,
    xanchor="left",
    x=0.05
))

fig.show()

#### methods comparison

In [56]:
#plot method diff average by sparsity

import plotly.graph_objects as go
import plotly.express as px

fig = go.Figure()

measure = 'masked_pearson_abs'
color_col = 'method'

method_colors = {'actgrad':px.colors.qualitative.T10[1],
                'SNIP':px.colors.qualitative.T10[3],
                'FORCE':px.colors.qualitative.T10[4],
                'magnitude':px.colors.qualitative.T10[6],
               }


df_plot = compare_df

sparsities =  list(df_plot['sparsity'].unique())
sparsities.sort()

df_plot = df_plot.replace({'alexnet_sparse': 'sparse alexnet'})
df_plot = df_plot.replace({'snip': 'SNIP'})
df_plot = df_plot.replace({'actxgrad': 'actgrad'})

models =  list(df_plot['model'].unique())
methods = list(df_plot['method'].unique())

for model in models:
    linetype = None
    if model == 'sparse alexnet':
        linetype = 'dot' 
    for method in methods:
        x = sparsities
        y = []
        for sparsity in sparsities:
            df_sel = df_plot.loc[(df_plot['sparsity'] == sparsity) & (df_plot['method'] == method) & (df_plot['model'] == model)]
            y.append(np.median(df_sel[measure]))

        fig.add_trace(go.Scatter(x=x, y=y,
                        mode='lines',
                        name=model+':'+method,
                        line=dict(color=method_colors[method],width=4,dash=linetype)
                            )
                        )

fig.update_xaxes(autorange="reversed",title = 'sparsity',gridcolor='rgb(210,210,210)')
fig.update_yaxes(title = '|Pearson R|',gridcolor='rgb(210,210,210)')

fig.update_layout({ 'showlegend':True,
                    'width':900,
                    'plot_bgcolor':'rgba(255,255,255,1)',
                    'paper_bgcolor':'rgba(255,255,255,1)'
                    })
   
fig.update_layout(
    xaxis = dict(
        tickmode = 'array',
        tickvals = [.9, .7, .5, .3, .1, .01]
    )
)
     

                
fig.update_layout(legend=dict(
    yanchor="top",
    y=0.7,
    xanchor="left",
    x=0.1
))
    
fig.show()

#### Disconnect

In [8]:
total = len(dist_df.feature_name.unique())

x = list(dist_df.sparsity.unique())

x



array([0.001, 0.005, 0.01 , 0.05 , 0.1  , 0.2  , 0.3  , 0.4  , 0.5  ,
       0.6  , 0.7  , 0.8  , 0.9  ])

In [15]:
#total disconnect

y = []
for sparsity in x:
    y.append(float(dist_df.loc[(dist_df['sparsity']==sparsity) & (dist_df['model']=='alexnet_sparse')]['total_collapse'].sum())/float(total))
    
y

[0.8069196428571429,
 0.29799107142857145,
 0.056919642857142856,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0]

In [35]:
#effective sparsity versus sparsity

import plotly.graph_objects as go

fig = go.Figure()



df_plot = dist_df.loc[(dist_df['sparsity']<.6) & (dist_df['model']=='alexnet_sparse') & (dist_df['method']=='actxgrad')]

x=[]
y=[]

for i, row in df_plot.iterrows():
    x.append(row['effective_sparsity'])
    y.append(row['sparsity'])
    



fig.add_trace(go.Scatter(x=x, y=y,
                    mode='markers',
                    name='points',
              ))
    


# add midline  
fig.add_trace(go.Scatter(x=[0,.001,.01,.1,.2,.3,.4,.5], y=[0,.001,.01,.1,.2,.3,.4,.5],
                    mode='lines',
                    name='equivalence',
              ))
    

    
fig.update_xaxes(autorange="reversed",title="Effective Sparsity",type='log') 
fig.update_yaxes(title="Sparsity",type='log')

fig.update_layout({
                    'plot_bgcolor':'rgba(255,255,255,1)',
                    'paper_bgcolor':'rgba(255,255,255,1)',
                    'showlegend':False,
                    })




In [36]:
df_plot.columns

Index(['model', 'sparsity', 'unit', 'layer', 'cum_sal', 'masked_spearman',
       'masked_pearson', 'masked_avg_diff', 'masked_normed_diff',
       'pruned_spearman', 'pruned_pearson', 'pruned_avg_diff',
       'pruned_normed_diff', 'method', 'structure', 'batch_size', 'data_path',
       'config', 'feature_name', 'total_collapse', 'total_params', 'masked_k',
       'effective_k', 'effective_sparsity', 'masked_pearson_abs',
       'pruned_pearson_abs'],
      dtype='object')

In [40]:
#masked versus pruned circuit correlations

import plotly.graph_objects as go

fig = go.Figure()



df_plot = dist_df.loc[(dist_df['sparsity']<.6) & (dist_df['model']=='alexnet_sparse') & (dist_df['method']=='actxgrad')]

x=[]
y=[]

for i, row in df_plot.iterrows():
    x.append(row['effective_sparsity'])
    y.append(row['pruned_pearson_abs']-row['masked_pearson_abs'])
    



fig.add_trace(go.Scatter(x=x, y=y,
                    mode='markers',
                    name='points',
              ))
    

    
    # add midline  
fig.add_trace(go.Scatter(x=[0,.001,.01,.1,.2,.3,.4,.5], y=[0,0,0,0,0,0,0,0],
                    mode='lines',
                    name='equivalence',
              ))

    
fig.update_xaxes(autorange="reversed",title="Effective Sparsity",type='log') 
fig.update_yaxes(title=r'$\Huge{\Delta|R|}$')

fig.update_layout({
                    'plot_bgcolor':'rgba(255,255,255,1)',
                    'paper_bgcolor':'rgba(255,255,255,1)',
                    'showlegend':False,
                    'height':1000,
                    'width':2000,
                    'font':{'size':45}
                    })


In [41]:
fig.write_image("../plots/pruned_vs_masked.png")