In [22]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from dictionary_learning.interp import examine_dimension
from dictionary_learning.utils import zst_to_generator
import torch as t
import gc
import numpy as np
from dictionary_learning.dictionary import GatedAutoEncoder, JumpReLUSAE
from dictionary_learning.buffer import ActivationBuffer
from sae_lens import SparseAutoencoder
from tokenizers.processors import TemplateProcessing
from huggingface_hub import hf_hub_download

model_name = "google/gemma-2-2b"
# model_name = "meta-llama/Meta-Llama-3-8B"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=t.float16,
                             device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="right")
layer = 13
submodule_name = f"model.layers.{layer}"
submodule = model.model.layers[layer]

dictionaries = {}
# ae = GatedAutoEncoder(4096, 32768).half().to("cuda")
# ae.load_state_dict(t.load(f'llama_saes/layer{layer}/ae_81920.pt'))
# dictionaries[submodule] = ae
path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2b-pt-res",
    filename=f"layer_{layer}/width_16k/canonical/params.npz",
    force_download=False,
)
params = np.load(path_to_params)
pt_params = {k: t.from_numpy(v).cuda() for k, v in params.items()}
ae = JumpReLUSAE(params["W_enc"].shape[0], params["W_enc"].shape[1]).to("cuda")
ae.load_state_dict(pt_params)
ae = ae.half()
dictionaries[submodule] = ae

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

cuda


In [23]:
tokenizer.encode("This is a string.")

[2, 1596, 603, 476, 2067, 235265]

In [24]:
tokenizer.convert_ids_to_tokens(tokenizer.encode(" is", add_special_tokens=False))

['▁is']

In [35]:
import pandas as pd
from datasets import load_dataset
from collections import deque
from dictionary_learning.interp_utils import *

max_length = 64

