In [1]:
import yaml

from phd.feature_search.scripts.full_feature_search import *

if not hydra.core.global_hydra.GlobalHydra().is_initialized():
    hydra.initialize(config_path='../conf')

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path='../conf', config_name='full_feature_search')
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  hydra.initialize(config_path='../conf')


In [2]:
# Load hydra config
cfg = hydra.compose(
    config_name = "comet_sweeps/nonlinear_geoff_ablation_v5/base_config",
    overrides = [
        "wandb=false",
        "comet_ml=false",
        "model.hidden_dim=5120",
        "optimizer.learning_rate=$\{eval:0.003 / ${model.hidden_dim} ** 0.75\}",
        "train.log_freq=200",
        "task.flip_rate=$\{eval:2**-16\}",
        "seed=20250902",
    ]
)

yaml_cfg = omegaconf.OmegaConf.to_container(cfg, resolve=True)
print(yaml.dump(yaml_cfg, indent=2, width=80))

comet_ml: false
comet_ml_workspace: phd-research
device: cpu
feature_recycling:
  feature_protection_steps: 0
  initial_step_size_method: constant
  recycle_rate: 0.005
  use_cbp_utility: true
  use_signed_utility: false
  utility_decay: 0.99
input_recycling:
  distractor_chance: 0.0
  feature_protection_steps: 100
  n_start_real_features: -1
  recycle_rate: 0.0
  use_cbp_utility: false
  utility_decay: 0.99
model:
  activation: ltu
  hidden_dim: 5120
  n_frozen_layers: 1
  n_layers: 2
  output_dim: 1
  use_bias: true
  weight_init_method: binary
optimizer:
  autostep: true
  learning_rate: 4.956427797377644e-06
  meta_learning_rate: 0.005
  name: idbd
  step_size_decay: 0.0
  version: squared_grads
  weight_decay: 0
project: feature-search
representation_optimizer:
  learning_rate: 0.001
  name: null
  weight_decay: 0
seed: 20250902
task:
  activation: ltu
  distractor_chance: 0.0
  distractor_mean_range:
  - -0.5
  - 0.5
  distractor_std_range:
  - 0.1
  - 1.0
  flip_rate: 1.52587890

In [3]:
"""Run the feature recycling experiment."""
assert cfg.model.n_layers == 2, "Only 2-layer models are supported!"

cfg = init_experiment(cfg.project, cfg)

task, task_iterator, model, criterion, optimizer, repr_optimizer, recycler, cbp_tracker = \
    prepare_ltu_geoff_experiment(cfg)
model.forward = model_distractor_forward_pass.__get__(model)

distractor_tracker = DistractorTracker(
    model,
    cfg.task.distractor_chance,
    tuple(cfg.task.distractor_mean_range),
    tuple(cfg.task.distractor_std_range),
    seed = seed_from_string(cfg.seed, 'distractor_tracker'),
)

# run_experiment(
#     cfg, task, task_iterator, model, criterion, optimizer,
#     repr_optimizer, cbp_tracker, distractor_tracker,
# )


In [4]:
class RingBuffer:
    def __init__(self, size):
        self.size = size
        self.buffer = [None] * size
        self.index = 0

    def append(self, item):
        self.buffer[self.index] = item
        self.index = (self.index + 1) % self.size

    def get_buffer(self):
        return self.buffer
        
    def sample(self, n):
        # Get only non-None values
        valid_items = [x for x in self.buffer if x is not None]
        
        # Sample min of n and number of valid items
        n = min(n, len(valid_items))
        if n == 0:
            return []
            
        return random.sample(valid_items, n)

In [5]:
def train_step(
    model, criterion, optimizer, repr_optimizer, use_bias,
    cumulant_stats, distractor_tracker, inputs, targets,
    effective_lr_accum,
):
    # Forward pass
    outputs, param_inputs = model(
        inputs, distractor_tracker.replace_features, use_bias)
    loss = criterion(outputs, targets)
    
    with torch.no_grad():
        if cfg.train.standardize_cumulants:
            baseline_pred = torch.zeros_like(targets)
        else:
            baseline_pred = cumulant_stats.running_mean.cpu().view(1, 1)
        mean_pred_loss = criterion(baseline_pred, targets)

    # Backward pass
    optimizer.zero_grad()
    if repr_optimizer is not None:
        repr_optimizer.zero_grad()
    
    if isinstance(optimizer, IDBD):
        # Mean over batch dimension
        param_inputs = {k: v.mean(dim=0) for k, v in param_inputs.items()}
        retain_graph = optimizer.version == 'squared_grads'
        loss.backward(retain_graph=retain_graph)
        stats = optimizer.step(outputs, param_inputs)
        effective_lr_accum += list(stats.values())[0]['effective_step_size'].mean().item()
    else:
        loss.backward()
        optimizer.step()
        
    if repr_optimizer is not None:
        repr_optimizer.step()
        
    return loss.item(), mean_pred_loss.item(), effective_lr_accum

