In [25]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')

from datasets import load_dataset
import torch as t
from nnsight import LanguageModel
from tqdm import tqdm
from collections import defaultdict
import json

from buffer import AllActivationBuffer
from trainers.scae import SCAESuite
from utils import load_model_with_folded_ln2, load_iterable_dataset

DTYPE = t.bfloat16
device = "cuda:0" if t.cuda.is_available() else "cpu"
t.set_grad_enabled(False)
t.manual_seed(42)

model = load_model_with_folded_ln2("gpt2", device=device, torch_dtype=DTYPE)
data = load_iterable_dataset('Skylion007/openwebtext')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [26]:
C = 10
expansion = 16
k = 128

num_features = model.config.n_embd * expansion
n_layer = model.config.n_layer

In [27]:
suite = SCAESuite.from_pretrained(
    'jacobcd52/gpt2_suite_folded_ln',
    device=device,
    dtype=DTYPE,
    )

  checkpoint = t.load(checkpoint_path, map_location='cpu')


In [28]:
initial_submodule = model.transformer.h[0]
layernorm_submodules = {}
submodules = {}
for layer in range(n_layer):
    submodules[f"mlp_{layer}"] = (model.transformer.h[layer].mlp, "in_and_out")
    submodules[f"attn_{layer}"] = (model.transformer.h[layer].attn, "out")

    layernorm_submodules[f"mlp_{layer}"] = model.transformer.h[layer].ln_2

buffer = AllActivationBuffer(
    data=data,
    model=model,
    submodules=submodules,
    initial_submodule=initial_submodule,
    layernorm_submodules=layernorm_submodules,
    d_submodule=model.config.n_embd,
    n_ctxs=128,
    out_batch_size = 32,
    refresh_batch_size = 256,
    device=device,
    dtype=DTYPE,
)

  t.cuda.amp.autocast(dtype=self.dtype)
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [29]:
# load connections from top_connections.pkl
# USE PICKLE
import pickle
with open('/root/dictionary_learning/top_connections.pkl', 'rb') as f:
    top_connections = pickle.load(f)
suite.connections = top_connections

In [10]:
# # all_conns = t.arange(0, num_features, device=device, dtype=t.int64).expand(C, num_features)
# random_conns = t.randint(0, num_features, (num_features, C), device=device, dtype=t.int64)
# connections = {}
# for down_layer in range(n_layer):
#     # connections[f"attn_{down_layer}"] = {}
#     connections[f"mlp_{down_layer}"] = {}
#     for up_layer in range(down_layer):
#         connections[f"mlp_{down_layer}"][f"attn_{up_layer}"] = random_conns
#         connections[f"mlp_{down_layer}"][f"mlp_{up_layer}"] = random_conns
#     connections[f"mlp_{down_layer}"][f"attn_{down_layer}"] = random_conns

# suite.connections = connections

In [30]:
def run_evaluation(
        suite, 
        buffer, 
        n_batches=10, 
        ce_batch_size=32,
        use_sparse_connections=False
        ):
    '''Simple function to run evaluation on several batches, and return the average metrics'''
    
    varexp_metrics = {name : {} for name in buffer.submodules.keys()}
    ce_metrics = {name : {} for name in buffer.submodules.keys()}

    for i in tqdm(range(n_batches)):
        # get varexp metrics
        initial_acts, input_acts, output_acts, layernorm_scales = next(buffer)
        batch_varexp_metrics = suite.evaluate_varexp_batch(
            initial_acts,
            input_acts, 
            output_acts,
            layernorm_scales,
            use_sparse_connections=use_sparse_connections
            )

        # # get CE metrics
        # b = buffer.refresh_batch_size
        # buffer.refresh_batch_size = ce_batch_size
        # tokens = buffer.token_batch()
        # batch_ce_metrics = suite.evaluate_ce_batch(
        #     model, 
        #     tokens, 
        #     initial_submodule,
        #     submodules,
        #     layernorm_submodules,
        #     use_sparse_connections=use_sparse_connections
        #     )
        # buffer.refresh_batch_size = b

        for name in ce_metrics.keys():
            # for metric in batch_ce_metrics[name].keys():
            #     ce_metrics[name][metric] = ce_metrics[name].get(metric, 0) + batch_ce_metrics[name].get(metric, 0) / n_batches
            for metric in batch_varexp_metrics[name].keys():
                varexp_metrics[name][metric] = varexp_metrics[name].get(metric, 0) + batch_varexp_metrics[name].get(metric, 0) / n_batches
           
    return varexp_metrics, ce_metrics

In [34]:
varexp_metrics, ce_metrics = run_evaluation(
    suite, 
    buffer, 
    n_batches=10, 
    ce_batch_size=1,
    use_sparse_connections=False
    )

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:00<00:00, 81.46it/s]


