In [1]:
import json
import pymc as pm
import arviz as az
import jax.config as config
config.update('jax_enable_x64', True)
import numpy as np

import pytensor
from pytensor import tensor as pt
from pytensor.graph.basic import Apply
from pytensor.graph import Apply, Op

import jax
from jax import grad, jit, vmap, value_and_grad, random
import jax.numpy as jnp

import prepare_inference

import pyhf
pyhf.set_backend('jax')



In [2]:
N = 5
obs = jnp.array([70.]*N)
model = pyhf.simplemodels.correlated_background([10]*N,[50]*N, [45]*N, [55]*N)
model.expected_actualdata(model.config.suggested_init()),model.config.suggested_init()

(Array([60., 60., 60., 60., 60.], dtype=float64), [0.0, 1.0])

In [3]:
class VJPOp(pt.Op):
    itypes = [pt.dvector,pt.dvector]  
    otypes = [pt.dvector]

    def __init__(self, vjp_func):
        self.vjp_func = vjp_func

    def perform(self, node, inputs, outputs):
        (at, vector) = inputs
        results = self.vjp_func(at, vector)

        if len(outputs) == 1:
            outputs[0][0] = np.asarray(results, dtype = np.float64)

        for i, r in enumerate(results):
            outputs[i][0] = np.asarray(r, dtype = np.float64)


class ExpDataOp(pt.Op):
    itypes = [pt.dvector]  
    otypes = [pt.dvector]

    def __init__(self, fwd_func):
        self.fwd_func = fwd_func

    def perform(self, node, inputs, outputs):
        (parameters, ) = inputs
        results = self.fwd_func(parameters)

        if len(outputs) == 1:
                outputs[0][0] = np.asarray(results, dtype = np.float64)
                return
        for i, r in enumerate(results):
                outputs[i][0] = np.asarray(r, dtype = np.float64)

    def grad(self, at_vector, vector):
        return [vjp_op(at_vector[0],vector[0])]
    
def _pyhf_forward(x):
    return model.expected_actualdata(x)

pyhf_fwd_func = jax.jit(_pyhf_forward)
pyhf_vjp_func = jax.jit(lambda at, vector: jax.vjp(_pyhf_forward, at)[1](vector))
                
fwd_op = ExpDataOp(pyhf_fwd_func)
vjp_op = VJPOp(pyhf_vjp_func)

In [4]:
import prepare_inference
from prepare_inference import get_target
def priors2pymc(prepared_model):
    unconstr_pars, norm_pars, poiss_pars = [], [], []
    norm_mu, norm_sigma = [], []
    poiss_alpha, poiss_beta = [], []
    model = prepared_model['model']
    obs = prepared_model['obs']
    prior_dict = prepared_model['priors']
    precision = prepared_model['precision']
        
    for key in prior_dict.keys():
        sub_dict = prior_dict[key]

    ## Unconstrained
        if sub_dict['type'] == 'unconstrained':
            unconstr_pars.extend(pm.Gamma('Unconstrained', alpha=sub_dict['input'][0], beta=sub_dict['input'][1]))
        pass

    ## Normal and Poisson constraints            
        if sub_dict['type'] == 'normal':
            norm_mu.append(sub_dict['input'][0])
            norm_sigma.append(sub_dict['input'][1])

        if sub_dict['type'] == 'poisson':
            poiss_alpha.append(sub_dict['input'][0])
            poiss_beta.append(sub_dict['input'][1])

    if np.array(norm_mu, dtype=object).size != 0:
        norm_pars.extend(pm.Normal('Normals', mu=list(np.concatenate(norm_mu)), sigma=list(np.concatenate(norm_sigma))))

    if np.array(poiss_alpha, dtype=object).size != 0:
        poiss_pars.extend(pm.Gamma('Gammas', alpha=list(np.concatenate(poiss_alpha)), beta=list(np.concatenate(poiss_beta))))

    pars = []
    for i in [unconstr_pars, norm_pars, poiss_pars]:
        i = np.array(i)
        if i.size != 0:
            pars.append(i)
    pars = np.concatenate(pars)
    target = get_target(model)
    final = pt.as_tensor_variable(pars[target.argsort()].tolist())

    return final

In [5]:
import prepare_inference
unconstr_dict = {
    'mu': {'type': 'unconstrained', 'input': [[1.], [1.]]}
}
prior_dict = prepare_inference.prepare_priors(model, unconstr_dict)
prepared_model = prepare_inference.prepare_model(model=model, observations=obs, precision=1, priors=prior_dict)

In [7]:
with pm.Model():
    pars = priors2pymc(prepared_model)    
    params = pm.Deterministic('params',pars)
    expected = pm.Deterministic('expected',fwd_op(params))
    data = pm.Poisson('data', expected, observed=obs)

    post_data = pm.sample(150,  chains = 1, tune=100)
    post_pred = pm.sample_posterior_predictive(post_data)
    prior_pred = pm.sample_prior_predictive(150)
    

Only 150 samples in chain.
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...


Sampling 1 chain for 100 tune and 150 draw iterations (100 + 150 draws total) took 0 seconds.


ValueError: different number of dimensions on data and dims: 3 vs 2