In [1]:
%load_ext autoreload
%autoreload 2

# import IPython
# IPython.get_ipython().run_line_magic('cd', '..')  # Go up one directory

from buffer import AllActivationBuffer
from trainers.top_k import TrainerSCAE, AutoEncoderTopK
from training import trainSCAE

from datasets import load_dataset
import torch as t
from nnsight import LanguageModel


device = "cuda:0" if t.cuda.is_available() else "cpu"
model = LanguageModel("gpt2", device_map=device, torch_dtype=t.bfloat16)
# model.eval()


dataset = load_dataset(
    'Skylion007/openwebtext', 
    split='train', 
    streaming=True,
    trust_remote_code=True
    )

class CustomData():
    def __init__(self, dataset):
        self.data = iter(dataset)

    def __iter__(self):
        return self

    def __next__(self):
        return next(self.data)['text']

data = CustomData(dataset)

In [2]:
C = 10
expansion = 16
k = 64 # TODO

num_features = model.config.n_embd * expansion

In [3]:
n_layer = model.config.n_layer

submodules = {}
for layer in range(n_layer):
    submodules[f"mlp_{layer}"] = (model.transformer.h[layer].mlp, "in_and_out")
    submodules[f"attn_{layer}"] = (model.transformer.h[layer].attn, "out")


buffer = AllActivationBuffer(
    data=data,
    model=model,
    submodules=submodules,
    d_submodule=model.config.n_embd, # output dimension of the model component
    n_ctxs=128,  # you can set this higher or lower depending on your available memory
    device="cuda",
    out_batch_size = 128,
    refresh_batch_size = 256,
    dtype=t.bfloat16,
) 

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [4]:
important_features = {f"mlp_{layer}": t.randint(0, num_features, (num_features, C))
                        for layer in range(model.config.n_layer)} # TODO This doesn't make sense

# Get submodule names from the submodules dictionary
submodule_names = list(submodules.keys())

In [5]:
pretrained_info = {}
for layer in range(model.config.n_layer):
    for module in ['attn', 'mlp']:
        pretrained_info[f'{module}_{layer}'] = {'repo_id': 'jacobcd52/scae', 'filename': f'ae_{module}_{layer}.pt'}

In [6]:
trainer = TrainerSCAE(
        activation_dims={name: model.config.n_embd for name in submodule_names},
        dict_sizes={name: model.config.n_embd * expansion for name in submodule_names},
        ks={name: k for name in submodule_names},
        submodules=submodules,
        important_features={},
        pretrained_info=pretrained_info,
        model_config=model.config,
        auxk_alpha=0,
        connection_sparsity_coeff=0,
        use_sparse_connections=False,
        seed=None,
        device="cuda",
        wandb_name="SCAE",
        dtype=t.bfloat16,  # Add dtype parameter
)

#aa9a791c5e40fa7ab2f08d555ff72352c1cecaa2

In [10]:
def run_evaluation(trainer, buffer, n_batches=10, ce_batch_size=16):
    varexp_metrics = {name : {} for name in trainer.submodules.keys()}
    ce_metrics = {name : {} for name in trainer.submodules.keys()}
    for i in range(n_batches):
        # get varexp metrics
        input_acts, output_acts = next(buffer)
        batch_varexp_metrics = trainer.evaluate_varexp_batch(input_acts, output_acts, use_sparse_connections=False)

        # get CE metrics
        b = buffer.refresh_batch_size
        buffer.refresh_batch_size = ce_batch_size
        tokens = buffer.token_batch()
        batch_ce_metrics = trainer.evaluate_ce_batch(model, tokens, use_sparse_connections=False)
        buffer.refresh_batch_size = b

        for name in ce_metrics.keys():
            for metric in batch_ce_metrics[name].keys():
                ce_metrics[name][metric] = ce_metrics[name].get(metric, 0) + batch_ce_metrics[name][metric]
            for metric in batch_varexp_metrics[name].keys():
                varexp_metrics[name][metric] = varexp_metrics[name].get(metric, 0) + batch_varexp_metrics[name][metric]
    
    for name in ce_metrics.keys():
        for metric in ce_metrics[name].keys():
            ce_metrics[name][metric] = ce_metrics[name][metric] / n_batches
        for metric in varexp_metrics[name].keys():
            varexp_metrics[name][metric] = varexp_metrics[name][metric] / n_batches
        
    return varexp_metrics, ce_metrics

In [None]:
varexp_metrics, ce_metrics = run_evaluation(trainer, buffer, n_batches=2, ce_batch_size=16, use_sparse_connections=True)