In [None]:
# default_exp makers

In [None]:
#export
import jax
import jax.scipy as jsc
import jax.numpy as jnp
import numpy as np
from functools import partial

from neos import models

In [None]:
#export
def hists_from_nn_three_blobs(predict, NMC = 500, sig_mean = [-1, 1], b1_mean=[2, 2], b2_mean=[-1, -1], LUMI=10, sig_scale = 2, bkg_scale = 10):

    def get_hists(network, s, bs):
        NMC = len(s)
        s_hist = predict(network, s).sum(axis=0) * sig_scale / NMC * LUMI
        
        b_hists = tuple(
            (predict(network, b2).sum(axis=0) * bkg_scale / NMC * LUMI) for b in bs
        )
        
        b_mean = jax.numpy.mean(jax.numpy.asarray([bh1, bh2]), axis=0)
        b_unc = jax.numpy.std(jax.numpy.asarray([bh1, bh2]), axis=0)
        results = sh, b_mean, b_unc
        return results


    def hist_maker(get_hists):
        bkg1 = np.random.multivariate_normal(b1_mean, [[1, 0], [0, 1]], size=(NMC,))
        bkg2 = np.random.multivariate_normal(b2_mean, [[1, 0], [0, 1]], size=(NMC,))
        sig = np.random.multivariate_normal(sig_mean, [[1, 0], [0, 1]], size=(NMC,))

        def make(network):
            return get_hists(network, sig, bkg1, bkg2)

        make.bkg1 = bkg1
        make.bkg2 = bkg2
        make.sig = sig
        return make
    
    return hist_maker



In [None]:
#export
# kde experiment

def kde_counts_from_nn_three_blobs(predict, bins=np.linspace(0,1,3), bandwidth=.3, NMC = 500, sig_mean = [-1, 1], b1_mean=[2, 2], b2_mean=[-1, -1], LUMI=10, sig_scale = 2, bkg_scale = 10):
    
    # grab bin edges
    edge_lo   = bins[:-1]
    edge_hi   = bins[1:]
    
    # get counts from gaussian cdfs centered on each event, evaluated binwise
    def to_hist(events):
        cdf_up = jsc.stats.norm.cdf(edge_hi.reshape(-1,1),loc = events, scale = bandwidth)
        cdf_dn = jsc.stats.norm.cdf(edge_lo.reshape(-1,1),loc = events, scale = bandwidth)
        summed = (cdf_up-cdf_dn).sum(axis=1)
        return summed
    
    def get_hists(network, s, b1, b2):
        NMC = len(s)
        nn_s, nn_b1, nn_b2 = (
            predict(network, s).ravel(),
            predict(network, b1).ravel(),
            predict(network, b2).ravel(),
        )
             
        kde_counts = jax.numpy.asarray([
            to_hist(nn_s)* sig_scale / NMC * LUMI,
            to_hist(nn_b1)* bkg_scale / NMC * LUMI,
            to_hist(nn_b2)* bkg_scale / NMC * LUMI,
        ])
        
        b_mean = jax.numpy.mean(kde_counts[1:], axis=0)
        b_unc = jax.numpy.std(kde_counts[1:], axis=0)
        results = kde_counts[0], b_mean,b_unc
        return results


    def hist_maker():
        bkg1 = np.random.multivariate_normal(b1_mean, [[1, 0], [0, 1]], size=(NMC,))
        bkg2 = np.random.multivariate_normal(b2_mean, [[1, 0], [0, 1]], size=(NMC,))
        sig = np.random.multivariate_normal(sig_mean, [[1, 0], [0, 1]], size=(NMC,))

        def make(network):
            return get_hists(network, sig, bkg1, bkg2)

        make.bkg1 = bkg1
        make.bkg2 = bkg2
        make.sig = sig
        return make
    
    return hist_maker

In [None]:
#export
import pyhf
pyhf.set_backend(pyhf.tensor.jax_backend())
def nn_hepdata_like(histogram_maker):
    hm = histogram_maker()

    def nn_model_maker(network):
        s, b, db = hm(network)
        print(s, b, db)
        m = models.hepdata_like(s, b, db) #pyhf.simplemodels.hepdata_like(s, b, db)
        nompars = m.config.suggested_init
        bonlypars = jax.numpy.asarray([x for x in nompars])
        bonlypars = jax.ops.index_update(bonlypars, m.config.poi_index, 0.0)
        return m, bonlypars

    nn_model_maker.hm = hm
    return nn_model_maker