In [6]:
use_bias = cfg.model.get('use_bias', True)
    
# Distractor setup
n_hidden_units = model.layers[-1].in_features
first_feature_idx = 1 if use_bias else 0 # First feature is bias if enabled
distractor_tracker.process_new_features(list(range(first_feature_idx, n_hidden_units)))

# Training loop
step = 0
prev_pruned_idxs = set()
prune_layer = model.layers[-2]
# pbar = tqdm(total=cfg.train.total_steps, desc='Training')

# Flags
log_utility_stats = cfg.train.get('log_utility_stats', False)
log_pruning_stats = cfg.train.get('log_pruning_stats', False)
log_model_stats = cfg.train.get('log_model_stats', False)
log_optimizer_stats = cfg.train.get('log_optimizer_stats', False)

# Initialize accumulators
cumulant_stats = StandardizationStats(gamma=0.99)
cumulative_loss = np.float128(0.0)
loss_accum = 0.0
mean_pred_loss_accum = 0.0
effective_lr_accum = 0.0
pruned_accum = 0
pruned_newest_feature_accum = 0
n_steps_since_log = 0
total_pruned = 0
prune_thresholds = []
target_buffer = []

# Replay stuff
replay_buffer = RingBuffer(size=164)
n_replay_steps = 16
use_replay_buffer = True

In [7]:
while step < cfg.train.total_steps:

    # Generate batch of data
    inputs, targets = next(task_iterator)

    # Add noise to targets
    if cfg.task.noise_std > 0:
        targets += torch.randn_like(targets) * cfg.task.noise_std
    
    with torch.no_grad():
        standardized_targets, cumulant_stats = standardize_targets(targets, cumulant_stats)
    
    if cfg.train.standardize_cumulants:
        targets = standardized_targets
    target_buffer.extend(targets.view(-1).tolist())
    
    features, targets = inputs.to(cfg.device), targets.to(cfg.device)

    # Reset weights and optimizer states for recycled features
    if cbp_tracker is not None:
        if log_pruning_stats:
            pre_prune_utilities = cbp_tracker.get_statistics(prune_layer)['utility']

        if isinstance(cbp_tracker, SignedCBPTracker):
            pruned_idxs = cbp_tracker.prune_features(targets)
        else:
            pruned_idxs = cbp_tracker.prune_features()
        n_pruned = sum([len(idxs) for idxs in pruned_idxs.values()])
        total_pruned += n_pruned

        if prune_layer in pruned_idxs and len(pruned_idxs[prune_layer]) > 0:
            new_feature_idxs = pruned_idxs[prune_layer].tolist()
            distractor_process_idxs = new_feature_idxs

            # Don't turn bias into a distractor
            if use_bias:
                distractor_process_idxs = [idx for idx in distractor_process_idxs if idx != 0]

            # Turn some features into distractors
            distractor_tracker.process_new_features(distractor_process_idxs)

            # Log pruning statistics
            pruned_accum += len(new_feature_idxs)
            n_new_pruned_features = len(set(new_feature_idxs).intersection(prev_pruned_idxs))
            pruned_newest_feature_accum += n_new_pruned_features
            prev_pruned_idxs = set(new_feature_idxs)
            
            if log_pruning_stats:
                prune_thresholds.append(pre_prune_utilities[new_feature_idxs].max().item())
    
    # Train step
    if use_replay_buffer:
        samples = replay_buffer.sample(n_replay_steps)
        samples += [(features, targets)]
        for sample in samples:
            loss, mean_pred_loss, effective_lr_accum = train_step(
                model, criterion, optimizer, repr_optimizer, use_bias, cumulant_stats,
                distractor_tracker, *sample, effective_lr_accum)
        replay_buffer.append((features.clone(), targets.clone()))
        
    else:
        losses = []
        for _ in range(n_replay_steps):
            loss, mean_pred_loss, effective_lr_accum = train_step(
                model, criterion, optimizer, repr_optimizer, use_bias, cumulant_stats,
                distractor_tracker, features, targets, effective_lr_accum)
            losses.append(loss)
        loss = losses[0]
    
    
    
    # Accumulate metrics
    loss_accum += loss
    cumulative_loss += loss
    mean_pred_loss_accum += mean_pred_loss
    n_steps_since_log += 1
    
    # Log metrics
    if step % cfg.train.log_freq == 0:
        n_distractors = distractor_tracker.distractor_mask.sum().item()
        n_real_features = distractor_tracker.distractor_mask.numel() - n_distractors
        metrics = {
            'step': step,
            'samples': step * cfg.train.batch_size,
            'loss': loss_accum / n_steps_since_log,
            'cumulative_loss': float(cumulative_loss),
            'mean_prediction_loss': mean_pred_loss_accum / n_steps_since_log,
            'squared_targets': torch.tensor(target_buffer).square().mean().item(),
            'n_distractors': n_distractors,
            'n_real_features': n_real_features,
        }

        if log_pruning_stats:
            if pruned_accum > 0:
                metrics['fraction_pruned_were_new'] = pruned_newest_feature_accum / pruned_accum
                pruned_newest_feature_accum = 0
                pruned_accum = 0
            metrics['units_pruned'] = total_pruned
            if len(prune_thresholds) > 0:
                metrics['prune_threshold'] = np.mean(prune_thresholds)
            prune_thresholds.clear()
        
        if log_utility_stats:
            all_utilities = cbp_tracker.get_statistics(prune_layer)['utility']
            distractor_mask = distractor_tracker.distractor_mask
            real_utilities = all_utilities[~distractor_mask]
            distractor_utilities = all_utilities[distractor_mask]
            
            cumulative_utility = all_utilities.sum().item()
            metrics['cumulative_utility'] = cumulative_utility
            
            if len(real_utilities) > 0:
                metrics['real_utility_median'] = real_utilities.median().item()
                metrics['real_utility_25th'] = real_utilities.quantile(0.25).item()
                metrics['real_utility_75th'] = real_utilities.quantile(0.75).item()
            
            if len(distractor_utilities) > 0:
                metrics['distractor_utility_median'] = distractor_utilities.median().item()
                metrics['distractor_utility_25th'] = distractor_utilities.quantile(0.25).item() 
                metrics['distractor_utility_75th'] = distractor_utilities.quantile(0.75).item()
        
        if log_optimizer_stats and isinstance(optimizer, IDBD):
            states = list(optimizer.state.values())
            assert len(states) == 1, "There should not be more than one optimizer state!"
            state = states[0]
            step_sizes = torch.exp(state['beta'])
            metrics['mean_step_size'] = step_sizes.mean().item()
            metrics['median_step_size'] = step_sizes.median().item()
            metrics['effective_lr'] = effective_lr_accum / n_steps_since_log
        effective_lr_accum = 0.0

        log_metrics(metrics, cfg, step=step)
        
        print(f'step: {step} | loss: {metrics["loss"]:.4f}')
        # pbar.set_postfix(loss=metrics['loss'])
        # pbar.update(cfg.train.log_freq)
        
        # Reset accumulators
        loss_accum = 0.0
        mean_pred_loss_accum = 0.0
        n_steps_since_log = 0
        target_buffer = []

    step += 1

