In [1]:
%load_ext autoreload
%autoreload 2

from buffer import AllActivationBuffer
from trainers.top_k import TrainerSCAE, AutoEncoderTopK
from training import trainSCAE

from datasets import load_dataset
import torch as t
from nnsight import LanguageModel
from tqdm import tqdm

DTYPE = t.float32
device = "cuda:0" if t.cuda.is_available() else "cpu"
model = LanguageModel("gpt2", device_map=device, torch_dtype=DTYPE)
# model.eval()


dataset = load_dataset(
    'Skylion007/openwebtext', 
    split='train', 
    streaming=True,
    trust_remote_code=True
    )

class CustomData():
    def __init__(self, dataset):
        self.data = iter(dataset)

    def __iter__(self):
        return self

    def __next__(self):
        return next(self.data)['text']

data = CustomData(dataset)

In [2]:
C = 10
expansion = 16
k = 128 # TODO: automatically detect these

num_features = model.config.n_embd * expansion

In [3]:
n_layer = model.config.n_layer

submodules = {}
for layer in range(n_layer):
    submodules[f"mlp_{layer}"] = (model.transformer.h[layer].mlp, "in_and_out")
    submodules[f"attn_{layer}"] = (model.transformer.h[layer].attn, "out")

# def random_up_features(down_layer):
#     # Fake important-connection dictionary, for testing
#     dic = {}
#     for layer in range(down_layer):
#         dic[f"mlp_{layer}"]  = t.randint(0, num_features, (num_features, C), dtype=t.long)
#         dic[f"attn_{layer}"] = t.randint(0, num_features, (num_features, C), dtype=t.long)
#     return dic

# def all_up_features(down_layer):
#     # Fake important-connection dictionary, for testing
#     dic = {}
#     for layer in range(n_layer):
#         dic[f"mlp_{layer}"]  = t.randint(0, num_features, (num_features, C), dtype=t.long)
#         dic[f"attn_{layer}"] = t.randint(0, num_features, (num_features, C), dtype=t.long)
#     return dic

# important_features = {f"mlp_{down_layer}": random_up_features(down_layer) 
#                       for down_layer in range(n_layer)}
important_features = {}
                        

# Get submodule names from the submodules dictionary
submodule_names = list(submodules.keys())

pretrained_info = {}
for layer in range(n_layer):
    for module in ['attn', 'mlp']:
        pretrained_info[f'{module}_{layer}'] = {'repo_id': 'jacobcd52/scae', 'filename': f'ae_{module}_{layer}.pt'}

In [4]:
buffer = AllActivationBuffer(
    data=data,
    model=model,
    submodules=submodules,
    d_submodule=model.config.n_embd, # output dimension of the model component
    n_ctxs=128,  # you can set this higher or lower depending on your available memory
    device="cuda",
    out_batch_size = 512,
    refresh_batch_size = 256,
    dtype=DTYPE,
) 

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [5]:
trainer = TrainerSCAE(
        activation_dims={name: model.config.n_embd for name in submodule_names},
        dict_sizes={name: model.config.n_embd * expansion for name in submodule_names},
        ks={name: k for name in submodule_names},
        submodules=submodules,
        important_features=important_features,
        pretrained_info=pretrained_info,
        model_config=model.config,
        auxk_alpha=0,
        connection_sparsity_coeff=0.1,
        use_sparse_connections=False,
        seed=None,
        device="cuda",
        wandb_name="SCAE",
        dtype=DTYPE, 
)

In [6]:
input_acts, target_acts = next(buffer)
trainer.update(1, input_acts, target_acts)

frac variance explained for mlp_0: 0.9526596069335938
frac variance explained for attn_0: 0.9909620881080627
frac variance explained for mlp_1: 0.9785271883010864
frac variance explained for attn_1: 0.968472957611084
frac variance explained for mlp_2: 0.9956282377243042
frac variance explained for attn_2: 0.9545985460281372
frac variance explained for mlp_3: 0.902158796787262
frac variance explained for attn_3: 0.9296982288360596
frac variance explained for mlp_4: 0.8518295288085938
frac variance explained for attn_4: 0.9151685833930969
frac variance explained for mlp_5: 0.8135673999786377
frac variance explained for attn_5: 0.9261492490768433
frac variance explained for mlp_6: 0.7719327211380005
frac variance explained for attn_6: 0.917524516582489
frac variance explained for mlp_7: 0.7637858390808105
frac variance explained for attn_7: 0.9284847974777222
frac variance explained for mlp_8: 0.7614310383796692
frac variance explained for attn_8: 0.9162231087684631
frac variance explaine

3233.654052734375

