In [11]:
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
from jax.experimental import stax
from sklearn.model_selection import train_test_split
import numpy.random as npr
import pyhf
from pyhf.simplemodels import correlated_background
from typing import Any, Callable
from jaxopt import OptaxSolver
import optax
import relaxed
from neos import data

pyhf.set_backend("jax")

def run_neos(
    bandwidth,
    bins,
    epochs,
    loss="log(cls)",
    rng=PRNGKey(0),
    nn=None,
    batch_size=500,
    reflect=False,
    num_points=100000,
    animate=False,
    plot=True,
    test_size=0.1,
    predict=None,
    LUMI=10,
    sig_mean=jnp.asarray([-1, 1]),
    bup_mean=jnp.asarray([2.5, 2]),
    bdown_mean=jnp.asarray([-2.5, -1.5]),
    b_mean=jnp.asarray([1, -1]),
    sig_scale=2,
    bkg_scale=10,
):

    def gen_blobs():
        sig = jax.random.multivariate_normal(
            rng, sig_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(num_points,)
        )
        bkg_up = jax.random.multivariate_normal(
            rng, bup_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(num_points,)
        )
        bkg_down = jax.random.multivariate_normal(
            rng, bdown_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(num_points,)
        )
        bkg_nom = jax.random.multivariate_normal(
            rng, b_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(num_points,)
        )

        return sig, bkg_nom, bkg_up, bkg_down

    def histogram_maker(nn, data):
        assert data
        s, b_nom, b_up, b_down = data

        nn_s, nn_b_nom, nn_b_up, nn_b_down = (
            predict(nn, s).ravel(),
            predict(nn, b_nom).ravel(),
            predict(nn, b_up).ravel(),
            predict(nn, b_down).ravel(),
        )

        kde_counts = [
            relaxed.hist_kde(nn_s, bins, bandwidth, reflect_infinities=reflect)
            * sig_scale
            / num_points
            * LUMI,
            relaxed.hist_kde(nn_b_nom, bins, bandwidth, reflect_infinities=reflect)
            * bkg_scale
            / num_points
            * LUMI,
            relaxed.hist_kde(nn_b_up, bins, bandwidth, reflect_infinities=reflect)
            * bkg_scale
            / num_points
            * LUMI,
            relaxed.hist_kde(
                nn_b_down,
                bins,
                bandwidth,
                reflect_infinities=reflect,
            )
            * bkg_scale
            / num_points
            * LUMI,
        ]

        return [k for k in kde_counts]

    def model(nn, data):
        yields = histogram_maker(nn, data)
        m = correlated_background(*yields)
        nompars = m.config.suggested_init()
        bonlypars = jnp.asarray([x for x in nompars])
        bonlypars = jax.ops.index_update(bonlypars, m.config.poi_index, 0.0)
        return m, bonlypars

    d = data.generate_blobs(rng, blobs=4, NMC=num_points)()

    split = train_test_split(*d, test_size=test_size, shuffle=False, random_state=1)
    train, test = split[::2], split[1::2]

    num_train = train[0].shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    # batching mechanism
    def data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size : (i + 1) * batch_size]
                yield [points[batch_idx] for points in train]

    stream = data_stream()

    # regression net
    init_random_params, predict = stax.serial(
        stax.Dense(1024),
        stax.Relu,
        stax.Dense(1024),
        stax.Relu,
        stax.Dense(1),
        stax.Sigmoid,
    )

    _, network = init_random_params(jax.random.PRNGKey(2), (-1, 2))

    Array = Any

    def fit_objective(
        model_pars, model_kwargs, constrained_mu
    ) -> tuple[float, Callable[[Array], float]]:

        m, bonlypars = model(model_pars, **model_kwargs)


        exp_bonly_data = m.expected_data(bonlypars, include_auxdata=True)

        def expected_logpdf(
            pars: Array,
        ) -> tuple[float]:  # maps pars to bounded space if pdf_transform = True

            return m.logpdf(pars, exp_bonly_data)

        def constrained_fit_objective(nuis_par: Array) -> float:  # NLL
            pars = jnp.concatenate([jnp.asarray([constrained_mu]), jnp.array(nuis_par)])
            return -expected_logpdf(pars)[0]

        return constrained_fit_objective, m.config.suggested_init()[1:]


    def fit(model_pars, model_kwargs, constrained_mu=1., lr=3e-4):
        obj, suggested_init = fit_objective(model_pars, model_kwargs, constrained_mu)
        solver = OptaxSolver(fun=obj, opt=optax.adam(lr), implicit_diff=True)

        nps, state = solver.init(suggested_init)

        current = nps

        for _ in range(10):
            nps, state = solver.update(params=nps, state=state)
            print(f"[Step {state.iter_num}] NLL: {state.value} nps: {nps[0]}")

        return nps[0]

    def loss(nn, data):
        return fit(nn, dict(data=data))

    return jax.value_and_grad(loss)(network, test)

In [12]:
run_neos(bandwidth=0.1, bins=jnp.linspace(0,1,4), epochs=2)

[Step 1] NLL: Traced<ConcreteArray(5.071598482172044, dtype=float64)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(5.07159848, dtype=float64)
       tangent = Traced<ShapedArray(float64[]):JaxprTrace(level=1/0)> nps: Traced<ConcreteArray(0.9997000000089716, dtype=float64)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(0.9997, dtype=float64)
       tangent = Traced<ShapedArray(float64[]):JaxprTrace(level=1/0)>
[Step 2] NLL: Traced<ConcreteArray(5.071498178603554, dtype=float64)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(5.07149818, dtype=float64)
       tangent = Traced<ShapedArray(float64[]):JaxprTrace(level=1/0)> nps: Traced<ConcreteArray(0.9994000019672078, dtype=float64)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(0.9994, dtype=float64)
       tangent = Traced<ShapedArray(float64[]):JaxprTrace(level=1/0)>
[Step 3] NLL: Traced<ConcreteArray(5.071397900661253, dtype=float64)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(5.0713979, dtype=floa

(DeviceArray(0.99700024, dtype=float64),
 [(DeviceArray([[ 8.84507332e-09, -2.80899314e-09,  4.91090997e-09, ...,
                  5.02365077e-09,  2.23502715e-10, -9.39729922e-09],
                [ 3.91631491e-09,  2.24027784e-09, -5.19029750e-09, ...,
                 -1.82871359e-09, -6.84650754e-10, -1.65184816e-09]],            dtype=float64),
   DeviceArray([-1.17318005e-10,  2.11433991e-09, -4.07399499e-09, ...,
                -2.51542833e-09,  7.09901938e-10, -7.02659851e-10],            dtype=float64)),
  (),
  (DeviceArray([[-2.46092372e-10,  1.25536133e-11, -1.32011604e-13, ...,
                 -1.63681144e-11, -6.84514852e-10, -1.97945495e-10],
                [-9.42006610e-11, -3.59981367e-15, -3.11857414e-14, ...,
                 -1.99167609e-11, -3.27152959e-10, -1.07366533e-10],
                [-1.48190900e-11,  2.94450836e-18, -3.34948538e-15, ...,
                 -2.75314950e-12, -5.15276008e-11, -1.66257780e-11],
                ...,
                [-9.519125