In [3]:
# Load acts and toks

import torch
import numpy as np
from transformers import AutoTokenizer
from interp_utils import reload_module
from sparse_models import SparseMLP
from feature_kit import *
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

sparsity_levels = torch.arange(-5,5)
sparsity = 3
fname = f'mlp_F6000_S{sparsity}_R1.pt'

acts = torch.load(f'val_acts/{fname}')
per_doc_maxes = acts.max(dim=-1).values
per_doc_maxes = per_doc_maxes/(per_doc_maxes.max(dim=-1).values[:,None]+1)

docs = torch.load('val_acts/val_tok_strs.pt')[:acts.shape[-2]]
doc_ids = torch.load('val_acts/val_tok_ids.pt')

mlp = SparseMLP(n_features=6000, d_model=768, disable_comet=True)
mlp.load_state_dict(torch.load(f'./sparse-mlps/{fname}', map_location=device))

print(fname)

mlp_F6000_S3_R1.pt


In [None]:
doc_strings = [''.join(doc[:128]) for doc in docs.tolist()]

def filter_docs(and_=[], or_=[], not_=[]):
    '''
    and_: list of strings
    or_: list of strings
    not_: list of strings

    filter docs for documents that include every string in and_ as a substring, at least one substring in or_, and no substrings in not_
    '''
    return np.array([i for i, doc in enumerate(doc_strings) if (any([inc in doc for inc in or_]) or len(or_) == 0) and (not any([exc in doc for exc in not_]) or len(not_) == 0) and all([inc in doc for inc in and_])])

def get_feature_data(feature_idx, and_=[], or_=[], not_=[], reversed=False):
    subset = filter_docs(and_=and_, or_=or_, not_=not_)
    print(f'Found {len(subset)} docs')

    feature_acts = (acts[feature_idx])
    feature_acts = ((feature_acts/feature_acts.max())[subset])

    perm = feature_acts.max(dim=-1).values.argsort(descending=reversed)
    feature_acts = feature_acts[perm]
    feature_docs = docs[subset][perm]
    feature_doc_ids = doc_ids[subset][perm]
    per_doc_feature_maxes = feature_acts.max(dim=-1).values

    return feature_docs, feature_doc_ids, feature_acts, per_doc_feature_maxes

In [None]:
import pysvelte

FEATURE_IDX = 3000
feature_docs, feature_doc_ids, feature_weights, feature_maxes = get_feature_data(feature_idx=FEATURE_IDX, and_=[], or_=[], not_=[], reversed=False)
print(FEATURE_IDX)

pysvelte.WeightedDocs(tokens=feature_docs.tolist(), weights=feature_weights.tolist(), per_doc_maxes=feature_maxes.tolist()).show()

In [None]:
feature_acts.shape

In [None]:
acts.shape

In [None]:
' swim' in water_words

In [None]:
import pysvelte

reload_module('tok_labelling')
from tok_labelling import new_tok_label, load_tok_label

# new_tok_label(title='water_words', description="Strings that are related to water. For example, 'water', 'bath', 'overflow', 'tank', 'wet', 'ducks'")
water_words = load_tok_label('water_words')

water_word_match = Match(*water_words)
# water_word_cases = Seq(
#     Cases(*water_words)
# )

feats_fn = Stack(
    water_word_match
    # Anything()
)

MAX_DOCS_TO_PYSVELTE = 100000

FIND_HOLES = False
HIDE_POS_ERRS = False
USE_MSE_IN_RENDERER = False

feature_acts = acts[FEATURE_IDX]
feature_acts = feature_acts/feature_acts.max()
reg_weights, reg_bias, pred = pred_feature(feature_acts, feats_fn, doc_ids)
print(reg_weights)

mse_errs = (feature_acts-pred)**2
abs_errs = (feature_acts-pred).abs()

signed_errs = mse_errs*(feature_acts-pred).sign() if USE_MSE_IN_RENDERER else abs_errs*(feature_acts-pred).sign()
doc_maxes = mse_errs.max(dim=-1).values
err_docs = doc_maxes.topk(k=10).indices

# sort docs by error
if FIND_HOLES:
    perm = signed_errs.min(dim=-1).values.clamp(max=0).abs().argsort(descending=False)[-MAX_DOCS_TO_PYSVELTE:]
    signed_errs = signed_errs[perm]
else:
    perm = mse_errs.max(dim=-1).values.argsort(descending=False)[-MAX_DOCS_TO_PYSVELTE:]
    signed_errs = signed_errs[perm]

if HIDE_POS_ERRS:
    signed_errs = signed_errs.clamp(max=0)


mse = mse_errs.mean()
print(f'mse: {mse:.2E}')
abs_errs = abs_errs[perm]
feature_doc_subset = docs[perm] 
# weights = feature_acts[perm]s

# feature_docs
component = pysvelte.WeightedDocs(tokens=feature_doc_subset.tolist(), weights=signed_errs.tolist(), per_doc_maxes=doc_maxes.tolist(), reversed=False)
component.show()

