In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from nnsight import NNsight
from utils import load_gemma_sae


torch.set_grad_enabled(False)


tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
lm = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b").to(torch.bfloat16).cuda()
# model = NNsight(lm)



In [None]:
from datasets import load_dataset

ds = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT")

In [None]:
import numpy as np

tok_strs = np.array([tokenizer.decode([tok_id]).replace(' ', '·').replace('\n', '⤶') for tok_id in range(256000)])



In [None]:
from noa_tools import reload_module
reload_module('noa_tools')
from noa_tools import register_hook, remove_hooks, reload_module

In [None]:
# LAYER = 0
# L0=43

# sae = load_gemma_sae('att', filename=f'layer_{LAYER}/width_65k/average_l0_{L0}').to(torch.bfloat16).cuda()

# sae.W_dec.shape

In [None]:
lm.model.layers[0].mlp

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm
reload_module('utils')
from utils import CustomBreakError, register_sae
from noa_tools import clear_cache

NUM_DOCS = 10000
BATCH_SIZE = 300

num_batches= NUM_DOCS // BATCH_SIZE
dataloader = DataLoader(ds['train'], batch_size=BATCH_SIZE)

clear_cache(lm)
module, sae = register_sae(lm, layer=2, l0=100, type='att', width='16k')


all_acts = []
all_toks = []
for i, batch in tqdm(enumerate(dataloader), total=num_batches):
    out = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True, max_length=128)
    tok_ids, attn_mask = out['input_ids'], out['attention_mask']
    tok_ids = tok_ids[attn_mask[:,0].bool()].cuda()
    
    try:
        lm.forward(tok_ids)
    except CustomBreakError:
        pass
    except Exception as e:
        print(f"Unexpected error: {e}")
    
    sae_inp = module.cache['sae_inp']
    
    acts = sae.encode(sae_inp, indices=range(0,200))

    all_acts.append(acts)
    all_toks.append(tok_ids)

    if len(all_acts) > num_batches:
        break

acts = torch.cat(all_acts, dim=0).cpu()
toks = tok_strs[np.concatenate([toks.cpu().numpy() for toks in all_toks], axis=0)]

from einops import rearrange
acts = acts/(rearrange(acts, 'b s a -> ( b s ) a').max(dim=0).values[None,None]+1e-10)


In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import CustomBreakError, register_sae
from noa_tools import clear_cache
from einops import rearrange

def get_acts(lm, layer, l0, sae_type, width='16k', batch_size=200, indices=range(0,100), num_docs=10_000):
    

    num_batches= num_docs // batch_size
    dataloader = DataLoader(ds['train'], batch_size=batch_size)

    clear_cache(lm)
    module, sae = register_sae(lm, layer=layer, l0=l0, type=sae_type, width='16k')


    all_acts = []
    all_toks = []
    for i, batch in tqdm(enumerate(dataloader), total=num_batches):
        out = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True, max_length=128)
        tok_ids, attn_mask = out['input_ids'], out['attention_mask']
        tok_ids = tok_ids[attn_mask[:,0].bool()].cuda()
        
        try:
            lm.forward(tok_ids)
        except CustomBreakError:
            pass
        except Exception as e:
            print(f"Unexpected error: {e}")
        
        sae_inp = module.cache['sae_inp']
        
        acts = sae.encode(sae_inp, indices=range(0,200))

        all_acts.append(acts)
        all_toks.append(tok_ids)

        if len(all_acts) > num_batches:
            break

    acts = torch.cat(all_acts, dim=0).cpu()
    toks = tok_strs[np.concatenate([toks.cpu().numpy() for toks in all_toks], axis=0)]

    
    acts = acts/(rearrange(acts, 'b s a -> ( b s ) a').max(dim=0).values[None,None]+1e-10)
    
    return acts, toks


In [None]:
import pysvelte

N_DOCS = -1

FEAT_START = 100
N_FEATS = 20

for FEAT in range(FEAT_START, FEAT_START + N_FEATS):
    print(f'Feat {FEAT}')

    feat_acts = acts[:N_DOCS,:,FEAT]
    feat_toks = toks[:N_DOCS]
    feat_mask = feat_acts.max(dim=-1).values > 0
    docs = feat_toks[feat_mask].tolist()
    feat_acts = feat_acts[feat_mask].cpu().tolist()
    pysvelte.WeightedDocs(docs=docs, acts=feat_acts, start=0.8, k=4).show()



In [None]:
# model.generate(input_ids=tok_ids, attention_mask=attn_mask, max_new_tokens=128)

In [None]:
prompt = 'Once upon a time there was a giant'

inputs = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to("cuda")

outputs = model.generate(input_ids=inputs, max_new_tokens=50, temperature=1.0, do_sample=True)

In [None]:
[tokenizer.decode([tok_id]) for tok_id in outputs[0]]

In [None]:
from transformers import GenerationConfig

GenerationConfig.from_pretrained("google/gemma-2b")