def parse_and_load_text(indata, tokenizer, max_length=128, space_char="▁"):
    def _parse_morphosyn_feats(morphosyn_str, pos):
        if morphosyn_str == "_":
            return []
        
        features = []
        if "|" in morphosyn_str:
            morphosyn_list = morphosyn_str.split("|")
        else:
            morphosyn_list = [morphosyn_str]
        for feature in morphosyn_list:
            name, value = feature.split("=")
            features.append(f"{pos}:{name}_{value}")
        return features

    def _lookahead(lines, idx, word, tokens, sentence_to_labels, sentence_to_deps, morphosyn_feats, dep_label, max_lookahead=1):
        lookahead = 1
        while lookahead <= max_lookahead:
            matched = True
            next_word = lines[idx+lookahead].split("\t")[1]
            word = f"{word}{next_word}"
            word_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(word, add_special_tokens=False))
            # nested munging
            for word_token in word_tokens:
                if word_token == tokens[0]:
                    sentence_to_labels[sentence].append(morphosyn_feats)
                    sentence_to_deps[sentence].append(dep_label)
                    tokens.popleft()
                else:
                    lookahead += 1
                    matched = False
                    break
            if matched:
                break
        return tokens, lookahead

    sentences = []
    sentence_to_labels = {}
    sentence_to_deps = {}
    lines = indata.readlines()
    num_to_skip = 0
    num_sents = 0
    num_sents_skipped = 0
    sentence_unhandled = False
    for idx, line in enumerate(lines):
        if num_to_skip > 0:
            num_to_skip -= 1
            continue

        if line.startswith("# text"):
            sentence = line.strip().split("# text = ")[1]
            if sentence.startswith("http") or sentence == "Smokers Haven":
                sentence = None     # skip this one
                num_sents_skipped += 1
                continue
            sentence = sentence.replace(u"\xa0", " ").replace("  ", " ")
            num_sents += 1
            sentences.append(sentence)
            tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(sentence, add_special_tokens=False))
            tokens_len = len(tokens)
            tokens = deque(tokens)  # deques can pop from left much more efficiently than lists
            sentence_to_labels[sentence] = []
            sentence_to_deps[sentence] = []
            continue
        elif line.startswith("# "):
            continue
        elif len(line) < 2:     # Empty line means end-of-sentence
            if sentence is None:
                continue
            assert len(tokens) == 0, f"Not all tokens have been processed! Remainders: {tokens}"
            assert tokens_len == len(sentence_to_labels[sentence])
            continue
        
        if sentence is None:
            continue
        # munge sentence word-by-word
        row = line.split("\t")
        _id, word, lemma, pos, ptb_pos, morphosyn_feats, dep_to, dep_label, _, notes = row
        if _id.endswith(".1"):     # word not actually in sentence
            continue

        morphosyn_feats = _parse_morphosyn_feats(morphosyn_feats, pos)
        if tokens[0].startswith(space_char):
            word = f" {word}"
        word_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(word, add_special_tokens=False))
        for token in word_tokens:
            if token != tokens[0]:
                if pos == "PUNCT":  # Try lookahead
                    tokens, num_to_skip = _lookahead(lines, idx, word, tokens, sentence_to_labels, sentence_to_deps, morphosyn_feats, dep_label, max_lookahead=3)
                    continue
                elif idx != len(lines)-1 and len(lines[idx+1]) > 2 and lines[idx+1].split("\t")[3] == "PART":
                    tokens, num_to_skip = _lookahead(lines, idx, word, tokens, sentence_to_labels, sentence_to_deps, morphosyn_feats, dep_label)
                    continue
                else:
                    num_sents_skipped += 1
                    del sentence_to_labels[sentence]
                    del sentence_to_deps[sentence]
                    sentence = None
                    sentence_unhandled = True
                    break
                # raise Exception(f"Mismatched token lists for sentence:\n{sentence}\nWord tokens: {word_tokens}\nSent tokens: {tokens}")
            sentence_to_labels[sentence].append(morphosyn_feats)
            sentence_to_deps[sentence].append(dep_label)
            tokens.popleft()
            # If we're at max_length, stop
        #     if len(sentence_to_labels[sentence]) >= max_length:
        #         break
        # if sentence_unhandled:
        #     sentence_unhandled = False
        #     continue
        # if len(sentence_to_labels[sentence]) >= max_length:
        #     continue
    
    print(f"Unhandled sentences: {num_sents_skipped} / {num_sents} ({num_sents_skipped / num_sents * 100:.2f}%)")
    return sentences, sentence_to_labels, sentence_to_deps

def convert_to_dataset(sentences, tokenizer, max_length=128, num_datapoints=None):
    if(num_datapoints):
        split_sentences[:num_datapoints]
    else:
        split_sentences = sentences
    df = pd.DataFrame(split_sentences)
    dataset = Dataset.from_pandas(df.rename(columns={0: "text"}), split="train")
    tokenized_dataset = dataset.map(
        lambda x: tokenizer(x["text"], padding=True, truncation=True,
                            max_length=max_length),
        batched=True,
    )
    # ).filter(
    #     lambda x: len(x['input_ids']) > max_length
    # ).map(
    #     lambda x: {'input_ids': x['input_ids'][:max_length]}
    # )
    return tokenized_dataset

    # dataset = load_dataset(dataset_name, split=split_text).map(
    #     lambda x: tokenizer(x['text']),
    #     batched=True,
    # ).filter(
    #     lambda x: len(x['input_ids']) > max_length
    # ).map(
    #     lambda x: {'input_ids': x['input_ids'][:max_length]}
    # )
    # return dataset

with open("data/ud/UD_English/en-ud-train.conllu", 'r') as indata:
    sentences, sentence_to_labels, sentence_to_deps = parse_and_load_text(indata, tokenizer, max_length=max_length)
                                                        # space_char="Ġ")
dataset = convert_to_dataset(sentences, tokenizer, max_length=max_length)
# dataset = download_dataset(dataset_name, tokenizer=tokenizer, max_length=max_seq_length, num_datapoints=7000)

Unhandled sentences: 306 / 12460 (2.46%)


