In [119]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')

from buffer import AllActivationBuffer, NNsightActivationBuffer, ActivationBuffer
from trainers.scae import TrainerConfig
from training import train_scae_suite

from datasets import load_dataset
import torch as t
from nnsight import LanguageModel
from collections import defaultdict
from tqdm import tqdm

DTYPE = t.bfloat16
device = "cuda:0" if t.cuda.is_available() else "cpu"
model = LanguageModel("gpt2", device_map=device, torch_dtype=DTYPE)
model.eval()

dataset = load_dataset(
    'Skylion007/openwebtext', 
    split='train', 
    streaming=True,
    trust_remote_code=True
    )

class CustomData():
    '''dumb helper class to make the dataset iterable'''
    def __init__(self, dataset):
        self.data = iter(dataset)
    def __iter__(self):
        return self
    def __next__(self):
        return next(self.data)['text']

data = CustomData(dataset)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
C = 10
expansion = 16
k = 128

num_features = model.config.n_embd * expansion
n_layer = model.config.n_layer

In [3]:
pretrained_configs = {}
connections = defaultdict(dict)

for down_layer in range(n_layer):
    for module in ['attn', 'mlp']:
        down_name = f'{module}_{down_layer}'
        pretrained_configs[f'{module}_{down_layer}'] = {
            'repo_id': 'jacobcd52/scae', 
            'filename': f'ae_{module}_{down_layer}.pt',
            'k' : k
            }
        
        # Use random connections for testing
        if module=='mlp':
            for up_layer in range(down_layer+1): # mlp sees attn from same layer
                up_name = f'{module}_{up_layer}'
                connections[down_name][up_name] = t.randint(0, num_features, (num_features, C), dtype=t.long)

In [124]:
submodules = {}
for layer in range(n_layer):
    submodules[f"mlp_{layer}"] = {"input_point" : (model.transformer.h[layer].ln_2, "out"),
                                  "output_point" : (model.transformer.h[layer].mlp, "in")}
    submodules[f"attn_{layer}"] = {"input_point" : (model.transformer.h[layer].attn, "in"),
                                    "output_point" : (model.transformer.h[layer].attn, "out")}
    
buffer = AllActivationBuffer(
    data=data,
    model=model,
    submodules=submodules,
    d_submodule=model.config.n_embd,
    n_ctxs=128,
    out_batch_size = 32,
    refresh_batch_size = 256,
    device=device,
    dtype=DTYPE,
)

  t.cuda.amp.autocast(dtype=self.dtype)


In [125]:
x = next(buffer)[0]['mlp_10']
y = next(buffer)[1]['mlp_10']

In [126]:
x-y

tensor([[  6.7500,   2.0625,  -6.8438,  ...,  -3.0781,   1.8203,  -4.9062],
        [-11.3750,   0.7188,  -2.5938,  ...,   0.9297,   7.2500,  -6.2500],
        [  2.8906,  -5.3125,  -3.1875,  ...,  -0.1562,   6.8438,  -2.7656],
        ...,
        [  0.5469,   1.2969,  -5.0938,  ...,   2.7500,  -3.1562,   5.6875],
        [ -3.5312,  -3.9844,  13.0000,  ...,  -6.6875,  -0.4453,   4.7188],
        [ -4.4688,   4.3750,  -3.2500,  ...,   0.4219,  -5.5625,   3.5000]],
       device='cuda:0', dtype=torch.bfloat16)

In [122]:
nnbuffer = ActivationBuffer(
    data=data,
    model=model,
    submodule=model.transformer.h[10].mlp,
    d_submodule=model.config.n_embd,
    n_ctxs=128,
    out_batch_size = 32,
    refresh_batch_size = 256,
    io="out",
    device=device,
)

In [123]:
next(nnbuffer)

IndexError: Above exception when execution Node: 'getitem_0' in Graph: '139515238658448'

In [30]:
trainer_cfg = TrainerConfig(
    connection_sparsity_coeff=0.1,
    steps=10,
)

trainer = train_scae_suite(
    buffer,
    module_specs=pretrained_configs,
    trainer_config=trainer_cfg,
    connections=connections,
    steps=10,
    save_steps = 10,
    dtype = DTYPE,
    device=device,
    # save_dir: Optional[str] = None,
    # log_steps: Optional[int] = None,
    # use_wandb: bool = False,
    # hf_repo_id: Optional[str] = None,
    # seed: Optional[int] = None,
)

ae_attn_0.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

  state_dict = t.load(weights_path, map_location='cpu')


ae_mlp_0.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_1.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_1.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_2.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_2.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_3.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_3.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_4.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_4.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_5.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_5.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_6.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_6.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_7.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_7.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_8.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_8.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_9.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_9.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_10.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_10.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_attn_11.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

ae_mlp_11.pt:   0%|          | 0.00/75.6M [00:00<?, ?B/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [8]:
suite = trainer.suite
def run_evaluation(
        suite, 
        buffer, 
        n_batches=10, 
        ce_batch_size=32,
        use_sparse_connections=False
        ):
    '''Simple function to run evaluation on several batches, and return the average metrics'''
    
    varexp_metrics = {name : {} for name in buffer.submodules.keys()}
    ce_metrics = {name : {} for name in buffer.submodules.keys()}

    for i in tqdm(range(n_batches)):
        # get varexp metrics
        input_acts, output_acts = next(buffer)
        batch_varexp_metrics = suite.evaluate_varexp_batch(
            input_acts, 
            output_acts,
            use_sparse_connections=use_sparse_connections
            )

        # get CE metrics
        b = buffer.refresh_batch_size
        buffer.refresh_batch_size = ce_batch_size
        tokens = buffer.token_batch()
        batch_ce_metrics = suite.evaluate_ce_batch(
            model, 
            tokens, 
            buffer.submodules,
            use_sparse_connections=use_sparse_connections
            )
        buffer.refresh_batch_size = b

        for name in ce_metrics.keys():
            for metric in batch_ce_metrics[name].keys():
                ce_metrics[name][metric] = ce_metrics[name].get(metric, 0) + batch_ce_metrics[name][metric] / n_batches
            for metric in batch_varexp_metrics[name].keys():
                varexp_metrics[name][metric] = varexp_metrics[name].get(metric, 0) + batch_varexp_metrics[name][metric] / n_batches
           
    return varexp_metrics, ce_metrics

In [12]:
varexp_metrics, ce_metrics = run_evaluation(
    suite, 
    buffer, 
    n_batches=10, 
    ce_batch_size=32,
    use_sparse_connections=False
    )

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]


ReferenceError: weakly-referenced object no longer exists