In [1]:
import jax
import jax.numpy as jnp
import optax

# from nets import datasets
from datasets.nonlinear_gp import NonlinearGPDataset
from nets.launch import configs
from nets.models.feedforward import MLP
from models.feedforward import SimpleNet
from nets import samplers
# from nets.simulators.online_sgd import train_step, eval_step, evaluate, simulate
from experiments.online_sgd import simulate
# from nets.experiments.analyzable_online_sgd.launcher_local import SearchConfig

In [2]:
from dataclasses import dataclass
from dataclasses import field
from nets.launch.hparams import Param
from nets.launch.hparams import EnumParam
from nets.launch.hparams import FixedParam

@dataclass(frozen=True, kw_only=True)
class SearchConfig(configs.Config):
  """Generic config for a hyperparameter search."""

  seed: Param = field(default_factory=lambda: FixedParam(0))

  # Model params.
  # num_ins: Param = field(default_factory=lambda: FixedParam(4))
  num_hiddens: Param = field(default_factory=lambda: FixedParam(100))
  init_scale: Param = field(default_factory=lambda: FixedParam(0.01))
  num_dimensions: Param = field(default_factory=lambda: FixedParam(40))

  # Training and evaluation params.
  optimizer_fn: Param = field(default_factory=lambda: FixedParam(optax.adam))
  learning_rate: Param = field(default_factory=lambda: FixedParam(1e-3))
  batch_size: Param = field(default_factory=lambda: FixedParam(100))
  num_epochs: Param = field(default_factory=lambda: FixedParam(1))

  # Dataset params.
  dataset_cls: Param = field(default_factory=lambda: FixedParam(NonlinearGPDataset))
  xi1: Param = field(default_factory=lambda: FixedParam(0.1))
  xi2: Param = field(default_factory=lambda: FixedParam(1.1))
  gain: Param = field(default_factory=lambda: FixedParam(1.0))
  # num_dimensions: Param = field(default_factory=lambda: FixedParam(4))
  sampler_cls: Param = field(default_factory=lambda: FixedParam(samplers.EpochSampler))
  
#   dataset_cls: Param = field(init=False)
#   num_dimensions: Param = field(init=False)
#   num_exemplars_per_class: Param = field(init=False)
#   exemplar_noise_scale: Param = field(init=False)

  # Sampler params.
#   sampler_cls: Param = field(init=False)

In [3]:
search_config = SearchConfig(key=jax.random.PRNGKey(0), num_configs=1)
params = search_config[0]
simulate(**params)

Using JAX backend: cpu

Using configuration:
{'batch_size': 100,
 'dataset_cls': <class 'datasets.nonlinear_gp.NonlinearGPDataset'>,
 'gain': 1.0,
 'init_scale': 0.01,
 'learning_rate': 0.001,
 'num_dimensions': 40,
 'num_epochs': 1,
 'num_hiddens': 100,
 'optimizer_fn': <function adam at 0x13b186fc0>,
 'sampler_cls': <class 'nets.samplers.base.EpochSampler'>,
 'seed': 0,
 'xi1': 0.1,
 'xi2': 1.1}

simulate: len(dataset)=1000
Model:
SimpleNet(
  fc1=Linear(
    weight=f32[100,40],
    bias=f32[100],
    in_features=40,
    out_features=100,
    use_bias=True
  ),
  act=<wrapped function <lambda>>
)

Starting evaluation...


10it [00:01,  7.55it/s]


Completed evaluation over 1000 examples in 1.35 secs.
####
ITERATION 0
eval set:

	loss:			1.0002896785736084
	accuracy:		0.0
	BASELINE:		50.00%
	GT labels:		[ 1. -1. -1.  1. -1.  1. -1.  1. -1. -1.  1. -1.  1.  1. -1. -1.  1. -1.
 -1.  1.  1. -1. -1. -1.  1.  1.  1. -1.  1.  1.  1. -1.  1.  1.  1.  1.
  1. -1.  1.  1. -1. -1.  1.  1. -1. -1. -1. -1. -1. -1. -1.  1. -1.  1.
  1.  1.  1.  1.  1. -1.  1.  1.  1. -1. -1.  1. -1. -1. -1. -1. -1. -1.
 -1. -1. -1.  1.  1.  1. -1. -1. -1.  1.  1. -1. -1.  1.  1. -1. -1. -1.
 -1. -1. -1. -1. -1. -1.  1. -1.  1.  1.  1. -1.  1. -1. -1. -1. -1. -1.
 -1.  1. -1.  1.  1.  1. -1.  1. -1.  1. -1.  1.  1. -1. -1.  1.  1. -1.
 -1.  1.  1.  1.  1.  1.  1.  1. -1. -1.  1.  1.  1.  1. -1.  1. -1.  1.
  1.  1.  1.  1.  1. -1.  1. -1.  1. -1.  1.  1.  1.  1.  1. -1.  1. -1.
  1. -1. -1.  1. -1.  1.  1.  1.  1. -1.  1.  1. -1.  1. -1.  1.  1. -1.
  1.  1.  1. -1.  1.  1.  1. -1. -1. -1. -1.  1. -1.  1. -1.  1.  1. -1.
 -1.  1. -1.  1.  1. -1. -1.  1. -1.  1

0it [00:00, ?it/s]


ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float32[].
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

## Dataset

In [3]:
import jax
import jax.numpy as jnp
import equinox as eqx
from jaxtyping import Array
from jax.random import KeyArray

key = jax.random.PRNGKey(0)
model = SimpleNet(
    in_features=2,
    hidden_features=4,
    out_features=1,
    act=jnp.tanh,
    # drop=0.,
    key=key,
    init_scale=1.
)
# model = MLP(**params)

In [4]:
batch_size = 2
x, y = jnp.ones((batch_size, 2)), jnp.ones(batch_size)
(key,) = jax.random.split(key, 1)
model(x[0], key=key)

<PjitFunction of <function jax.numpy.tanh at 0x11c28e2a0>>


Array(0.28076008, dtype=float32)

In [5]:
def mse(pred_y: Array, y: Array) -> Array:
  """Compute elementwise mean squared error."""
  return jnp.square(pred_y - y).mean(axis=-1)


@eqx.filter_value_and_grad
def compute_loss(model: eqx.Module, x: Array, y: Array, key: KeyArray) -> Array:
  """Compute cross-entropy loss on a single example."""
  # print(x.shape, y.shape)
  keys = jax.random.split(key, x.shape[0])
  pred_y = jax.vmap(model)(x, key=keys)
  loss = mse(pred_y, y)
  # print(jnp.mean(jnp.abs(pred_y)).item())
  return loss.mean()

In [6]:
loss, grads = compute_loss(model, x, y, key=key)

<PjitFunction of <function jax.numpy.tanh at 0x11c28e2a0>>


In [7]:
grads(x[0], key=key)

None


TypeError: 'NoneType' object is not callable

In [8]:
jax.grad(model.__call__)(x[0], key=key)

Array([-0.3679621 ,  0.56394064], dtype=float32)

In [8]:
jax.grad(model.act)(2.)

Array(0.07065082, dtype=float32, weak_type=True)