In [1]:
from datasets import load_dataset
from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForCausalLM
import torch
from sae_lens import SAE
from transformer_lens import HookedTransformer

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7ba55817ed10>

# Load query2sae Model

In [3]:
repo = "mksethi/gemma-query2sae"
cfg = AutoConfig.from_pretrained(repo, trust_remote_code=True)
query2sae = AutoModel.from_pretrained(repo, trust_remote_code=True)

print(type(cfg))    # -> Query2SAEConfig
print(type(query2sae))  # -> Query2SAEModel


<class 'transformers_modules.mksethi.gemma-query2sae.b722da8493f29c204bce3a980e816d6ce939def1.configuration_query2sae.Query2SAEConfig'>
<class 'transformers_modules.mksethi.gemma-query2sae.b722da8493f29c204bce3a980e816d6ce939def1.modeling_query2sae.Query2SAEModel'>


In [4]:
query2sae.to(device)

Query2SAEModel(
  (backbone): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (head): Sequential(
    (0): Linear(in_features=768, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=

# Load gemma-2b-it

In [5]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side="right"
tokenizer.truncation_side="right"
gemma = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    torch_dtype=torch.bfloat16
)
gemma.to(device)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): GELUActivation()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
      )
    )
    (norm): GemmaRMSNorm((2048,), 

# Load truthfulQA

In [6]:
ds = load_dataset("domenicrosati/TruthfulQA")

In [7]:
# ds = ds.remove_columns(['Type', 'Category', 'Source'])
ds

DatasetDict({
    train: Dataset({
        features: ['Type', 'Category', 'Question', 'Best Answer', 'Correct Answers', 'Incorrect Answers', 'Source'],
        num_rows: 817
    })
})

# Load in SAE tools

In [8]:
sae, cfg, _ = SAE.from_pretrained(
    release="gemma-2b-it-res-jb",
    sae_id="blocks.12.hook_resid_post",
    device = "cpu"
)
sae = sae.to(device)

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
  sae, cfg, _ = SAE.from_pretrained(


In [9]:
model = HookedTransformer.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16).eval()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2b-it into HookedTransformer


In [15]:
import torch
from transformer_lens import HookedTransformer

def sae_fn(batch):
    correct_ans = batch['Correct Answers']
    inc_ans = batch['Incorrect Answers']
    
    with torch.no_grad():
        
        _, cache = model.run_with_cache(correct_ans)
        resid = cache['blocks.12.hook_resid_post']
        flattened = resid.reshape(-1, resid.shape[-1])
        c_sae_features = sae.encode(flattened)
        c_sae_reshaped = c_sae_features.reshape(
            len(correct_ans), 
            resid.shape[1], 
            -1
        )
        
        _, cache = model.run_with_cache(inc_ans)
        resid = cache['blocks.12.hook_resid_post']
        flattened = resid.reshape(-1, resid.shape[-1])
        i_sae_features = sae.encode(flattened)
        i_sae_reshaped = i_sae_features.reshape(
            len(inc_ans), 
            resid.shape[1], 
            -1
        )

    return {
        'correct_sae_features': c_sae_reshaped,
        'incorrect_sae_features': i_sae_reshaped
    }

In [11]:
def tok_fn(batch):
    questions = batch['Question']
    correct_ans = batch['Best Answer']
    inc_ans = batch['Incorrect Answers']

    q_out = tokenizer(
        questions,
        truncation=True,
        padding="max_length",
        max_length=256
    )

    c_out = tokenizer(
        correct_ans,
        truncation=True,
        padding="max_length",
        max_length=256
    )

    i_out = tokenizer(
        correct_ans,
        truncation=True,
        padding="max_length",
        max_length=256
    )
    return {
        'question': q_out['input_ids'], 'q_attention_mask': q_out['attention_mask'],
        'correct_answer': c_out['input_ids'], 'c_attention_mask': c_out['attention_mask'],
        'incorrect_answer': i_out['input_ids'], 'i_attention_mask': i_out['attention_mask'],
        }

In [17]:
ds_tok = ds.map(
    tok_fn,
    batched=True,
    # remove_columns=['Type', 'Category', 'Question', 'Best Answer', 'Correct Answers', 'Incorrect Answers', 'Source'],
    desc="Tokenizing Questions"
)

Tokenizing Questions:   0%|          | 0/817 [00:00<?, ? examples/s]

In [18]:
ds_tok

DatasetDict({
    train: Dataset({
        features: ['Type', 'Category', 'Question', 'Best Answer', 'Correct Answers', 'Incorrect Answers', 'Source', 'question', 'q_attention_mask', 'correct_answer', 'c_attention_mask', 'incorrect_answer', 'i_attention_mask'],
        num_rows: 817
    })
})

In [19]:
ds_sae = ds_tok.map(
    sae_fn,
    batched=True,
    batch_size=8,
    desc="Getting Answer SAE's"
)

Getting Answer SAE's:   0%|          | 0/817 [00:00<?, ? examples/s]

In [23]:
c_sae = ds_sae['train'][0]['correct_sae_features']
i_sae = ds_sae['train'][0]['incorrect_sae_features']

list