In [1]:
import sys
import timeit
import tracemalloc
import logging
import warnings
warnings.simplefilter("ignore")

#numpy & friends
import numpy as onp
import matplotlib.pyplot as plt
%matplotlib inline
import scipy.optimize

import random as random_orig
from itertools import permutations
from functools import partial

# jax and friends
import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)
key = jax.random.PRNGKey(0)

# physics stuff
import pylhe

# Madjax
import madjax


DO_TEST=False



In [2]:
# Memory tracing
def print_memory_stuff(note=""):
    current, peak = tracemalloc.get_traced_memory()
    print(note, f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")

# Setup Gradient Descent

In [4]:
#Custom sigmoid, with identity gradient
#Easy way to implement natural gradient descent 
#  on sigmoid reparameterization of unconstrained dual space

@jax.custom_jvp
def SigmoidStraightThrough(x):
    return jax.nn.sigmoid(x)

SigmoidStraightThrough.defjvps(lambda t, ans, x: t)

In [5]:
#Batched Gradient descent optimizer
#Since SigmoidStraightThrough takes care of jacobian cancellation
#  in natural gradient descent (more numerically stable)
#  ... so nothing extra has to be added here
#
# INPUT:
#   loss      - vectorized function to be optimized
#   grad_loss - vectorized function, gradient of loss
#   u_init    - batched initial guesses
#   lr        - float, learning rate / step size
#   max_iter  - int, maximum number of iterations before stop
#   tol       - float, tolerance of loss change before stop
#
# OUTPUT: dictionary (all values aer batched)
#   "x" - input value at end
#   "loss" - loss at end
#   "n" - number of iteration to finish

def get_vopt_grad_descent(loss, grad_loss):
    
    def _cond_fun(vals):
        
        _not_nan   = jnp.invert(jnp.isnan(vals['loss']))
        _check_tol = (vals['l_diff'] > vals['tol'])
        _check_n   = (vals['n'] < vals['max_iter'])
        
        
         # update what elements to keep updating
        _keep_going = (_not_nan & _check_tol & _check_n)
        
        return _keep_going
    
    def _body_fun(vals):     
        x_out      = vals['x'] - vals['lr']*grad_loss(vals['x'], vals['aux'])
        loss_out   = loss(x_out, vals['aux'])
        l_diff_out = jnp.abs(loss_out - vals['loss'])
        n_iter_out = vals['n']+1
        
        
        return {'x':x_out, 'aux':vals['aux'], 'loss':loss_out, 'l_diff':l_diff_out, 'n':n_iter_out,
               'lr':vals['lr'], 'max_iter':vals['max_iter'], 'tol':vals['tol']}
               
    
    
    def _while_fun(vals):but 
        return jax.lax.while_loop(_cond_fun, _body_fun, vals)

    
    dct_axes = {'l_diff': 0, 'loss': 0, 'n': 0, 'x': 0, 'aux':0, 'lr':None, 'max_iter':None, 'tol':None}
    
    _result = jax.jit( jax.vmap(_while_fun, in_axes=(dct_axes,)) )
    return _result


def run_vopt(optimizer, x_init, aux_input, lr=.1, max_iter=100, tol=0.001):

    loss_val  = 10*jnp.ones( (x_init.shape[0]) ) #loss(x_init, aux_input)
    loss_diff = 10*jnp.ones( (x_init.shape[0]) )
    n_iter    = jnp.zeros( (x_init.shape[0]), dtype=int)
        
    in_vals = {'x':x_init, 'aux':aux_input, 'loss':loss_val, 'l_diff':loss_diff, 'n':n_iter, 
                'lr':lr, 'max_iter':max_iter, 'tol':tol}
    
    return optimizer(jax.lax.stop_gradient(in_vals))

In [6]:
if DO_TEST:
    
    def f(x,y):
        return jnp.sum(x*x)+jnp.sum(y)
        
    f = jax.jit(f)
    df_dx = jax.jit(jax.grad(f, argnums=0))
    
    _x = jnp.array([[1.,0.],[2.,0.],[3.,0.],[4.,0.],[5.,0.]])
    _y = jnp.array([[1.,2.],[3.,4.],[1.,2.],[3.,4.],[1.,2.]])
    
    print(f(_x,_y))
    print(df_dx(_x,_y))
    
    optimizer = get_vopt_grad_descent(f, df_dx)
    
    for i in range(5):
        start_time = timeit.default_timer()
        run_vopt(optimizer, _x, _y, lr=.1, max_iter=100, tol=0.001)
        elapsed = timeit.default_timer() - start_time
        print(i, "time", elapsed,"\n")
      
    print("second \n")
    #optimizer = get_vopt_grad_descent(f, df_dx, lr=.1, max_iter=100, tol=0.001)
    
    for i in range(5):
        start_time = timeit.default_timer()
        _res= run_vopt(optimizer, _x, _y, lr=.1, max_iter=200, tol=0.001)
        elapsed = timeit.default_timer() - start_time
        print(i, "time", elapsed,"\n")
    
    print(_res)

# 4-Vector Operations

In [7]:
def energy_from_mass_basis(p):
    return jnp.sqrt(jnp.sum(jnp.square(p)))

def mass_from_energy_basis(p):
    _m = p[0]*p[0] - p[1]*p[1] - p[2]*p[2] - p[3]*p[3]
    _m = jnp.where(_m<1e-5, 1e-5, jnp.sqrt(_m))
    return _m

vector_energy_from_mass_basis = jax.jit(jax.vmap(energy_from_mass_basis, in_axes=0))
vector_mass_from_energy_basis = jax.jit(jax.vmap(mass_from_energy_basis, in_axes=0))


def scalar_convert_to_energy_basis(p):
    return jnp.hstack( ([energy_from_mass_basis(p)], p[1:4]) )
    #_energy = energy_from_mass_basis(p)
    #jax.ops.index_update(p, jax.ops.index[0], _energy)
    #return p
    
def scalar_convert_to_mass_basis(p):
    return jnp.hstack( ([mass_from_energy_basis(p)], p[1:4]) )
    #_mass = mass_from_energy_basis(p)
    #jax.ops.index_update(p, jax.ops.index[0], _mass)
    #return p

vector_convert_to_energy_basis = jax.jit(jax.vmap(scalar_convert_to_energy_basis, in_axes=0))
vector_convert_to_mass_basis   = jax.jit(jax.vmap(scalar_convert_to_mass_basis, in_axes=0))

# Setup MadJax

In [8]:

E_cm = 500.
config_name  = "ee_ttbar_bqq_bqq"
process_name = "Matrix_1_epem_ttx_t_budx_tx_bxdux"
nDimPS=14

mj = madjax.MadJax(config_name=config_name)
matrix_element = mj.matrix_element(E_cm=E_cm, process_name=process_name, return_grad=False, do_jit=False)
jacobian = mj.jacobian(E_cm=E_cm, process_name=process_name, do_jit=False)
ps_gen = mj.phasespace_generator(E_cm=E_cm, process_name=process_name)({})
ps_vec = mj.phasespace_vectors(E_cm=E_cm, process_name=process_name)

external_params={}

sigma_smear=0.1


## LHE Event Manipulation

In [9]:
# get LHE events generator
lhe_events = pylhe.readLHE("./data/ee_ttbar_bqq_bqq/unweighted_events.lhe")

In [10]:
def lhe_event_to_ps_point(event):
    _particles = []
    for p in event.particles:
        #print(p.id, p.status, p.px, p.py, p.pz, p.e)
        if p.status==-1 or p.status==1:
            _particles += [[p.e, p.px, p.py, p.pz]]
    return jnp.array(_particles)

def get_multiple_lhe_event_to_ps_point(lhe_event_generator, n_events):
    _evts=[]
    for i in range(n_events):
        try:
            _ev = lhe_event_generator.__next__()
        except:
            print("No More Events")
            return _evts
        
        _ps = lhe_event_to_ps_point(_ev)
        _evts += [_ps]
    
    return jnp.array(_evts)

In [11]:
if DO_TEST:
    test_evts = get_multiple_lhe_event_to_ps_point(lhe_events, 3)
    print(test_evts)

## Simulation and InverseKinematics

In [12]:
def convert_to_PS_point(incoming, observed):
    _out = jax.numpy.vstack((incoming, observed))
    return [madjax.phasespace.vectors.Vector(p) for p in _out]

def scalar_unc_rv_from_obs(obs, E_cm):
    incoming = jnp.array([[E_cm/2.0,0,0,E_cm/2.0], [E_cm/2.0,0,0,-E_cm/2.0]]) 
    obs_vec = convert_to_PS_point(incoming=incoming, observed=obs)
    i_rv, i_wt = ps_gen.invertKinematics(E_cm, obs_vec)
    i_unc_rv = jax.scipy.special.logit(i_rv)
    
    return i_unc_rv

vector_unc_rv_from_obs = jax.jit(jax.vmap(scalar_unc_rv_from_obs, in_axes=(0,None)), static_argnums=1)

#if DO_TEST:
#    key,subkey = jax.random.split(key)
#    urv = jax.random.normal(subkey, shape=(5,14))
#    rv  = SigmoidStraightThrough(urv)

#    key,subkey = jax.random.split(key)   
#    obs = vector_sim(jax.random.split(subkey,5), rv, 0.1)
    
#    print(vector_unc_rv_from_obs(obs,E_cm))

In [13]:
def gauss_smear(key, final_state, sigma_all):
    
    key,subkey = jax.random.split(key)
    smear      = jax.random.normal(subkey, final_state.shape)*sigma_all
    observed   = final_state*(1 + smear)
    
    return observed

def gauss_p_smear(key, final_state, sigma_p):
    _fs_m = vector_convert_to_mass_basis(final_state)
    _n = final_state.shape[0]
    
    key,subkey = jax.random.split(key)
    p_smear = jax.random.normal(subkey,(_n,3))*sigma_p
    new_p = _fs_m[:,1:4]*(1+p_smear)
    
    observed = vector_convert_to_energy_basis( jnp.hstack((jnp.reshape(_fs_m[:,0], (_n,1)), new_p)) )
    return observed


def gauss_p_lognormal_m_smear(key, final_state, sigma_p, sigma_m):
    
    _fs_m = vector_convert_to_mass_basis(final_state)
    
    _n = final_state.shape[0]
    
    key,subkey = jax.random.split(key)
    p_smear = jax.random.normal(subkey,(_n,3))*sigma_p
    new_p = _fs_m[:,1:4]*(1+p_smear)
    
    key,subkey = jax.random.split(key)
    m_smear = jax.random.normal(subkey,(_n,))*sigma_m
    new_m = jnp.exp( jnp.log(_fs_m[:,0]+1.)*(1+m_smear))
    
    observed = vector_convert_to_energy_basis( jnp.hstack((jnp.reshape(new_m, (_n,1)), new_p)) )
    return observed

In [14]:
if DO_TEST:
    key,subkey = jax.random.split(key)   
    rv = jax.random.uniform(subkey, shape=(14,))
    
    PS_point    = ps_vec({}, rv)
    final_state = PS_point[-6:]
    
    print(PS_point)
    
    print(final_state)
    
    print(gauss_smear(subkey, final_state, 0.1))
    
    print(gauss_p_smear(subkey, final_state, 0.1))
    
    print(gauss_p_lognormal_m_smear(subkey, final_state, 0.1, 0.1))
    
    print("LHE test")
    
    ev = lhe_events.__next__()
    
    lhe_ps = lhe_event_to_ps_point(ev)
    lhe_fs = lhe_ps[-6:]
    
    print(lhe_ps)
    
    print(lhe_fs)
    
    print(gauss_smear(subkey, lhe_fs, 0.1))
    
    print(gauss_p_smear(subkey, lhe_fs, 0.1))
    
    print(gauss_p_lognormal_m_smear(subkey, lhe_fs, 0.1, 0.1))
    
    
    

In [15]:


def scalar_sim(key, rv, sigma_smear):
    PS_point    = ps_vec({}, rv)
    final_state = PS_point[-6:]
    
    return gauss_p_smear(key, final_state, sigma_smear)

vector_sim = jax.jit(jax.vmap(scalar_sim, in_axes=(0,0,None)), static_argnums=(2,))



def scalar_sim_from_ucon_rv(key, sigma_smear):
    key,subkey = jax.random.split(key)
    urv = jax.random.normal(subkey, shape=(14,))
    rv  = SigmoidStraightThrough(urv)
    
    PS_point    = ps_vec({}, rv)
    final_state = PS_point[-6:]
    
    observed   = gauss_p_smear(key, final_state, sigma_smear)
    
    return observed, final_state, urv, rv

vector_sim_from_ucon_rv = jax.jit(jax.vmap(scalar_sim_from_ucon_rv, in_axes=(0,None)), static_argnums=(1,))

def batch_sim_from_ucon_rv(key, Nevt, sigma_smear):
    key,subkey = jax.random.split(key)
    return vector_sim_from_ucon_rv(jax.random.split(subkey,Nevt),sigma_smear) 


if DO_TEST:
    key,subkey = jax.random.split(key)   
    rv = jax.random.uniform(subkey, shape=(5,14))

    key,subkey = jax.random.split(key)   
    print(vector_sim(jax.random.split(subkey,5), rv, 0.1))

In [16]:
def scalar_sim_from_ps_point(key, ps_point, sigma_smear):
    rv, _ = ps_gen.invertKinematics(E_cm, [madjax.phasespace.vectors.Vector(p) for p in ps_point])
    urv = jax.scipy.special.logit(rv)
    
    final_state = ps_point[-6:]
    observed   = gauss_p_smear(key, final_state, sigma_smear)
    
    return observed, final_state, urv, rv
    
vector_sim_from_ps_point = jax.jit(jax.vmap(scalar_sim_from_ps_point, in_axes=(0,0,None)), static_argnums=(2,))

def batch_sim_from_ps_point(key, ps_points, sigma_smear):
    key,subkey = jax.random.split(key)
    _Nevt = ps_points.shape[0]
    return vector_sim_from_ps_point(jax.random.split(subkey,_Nevt), ps_points, sigma_smear) 


if DO_TEST:
    test_evts = get_multiple_lhe_event_to_ps_point(lhe_events, 3)

    key,subkey = jax.random.split(key)   
    print(vector_sim_from_ps_point(jax.random.split(subkey,3), test_evts, 0.1))
    
    print("test events",test_evts)

# Matrix Element Likelihood

In [17]:
def get_scalar_log_me(params):
    def func(rv):
        return jnp.log(matrix_element(params, rv))+jnp.log(jacobian(params, rv))
    return func

scalar_log_me = get_scalar_log_me(params={})
vector_log_me = jax.jit(jax.vmap(scalar_log_me))


In [21]:
if DO_TEST:
    key,subkey = jax.random.split(key)   
    rv = jax.random.uniform(subkey, shape=(6,14))
    
    print(vector_log_me(rv))
    print(jax.vmap(jax.grad(scalar_log_me))(rv))

[-0.39823289  2.58311715  8.00678916  9.11183296  1.56843601  1.07406641]
[[ 3.91309208e+01  2.63827764e+01  8.69408554e+01 -1.01493099e+01
  -1.20057197e+01  6.27361322e+00 -4.34817617e+00 -5.70039718e+01
   1.93039106e+00  5.05312885e+01 -1.72416168e+00 -2.81208666e+00
  -3.64708748e+00  3.01115538e+00]
 [-1.12602413e+01 -1.86602562e+01 -1.91804533e+01  7.94995608e+00
   5.65122350e+00  8.71535019e+01  1.59598423e+00 -6.31361378e+01
  -4.94515941e+00 -2.40972648e+01  2.60365365e+00 -1.82113222e+00
  -1.90839441e+00  1.90104282e+00]
 [ 4.55217790e+01  2.56199778e+01  3.10010006e+01  1.85702392e+01
  -7.82398415e+00  7.03565398e+00 -1.69072107e+01  2.42003349e+00
   2.40029788e+01 -8.67998537e+00 -1.71147138e+00 -3.47992376e-01
  -2.14121668e+00 -4.27708355e-01]
 [ 4.88543991e+01  1.03297824e+01 -5.37947458e-01 -2.37406456e+01
   2.75719532e+01 -1.32373586e+01  5.11499504e+01 -1.77313026e+02
  -1.09950421e+01  1.91764967e+02 -2.47089527e-01 -6.35461365e-01
  -3.53595926e+00 -5.79125096

In [22]:


def get_scalar_neg_log_pdf(sigma_smear):
    def _scalar_neg_log_pdf(ucon_rv, obs):
        rv  = SigmoidStraightThrough(ucon_rv)
    
        PS_point      = ps_vec({}, rv)
        final_state   = PS_point[-6:]
    
        log_smear_pdf = jnp.sum(jax.scipy.stats.norm.logpdf(obs[:,1:4],final_state[:,1:4], 
                                                            sigma_smear*final_state[:,1:4]))
    
        return -1.0*log_smear_pdf -1.0*scalar_log_me(rv)
    
    return _scalar_neg_log_pdf


scalar_neg_log_pdf = get_scalar_neg_log_pdf(sigma_smear)
scalar_grad_neg_log_pdf = jax.grad(scalar_neg_log_pdf, argnums=0)

scalar_neg_log_pdf=jax.jit(scalar_neg_log_pdf)
scalar_grad_neg_log_pdf=jax.jit(scalar_grad_neg_log_pdf)
    
vector_neg_log_pdf = jax.jit( jax.vmap(scalar_neg_log_pdf, in_axes=(0,0)) )
vector_grad_neg_log_pdf = jax.jit( jax.vmap(jax.grad(scalar_neg_log_pdf, argnums=0), in_axes=(0,0)), )


In [25]:

if DO_TEST:
    key,subkey = jax.random.split(key)
    urv = jax.random.normal(subkey, shape=(14,))
    rv  = SigmoidStraightThrough(urv)
    
    key,subkey = jax.random.split(key)
    obs = scalar_sim(subkey, rv, 0.1)
    
    print("scalar_log_me", scalar_log_me(rv))
    print("grad_scalar_log_me",jax.grad(scalar_log_me)(rv))
    
    print("scalar_neg_log_pdf", scalar_neg_log_pdf(urv, obs))
    print("scalar_grad_neg_log_pdf", scalar_grad_neg_log_pdf(urv, obs))
    
    key,subkey = jax.random.split(key)
    urv = jax.random.normal(subkey, shape=(5,14))
    rv  = SigmoidStraightThrough(urv)

    key,subkey = jax.random.split(key)   
    obs = vector_sim(jax.random.split(subkey,5), rv, 0.1)

    #print("nll", vector_neg_log_pdf(urv, obs))
    #print("grad_nll", vector_grad_neg_log_pdf(urv, obs))


scalar_neg_log_pdf 45.23359799331856
scalar_grad_neg_log_pdf [ -49.50908539   36.27800611   34.07986498  -12.5290706   -15.94060717
   79.35361302 -103.58978229  129.34188054 -117.59686584  469.53723295
 -140.04301258  181.30158984  240.47120832 -223.75090463]


# Combinatorics

In [26]:
#Define potential parton permutation
#  for doing combinatorial likelihood

comb = [0,1,2,3,4,5]
_perms = permutations(comb)

perms = onp.array([onp.array(i) for i in _perms])

#selecting b's in right place
_c1 = (perms[:,0]==0) & (perms[:,3]==3)
_c2 = (perms[:,0]==3) & (perms[:,3]==0)
_cb = (_c1)|(_c2)
perms_correct_b = perms[_cb]

#ignoring swapping jets in same W

_w1 = (perms[:,1]==2) & (perms[:,2]==2)
_w2 = (perms[:,1]==5) & (perms[:,2]==4)
_w3 = (perms[:,4]==2) & (perms[:,5]==2)
_w4 = (perms[:,4]==5) & (perms[:,5]==4)
_ww = (~_w1)&(~_w2)&(~_w3)&(~_w4)
perms_ignore_w_swap = perms[_ww]

#selecting b's in right place and ignoring swapping jets in same W
perms_correct_b_ignore_w_swap = perms[_cb&_ww]

match_combo=onp.array([[0,1,2,3,4,5],
                       [0,2,1,3,4,5],
                       [0,1,2,3,5,4],
                       [0,2,1,3,5,4],
                       [3,4,5,0,1,2],
                       [3,4,5,0,2,1],
                       [3,5,4,0,1,2],
                       [3,5,4,0,2,1]])

# Combinatorial Likelihood

In [27]:

def single_event_urv_opt(optimizer, obs, combos, lr=.1, max_iter=100, tol=0.001):
    
    obs_combos = obs[combos]
    obs_combos_unc_rv_init = vector_unc_rv_from_obs(obs_combos, E_cm)
    
    return run_vopt(optimizer, obs_combos_unc_rv_init, obs_combos, lr=lr, max_iter=max_iter, tol=tol)


In [28]:
optimizer = get_vopt_grad_descent(scalar_neg_log_pdf, scalar_grad_neg_log_pdf)

In [45]:
if True: #DO_TEST:
    key,subkey = jax.random.split(key)
    urv = jax.random.normal(subkey, shape=(14,))
    rv  = SigmoidStraightThrough(urv)
    
    key,subkey = jax.random.split(key)
    obs = scalar_sim(subkey, rv, 0.1)
    
    
    start_time = timeit.default_timer()
    results = single_event_urv_opt(optimizer, obs, perms_correct_b, lr=1.0e-6, max_iter=500, tol=0.1)
    elapsed = timeit.default_timer() - start_time
    print("time", elapsed,"\n")


time 0.003752490999886504 



In [46]:
if True: #DO_TEST:
    start_time = timeit.default_timer()
    i_x_min = onp.nanargmin(results["loss"])
    elapsed = timeit.default_timer() - start_time
    print("time", elapsed,"\n")
    print(i_x_min)
    print(perms_correct_b[i_x_min])
    print(results['loss'][i_x_min])
    print(results['loss'][0])
    print(results['l_diff'][i_x_min])
    print(results['n'])
    print(results['x'][i_x_min])
    print(urv)

    psp, _ = ps_gen.generateKinematics(E_cm, SigmoidStraightThrough(results['x'][i_x_min]))
    best_match_kin = jnp.array([p.vector for p in psp])

    psp, _ = ps_gen.generateKinematics(E_cm, SigmoidStraightThrough(results['x'][0]))
    zero_kin = jnp.array([p.vector for p in psp])

    psp, _ = ps_gen.generateKinematics(E_cm, SigmoidStraightThrough(urv))
    true_kin = jnp.array([p.vector for p in psp])
    print('best_match_kin', best_match_kin)
    print('zero_kin', zero_kin)
    print('true_kin', true_kin)


time 3.6110227439999107 

23
[0 5 4 3 2 1]
40.8541306235044
55.90916299405863
0.09997274421905189
[ 46  33  12  12  63  66  28  22   6   8  47 106  10  10   6   7   6   6
  89  94 105 109 116 118   2   2  18  17   5   4   2   2  14  14   3   3
  24  24  17  17  16  16   2   2   4   4  10  10]
[-0.99930647 -4.16366677  0.9966058  -1.35351904 -0.34983634 -0.71120091
 -1.15974196 -1.4486906   1.68605053 -1.11978722  1.29164872 -1.1243995
 -1.41053849 -0.49838958]
[-0.99556443 -1.27703371  0.40746619  0.5186473  -0.41511019 -0.63299566
  0.19135159  0.96979224 -0.35881098  0.65964965  2.12391881  0.55903056
  1.73695582  0.56314504]
best_match_kin [[ 250.            0.            0.          250.        ]
 [ 250.            0.            0.         -250.        ]
 [ 112.13581186  -52.75367503   96.91771715  -19.3999014 ]
 [ 131.776418     68.10710589   85.78233929  -73.26552215]
 [  30.98384512   -1.65802489   -6.75877008   30.192195  ]
 [  52.27325171   -2.85472237   -8.438839     51.2936

# Chi2 approach

In [47]:

def get_scalar_chi2(sigma_jet=0.1, is_parton_mass_basis=False):
    def _chi2(final_state, obs):
        _mt = 173.0
        _gt = 1.5
        _mw = 80.4
        _gw = 2.1
        _mq = 5.0e-3
        _mb = 4.7
        
        _fs_masses = jnp.array([_mb, _mq, _mq, _mb, _mq, _mq])
        
        

        #_parton = jnp.reshape(final_state, (6,4))

        #_parton = jax.lax.cond(is_parton_mass_basis, _parton, vector_convert_to_energy_basis, _parton, lambda x: x)
        
        _parton = vector_convert_to_energy_basis( jnp.hstack((_fs_masses.reshape(6,1),final_state)) )

        w1 = _parton[1]+_parton[2]
        t1 = _parton[0]+_parton[1]+_parton[2]
        w2 = _parton[4]+_parton[5]
        t2 = _parton[3]+_parton[4]+_parton[5]

        mw1 = jnp.sqrt(w1[0]*w1[0] - w1[1]*w1[1] - w1[2]*w1[2] - w1[3]*w1[3])
        mt1 = jnp.sqrt(t1[0]*t1[0] - t1[1]*t1[1] - t1[2]*t1[2] - t1[3]*t1[3])

        mw2 = jnp.sqrt(w2[0]*w2[0] - w2[1]*w2[1] - w2[2]*w2[2] - w2[3]*w2[3])
        mt2 = jnp.sqrt(t2[0]*t2[0] - t2[1]*t2[1] - t2[2]*t2[2] - t2[3]*t2[3])
        
        
        
        #log_smear_pdf = jnp.sum(jax.scipy.stats.norm.logpdf(obs,_parton, _parton*sigma_smear))
        log_smear_pdf = jnp.sum( jnp.square((obs[:,1:4] - _parton[:,1:4]))/jnp.square((_parton[:,1:4]*sigma_smear)) )

        
        chi2_out = (jnp.square((mw1-_mw)/_gw)+jnp.square((mt1-_mt)/_gt)+
                    jnp.square((mw2-_mw)/_gw)+jnp.square((mt2-_mt)/_gt)) 
        #chi2_out = 8.0*jnp.log(2.0)*chi2_out - 2.0*log_smear_pdf
        
        chi2_out = chi2_out + log_smear_pdf
                                               
        return chi2_out
    
    return _chi2
    
scalar_chi2 = jax.jit(get_scalar_chi2(sigma_jet=sigma_smear, is_parton_mass_basis=False))
scalar_grad_chi2 = jax.jit(jax.grad(get_scalar_chi2(sigma_jet=sigma_smear, is_parton_mass_basis=False), argnums=0))

                                                                                             

In [49]:
DO_TEST=True
if DO_TEST:
    key,subkey = jax.random.split(key)
    urv = jax.random.normal(subkey, shape=(14,))
    rv  = SigmoidStraightThrough(urv)
    
    PS_point    = ps_vec({}, rv)
    final_state = PS_point[-6:,1:4]
    
    key,subkey = jax.random.split(key)
    obs = scalar_sim(subkey, rv, 0.1)
    
    print(scalar_chi2(final_state, obs))
    print(scalar_grad_chi2(final_state, obs))

3850.029828852817
[[ 24.36114502  -4.73328201 -36.08506191]
 [ 67.02685351 -17.38592002  59.32339561]
 [-79.55795511  19.53675543  -9.47421831]
 [-40.15663951 -32.22110977   2.78226001]
 [-80.85334212 -48.50850753  -3.67035613]
 [ 31.88082857  20.76512586   0.43435586]]


In [50]:
def single_event_chi2_opt(optimizer, obs, combos, lr=.1, max_iter=100, tol=0.001):
    
    obs_combos = obs[combos]
    x_init = obs_combos[:,:,1:4]
        
    return run_vopt(optimizer, x_init, obs_combos, lr=lr, max_iter=max_iter, tol=tol)


def get_scalar_event_chi2_opt(optimizer, lr=.1, max_iter=100, tol=0.001):
    def _scalar_event_chi2_opt(obs, combos):
        obs_combos = obs[combos]
        x_init = obs_combos[:,:,1:4]
        return run_vopt(optimizer, x_init, obs_combos, lr=lr, max_iter=max_iter, tol=tol)
    return _scalar_event_chi2_opt


        

In [51]:
optimizer_chi2 = get_vopt_grad_descent(scalar_chi2, scalar_grad_chi2)

#scalar_event_chi2_opt = get_scalar_event_chi2_opt(optimizer_chi2, lr=.001, max_iter=3000, tol=0.001)
#vector_event_chi2_opt = jax.vmap(scalar_event_chi2_opt)

In [53]:
DO_TEST=True
if DO_TEST:
    key,subkey = jax.random.split(key)
    urv = jax.random.normal(subkey, shape=(14,))
    rv  = SigmoidStraightThrough(urv)
    
    key,subkey = jax.random.split(key)
    obs = scalar_sim(subkey, rv, 0.1)
    
    print(obs)
    
    start_time = timeit.default_timer()
    results = single_event_chi2_opt(optimizer_chi2, obs, perms_correct_b, lr=5e-4, max_iter=300000, tol=0.1)
    #results = scalar_event_chi2_opt(obs, perms_correct_b)
    elapsed = timeit.default_timer() - start_time
    print("results time", elapsed,"\n")
    
    start_time = timeit.default_timer()
    i_x_min = onp.nanargmin(results["loss"])
    elapsed = timeit.default_timer() - start_time
    print("access time", elapsed,"\n")
    print("i_x_min", i_x_min)
    print("perms_correct_b[i_x_min]", perms_correct_b[i_x_min])
    print("results['loss'][i_x_min]", results['loss'][i_x_min])
    print("results['loss'][0]", results['loss'][0])
    print("results['l_diff'][i_x_min]", results['l_diff'][i_x_min])
    print("results['n']", results['n'])
        
    fs_masses = jnp.array([4.7, 5.0e-3, 5.0e-3, 4.7, 5.0e-3, 5.0e-3])
    partons_i_x_min = vector_convert_to_energy_basis(jnp.hstack( (fs_masses.reshape(6,1), results['x'][i_x_min]) ))
    partons_0 = vector_convert_to_energy_basis(jnp.hstack( (fs_masses.reshape(6,1),results['x'][0]) ))
    
    print("best urv", scalar_unc_rv_from_obs(partons_i_x_min, E_cm))
    print("true urv", urv)

    psp, _ = ps_gen.generateKinematics(E_cm, SigmoidStraightThrough(urv))
    true_kin = jnp.array([p.vector for p in psp])
    print('best_match_kin', partons_i_x_min)
    print('zero_kin', partons_0)
    print('true_kin', true_kin[-6:])

[[ 58.77849951  25.0838613   52.3159971   -8.16445678]
 [ 76.25891735 -75.71786691  -5.93856852  -6.85277393]
 [ 81.28861709 -68.84936061 -43.16451525  -2.1046224 ]
 [117.22646636  73.59310264 -67.06410334 -61.69615631]
 [ 99.42841719  33.32865315  92.17379676  16.70934512]
 [ 79.6812573   -4.53822812 -24.90191312  75.55396729]]
results time 0.004797403999873495 

access time 0.0651016689998869 

i_x_min 10
perms_correct_b[i_x_min] [0 2 5 3 1 4]
results['loss'][i_x_min] 312.8243097482939
results['loss'][0] 397.2525355344082
results['l_diff'][i_x_min] 0.09989159684238302
results['n'] [2111 2111 1650 1650 1690 1690 2111 2111 1796 1796 1665 1665 1650 1650
 1796 1796 2563 2563 1690 1690 1665 1665 2563 2563 2563 2563 1665 1665
 1796 1796 2563 2563 1690 1690 1650 1650 1665 1665 1690 1690 2111 2111
 1796 1796 1650 1650 2111 2111]
best urv [-1.88749395  0.00522156 -0.05978499 -0.3598741  -0.31064048 -1.51611898
 -0.17618421  0.2886887   3.22079463  0.57596079 -0.94415154  1.74820943
 -0.425571

In [54]:
if DO_TEST:
    key,subkey = jax.random.split(key)
    Nevt=10
    observed, final_state, urv, rv = batch_sim_from_ucon_rv(subkey, Nevt=Nevt, sigma_smear=0.1)
    
    all_results = []
    for i in range(Nevt):
        all_results.append( single_event_chi2_opt(optimizer_chi2, observed[i], perms_correct_b, 
                                                  lr=5e-4, max_iter=300000, tol=0.1) )
        

In [55]:
def reco_ttbar(partons):
    w1 = partons[1]+partons[2]
    t1 = partons[0]+partons[1]+partons[2]
    w2 = partons[4]+partons[5]
    t2 = partons[3]+partons[4]+partons[5]
    
    return jnp.vstack((w1,t1,w2,t2))

In [56]:
for i in range(Nevt):
    r=all_results[i]
    i_x_min = onp.nanargmin(r["loss"])
    print("i_x_min", i_x_min)
    print("perms_correct_b[i_x_min]", perms_correct_b[i_x_min])
    
    print(final_state[i])
    reco_tt = reco_ttbar(final_state[i])
    print(reco_tt)
    print(mass_from_energy_basis(reco_tt[0]))
    
    print("masses", vector_mass_from_energy_basis(reco_tt))

i_x_min 24
perms_correct_b[i_x_min] [3 1 2 0 4 5]
[[ 115.37033167   15.36197631   90.34456704  -69.92919504]
 [  86.36914734  -60.33482275  -47.16507533   39.93487756]
 [  57.50556518   43.2452963    37.87655599   -1.44944178]
 [ 128.10080033  -65.65150012 -108.90300721  -14.75603595]
 [  31.15651789    1.75162854  -16.27922111   26.50749639]
 [  81.49763758   65.62742172   44.12618062   19.69229883]]
[[143.87471253 -17.08952645  -9.28851935  38.48543578]
 [259.2450442   -1.72755014  81.0560477  -31.44375926]
 [112.65415547  67.37905026  27.84695951  46.19979522]
 [240.7549558    1.72755014 -81.0560477   31.44375926]]
137.26061208965322
masses [137.26061209 244.2257473   72.39577404 224.50205222]
i_x_min 10
perms_correct_b[i_x_min] [0 2 5 3 1 4]
[[ 120.81985239 -100.77817204   66.35714608   -3.97943921]
 [  45.49223013  -41.59143826   14.5757851    11.28014869]
 [  74.53808582  -13.65473133   12.10427843  -72.270056  ]
 [ 135.30564095   62.84613863 -113.18036068   39.06526959]
 [  80.8

# Comparisons
    

In [None]:
if DO_TEST:
    key,subkey = jax.random.split(key)
    urv = jax.random.normal(subkey, shape=(14,))
    rv  = SigmoidStraightThrough(urv)
    
    key,subkey = jax.random.split(key)
    obs = scalar_sim(subkey, rv, 0.1)
    
    print(obs)
    
    
    start_time = timeit.default_timer()
    results_lk = single_event_urv_opt(optimizer, obs, perms_correct_b, lr=1.0e-6, max_iter=500, tol=0.1)
    elapsed = timeit.default_timer() - start_time
    print("time", elapsed,"\n")
    
    start_time = timeit.default_timer()
    results_ch2 = single_event_chi2_opt(optimizer_chi2, obs, perms_correct_b, lr=1e-6, max_iter=300000, tol=0.1)
    #results = scalar_event_chi2_opt(obs, perms_correct_b)
    elapsed = timeit.default_timer() - start_time
    print("results time", elapsed,"\n")

In [66]:
if DO_TEST:
    Nevts = 100
    events = get_multiple_lhe_event_to_ps_point(lhe_events,Nevts)
    
    key,subkey = jax.random.split(key)   
    obs_evts, fs_evts, urv_evts, rv_evts = vector_sim_from_ps_point(jax.random.split(subkey,Nevts), events, 0.1)
   

    all_results_lk = []
    all_results_ch2 = []
    for i in range(Nevts):
        obs = obs_evts[i]
        
        #print(obs)
        print("/n", "### Event", i, "### \n")
    
        start_time = timeit.default_timer()
        results_lk = single_event_urv_opt(optimizer, obs, perms_correct_b, lr=1.0e-6, max_iter=500, tol=0.1)
        elapsed = timeit.default_timer() - start_time
        print("lk time", elapsed,"\n")
        all_results_lk.append(results_lk)
    
        start_time = timeit.default_timer()
        results_ch2 = single_event_chi2_opt(optimizer_chi2, obs, perms_correct_b, lr=1e-6, max_iter=300000, tol=0.1)
        #results = scalar_event_chi2_opt(obs, perms_correct_b)
        elapsed = timeit.default_timer() - start_time
        print("ch2 time", elapsed,"\n")
        all_results_ch2.append(results_ch2)
       

time 0.004032816999824718 

results time 0.005564683000557125 

time 0.4577588550000655 

results time 0.6771440200000143 

time 3.316393168000104 

results time 1.2703770199996143 

time 0.4582820779996837 

results time 2.5104766129998097 

time 6.66825284100014 

results time 2.3272176339996804 

time 1.4794173880000017 

results time 0.005367049000597035 

time 3.452600736000022 

results time 0.013590427999588428 

time 1.5583997419998923 

results time 0.005946567000137293 

time 14.945070813999337 

results time 1.75231644699943 

time 5.38435832399955 

results time 0.01614917400002014 

time 1.581507676000001 

results time 1.0908294520004347 

time 11.75907335699958 

results time 0.7844059349999952 

time 5.607579401000294 

results time 1.036305898000137 

time 17.494528422999792 

results time 0.006542780000017956 

time 0.8671530329993402 

results time 1.785588925000411 

time 13.711452180999913 

results time 0.0058324159999756375 

time 0.9824965800007703 

results tim

In [80]:
n_match_lk = 0
n_match_ch2 = 0

print(match_combo)

for i in range(Nevts):
    r_lk=all_results_lk[i]
    i_x_min_lk = onp.nanargmin(r_lk["loss"])
    print("i_x_min_lk", i_x_min_lk)
    print("perms_correct_b[i_x_min_lk]", perms_correct_b[i_x_min_lk])
    
    if perms_correct_b[i_x_min_lk].tolist() in match_combo.tolist(): n_match_lk += 1 
        
    
    r_ch2=all_results_ch2[i]
    i_x_min_ch2 = onp.nanargmin(r_ch2["loss"])
    print("i_x_min_ch2", i_x_min_ch2)
    print("perms_correct_b[i_x_min_ch2]", perms_correct_b[i_x_min_ch2])
    
    if perms_correct_b[i_x_min_ch2].tolist() in match_combo.tolist(): n_match_ch2 += 1
    
    print("")
    
print("n_match_lk",n_match_lk)
print("n_match_ch2",n_match_ch2)

[[0 1 2 3 4 5]
 [0 2 1 3 4 5]
 [0 1 2 3 5 4]
 [0 2 1 3 5 4]
 [3 4 5 0 1 2]
 [3 4 5 0 2 1]
 [3 5 4 0 1 2]
 [3 5 4 0 2 1]]
i_x_min_lk 46
perms_correct_b[i_x_min_lk] [3 5 4 0 1 2]
i_x_min_ch2 0
perms_correct_b[i_x_min_ch2] [0 1 2 3 4 5]

i_x_min_lk 0
perms_correct_b[i_x_min_lk] [0 1 2 3 4 5]
i_x_min_ch2 0
perms_correct_b[i_x_min_ch2] [0 1 2 3 4 5]

i_x_min_lk 7
perms_correct_b[i_x_min_lk] [0 2 1 3 5 4]
i_x_min_ch2 0
perms_correct_b[i_x_min_ch2] [0 1 2 3 4 5]

i_x_min_lk 23
perms_correct_b[i_x_min_lk] [0 5 4 3 2 1]
i_x_min_ch2 23
perms_correct_b[i_x_min_ch2] [0 5 4 3 2 1]

i_x_min_lk 0
perms_correct_b[i_x_min_lk] [0 1 2 3 4 5]
i_x_min_ch2 0
perms_correct_b[i_x_min_ch2] [0 1 2 3 4 5]

i_x_min_lk 1
perms_correct_b[i_x_min_lk] [0 1 2 3 5 4]
i_x_min_ch2 0
perms_correct_b[i_x_min_ch2] [0 1 2 3 4 5]

i_x_min_lk 46
perms_correct_b[i_x_min_lk] [3 5 4 0 1 2]
i_x_min_ch2 46
perms_correct_b[i_x_min_ch2] [3 5 4 0 1 2]

i_x_min_lk 41
perms_correct_b[i_x_min_lk] [3 4 5 0 2 1]
i_x_min_ch2 0
perms_correct

In [81]:
perms_correct_b

array([[0, 1, 2, 3, 4, 5],
       [0, 1, 2, 3, 5, 4],
       [0, 1, 4, 3, 2, 5],
       [0, 1, 4, 3, 5, 2],
       [0, 1, 5, 3, 2, 4],
       [0, 1, 5, 3, 4, 2],
       [0, 2, 1, 3, 4, 5],
       [0, 2, 1, 3, 5, 4],
       [0, 2, 4, 3, 1, 5],
       [0, 2, 4, 3, 5, 1],
       [0, 2, 5, 3, 1, 4],
       [0, 2, 5, 3, 4, 1],
       [0, 4, 1, 3, 2, 5],
       [0, 4, 1, 3, 5, 2],
       [0, 4, 2, 3, 1, 5],
       [0, 4, 2, 3, 5, 1],
       [0, 4, 5, 3, 1, 2],
       [0, 4, 5, 3, 2, 1],
       [0, 5, 1, 3, 2, 4],
       [0, 5, 1, 3, 4, 2],
       [0, 5, 2, 3, 1, 4],
       [0, 5, 2, 3, 4, 1],
       [0, 5, 4, 3, 1, 2],
       [0, 5, 4, 3, 2, 1],
       [3, 1, 2, 0, 4, 5],
       [3, 1, 2, 0, 5, 4],
       [3, 1, 4, 0, 2, 5],
       [3, 1, 4, 0, 5, 2],
       [3, 1, 5, 0, 2, 4],
       [3, 1, 5, 0, 4, 2],
       [3, 2, 1, 0, 4, 5],
       [3, 2, 1, 0, 5, 4],
       [3, 2, 4, 0, 1, 5],
       [3, 2, 4, 0, 5, 1],
       [3, 2, 5, 0, 1, 4],
       [3, 2, 5, 0, 4, 1],
       [3, 4, 1, 0, 2, 5],
 

True

AttributeError: 'numpy.ndarray' object has no attribute 'count'