In [1]:
import jax
import jax.numpy as jnp
import optax

# from nets import datasets
from datasets.nonlinear_gp import NonlinearGPDataset
from nets.launch import configs
from nets.models.feedforward import MLP
from models.feedforward import SimpleNet
from nets import samplers
# from nets.simulators.online_sgd import train_step, eval_step, evaluate, simulate
from experiments.online_sgd import simulate
# from nets.experiments.analyzable_online_sgd.launcher_local import SearchConfig

In [2]:
from dataclasses import dataclass
from dataclasses import field
from nets.launch.hparams import Param
from nets.launch.hparams import EnumParam
from nets.launch.hparams import FixedParam

@dataclass(frozen=True, kw_only=True)
class SearchConfig(configs.Config):
  """Generic config for a hyperparameter search."""

  seed: Param = field(default_factory=lambda: FixedParam(0))

  # Model params.
  # num_ins: Param = field(default_factory=lambda: FixedParam(4))
  num_hiddens: Param = field(default_factory=lambda: FixedParam(100))
  init_scale: Param = field(default_factory=lambda: FixedParam(0.01))
  num_dimensions: Param = field(default_factory=lambda: FixedParam(40))

  # Training and evaluation params.
  optimizer_fn: Param = field(default_factory=lambda: FixedParam(optax.adam))
  learning_rate: Param = field(default_factory=lambda: FixedParam(1e-3))
  batch_size: Param = field(default_factory=lambda: FixedParam(100))
  num_epochs: Param = field(default_factory=lambda: FixedParam(1))

  # Dataset params.
  dataset_cls: Param = field(default_factory=lambda: FixedParam(NonlinearGPDataset))
  xi1: Param = field(default_factory=lambda: FixedParam(0.1))
  xi2: Param = field(default_factory=lambda: FixedParam(1.1))
  gain: Param = field(default_factory=lambda: FixedParam(1.0))
  # num_dimensions: Param = field(default_factory=lambda: FixedParam(4))
  sampler_cls: Param = field(default_factory=lambda: FixedParam(samplers.EpochSampler))
  
#   dataset_cls: Param = field(init=False)
#   num_dimensions: Param = field(init=False)
#   num_exemplars_per_class: Param = field(init=False)
#   exemplar_noise_scale: Param = field(init=False)

  # Sampler params.
#   sampler_cls: Param = field(init=False)

In [3]:
search_config = SearchConfig(key=jax.random.PRNGKey(0), num_configs=1)
params = search_config[0]
simulate(**params)

Using JAX backend: cpu

Using configuration:
{'batch_size': 100,
 'dataset_cls': <class 'datasets.nonlinear_gp.NonlinearGPDataset'>,
 'gain': 1.0,
 'init_scale': 0.01,
 'learning_rate': 0.001,
 'num_dimensions': 40,
 'num_epochs': 1,
 'num_hiddens': 100,
 'optimizer_fn': <function adam at 0x13b186fc0>,
 'sampler_cls': <class 'nets.samplers.base.EpochSampler'>,
 'seed': 0,
 'xi1': 0.1,
 'xi2': 1.1}

simulate: len(dataset)=1000
Model:
SimpleNet(
  fc1=Linear(
    weight=f32[100,40],
    bias=f32[100],
    in_features=40,
    out_features=100,
    use_bias=True
  ),
  act=<wrapped function <lambda>>
)

Starting evaluation...


10it [00:01,  7.55it/s]


