In [1]:
import torch
import wandb
from tqdm.notebook import tqdm

from grad_buffer import GradBuffer, GradBufferConfig
from model import AttributionSAE, AttributionSAEConfig
import loss

In [2]:
%env TOKENIZERS_PARALLELISM=false
%env WANDB_SILENT=true

env: TOKENIZERS_PARALLELISM=false
env: WANDB_SILENT=true


In [3]:
cfg = {
    'model_name': 'pythia-70m',
    'dataset_name': 'wikitext',
    'dataset_split': 'train',
    'dataset_config': 'wikitext-103-v1',
    'n_dim': 512,
    'expansion_factor': 32,
    'batch_size': 32,
    'total_steps': 10000,
    'learning_rate': 1e-4,
    'λ': 1e-3,
    'α': 1e-3,
    'β': 1e-3,
    'device': 'mps',
    'dtype': 'float32',
    'seed': 42
}

torch.manual_seed(cfg['seed'])

<torch._C.Generator at 0x108faf650>

In [4]:
wandb.init(
    project='AttributionSAE Experiments',
    entity='collingray',
    name=input("Wandb Run Name: "),
    config=cfg,
)

In [5]:
buffer_config = GradBufferConfig(
    model_name=cfg['model_name'],
    layers=list(range(6)),
    dataset_name=cfg['dataset_name'],
    dataset_split=cfg['dataset_split'],
    dataset_config=cfg['dataset_config'],
    device=cfg['device'],
)

In [6]:
buffer = GradBuffer(buffer_config)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Loaded pretrained model pythia-70m into HookedTransformer


In [7]:
model_config = AttributionSAEConfig(
    n_dim=cfg['n_dim'],
    m_dim=cfg['n_dim']*cfg['expansion_factor'],
    device=cfg['device'],
    dtype=cfg['dtype'],
)

In [8]:
model = AttributionSAE(model_config)

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=cfg['learning_rate'])

In [10]:
report_interval = 10

for step in tqdm(range(cfg['total_steps'])):
    optimizer.zero_grad()
    x, grad = buffer.next(cfg['batch_size'])
    grad = grad.unsqueeze(-2)
    y, f = model(x)
    
    dictionary = model.W_d.weight

    reconstruction = loss.reconstruction(x, y)    
    act_sparsity = loss.act_sparsity(f)
    grad_sparsity = loss.grad_sparsity(f, grad, dictionary)
    unexplained = loss.unexplained(x, y, grad)
    l0 = (f != 0).sum(-1).float().mean()
    fvu = reconstruction / x.var()
    
    total_loss = reconstruction + cfg['λ']*act_sparsity + cfg['α']*grad_sparsity + cfg['β']*unexplained
    
    total_loss.backward()
    optimizer.step()

    if step % report_interval == 0:
        wandb.log({
            'loss': total_loss.item(),
            'reconstruction': reconstruction.item(),
            'act_sparsity': act_sparsity.item(),
            'grad_sparsity': grad_sparsity.item(),
            'unexplained': unexplained.item(),
            'l0': l0.item(),
            'fvu': fvu.item(),
        })

        torch.mps.empty_cache()

  0%|          | 0/10000 [00:00<?, ?it/s]

In [11]:
wandb.finish()

In [25]:
model.save('model')