In [35]:
for name in ce_metrics.keys():
    ce_metrics[name]['loss_reconstructed'] = 1
    ce_metrics[name]['loss_original'] = 1
    ce_metrics[name]['frac_recovered'] = 1

In [33]:
print(f"Clean loss = {ce_metrics['mlp_0']['loss_original']:.3f}\n")

print("Module  CE increase  CE expl FVU")
for name in [k for k in ce_metrics.keys() if 'mlp' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}        {ce_metrics[name]['frac_recovered']*100:.0f}%     {varexp_metrics[name]['FVU']*100:.0f}%")

print()

for name in [k for k in ce_metrics.keys() if 'attn' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}        {ce_metrics[name]['frac_recovered']*100:.0f}%     {varexp_metrics[name]['FVU']*100:.0f}%")

Clean loss = 1.000

Module  CE increase  CE expl FVU
mlp_0   0.000        100%     5%
mlp_1   0.000        100%     18%
mlp_2   0.000        100%     51%
mlp_3   0.000        100%     23%
mlp_4   0.000        100%     24%
mlp_5   0.000        100%     24%
mlp_6   0.000        100%     25%
mlp_7   0.000        100%     24%
mlp_8   0.000        100%     23%
mlp_9   0.000        100%     20%
mlp_10   0.000        100%     277%
mlp_11   0.000        100%     405%

attn_0   0.000        100%     1%
attn_1   0.000        100%     3%
attn_2   0.000        100%     4%
attn_3   0.000        100%     6%
attn_4   0.000        100%     8%
attn_5   0.000        100%     7%
attn_6   0.000        100%     8%
attn_7   0.000        100%     7%
attn_8   0.000        100%     8%
attn_9   0.000        100%     7%
attn_10   0.000        100%     6%
attn_11   0.000        100%     1%


In [36]:
print(f"Clean loss = {ce_metrics['mlp_0']['loss_original']:.3f}\n")

print("Module  CE increase  CE expl FVU")
for name in [k for k in ce_metrics.keys() if 'mlp' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}        {ce_metrics[name]['frac_recovered']*100:.0f}%     {varexp_metrics[name]['FVU']*100:.0f}%")

print()

for name in [k for k in ce_metrics.keys() if 'attn' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}        {ce_metrics[name]['frac_recovered']*100:.0f}%     {varexp_metrics[name]['FVU']*100:.0f}%")

Clean loss = 1.000

Module  CE increase  CE expl FVU
mlp_0   0.000        100%     4%
mlp_1   0.000        100%     11%
mlp_2   0.000        100%     19%
mlp_3   0.000        100%     12%
mlp_4   0.000        100%     15%
mlp_5   0.000        100%     16%
mlp_6   0.000        100%     17%
mlp_7   0.000        100%     17%
mlp_8   0.000        100%     17%
mlp_9   0.000        100%     15%
mlp_10   0.000        100%     11%
mlp_11   0.000        100%     8%

attn_0   0.000        100%     1%
attn_1   0.000        100%     3%
attn_2   0.000        100%     4%
attn_3   0.000        100%     6%
attn_4   0.000        100%     8%
attn_5   0.000        100%     7%
attn_6   0.000        100%     8%
attn_7   0.000        100%     7%
attn_8   0.000        100%     8%
attn_9   0.000        100%     7%
attn_10   0.000        100%     6%
attn_11   0.000        100%     1%


In [38]:
top_connections['mlp_10']

{'mlp_0': tensor([[   -1,    -1,    -1,  ...,    -1,    -1,    -1],
         [  250,    -1, 11139,  ...,    -1,  7932,    -1],
         [   -1,    -1,    -1,  ...,    -1,    -1,    -1],
         ...,
         [  250,    -1, 11139,  ..., 12242,    -1,    -1],
         [  250,    -1,    -1,  ..., 11132,    -1,    -1],
         [   -1,    -1,    -1,  ...,    -1,    -1,    -1]], device='cuda:0'),
 'mlp_1': tensor([[   -1,    -1,    -1,  ...,    -1,    -1,    -1],
         [   -1,  3208,    -1,  ...,    -1,    -1,  2056],
         [   -1,    -1,    -1,  ...,    -1,    -1,    -1],
         ...,
         [   -1,  3208,    -1,  ...,    -1,  3180,    -1],
         [   -1,    -1, 11580,  ...,    -1,    -1,    -1],
         [ 3208,    -1,    -1,  ...,    -1,    -1,    -1]], device='cuda:0'),
 'attn_0': tensor([[   -1,    -1,    -1,  ...,    -1,    -1,    -1],
         [   -1,    -1,    -1,  ...,    -1,    -1,    -1],
         [   -1,    -1,    -1,  ...,    -1,    -1,    -1],
         ...,
       