In [1]:
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
from jax.experimental import stax
from sklearn.model_selection import train_test_split
import numpy.random as npr
import pyhf
from pyhf.simplemodels import correlated_background
from jaxopt import OptaxSolver 
import optax
import relaxed

pyhf.set_backend("jax")

def run_4blobs(
    bandwidth,
    bins,
    epochs,
    loss_expr=lambda metrics: jnp.log(metrics['CLs']),
    rng=PRNGKey(0),
    nn=None,
    batch_size=500,
    reflect=False,
    num_points=100000,
    animate=False,
    plot=True,
    test_size=0.1,
    predict=None,
    LUMI=10,
    sig_mean=jnp.asarray([-1, 1]),
    bup_mean=jnp.asarray([2.5, 2]),
    bdown_mean=jnp.asarray([-2.5, -1.5]),
    b_mean=jnp.asarray([1, -1]),
    sig_scale=2,
    bkg_scale=10,
):

    ## helper fn for data gen ##
    def gen_blobs():
        sig = jax.random.multivariate_normal(
            rng, sig_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(num_points,)
        )
        bkg_up = jax.random.multivariate_normal(
            rng, bup_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(num_points,)
        )
        bkg_down = jax.random.multivariate_normal(
            rng, bdown_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(num_points,)
        )
        bkg_nom = jax.random.multivariate_normal(
            rng, b_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(num_points,)
        )

        return sig, bkg_nom, bkg_up, bkg_down

    ## nn --> yields ##
    def histogram_maker(nn, data):
        assert data
        s, b_nom, b_up, b_down = data

        nn_s, nn_b_nom, nn_b_up, nn_b_down = (
            predict(nn, s).ravel(),
            predict(nn, b_nom).ravel(),
            predict(nn, b_up).ravel(),
            predict(nn, b_down).ravel(),
        )

        kde_counts = [
            relaxed.hist_kde(nn_s, bins, bandwidth, reflect_infinities=reflect)
            * sig_scale
            / num_points
            * LUMI,
            relaxed.hist_kde(nn_b_nom, bins, bandwidth, reflect_infinities=reflect)
            * bkg_scale
            / num_points
            * LUMI,
            relaxed.hist_kde(nn_b_up, bins, bandwidth, reflect_infinities=reflect)
            * bkg_scale
            / num_points
            * LUMI,
            relaxed.hist_kde(
                nn_b_down,
                bins,
                bandwidth,
                reflect_infinities=reflect,
            )
            * bkg_scale
            / num_points
            * LUMI,
        ]

        return [k for k in kde_counts]

    ## yields --> model ##
    def model_with_bonlypars(nn, data):
        yields = histogram_maker(nn, data)
        m = correlated_background(*yields)
        nompars = m.config.suggested_init()
        bonlypars = jnp.asarray([x for x in nompars])
        bonlypars = bonlypars.at[m.config.poi_index].set(0.0)
        return m, bonlypars


    ## Data generation + train/test (thanks to jax docs for batching code!) ##
    d = gen_blobs()

    split = train_test_split(*d, test_size=test_size, shuffle=False, random_state=1)
    train, test = split[::2], split[1::2]

    num_train = train[0].shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    # batching mechanism
    def data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size : (i + 1) * batch_size]
                yield [points[batch_idx] for points in train]

    stream = data_stream()

    # regression net
    init_random_params, predict = stax.serial(
        stax.Dense(1024),
        stax.Relu,
        stax.Dense(1024),
        stax.Relu,
        stax.Dense(1),
        stax.Sigmoid,
    )

    _, network = init_random_params(jax.random.PRNGKey(2), (-1, 2))

    ## -NLL from model ##
    def fit_objective(model_pars, model_kwargs, constrained_mu=None):
        m, bonlypars = model_with_bonlypars(model_pars, **model_kwargs)
        exp_bonly_data = m.expected_data(bonlypars, include_auxdata=True)

        def constrained_fit_objective(nuis_par) -> float:  # NLL
            pars = jnp.concatenate([jnp.asarray([constrained_mu]), jnp.array(nuis_par)])
            return -m.logpdf(pars, data = exp_bonly_data)[0]

        return constrained_fit_objective

    # try wrapping obj with closure_convert
    def _minimize(objective_fn, lhood_pars, lr):
        converted_fn, m_pars = jax.closure_convert(objective_fn, lhood_pars) 
        # m_pars seems to be empty, took that line from docs example
        solver = OptaxSolver(fun=converted_fn, opt=optax.adam(lr), implicit_diff=True)
        return solver.run(lhood_pars, *m_pars)[0]

    # constrained fit with grad descent via Optax
    def fit(model_pars, model_kwargs, init_vals, constrained_mu=None, lr=3e-4):
        obj = fit_objective(model_pars, model_kwargs, constrained_mu=constrained_mu)
        fit_res = _minimize(obj, init_vals, lr)
        new_pars = jnp.array([constrained_mu, fit_res[0]]) if constrained_mu else fit_res 
        return new_pars

    # # use fit result as loss
    # def loss(nn, data):
    #     return fit(nn, dict(data=data))
    
    # print(loss(network, test)) # fwd pass

    
    def expected_CLs(nn, model_kwargs=dict(), test_mu=1., fit_lr=3e-4, loss_expr=lambda metrics: metrics['CLs']):
        m, bonlypars = model_with_bonlypars(nn, **model_kwargs) 
        exp_bonly_data = m.expected_data(bonlypars, include_auxdata=True)
        suggested_init = m.config.suggested_init()
        del suggested_init[0] # don't need init for mu since we're not fitting it
        
        # we know that the global MLE pars for expected bkg-only data are the bkg-only pars!
        denominator = bonlypars

        numerator = fit(nn, model_kwargs, init_vals=suggested_init, constrained_mu=test_mu, lr=fit_lr)

        # compute test statistic (lambda(µ))
        profile_likelihood = -2 * (
            m.logpdf(numerator, exp_bonly_data)[0] - m.logpdf(denominator, exp_bonly_data)[0]
        )

        # in exclusion fit zero out test stat if best fit µ^ is larger than test µ
        muhat = denominator[0]
        sqrtqmu = jnp.sqrt(jnp.where(muhat < test_mu, profile_likelihood, 0.0))
        CLsb = 1 - pyhf.tensorlib.normal_cdf(sqrtqmu)
        altval = 0
        CLb = 1 - pyhf.tensorlib.normal_cdf(altval)
        CLs = CLsb / CLb
        pull = jnp.array(
            [
                (numerator - jnp.array(m.config.suggested_init()))[
                    m.config.par_order.index(k)
                ]
                / m.config.param_set(k).width()[0]
                for k in m.config.par_order
                if m.config.param_set(k).constrained
            ]
        )
        # should use global mle pars -- here we know them since exp_data came from bonlypars
        errors = relaxed.cramer_rao_uncert(m, bonlypars, exp_bonly_data)

        pull_err = jnp.array(
            [
                errors[m.config.par_slice(k)] / m.config.param_set(k).width()[0]
                for k in m.config.par_order
                if m.config.param_set(k).constrained
            ]
        )

        metrics = dict(
            CLs=CLs,
            CLsb=CLsb,
            CLb=CLb,
            profile_likelihood=profile_likelihood,
            pull=pull,
            pull_err=pull_err,
            errors=errors,
        )

        return loss_expr(metrics), metrics

    def loss(nn, data, loss_expr):
        return expected_CLs(nn, dict(data=data), test_mu=1., loss_expr=loss_expr)

    return jax.value_and_grad(loss, has_aux = True)(network, test, loss_expr)



