In [1]:
import sys
import timeit

import jax
from jax.config import config
config.update('jax_enable_x64', True)

import jax.numpy as jnp
from jax.ops import index, index_update


import madjax
from madjax.phasespace.flat_phase_space_generator import FlatInvertiblePhasespace
#from madjax.phasespace.new_flat_phase_space_generator import generate_phase_space_inputs, generateKinematics, invertKinematics
from madjax.phasespace.vectors import Vector

key = jax.random.PRNGKey(0)

#numpy & friends
import numpy as onp
import matplotlib.pyplot as plt
%matplotlib inline
import scipy.optimize

import random as random_orig

#import importlib
#importlib.reload(jax)

#mj = madjax.MadJax(config_name=sys.argv[1])
mj = madjax.MadJax(config_name="ttbar_bqq_bqq")



In [2]:
process_name = "Matrix_1_gg_ttx_t_budx_tx_bxdux"
process = mj.processes[process_name]()

E_cm = 14000.

nDimPhaseSpace = 14

matrix_element = mj.matrix_element(E_cm=E_cm, process_name=process_name, do_jit=False)


external_parameters = {}
parameters = mj.parameters.calculate_full_parameters(external_parameters)
external_masses = process.get_external_masses(parameters)


ps_generator = FlatInvertiblePhasespace(
    external_masses[0], external_masses[1],
    beam_Es = (E_cm/2.,E_cm/2.),
    beam_types=(0,0)
)





In [3]:
def get_final_state(point):
    return jax.numpy.asarray( [p.asarray() for p in point[-6:]] )

def convert_to_PS_point(incoming, observed):
    _out = jax.numpy.vstack(([p.asarray() for p in incoming], [p for p in observed]))
    return [Vector(p) for p in _out]

def smear_final_state(subkey, final, sig_jet=0.1):
    smear = jax.random.normal(subkey,(6,4))*0.2
    return final + smear

def smear_logpdf(observed, final):
    log_pdf = jax.scipy.stats.norm.logpdf(observed,final,0.2)
    return jax.numpy.sum(log_pdf)

def logpdf(observed, parameters, random_variables):
    PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)
    ME_value = matrix_element(parameters, random_variables)
    final = get_final_state(PS_point)
    return smear_logpdf(observed, final) + jax.numpy.log(ME_value) + jax.numpy.log(jacobian)
    

In [4]:
def get_nll(params):
    def nll(observed, rv):
        return -logpdf(observed, params, rv) 

    v_and_g = jax.jit(jax.value_and_grad(nll, argnums=1))
    #v_and_g = jax.value_and_grad(nll, argnums=1)
    
    def objective(observed, rv):
        v,g = v_and_g(observed, rv)
        return v,g
    return objective


def get_rv_objective(observed, nll_objective):
    
    def rv_objective(rv):
        return nll_objective(observed, rv)
    
    return rv_objective

In [14]:
random_variables = [random_orig.random() for _ in range(ps_generator.nDimPhaseSpace())]
PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)

fs = get_final_state(PS_point)

key,subkey = jax.random.split(key)
obs = smear_final_state(subkey, fs)

nll_objective = get_nll(external_parameters)
objective = get_rv_objective(obs, nll_objective)

start_time = timeit.default_timer()
print(objective(random_variables))
elapsed = timeit.default_timer() - start_time
print("jit time", elapsed)

(DeviceArray(-14.8815418, dtype=float64), [DeviceArray(-60085.99617098, dtype=float64), DeviceArray(420261.99587517, dtype=float64), DeviceArray(16840.82836004, dtype=float64), DeviceArray(14707.48934722, dtype=float64), DeviceArray(13475.15093016, dtype=float64), DeviceArray(6535.80252126, dtype=float64), DeviceArray(9115.98832613, dtype=float64), DeviceArray(-269356.6993851, dtype=float64), DeviceArray(2472.36807693, dtype=float64), DeviceArray(-2785.20797573, dtype=float64), DeviceArray(-10634.66296716, dtype=float64), DeviceArray(-1950.4600314, dtype=float64), DeviceArray(-1794.29981612, dtype=float64), DeviceArray(-16069.63505263, dtype=float64)])
jit time 660.3369751080008


In [18]:
random_variables = [random_orig.random() for _ in range(ps_generator.nDimPhaseSpace())]
PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)

fs = get_final_state(PS_point)

key,subkey = jax.random.split(key)
obs = smear_final_state(subkey, fs)

In [45]:
obs_vec = convert_to_PS_point(incoming=PS_point[0:2], observed=obs)
i_rv, i_wt = ps_generator.invertKinematics(E_cm, obs_vec)

objective = get_rv_objective(obs, nll_objective)

