In [13]:
import json
import pyhf
pyhf.set_backend('jax')

import pymc as pm
import arviz as az

import numpy as np

import pytensor
from pytensor import tensor as pt
from pytensor.graph.basic import Apply
from pytensor.graph import Apply, Op

# import aesara
import aesara.tensor as at
# from aesara.graph.op import Op
from aesara.link.jax.dispatch import jax_funcify

import jax
from jax import grad, jit, vmap, value_and_grad, random
import jax.numpy as jnp


# import sys
# sys.path.insert(1, '/Users/malinhorstmann/Documents/pyhf_pymc/src')
import MH_inference
import HMC_inference
import prepare_inference

### Model

In [14]:
with open('SRee_SRmm_Srem.json') as serialized:
# with open('ttbar_ljets_xsec_inclusive_pruned.json') as serialized:
    spec = json.load(serialized)

workspace = pyhf.Workspace(spec)
model = workspace.model()
obs = workspace.data(model, include_auxdata=False)
nBins = len(model.expected_actualdata(model.config.suggested_init()))
nPars = len(model.config.suggested_init())

# Prepare the priors for sampling
    # Unconstrained parameters
unconstr_dict = {
    'uncon1': {'type': 'unconstrained', 'type2': 'normal', 'input': [[1], [0.1]]}
    }

    # Create dictionary with all priors (unconstrained, constrained by normal and poisson)
prior_dict = prepare_inference.prepare_priors(model, unconstr_dict)

    # dictionary with keys 'model', 'obs', 'priors', 'precision'

prepared_model = prepare_inference.prepare_model(model=model, observations=obs, precision=0.10, priors=prior_dict)

### General Stuff

In [15]:
@jax.jit
def processed_expData(parameters):
    # a = jnp.stack([jax.jit(model.expected_actualdata(p))[i] for i in range(nBins)])
    a = jnp.stack([model.expected_actualdata(parameters)[i] for i in range(nBins)])
    return a

@jax.jit
def vjp_expData(parameters, vector):
    _,back = jax.vjp(processed_expData, parameters)
    return back(vector)

In [16]:
one_vector = np.full(nBins, 1., dtype='float64')

pars = prepare_inference.priors2pymc(prepared_model)
print(pars.type)



TensorType(float64, (18,))


## Second Version

 - Attention: Input and Output length of `VJPCustomOp(Op)` are super random. I have to add this weird `x`
 by hand, in order to have a three-dim output for `VJPCustomOp()`.

### The gradient Op

In [21]:
class VJPCustomOp(Op):

    def make_node(self, vjp_func, parameters, vector):
        a = np.linspace(0.01, 1, nBins).tolist()
        # a = model.config.suggested_init()
        # pars = at.as_tensor_variable(a)
        self.vjp_func = vjp_expData
        inputs = [pt.as_tensor_variable(a), pt.as_tensor_variable(vector)]
        outputs = [inputs[0].type()]
        
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (parameters, vector) = inputs
        results = vjp_expData(parameters, vector)

        if not isinstance(results, (list, tuple)):
                results = (results,)
                
        for i, r in enumerate(results):
            outputs[i][0] = np.asarray(r)

vjp_custom_op = VJPCustomOp()

In [22]:
# Testing if at-type tensors work
a = np.linspace(0.01, 1, len(model.config.par_names)).tolist()
pars = at.as_tensor_variable(a)
print(pars.type)

# Testing if pm-type Tensors work
pars = prepare_inference.priors2pymc(prepared_model)
print(pars.type)

print(vjp_custom_op(vjp_func=vjp_expData, parameters=pars, vector=one_vector).eval())

# Sampling the gradient
with pm.Model():
    # pars = prepare_inference.priors2pymc(prepared_model)
    mu = vjp_custom_op(vjp_func=vjp_expData, parameters=pars, vector=one_vector).eval()
    pm.Normal("ehh", mu=mu, sigma=0.1)
    post_data = pm.sample(500)
    post_pred = pm.sample_posterior_predictive(post_data)
    prior_pred = pm.sample_prior_predictive(500)


TensorType(float64, (18,))
TensorType(float64, (18,))


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...


[0.3844388  0.00761265 0.10609929]


Multiprocess sampling (4 chains in 4 jobs)
NUTS: [ehh]
NUTS: [ehh]


Sampling 4 chains for 1_000 tune and 500 draw iterations (4_000 + 2_000 draws total) took 0 seconds.
Sampling: [ehh]


### The non-gradient Op (with grad node)

In [23]:
class CustomOp(Op):
    
    def make_node(self, func, parameters):
        self.func = processed_expData
        inputs = [pt.as_tensor_variable(parameters)]
        outputs = [inputs[0].type()]

        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (parameters, ) = inputs
        results = processed_expData(parameters)

        if len(outputs) == 1:
                outputs[0][0] = np.asarray(results)
                return
        for i, r in enumerate(results):
                outputs[i][0] = np.asarray(r)

    def grad(self, vjp_func, parameters, vector):
        return [vjp_custom_op(vjp_func, parameters, vector)]
        
custom_op = CustomOp()


In [24]:
# Testing if at-type tensors work
# a = np.linspace(0.01, 1, len(model.config.par_names)).tolist()
# pars = at.as_tensor_variable(a)
# print(pars.type)

# Testing if pm-type Tensors work
pars = prepare_inference.priors2pymc(prepared_model)
print(pars.type)

print(custom_op.grad(processed_expData, pars, one_vector)[0].eval())

with pm.Model():
    pars = prepare_inference.priors2pymc(prepared_model)
    mu = custom_op.grad(processed_expData, pars, one_vector)[0].eval()
    pm.Normal("ExpData", mu=mu, sigma=0.1)
    
    post_data = pm.sample(500)
    post_pred = pm.sample_posterior_predictive(post_data)
    prior_pred = pm.sample_prior_predictive(500)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...


TensorType(float64, (18,))
[0.3844388  0.00761265 0.10609929]


Multiprocess sampling (4 chains in 4 jobs)
NUTS: [Unconstrained, Normals, ExpData]
NUTS: [Unconstrained, Normals, ExpData]


Sampling 4 chains for 1_000 tune and 500 draw iterations (4_000 + 2_000 draws total) took 1 seconds.
Sampling: [ExpData, Normals, Unconstrained]


In [25]:
print(model.expected_actualdata(a))
az.summary(post_data, var_names="ExpData")

[0.47245307 0.01439832 0.2079121 ]


Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
ExpData[0],0.384,0.1,0.194,0.565,0.002,0.001,4037.0,1616.0,1.01
ExpData[1],0.008,0.101,-0.185,0.187,0.002,0.003,4340.0,1449.0,1.0
ExpData[2],0.105,0.103,-0.09,0.288,0.001,0.002,5539.0,1489.0,1.0
