In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')

from datasets import load_dataset
import torch as t
from nnsight import LanguageModel
from tqdm import tqdm
from collections import defaultdict
import json

from buffer import AllActivationBuffer
from trainers.scae import SCAESuite
from utils import load_model_with_folded_ln2, load_iterable_dataset

DTYPE = t.float32
device = "cuda:0" if t.cuda.is_available() else "cpu"
t.manual_seed(42)

model = load_model_with_folded_ln2("gpt2", device=device, torch_dtype=DTYPE)
data = load_iterable_dataset('Skylion007/openwebtext')

In [2]:
C = 10
expansion = 16
k = 128

num_features = model.config.n_embd * expansion
n_layer = model.config.n_layer

In [3]:
suite = SCAESuite.from_pretrained(
    'jacobcd52/gpt2_suite_folded_ln',
    device=device,
    dtype=DTYPE,
    )

  checkpoint = t.load(checkpoint_path, map_location='cpu')


In [4]:
initial_submodule = model.transformer.h[0]
layernorm_submodules = {}
submodules = {}
for layer in range(n_layer):
    submodules[f"mlp_{layer}"] = (model.transformer.h[layer].mlp, "in_and_out")
    submodules[f"attn_{layer}"] = (model.transformer.h[layer].attn, "out")

    layernorm_submodules[f"mlp_{layer}"] = model.transformer.h[layer].ln_2

buffer = AllActivationBuffer(
    data=data,
    model=model,
    submodules=submodules,
    initial_submodule=initial_submodule,
    layernorm_submodules=layernorm_submodules,
    d_submodule=model.config.n_embd,
    n_ctxs=128,
    out_batch_size = 32,
    refresh_batch_size = 256,
    device=device,
    dtype=DTYPE,
)

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [5]:
# # load connections from top_connections.pkl
# # USE PICKLE
# import pickle
# with open('/root/dictionary_learning/top_connections.pkl', 'rb') as f:
#     top_connections = pickle.load(f)
# suite.connections = top_connections

In [5]:
all_conns = t.arange(0, num_features, device=device, dtype=t.int64).expand(num_features, num_features)
connections = {}
for down_layer in range(n_layer):
    # connections[f"attn_{down_layer}"] = {}
    connections[f"mlp_{down_layer}"] = {}
    for up_layer in range(down_layer+1):
        connections[f"mlp_{down_layer}"][f"attn_{up_layer}"] = all_conns
        connections[f"mlp_{down_layer}"][f"mlp_{up_layer}"] = all_conns

suite.connections = connections

In [6]:
all_conns.shape

torch.Size([12288, 12288])

In [7]:
def run_evaluation(
        suite, 
        buffer, 
        n_batches=10, 
        ce_batch_size=32,
        use_sparse_connections=False
        ):
    '''Simple function to run evaluation on several batches, and return the average metrics'''
    
    varexp_metrics = {name : {} for name in buffer.submodules.keys()}
    ce_metrics = {name : {} for name in buffer.submodules.keys()}

    for i in tqdm(range(n_batches)):
        # get varexp metrics
        initial_acts, input_acts, output_acts, layernorm_scales = next(buffer)
        batch_varexp_metrics = suite.evaluate_varexp_batch(
            initial_acts,
            input_acts, 
            output_acts,
            layernorm_scales,
            use_sparse_connections=use_sparse_connections
            )

        # get CE metrics
        b = buffer.refresh_batch_size
        buffer.refresh_batch_size = ce_batch_size
        tokens = buffer.token_batch()
        batch_ce_metrics = suite.evaluate_ce_batch(
            model, 
            tokens, 
            initial_submodule,
            submodules,
            layernorm_submodules,
            use_sparse_connections=use_sparse_connections
            )
        buffer.refresh_batch_size = b

        for name in ce_metrics.keys():
            for metric in batch_ce_metrics[name].keys():
                ce_metrics[name][metric] = ce_metrics[name].get(metric, 0) + batch_ce_metrics[name].get(metric, 0) / n_batches
            for metric in batch_varexp_metrics[name].keys():
                varexp_metrics[name][metric] = varexp_metrics[name].get(metric, 0) + batch_varexp_metrics[name].get(metric, 0) / n_batches
           
    return varexp_metrics, ce_metrics

In [8]:
varexp_metrics, ce_metrics = run_evaluation(
    suite, 
    buffer, 
    n_batches=1, 
    ce_batch_size=1,
    use_sparse_connections=True
    )

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 24.00 GiB. GPU 0 has a total capacity of 44.45 GiB of which 3.13 GiB is free. Process 2130634 has 41.31 GiB memory in use. Of the allocated memory 37.24 GiB is allocated by PyTorch, and 3.77 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [13]:
print(f"Clean loss = {ce_metrics['mlp_0']['loss_original']:.3f}\n")

print("Module  CE increase  CE expl FVU")
for name in [k for k in ce_metrics.keys() if 'mlp' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}        {ce_metrics[name]['frac_recovered']*100:.0f}%     {varexp_metrics[name]['FVU']*100:.0f}%")

print()

for name in [k for k in ce_metrics.keys() if 'attn' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}        {ce_metrics[name]['frac_recovered']*100:.0f}%     {varexp_metrics[name]['FVU']*100:.0f}%")

Clean loss = 2.603

Module  CE increase  CE expl FVU
mlp_0   2.052        63%     156%
mlp_1   3.324        -9061%     501%
mlp_2   4.008        -1937%     1613%
mlp_3   4.290        -2960%     3161%
mlp_4   5.054        -4547%     2836%
mlp_5   4.334        -2453%     1951%
mlp_6   3.739        -2183%     1089%
mlp_7   3.900        -3212%     641%
mlp_8   3.856        -3087%     457%
mlp_9   4.637        -3104%     478%
mlp_10   1.015        -373%     206%
mlp_11   2.167        -365%     265%

attn_0   0.008        100%     1%
attn_1   -0.005        156%     3%
attn_2   0.008        88%     5%
attn_3   0.007        93%     6%
attn_4   -0.005        104%     7%
attn_5   0.003        91%     7%
attn_6   0.004        96%     7%
attn_7   0.001        120%     7%
attn_8   0.007        91%     8%
attn_9   -0.006        108%     7%
attn_10   0.005        90%     6%
attn_11   -0.000        100%     1%
