In [None]:
import sys
import timeit

import jax
from jax.config import config
config.update('jax_enable_x64', True)

import jax.numpy as jnp
from jax.ops import index, index_update


import madjax
from madjax.phasespace.flat_phase_space_generator import FlatInvertiblePhasespace
#from madjax.phasespace.new_flat_phase_space_generator import generate_phase_space_inputs, generateKinematics, invertKinematics
from madjax.phasespace.vectors import Vector

key = jax.random.PRNGKey(0)

#numpy & friends
import numpy as onp
import matplotlib.pyplot as plt
%matplotlib inline
import scipy.optimize

import random as random_orig

mj = madjax.MadJax(config_name="ttbar_bqq_bqq")

In [None]:
process_name = "Matrix_1_gg_ttx_t_budx_tx_bxdux"
process = mj.processes[process_name]()

E_cm = 14000.

nDimPhaseSpace = 14

matrix_element = mj.matrix_element(E_cm=E_cm, process_name=process_name, return_grad=False, do_jit=False)


external_parameters = {}
parameters = mj.parameters.calculate_full_parameters(external_parameters)
external_masses = process.get_external_masses(parameters)


ps_generator = FlatInvertiblePhasespace(
    external_masses[0], external_masses[1],
    beam_Es = (E_cm/2.,E_cm/2.),
    beam_types=(0,0)
)





In [None]:
def get_final_state(point):
    return jax.numpy.asarray( [p.asarray() for p in point[-6:]] )

def convert_to_PS_point(incoming, observed):
    _out = jax.numpy.vstack(([p.asarray() for p in incoming], [p for p in observed]))
    return [Vector(p) for p in _out]

def smear_final_state(subkey, final, sig_jet=0.1):
    smear = jax.random.normal(subkey,(6,4))*0.2
    return final + smear

def smear_logpdf(observed, final):
    log_pdf = jax.scipy.stats.norm.logpdf(observed,final,0.2)
    return jax.numpy.sum(log_pdf)

def logpdf(observed, parameters, random_variables):
    PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)
    ME_value = matrix_element(parameters, random_variables)
    final = get_final_state(PS_point)
    return smear_logpdf(observed, final) + jax.numpy.log(ME_value) + jax.numpy.log(jacobian)
    

In [None]:
def get_nll(params):
    def nll(observed, rv):
        return -logpdf(observed, params, rv) 

    v_and_g = jax.jit(jax.value_and_grad(nll, argnums=1))
    #v_and_g = jax.value_and_grad(nll, argnums=1)
    
    def objective(observed, rv):
        v,g = v_and_g(observed, rv)
        return v,g
    return objective


def get_rv_objective(observed, nll_objective):
    
    def rv_objective(rv):
        return nll_objective(observed, rv)
    
    return rv_objective

In [None]:
random_variables = [random_orig.random() for _ in range(ps_generator.nDimPhaseSpace())]
PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)

fs = get_final_state(PS_point)

key,subkey = jax.random.split(key)
obs = smear_final_state(subkey, fs)

nll_objective = get_nll(external_parameters)
objective = get_rv_objective(obs, nll_objective)

start_time = timeit.default_timer()
print(objective(random_variables))
elapsed = timeit.default_timer() - start_time
print("time", elapsed)

In [None]:
random_variables = [random_orig.random() for _ in range(ps_generator.nDimPhaseSpace())]
PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)

fs = get_final_state(PS_point)

key,subkey = jax.random.split(key)
obs = smear_final_state(subkey, fs)

In [None]:
obs_vec = convert_to_PS_point(incoming=PS_point[0:2], observed=obs)
i_rv, i_wt = ps_generator.invertKinematics(E_cm, obs_vec)

objective = get_rv_objective(obs, nll_objective)

start_time = timeit.default_timer()
r = scipy.optimize.minimize(
    objective,
    jnp.asarray(i_rv),
    jac = True,
    method = 'Newton-CG',
    bounds = [(0.0,1.0)]*14
)
elapsed = timeit.default_timer() - start_time


print(r)
print("time",elapsed)
bestfit = r.x
print("true rvs:",random_variables)

In [None]:
jnp.asarray(i_rv)

In [None]:
random_variables = [random_orig.random() for _ in range(ps_generator.nDimPhaseSpace())]
PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)

fs = get_final_state(PS_point)

key,subkey = jax.random.split(key)
obs = smear_final_state(subkey, fs)

combos = [[0,1,2,3,4,5],
          [0,4,2,3,1,5],
          [0,5,2,3,4,1],
          [3,1,2,0,4,5],
          [3,4,2,0,1,5],
          [3,5,2,0,4,1]]

results = []
for i in range(6):
    print("Starting Combination", i)
    
    new_obs = obs[ combos[i] ]
    
    obs_vec = convert_to_PS_point(incoming=PS_point[0:2], observed=new_obs)
    i_rv, i_wt = ps_generator.invertKinematics(E_cm, obs_vec)
    
    new_objective = get_rv_objective(new_obs, nll_objective)
    
    start_time = timeit.default_timer()
    r = scipy.optimize.minimize(
    new_objective,
    onp.asarray(i_rv),
    jac = True,
    method = 'Newton-CG',
    bounds = [(0.0,1.0)]*14
    )
    elapsed = timeit.default_timer() - start_time

    results.append(r)
    print(r)
    if i==0:
        print("rvs", random_variables)
    print("time",elapsed)
    print("")

In [None]:
print(random_variables)