Map:   0%|          | 0/12460 [00:00<?, ? examples/s]

In [None]:
dataset[15]['input_ids']

[2,
 1596,
 50276,
 603,
 573,
 1872,
 5830,
 575,
 573,
 1758,
 576,
 187987,
 575,
 573,
 8432,
 685,
 21240,
 577,
 573,
 5086,
 235290,
 36622,
 576,
 573,
 3170,
 2330,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0]

In [None]:
import torch
from torch.utils.data import DataLoader
from einops import rearrange
from tqdm import tqdm
from baukit import Trace

max_seq_length = 64

def get_dictionary_activations(model, dataset, cache_name, max_seq_length, autoencoder, batch_size=32):
    # num_features, d_model = autoencoder.encoder.weight.shape
    num_features, d_model = params["W_enc"].shape[1], params["W_enc"].shape[0]
    datapoints = dataset.num_rows
    dictionary_activations = torch.zeros((datapoints*max_seq_length, num_features))
    token_list = torch.zeros((datapoints*max_seq_length), dtype=torch.int64)
    with torch.no_grad(), dataset.formatted_as("pt"):
        dl = DataLoader(dataset["input_ids"], batch_size=batch_size)
        for i, batch in enumerate(tqdm(dl)):
            batch = batch.to(model.device)
            token_list[i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length] = rearrange(batch, "b s -> (b s)")
            with Trace(model, cache_name) as ret:
                _ = model(batch).logits
                internal_activations = ret.output
                # check if instance tuple
                if(isinstance(internal_activations, tuple)):
                    internal_activations = internal_activations[0]
            batched_neuron_activations = rearrange(internal_activations, "b s n -> (b s) n" )
            batched_dictionary_activations = autoencoder.encode(batched_neuron_activations)
            dictionary_activations[i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length,:] = batched_dictionary_activations.cpu()
    return dictionary_activations, token_list

batch_size = 64
dictionary_activations, tokens_for_each_datapoint = get_dictionary_activations(model, dataset, submodule_name, max_seq_length, ae, batch_size=batch_size)

100%|██████████| 195/195 [01:10<00:00,  2.77it/s]


In [None]:
num_seqs = int(dictionary_activations.shape[0] / max_length)
num_feats = dictionary_activations.shape[-1]
dictionary_activations = dictionary_activations.reshape((num_seqs, max_length, num_feats))

tokens_for_each_datapoint = tokens_for_each_datapoint.reshape((num_seqs, max_length))

In [None]:
feature_freqs = defaultdict(int)
for sentence in sentence_to_labels:
    morphosyn_feats = sentence_to_labels[sentence]
    dep_feats = sentence_to_deps[sentence]
    for i in range(len(morphosyn_feats)):
        for feat in morphosyn_feats[i]:
            feature_freqs[feat] += 1
        feature_freqs[dep_feats[i]] += 1

In [48]:
def feature_precisions(dictionary_activations, tokens, feature_idx, sentences, sentence_to_labels):
    morphosyn_acts = defaultdict(float)
    dep_acts = defaultdict(float)
    
    for idx, sentence in tqdm(enumerate(sentences), total=len(sentences), desc="Examples"):
        if sentence not in sentence_to_labels:
            continue
        sentence_len = len(sentence_to_labels[sentence])
        tokens_sent = tokens[idx]
        if tokens_sent[0] == 2:
            idx_offset = 1
        else:
            idx_offset = 0
        dictionary_sent_acts = dictionary_activations[idx, idx_offset : sentence_len+idx_offset, feature_idx]
        
        nonzero_idxs = dictionary_sent_acts.nonzero().flatten().tolist()
        morphosyn_feats = [sentence_to_labels[sentence][j] for j in nonzero_idxs]
        dep_feats = [sentence_to_deps[sentence][j] for j in nonzero_idxs]
        for j, feat_list in enumerate(morphosyn_feats):
            dep_label = dep_feats[j]
            for feat in feat_list:
                morphosyn_acts[feat] += dictionary_sent_acts[nonzero_idxs[j]].item()
            dep_acts[dep_label] += dictionary_sent_acts[nonzero_idxs[j]].item()
    
    for feat in morphosyn_acts:
        morphosyn_acts[feat] /= dictionary_activations.shape[0]
    for dep in dep_acts:
        dep_acts[dep] /= dictionary_activations.shape[0]

    return morphosyn_acts, dep_acts

