In [9]:
import equinox as eqx
import hydra
import jax
import jax.numpy as jnp
from omegaconf import DictConfig, OmegaConf

from phd.feature_search.scripts.jax_full_feature_search import train_step, train_multi_step, run_experiment, TrainState
from phd.feature_search.jax_core.experiment_helpers import rng_from_string, seed_from_string
from phd.feature_search.jax_core.models import MLP
from phd.feature_search.jax_core.tasks.geoff import NonlinearGEOFFTask

if not hydra.core.global_hydra.GlobalHydra().is_initialized():
    hydra.initialize(config_path="../conf")

In [2]:
seed = 20250925
rng = jax.random.key(seed)

In [12]:
cfg.task

{'name': 'static_linear_geoff', 'type': 'regression', 'n_features': 20, 'n_real_features': 10}

In [None]:
# Load hydra config
cfg = hydra.compose(
    config_name = 'nonlinear_geoff',
    # overrides = [
    #     "seed=200",
    #     "model.hidden_dim=20_000",
    #     "task.distractor_chance=0.95", # 0.9
    #     "task.noise_std=0.0",
    #     "feature_recycling.recycle_rate=0.005",
    #     "train.total_steps=75_000",
    #     "train.standardize_cumulants=true",
    #     "optimizer.learning_rate=$\{eval:0.03 / ${model.hidden_dim} ** 0.75\}",
    # ]
)

overrides = DictConfig(dict(
    task = {
    },
    train = {
        'standardize_cumulants': True,
    },
    model = {
        'use_bias': True,
    },
    optimizer = {
        'name': 'rmsprop',
        'learning_rate': 0.001,
    },
))

cfg = OmegaConf.merge(cfg, overrides)



In [None]:
config = DictConfig(dict(
    task = {
        'n_features': 128,
        'flip_rate': 0.0,
        'n_layers': 4,
        'n_stationary_layers': 4,
        'hidden_dim': 128,
        'activation': 'ltu',
        'sparsity': 0.99,
        'weight_init': 'binary',
        'noise_std': 0.0,
    },
    train = {
        'standardize_cumulants': True,
    },
    model = {
        'use_bias': True,
    },
    optimizer = {
        'name': 'rmsprop',
        'learning_rate': 0.001,
    },
))

In [None]:
### Create task ###

task = NonlinearGEOFFTask(
    n_features = 128,
    flip_rate = 0.0,
    n_layers = 4,
    n_stationary_layers = 4,
    hidden_dim = 128,
    activation = 'ltu',
    sparsity = 0.01,
    weight_init = 'binary',
    seed = seed_from_string(seed, 'task')
)

task_init_key = rng_from_string(rng, 'task_init_key')
task.weights[-1] = jax.random.uniform(
    task_init_key,
    task.weights[-1].shape,
    minval = -jnp.sqrt(6 / task.weights[-1].shape[0]),
    maxval = jnp.sqrt(6 / task.weights[-1].shape[0]),
)


### Create model ###

model = MLP(
    input_dim = 128,
    output_dim = 1,
    n_layers = 3,
    hidden_dim = 128,
    weight_init_method = 'lecun_uniform', # Input layer only
    activation = 'ltu',
    n_frozen_layers = 0,
    key = rng_from_string(rng, 'model'),
)
model = eqx.tree_at(
    lambda m: m.layers[-1].weight, model,
    jnp.zeros_like(model.layers[-1].weight),
)


### Create optimizer ###

optimizer = eqx.nn.optimizers.Adam(
    model,
    learning_rate=0.001,
)


### Prepare train state ###

train_state = TrainState(
    model = model,
    optimizer = optimizer,
    repr_optimizer = repr_optimizer,
    cbp_tracker = cbp_tracker,
    distractor_tracker = distractor_tracker,
    cfg = cfg,
    criterion = criterion,
    rng = rng,
)


In [5]:
config = DictConfig(dict(
    task = {
        'noise_std': 0.0,
    },
    train = {
        'standardize_cumulants': True,
    },
    model = {
        'use_bias': True,
    },
    optimizer = {
        'name': 'rmsprop',
        'learning_rate': 0.001,
    },
))

In [238]:
task, (x, y) = task.generate_batch(1)
y, x

(Array([[-0.10118932]], dtype=float32),
 Array([[ 0.36812264,  0.14396961,  0.23775421, -2.1664    ,  1.4979659 ,
         -1.2783132 , -1.1702579 , -0.4902626 , -1.0994087 ,  0.4753345 ,
         -0.02034544,  0.18264541, -0.20153946,  0.48281607, -1.4713173 ,
         -0.02804903,  0.4123233 , -0.34940764,  1.110215  ,  0.6521102 ,
          0.15252645,  0.37967673,  1.518812  ,  0.04493431,  1.7105712 ,
         -0.5803446 ,  0.6708027 ,  3.1807697 ,  0.2173414 ,  0.85257137,
          1.3861659 ,  0.47228897, -0.05502069,  0.22800943,  0.04969765,
         -0.30049673,  0.3231354 , -0.77608275, -0.41071197,  0.85540485,
          0.25477242, -0.9231067 , -1.1086105 ,  1.1790812 ,  0.7669939 ,
         -0.26520517, -1.1834105 , -0.13719901, -1.9254636 ,  0.61254644,
         -0.09367189,  2.0645673 ,  1.1909419 , -0.2551069 ,  0.76494515,
         -0.4698652 , -0.6349504 , -0.10417689,  0.16939937,  0.3544758 ,
         -0.21521348, -0.7152432 ,  1.2963883 ,  1.3865126 ,  0.36697778