In [2]:
run_4blobs(bandwidth=0.1, bins=jnp.linspace(0,1,4), epochs=2)

((DeviceArray(-1.36451481, dtype=float64),
  {'CLb': DeviceArray(0.5, dtype=float64),
   'CLs': DeviceArray(0.25550461, dtype=float64),
   'CLsb': DeviceArray(0.12775231, dtype=float64),
   'errors': DeviceArray([0.96647312, 1.57786485], dtype=float64),
   'profile_likelihood': DeviceArray(1.29295218, dtype=float64),
   'pull': DeviceArray([1.], dtype=float64),
   'pull_err': DeviceArray([[0.96647312]], dtype=float64)}),
 [(DeviceArray([[-0.07969591,  0.0009669 , -0.00181384, ..., -0.01683938,
                  0.02609834,  0.08186976],
                [-0.09461673,  0.00064717, -0.00298575, ..., -0.02308915,
                 -0.02503954,  0.06435425]], dtype=float64),
   DeviceArray([-0.0341202 , -0.00259844,  0.00235879, ..., -0.00706263,
                 0.02045382,  0.02705943], dtype=float64)),
  (),
  (DeviceArray([[-1.17804279e-04, -1.65407495e-04,  8.04878312e-07, ...,
                  3.25869044e-05, -2.83761319e-03, -7.70001130e-04],
                [ 1.09952157e-04,  5.7094