In [None]:
import pyhf
import jax
from fax.implicit import twophase
pyhf.set_backend(pyhf.tensor.jax_backend())


class _Config(object):
    def __init__(self):
        self.poi_index = 0
        self.npars = 2
    def suggested_init(self):
        return [1.0,1.0]
    def suggested_bounds(self):
        return [[0.,10.],[0.,10.]]

class Model(object):
    def __init__(self,spec):
        self.sig,self.nominal,self.uncert = spec 
        self.factor = (self.nominal/self.uncert)**2
        self.aux = 1.0*self.factor
        self.config = _Config()
    
    def expected_data(self,pars,include_auxdata = True):
        mu,gamma = pars
        expected_main = jax.numpy.asarray([gamma*self.nominal + mu*self.sig])
        aux_data      = jax.numpy.asarray([self.aux])
        return jax.numpy.concatenate([expected_main,aux_data])

    def logpdf(self,pars,data):
        maindata, auxdata = data
        main,_ = self.expected_data(pars)
        mu,gamma = pars
        main =  pyhf.probability.Poisson(main).log_prob(maindata)
        constraint = pyhf.probability.Poisson(gamma*self.factor).log_prob(auxdata)
        return jax.numpy.asarray([main + constraint])

def hepdata_like(signal_data, bkg_data, bkg_uncerts, batch_size=None):
    return Model([signal_data[0],bkg_data[0],bkg_uncerts[0]])
#Model([[signal_data[0],signal_data[1]],[bkg_data[0],bkg_data[1]],[bkg_uncerts[0],bkg_uncerts[1]]])

def get_solvers(model_constructor):
    def make_model(hyper_pars):
        constrained_mu, nn_pars = hyper_pars[0], hyper_pars[1]
        m,bonlypars = model_constructor(nn_pars)
        
        exp_bonly_data = m.expected_data(bonlypars, include_auxdata = True) + 0.2

        def expected_logpdf(pars):
            return m.logpdf(pars,exp_bonly_data)
        
        def global_fit_objective(pars): #NLL
            return -expected_logpdf(pars)[0]
        
        def constrained_fit_objective(nuis_par): #NLL
            pars = jax.numpy.concatenate([jax.numpy.asarray([constrained_mu]),nuis_par])
            return -expected_logpdf(pars)[0]
        
        return constrained_mu,global_fit_objective, constrained_fit_objective

    def global_bestfit_minimized(hyper_param):
        _,nll,_ = make_model(hyper_param)
        def bestfit_via_grad_descent(i,param): #gradient descent
            param = param - jax.grad(nll)(param)*0.01
            return param
        return bestfit_via_grad_descent

    def constrained_bestfit_minimized(hyper_param):
        mu,nll,cnll = make_model(hyper_param)
        def bestfit_via_grad_descent(i,param): #gradient descent
            _,np = param[0],param[1:]
            np = np - jax.grad(cnll)(np)*0.01
            param = jax.numpy.concatenate([jax.numpy.asarray([mu]),np])
            return param
        return bestfit_via_grad_descent

    global_solve = twophase.two_phase_solver(
        param_func=global_bestfit_minimized,
        default_rtol=1e-10,
        default_atol=1e-10,
        default_max_iter=1000000,
    )
    constrained_solver = twophase.two_phase_solver(
        param_func=constrained_bestfit_minimized,
        default_rtol=1e-10,
        default_atol=1e-10,
        default_max_iter=1000000,
    )
    def g_fitter(init,hyper_pars):
        return global_solve(init,hyper_pars).value
    def c_fitter(init,hyper_pars):
        return constrained_solver(init,hyper_pars).value

    return g_fitter, c_fitter    


def cls_jax(hyper_pars,test_mu,mmaker):
    g_fitter,c_fitter = get_solvers(mmaker)
    #test_mu, nn_pars = hyper_pars[0], hyper_pars[1:]
    nn_pars = hyper_pars

    m,bonlypars = mmaker(nn_pars)
    exp_data = m.expected_data(bonlypars)
    initval     = jax.numpy.asarray([test_mu,1.0])

    # the constrained fit
    numerator   = c_fitter(initval,[test_mu,hyper_pars])
    # the global fit
    denominator = g_fitter(initval,[test_mu,hyper_pars])
    
    print(f'constrained fit: {numerator}')
    print(f'global fit: {denominator}')

    
    #compute test statistic (lambda(µ))
    profile_likelihood = -2*(m.logpdf(numerator,exp_data)[0]-m.logpdf(denominator,exp_data)[0])
    # in exclusion fit zero out test stat if best fit µ^ is larger than test µ
    muhat = denominator[0]
    sqrtqmu = jax.numpy.sqrt(jax.numpy.where(muhat < test_mu,profile_likelihood,0.0))
    # compute CLs
    nullval = sqrtqmu
    altval  = 0
    CLsb    = 1 - pyhf.tensorlib.normal_cdf(nullval)
    CLb     = 1 - pyhf.tensorlib.normal_cdf(altval)
    CLs     = CLsb / CLb
    return CLs

