In [None]:
import sys
import os
parent_dir = os.path.abspath('..')
sys.path.append(parent_dir)

from nnsight import LanguageModel
from activation_utils import SparseAct
import torch as t
import plotly.graph_objects as go
from loading_utils import load_examples
from dictionary_learning import AutoEncoder
from dictionary_learning.dictionary import IdentityDict
from ablation import run_with_ablations
from scipy import interpolate
import math
from statistics import stdev


In [None]:
device = 'cuda:0'
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=device, dispatch=True)

start_layer = 2 # explain the model starting here

In [None]:
# load submodules
submodules = []
if start_layer < 0: submodules.append(model.gpt_neox.embed_in)
for i in range(start_layer, len(model.gpt_neox.layers)):
    submodules.extend([
        model.gpt_neox.layers[i].attention,
        model.gpt_neox.layers[i].mlp,
        model.gpt_neox.layers[i]
    ])

submod_names = {
    model.gpt_neox.embed_in : 'embed'
}
for i in range(len(model.gpt_neox.layers)):
    submod_names[model.gpt_neox.layers[i].attention] = f'attn_{i}'
    submod_names[model.gpt_neox.layers[i].mlp] = f'mlp_{i}'
    submod_names[model.gpt_neox.layers[i]] = f'resid_{i}'

In [None]:
# load dictionaries
dict_id = 10

activation_dim = 512
expansion_factor = 64
dict_size = expansion_factor * activation_dim

feat_dicts = {}
ae = AutoEncoder(activation_dim, dict_size).to(device)
ae.load_state_dict(t.load(f'../dictionaries/pythia-70m-deduped/embed/{dict_id}_{dict_size}/ae.pt'))
feat_dicts[model.gpt_neox.embed_in] = ae
for i in range(len(model.gpt_neox.layers)):
    ae = AutoEncoder(activation_dim, dict_size).to(device)
    ae.load_state_dict(t.load(f'../dictionaries/pythia-70m-deduped/attn_out_layer{i}/{dict_id}_{dict_size}/ae.pt'))
    feat_dicts[model.gpt_neox.layers[i].attention] = ae

    ae = AutoEncoder(activation_dim, dict_size).to(device)
    ae.load_state_dict(t.load(f'../dictionaries/pythia-70m-deduped/mlp_out_layer{i}/{dict_id}_{dict_size}/ae.pt'))
    feat_dicts[model.gpt_neox.layers[i].mlp] = ae

    ae = AutoEncoder(activation_dim, dict_size).to(device)
    ae.load_state_dict(t.load(f'../dictionaries/pythia-70m-deduped/resid_out_layer{i}/{dict_id}_{dict_size}/ae.pt'))
    feat_dicts[model.gpt_neox.layers[i]] = ae

neuron_dicts = {
    submod : IdentityDict(activation_dim).to(device) for submod in submodules
}


In [None]:
# use mean ablation
ablation_fn = lambda x: x.mean(dim=0).expand_as(x)

