In [None]:
# default_exp makers

# neos.makers

> Functions that define the workflow from parametric observable --> statistical model.

This module contains example workflows to go from the output of a neural network to a differentiable histogram, and to then use that as a basis for statistical modelling via the [HistFactory likelihood specification](https://scikit-hep.org/pyhf/intro.html#histfactory).

These functions are designed to be composed such that a final metric (e.g. expected p-value) is explicitly made a function of the parameters of the neural network. You can see this behaviour through the nested function design; one can specify all other hyperparameters ahead of time when initializing the functions, and the nn weights don't have to be specified until the inner function is called. Keep reading for examples!

## differentiable histograms from neural networks

In [None]:
#export
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np

from relaxed import hist_kde as hist 

In [None]:
#export
def hists_from_nn(
    data_generator, predict, hpar_dict, method="softmax", LUMI=10, sig_scale=2, bkg_scale=10, 
):
    """Initialize a function `hist_maker` that returns a 'soft' histogram based
    on a neural network with a softmax output. Choose which example problem to
    try by setting the `example` argument.

    Args:
        data_generator: Callable that returns generated data (in jax array 
                        format).

        predict: Decision function for a parameterized observable, e.g. neural 
                 network.

        method: A string to specify the method to use for constructing soft 
                histograms. Either "softmax" or "kde".

        LUMI: 'Luminosity' scaling factor for the yields.

        sig_scale: Individual scaling factor for the signal yields.

        bkg_scale: Individual scaling factor for the signal yields.

    Returns:
        hist_maker: A callable function that takes the parameters of the 
                    observable (and optional hyperpars), then constructs signal,
                    background, and background uncertainty yields.
    """

    data = data_generator()

    if len(data) == 3:
        if method == "softmax":

            def hist_maker(hm_params):
                """Uses the nn decision function `predict` to form histograms
                from signal and background data, all drawn from multivariate
                normal distributions with different means. Two background
                distributions are sampled from, which is meant to mimic the
                situation in particle physics where one has a 'nominal'
                prediction for a nuisance parameter (taken here as the mean of
                two modes) and then alternate values (e.g. from varying up/down
                by one standard deviation), which then modifies the background
                pdf. Here, we take that effect to be a shift of the mean of the
                distribution. The value for the background histogram is then
                the mean of the resulting counts of the two modes, and the
                uncertainty can be quantified through the count standard
                deviation.

                Arguments:
                    hm_params: a list containing:
                        nn: jax array of observable parameters.
                """
                nn = hm_params
                s, b_up, b_down = data
                NMC = len(s)
                s_hist = predict(nn, s).sum(axis=0) * sig_scale / NMC * LUMI

                b_hists = [
                    predict(nn, b_up).sum(axis=0)   * bkg_scale / NMC * LUMI,
                    predict(nn, b_down).sum(axis=0) * bkg_scale / NMC * LUMI,
                ]

                b_mean = jnp.mean(jnp.asarray(b_hists), axis=0)
                b_unc = jnp.std(jnp.asarray(b_hists), axis=0)

                return s_hist, b_mean, b_unc

        elif method == "kde":

            def hist_maker(hm_params):
                """Uses the nn decision function `predict` to form histograms
                from signal and background data using a kde, all drawn from
                multivariate normal distributions with different means. Two
                background distributions are sampled from, which is meant to
                mimic the situation in particle physics where one has a
                'nominal' prediction for a nuisance parameter (taken here as
                the mean of two modes) and then alternate values (e.g. from
                varying up/down by one standard deviation), which then modifies
                the background pdf. Here, we take that effect to be a shift of
                the mean of the distribution. The value for the background
                histogram is then the mean of the resulting counts of the two
                modes, and the uncertainty can be quantified through the count
                standard deviation.

                Arguments:
                    hm_params: Array-like, consisting of:
                        nn: jax array of observable parameters.

                        bins: Array of bin edges, e.g. np.linspace(0,1,3) 
                              defines a two-bin histogram with edges at 0, 0.5, 
                              1.

                        bandwidth: Float that controls the 'smoothness' of the 
                                   kde. It's recommended to keep this fairly 
                                   similar to the bin width to avoid 
                                   oversmoothing the distribution. Going too low
                                   will cause things to break, as the gradients 
                                   of the kde become unstable.
                """
                nn = hm_params
                bins, bandwidth = hpar_dict["bins"], hpar_dict["bandwidth"]
                s, b_up, b_down = data
                NMC = len(s)

                nn_s, nn_b_up, nn_b_down = (
                    predict(nn, s).ravel(),
                    predict(nn, b_up).ravel(),
                    predict(nn, b_down).ravel(),
                )

                s_hist = hist(nn_s, bins, bandwidth) * sig_scale / NMC * LUMI

                b_hists = jnp.asarray(
                    [
                        hist(nn_b_up, bins, bandwidth)   * bkg_scale / NMC * LUMI,
                        hist(nn_b_down, bins, bandwidth) * bkg_scale / NMC * LUMI,
                    ]
                )

                kde_counts = [
                    s_hist,
                    jnp.mean(b_hists, axis=0),
                    jnp.std(b_hists, axis=0),
                ]

                return kde_counts

        else:
            assert False, (
            f"Unsupported method: {method}"
            " (only using kde or softmax for these examples)."
            )
                
    elif len(data) == 4:
        if method == "softmax":

            def hist_maker(hm_params):
                """Uses the nn decision function `predict` to form histograms
                from signal and background data, all drawn from multivariate
                normal distributions with different means. Three background
                distributions are sampled from, which mimics the situation in
                particle physics where one has a 'nominal' prediction for a
                nuisance parameter (taken here as the mean of two modes) and
                then alternate values (e.g. from varying up/down by one
                standard deviation), which then modifies the background pdf.
                Here, we take that effect to be a shift of the mean of the
                distribution. The HistFactory 'histosys' nusiance parameter
                will then be constructed from the yields downstream by
                interpolating between them using pyhf.

                Arguments:
                    hm_params: a list containing:
                        nn: jax array of observable parameters.

                Returns:
                    Set of 4 counts for signal, background, and up/down modes.
                """
                nn = hm_params
                s, b_nom, b_up, b_down = data
                NMC = len(s)
                counts = [
                    predict(nn, s).sum(axis=0)      * sig_scale / NMC * LUMI,
                    predict(nn, b_nom).sum(axis=0)  * bkg_scale / NMC * LUMI,
                    predict(nn, b_up).sum(axis=0)   * bkg_scale / NMC * LUMI,
                    predict(nn, b_down).sum(axis=0) * bkg_scale / NMC * LUMI,
                ]

                return counts

        elif method == "kde":

            def hist_maker(hm_params):
                """Uses the nn decision function `predict` to form histograms
                from signal and background data, all drawn from multivariate
                normal distributions with different means. Three background
                distributions are sampled from, which mimics the situation in
                particle physics where one has a 'nominal' prediction for a
                nuisance parameter (taken here as the mean of two modes) and
                then alternate values (e.g. from varying up/down by one
                standard deviation), which then modifies the background pdf.
                Here, we take that effect to be a shift of the mean of the
                distribution. The HistFactory 'histosys' nusiance parameter
                will then be constructed from the yields downstream by
                interpolating between them using pyhf.

                Arguments:
                    hm_params: Array-like, consisting of:
                        nn: jax array of observable parameters.

                        bins: Array of bin edges, e.g. np.linspace(0,1,3) 
                              defines a two-bin histogram with edges at 0, 0.5, 
                              1.

                        bandwidth: Float that controls the 'smoothness' of the 
                                   kde. It's recommended to keep this fairly 
                                   similar to the bin width to avoid 
                                   oversmoothing the distribution. Going too low
                                   will cause things to break, as the gradients 
                                   of the kde become unstable.

                Returns:
                    Set of 4 counts for signal, background, and up/down modes.
                """
                nn = hm_params
                bins, bandwidth = hpar_dict["bins"], hpar_dict["bandwidth"]
                s, b_nom, b_up, b_down = data
                NMC = len(s)

                nn_s, nn_b_nom, nn_b_up, nn_b_down = (
                    predict(nn, s).ravel(),
                    predict(nn, b_nom).ravel(),
                    predict(nn, b_up).ravel(),
                    predict(nn, b_down).ravel(),
                )

                kde_counts = [
                    hist(nn_s, bins, bandwidth)      * sig_scale / NMC * LUMI,
                    hist(nn_b_nom, bins, bandwidth)  * bkg_scale / NMC * LUMI,
                    hist(nn_b_up, bins, bandwidth)   * bkg_scale / NMC * LUMI,
                    hist(nn_b_down, bins, bandwidth) * bkg_scale / NMC * LUMI,
                ]

                return kde_counts

        else:
            assert False, (
            f"Unsupported method: {method}"
            " (only using kde or softmax for these examples)."
            )
    else:
        assert False, (
            f"Unsupported number of blobs: {blobs}"
            " (only using 3 or 4 blobs for these examples)."
        )

    return hist_maker


### Usage:

Begin by instantiating `hists_from_nn` with a function that generates a 3 or 4-tuple of data (we have `generate_blobs` for this!), and a neural network `predict` method (takes inputs & weights, returns output)

In [None]:
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
from jax.experimental import stax

import neos
from neos.makers import hists_from_nn
from neos.data import generate_blobs

# data generator
gen_data = generate_blobs(rng=PRNGKey(1),blobs=4)

# nn
init_random_params, predict = stax.serial(
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(1),
    stax.Sigmoid
)

hist_maker = hists_from_nn(gen_data, predict, method='kde')

Now, when we initialize our neural network's weights and pass them to `hist_maker` along with some hyperparameters for the histogram (binning, bandwidth), we should get back a set of event yields:

In [None]:
_, network = init_random_params(jax.random.PRNGKey(13), (-1, 2))

hyperpars = dict(bandwidth=0.5, bins=jnp.linspace(0,1,3))

hist_maker([network, hyperpars])

[DeviceArray([6.76080181, 6.8832221 ], dtype=float64),
 DeviceArray([34.14177765, 34.12072566], dtype=float64),
 DeviceArray([35.10574108, 33.04116608], dtype=float64),
 DeviceArray([32.32513054, 35.63212448], dtype=float64)]

## statistical models

In [None]:
#export

import pyhf

jax_backend = pyhf.tensor.jax_backend(precision="64b")
pyhf.set_backend(jax_backend)

from neos.models import hepdata_like

In [None]:
#export
def hepdata_like_from_hists(histogram_maker):
    """Returns a function that constructs a typical 'hepdata-like' statistical
    model with signal, background, and background uncertainty yields when
    evaluated at the parameters of the observable.

    Args:
        histogram_maker: A function that, when called, returns a secondary function
        that takes the observable's parameters as argument, and returns yields.

    Returns:
        nn_model_maker: A function that returns a Model object (either from
        `neos.models` or from `pyhf`) when evaluated at the observable's parameters,
        along with the background-only parameters for use in downstream inference.
    """

    def nn_model_maker(hm_params):
        s, b, db = histogram_maker(hm_params)
        m = hepdata_like(s, b, db)  # neos 'pyhf' model
        nompars = m.config.suggested_init()
        bonlypars = jnp.asarray([x for x in nompars])
        bonlypars = jax.ops.index_update(bonlypars, m.config.poi_index, 0.0)
        return m, bonlypars

    return nn_model_maker

### Usage:

In [None]:
# define a hist_maker as above

import jax
import jax.numpy as jnp
from jax.random import PRNGKey
from jax.experimental import stax

import neos
from neos.makers import hists_from_nn, hepdata_like_from_hists
from neos.data import generate_blobs

# data generator, three blobs only for this model
gen_data = generate_blobs(rng=PRNGKey(1),blobs=3)

# nn
init_random_params, predict = stax.serial(
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(1),
    stax.Sigmoid
)

hist_maker = hists_from_nn(gen_data, predict, method='kde')

# then use this to define your model:
model = hepdata_like_from_hists(hist_maker)

Similar to above, we can get output at this stage by initializing the neural network. `hepdata_like_from_hists` will return a `Model` object with callable `logpdf` method, as well as the model parameters in the background-only scenario for convenience. See [this link](https://scikit-hep.org/pyhf/_generated/pyhf.simplemodels.hepdata_like.html) for more about the type of model being used here, as well as the rest of the `pyhf` docs for added physics context.

In [None]:
_, network = init_random_params(jax.random.PRNGKey(13), (-1, 2))

hyperpars = dict(bandwidth=0.5, bins=jnp.linspace(0,1,3))

m, bkg_only_pars = model([network, hyperpars])
m.logpdf(bkg_only_pars,data=[1,1])

DeviceArray([-1338.66123891], dtype=float64)

In [None]:
#export
import sys
from unittest.mock import patch
jax_backend = pyhf.tensor.jax_backend(precision='64b')
pyhf.set_backend(jax_backend)

def histosys_model_from_hists(histogram_maker):
    """Returns a function that constructs a HEP statistical model using a
    'histosys' uncertainty for the background (nominal background, up and down
    systematic variations) when evaluated at the parameters of the observable.

    Args:
        histogram_maker: A function that, when called, returns a secondary function
        that takes the observable's parameters as argument, and returns yields.

    Returns:
        nn_model_maker: A function that returns a `pyhf.Model` object when 
        evaluated at the observable's parameters (nn weights), along with the 
        background-only parameters for use in downstream inference.
    """

    @patch('pyhf.default_backend', new=jax_backend)
    @patch.object(sys.modules['pyhf.interpolators.code0'], 'default_backend', new=jax_backend)
    @patch.object(sys.modules['pyhf.interpolators.code1'], 'default_backend', new=jax_backend)
    @patch.object(sys.modules['pyhf.interpolators.code2'], 'default_backend', new=jax_backend)
    @patch.object(sys.modules['pyhf.interpolators.code4'], 'default_backend', new=jax_backend)
    @patch.object(sys.modules['pyhf.interpolators.code4p'], 'default_backend', new=jax_backend)
    @patch.object(sys.modules['pyhf.modifiers.shapefactor'], 'default_backend', new=jax_backend)
    @patch.object(sys.modules['pyhf.modifiers.shapesys'], 'default_backend', new=jax_backend)
    @patch.object(sys.modules['pyhf.modifiers.staterror'], 'default_backend', new=jax_backend)
    def from_spec(yields):

        s, b, bup, bdown = yields

        spec = {
            "channels": [
                {
                    "name": "nn",
                    "samples": [
                        {
                            "name": "signal",
                            "data": s,
                            "modifiers": [
                                {"name": "mu", "type": "normfactor", "data": None}
                            ],
                        },
                        {
                            "name": "bkg",
                            "data": b,
                            "modifiers": [
                                {
                                    "name": "nn_histosys",
                                    "type": "histosys",
                                    "data": {
                                        "lo_data": bdown,
                                        "hi_data": bup,
                                    },
                                }
                            ],
                        },
                    ],
                },
            ],
        }

        return pyhf.Model(spec)

    def nn_model_maker(hm_params):
        yields = histogram_maker(hm_params)
        m = from_spec(yields)
        nompars = m.config.suggested_init()
        bonlypars = jnp.asarray([x for x in nompars])
        bonlypars = jax.ops.index_update(bonlypars, m.config.poi_index, 0.0)
        return m, bonlypars

    return nn_model_maker


### Usage:

In [None]:
# define a hist_maker as above

import jax
import jax.numpy as jnp
from jax.random import PRNGKey
from jax.experimental import stax

import neos
from neos.makers import hists_from_nn, histosys_model_from_hists
from neos.data import generate_blobs

# data generator, four blobs only for this model
gen_data = generate_blobs(rng=PRNGKey(1),blobs=4)

# nn
init_random_params, predict = stax.serial(
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(1),
    stax.Sigmoid
)

hist_maker = hists_from_nn(gen_data, predict, method='kde')

# then use this to define your model:
model = histosys_model_from_hists(hist_maker)

_, network = init_random_params(jax.random.PRNGKey(13), (-1, 2))

hyperpars = dict(bandwidth=0.5, bins=jnp.linspace(0,1,3))

# instantiate model and eval logpdf
m, bkg_only_pars = model([network, hyperpars])
m.logpdf(bkg_only_pars,data=[1,1,1])

DeviceArray([-62.62101506], dtype=float64)