#def cls_test(hyper_pars,mmaker):
#    test_mu, line_pars = hyper_pars[0], hyper_pars[1:]
#    m,bonlypars = mmaker(line_pars)
#    exp_data = m.expected_data(bonlypars)
#    return pyhf.infer.hypotest(test_mu,exp_data,m)


def get_event_data(keys):
    k1,k2,k3 = keys
    sig  = jax.random.multivariate_normal(jax.random.PRNGKey(k1), jax.numpy.asarray([2,5]), jax.numpy.asarray([[1,0.],[0.,1]]),shape=(1,5000))[0]
    bkg1 = jax.random.multivariate_normal(jax.random.PRNGKey(k2), jax.numpy.asarray([4,6]), jax.numpy.asarray([[1,0.6],[0.6,1]]),shape=(1,5000))[0]
    bkg2 = jax.random.multivariate_normal(jax.random.PRNGKey(k3), jax.numpy.asarray([5.5,4.5]), jax.numpy.asarray([[1.7,0.2],[0.2,1]]),shape=(1,5000))[0]
    return sig,bkg1,bkg2


def real_model_maker(nn_params):
    #instantiate nn
    nn = three_blob_classifier()
    a,b,c = get_event_data([1,2,3])
    #print(f'nn: {nn_params}')
    s,b,db = hists_from_nn_uncert(nn,nn_params,a,b,c)
    s,b,db = s/10,b/10.,db/10.
    print(f'model: {s},{b},{db}')
    m =  hepdata_like(s,b,db)
    nompars = m.config.suggested_init()
    bonlypars = jax.numpy.asarray([x for x in nompars])
    bonlypars = jax.ops.index_update(bonlypars,m.config.poi_index,0.)
    return m,bonlypars



In [None]:
from fullstream.nn import three_blob_classifier
from fullstream.stats import hists_from_nn_uncert

In [None]:
nn = three_blob_classifier()

### train, or random params
nn.train(num_epochs=1)

In [None]:
jax.value_and_grad(cls_jax)(nn.params,1.,real_model_maker)

#_, init_params = nn.init_random_params(jax.random.PRNGKey(12), (-1, 2))
#cls_jax(init_params,1.,real_model_maker)

model: Traced<ConcreteArray([3.20238519 6.79761481])>with<JVPTrace(level=2/0)>,Traced<ConcreteArray([43.13707402  6.86292598])>with<JVPTrace(level=2/0)>,Traced<ConcreteArray([4.44113382 4.44113382])>with<JVPTrace(level=2/0)>
model: Traced<ShapedArray(float64[2]):JaxprTrace(level=3/0)>,Traced<ShapedArray(float64[2]):JaxprTrace(level=3/0)>,Traced<ShapedArray(float64[2]):JaxprTrace(level=3/0)>
model: Traced<ShapedArray(float64[2]):JaxprTrace(level=3/0)>,Traced<ShapedArray(float64[2]):JaxprTrace(level=3/0)>,Traced<ShapedArray(float64[2]):JaxprTrace(level=3/0)>
model: Traced<ConcreteArray([3.20238519 6.79761481])>with<JVPTrace(level=5/0)>,Traced<ConcreteArray([43.13707402  6.86292598])>with<JVPTrace(level=5/0)>,Traced<ConcreteArray([4.44113382 4.44113382])>with<JVPTrace(level=5/0)>
model: Traced<ShapedArray(float64[2]):JaxprTrace(level=3/0)>,Traced<ShapedArray(float64[2]):JaxprTrace(level=3/0)>,Traced<ShapedArray(float64[2]):JaxprTrace(level=3/0)>
model: Traced<ShapedArray(float64[2]):Jaxpr

(DeviceArray(0.68926019, dtype=float64),
 [(DeviceArray([[ 3.2891934e-05,  7.4151361e-01, -1.3925934e-02,
                 -5.5577046e-01,  1.7432596e-01],
                [-7.0972915e-04,  1.5240345e+00, -2.8729323e-02,
                 -1.1507896e+00,  3.6045146e-01]], dtype=float32),
   DeviceArray([-1.4457165e-04,  2.9077259e-01, -5.4704603e-03,
                -2.1916847e-01,  6.8680577e-02], dtype=float32)),
  (),
  (DeviceArray([[-1.4179809e-05,  1.4179809e-05],
                [-2.8190127e-01,  2.8190127e-01],
                [-3.6950070e-02,  3.6950070e-02],
                [-6.2024558e-01,  6.2024558e-01],
                [-5.3603518e-01,  5.3603518e-01]], dtype=float32),
   DeviceArray([-0.17807451,  0.17807451], dtype=float32)),
  ()])