In [None]:
%env TOKENIZERS_PARALLELISM=false
%env WANDB_SILENT=true

In [None]:
import torch
import wandb
from tqdm.notebook import tqdm

from grad_buffer import GradBuffer, GradBufferConfig
from model import AttributionSAE, AttributionSAEConfig
import loss

In [None]:
cfg = {
    'model_name': 'pythia-70m',
    'dataset_name': 'wikitext',
    'dataset_split': 'train',
    'dataset_config': 'wikitext-103-v1',
    'n_dim': 512,
    'expansion_factor': 32,
    'batch_size': 32,
    'total_steps': 50000,
    'parallelism': 48,
    'device': 'cuda',
    'dtype': 'bfloat16',
    'seed': 42
}

torch.manual_seed(cfg['seed'])

In [None]:
sweep_config = {
    'lr': [1e-2, 7e-3, 5.5e-3, 4e-3, 3e-3, 2e-3, 1.6e-3, 1.2e-3, 8.6e-4, 6.3e-4, 4.6e-4, 3.4e-4, 2.5e-4, 1.8e-4, 1.4e-4, 1e-4],
    'λ': [3e-2],
    'α': [0],
    'β': [1],
}

sweep_configs = [
    {
        'lr': lr,
        'λ': λ,
        'α': α,
        'β': β,
    }
    for lr in sweep_config['lr']
    for λ in sweep_config['λ']
    for α in sweep_config['α']
    for β in sweep_config['β']
]

In [None]:
buffer_config = GradBufferConfig(
    model_name=cfg['model_name'],
    layers=list(range(6)),
    buffer_size=2**16,
    buffer_device='cpu',
    min_capacity=3*(2**14),
    dataset_name=cfg['dataset_name'],
    dataset_split=cfg['dataset_split'],
    dataset_config=cfg['dataset_config'],
    max_seq_length=512,
    device=cfg['device'],
    dtype=cfg['dtype'],
    seed=cfg['seed'],
)

In [None]:
buffer = GradBuffer(buffer_config)

In [None]:
model_config = AttributionSAEConfig(
    n_dim=cfg['n_dim'],
    m_dim=cfg['n_dim']*cfg['expansion_factor'],
    device=cfg['device'],
    dtype=cfg['dtype'],
)

In [None]:
wandb.init(
    project='AttributionSAE Experiments',
    entity='collingray',
    group='sweeps',
    name=input("Wandb Run Name: "),
    config={**cfg, **{f'sweeps.{i}': c for i, c in enumerate(sweep_configs)}},
)

In [None]:
def train_models(configs, buffer, offset, report_interval=10):
    models = [AttributionSAE(model_config) for _ in range(len(configs))]
    optimizers = [torch.optim.Adam(models[i].parameters(), lr=config['lr']) for i, config in enumerate(configs)]
    schedulers = [torch.optim.lr_scheduler.OneCycleLR(optimizers[i], max_lr=config['lr'], total_steps=cfg['total_steps']) for i, config in enumerate(configs)]
    
    for step in tqdm(range(cfg['total_steps'])):

        x, grad = buffer.next(cfg['batch_size'])
        x = x.to(cfg['device'])
        grad = grad.unsqueeze(-2).to(cfg['device'])
        
        for i in range(len(configs)):
            config = configs[i]
            model = models[i]
            optimizer = optimizers[i]
            scheduler = schedulers[i]
            
            y, f = model(x)
            
            dictionary = model.W_d.weight
        
            reconstruction = loss.reconstruction(x, y)    
            act_sparsity = loss.act_sparsity(f)
            grad_sparsity = loss.grad_sparsity(f, grad, dictionary)
            unexplained = loss.unexplained(x, y, grad)
            l0 = (f != 0).sum(-1).float().mean()
            fvu = reconstruction / x.var()
            
            total_loss = reconstruction + config['λ']*act_sparsity + config['α']*grad_sparsity + config['β']*unexplained
            
            total_loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
            if step % report_interval == 0:
                wandb.log({
                        str(offset + i): {
                            'loss': total_loss.item(),
                            'reconstruction': reconstruction.item(),
                            'act_sparsity': act_sparsity.item(),
                            'grad_sparsity': grad_sparsity.item(),
                            'unexplained': unexplained.item(),
                            'l0': l0.item(),
                            'fvu': fvu.item(),
                        }
                    },
                    step=step // report_interval
                )


In [None]:
for i in range(-(len(sweep_configs) // -cfg['parallelism'])):    
    train_models(sweep_configs[i*cfg['parallelism']:(i+1)*cfg['parallelism']], buffer, i*cfg['parallelism'])

In [None]:
wandb.finish()