In [None]:
# default_exp models

# neos.models

> Surrogate module to temorarily replace pyhf functionality.

This module implements a very lightweght version of pyhf-like model building. For now, there are some hard-coded numbers (bounds, init) that help with the three gaussian blobs demonstration. This is not meant to be used in practice, as we plan to interface with full pyhf, but this is easily customizable if you want to make some modifications to try something out :)

In [None]:
# export
import jax
from jax.config import config

import pyhf

# avoid those precision errors!
config.update("jax_enable_x64", True)

pyhf.set_backend(pyhf.tensor.jax_backend())

### Define model and config classes analagous to pyhf syntax

In [None]:
# export
# class-based
class _Config(object):
    def __init__(self):
        self.poi_index = 0
        self.npars = 2

    def suggested_init(self):
        return jax.numpy.asarray([1.0, 1.0])

    def suggested_bounds(self):
        return jax.numpy.asarray(
            [jax.numpy.asarray([0.0, 10.0]), jax.numpy.asarray([0.0, 10.0])]
        )


class Model(object):
    def __init__(self, spec):
        self.sig, self.nominal, self.uncert = spec
        self.factor = (self.nominal / self.uncert) ** 2
        self.aux = 1.0 * self.factor
        self.config = _Config()

    def expected_data(self, pars, include_auxdata=True):
        mu, gamma = pars
        expected_main = jax.numpy.asarray([gamma * self.nominal + mu * self.sig])
        aux_data = jax.numpy.asarray([self.aux])
        return jax.numpy.concatenate([expected_main, aux_data])

    def logpdf(self, pars, data):
        maindata, auxdata = data
        main, _ = self.expected_data(pars)
        mu, gamma = pars
        main = pyhf.probability.Poisson(main).log_prob(maindata)
        constraint = pyhf.probability.Poisson(gamma * self.factor).log_prob(auxdata)
        # sum log probs over bins
        return jax.numpy.asarray([jax.numpy.sum(main + constraint, axis=0)])


def hepdata_like(signal_data, bkg_data, bkg_uncerts, batch_size=None):
    return Model([signal_data, bkg_data, bkg_uncerts])

### Build an example model, and get gradients: 

In [None]:
sig = jax.numpy.asarray([20, 40, 3])
bkg = jax.numpy.asarray([40, 20, 3])
un = jax.numpy.asarray([3, 3, 3])
m = hepdata_like(sig, bkg, un)
d = m.expected_data([1.0, 1.0])


def logpdf_unlisted(pars):
    return m.logpdf(pars, d)[0]


jax.value_and_grad(logpdf_unlisted)([2.0, 1.0])

(DeviceArray(-27.74804929, dtype=float64),
 [DeviceArray(-22., dtype=float64), DeviceArray(-19., dtype=float64)])

### A bonus functional implementation!

In [None]:
# hide
# # functional
# from collections import namedtuple

# _Config = namedtuple("_Config", ["poi_index","npars","suggested_init","suggested_bounds"])

# def init_config():
#     return _Config(0,2,jax.numpy.asarray([1.0, 1.0]),jax.numpy.asarray(
#             [jax.numpy.asarray([0.0, 10.0]), jax.numpy.asarray([0.0, 10.0])]
#         ))

# Model = namedtuple("Model", ["sig", "nominal", "uncert", "factor", "aux", "config"])

# def init_model(spec):
#     sig, nominal, uncert = spec
#     factor = (nominal / uncert) ** 2
#     aux = 1.0 * factor
#     config = init_config()
#     return Model(sig, nominal, uncert, factor, aux, config)

# def expected_data(model, pars, include_auxdata=True):
#     mu, gamma = pars
#     expected_main = jax.numpy.asarray([gamma * model.nominal + mu * model.sig])
#     aux_data = jax.numpy.asarray([model.aux])
#     return jax.numpy.concatenate([expected_main, aux_data])

# @jax.jit
# def logpdf(model, pars, data):
#     maindata, auxdata = data
#     main, _ = expected_data(model,pars)
#     mu, gamma = pars
#     main = pyhf.probability.Poisson(main).log_prob(maindata)
#     constraint = pyhf.probability.Poisson(gamma * model.factor).log_prob(auxdata)
#     # sum log probs over bins
#     return jax.numpy.asarray([jax.numpy.sum(main + constraint,axis=0)])


# def hepdata_like(signal_data, bkg_data, bkg_uncerts, batch_size=None):
#     return init_model([signal_data, bkg_data, bkg_uncerts])