In [1]:
import numpy as np
import matplotlib.pyplot as plt
import json
import time
import pytensor 
import pymc as pm
import arviz as az
import scipy.stats as sps
import pyhf
pyhf.set_backend('jax')

from pytensor import tensor as pt
from pytensor.graph.basic import Apply
from pytensor.graph import Apply, Op

import aesara
import aesara.tensor as at
from aesara.graph.op import Op
from aesara.link.jax.dispatch import jax_funcify
import jax
from jax import grad, jit, vmap, value_and_grad, random
import jax.numpy as jnp

# Using JAX and aesara for HMC sampling of the model.expected_actualdata()

## Set-Up

### Model

In [3]:
### Choose model
n = "DisplacedLeptons"

if n == "ttbar":
    with open("/Users/malinhorstmann/Documents/pyhf_pymc/PredictiveChecks/ttbar_ljets_xsec_inclusive_pruned.json") as serialized:
        spec = json.load(serialized)
    nBins = 37

if n == "DisplacedLeptons":
    with open("/Users/malinhorstmann/Documents/pyhf_pymc/PredictiveChecks/SRee_SRmm_Srem.json") as serialized:
        spec = json.load(serialized)
    nBins = 3


### Create pyhf model
workspace = pyhf.Workspace(spec)

model = workspace.model()

### Observations
obs = model.expected_data(model.config.suggested_init())

### Expected data

In [4]:
### Preprocess the model.expected_actualdata
def processedData(p):
    # a = jnp.stack([jax.jit(model.expected_actualdata(p))[i] for i in range(nBins)])
    a = jnp.stack([model.expected_actualdata(p)[i] for i in range(nBins)])
    return a

### 

## Op 1

In [5]:
from aesara.graph.op import Op
def make_op(func, itypes, otypes):
    @jax.jit
    def vjp_func(fwd_inputs, vector):
        _,back = jax.vjp(func,fwd_inputs)
        return back(vector)

    class JaxVJPOp(Op):
        __props__ = ("jax_vjp_func",)

        def __init__(self):
            self.jax_vjp_func = vjp_func
            self.itypes = itypes + otypes
            self.otypes = itypes
            super().__init__()

        def perform(self, node, inputs, outputs):

            results = self.jax_vjp_func(*(jnp.asarray(x) for x in inputs))

            if not isinstance(results, (list, tuple)):
                results = (results,)

            for i, r in enumerate(results):
                outputs[i][0] = np.asarray(r)


    jax_grad_op = JaxVJPOp()
                
    @jax_funcify.register(JaxVJPOp)
    def jax_funcify_JaxGradOp(op):
        return op.jax_vjp_func

    @jax.jit
    def fwd_func(fwd_inputs):
        return func(fwd_inputs)
    
    class JaxOp(Op):
        __props__ = ("fwd_func",)

        def __init__(self):
            self.fwd_func = fwd_func
            self.itypes = itypes
            self.otypes = otypes
            super().__init__()

        def perform(self, node, inputs, outputs):
            results = self.fwd_func(*(jnp.asarray(x) for x in inputs))
            if len(outputs) == 1:
                outputs[0][0] = np.asarray(results)
                return
            for i, r in enumerate(results):
                outputs[i][0] = np.asarray(r)

        def grad(self, inputs, vectors):
            return [jax_grad_op(inputs[0], vectors[0])]

    @jax_funcify.register(JaxOp)
    def jax_funcify_JaxOp(op):
        return op.fwd_func

    jax_op = JaxOp()
    
    return jax_op, jax_grad_op

### Testing the `grad`

In [6]:
### Appling the Op to model.expected_actualdata
op, grad_op = make_op(
    processedData,
    (at.TensorType(dtype=np.float64, shape=(len(model.config.par_map),)),),
    (at.TensorType(dtype=np.float64, shape=(nBins,)),),
)

### Test for some array of input parameters
a = np.linspace(0.01, 1, len(model.config.par_names)).tolist()
pars = at.as_tensor_variable(a)
grad_op(pars, at.constant([1.0, 1.0, 1.0])).eval()

array([ 5.22021957e-02,  7.65032179e-03,  9.50228104e-02,  2.61874674e-05,
       -3.44132928e-05, -4.50589683e-06,  5.89919525e-06,  2.11944936e-05,
       -7.58614712e-06,  1.74513037e-02,  1.90408675e-01, -2.65716482e-05,
       -3.39078458e-06,  4.72531697e-06,  3.81544561e-06,  1.61576618e-05,
        4.09487406e-06, -9.99838836e-07])

## Op 2

In [7]:
def func(x):
    return x[0]**2 + x[1]**3, x[2]**4

eval_point = [1.0, 1.0, 1.0]

primals, func_vjp = jax.vjp(func, eval_point)

vjp1 = func_vjp((2.0, 1.0))