In [12]:
def run_evaluation(trainer, buffer, n_batches=100, ce_batch_size=32):
    varexp_metrics = {name : {} for name in trainer.submodules.keys()}
    ce_metrics = {name : {} for name in trainer.submodules.keys()}
    for i in tqdm(range(n_batches)):
        # get varexp metrics
        input_acts, output_acts = next(buffer)
        batch_varexp_metrics = trainer.evaluate_varexp_batch(input_acts, output_acts, use_sparse_connections=False)

        # get CE metrics
        b = buffer.refresh_batch_size
        buffer.refresh_batch_size = ce_batch_size
        tokens = buffer.token_batch()
        batch_ce_metrics = trainer.evaluate_ce_batch(model, tokens, use_sparse_connections=False)
        buffer.refresh_batch_size = b

        for name in ce_metrics.keys():
            for metric in batch_ce_metrics[name].keys():
                ce_metrics[name][metric] = ce_metrics[name].get(metric, 0) + batch_ce_metrics[name][metric]
            for metric in batch_varexp_metrics[name].keys():
                varexp_metrics[name][metric] = varexp_metrics[name].get(metric, 0) + batch_varexp_metrics[name][metric]
    
    for name in ce_metrics.keys():
        for metric in ce_metrics[name].keys():
            ce_metrics[name][metric] = ce_metrics[name][metric] / n_batches
        for metric in varexp_metrics[name].keys():
            varexp_metrics[name][metric] = varexp_metrics[name][metric] / n_batches
        
    return varexp_metrics, ce_metrics

varexp_metrics, ce_metrics = run_evaluation(trainer, buffer, n_batches=2, ce_batch_size=32)

print(f"Clean loss = {ce_metrics['mlp_0']['loss_original']:.3f}")
print()

  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 2/2 [00:09<00:00,  4.85s/it]

Clean loss = 3.413






In [13]:
for name in [k for k in ce_metrics.keys() if 'mlp' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}   {varexp_metrics[name]['frac_variance_explained']*100:.0f}%     {varexp_metrics[name]['l2_loss']:.1f}")

print()

for name in [k for k in ce_metrics.keys() if 'attn' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}   {varexp_metrics[name]['frac_variance_explained']*100:.0f}%    {varexp_metrics[name]['l2_loss']:.1f}")

mlp_0   0.009   96%     4.0
mlp_1   0.007   98%     5.1
mlp_2   0.048   100%     11.0
mlp_3   0.007   92%     5.9
mlp_4   0.008   87%     7.0
mlp_5   0.012   82%     8.3
mlp_6   0.017   78%     9.9
mlp_7   0.017   77%     11.8
mlp_8   0.017   76%     14.4
mlp_9   0.021   78%     17.7
mlp_10   0.026   82%     24.1
mlp_11   0.046   91%     32.3

attn_0   0.003   99%    1.5
attn_1   -0.000   97%    1.6
attn_2   -0.000   95%    1.8
attn_3   -0.000   93%    2.3
attn_4   0.001   91%    2.8
attn_5   0.000   93%    3.0
attn_6   0.002   92%    3.6
attn_7   0.001   93%    3.8
attn_8   0.001   92%    4.8
attn_9   0.001   93%    5.6
attn_10   0.002   94%    6.8
attn_11   0.003   100%    9.1


In [7]:
for name in [k for k in ce_metrics.keys() if 'mlp' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}   {varexp_metrics[name]['frac_variance_explained']*100:.0f}%     {varexp_metrics[name]['l2_loss']:.1f}")

print()

for name in [k for k in ce_metrics.keys() if 'attn' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}   {varexp_metrics[name]['frac_variance_explained']*100:.0f}%    {varexp_metrics[name]['l2_loss']:.1f}")

mlp_0   0.016   92%     6.3
mlp_1   0.023   97%     6.9
mlp_2   4.117   38%     1264.0
mlp_3   0.023   84%     8.2
mlp_4   0.023   80%     9.3
mlp_5   0.023   71%     10.8
mlp_6   0.031   69%     12.5
mlp_7   0.031   68%     14.8
mlp_8   0.031   68%     17.6
mlp_9   0.039   69%     21.5
mlp_10   0.047   75%     30.4
mlp_11   0.062   86%     44.8

attn_0   0.016   95%    3.4
attn_1   0.008   91%    2.7
attn_2   0.008   89%    2.8
attn_3   0.016   85%    3.4
attn_4   0.016   81%    4.2
attn_5   0.008   83%    4.4
attn_6   0.016   81%    5.4
attn_7   0.016   83%    5.7
attn_8   0.016   81%    7.0
attn_9   0.016   84%    8.1
attn_10   0.016   88%    9.5
attn_11   0.023   82%    22.1
