In [None]:
%load_ext autoreload
%autoreload 2

from buffer import AllActivationBuffer
from trainers.top_k import TrainerSCAE, AutoEncoderTopK
from trainers.scae import SCAESuite, TrainerSCAESuite, TrainerConfig
from training import train_scae_suite

from datasets import load_dataset
import torch as t
from nnsight import LanguageModel
from tqdm import tqdm
from collections import defaultdict

DTYPE = t.bfloat16
device = "cuda:0" if t.cuda.is_available() else "cpu"
model = LanguageModel("gpt2", device_map=device, torch_dtype=DTYPE)
# model.eval()


dataset = load_dataset(
    'Skylion007/openwebtext', 
    split='train', 
    streaming=True,
    trust_remote_code=True
    )

class CustomData():
    def __init__(self, dataset):
        self.data = iter(dataset)

    def __iter__(self):
        return self

    def __next__(self):
        return next(self.data)['text']

data = CustomData(dataset)

C = 10
expansion = 16
k = 128 # TODO: automatically detect these

num_features = model.config.n_embd * expansion


##
n_layer = model.config.n_layer

submodules = {}
for layer in range(n_layer):
    submodules[f"mlp_{layer}"] = (model.transformer.h[layer].mlp, "in_and_out")
    submodules[f"attn_{layer}"] = (model.transformer.h[layer].attn, "out")

def random_up_features(down_layer):
    # Fake important-connection dictionary, for testing
    dic = {}
    for layer in range(down_layer):
        dic[f"mlp_{layer}"]  = t.randint(0, num_features, (num_features, C), dtype=t.long)
        dic[f"attn_{layer}"] = t.randint(0, num_features, (num_features, C), dtype=t.long)
    return dic

def all_up_features(down_layer):
    # Fake important-connection dictionary, for testing
    dic = {}
    for layer in range(n_layer):
        dic[f"mlp_{layer}"]  = t.randint(0, num_features, (num_features, C), dtype=t.long)
        dic[f"attn_{layer}"] = t.randint(0, num_features, (num_features, C), dtype=t.long)
    return dic

important_features = {f"mlp_{down_layer}": random_up_features(down_layer) 
                      for down_layer in range(n_layer)}
# important_features = {}
                        

# Get submodule names from the submodules dictionary
submodule_names = list(submodules.keys())

pretrained_info = {}
for layer in range(n_layer):
    for module in ['attn', 'mlp']:
        pretrained_info[f'{module}_{layer}'] = {'repo_id': 'jacobcd52/scae', 'filename': f'ae_{module}_{layer}.pt'}


##
buffer = AllActivationBuffer(
    data=data,
    model=model,
    submodules=submodules,
    d_submodule=model.config.n_embd, # output dimension of the model component
    n_ctxs=128,  # you can set this higher or lower depending on your available memory
    device="cuda",
    out_batch_size = 64,
    refresh_batch_size = 256,
    dtype=DTYPE,
) 


# def run_evaluation(trainer, buffer, n_batches=100, ce_batch_size=32):
#     varexp_metrics = {name : {} for name in trainer.submodules.keys()}
#     ce_metrics = {name : {} for name in trainer.submodules.keys()}
#     for i in tqdm(range(n_batches)):
#         # get varexp metrics
#         input_acts, output_acts = next(buffer)
#         batch_varexp_metrics = trainer.evaluate_varexp_batch(input_acts, output_acts, use_sparse_connections=False)

#         # get CE metrics
#         b = buffer.refresh_batch_size
#         buffer.refresh_batch_size = ce_batch_size
#         tokens = buffer.token_batch()
#         batch_ce_metrics = trainer.evaluate_ce_batch(model, tokens, use_sparse_connections=False)
#         buffer.refresh_batch_size = b

#         for name in ce_metrics.keys():
#             for metric in batch_ce_metrics[name].keys():
#                 ce_metrics[name][metric] = ce_metrics[name].get(metric, 0) + batch_ce_metrics[name][metric]
#             for metric in batch_varexp_metrics[name].keys():
#                 varexp_metrics[name][metric] = varexp_metrics[name].get(metric, 0) + batch_varexp_metrics[name][metric]
    
#     for name in ce_metrics.keys():
#         for metric in ce_metrics[name].keys():
#             ce_metrics[name][metric] = ce_metrics[name][metric] / n_batches
#         for metric in varexp_metrics[name].keys():
#             varexp_metrics[name][metric] = varexp_metrics[name][metric] / n_batches
        
#     return varexp_metrics, ce_metrics

# varexp_metrics, ce_metrics = run_evaluation(trainer, buffer, n_batches=2, ce_batch_size=32)

# print(f"Clean loss = {ce_metrics['mlp_0']['loss_original']:.3f}")
# print()
# for name in [k for k in ce_metrics.keys() if 'mlp' in k]:
#     print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}   {varexp_metrics[name]['frac_variance_explained']*100:.0f}%     {varexp_metrics[name]['l2_loss']:.1f}")

# print()

# for name in [k for k in ce_metrics.keys() if 'attn' in k]:
#     print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}   {varexp_metrics[name]['frac_variance_explained']*100:.0f}%    {varexp_metrics[name]['l2_loss']:.1f}")


pretrained_configs = {}
connections = defaultdict(dict)

for down_layer in range(n_layer):
    for module in ['attn', 'mlp']:
        down_name = f'{module}_{down_layer}'
        pretrained_configs[f'{module}_{down_layer}'] = {
            'repo_id': 'jacobcd52/scae', 
            'filename': f'ae_{module}_{down_layer}.pt',
            'k' : k
            }
        
        # Use random connections for testing
        for up_layer in range(down_layer):
            up_name = f'{module}_{up_layer}'
            connections[down_name][up_name] = t.randint(0, num_features, (num_features, C), dtype=t.long)

suite = SCAESuite.from_pretrained(pretrained_configs, connections=connections)

trainer_cfg = TrainerConfig(
    connection_sparsity_coeff=0.1,
)

trainer = train_scae_suite(
    buffer,
    module_specs=pretrained_configs,
    trainer_config=trainer_cfg,
    connections=connections,
    steps=100,
    save_steps = 10,
    # save_dir: Optional[str] = None,
    # log_steps: Optional[int] = None,
    # use_wandb: bool = False,
    # hf_repo_id: Optional[str] = None,
    dtype = DTYPE,
    device=device,
    # seed: Optional[int] = None,
)