In [None]:
def get_fcs(
        dataset,
        model,
        submodules,
        dictionaries,
        ablation_fn,
        thresholds,
        length,
        handle_errors = 'default', # also 'remove' or 'resid_only'
        use_neurons = False,
        random = False
):
    # load data 
    if not use_neurons:
        circuit = t.load(f'../circuits/{dataset}_train_dict10_node0.1_edge0.01_n100_aggnone.pt')['nodes']
    else:
        circuit = t.load(f'../circuits/{dataset}_train_dictid_node0.1_edge0.01_n100_aggnone.pt')['nodes']
    examples = load_examples(f'/share/projects/dictionary_circuits/data/phenomena/{dataset}_test.json', 40, model, length=length)
    clean_inputs = t.cat([e['clean_prefix'] for e in examples], dim=0).to('cuda:0')
    clean_answer_idxs = t.tensor([e['clean_answer'] for e in examples], dtype=t.long, device='cuda:0')
    patch_inputs = t.cat([e['patch_prefix'] for e in examples], dim=0).to('cuda:0')
    patch_answer_idxs = t.tensor([e['patch_answer'] for e in examples], dtype=t.long, device='cuda:0')
    def metric_fn(model):
        return (
            - t.gather(model.embed_out.output[:,-1,:], dim=-1, index=patch_answer_idxs.view(-1, 1)).squeeze(-1) + \
            t.gather(model.embed_out.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
        )
    
    with t.no_grad():
        out = {}

        # get F(M)
        with model.trace(clean_inputs):
            metric = metric_fn(model).save()
        fm = metric.value.mean().item()

        out['fm'] = fm

        # get F(∅)
        fempty = run_with_ablations(
            clean_inputs,
            patch_inputs,
            model,
            submodules,
            dictionaries,
            nodes = {
                submod : SparseAct(
                    act=t.zeros(dict_size if not use_neurons else activation_dim, dtype=t.bool), 
                    resc=t.zeros(1, dtype=t.bool)).to(device)
                for submod in submodules
            },
            metric_fn=metric_fn,
            ablation_fn=ablation_fn,
        ).mean().item()
        out['fempty'] = fempty

        for threshold in thresholds:
            out[threshold] = {}
            nodes = {
                submod : circuit[submod_names[submod]].abs() > threshold for submod in submodules
            }

            if handle_errors == 'remove':
                for k in nodes: nodes[k].resc = t.zeros_like(nodes[k].resc, dtype=t.bool)
            elif handle_errors == 'resid_only':
                for k in nodes:
                    if k not in model.gpt_neox.layers: nodes[k].resc = t.zeros_like(nodes[k].resc, dtype=t.bool)

            n_nodes = sum([n.act.sum() + n.resc.sum() for n in nodes.values()]).item()
            if random:
                total_nodes = sum([n.act.numel() + n.resc.numel() for n in nodes.values()])
                p = n_nodes / total_nodes
                for k in nodes:
                    nodes[k].act = t.bernoulli(t.ones_like(nodes[k].act, dtype=t.float) * p).to(device).to(dtype=t.bool)
                    nodes[k].resc = t.ones_like(nodes[k].resc, dtype=t.bool).to(device)
                out[threshold]['n_nodes'] = sum([n.act.sum() + n.resc.sum() for n in nodes.values()]).item()
            else:
                out[threshold]['n_nodes'] = n_nodes
            
            out[threshold]['fc'] = run_with_ablations(
                clean_inputs,
                patch_inputs,
                model,
                submodules,
                dictionaries,
                nodes=nodes,
                metric_fn=metric_fn,
                ablation_fn=ablation_fn,
            ).mean().item()
            out[threshold]['fccomp'] = run_with_ablations(
                clean_inputs,
                patch_inputs,
                model,
                submodules,
                dictionaries,
                nodes=nodes,
                metric_fn=metric_fn,
                ablation_fn=ablation_fn,
                complement=True
            ).mean().item()
            out[threshold]['faithfulness'] = (out[threshold]['fc'] - fempty) / (fm - fempty)
            out[threshold]['completeness'] = (out[threshold]['fccomp'] - fempty) / (fm - fempty)
    return out


In [None]:
datasets = {
    'rc' : 6,
    'nounpp' : 5,
    'simple' : 2,
    'within_rc' : 5
}
thresholds = t.logspace(-4, 0, 15, 10).tolist()
outs = {
    'features' : {
        dataset : get_fcs(
            dataset,
            model,
            submodules,
            feat_dicts,
            ablation_fn=ablation_fn,
            thresholds = thresholds,
            length=length,
        ) for dataset, length in datasets.items()
    },
    'features_wo_errs' : {
        dataset : get_fcs(
            dataset,
            model,
            submodules,
            feat_dicts,
            ablation_fn=ablation_fn,
            thresholds = thresholds,
            length=length,
            handle_errors='remove'
        ) for dataset, length in datasets.items()
    },
    'features_wo_some_errs' : {
        dataset : get_fcs(
            dataset,
            model,
            submodules,
            feat_dicts,
            ablation_fn=ablation_fn,
            thresholds = thresholds,
            length=length,
            handle_errors='resid_only'
        ) for dataset, length in datasets.items()
    },
    'neurons' : {
        dataset : get_fcs(
            dataset,
            model,
            submodules,
            neuron_dicts,
            ablation_fn=ablation_fn,
            thresholds = thresholds,
            length=length,
            use_neurons=True
        ) for dataset, length in datasets.items()
    },
}

In [None]:
%load_ext autoreload
%autoreload 2
import plotly.graph_objects as go


fig = go.Figure()


colors = {
    'features' : 'blue',
    'features_wo_errs' : 'red',
    'features_wo_some_errs' : 'green',
    'neurons' : 'purple',
    # 'random_features' : 'black'
}

for setting, subouts in outs.items():

    x_min = max([min(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) + 1
    x_max = min([max(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) - 1
    fs = {
        dataset : interpolate.interp1d([subouts[dataset][t]['n_nodes'] for t in thresholds], [subouts[dataset][t]['faithfulness'] for t in thresholds])
        for dataset in datasets
    }
    xs = t.logspace(math.log10(x_min), math.log10(x_max), 100, 10).tolist()

    for dataset in datasets:

        

        fig.add_trace(go.Scatter(
            x = [subouts[dataset][t]['n_nodes'] for t in thresholds],
            y = [subouts[dataset][t]['faithfulness'] for t in thresholds],
            mode='lines', line=dict(color=colors[setting]), opacity=0.17, showlegend=False
        ))

    fig.add_trace(go.Scatter(
        x=xs,
        y=[ sum([f(x) for f in fs.values()]) / len(fs) for x in xs ],
        mode='lines', line=dict(color=colors[setting]), name=setting
    ))

fig.update_xaxes(range=(0, 1700))
fig.update_yaxes(range=(0, 1.1))

fig.update_layout(
    xaxis_title='Nodes',
    yaxis_title='Faithfulness',
    width=800,
    height=375,
    # set white background color
    plot_bgcolor='rgba(0,0,0,0)',
    # add grey gridlines
    yaxis=dict(gridcolor='rgb(200,200,200)',mirror=True,ticks='outside',showline=True),
    xaxis=dict(gridcolor='rgb(200,200,200)', mirror=True, ticks='outside', showline=True),

)

# fig.show()
fig.write_image('faithfulness.pdf')

In [None]:
fig = go.Figure()

colors = {
    'features' : 'blue',
    'features_wo_errs' : 'red',
    'features_wo_some_errs' : 'green',
    'neurons' : 'purple'
}

for setting, subouts in outs.items():

    x_min = max([min(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) + 1
    x_max = min([max(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) - 1
    fs = {
        dataset : interpolate.interp1d([subouts[dataset][t]['n_nodes'] for t in thresholds], [subouts[dataset][t]['completeness'] for t in thresholds])
        for dataset in datasets
    }
    xs = t.logspace(math.log10(x_min), math.log10(x_max), 100, 10).tolist()

    for dataset in datasets:
        fig.add_trace(go.Scatter(
            x = [subouts[dataset][t]['n_nodes'] for t in thresholds],
            y = [subouts[dataset][t]['completeness'] for t in thresholds],
            mode='lines', line=dict(color=colors[setting]), opacity=0.17, showlegend=False
        ))
    fig.add_trace(go.Scatter(
        x=xs,
        y=[ sum([f(x) for f in fs.values()]) / len(fs) for x in xs ],
        mode='lines', line=dict(color=colors[setting]), name=setting
    ))

fig.update_xaxes(range=(0,300))
fig.update_yaxes(range=(-.15, 1))

fig.update_layout(
    xaxis_title='Nodes',
    yaxis_title='Faithfulness',
    width=800,
    height=375,
    # set white background color
    plot_bgcolor='rgba(0,0,0,0)',
    # add grey gridlines
    yaxis=dict(gridcolor='rgb(200,200,200)',mirror=True,ticks='outside',showline=True),
    xaxis=dict(gridcolor='rgb(200,200,200)', mirror=True, ticks='outside', showline=True),
)
# fig.show()
fig.write_image('completeness.pdf')