In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')

from datasets import load_dataset
import torch as t
from nnsight import LanguageModel
from tqdm import tqdm
from collections import defaultdict

from buffer import AllActivationBuffer
from trainers.scae import SCAESuite

DTYPE = t.bfloat16
device = "cuda:0" if t.cuda.is_available() else "cpu"
model = LanguageModel("gpt2", device_map=device, torch_dtype=DTYPE)

dataset = load_dataset(
    'Skylion007/openwebtext', 
    split='train', 
    streaming=True,
    trust_remote_code=True
    )

class CustomData():
    '''dumb helper class to make the dataset iterable'''
    def __init__(self, dataset):
        self.data = iter(dataset)
    def __iter__(self):
        return self
    def __next__(self):
        return next(self.data)['text']

data = CustomData(dataset)

In [2]:
C = 10
expansion = 16
k = 128

num_features = model.config.n_embd * expansion
n_layer = model.config.n_layer


In [17]:
model.transformer.h[0].ln_2.bias

Parameter containing:
tensor([ 4.2480e-02,  3.2715e-02,  4.4861e-03,  1.5747e-02,  5.9509e-03,
        -2.6855e-02,  1.1719e-02,  5.1270e-02,  1.3245e-02, -4.2419e-03,
        -1.1658e-02,  2.1973e-03, -1.3367e-02, -2.0599e-03, -2.6489e-02,
         1.1902e-02, -4.2419e-03, -1.7822e-02,  4.7852e-02, -3.5553e-03,
         1.0490e-04, -1.2268e-02,  1.0864e-02, -5.9509e-03,  6.3171e-03,
         5.8289e-03,  6.2012e-02,  7.7057e-04, -2.4567e-03,  3.7994e-03,
         4.1504e-03,  3.8574e-02,  7.1106e-03, -1.9775e-02, -8.7280e-03,
         1.0559e-02,  1.3281e-01,  6.4087e-03,  8.8501e-03,  3.2715e-02,
         1.2146e-02, -6.0730e-03,  1.4160e-02, -6.6833e-03,  3.1250e-02,
        -2.2705e-02,  3.3936e-02, -1.2695e-02,  8.4473e-02,  3.5645e-02,
         4.0894e-03, -4.0894e-03, -7.2266e-02,  6.3782e-03,  3.6621e-02,
         5.6641e-02, -2.8839e-03,  7.2754e-02,  3.3875e-03, -1.8677e-02,
         4.3701e-02, -8.9722e-03,  1.7242e-03,  4.8828e-03,  6.4844e-01,
         2.2461e-02, -1.6098e

In [5]:
pretrained_configs = {}
connections = defaultdict(dict)

for down_layer in range(n_layer):
    for module in ['attn', 'mlp']:
        down_name = f'{module}_{down_layer}'
        pretrained_configs[f'{module}_{down_layer}'] = {
            'repo_id': 'jacobcd52/scae', 
            'filename': f'ae_{module}_{down_layer}.pt',
            'k' : k,
            'layernorm_gamma': 100
            }
        
        # Use random connections for testing
        if module=='mlp':
            for up_layer in range(down_layer+1):
                up_name = f'{module}_{up_layer}'
                connections[down_name][up_name] = t.randint(0, num_features, (num_features, C), dtype=t.long)

suite = SCAESuite.from_pretrained(pretrained_configs, connections=connections)

  state_dict = t.load(weights_path, map_location='cpu')


In [6]:
submodules = {}
for layer in range(n_layer):
    submodules[f"mlp_{layer}"] = (model.transformer.h[layer].mlp, "in_and_out")
    submodules[f"attn_{layer}"] = (model.transformer.h[layer].attn, "out")
    
buffer = AllActivationBuffer(
    data=data,
    model=model,
    submodules=submodules,
    d_submodule=model.config.n_embd,
    n_ctxs=128,
    out_batch_size = 32,
    refresh_batch_size = 256,
    device=device,
    dtype=DTYPE,
)

  with t.cuda.amp.autocast(dtype=self.dtype):
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [7]:
def run_evaluation(
        suite, 
        buffer, 
        n_batches=10, 
        ce_batch_size=32,
        use_sparse_connections=False
        ):
    '''Simple function to run evaluation on several batches, and return the average metrics'''
    
    varexp_metrics = {name : {} for name in buffer.submodules.keys()}
    ce_metrics = {name : {} for name in buffer.submodules.keys()}

    for i in tqdm(range(n_batches)):
        # get varexp metrics
        input_acts, output_acts = next(buffer)
        batch_varexp_metrics = suite.evaluate_varexp_batch(
            input_acts, 
            output_acts,
            use_sparse_connections=use_sparse_connections
            )

        # get CE metrics
        b = buffer.refresh_batch_size
        buffer.refresh_batch_size = ce_batch_size
        tokens = buffer.token_batch()
        batch_ce_metrics = suite.evaluate_ce_batch(
            model, 
            tokens, 
            buffer.submodules,
            use_sparse_connections=use_sparse_connections
            )
        buffer.refresh_batch_size = b

        for name in ce_metrics.keys():
            for metric in batch_ce_metrics[name].keys():
                ce_metrics[name][metric] = ce_metrics[name].get(metric, 0) + batch_ce_metrics[name][metric] / n_batches
            for metric in batch_varexp_metrics[name].keys():
                varexp_metrics[name][metric] = varexp_metrics[name].get(metric, 0) + batch_varexp_metrics[name][metric] / n_batches
           
    return varexp_metrics, ce_metrics

In [8]:
varexp_metrics, ce_metrics = run_evaluation(
    suite, 
    buffer, 
    n_batches=1, 
    ce_batch_size=32,
    use_sparse_connections=True
    )

100%|██████████| 1/1 [00:06<00:00,  6.97s/it]


In [9]:
print(f"Clean loss = {ce_metrics['mlp_0']['loss_original']:.3f}\n")

print("Module  CE increase  CE expl Var expl")
for name in [k for k in ce_metrics.keys() if 'mlp' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}        {ce_metrics[name]['frac_recovered']*100:.0f}%     {varexp_metrics[name]['frac_variance_explained']*100:.0f}%")

print()

for name in [k for k in ce_metrics.keys() if 'attn' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}        {ce_metrics[name]['frac_recovered']*100:.0f}%     {varexp_metrics[name]['frac_variance_explained']*100:.0f}%")

Clean loss = 3.438

Module  CE increase  CE expl Var expl
mlp_0   7.625        -77%     -1670532%
mlp_1   7.875        -12500%     -3465992%
mlp_2   4.531        -7150%     -1822076%
mlp_3   14.438        -18380%     -2973179%
mlp_4   14.188        -22600%     -2956593%
mlp_5   14.938        -19020%     -1985468%
mlp_6   16.312        -17300%     -1583495%
mlp_7   17.438        -18500%     -1205486%
mlp_8   17.562        -18633%     -729978%
mlp_9   16.062        -14586%     -474382%
mlp_10   11.000        -8700%     -174055%
mlp_11   7.375        -3531%     -31672%

attn_0   0.016        99%     99%
attn_1   0.000        100%     96%
attn_2   0.016        67%     96%
attn_3   0.000        100%     93%
attn_4   0.016        67%     92%
attn_5   0.000        100%     92%
attn_6   0.000        100%     91%
attn_7   0.000        100%     93%
attn_8   0.000        100%     92%
attn_9   0.016        80%     93%
attn_10   0.016        75%     93%
attn_11   0.016        92%     99%


In [28]:
suite.configs

{'attn_0': SubmoduleConfig(activation_dim=768, dict_size=12288, k=128, upstream_connections=None, layernorm_gamma=1.0),
 'mlp_0': SubmoduleConfig(activation_dim=768, dict_size=12288, k=128, upstream_connections=None, layernorm_gamma=1.0),
 'attn_1': SubmoduleConfig(activation_dim=768, dict_size=12288, k=128, upstream_connections=None, layernorm_gamma=1.0),
 'mlp_1': SubmoduleConfig(activation_dim=768, dict_size=12288, k=128, upstream_connections=None, layernorm_gamma=1.0),
 'attn_2': SubmoduleConfig(activation_dim=768, dict_size=12288, k=128, upstream_connections=None, layernorm_gamma=1.0),
 'mlp_2': SubmoduleConfig(activation_dim=768, dict_size=12288, k=128, upstream_connections=None, layernorm_gamma=1.0),
 'attn_3': SubmoduleConfig(activation_dim=768, dict_size=12288, k=128, upstream_connections=None, layernorm_gamma=1.0),
 'mlp_3': SubmoduleConfig(activation_dim=768, dict_size=12288, k=128, upstream_connections=None, layernorm_gamma=1.0),
 'attn_4': SubmoduleConfig(activation_dim=76