Want:

main/control loss for:

- GSAE
- GSAE finetune
- GSAE-phys (sweep L0)
- GSAE + SSAE (sweep L0)
- clean loss

In [28]:
import sys
sys.path.append("/root/specialised-SAEs")
from datasets import load_dataset
from transformer_lens import utils, HookedTransformer
import gc
import torch

from sae_lens.jacob.load_sae_from_hf import load_sae_from_hf

In [25]:
DTYPE = "float32"

In [24]:
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda", dtype=DTYPE)

Loaded pretrained model gpt2-small into HookedTransformer


In [29]:
gsae_64 = load_sae_from_hf("jacobcd52/gpt2-gsae", 
                        "expansion=64.safetensors", 
                        "expansion=64_cfg.json",
                        device="cuda",
                        dtype=DTYPE)

gsae_16 = load_sae_from_hf("jacobcd52/gpt2-gsae", 
                        "expansion=16.safetensors", 
                        "expansion=16_cfg.json",
                        device="cuda",
                        dtype=DTYPE)



Downloading weights from Hugging Face Hub


expansion=64.safetensors:   0%|          | 0.00/303M [00:00<?, ?B/s]

GSAE weights file saved as temp_sae/sae_weights.safetensors
Downloading cfg from Hugging Face Hub


expansion=64_cfg.json:   0%|          | 0.00/2.48k [00:00<?, ?B/s]

GSAE cfg file saved as temp_sae/cfg.json
Loading weights into GSAE from temp_sae/sae_weights.safetensors
temp_sae/cfg.json temp_sae/sae_weights.safetensors


: 

In [22]:
# get OWT tokens
data = load_dataset("stas/openwebtext-10k", split="train")
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=256)
tokenized_data = tokenized_data.shuffle(42)
owt_tokens = tokenized_data["tokens"].cuda()
print("owt_tokens has shape", owt_tokens.shape)
print("total number of tokens:", int(owt_tokens.numel()//1e6), "million")
print()
# get physics-papers tokens
data = load_dataset("jacobcd52/physics-papers", split="train[:10%]")
# Define a filter function to remove null entries
def remove_null_entries(example):
    return all(value is not None and value != '' for value in example.values())
# Apply the filter to remove null entries
data = data.filter(remove_null_entries)
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=256)
tokenized_data = tokenized_data.shuffle(42)
phys_tokens = tokenized_data["tokens"].cuda()
print("phys_tokens has shape", phys_tokens.shape)
print("total number of tokens:", int(phys_tokens.numel()//1e6), "million")

# clean up
del tokenized_data, data
gc.collect()

owt_tokens has shape torch.Size([44086, 256])
total number of tokens: 11 million
phys_tokens has shape torch.Size([62879, 256])
total number of tokens: 16 million


0

In [26]:
def get_loss(
        sae_list, 
        tokens, 
        num_tokens, 
        hook_pt = 'blocks.8.hook_resid_pre', 
        batch_size=16
        ):
    
    # define hook fn to patch in SAE reconstructions, as well as cache the L0
    l0_dic = {}
    def patch_hook(act, hook):

        l0 = 0
        out = torch.zeros_like(act)
        
        for sae in sae_list:
            feature_acts = sae.encode_fn(act)
            l0 += (feature_acts > 0).to(DTYPE).sum(dim=-1).mean()
            out += sae.decode(feature_acts)

        l0_dic[0] = l0
        return out
    
    # initialise running variables
    total_l0 = 0
    total_loss = 0

    num_batches = num_tokens // (tokens.shape[1] * batch_size)

    for b in range(num_batches):
        # get batch
        batch = tokens[b*batch_size:(b+1)*batch_size]
        total_loss += model.run_with_hooks(
            batch,
            return_type="loss",
            fwd_hooks = [(hook_pt, patch_hook)]
        ).item()
        total_l0 += l0_dic[0].item()

    return total_loss / num_batches, total_l0 / num_batches

In [None]:
get_loss(sae_list, tokens, num_tokens