In [50]:
import yaml

import hydra

from phd.feature_search.scripts.ensemble_feature_search import *

In [103]:
if not hydra.core.global_hydra.GlobalHydra().is_initialized():
    hydra.initialize(config_path="../../conf")

# Load hydra config
cfg = hydra.compose(
    config_name = "comet_sweeps/nonlinear_geoff_ablation_v5/base_config.yaml",
    overrides = [
        "task.n_features=30",
        "task.n_real_features=30",
        "task.noise_std=0.0",
        "train.total_steps=2_000_000",
        "seed=20250812",
        "train.log_freq=500",
        "comet_ml=false",
        "wandb=false",
        "model.hidden_dim=1280", #1280",
        "+model.ensemble_feature_selection_method=random",
        "+model.feature_utility_mode=mean",
        "+model.ensemble_utility_mode=objective_improvement",
        "+model.prediction_mode=mean",
        "+feature_recycling.ensemble_recycle_rate=0.0",
        "+model.ensemble_dim=1280",
        "+model.n_ensemble_members=1",
    ]
)

# yaml_cfg = omegaconf.OmegaConf.to_container(cfg, resolve=True)
# print(yaml.dump(yaml_cfg, indent=2, width=80))

In [104]:
# Every time a new feature is introduced, check if it should be kept or not
# If it should be kept, track how long for the feature to rise in utility to top 10%
# But what is utility in an ensemble? I think for now I can just say it's the mean.

In [111]:
cfg = init_experiment(cfg.project, cfg)

task, task_iterator, model, criterion, optimizer, repr_optimizer, recycler = \
    prepare_ltu_geoff_experiment(cfg)
model.forward = model_distractor_forward_pass.__get__(model)

distractor_tracker = DistractorTracker(
    model,
    cfg.task.distractor_chance,
    tuple(cfg.task.distractor_mean_range),
    tuple(cfg.task.distractor_std_range),
    seed = seed_from_string(cfg.seed, 'distractor_tracker'),
)

# Gives a 1:1 hidden dim to ensemble input mapping
model.ensemble_input_ids = torch.arange(
        0, model.n_ensemble_members * model.ensemble_dim,
        dtype = torch.long,
        device = model.ensemble_input_ids.device,
    ).reshape(model.n_ensemble_members, model.ensemble_dim)

# Distractor setup
n_hidden_units = model.input_layer.out_features
distractor_tracker.process_new_features(torch.arange(n_hidden_units))

# Training loop
step = 0
prev_pruned_idxs = set()
prune_layer = model.activation
pbar = tqdm(total=cfg.train.total_steps, desc='Training')

# Flags
log_utility_stats = cfg.train.get('log_utility_stats', False)
log_pruning_stats = cfg.train.get('log_pruning_stats', False)
log_model_stats = cfg.train.get('log_model_stats', False)

# Initialize accumulators
cumulant_stats = StandardizationStats(gamma=0.99)
pruning_state = PruningState()
cumulative_loss = np.float128(0.0)
loss_accum = 0.0
ensemble_loss_accum = 0.0
mean_pred_loss_accum = 0.0
pruned_accum = 0
pruned_newest_feature_accum = 0
n_steps_since_log = 0
total_features_pruned = 0
total_ensembles_pruned = 0
prune_thresholds = []
target_buffer = []

times_to_top_10_percent = [] # Tuples of feature creation step and time to top 10%
tracked_features = {} # Maps feature idx to the step it was introduced

Training:   2%|▏         | 44500/2000000 [04:33<3:20:17, 162.72it/s, loss=0.443]


In [112]:
# Get, for each target feature, how closely the closest learning network hidden unit matches it
optimal_feature_weights = task.weights[0].T

def compute_best_feature_match_counts(model, task):
    optimal_feature_weights = task.weights[0].T
    feature_matches = model.input_layer.weight.unsqueeze(dim=1) == optimal_feature_weights.unsqueeze(dim=0)
    positive_feature_match_counts = feature_matches.sum(dim=2)
    negative_feature_match_counts = (~feature_matches).sum(dim=2)
    feature_match_counts = torch.maximum(positive_feature_match_counts, negative_feature_match_counts)
    best_feature_match_counts = feature_match_counts.max(dim=0).values
    return best_feature_match_counts

best_feature_match_counts = compute_best_feature_match_counts(model, task)  

print(f'Best feature match counts: {best_feature_match_counts.tolist()}')

Best feature match counts: [24, 24, 25, 24, 24, 23, 24, 24, 25, 25, 23, 24, 24, 26, 24, 23, 23, 24, 26, 24]


In [113]:
while step < cfg.train.total_steps:

    ### Data Processing ###

    # Generate batch of data
    inputs, targets = next(task_iterator)

    # Add noise to targets
    if cfg.task.noise_std > 0:
        targets += torch.randn_like(targets) * cfg.task.noise_std
    
    with torch.no_grad():
        standardized_targets, cumulant_stats = standardize_targets(targets, cumulant_stats)
    
    if cfg.train.standardize_cumulants:
        targets = standardized_targets
    target_buffer.extend(targets.view(-1).tolist())
    
    features, targets = inputs.to(cfg.device), targets.to(cfg.device)


    ### Pruning ###

    if log_pruning_stats:
        pre_prune_feature_utilities = model.feature_utilities.ravel().cpu().clone().numpy()
        pre_prune_ensemble_utilities = model.ensemble_utilities.cpu().clone().numpy()
    
    # Prune the model if necessary
    prune_results = prune_model(
        model, optimizer, distractor_tracker, pruning_state, cfg)
    
    # Update pruning metrics
    total_features_pruned += prune_results['n_features_pruned']
    total_ensembles_pruned += prune_results['n_ensembles_pruned']
    
    if prune_results['n_features_pruned'] > 0:
        feature_idxs_pruned = prune_results['feature_idxs_pruned'].tolist()
        
        # Log pruning statistics
        pruned_accum += len(feature_idxs_pruned)
        n_new_pruned_features = len(set(feature_idxs_pruned).intersection(prev_pruned_idxs))
        pruned_newest_feature_accum += n_new_pruned_features
        prev_pruned_idxs = set(feature_idxs_pruned)
        
        if log_pruning_stats:
            prune_thresholds.append(pre_prune_feature_utilities[feature_idxs_pruned].max())
            
        # Track new features
        for feature_idx in feature_idxs_pruned:
            # If the feature is pruned while still being tracked, then it never made it to top 10%
            if feature_idx in tracked_features:
                times_to_top_10_percent.append(
                    (tracked_features[feature_idx], -1))
                tracked_features.pop(feature_idx)
            
        new_feature_weights = model.input_layer.weight[feature_idxs_pruned]
        feature_matches = new_feature_weights.unsqueeze(dim=1) == optimal_feature_weights.unsqueeze(dim=0)
        positive_match_counts = feature_matches.sum(dim=2)
        negative_match_counts = (~feature_matches).sum(dim=2)
        match_counts = torch.maximum(positive_match_counts, negative_match_counts) # (n_features_pruned, n_target_features)
        feature_improvement_matrix = match_counts > best_feature_match_counts.unsqueeze(dim=0)
        
        # Bool for each new feature that indicates if it is an overall improvement
        new_best_feature = feature_improvement_matrix.any(dim=1) # (n_features_pruned,)
        
        # best_feature_match_counts = torch.maximum(
        #     best_feature_match_counts,
        #     match_counts.max(dim=0).values,
        # )
        
        for i in range(len(feature_idxs_pruned)):
            feature_idx = feature_idxs_pruned[i]
            if new_best_feature[i]:
                tracked_features[feature_idx] = step
                
        best_feature_match_counts = compute_best_feature_match_counts(model, task)
    
    
    ### Track time to top 10% ###
    
    utility_treshold = torch.quantile(model.feature_utilities, 0.9)
    feature_utilities = model.feature_utilities.cpu().clone().numpy()
    for feature_idx in list(tracked_features.keys()):
        if feature_utilities[feature_idx] > utility_treshold:
            times_to_top_10_percent.append(
                (tracked_features[feature_idx], step - tracked_features[feature_idx]))
            tracked_features.pop(feature_idx)
    
    
    ### Forward Pass ###
    
    outputs, param_inputs, aux = model(
        features, targets, update_state=True,
        distractor_callback=distractor_tracker.replace_features,
    )
    loss = aux['loss']
    ensemble_loss_sum = aux['ensemble_losses'].sum()
    
    with torch.no_grad():
        if cfg.train.standardize_cumulants:
            baseline_pred = torch.zeros_like(targets)
        else:
            baseline_pred = cumulant_stats.running_mean.cpu().view(1, 1)
        mean_pred_loss = criterion(baseline_pred, targets)


    ### Backward Pass ###

    # Backward pass
    optimizer.zero_grad()
    if repr_optimizer is not None:
        repr_optimizer.zero_grad()
    
    if isinstance(optimizer, IDBD):
        # Mean over batch dimension
        param_inputs = {k: v.mean(dim=0) for k, v in param_inputs.items()}
        retain_graph = optimizer.version == 'squared_grads'
        ensemble_loss_sum.backward(retain_graph=retain_graph)
        optimizer.step(aux['ensemble_predictions'], param_inputs)
    else:
        ensemble_loss_sum.backward()
        optimizer.step()
        
    if repr_optimizer is not None:
        repr_optimizer.step()
        
    
    ### Metrics ###
    
    # Accumulate metrics
    loss_accum += loss.item()
    ensemble_loss_accum += ensemble_loss_sum.item()
    cumulative_loss += loss.item()
    mean_pred_loss_accum += mean_pred_loss.item()
    n_steps_since_log += 1
    
    
    ### Logging ###
    
    if step % cfg.train.log_freq == 0:
        n_distractors = distractor_tracker.distractor_mask.sum().item()
        n_real_features = distractor_tracker.distractor_mask.numel() - n_distractors
        metrics = {
            'step': step,
            'samples': step * cfg.train.batch_size,
            'loss': loss_accum / n_steps_since_log,
            'avg_ensemble_loss': ensemble_loss_accum / model.n_ensemble_members / n_steps_since_log,
            'cumulative_loss': float(cumulative_loss),
            'mean_prediction_loss': mean_pred_loss_accum / n_steps_since_log,
            'squared_targets': torch.tensor(target_buffer).square().mean().item(),
            'n_distractors': n_distractors,
            'n_real_features': n_real_features,
        }

        # Add model statistics separately for real and distractor features
        if log_model_stats:
            real_feature_masks = [
                torch.ones(model.layers[0].weight.shape[1], dtype=torch.bool, device=model.layers[0].weight.device),
                ~distractor_tracker.distractor_mask,
            ]
            metrics.update(get_model_statistics(
                model, features, param_inputs, real_feature_masks, metric_prefix='real_'))
            
            distractor_feature_masks = [
                real_feature_masks[0],
                distractor_tracker.distractor_mask,
            ]
            metrics.update(get_model_statistics(
                model, features, param_inputs, distractor_feature_masks, metric_prefix='distractor_'))

        log_metrics(metrics, cfg, step=step)
        
        pbar.set_postfix(loss=metrics['loss'])
        pbar.update(cfg.train.log_freq)
        
        # Reset accumulators
        loss_accum = 0.0
        mean_pred_loss_accum = 0.0
        ensemble_loss_accum = 0.0
        n_steps_since_log = 0
        target_buffer = []

    step += 1

