In [1]:
%load_ext autoreload
%autoreload 2

import json


import pymc as pm
import arviz as az

import numpy as np

import pytensor
from pytensor import tensor as pt
from pytensor.graph.basic import Apply
from pytensor.graph import Apply, Op

# import aesara
import aesara.tensor as at
# from aesara.graph.op import Op
from aesara.link.jax.dispatch import jax_funcify

import jax
from jax import grad, jit, vmap, value_and_grad, random
import jax.numpy as jnp


# import sys
# sys.path.insert(1, '/Users/malinhorstmann/Documents/pyhf_pymc/src')
import MH_inference
import HMC_inference
import prepare_inference

import matplotlib.pyplot as plt

import pyhf
pyhf.set_backend('jax')

## Model

In [2]:
### Simple pyhf model
model = pyhf.Model(
    {'channels': [{'name': 'singlechannel',
    'samples': [
    {'name': 'signal',
     'data': [6, 6, 60],
     'modifiers': [
         {'name': 'mu', 'type': 'normfactor', 'data': None}]},

    {'name': 'background',
     'data': [55, 55, 55],
     'modifiers': [
        ## Staterror / Normal
        {"name": "my_staterror","type": "staterror","data": [2.0, 2.0, 2.0],},
        ## Lumi / Normal
        {'name': 'lumi', 'type': 'lumi', 'data': None},
        ## Correlated / Normal
        {'name': 'corr_bkg', 'type': 'histosys','data': {'hi_data': [65, 56, 67], 'lo_data': [40, 40, 43]}},
        {'name': 'corr_bkg1', 'type': 'histosys','data': {'hi_data': [65, 56, 67], 'lo_data': [40, 40, 43]}},
        {'name': 'corr_bkg2', 'type': 'histosys','data': {'hi_data': [65, 56, 67], 'lo_data': [40, 40, 43]}},
        ## Uncorrelated / Poisson
        {'name': 'uncorr_bkg', 'type': 'shapesys','data': [7, 8, 7.17]},
        {'name': 'uncorr_bkg1', 'type': 'shapesys','data': [7, 8, 7.17]},
        {'name': 'uncorr_bkg2', 'type': 'shapesys','data': [7, 8, 7.17]},
        
         ]},    
                                 
    ]},
    ]
    ,
    "parameters": [
            {
                "name": "lumi",
                "auxdata": [1.0],
                "sigmas": [0.017],
                "bounds": [[0.915, 1.085]],
                "inits": [1.0],
            }],
        }
)

obs = model.expected_actualdata(model.config.suggested_init())

nBins = len(model.expected_actualdata(model.config.suggested_init()))
nPars = len(model.config.suggested_init())

In [None]:
# with open('SRee_SRmm_Srem.json') as serialized:
with open('ttbar_ljets_xsec_inclusive_pruned.json') as serialized:
    spec = json.load(serialized)

workspace = pyhf.Workspace(spec)
model = workspace.model()

obs = workspace.data(model, include_auxdata=False)

nBins = len(model.expected_actualdata(model.config.suggested_init()))
nPars = len(model.config.suggested_init())

In [3]:
# Prepare the priors for sampling
    # Unconstrained parameters
unconstr_dict = {
    'uncon1': {'type': 'unconstrained', 'type2': 'normal', 'input': [[1], [0.1]]}
    }

    # Create dictionary with all priors (unconstrained, constrained by normal and poisson)
prior_dict = prepare_inference.prepare_priors(model, unconstr_dict)

    # dictionary with keys 'model', 'obs', 'priors', 'precision'
prepared_model = prepare_inference.prepare_model(model=model, observations=obs, precision=1, priors=prior_dict)

## General Functions

In [4]:
# Jax expected data
@jax.jit
def processed_expData(parameters):
    a = jnp.stack([model.expected_actualdata(parameters)[i] for i in range(nBins)])
    return a

one_vector = np.full(nBins, 1., dtype='float64')

# Gradient list (dn_bins/dx_1, ..., dn_bins/dx_nPars)
@jax.jit
def vjp_expData(parameters):
    _,back = jax.vjp(processed_expData, parameters)
    return back(one_vector)

In [5]:
# Parameter options
eval_point = model.config.suggested_init()
pars = prepare_inference.priors2pymc(prepared_model)

## Hamiltonian MC

### Gradient Op

Takes both eval_point and pars as input.

In [6]:
class VJPOp(pt.Op):

    def make_node(self, parameters):
        # self.vjp_func = vjp_expData
        inputs = [pt.as_tensor_variable(parameters)]
        outputs = [pt.vector()]
        
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (parameters,) = inputs
        results = vjp_expData(parameters)

        if not isinstance(results, (list, tuple)):
                results = (results,)
                
        for i, r in enumerate(results):
            outputs[i][0] = np.asarray(r)

