In [1]:
import sys
import jax
from jax.config import config
config.update('jax_enable_x64', True)

import jax
import jax.numpy as jnp

import madjax
from madjax.phasespace.flat_phase_space_generator import FlatInvertiblePhasespace
#from madjax.phasespace.new_flat_phase_space_generator import generate_phase_space_inputs, generateKinematics, invertKinematics
from madjax.phasespace.vectors import Vector

key = jax.random.PRNGKey(0)

#numpy & friends
import numpy as onp
import matplotlib.pyplot as plt
%matplotlib inline
import scipy.optimize

import random as random_orig

import importlib
importlib.reload(jax)

#mj = madjax.MadJax(config_name=sys.argv[1])
mj = madjax.MadJax(config_name="ttbar_bqq_bqq")



In [2]:
process_name = "Matrix_1_gg_ttx_t_budx_tx_bxdux"
process = mj.processes[process_name]()

E_cm = 14000.

nDimPhaseSpace = 14

matrix_element = mj.matrix_element(E_cm=E_cm, process_name=process_name, return_grad=False)


external_parameters = {}
parameters = mj.parameters.calculate_full_parameters(external_parameters)
external_masses = process.get_external_masses(parameters)


ps_generator = FlatInvertiblePhasespace(
    external_masses[0], external_masses[1],
    beam_Es = (E_cm/2.,E_cm/2.),
    beam_types=(0,0)
)





In [7]:
def get_final_state(point):
    return jax.numpy.asarray( [p.asarray() for p in point[-6:]] )

def convert_to_PS_point(incoming, observed):
    _out = jax.numpy.vstack(([p.asarray() for p in incoming], [p for p in obs]))
    return [Vector(p) for p in _out]

def smear_final_state(subkey, final, sig_jet=0.1):
    smear = jax.random.normal(subkey,(6,4))*0.2
    return final + smear

def smear_logpdf(observed, final):
    log_pdf = jax.scipy.stats.norm.logpdf(observed,final,0.2)
    return jax.numpy.sum(log_pdf)

def logpdf(observed, parameters, random_variables):
    PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)
    ME_value = matrix_element(parameters, random_variables)
    final = get_final_state(PS_point)
    return smear_logpdf(observed, final) + jax.numpy.log(ME_value) + jax.numpy.log(jacobian)
    

In [4]:
def get_nll(params):
    def nll(observed, rv):
        return -logpdf(observed, params, rv) 

    v_and_g = jax.jit(jax.value_and_grad(nll, argnums=1))
    #v_and_g = jax.value_and_grad(nll)
    
    def objective(observed, rv):
        v,g = v_and_g(observed, rv)
        return v,g
    return objective


def get_rv_objective(observed, nll_objective):
    
    def rv_objective(rv):
        return nll_objective(observed, rv)
    
    return rv_objective

In [5]:
random_variables = [random_orig.random() for _ in range(ps_generator.nDimPhaseSpace())]
PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)

fs = get_final_state(PS_point)

key,subkey = jax.random.split(key)
obs = smear_final_state(subkey, fs)

nll_objective = get_nll(external_parameters)
objective = get_rv_objective(obs, nll_objective)

print(objective(random_variables)[1])

  lambda t, new_dtype, old_dtype: [convert_element_type(t, old_dtype)])


[DeviceArray(-67036.73778069, dtype=float64), DeviceArray(-46780.37099833, dtype=float64), DeviceArray(-602.1353541, dtype=float64), DeviceArray(-12292.48677876, dtype=float64), DeviceArray(-35833.94575276, dtype=float64), DeviceArray(190159.65786819, dtype=float64), DeviceArray(30515.52456446, dtype=float64), DeviceArray(-74646.83618944, dtype=float64), DeviceArray(-8359.9457158, dtype=float64), DeviceArray(-82446.64742757, dtype=float64), DeviceArray(-5354.57501763, dtype=float64), DeviceArray(45050.63884748, dtype=float64), DeviceArray(-8196.58370924, dtype=float64), DeviceArray(29143.08114759, dtype=float64)]


In [8]:
obs_vec = convert_to_PS_point(incoming=PS_point[0:2], observed=obs)
i_rv, i_wt = ps_generator.invertKinematics(E_cm, obs_vec)

r = scipy.optimize.minimize(
    objective,
    onp.asarray(i_rv),
    jac = True,
    method = 'SLSQP',
    bounds = [(0.0,1.0)]*14
)
print(r)
bestfit = r.x

     fun: -0.052323549719432094
     jac: array([  23209.39842793,   -2373.90358311,    4896.50384407,
          2457.33084551,  -32124.87493355, -135883.110809  ,
        -22414.37669318,   86296.08476771,  -64102.35367142,
        -73747.28669617,  -63671.85960357, -147807.70070856,
        -15844.24090645,  379749.95331651])
 message: 'Optimization terminated successfully.'
    nfev: 1
     nit: 5
    njev: 1
  status: 0
 success: True
       x: array([0.33337956, 0.54974943, 0.38719507, 0.24073918, 0.3357683 ,
       0.34047089, 0.2840093 , 0.91097052, 0.77513549, 0.25508162,
       0.1775508 , 0.38788803, 0.18046613, 0.14091495])


In [9]:
print(random_variables)

[0.3333012238564995, 0.5497305650886092, 0.3872104283590405, 0.24066640520746863, 0.3357395325595064, 0.3404992286392333, 0.2840465854061176, 0.9109509106976008, 0.7751252987704087, 0.25507142253210724, 0.17755850777668225, 0.38788706277447815, 0.18046540357901786, 0.1408610027268985]