# pbar.close()

step: 0 | loss: 1.0806
step: 200 | loss: 1.1469
step: 400 | loss: 0.3765
step: 600 | loss: 0.2714
step: 800 | loss: 0.2876
step: 1000 | loss: 0.3304
step: 1200 | loss: 0.2591
step: 1400 | loss: 0.3240
step: 1600 | loss: 0.2661
step: 1800 | loss: 0.2862
step: 2000 | loss: 0.3267
step: 2200 | loss: 0.2830
step: 2400 | loss: 0.3011
step: 2600 | loss: 0.3120
step: 2800 | loss: 0.3277
step: 3000 | loss: 0.3170
step: 3200 | loss: 0.3102
step: 3400 | loss: 0.2948
step: 3600 | loss: 0.2911
step: 3800 | loss: 0.3302
step: 4000 | loss: 0.2798


KeyboardInterrupt: 

In [None]:
finish_experiment(cfg)

Error in callback <bound method _WandbInit._pre_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7f44b941b490>> (for pre_run_cell), with arguments args (<ExecutionInfo object at 7f44b95e1cc0, raw_cell="finish_experiment(cfg)" store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://wsl%2Bubuntu-24.04/home/ejmejm/local_projects/phd_research/phd/feature_search/notebooks/feature_search_with_replay.ipynb#W6sdnNjb2RlLXJlbW90ZQ%3D%3D>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe

BrokenPipeError: [Errno 32] Broken pipe

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7f44b941b490>> (for post_run_cell), with arguments args (<ExecutionResult object at 7f44b95e1390, execution_count=9 error_before_exec=None error_in_exec=[Errno 32] Broken pipe info=<ExecutionInfo object at 7f44b95e1cc0, raw_cell="finish_experiment(cfg)" store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://wsl%2Bubuntu-24.04/home/ejmejm/local_projects/phd_research/phd/feature_search/notebooks/feature_search_with_replay.ipynb#W6sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe