In [7]:
import logging
from collections import defaultdict
from typing import Iterator, Tuple, Callable, List, Optional

from IPython.display import HTML
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Optimizer
from tqdm import tqdm
import yaml

import hydra
from omegaconf import DictConfig

from phd.feature_search.core.experiment_helpers import (
    get_model_statistics,
    prepare_task,
    prepare_optimizer,
    seed_from_string,
    set_seed,
    standardize_targets,
    StandardizationStats,
)
from phd.feature_search.core.idbd import IDBD
from phd.feature_search.core.models import LTU, MLP
from phd.feature_search.core.feature_recycling import InputRecycler, CBPTracker
from phd.feature_search.core.tasks import NonlinearGEOFFTask
from phd.feature_search.scripts.full_feature_search import *
from phd.research_utils.logging import *

%matplotlib inline

if not hydra.core.global_hydra.GlobalHydra().is_initialized():
    hydra.initialize(config_path="../conf")

In [8]:
# Load hydra config
cfg = hydra.compose(
    config_name = "full_feature_search",
    overrides = [
        "seed=200",
        "model.hidden_dim=20_000",
        "task.distractor_chance=0.95", # 0.9
        "task.noise_std=0.0",
        "feature_recycling.recycle_rate=0.005",
        "train.total_steps=75_000",
        "train.standardize_cumulants=true",
        "optimizer.learning_rate=$\{eval:0.03 / ${model.hidden_dim} ** 0.75\}",
    ]
)

yaml_cfg = omegaconf.OmegaConf.to_container(cfg, resolve=True)
print(yaml.dump(yaml_cfg, indent=2, width=80))

comet_ml: false
comet_ml_workspace: phd-research
device: cpu
feature_recycling:
  feature_protection_steps: 0
  recycle_rate: 0.005
  use_cbp_utility: true
  utility_decay: 0.99
input_recycling:
  distractor_chance: 0.0
  feature_protection_steps: 100
  n_start_real_features: -1
  recycle_rate: 0.0
  use_cbp_utility: false
  utility_decay: 0.99
model:
  activation: ltu
  hidden_dim: 20000
  log_model_stats: false
  n_frozen_layers: 1
  n_layers: 2
  output_dim: 1
  weight_init_method: binary
optimizer:
  autostep: true
  learning_rate: 1.7838106725040817e-05
  meta_learning_rate: 0.005
  name: idbd
  version: squared_grads
  weight_decay: 0
project: feature-search
representation_optimizer:
  learning_rate: 0.001
  name: null
  weight_decay: 0
seed: 200
task:
  activation: ltu
  distractor_chance: 0.95
  distractor_mean_range:
  - -0.5
  - 0.5
  distractor_std_range:
  - 0.1
  - 1.0
  flip_rate: 0.0
  hidden_dim: 20
  n_features: 20
  n_layers: 2
  n_real_features: 20
  n_stationary_lay

In [9]:
# cfg = init_experiment(cfg.project, cfg)

task, task_iterator, model, criterion, optimizer, repr_optimizer, recycler, cbp_tracker = \
    prepare_ltu_geoff_experiment(cfg)
model.forward = model_distractor_forward_pass.__get__(model)

distractor_tracker = DistractorTracker(
    model,
    cfg.task.distractor_chance,
    tuple(cfg.task.distractor_mean_range),
    tuple(cfg.task.distractor_std_range),
    seed = seed_from_string(cfg.seed, 'distractor_tracker'),
)

# run_experiment(
#     cfg, task, task_iterator, model, criterion, optimizer,
#     repr_optimizer, cbp_tracker, distractor_tracker,
# )

# finish_experiment(cfg)

# Distractor setup
n_hidden_units = model.layers[-1].in_features
distractor_tracker.process_new_features(list(range(n_hidden_units)))

# Training loop
step = 0
prev_pruned_idxs = set()
prune_layer = model.layers[-2]

# Initialize accumulators
cumulant_stats = StandardizationStats(gamma=0.99)
cumulative_loss = np.float128(0.0)
loss_accum = 0.0
mean_pred_loss_accum = 0.0
pruned_accum = 0
pruned_newest_feature_accum = 0
n_steps_since_log = 0
total_pruned = 0
target_buffer = []

run_metrics = defaultdict(list)

In [10]:
with tqdm(total=cfg.train.total_steps, desc='Training') as pbar:
    while step < cfg.train.total_steps:
        # Generate batch of data
        inputs, targets = next(task_iterator)

        # Add noise to targets
        if cfg.task.noise_std > 0:
            targets += torch.randn_like(targets) * cfg.task.noise_std
        
        with torch.no_grad():
            standardized_targets, cumulant_stats = standardize_targets(targets, cumulant_stats)
        
        if cfg.train.standardize_cumulants:
            targets = standardized_targets
        target_buffer.extend(targets.view(-1).tolist())
        
        features, targets = inputs.to(cfg.device), targets.to(cfg.device)

        # Reset weights and optimizer states for recycled features
        if cbp_tracker is not None:
            pruned_idxs = cbp_tracker.prune_features()
            n_pruned = sum([len(idxs) for idxs in pruned_idxs.values()])
            total_pruned += n_pruned

            if prune_layer in pruned_idxs and len(pruned_idxs[prune_layer]) > 0:
                new_feature_idxs = pruned_idxs[prune_layer].tolist()

                # Turn some features into distractors
                distractor_tracker.process_new_features(new_feature_idxs)

                # Log pruning statistics
                pruned_accum += len(new_feature_idxs)
                n_new_pruned_features = len(set(new_feature_idxs).intersection(prev_pruned_idxs))
                pruned_newest_feature_accum += n_new_pruned_features
                prev_pruned_idxs = set(new_feature_idxs)
        
        # Forward pass
        outputs, param_inputs = model(
            features, distractor_tracker.replace_features)
        loss = criterion(outputs, targets)
        
        with torch.no_grad():
            if cfg.train.standardize_cumulants:
                baseline_pred = torch.zeros_like(targets)
            else:
                baseline_pred = cumulant_stats.running_mean.cpu().view(1, 1)
            mean_pred_loss = criterion(baseline_pred, targets)

        # Backward pass
        optimizer.zero_grad()
        if repr_optimizer is not None:
            repr_optimizer.zero_grad()
        
        if isinstance(optimizer, IDBD):
            # Mean over batch dimension
            param_inputs = {k: v.mean(dim=0) for k, v in param_inputs.items()}
            retain_graph = optimizer.version == 'squared_grads'
            loss.backward(retain_graph=retain_graph)
            optimizer.step(outputs, param_inputs)
        else:
            loss.backward()
            optimizer.step()
            
        if repr_optimizer is not None:
            repr_optimizer.step()
        
        # Accumulate metrics
        loss_accum += loss.item()
        cumulative_loss += loss.item()
        mean_pred_loss_accum += mean_pred_loss.item()
        n_steps_since_log += 1
        
        # Log metrics
        if step % cfg.train.log_freq == 0:
            n_distractors = distractor_tracker.distractor_mask.sum().item()
            n_real_features = distractor_tracker.distractor_mask.numel() - n_distractors
            metrics = {
                'step': step,
                'samples': step * cfg.train.batch_size,
                'loss': loss_accum / n_steps_since_log,
                'cumulative_loss': float(cumulative_loss),
                'mean_prediction_loss': mean_pred_loss_accum / n_steps_since_log,
                'squared_targets': torch.tensor(target_buffer).square().mean().item(),
                'units_pruned': total_pruned,
                'n_distractors': n_distractors,
                'n_real_features': n_real_features,
            }

            if pruned_accum > 0:
                metrics['fraction_pruned_were_new'] = pruned_newest_feature_accum / pruned_accum
                pruned_newest_feature_accum = 0
                pruned_accum = 0

            # Add model statistics separately for real and distractor features
            if cfg.model.get('log_model_stats', False):
                real_feature_masks = [
                    torch.ones(model.layers[0].weight.shape[1], dtype=torch.bool, device=model.layers[0].weight.device),
                    ~distractor_tracker.distractor_mask,
                ]
                metrics.update(get_model_statistics(
                    model, features, param_inputs, real_feature_masks, metric_prefix='real_'))
                
                distractor_feature_masks = [
                    real_feature_masks[0],
                    distractor_tracker.distractor_mask,
                ]
                metrics.update(get_model_statistics(
                    model, features, param_inputs, distractor_feature_masks, metric_prefix='distractor_'))

            log_metrics(metrics, cfg, step=step)
            for metric_name, value in metrics.items():
                run_metrics[metric_name].append(value)
                
            run_metrics['weights'].append(model.layers[-1].weight.clone().detach().cpu().squeeze(0).numpy())
            run_metrics['utilities'].append(cbp_tracker.get_statistics(model.layers[-2])['utility'].clone().detach().numpy())
            run_metrics['distractor_masks'].append(distractor_tracker.distractor_mask.clone().detach().cpu().numpy())
            
            pbar.set_postfix(loss=metrics['loss'])
            pbar.update(cfg.train.log_freq)
            
            # Reset accumulators
            loss_accum = 0.0
            mean_pred_loss_accum = 0.0
            n_steps_since_log = 0
            target_buffer = []

        step += 1

Training:  53%|█████▎    | 40000/75000 [04:10<03:39, 159.69it/s, loss=0.345]


KeyboardInterrupt: 

In [11]:
distractor_tracker.distractor_means[:] = 0
distractor_tracker.distractor_stds[:] = 0

In [5]:
n_steps = 100
n_features = 100

np.random.seed(13)
sampled_indices = np.random.choice(n_features, size=n_features, replace=False)

# Extract weights for sampled features across all timesteps
weights = np.stack(run_metrics['utilities'])[:n_steps]
distractor_masks = np.stack(run_metrics['distractor_masks'])[:n_steps]
sampled_weights = weights[:, sampled_indices]
sampled_distractor_masks = distractor_masks[:, sampled_indices]

min_weight = sampled_weights.min()
max_weight = sampled_weights.max()
weight_range = max_weight - min_weight

# Create animated bar plot
fig, ax = plt.subplots(figsize=(12, 6))

def update(frame):
    ax.clear()
    # Create bars with different colors based on distractor mask
    colors = ['red' if is_distractor else 'blue' for is_distractor in sampled_distractor_masks[frame]]
    ax.bar(range(n_features), sampled_weights[frame], color=colors)
    ax.set_ylim(
        min_weight - 0.05 * weight_range, 
        max_weight + 0.05 * weight_range,
    )
    ax.set_title(f'Feature Weights Over Time (Step {frame})')
    ax.set_xlabel('Feature Index')
    ax.set_ylabel('Weight Value')
    ax.set_xticks([])  # Remove x-axis ticks

# Create animation
anim = animation.FuncAnimation(
    fig, 
    update,
    frames = len(sampled_weights),
    interval = 100, # ms between frames
    repeat = False,
    # blit = True,
)

plt.close()
HTML(anim.to_jshtml())