In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')

from datasets import load_dataset
import torch as t
from nnsight import LanguageModel
from tqdm import tqdm
from collections import defaultdict

from buffer import AllActivationBuffer
from trainers.scae import SCAESuite
from utils import load_model_with_folded_ln2, load_iterable_dataset

DTYPE = t.bfloat16
device = "cuda:0" if t.cuda.is_available() else "cpu"
model = load_model_with_folded_ln2("gpt2", device=device, torch_dtype=DTYPE)

data = load_iterable_dataset('Skylion007/openwebtext')

In [2]:
C = 10
expansion = 16
k = 128

num_features = model.config.n_embd * expansion
n_layer = model.config.n_layer

In [3]:
pretrained_configs = {}
connections = defaultdict(dict)

for down_layer in range(n_layer):
    for module in ['attn', 'mlp']:
        down_name = f'{module}_{down_layer}'
        pretrained_configs[f'{module}_{down_layer}'] = {
            'repo_id': 'jacobcd52/scae', 
            'filename': f'ae_{module}_{down_layer}.pt',
            'k' : k,
            }
        
        # Use random connections for testing
        if module=='mlp':
            for up_layer in range(down_layer+1):
                up_name = f'{module}_{up_layer}'
                connections[down_name][up_name] = t.randint(0, num_features, (num_features, C), dtype=t.long)

suite = SCAESuite.from_pretrained(
    pretrained_configs, 
    connections=connections,
    dtype=DTYPE,
    )

ae_attn_0.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

  state_dict = t.load(weights_path, map_location='cpu')


ae_mlp_0.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_1.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_1.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_2.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_2.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_3.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_3.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_4.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_4.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_5.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_5.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_6.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_6.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_7.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_7.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_8.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_8.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_9.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_9.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_10.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_10.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_11.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_11.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

In [5]:
initial_submodule = model.transformer.h[0]

layernorm_submodules = {}
submodules = {}
for layer in range(n_layer):
    submodules[f"mlp_{layer}"] = (model.transformer.h[layer].mlp, "in_and_out")
    submodules[f"attn_{layer}"] = (model.transformer.h[layer].attn, "out")

    layernorm_submodules[f"mlp_{layer}"] = model.transformer.h[layer].ln_2

buffer = AllActivationBuffer(
    data=data,
    model=model,
    submodules=submodules,
    initial_submodule=initial_submodule,
    layernorm_submodules=layernorm_submodules,
    d_submodule=model.config.n_embd,
    n_ctxs=128,
    out_batch_size = 2,
    refresh_batch_size = 2,
    device=device,
    dtype=DTYPE,
)

In [11]:
def run_evaluation(
        suite, 
        buffer, 
        n_batches=10, 
        ce_batch_size=32,
        use_sparse_connections=False
        ):
    '''Simple function to run evaluation on several batches, and return the average metrics'''
    
    varexp_metrics = {name : {} for name in buffer.submodules.keys()}
    ce_metrics = {name : {} for name in buffer.submodules.keys()}

    for i in tqdm(range(n_batches)):
        # get varexp metrics
        initial_acts, input_acts, output_acts, layernorm_scales = next(buffer)
        batch_varexp_metrics = suite.evaluate_varexp_batch(
            initial_acts,
            input_acts, 
            output_acts,
            layernorm_scales,
            use_sparse_connections=use_sparse_connections
            )

        # get CE metrics
        b = buffer.refresh_batch_size
        buffer.refresh_batch_size = ce_batch_size
        tokens = buffer.token_batch()
        batch_ce_metrics = suite.evaluate_ce_batch(
            model, 
            tokens, 
            initial_submodule,
            submodules,
            layernorm_submodules,
            use_sparse_connections=use_sparse_connections
            )
        buffer.refresh_batch_size = b

        for name in ce_metrics.keys():
            for metric in batch_ce_metrics[name].keys():
                ce_metrics[name][metric] = ce_metrics[name].get(metric, 0) + batch_ce_metrics[name].get(metric, 0) / n_batches
            for metric in batch_varexp_metrics[name].keys():
                varexp_metrics[name][metric] = varexp_metrics[name].get(metric, 0) + batch_varexp_metrics[name].get(metric, 0) / n_batches
           
    return varexp_metrics, ce_metrics

In [26]:
varexp_metrics, ce_metrics = run_evaluation(
    suite, 
    buffer, 
    n_batches=10, 
    ce_batch_size=2,
    use_sparse_connections=True
    )

100%|██████████| 10/10 [00:14<00:00,  1.44s/it]


In [27]:
print(f"Clean loss = {ce_metrics['mlp_0']['loss_original']:.3f}\n")

print("Module  CE increase  CE expl FVU")
for name in [k for k in ce_metrics.keys() if 'mlp' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}        {ce_metrics[name]['frac_recovered']*100:.0f}%     {varexp_metrics[name]['FVU']*100:.0f}%")

print()

for name in [k for k in ce_metrics.keys() if 'attn' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}        {ce_metrics[name]['frac_recovered']*100:.0f}%     {varexp_metrics[name]['FVU']*100:.0f}%")

Clean loss = 3.391

Module  CE increase  CE expl FVU
mlp_0   2.119        51%     162%
mlp_1   2.544        -1838%     255%
mlp_2   2.597        -5706%     49649%
mlp_3   5.634        -13572%     7220%
mlp_4   5.716        -13395%     4721%
mlp_5   6.272        -11866%     2631%
mlp_6   5.428        -7154%     1641%
mlp_7   4.650        -7210%     1071%
mlp_8   4.384        -6736%     716%
mlp_9   3.950        -5369%     562%
mlp_10   3.416        -2698%     296%
mlp_11   3.825        -1902%     167%

attn_0   -0.000        100%     1%
attn_1   -0.002        77%     2%
attn_2   0.005        70%     6%
attn_3   -0.003        112%     7%
attn_4   0.000        96%     5%
attn_5   0.000        71%     6%
attn_6   0.009        58%     7%
attn_7   -0.006        117%     6%
attn_8   0.003        88%     5%
attn_9   0.003        86%     6%
attn_10   0.005        85%     5%
attn_11   0.000        100%     0%
