In [2]:
from functools import partial

from IPython.display import display, HTML, IFrame
from ipyannotations.generic import FreetextAnnotator
from superintendent import Superintendent
import jsonlines
import torch

In [3]:
model_name = 'pythia-70m-deduped' # name of the model according to neuronpedia
existing_annotation_file = 'feature-circuits-gp/annotations/pythia_annotations_new.jsonl' # .jsonl file with existing annotations
circuit_file = 'feature-circuits-gp/circuits/pythia-70m-deduped_NPZ_ambiguous_samelen_n24_aggnone_node0.1.pt' # .pt file with the circuit
API_KEY = ""
d = torch.load(circuit_file)

In [4]:
annotations = {}

if existing_annotation_file != '':
    with jsonlines.open(existing_annotation_file) as reader:
        for obj in reader:
            annotations[obj['Name']] = obj['Annotation']

def get_existing_annotation(feature_info):
    if feature_info not in annotations:
        return ''
    return annotations[feature_info]


In [5]:
def process_feature_info(feature_info, model):
    if 'pythia' in model:
        layer_info, feature = feature_info.split('/')
        if layer_info == 'embed':
            name = 'e-res-sm'
        else:
            comp, layer = layer_info.split('_')
            comp = comp[:3]
            name = f'{layer}-{comp}-sm'
        return name, feature
    elif 'gemma-2' in model:
        layer_info, feature = feature_info.split('/')
        if layer_info == 'embed':
            raise ValueError('No embedding SAEs for Gemma')
        else:
            comp, layer = layer_info.split('_')
            comp = comp[:3]
            name = f'{layer}-gemmascope-{comp}-16k'
        return name, feature

def display_examples(feature_info, model):
    sae_id, feature_idx = process_feature_info(feature_info, model)
    url = f"https://neuronpedia.org/{model}/{sae_id}/{feature_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
    display(IFrame(url, width=800, height=400))
    html = f'<span style="background-color:white; color:black"><strong>Feature: </strong>{feature_info}</span></body>'
    html += f'<span style="background-color:white; color:black"><strong> Annotation: </strong>{get_existing_annotation(feature_info)}</span></body>'
    display(HTML(html))
    
class DictSuperintendent(Superintendent):
    def __init__(self, annotations, *args, **kwargs):
        self.annotations = annotations
        super().__init__(*args, **kwargs)

    def _annotation_iterator(self):
        """The annotation loop."""
        self.children = [self.top_bar, self.labelling_widget]
        self.progressbar.bar_style = ""
        for id_, x in self.queue:

            with self._render_hold_message("Loading..."):
                self.labelling_widget.display(x)
            y = yield
            if y is not None:
                self.queue.submit(id_, y)
                if y != '':
                    self.annotations[x] = y
            self.progressbar.value = self.queue.progress

        yield self._render_finished()

class FreeTextAnnotatorWithEnter(FreetextAnnotator):
    # By default, this won't work because Shift + Enter is used to run cells in VSCode
    # But if you change either your VSCode settings or the keys used here's it'll probably work
    def _handle_keystroke(self, event):
        if event["key"] == "Enter" and event["shiftKey"]:
            self.data = self.freetext_widget.value[:-1]
            super()._handle_keystroke(event)


In [6]:
thresh = 0.1
flist = []
for loc, nodes in d['nodes'].items():
    _, features = torch.where(nodes.act.abs() > thresh)
    for f in features:
        flist.append(f"{loc}/{f.item()}")
        
flist = list(set(flist))
# comment out if you'd like to re-annotate things for which you already have annotations
flist = [x for x in flist if x not in annotations]

annotator = FreeTextAnnotatorWithEnter(display_function=partial(display_examples, model=model_name), textbox_placeholder='Type your annotation here')
data_labeller = DictSuperintendent(annotations, features=flist, labelling_widget=annotator,)

data_labeller



DictSuperintendent(children=(HBox(children=(HBox(children=(FloatProgress(value=0.0, description='Progress:', m…

In [8]:
out_file = 'feature-circuits-gp/annotations/pythia_annotations_complete.jsonl'
with jsonlines.open(out_file, 'w') as writer:
    for key, value in annotations.items():
        writer.write({'Name': key, 'Annotation': value})