morphosyn_acts, dep_acts = feature_precisions(dictionary_activations, tokens_for_each_datapoint, 3883, sentences, sentence_to_labels)

Examples:   0%|          | 0/12460 [00:00<?, ?it/s]

In [49]:
sorted(dep_acts.items(), key=lambda x: x[1], reverse=True)

[('punct', 72.62273650682182),
 ('nsubj', 62.049668940609955),
 ('root', 54.59590063202247),
 ('obl', 42.58081109550562),
 ('compound', 42.41500175561798),
 ('obj', 38.44035162520064),
 ('nmod', 36.894063503210276),
 ('case', 36.42236531902087),
 ('det', 28.52585899879615),
 ('conj', 25.323205507624397),
 ('amod', 24.603417184991976),
 ('advmod', 24.106315208667738),
 ('nummod', 22.44045570826645),
 ('aux', 17.57957338483146),
 ('cc', 15.14309540529695),
 ('appos', 14.464808637640449),
 ('mark', 13.27491974317817),
 ('flat', 11.193316111556982),
 ('cop', 9.043392606340289),
 ('nmod:poss', 8.709218248394864),
 ('advcl', 8.551315459470304),
 ('xcomp', 7.389261888041734),
 ('list', 6.3388794141252),
 ('ccomp', 5.584127959470305),
 ('acl:relcl', 5.505793539325842),
 ('nsubj:pass', 4.545368930577849),
 ('parataxis', 4.344797100722311),
 ('acl', 4.020584620786517),
 ('discourse', 3.4201921147672554),
 ('aux:pass', 2.3526798254414127),
 ('goeswith', 2.3021706962279294),
 ('obl:tmod', 2.256221

In [50]:
sorted(morphosyn_acts.items(), key=lambda x: x[1], reverse=True)

[('NOUN:Number_Sing', 115.8342621388443),
 ('PROPN:Number_Sing', 79.06908482142858),
 ('NUM:NumType_Card', 47.24079429173355),
 ('PRON:PronType_Prs', 45.21140273876404),
 ('NOUN:Number_Plur', 38.85983396869984),
 ('PRON:Number_Sing', 33.19192164927769),
 ('PRON:Case_Nom', 32.79439205457464),
 ('ADJ:Degree_Pos', 31.908334169341895),
 ('AUX:VerbForm_Fin', 27.078156350321027),
 ('DET:PronType_Art', 23.44992350521669),
 ('PRON:Person_1', 22.584531751605137),
 ('VERB:VerbForm_Fin', 22.191040078250403),
 ('VERB:Tense_Past', 18.744511185794543),
 ('AUX:Mood_Ind', 18.6844878611557),
 ('VERB:Mood_Ind', 17.9780547752809),
 ('DET:Definite_Def', 17.83414551565008),
 ('PRON:Person_3', 16.518686045345106),
 ('AUX:Tense_Pres', 14.65222838081862),
 ('VERB:VerbForm_Inf', 13.478255417335474),
 ('VERB:VerbForm_Part', 12.68010132423756),
 ('VERB:Tense_Pres', 11.913644913723916),
 ('PRON:Number_Plur', 9.511666081460675),
 ('AUX:Number_Sing', 9.497331460674157),
 ('AUX:Person_3', 9.08046749598716),
 ('PRON:

In [None]:
features_list = list(sentence_to_labels.values())
unique_features = set()
for dict_list in features_list:
    for dict in dict_list:
        unique_features.update([f"{k}_{v}" for k, v in dict.items()])
print(len(unique_features))

35