start_time = timeit.default_timer()
r = scipy.optimize.minimize(
    objective,
    jnp.asarray(i_rv),
    jac = True,
    method = 'Newton-CG',
    bounds = [(0.0,1.0)]*14
)
elapsed = timeit.default_timer() - start_time


print(r)
print("time",elapsed)
bestfit = r.x
print("true rvs:",random_variables)

     fun: -8.719950193933414
     jac: DeviceArray([ 118800.02115817,  -11154.26588206,  -20576.64062697,
              -25767.01651583,  -93023.3076097 ,   14449.25821559,
              -33705.89940403,  -88277.59596754,  -24875.58980591,
                -986.18193972,  -23219.99844456, -161500.42701992,
              -65984.7003016 ,  412553.76160724], dtype=float64)
 message: 'Optimization terminated successfully.'
    nfev: 2
    nhev: 0
     nit: 1
    njev: 4
  status: 0
 success: True
       x: DeviceArray([0.59480519, 0.58971939, 0.48156639, 0.44012385, 0.93200458,
             0.35267831, 0.38137671, 0.51886891, 0.23973291, 0.39124666,
             0.19193181, 0.49067666, 0.90717061, 0.78677016],            dtype=float64)
time 0.015068435999637586
true rvs: [0.5946855411787582, 0.589784601154183, 0.48161640450153487, 0.440117055990217, 0.9320075921373164, 0.35271666696798, 0.38135748404038217, 0.5188755444334604, 0.23968599595396956, 0.39124635073454905, 0.1919393341834036, 0.

In [41]:
jnp.asarray(i_rv)

DeviceArray([0.59481747, 0.58971824, 0.48156426, 0.44012119, 0.93199496,
             0.3526798 , 0.38137323, 0.51885978, 0.23973033, 0.39124656,
             0.19192941, 0.49065995, 0.90716379, 0.78681283],            dtype=float64)

In [46]:
random_variables = [random_orig.random() for _ in range(ps_generator.nDimPhaseSpace())]
PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)

fs = get_final_state(PS_point)

key,subkey = jax.random.split(key)
obs = smear_final_state(subkey, fs)

combos = [[0,1,2,3,4,5],
          [0,4,2,3,1,5],
          [0,5,2,3,4,1],
          [3,1,2,0,4,5],
          [3,4,2,0,1,5],
          [3,5,2,0,4,1]]

results = []
for i in range(6):
    print("Starting Combination", i)
    
    new_obs = obs[ combos[i] ]
    
    obs_vec = convert_to_PS_point(incoming=PS_point[0:2], observed=new_obs)
    i_rv, i_wt = ps_generator.invertKinematics(E_cm, obs_vec)
    
    new_objective = get_rv_objective(new_obs, nll_objective)
    
    start_time = timeit.default_timer()
    r = scipy.optimize.minimize(
    new_objective,
    onp.asarray(i_rv),
    jac = True,
    method = 'Newton-CG',
    bounds = [(0.0,1.0)]*14
    )
    elapsed = timeit.default_timer() - start_time

    results.append(r)
    print(r)
    if i==0:
        print("rvs", random_variables)
    print("time",elapsed)
    print("")

Starting Combination 0
     fun: -9.370558308130839
     jac: DeviceArray([-16230.60877106,   3820.9689206 ,  -5329.63901009,
               -730.60131997,    846.66869122,  14947.46341143,
               1022.17729131,  -5530.35833604,  -7674.09052773,
             -19910.10978662,  45822.68536553,  -7337.99754921,
               2364.81144326, -24312.82763115], dtype=float64)
 message: 'Optimization terminated successfully.'
    nfev: 3
    nhev: 0
     nit: 2
    njev: 10
  status: 0
 success: True
       x: DeviceArray([0.37356685, 0.66163744, 0.36552485, 0.76437303, 0.56057858,
             0.73302956, 0.63668933, 0.13483139, 0.65585128, 0.46628274,
             0.02085885, 0.65350994, 0.1765288 , 0.5946134 ],            dtype=float64)
rvs [0.3736468804618248, 0.6616159365747475, 0.365559785545863, 0.7643754277301225, 0.5605364666037468, 0.7330235841782851, 0.6366982570997428, 0.13484100311004077, 0.6558287713119924, 0.46627364735563237, 0.020863202496359423, 0.6536273788443114, 0

In [13]:
print(random_variables)

[0.5369316069803622, 0.6441100322294254, 0.7056283186512375, 0.2944479240389073, 0.11527650853269467, 0.3913974949833604, 0.7265434668879712, 0.029688331380537614, 0.09412554128720063, 0.755978949188364, 0.7743317785305885, 0.5219326666455739, 0.9996664445851675, 0.16990592622490897]
