In [None]:
from datasets import load_dataset
from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForCausalLM
import torch
from sae_lens import SAE
from transformer_lens import HookedTransformer

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)

# Load query2sae Model

In [None]:
repo = "mksethi/gemma-query2sae"
cfg = AutoConfig.from_pretrained(repo, trust_remote_code=True)
query2sae = AutoModel.from_pretrained(repo, trust_remote_code=True)

print(type(cfg))    # -> Query2SAEConfig
print(type(query2sae))  # -> Query2SAEModel


In [None]:
query2sae.to(device)

# Load gemma-2b-it

In [None]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side="right"
tokenizer.truncation_side="right"
gemma = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    torch_dtype=torch.bfloat16
)
gemma.to(device)


# Load truthfulQA

In [None]:
ds = load_dataset("domenicrosati/TruthfulQA")

In [None]:
# ds = ds.remove_columns(['Type', 'Category', 'Source'])
ds

# Load in SAE tools

In [None]:
sae, cfg, _ = SAE.from_pretrained(
    release="gemma-2b-it-res-jb",
    sae_id="blocks.12.hook_resid_post",
    device = "cpu"
)
sae = sae.to(device)

In [None]:
model = HookedTransformer.from_pretrained("google/gemma-2b-it").eval()

In [None]:
import torch
from transformer_lens import HookedTransformer

def sae_fn(batch):
    correct_ans = batch['correct_answer']
    inc_ans = batch['incorrect_answer']
    
    with torch.no_grad():
        
        _, cache = model.run_with_cache(correct_ans)
        resid = cache['blocks.12.hook_resid_post']
        flattened = resid.reshape(-1, resid.shape[-1])
        c_sae_features = sae.encode(flattened)
        c_sae_reshaped = c_sae_features.reshape(
            len(correct_ans), 
            resid.shape[1], 
            -1
        )
        
        _, cache = model.run_with_cache(inc_ans)
        resid = cache['blocks.12.hook_resid_post']
        flattened = resid.reshape(-1, resid.shape[-1])
        i_sae_features = sae.encode(flattened)
        i_sae_reshaped = i_sae_features.reshape(
            len(inc_ans), 
            resid.shape[1], 
            -1
        )

    return {
        'correct_sae_features': c_sae_reshaped,
        'incorrect_sae_features': i_sae_reshaped
    }

In [None]:
def tok_fn(batch):
    questions = batch['Question']
    correct_ans = batch['Best Answer']
    inc_ans = batch['Incorrect Answers']

    q_out = tokenizer(
        questions,
        truncation=True,
        padding="max_length",
        max_length=256
    )

    c_out = tokenizer(
        correct_ans,
        truncation=True,
        padding="max_length",
        max_length=256
    )

    i_out = tokenizer(
        correct_ans,
        truncation=True,
        padding="max_length",
        max_length=256
    )
    return {
        'question': q_out['input_ids'], 'q_attention_mask': q_out['attention_mask'],
        'correct_answer': c_out['input_ids'], 'c_attention_mask': c_out['attention_mask'],
        'incorrect_answer': i_out['input_ids'], 'i_attention_mask': i_out['attention_mask'],
        }

In [None]:
ds_tok = ds.map(
    tok_fn,
    batched=True,
    remove_columns=['Type', 'Category', 'Question', 'Best Answer', 'Correct Answers', 'Incorrect Answers', 'Source'],
    desc="Tokenizing Questions"
)

In [None]:
ds_tok

In [None]:
ds_sae = ds_tok.map(
    sae_fn,
    batched=True,
    desc="Getting Answer SAE's"
)