vjp_op = VJPOp()

### Non-Gradient Op

Takes both eval_point and pars as input.

In [11]:
class ExpDataOp(pt.Op):
    
    def make_node(self, parameters):
        inputs = [pt.as_tensor_variable(parameters)]
        outputs = [pt.vector()]

        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (parameters, ) = inputs
        results = processed_expData(parameters)

        if len(outputs) == 1:
                outputs[0][0] = np.asarray(results)
                return
        for i, r in enumerate(results):
                outputs[i][0] = np.asarray(r)

    # def grad(self, parameters):
    #     return [vjp_op(parameters)]
        
expData_op = ExpDataOp()

In [9]:
# class ExpDataOp(pt.Op):
#     """

#     """
#     itypes = [pt.dvector]  
#     otypes = [pt.dvector]  

#     # def __init__(self, name, func):
#     #     ## Add inputs as class attributes
#     #     self.func = func
#     #     self.name = name

#     def perform(self, node, inputs, outputs):
#         ## Method that is used when calling the Op
#         (theta,) = inputs  # Contains my variables

#         ## Calling input function (in our case the model.expected_actualdata)
#         result = self.func(theta)

#         ## Output values of model.expected_actualdata
#         outputs[0][0] = np.asarray(result, dtype=node.outputs[0].dtype)

#     def grad(self, parameters):
#         return [vjp_op(parameters)]

### Sampling

In [13]:
with pm.Model():
    pars = prepare_inference.priors2pymc(prepared_model)
    pars = model.config.suggested_init()

    expData_op = ExpDataOp()
    ExpData = pm.Normal("ExpData", mu=expData_op(pars).eval(), sigma=1, observed=obs)

    post_data = pm.sample(1500)
    post_pred = pm.sample_posterior_predictive(post_data)
    prior_pred = pm.sample_prior_predictive(1500)

print(model.expected_actualdata(model.config.suggested_init()))
az.summary(post_pred, var_names="ExpData")

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [Unconstrained, Normals, Gammas]


Sampling 4 chains for 1_000 tune and 1_500 draw iterations (4_000 + 6_000 draws total) took 2 seconds.


Sampling: [ExpData, Gammas, Normals, Unconstrained]


[ 61.  61. 115.]




Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
ExpData[0],60.99,0.994,59.186,62.935,0.013,0.009,6053.0,5928.0,1.0
ExpData[1],60.983,0.993,59.14,62.814,0.013,0.009,5896.0,5930.0,1.0
ExpData[2],115.011,0.993,113.126,116.858,0.013,0.009,5704.0,5712.0,1.0


## Metropolis-Hastings

### Non-Gradient Op

In [26]:
class ExpDataOp_MH1(pt.Op):
    """

    """
    itypes = [pt.dvector]  
    otypes = [pt.dvector]  

    def __init__(self, name, func):
        ## Add inputs as class attributes
        self.func = func
        self.name = name

    def perform(self, node, inputs, outputs):
        ## Method that is used when calling the Op
        (theta,) = inputs  # Contains my variables

        ## Calling input function (in our case the model.expected_actualdata)
        result = self.func(theta)

        ## Output values of model.expected_actualdata
        outputs[0][0] = np.asarray(result, dtype=node.outputs[0].dtype)

### Sampling

In [31]:
with pm.Model():

    pars = prepare_inference.priors2pymc(prepared_model)
    # pars = eval_point
    
    ExpData_op_MH = ExpDataOp_MH1('mainOp', jax.jit(model.expected_actualdata))

    ExpData_MH = pm.Normal('ExpData', mu=ExpData_op_MH(pars), sigma=1, observed=obs)

    post_data_MH = pm.sample(500)
    post_pred_MH = pm.sample_posterior_predictive(post_data_MH)
    prior_pred_MH = pm.sample_prior_predictive(500)

print(model.expected_actualdata(model.config.suggested_init()))
az.summary(post_pred_MH, var_names="ExpData")

Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>Metropolis: [Unconstrained]
>Metropolis: [Normals]
>Metropolis: [Gammas]
CompoundStep
>Metropolis: [Unconstrained]
>Metropolis: [Normals]
>Metropolis: [Gammas]


Sampling 4 chains for 1_000 tune and 500 draw iterations (4_000 + 2_000 draws total) took 1 seconds.


Sampling: [ExpData, Gammas, Normals, Unconstrained]


[ 61.  61. 115.]




Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
ExpData[0],61.179,1.403,58.425,63.645,0.05,0.035,798.0,1381.0,1.0
ExpData[1],60.708,1.454,58.094,63.426,0.065,0.046,507.0,726.0,1.0
ExpData[2],115.07,1.423,112.616,117.905,0.057,0.041,614.0,1068.0,1.0
