In [28]:
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
from jax.experimental import stax
from sklearn.model_selection import train_test_split
import numpy.random as npr
import pyhf
from pyhf.simplemodels import correlated_background
from typing import Any, Callable
from jaxopt import OptaxSolver
import functools
import optax
import relaxed
from neos import data

from inspect import getclosurevars

pyhf.set_backend("jax")

def run_neos(
    bandwidth,
    bins,
    epochs,
    loss="log(cls)",
    rng=PRNGKey(0),
    nn=None,
    batch_size=500,
    reflect=False,
    num_points=100000,
    animate=False,
    plot=True,
    test_size=0.1,
    predict=None,
    LUMI=10,
    sig_mean=jnp.asarray([-1, 1]),
    bup_mean=jnp.asarray([2.5, 2]),
    bdown_mean=jnp.asarray([-2.5, -1.5]),
    b_mean=jnp.asarray([1, -1]),
    sig_scale=2,
    bkg_scale=10,
):

    ## helper fn for data gen ##
    def gen_blobs():
        sig = jax.random.multivariate_normal(
            rng, sig_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(num_points,)
        )
        bkg_up = jax.random.multivariate_normal(
            rng, bup_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(num_points,)
        )
        bkg_down = jax.random.multivariate_normal(
            rng, bdown_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(num_points,)
        )
        bkg_nom = jax.random.multivariate_normal(
            rng, b_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(num_points,)
        )

        return sig, bkg_nom, bkg_up, bkg_down

    ## nn --> yields ##
    def histogram_maker(nn, data):
        assert data
        s, b_nom, b_up, b_down = data

        nn_s, nn_b_nom, nn_b_up, nn_b_down = (
            predict(nn, s).ravel(),
            predict(nn, b_nom).ravel(),
            predict(nn, b_up).ravel(),
            predict(nn, b_down).ravel(),
        )

        kde_counts = [
            relaxed.hist_kde(nn_s, bins, bandwidth, reflect_infinities=reflect)
            * sig_scale
            / num_points
            * LUMI,
            relaxed.hist_kde(nn_b_nom, bins, bandwidth, reflect_infinities=reflect)
            * bkg_scale
            / num_points
            * LUMI,
            relaxed.hist_kde(nn_b_up, bins, bandwidth, reflect_infinities=reflect)
            * bkg_scale
            / num_points
            * LUMI,
            relaxed.hist_kde(
                nn_b_down,
                bins,
                bandwidth,
                reflect_infinities=reflect,
            )
            * bkg_scale
            / num_points
            * LUMI,
        ]

        return [k for k in kde_counts]

    ## yields --> model ##
    def model(nn, data):
        yields = histogram_maker(nn, data)
        m = correlated_background(*yields)
        nompars = m.config.suggested_init()
        bonlypars = jnp.asarray([x for x in nompars])
        bonlypars = jax.ops.index_update(bonlypars, m.config.poi_index, 0.0)
        return m, bonlypars


    ## Data generation + train/test ##
    d = gen_blobs()

    split = train_test_split(*d, test_size=test_size, shuffle=False, random_state=1)
    train, test = split[::2], split[1::2]

    num_train = train[0].shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    # batching mechanism
    def data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size : (i + 1) * batch_size]
                yield [points[batch_idx] for points in train]

    stream = data_stream()

    # regression net
    init_random_params, predict = stax.serial(
        stax.Dense(1024),
        stax.Relu,
        stax.Dense(1024),
        stax.Relu,
        stax.Dense(1),
        stax.Sigmoid,
    )

    _, network = init_random_params(jax.random.PRNGKey(2), (-1, 2))

    Array = Any

    ## -NLL from model ##
    def fit_objective(model_pars, model_kwargs, constrained_mu):
        m, bonlypars = model(model_pars, **model_kwargs)
        exp_bonly_data = m.expected_data(bonlypars, include_auxdata=True)

        def constrained_fit_objective(nuis_par: Array) -> float:  # NLL
            pars = jnp.concatenate([jnp.asarray([constrained_mu]), jnp.array(nuis_par)])
            return -m.logpdf(pars, data = exp_bonly_data)[0]

        return constrained_fit_objective

    # try wrapping obj with closure_convert
    def minimize(objective_fn, np, lr):
        converted_fn, m_pars = jax.closure_convert(objective_fn, np) 
        # m_pars seems to be empty, took that line from docs example
        solver = OptaxSolver(fun=converted_fn, opt=optax.adam(lr), implicit_diff=True)
        return solver.run(np, *m_pars)[0][0]

    # constrained fit with grad descent via Optax
    def fit(model_pars, model_kwargs, constrained_mu=1., lr=3e-4):
        m, _ = model(model_pars, **model_kwargs) 
        suggested_init = [m.config.suggested_init()[0]]
        obj = fit_objective(model_pars, model_kwargs, constrained_mu=constrained_mu)

        return minimize(obj, suggested_init, lr)

    # use fit result as loss
    def loss(nn, data):
        return fit(nn, dict(data=data))
    
    print(loss(network, test)) # fwd pass
    return jax.value_and_grad(loss)(network, test)

In [29]:
run_neos(bandwidth=0.1, bins=jnp.linspace(0,1,4), epochs=2)

-0.0011910244856098077


(DeviceArray(-0.00119102, dtype=float64),
 [(DeviceArray([[ 0.03059435, -0.00975196,  0.01705084, ...,  0.01742486,
                  0.00082944, -0.03244367],
                [ 0.01347071,  0.00778273, -0.01803691, ..., -0.00638081,
                 -0.00243054, -0.00561185]], dtype=float64),
   DeviceArray([-0.0004559 ,  0.00733674, -0.01414513, ..., -0.00874408,
                 0.00250869, -0.00237924], dtype=float64)),
  (),
  (DeviceArray([[-8.54850084e-04,  4.33896450e-05, -4.56369895e-07, ...,
                 -5.68307526e-05, -2.38097885e-03, -6.88376352e-04],
                [-3.26957951e-04, -1.23679453e-08, -1.07857719e-07, ...,
                 -6.89681392e-05, -1.13554108e-03, -3.72788048e-04],
                [-5.14473857e-05,  1.04615312e-11, -1.15888590e-08, ...,
                 -9.54691167e-06, -1.78892791e-04, -5.77282605e-05],
                ...,
                [-3.30678975e-04,  1.05126151e-05, -3.43993512e-11, ...,
                 -1.35945844e-05, -1.00762982e

In [None]:
import jax
import jax.scipy as jsp
from jaxopt import OptaxSolver
import optax

def pipeline(param_for_grad, data):
    def to_minimize(latent):
        return -jsp.stats.norm.logpdf(data, loc=param_for_grad*latent, scale=1)

    solver = OptaxSolver(fun=to_minimize, opt=optax.adam(3e-4), implicit_diff=True)

    initial, _ = solver.init(init_params = 5.)

    result, _ = solver.run(init_params = initial)

    return result

jax.value_and_grad(pipeline)(2., data=6.)



CustomVJPException: Detected differentiation of a custom_vjp function with respect to a closed-over value. That isn't supported because the custom VJP rule only specifies how to differentiate the custom_vjp function with respect to explicit input parameters. Try passing the closed-over value into the custom_vjp function as an argument, and adapting the custom_vjp fwd and bwd rules.