In [1]:
import jax
import jax.numpy as jnp
import optax

# from nets import datasets
from datasets.nonlinear_gp import NonlinearGPDataset
from nets.launch import configs
from nets.models.feedforward import MLP
from models.feedforward import SimpleNet
from nets import samplers
# from nets.simulators.online_sgd import train_step, eval_step, evaluate, simulate
from experiments.online_sgd import simulate
# from nets.experiments.analyzable_online_sgd.launcher_local import SearchConfig

In [4]:
from dataclasses import dataclass
from dataclasses import field
from nets.launch.hparams import Param
from nets.launch.hparams import EnumParam
from nets.launch.hparams import FixedParam

@dataclass(frozen=True, kw_only=True)
class SearchConfig(configs.Config):
  """Generic config for a hyperparameter search."""

  seed: Param = field(default_factory=lambda: FixedParam(0))

  # Model params.
  num_ins: Param = field(default_factory=lambda: FixedParam(10))
  num_hiddens: Param = field(default_factory=lambda: FixedParam(8))
  init_scale: Param = field(default_factory=lambda: FixedParam(1.0))

  # Training and evaluation params.
  optimizer_fn: Param = field(default_factory=lambda: FixedParam(optax.sgd))
  learning_rate: Param = field(default_factory=lambda: FixedParam(1e-3))
  batch_size: Param = field(default_factory=lambda: FixedParam(32))
  num_epochs: Param = field(default_factory=lambda: FixedParam(5))

  # Dataset params.
  dataset_cls: Param = field(default_factory=lambda: FixedParam(NonlinearGPDataset))
  xi1: Param = field(default_factory=lambda: FixedParam(0.1))
  xi2: Param = field(default_factory=lambda: FixedParam(1.1))
  gain: Param = field(default_factory=lambda: FixedParam(1.0))
  num_dimensions: Param = field(default_factory=lambda: FixedParam(2))
  sampler_cls: Param = field(default_factory=lambda: FixedParam(samplers.EpochSampler))
  
#   dataset_cls: Param = field(init=False)
#   num_dimensions: Param = field(init=False)
#   num_exemplars_per_class: Param = field(init=False)
#   exemplar_noise_scale: Param = field(init=False)

  # Sampler params.
#   sampler_cls: Param = field(init=False)

In [5]:
search_config = SearchConfig(key=jax.random.PRNGKey(0), num_configs=1, num_hiddens=FixedParam(4), init_scale=FixedParam(0.1))
params = search_config.__getitem__(0)
simulate(**params)

Using JAX backend: cpu

Using configuration:
{'batch_size': 32,
 'dataset_cls': <class 'datasets.nonlinear_gp.NonlinearGPDataset'>,
 'gain': 1.0,
 'init_scale': 0.1,
 'learning_rate': 0.001,
 'num_dimensions': 2,
 'num_epochs': 5,
 'num_hiddens': 4,
 'num_ins': 10,
 'optimizer_fn': <function sgd at 0x145adb9c0>,
 'sampler_cls': <class 'nets.samplers.base.EpochSampler'>,
 'seed': 0,
 'xi1': 0.1,
 'xi2': 1.1}

Model:
SimpleNet(
  fc1=Linear(
    weight=f32[4,10],
    bias=f32[4],
    in_features=10,
    out_features=4,
    use_bias=True
  ),
  act=<wrapped function <lambda>>
)



TypeError: Slice size at index 0 in gather op is out of range, must be within [0, 0 + 1), got 1.

## Dataset

In [6]:
key = jax.random.PRNGKey(0)
# model = SimpleNet( #MLP(
#     in_features=2,
#     hidden_features=4,
#     out_features=1,
#     # act=jnp.tanh,
#     # drop=0.,
#     key=key,
#     init_scale=1.
# )
model = MLP(**params)

TypeError: MLP.__init__() got an unexpected keyword argument 'seed'

In [3]:
batch = jnp.ones(2)
(key,) = jax.random.split(key, 1)
model(batch, key=key)

Array([0.33544254], dtype=float32)