print(f'primals: {primals}')
print(f'vjp: {vjp1[0]}' )

primals: (Array(2., dtype=float64, weak_type=True), Array(1., dtype=float64, weak_type=True))
vjp: [Array(4., dtype=float64, weak_type=True), Array(6., dtype=float64, weak_type=True), Array(4., dtype=float64, weak_type=True)]


## Inference

- .. Attention: Using `pm.Normal()` tensor variables as input to `grad_op` does not work yet.
- `op` yields approximatly the same results as the sampling from the normal `Op`

In [110]:
### Opening the pyMC model space
start_time = time.time()
with pm.Model():
    ## Parameters
    a = np.linspace(0.01, 1, len(model.config.par_names)).tolist()
    pars = at.as_tensor_variable(a)

    ## Model for the logpdf()
    # main = pm.Normal("main", mu=grad_op(pars, at.constant([1.0, 1.0, 1.0])).eval(), sigma=1)#, observed=obs)
    main = pm.Normal("main", mu=op(pars).eval(), sigma=0.001)#, observed=obs)

    ## Sampling ...
    post_data = pm.sample(500, progressbar=False)
    # prior_data = pm.sample_prior_predictive(500)
    # post_pred = pm.sample_posterior_predictive(post_data)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [main]
Sampling 4 chains for 1_000 tune and 500 draw iterations (4_000 + 2_000 draws total) took 0 seconds.


In [111]:
az.summary(post_data)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
main[0],0.472,0.001,0.471,0.474,0.0,0.0,3020.0,1479.0,1.0
main[1],0.014,0.001,0.013,0.016,0.0,0.0,3195.0,1623.0,1.0
main[2],0.208,0.001,0.206,0.21,0.0,0.0,2392.0,1678.0,1.0


## Comparison with "normal" Op

... yields approximatly the same as the `op, grad_op = make_op()` for the expected bin cound, i.e. `op`.

In [116]:
from aesara.graph.op import Op

In [118]:
### Class that creates the model Op
class Op(pt.Op):
    itypes = [pt.dvector]  # Expects a vector of parameter values
    otypes = [pt.dvector]  # Outputs a vector of values (the model.expected_actualdata)

    def __init__(self, name, func):
        ## Add inputs as class attributes
        self.func = func
        self.name = name

    def perform(self, node, inputs, outputs):
        ## Method that is used when calling the Op
        (theta,) = inputs  # Contains my variables

        ## Calling input function (in our case the model.expected_actualdata)
        result = self.func(theta)

        ## Output values of model.expected_actualdata
        outputs[0][0] = np.asarray(result, dtype=node.outputs[0].dtype)

In [119]:
### Applying the Op with arguments (function, name)
mainOp = Op("mainOp", jax.jit(model.expected_actualdata))
# mainOp = Op("mainOp", model.expected_actualdata)

### Opening the PyMC model space
with pm.Model():
    pars = []
    mu = []
    sigma = []

    ## Stitching
    for i in range(18):
            mu.append(a[i])
            sigma.append(0.0001)
    pars.extend(pm.Normal('n', mu=mu, sigma=sigma))

    pars = np.concatenate([pars])
    final = pt.as_tensor_variable(pars.tolist())

    ## Model for the model.expected_actualdata()
        # Attention: pm.Poisson breaks down, as \lambda < 0 occasionally if mu=0.0
    main = pm.Normal("main", mu=mainOp(final))#observed=obs)

    ## Sampling
    post_data = pm.sample(500, progressbar=False)
    # prior_data = pm.sample_prior_predictive(500)
    # post_pred = pm.sample_posterior_predictive(post_data)

Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>Metropolis: [n]
>NUTS: [main]
Sampling 4 chains for 1_000 tune and 500 draw iterations (4_000 + 2_000 draws total) took 2 seconds.


In [120]:
az.summary(post_data)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
n[0],0.01,0.0,0.01,0.01,0.0,0.0,231.0,206.0,1.02
n[1],0.068,0.0,0.068,0.068,0.0,0.0,180.0,330.0,1.02
n[2],0.126,0.0,0.126,0.127,0.0,0.0,318.0,387.0,1.01
n[3],0.185,0.0,0.185,0.185,0.0,0.0,258.0,244.0,1.02
n[4],0.243,0.0,0.243,0.243,0.0,0.0,244.0,283.0,1.03
n[5],0.301,0.0,0.301,0.301,0.0,0.0,252.0,205.0,1.01
n[6],0.359,0.0,0.359,0.36,0.0,0.0,252.0,310.0,1.03
n[7],0.418,0.0,0.417,0.418,0.0,0.0,333.0,222.0,1.03
n[8],0.476,0.0,0.476,0.476,0.0,0.0,270.0,242.0,1.03
n[9],0.534,0.0,0.534,0.534,0.0,0.0,299.0,325.0,1.01
