In [74]:
import json
import pyhf
pyhf.set_backend('jax')

import pymc as pm

import numpy as np

import pytensor
from pytensor import tensor as pt
from pytensor.graph.basic import Apply
from pytensor.graph import Apply, Op

# import aesara
import aesara.tensor as at
# from aesara.graph.op import Op
from aesara.link.jax.dispatch import jax_funcify

import jax
from jax import grad, jit, vmap, value_and_grad, random
import jax.numpy as jnp


# import sys
# sys.path.insert(1, '/Users/malinhorstmann/Documents/pyhf_pymc/src')
import MH_inference
import HMC_inference
import prepare_inference

### Model

In [34]:
with open('SRee_SRmm_Srem.json') as serialized:
    spec = json.load(serialized)

workspace = pyhf.Workspace(spec)
model = workspace.model()
obs = workspace.data(model, include_auxdata=False)
nBins = len(model.expected_actualdata(model.config.suggested_init()))

### General stuff

In [106]:
@jax.jit
def processed_expData(parameters):
    # a = jnp.stack([jax.jit(model.expected_actualdata(p))[i] for i in range(nBins)])
    a = jnp.stack([model.expected_actualdata(parameters)[i] for i in range(nBins)])
    return a

@jax.jit
def vjp_expData(parameters, vector):
    _,back = jax.vjp(processed_expData, parameters)
    return back(vector)

In [85]:
pars = prepare_inference.priors2pymc(prepared_model)
print(pars.type)

TensorType(float64, (18,))


## First version

### Creating the Ops

In [109]:
def make_op(func, itypes, otypes):
    """

    """
    @jax.jit
    def vjp_func(parameters, vector):
        _,back = jax.vjp(func, parameters)
        return back(vector)

    class JaxVJPOp(Op):
        __props__ = ("jax_vjp_func",)

        def __init__(self):
            self.jax_vjp_func = vjp_func
            self.itypes = itypes + otypes
            self.otypes = itypes
            super().__init__()

        def perform(self, node, inputs, outputs):

            results = self.jax_vjp_func(*(jnp.asarray(x) for x in inputs))

            if not isinstance(results, (list, tuple)):
                results = (results,)

            for i, r in enumerate(results):
                outputs[i][0] = np.asarray(r)


    jax_grad_op = JaxVJPOp()
                
    @jax_funcify.register(JaxVJPOp)
    def jax_funcify_JaxGradOp(op):
        return op.jax_vjp_func

    @jax.jit
    def fwd_func(fwd_inputs):
        return func(fwd_inputs)
    
    class JaxOp(Op):
        __props__ = ("fwd_func",)

        def __init__(self):
            self.fwd_func = fwd_func
            self.itypes = itypes
            self.otypes = otypes
            super().__init__()

        def perform(self, node, inputs, outputs):
            results = self.fwd_func(*(jnp.asarray(x) for x in inputs))
            if len(outputs) == 1:
                outputs[0][0] = np.asarray(results)
                return
            for i, r in enumerate(results):
                outputs[i][0] = np.asarray(r)

        def grad(self, inputs, vectors):
            return [jax_grad_op(inputs[0], vectors[0])]

    # @jax_funcify.register(JaxOp)
    # def jax_funcify_JaxOp(op):
    #     return op.fwd_func

    jax_op = JaxOp()
    
    return jax_op, jax_grad_op


In [110]:
### Appling the Op to model.expected_actualdata
op, grad_op = HMC_inference.make_op(
    processed_expData,
    (at.TensorType(dtype=np.float64, shape=(len(model.config.par_map),)),),
    (at.TensorType(dtype=np.float64, shape=(nBins,)),),
)

a = np.linspace(0.01, 1, len(model.config.par_names)).tolist()
pars = at.as_tensor_variable(a)
print(pars.type)
print(grad_op(pars, at.constant([1.0, 1.0, 1.0])).eval())

print(op(pars).eval())


TensorType(float64, (18,))
[ 5.22021957e-02  7.65032179e-03  9.50228104e-02  2.61874674e-05
 -3.44132928e-05 -4.50589683e-06  5.89919525e-06  2.11944936e-05
 -7.58614712e-06  1.74513037e-02  1.90408675e-01 -2.65716482e-05
 -3.39078458e-06  4.72531697e-06  3.81544561e-06  1.61576618e-05
  4.09487406e-06 -9.99838836e-07]
