In [None]:
# default_exp makers

# neos.makers

> Functions that define the workflow from parametric observable --> statistical model.

This module contains example implementations of functions that are composed such that everything downstream is a function of the parameters of your observable.


- `hists_from_nn_three_blobs(predict)` uses the nn decision function `predict` to form histograms from signal and background data, all drawn from multivariate normal distributions with different means. Two background distributions are sampled from, which is meant to mimic the situation in particle physics where one has a 'nominal' prediction for a nuisance parameter and then an alternate value (e.g. from varying up/down by one standard deviation), which then modifies the background pdf. Here, we take that effect to be a shift of the mean of the distribution. The value for the background histogram is then the mean of the resulting counts of the two modes, and the uncertainty can be quantified through the count standard deviation.
- `kde_counts_from_nn_three_blobs(predict, bins)` functions exactly as above, but uses a different method involving kernel density estimation to get the yields from the parameters of the observable, and needs the binning pre-specified as argument.
- `nn_hepdata_like(hmaker)` uses the resulting functions from either of the above (or your own!) methods to construct histograms, then feeds them into the `neos.models.hepdata_like` function that constructs a pyhf-like model. This can then be used to call things like `logpdf` and `expected_data` downstream when CLs values are calculated.

In [None]:
#export
import jax
import jax.scipy as jsc
import jax.numpy as jnp
import numpy as np
from functools import partial

from neos import models

In [None]:
#export
def hists_from_nn_three_blobs(predict, NMC = 500, sig_mean = [-1, 1], b1_mean=[2, 2], b2_mean=[-1, -1], LUMI=10, sig_scale = 2, bkg_scale = 10):
    '''
    Uses the nn decision function `predict` to form histograms from signal and background 
    data, all drawn from multivariate normal distributions with different means. Two 
    background distributions are sampled from, which is meant to mimic the situation in 
    particle physics where one has a 'nominal' prediction for a nuisance parameter and then 
    an alternate value (e.g. from varying up/down by one standard deviation), which then 
    modifies the background pdf. Here, we take that effect to be a shift of the mean of the 
    distribution. The value for the background histogram is then the mean of the resulting 
    counts of the two modes, and the uncertainty can be quantified through the count 
    standard deviation.
    
    Args:
            predict: Decision function for a parameterized observable. Assumed softmax here.

    Returns:
            hist_maker: A callable function that takes the parameters of the observable, 
            then constructs signal, background, and background uncertainty yields.
    '''
    def get_hists(network, s, bs):
        NMC = len(s)
        s_hist = predict(network, s).sum(axis=0) * sig_scale / NMC * LUMI
        
        b_hists = tuple(
            (predict(network, b).sum(axis=0) * bkg_scale / NMC * LUMI) for b in bs
        )
        
        b_mean = jax.numpy.mean(jax.numpy.asarray(b_hists), axis=0)
        b_unc = jax.numpy.std(jax.numpy.asarray(b_hists), axis=0)
        results = s_hist, b_mean, b_unc
        return results


    def hist_maker():
        bkg1 = np.random.multivariate_normal(b1_mean, [[1, 0], [0, 1]], size=(NMC,))
        bkg2 = np.random.multivariate_normal(b2_mean, [[1, 0], [0, 1]], size=(NMC,))
        sig = np.random.multivariate_normal(sig_mean, [[1, 0], [0, 1]], size=(NMC,))

        def make(network):
            return get_hists(network, sig, (bkg1,bkg2))

        make.bkg1 = bkg1
        make.bkg2 = bkg2
        make.sig = sig
        return make
    
    return hist_maker



In [None]:
#export
# kde experiment

def kde_counts_from_nn_three_blobs(predict, bins, bandwidth=.3, NMC = 500, sig_mean = [-1, 1], b1_mean=[2, 2], b2_mean=[-1, -1], LUMI=10, sig_scale = 2, bkg_scale = 10):
    '''
    Exactly the same as `hists_from_nn_three_blobs`, but takes in a regression network, and
    forms a kernel density estimate (kde) for the output. The yields are then calculated as 
    the integral of the kde's cumulative density function between the bin edges, which should
    be specified as an argument to the function.
    
    Args:
            predict: Decision function for a parameterized observable. When evaluated, the 
            output should be one number per event, i.e. a regression network or similar.

    Returns:
            hist_maker: A callable function that takes the parameters of the observable, 
            then constructs signal, background, and background uncertainty yields.
    '''
    # grab bin edges
    edge_lo   = bins[:-1]
    edge_hi   = bins[1:]
    
    # get counts from gaussian cdfs centered on each event, evaluated binwise
    def to_hist(events):
        cdf_up = jsc.stats.norm.cdf(edge_hi.reshape(-1,1),loc = events, scale = bandwidth)
        cdf_dn = jsc.stats.norm.cdf(edge_lo.reshape(-1,1),loc = events, scale = bandwidth)
        summed = (cdf_up-cdf_dn).sum(axis=1)
        return summed
    
    def get_hists(network, s, b1, b2):
        NMC = len(s)
        nn_s, nn_b1, nn_b2 = (
            predict(network, s).ravel(),
            predict(network, b1).ravel(),
            predict(network, b2).ravel(),
        )
             
        kde_counts = jax.numpy.asarray([
            to_hist(nn_s)* sig_scale / NMC * LUMI,
            to_hist(nn_b1)* bkg_scale / NMC * LUMI,
            to_hist(nn_b2)* bkg_scale / NMC * LUMI,
        ])
        
        b_mean = jax.numpy.mean(kde_counts[1:], axis=0)
        b_unc = jax.numpy.std(kde_counts[1:], axis=0)
        results = kde_counts[0], b_mean,b_unc
        return results


    def hist_maker():
        bkg1 = np.random.multivariate_normal(b1_mean, [[1, 0], [0, 1]], size=(NMC,))
        bkg2 = np.random.multivariate_normal(b2_mean, [[1, 0], [0, 1]], size=(NMC,))
        sig = np.random.multivariate_normal(sig_mean, [[1, 0], [0, 1]], size=(NMC,))

        def make(network):
            return get_hists(network, sig, bkg1, bkg2)

        make.bkg1 = bkg1
        make.bkg2 = bkg2
        make.sig = sig
        return make
    
    return hist_maker

In [None]:
#export
import pyhf
pyhf.set_backend(pyhf.tensor.jax_backend())

def nn_hepdata_like(histogram_maker):
    '''
    Returns a function that constructs a typical 'hepdata-like' statistical model
    with signal, background, and background uncertainty yields when evaluated at
    the parameters of the observable.
    
    Args:
            histogram_maker: A function that, when called, returns a secondary function
            that takes the observable's parameters as argument, and returns yields.
         
    Returns:
            nn_model_maker: A function that returns a Model object (either from 
            `neos.models` or from `pyhf`) when evaluated at the observable's parameters.
    '''
    hm = histogram_maker()

    def nn_model_maker(network):
        s, b, db = hm(network)
        print(type(s))
        pprint(f's={s}, b={b}, db={db}')
        m = pyhf.simplemodels.hepdata_like(s, b, db)
        nompars = m.config.suggested_init()
        bonlypars = jax.numpy.asarray([x for x in nompars])
        bonlypars = jax.ops.index_update(bonlypars, m.config.poi_index, 0.0)
        return m, bonlypars

    nn_model_maker.hm = hm
    return nn_model_maker