In [1]:
from jax.example_libraries import stax
from jax.random import PRNGKey
import jax.numpy as jnp
import neos
from neos.experiments.nn_observable import (
    nn_summary_stat,
    make_model,
    generate_data,
    first_epoch,
    last_epoch,
    per_epoch,
    plot_setup,
)

rng_state = 0  # random state

init_random_params, nn = stax.serial(
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(1),
    stax.Sigmoid,
)

_, init_pars = init_random_params(PRNGKey(rng_state), (-1, 2))

p = neos.Pipeline(
    yields_from_pars=nn_summary_stat,
    model_from_yields=make_model,
    init_pars=init_pars,
    data=generate_data(),
    yield_kwargs=dict(nn=nn, bandwidth=1e-1, bins=jnp.linspace(0, 1, 5)),
    random_state=rng_state,
    loss=lambda x: x["CLs"],
    first_epoch_callback=first_epoch,
    last_epoch_callback=last_epoch,
    per_epoch_callback=per_epoch,
    plot_setup=plot_setup,
    num_epochs=50,
)
p.run()



epoch 0: took 18.7435s. state:
{'1-pull_width**2': DeviceArray(0.26492218, dtype=float64),
 'CLs': DeviceArray(0.06000888, dtype=float64),
 'loss': DeviceArray(0.06000888, dtype=float64),
 'mu_uncert': DeviceArray(0.4999013, dtype=float64),
 'pull': DeviceArray([0.03141377], dtype=float64),
 'pull_width': DeviceArray(0.48529408, dtype=float64),
 'yields': [DeviceArray([ 0.09774804,  9.17568   , 10.562891  ,  0.16366616], dtype=float32),
            DeviceArray([ 0.5660533 , 47.763428  , 50.952538  ,  0.71791816], dtype=float32),
            DeviceArray([ 0.903627  , 54.386677  , 44.282684  ,  0.42693916], dtype=float32),
            DeviceArray([ 0.26240483, 37.854694  , 60.441505  ,  1.4412428 ], dtype=float32)]}
epoch 1: took 14.0610s. state:
{'1-pull_width**2': DeviceArray(0.00608275, dtype=float64),
 'CLs': DeviceArray(0.00037576, dtype=float64),
 'loss': DeviceArray(0.00037576, dtype=float64),
 'mu_uncert': DeviceArray(0.13373315, dtype=float64),
 'pull': DeviceArray([-0.00838465]

In [2]:
!pwd

/Users/phinate/gradhep/neos