[0.47245307 0.01439832 0.2079121 ]


In [111]:
# Prepare the priors for sampling
    # Unconstrained parameters
unconstr_dict = {
    'uncon1': {'type': 'unconstrained', 'type2': 'normal', 'input': [[1], [0.1]]}
    }

    # Create dictionary with all priors (unconstrained, constrained by normal and poisson)
prior_dict = prepare_inference.prepare_priors(model, unconstr_dict)

    # dictionary with keys 'model', 'obs', 'priors', 'precision'
prepared_model = prepare_inference.prepare_model(model=model, observations=obs, precision=0.10, priors=prior_dict)

In [115]:
# Sampling
with pm.Model():
    
    pars = prepare_inference.priors2pymc(prepared_model)
    print(pars.type)

    mu = op(pars)
    main = pm.Normal("main", mu=mu)#, observed=obs)


TensorType(float64, (18,))


TypeError: Invalid input types for Op JaxOp{fwd_func=<CompiledFunction of <function make_op.<locals>.fwd_func at 0x28bcee0d0>>}:
Input 1/1: Expected TensorType(float64, (18,)), got TensorType(float64, (18,))

## Second version

### The gradient Op

In [107]:
class VJPCustomOp(Op):

    def make_node(self, vjp_func, parameters, vector):
        self.func = vjp_expData
        inputs = [pt.as_tensor_variable(x), pt.as_tensor_variable(vector)]
        outputs = [inputs[0].type()]
        
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (parameters, vector) = inputs
        results = vjp_expData(parameters, vector)

        if not isinstance(results, (list, tuple)):
                results = (results,)
                
        for i, r in enumerate(results):
            outputs[i][0] = np.asarray(r)

vjp_custom_op = VJPCustomOp()

In [120]:
# Testing if at-type tensors work
a = np.linspace(0.01, 1, len(model.config.par_names)).tolist()
pars = at.as_tensor_variable(a)
print(pars.type)

# Testing if pm-type Tensors work
pars = prepare_inference.priors2pymc(prepared_model)
print(pars.type)

print(vjp_custom_op(vjp_func=vjp_expData, parameters=pars, vector=pt.as_tensor_variable([1.0, 1.0, 1.0])).eval())

# Sampling the gradient
with pm.Model():
    pars = prepare_inference.priors2pymc(prepared_model)
    mu = vjp_custom_op(vjp_func=vjp_expData, parameters=pars, vector=pt.as_tensor_variable([1.0, 1.0, 1.0])).eval()
    pm.Normal("test", mu=mu, sigma=0.1)
    pm.sample(100)


Only 100 samples in chain.
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...


TensorType(float64, (18,))
TensorType(float64, (18,))
[-0.45348924  0.40669046  0.09532455]


Multiprocess sampling (4 chains in 4 jobs)
NUTS: [Unconstrained, Normals, test]
NUTS: [Unconstrained, Normals, test]


Sampling 4 chains for 1_000 tune and 100 draw iterations (4_000 + 400 draws total) took 1 seconds.


### The non-gradient Op (with grad node)

In [103]:
class CustomOp(Op):
    
    def make_node(self, func, parameters):
        self.func = processed_expData
        inputs = [pt.as_tensor_variable(parameters)]
        outputs = [inputs[0].type()]

        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        results = vjp_expData(parameters, vector)
        
        



In [29]:
def make_op(func, itypes, otypes):
    """

    """
    @jax.jit
    def vjp_func(fwd_inputs, vector):
        _,back = jax.vjp(func,fwd_inputs)
        return back(vector)

    class JaxVJPOp(Op):
        __props__ = ("jax_vjp_func",)

        def __init__(self):
            self.jax_vjp_func = vjp_func
            self.itypes = itypes + otypes
            self.otypes = itypes
            super().__init__()

        def perform(self, node, inputs, outputs):

            results = self.jax_vjp_func(*(jnp.asarray(x) for x in inputs))

            if not isinstance(results, (list, tuple)):
                results = (results,)

            for i, r in enumerate(results):
                outputs[i][0] = np.asarray(r)


    jax_grad_op = JaxVJPOp()
                
    @jax_funcify.register(JaxVJPOp)
    def jax_funcify_JaxGradOp(op):
        return op.jax_vjp_func

    @jax.jit
    def fwd_func(fwd_inputs):
        return func(fwd_inputs)
    
    class JaxOp(Op):
        __props__ = ("fwd_func",)

        def __init__(self):
            self.fwd_func = fwd_func
            self.itypes = itypes
            self.otypes = otypes
            super().__init__()

        def perform(self, node, inputs, outputs):
            results = self.fwd_func(*(jnp.asarray(x) for x in inputs))
            if len(outputs) == 1:
                outputs[0][0] = np.asarray(results)
                return
            for i, r in enumerate(results):
                outputs[i][0] = np.asarray(r)

        def grad(self, inputs, vectors):
            return [jax_grad_op(inputs[0], vectors[0])]

    @jax_funcify.register(JaxOp)
    def jax_funcify_JaxOp(op):
        return op.fwd_func

    jax_op = JaxOp()
    
    return jax_op, jax_grad_op