pbar.close()

Training:   1%|          | 14500/2000000 [00:56<2:12:38, 249.47it/s, loss=0.448]

KeyboardInterrupt: 

In [None]:
times_to_top_10_percent

[(40, -1),
 (40, -1),
 (57, -1),
 (65, -1),
 (145, -1),
 (208, -1),
 (230, -1),
 (249, -1),
 (274, -1),
 (310, -1),
 (307, -1),
 (333, -1),
 (357, -1),
 (331, -1),
 (12, -1),
 (411, -1),
 (421, -1),
 (306, -1),
 (455, -1),
 (503, -1),
 (550, -1),
 (542, -1),
 (604, -1),
 (587, -1),
 (649, -1),
 (694, -1),
 (695, -1),
 (753, -1),
 (754, -1),
 (776, -1),
 (797, -1),
 (843, -1),
 (855, -1),
 (861, -1),
 (905, -1),
 (137, -1),
 (917, -1),
 (913, -1),
 (977, -1),
 (978, -1),
 (992, -1),
 (1011, -1),
 (1086, -1),
 (1106, -1),
 (1117, -1),
 (1150, -1),
 (1177, -1),
 (1210, -1),
 (1296, -1),
 (1303, -1),
 (1317, -1),
 (1406, -1),
 (1425, -1),
 (1424, -1),
 (1470, -1),
 (1487, -1),
 (1542, -1),
 (1574, -1),
 (1581, -1),
 (1597, -1),
 (1614, -1),
 (1602, -1),
 (1579, -1),
 (1614, -1),
 (1646, -1),
 (1721, -1),
 (1739, -1),
 (1745, -1),
 (1764, -1),
 (1781, -1),
 (180, -1),
 (1796, -1),
 (1821, -1),
 (1831, -1),
 (1833, -1),
 (1847, -1),
 (1877, -1),
 (1968, -1),
 (1982, -1),
 (97, 1972),
 (1984,

Training:   1%|          | 14500/2000000 [01:07<2:12:38, 249.47it/s, loss=0.448]

In [71]:
# Plot how many of the good features were discarded
# And separately plot the times to top 10%

In [None]:


run_experiment(
    cfg, task, task_iterator, model, criterion, optimizer,
    repr_optimizer, distractor_tracker,
)

finish_experiment(cfg)