Completed evaluation over 1000 examples in 1.35 secs.
####
ITERATION 0
eval set:

	loss:			1.0002896785736084
	accuracy:		0.0
	BASELINE:		50.00%
	GT labels:		[ 1. -1. -1.  1. -1.  1. -1.  1. -1. -1.  1. -1.  1.  1. -1. -1.  1. -1.
 -1.  1.  1. -1. -1. -1.  1.  1.  1. -1.  1.  1.  1. -1.  1.  1.  1.  1.
  1. -1.  1.  1. -1. -1.  1.  1. -1. -1. -1. -1. -1. -1. -1.  1. -1.  1.
  1.  1.  1.  1.  1. -1.  1.  1.  1. -1. -1.  1. -1. -1. -1. -1. -1. -1.
 -1. -1. -1.  1.  1.  1. -1. -1. -1.  1.  1. -1. -1.  1.  1. -1. -1. -1.
 -1. -1. -1. -1. -1. -1.  1. -1.  1.  1.  1. -1.  1. -1. -1. -1. -1. -1.
 -1.  1. -1.  1.  1.  1. -1.  1. -1.  1. -1.  1.  1. -1. -1.  1.  1. -1.
 -1.  1.  1.  1.  1.  1.  1.  1. -1. -1.  1.  1.  1.  1. -1.  1. -1.  1.
  1.  1.  1.  1.  1. -1.  1. -1.  1. -1.  1.  1.  1.  1.  1. -1.  1. -1.
  1. -1. -1.  1. -1.  1.  1.  1.  1. -1.  1.  1. -1.  1. -1.  1.  1. -1.
  1.  1.  1. -1.  1.  1.  1. -1. -1. -1. -1.  1. -1.  1. -1.  1.  1. -1.
 -1.  1. -1.  1.  1. -1. -1.  1. -1.  1

0it [00:00, ?it/s]


ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float32[].
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

## Dataset

In [3]:
import jax
import jax.numpy as jnp
import equinox as eqx
from jaxtyping import Array
from jax.random import KeyArray

key = jax.random.PRNGKey(0)
model = SimpleNet(
    in_features=2,
    hidden_features=4,
    out_features=1,
    act=jnp.tanh,
    # drop=0.,
    key=key,
    init_scale=1.
)
# model = MLP(**params)

In [4]:
batch_size = 2
x, y = jnp.ones((batch_size, 2)), jnp.ones(batch_size)
(key,) = jax.random.split(key, 1)
model(x[0], key=key)

<PjitFunction of <function jax.numpy.tanh at 0x11c28e2a0>>


Array(0.28076008, dtype=float32)

In [5]:
def mse(pred_y: Array, y: Array) -> Array:
  """Compute elementwise mean squared error."""
  return jnp.square(pred_y - y).mean(axis=-1)


@eqx.filter_value_and_grad
def compute_loss(model: eqx.Module, x: Array, y: Array, key: KeyArray) -> Array:
  """Compute cross-entropy loss on a single example."""
  # print(x.shape, y.shape)
  keys = jax.random.split(key, x.shape[0])
  pred_y = jax.vmap(model)(x, key=keys)
  loss = mse(pred_y, y)
  # print(jnp.mean(jnp.abs(pred_y)).item())
  return loss.mean()

In [6]:
loss, grads = compute_loss(model, x, y, key=key)

<PjitFunction of <function jax.numpy.tanh at 0x11c28e2a0>>


In [7]:
grads(x[0], key=key)

None


TypeError: 'NoneType' object is not callable

In [8]:
jax.grad(model.__call__)(x[0], key=key)

Array([-0.3679621 ,  0.56394064], dtype=float32)

In [8]:
jax.grad(model.act)(2.)

Array(0.07065082, dtype=float32, weak_type=True)

In [52]:
from jax.scipy.special import erf as gain_function

def Z(g):
    return jnp.sqrt( (2/jnp.pi) * jnp.arcsin( (g**2) / (1 + (g**2)) ) )

def generate_non_gaussian(key, xi, L, g):
    C = jnp.abs(jnp.tile(jnp.arange(L)[:, jnp.newaxis], (1, L)) - jnp.tile(jnp.arange(L), (L, 1)))
    C = -C ** 2 / (xi ** 2)
    print(C.min(), C.max())
    C = jnp.exp(C)#-C ** 2 / (xi ** 2))
    print(C.min(), C.max())
    print(jnp.linalg.norm(C - C.T))
    z = jax.random.multivariate_normal(key, jnp.zeros(L, dtype=jnp.float32), C, method="svd")
    # print(z)
    x = gain_function(g * z) / Z(g)
    return x

key = jax.random.PRNGKey(10)
generate_non_gaussian(key, 4.47, 100, 1.0)

-490.51846 0.0
0.0 1.0
0.0


Array([ 0.13280278,  0.67017764,  0.91039586,  0.82608676,  0.5739829 ,
        0.4835764 ,  0.74268484,  1.1424717 ,  1.4003053 ,  1.5015821 ,
        1.5188626 ,  1.4570681 ,  1.2281711 ,  0.697571  ,  0.00707251,
       -0.41849583, -0.45780444, -0.20584969,  0.19321951,  0.5506563 ,
        0.71367085,  0.6083238 ,  0.21559878, -0.3209988 , -0.73741674,
       -0.910966  , -0.88841873, -0.72893816, -0.49133992, -0.27254838,
       -0.19947001, -0.3623984 , -0.6866402 , -0.93569905, -0.9519884 ,
       -0.6739141 , -0.12738809,  0.45165548,  0.8314931 ,  1.0004594 ,
        1.0933323 ,  1.2417078 ,  1.4571066 ,  1.6044921 ,  1.6422333 ,
        1.5736152 ,  1.2690396 ,  0.7049215 ,  0.39836252,  0.61761546,
        1.0420426 ,  1.2518497 ,  1.139631  ,  0.5592598 , -0.4611649 ,
       -1.2714162 , -1.5962808 , -1.681137  , -1.6906425 , -1.6432804 ,
       -1.4295843 , -0.9486383 , -0.52866083, -0.46047392, -0.56591   ,
       -0.5086306 , -0.04066715,  0.79095155,  1.4145088 ,  1.61

In [47]:
import numpy as np
from scipy.special import erf as gain_function

def Z(g):
    return np.sqrt( (2/np.pi) * np.arcsin( (g**2) / (1 + (g**2)) ) )

def generate_non_gaussian(xi, L, g):
    C = np.abs(np.tile(jnp.arange(L)[:, np.newaxis], (1, L)) - np.tile(np.arange(L), (L, 1)))
    C = -C ** 2 / (xi ** 2)
    print(C.min(), C.max())
    C = np.exp(C)#-C ** 2 / (xi ** 2))
    print(C.min(), C.max())
    print(np.linalg.norm(C - C.T))
    z = np.random.multivariate_normal(np.zeros(L, dtype=np.float32), C)
    # print(z)
    x = gain_function(g * z) / Z(g)
    return x

generate_non_gaussian(4.47, 100, 1.0)

-490.51844511508494 0.0
9.344283683987794e-214 1.0
0.0


array([-1.54891939, -1.51287565, -1.48423348, -1.4549705 , -1.37988648,
       -1.17491775, -0.75829416, -0.22930115,  0.10143989,  0.01759758,
       -0.49526407, -1.1405603 , -1.51700782, -1.63815109, -1.65584137,
       -1.61534692, -1.48108181, -1.15562726, -0.52919494,  0.27638698,
        0.81959246,  0.85928576,  0.31014739, -0.6276994 , -1.17190218,
       -1.1413023 , -0.39232522,  0.98435388,  1.63800183,  1.72327758,
        1.72940663,  1.72540715,  1.66188252,  1.13662467, -0.06246403,
       -0.8875065 , -1.14644109, -1.11390957, -0.92821984, -0.72658798,
       -0.64141659, -0.64801138, -0.56271989, -0.16161808,  0.61157206,
        1.30576444,  1.58365093,  1.63551361,  1.58172505,  1.37797921,
        0.96661578,  0.37972625, -0.37193023, -1.13257522, -1.55322124,
       -1.67026464, -1.67974384, -1.61017492, -1.2964218 , -0.47742614,
        0.51555853,  1.0602738 ,  1.13652464,  0.79209087, -0.03192705,
       -0.90073183, -1.31342751, -1.36873313, -1.19169719, -0.85