In [None]:
import torch
import wandb
from tqdm.notebook import tqdm

from grad_buffer import GradBuffer, GradBufferConfig
from model import AttributionSAE, AttributionSAEConfig
import loss

In [None]:
%env TOKENIZERS_PARALLELISM=false
%env WANDB_SILENT=true

In [None]:
cfg = {
    'model_name': 'pythia-70m',
    'dataset_name': 'wikitext',
    'dataset_split': 'train',
    'dataset_config': 'wikitext-103-v1',
    'n_dim': 512,
    'expansion_factor': 32,
    'batch_size': 32,
    'total_steps': 10000,
    'learning_rate': 1e-4,
    'λ': 1e-3,
    'α': 1e-3,
    'β': 1e-3,
    'device': 'mps',
    'dtype': torch.float32,
    'seed': 42
}

torch.manual_seed(cfg['seed'])

In [None]:
wandb.init(
    project='AttributionSAE Experiments',
    entity='collingray',
    name='meaned loss terms',
    config=cfg,
)

In [None]:
buffer_config = GradBufferConfig(
    model_name=cfg['model_name'],
    layers=list(range(6)),
    dataset_name=cfg['dataset_name'],
    dataset_split=cfg['dataset_split'],
    dataset_config=cfg['dataset_config'],
    device=torch.device(cfg['device']),
)

In [None]:
buffer = GradBuffer(buffer_config)

In [None]:
model_config = AttributionSAEConfig(
    n_dim=cfg['n_dim'],
    m_dim=cfg['n_dim']*cfg['expansion_factor'],
    device=torch.device(cfg['device']),
    dtype=cfg['dtype'],
)

In [None]:
model = AttributionSAE(model_config)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=cfg['learning_rate'])

In [None]:
report_interval = 100

for step in tqdm(range(cfg['total_steps'])):
    optimizer.zero_grad()
    x, grad = buffer.next(cfg['batch_size'])
    grad = grad.unsqueeze(-2)
    y, f = model(x)
    
    dictionary = model.W_d.weight

    reconstruction = loss.reconstruction(x, y)    
    act_sparsity = loss.act_sparsity(f)
    grad_sparsity = loss.grad_sparsity(f, grad, dictionary)
    unexplained = loss.unexplained(x, y, grad)
    
    total_loss = reconstruction + cfg['λ']*act_sparsity + cfg['α']*grad_sparsity + cfg['β']*unexplained
    
    total_loss.backward()
    optimizer.step()

    if step % report_interval == 0:
        wandb.log({
            'loss': total_loss.item(),
            'reconstruction': reconstruction.item(),
            'act_sparsity': act_sparsity.item(),
            'grad_sparsity': grad_sparsity.item(),
            'unexplained': unexplained.item(),
        })

In [None]:
wandb.finish()