## Test

In [23]:
def custom_op_jax(x):
    return jnp.exp(x)

jitted_custom_op_jax = jax.jit(custom_op_jax)

def vjp_custom_op_jax(x, gz):
    _, vjp_fn = jax.vjp(custom_op_jax, x)
    return vjp_fn(gz)[0]

jitted_vjp_custom_op_jax = jax.jit(vjp_custom_op_jax)

In [26]:
class CustomOp(Op):
    def make_node(self, ):
        # Create a PyTensor node specifying the number and type of inputs and outputs

        # We convert the input into a PyTensor tensor variable
        inputs = [pt.as_tensor_variable(x)]
        # Output has the same type and shape as `x`
        outputs = [inputs[0].type()]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        # Evaluate the Op result for a specific numerical input

        # The inputs are always wrapped in a list
        (x,) = inputs
        result = jitted_custom_op_jax(x)
        # The results should be assigned inplace to the nested list
        # of outputs provided by PyTensor. If you have multiple
        # outputs and results, you should assign each at outputs[i][0]
        outputs[0][0] = np.asarray(result, dtype="float64")

    def grad(self, inputs, output_gradients):
        # Create a PyTensor expression of the gradient
        (x,) = inputs
        (gz,) = output_gradients
        # We reference the VJP Op created below, which encapsulates
        # the gradient operation
        return [vjp_custom_op(x, gz)]


class VJPCustomOp(Op):
    def make_node(self, x, gz):
        # Make sure the two inputs are tensor variables
        inputs = [pt.as_tensor_variable(x), pt.as_tensor_variable(gz)]
        # Output has the shape type and shape as the first input
        outputs = [inputs[0].type()]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (x, gz) = inputs
        result = jitted_vjp_custom_op_jax(x, gz)
        outputs[0][0] = np.asarray(result, dtype="float64")

# Instantiate the Ops
custom_op = CustomOp()
vjp_custom_op = VJPCustomOp()

In [27]:
pytensor.gradient.verify_grad(custom_op, (np.arange(5, dtype="float64"),), rng=np.random.default_rng())

In [29]:
with pm.Model() as model:
    x = pm.Normal("x", shape=(3,))
    y = pm.Deterministic("y", custom_op(x))  # HERE IS WHERE WE USE THE CUSTOM OP!
    z = pm.Normal("z", y, observed=[1, 2, 0])
    pm.sample(500)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Initializing NUTS using jitter+adapt_diag...
Ambiguities exist in dispatched function _unify

The following signatures may result in ambiguous behavior:
	[ConstrainedVar, object, Mapping], [object, ConstrainedVar, Mapping]
	[ConstrainedVar, object, Mapping], [object, ConstrainedVar, Mapping]
	[ConstrainedVar, Var, Mapping], [object, ConstrainedVar, Mapping]
	[ConstrainedVar, Var, Mapping], [object, ConstrainedVar, Mapping]


Consider making the following additions:

@dispatch(ConstrainedVar, ConstrainedVar, Mapping)
def _unify(...)

@dispatch(ConstrainedVar, ConstrainedVar, Mapping)
def _unify(...)

@dispatch(ConstrainedVar, ConstrainedVar, Mapping)
def _unify(...)

@dispatch(ConstrainedVar, ConstrainedVar, Mapping)
def _unify(...)
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [x]


Sampling 4 chains for 1_000 tune and 500 draw iterations (4_000 + 2_000 draws total) took 1 seconds.
