In [1]:
import sys
import timeit
import tracemalloc
import logging
import warnings

warnings.simplefilter("ignore")

# numpy & friends
import numpy as onp
import matplotlib.pyplot as plt

%matplotlib inline
import scipy.optimize

import random as random_orig
from itertools import permutations
from functools import partial

# jax and friends
import jax
import jax.numpy as jnp

jax.config.update('jax_enable_x64', True)
key = jax.random.PRNGKey(0)

# physics stuff
import pylhe

# Madjax
import madjax


DO_TEST = False



In [2]:
# Memory tracing
def print_memory_stuff(note=""):
    current, peak = tracemalloc.get_traced_memory()
    print(
        note, f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB"
    )

# Setup Gradient Descent

In [3]:
# Custom sigmoid, with identity gradient
# Easy way to implement natural gradient descent
#  on sigmoid reparameterization of unconstrained dual space


@jax.custom_jvp
def SigmoidStraightThrough(x):
    return jax.nn.sigmoid(x)


SigmoidStraightThrough.defjvps(lambda t, ans, x: t)

In [4]:
# Batched Gradient descent optimizer
# Since SigmoidStraightThrough takes care of jacobian cancellation
#  in natural gradient descent (more numerically stable)
#  ... so nothing extra has to be added here
#
# INPUT:
#   loss      - vectorized function to be optimized
#   grad_loss - vectorized function, gradient of loss
#   u_init    - batched initial guesses
#   lr        - float, learning rate / step size
#   max_iter  - int, maximum number of iterations before stop
#   tol       - float, tolerance of loss change before stop
#
# OUTPUT: dictionary (all values aer batched)
#   "x" - input value at end
#   "loss" - loss at end
#   "n" - number of iteration to finish


def get_vopt_grad_descent(loss, grad_loss):
    def _cond_fun(vals):

        _not_nan = jnp.invert(jnp.isnan(vals['loss']))
        _check_tol = vals['l_diff'] > vals['tol']
        _check_n = vals['n'] < vals['max_iter']

        # update what elements to keep updating
        _keep_going = _not_nan & _check_tol & _check_n

        return _keep_going

    def _body_fun(vals):
        jax.lax.stop_gradient(vals)

        lr = vals['lr']
        grad = jnp.clip(
            grad_loss(vals['x'], vals['aux']),
            a_min=-1.0 * vals['clip_grad'],
            a_max=vals['clip_grad'],
        )
        x_out_0 = vals['x'] - lr * grad
        loss_out_0 = loss(x_out_0, vals['aux'])

        lr = jax.lax.cond(jnp.isnan(loss_out_0), lambda x: 0.1 * x, lambda x: x, lr)
        x_out = vals['x'] - lr * grad
        loss_out = loss(x_out, vals['aux'])

        del x_out_0
        del loss_out_0

        # try to fix nan from too big step size, if this lower LR still fails then give up
        # lr = jnp.where(jnp.isnan(loss_out_0), 0.1*lr_0, lr_0)
        # lr = jnp.isnan(loss_out_0)*0.1*lr_0 + jnp.invert(jnp.isnan(loss_out_0))*lr_0
        # x_out      = vals['x'] - lr*grad_loss(vals['x'], vals['aux'])
        # loss_out   = loss(x_out, vals['aux'])

        l_diff_out = jnp.abs(loss_out - vals['loss'])
        n_iter_out = vals['n'] + 1

        # print("n", vals['n'])
        # print("x_in", vals['x'])
        # print("x_out", x_out)
        # print("loss_in", vals['loss'])
        # print("loss_out", loss_out)
        # print("grad", grad_loss(vals['x'], vals['aux']))

        return {
            'x': x_out,
            'aux': vals['aux'],
            'loss': loss_out,
            'l_diff': l_diff_out,
            'n': n_iter_out,
            'lr': lr,
            'max_iter': vals['max_iter'],
            'tol': vals['tol'],
            'clip_grad': vals['clip_grad'],
        }

    def _while_fun(vals):
        return jax.lax.while_loop(_cond_fun, _body_fun, vals)

    dct_axes = {
        'l_diff': 0,
        'loss': 0,
        'n': 0,
        'x': 0,
        'aux': 0,
        'lr': 0,
        'max_iter': None,
        'tol': None,
        'clip_grad': None,
    }

    _result = jax.jit(jax.vmap(_while_fun, in_axes=(dct_axes,)))
    return _result


def run_vopt(
    optimizer, x_init, aux_input, lr=0.1, max_iter=100, tol=0.001, clip_grad=1e8
):

    loss_val = 10.0 * jnp.ones(x_init.shape[0])  # loss(x_init, aux_input)
    loss_diff = 10.0 * jnp.ones(x_init.shape[0])
    n_iter = jnp.zeros((x_init.shape[0]), dtype=int)
    lr_init = lr * jnp.ones(x_init.shape[0])

    in_vals = {
        'x': x_init,
        'aux': aux_input,
        'loss': loss_val,
        'l_diff': loss_diff,
        'n': n_iter,
        'lr': lr_init,
        'max_iter': max_iter,
        'tol': tol,
        'clip_grad': clip_grad,
    }

    return optimizer(jax.lax.stop_gradient(in_vals))

In [5]:
if DO_TEST:

    def f(x, y):
        return jnp.sum(x * x) + jnp.sum(y)

    f = jax.jit(f)
    df_dx = jax.jit(jax.grad(f, argnums=0))

    _x = jnp.array([[1.0, 0.0], [2.0, 0.0], [3.0, 0.0], [4.0, 0.0], [5.0, 0.0]])
    _y = jnp.array([[1.0, 2.0], [3.0, 4.0], [1.0, 2.0], [3.0, 4.0], [1.0, 2.0]])

    print(f(_x, _y))
    print(df_dx(_x, _y))

    optimizer = get_vopt_grad_descent(f, df_dx)

    for i in range(5):
        start_time = timeit.default_timer()
        run_vopt(optimizer, _x, _y, lr=0.1, max_iter=100, tol=0.001)
        elapsed = timeit.default_timer() - start_time
        print(i, "time", elapsed, "\n")

    print("second \n")
    # optimizer = get_vopt_grad_descent(f, df_dx, lr=.1, max_iter=100, tol=0.001)

    for i in range(5):
        start_time = timeit.default_timer()
        _res = run_vopt(optimizer, _x, _y, lr=0.1, max_iter=200, tol=0.001)
        elapsed = timeit.default_timer() - start_time
        print(i, "time", elapsed, "\n")

    print(_res)

# 4-Vector Operations

In [6]:
def energy_from_mass_basis(p):
    return jnp.sqrt(jnp.sum(jnp.square(p)))


def mass_from_energy_basis(p):
    _m = p[0] * p[0] - p[1] * p[1] - p[2] * p[2] - p[3] * p[3]
    _m = jnp.where(_m < 1e-5, 1e-5, jnp.sqrt(_m))
    return _m


vector_energy_from_mass_basis = jax.jit(jax.vmap(energy_from_mass_basis, in_axes=0))
vector_mass_from_energy_basis = jax.jit(jax.vmap(mass_from_energy_basis, in_axes=0))


def scalar_convert_to_energy_basis(p):
    return jnp.hstack(([energy_from_mass_basis(p)], p[1:4]))
    # _energy = energy_from_mass_basis(p)
    # jax.ops.index_update(p, jax.ops.index[0], _energy)
    # return p


def scalar_convert_to_mass_basis(p):
    return jnp.hstack(([mass_from_energy_basis(p)], p[1:4]))
    # _mass = mass_from_energy_basis(p)
    # jax.ops.index_update(p, jax.ops.index[0], _mass)
    # return p


vector_convert_to_energy_basis = jax.jit(
    jax.vmap(scalar_convert_to_energy_basis, in_axes=0)
)
vector_convert_to_mass_basis = jax.jit(
    jax.vmap(scalar_convert_to_mass_basis, in_axes=0)
)

# Setup MadJax

In [7]:
E_cm = 500.0
config_name = "ee_ttbar_bqq_bqq"
process_name = "Matrix_1_epem_ttx_t_budx_tx_bxdux"
nDimPS = 14

mj = madjax.MadJax(config_name=config_name)
matrix_element = mj.matrix_element(
    E_cm=E_cm, process_name=process_name, return_grad=False, do_jit=False
)
jacobian = mj.jacobian(E_cm=E_cm, process_name=process_name, do_jit=False)
ps_gen = mj.phasespace_generator(E_cm=E_cm, process_name=process_name)({})
ps_vec = mj.phasespace_vectors(E_cm=E_cm, process_name=process_name)

external_params = {}

sigma_smear = 0.1

## LHE Event Manipulation

In [8]:
# get LHE events generator
lhe_events = pylhe.readLHE("./data/ee_ttbar_bqq_bqq/unweighted_events.lhe")

In [9]:
def lhe_event_to_ps_point(event):
    _particles = []
    for p in event.particles:
        # print(p.id, p.status, p.px, p.py, p.pz, p.e)
        if p.status == -1 or p.status == 1:
            _particles += [[p.e, p.px, p.py, p.pz]]
    return jnp.array(_particles)


def get_multiple_lhe_event_to_ps_point(lhe_event_generator, n_events):
    _evts = []
    for i in range(n_events):
        try:
            _ev = lhe_event_generator.__next__()
        except:
            print("No More Events")
            return _evts

        _ps = lhe_event_to_ps_point(_ev)
        _evts += [_ps]

    return jnp.array(_evts)

In [10]:
if DO_TEST:
    test_evts = get_multiple_lhe_event_to_ps_point(lhe_events, 3)
    print(test_evts)

## Simulation and InverseKinematics

In [11]:
def convert_to_PS_point(incoming, observed):
    _out = jax.numpy.vstack((incoming, observed))
    return [madjax.phasespace.vectors.Vector(p) for p in _out]


def scalar_unc_rv_from_obs(obs, E_cm, clip=None):
    incoming = jnp.array(
        [[E_cm / 2.0, 0, 0, E_cm / 2.0], [E_cm / 2.0, 0, 0, -E_cm / 2.0]]
    )
    obs_vec = convert_to_PS_point(incoming=incoming, observed=obs)
    i_rv, i_wt = ps_gen.invertKinematics(E_cm, obs_vec)
    i_unc_rv = jax.scipy.special.logit(i_rv)

    if clip is not None:
        i_unc_rv = jnp.clip(i_unc_rv, a_min=-1.0 * clip, a_max=1.0 * clip)

    return i_unc_rv


vector_unc_rv_from_obs = jax.jit(
    jax.vmap(scalar_unc_rv_from_obs, in_axes=(0, None, None)), static_argnums=(1, 2)
)

# if DO_TEST:
#    key,subkey = jax.random.split(key)
#    urv = jax.random.normal(subkey, shape=(5,14))
#    rv  = SigmoidStraightThrough(urv)

#    key,subkey = jax.random.split(key)
#    obs = vector_sim(jax.random.split(subkey,5), rv, 0.1)

#    print(vector_unc_rv_from_obs(obs,E_cm))

In [12]:
def gauss_smear(key, final_state, sigma_all):

    key, subkey = jax.random.split(key)
    smear = jax.random.normal(subkey, final_state.shape) * sigma_all
    observed = final_state * (1 + smear)

    return observed


def gauss_p_smear(key, final_state, sigma_p):
    _fs_m = vector_convert_to_mass_basis(final_state)
    _n = final_state.shape[0]

    key, subkey = jax.random.split(key)
    p_smear = jax.random.normal(subkey, (_n, 3)) * sigma_p
    new_p = _fs_m[:, 1:4] * (1 + p_smear)

    observed = vector_convert_to_energy_basis(
        jnp.hstack((jnp.reshape(_fs_m[:, 0], (_n, 1)), new_p))
    )
    return observed


def gauss_p_lognormal_m_smear(key, final_state, sigma_p, sigma_m):

    _fs_m = vector_convert_to_mass_basis(final_state)

    _n = final_state.shape[0]

    key, subkey = jax.random.split(key)
    p_smear = jax.random.normal(subkey, (_n, 3)) * sigma_p
    new_p = _fs_m[:, 1:4] * (1 + p_smear)

    key, subkey = jax.random.split(key)
    # m_smear = jax.random.normal(subkey,(_n,))*sigma_m
    # new_m = jnp.exp( jnp.log(_fs_m[:,0]+1.)*(1+m_smear))
    mode = _fs_m[:, 0]
    var = jnp.square(_fs_m[:, 0] * sigma_m)
    mean2 = jnp.power(1.0 / mode, 2.0 / 3.0) - var

    print("mode", mode)
    print("var", var)
    print("mean2", mean2)

    sigma2 = jnp.log(1 + var / (mean2))
    mu = jnp.log(mean2 / jnp.sqrt(mean2 + var))
    new_m = jnp.exp(jax.random.normal(subkey, (_n,)) * jnp.sqrt(sigma2) + mu)

    observed = vector_convert_to_energy_basis(
        jnp.hstack((jnp.reshape(new_m, (_n, 1)), new_p))
    )
    return observed

In [13]:
if DO_TEST:
    key, subkey = jax.random.split(key)
    rv = jax.random.uniform(subkey, shape=(14,))

    PS_point = ps_vec({}, rv)
    final_state = PS_point[-6:]

    print(PS_point)

    print(final_state)

    print(gauss_smear(subkey, final_state, 0.1))

    print(gauss_p_smear(subkey, final_state, 0.1))

    print(gauss_p_lognormal_m_smear(subkey, final_state, 0.1, 0.2))

    print("LHE test")

    ev = lhe_events.__next__()

    lhe_ps = lhe_event_to_ps_point(ev)
    lhe_fs = lhe_ps[-6:]

    print(lhe_ps)

    print(lhe_fs)

    print(gauss_smear(subkey, lhe_fs, 0.1))

    print(gauss_p_smear(subkey, lhe_fs, 0.1))

    print(gauss_p_lognormal_m_smear(subkey, lhe_fs, 0.1, 0.2))

In [14]:
def scalar_sim(key, rv, sigma_smear):
    PS_point = ps_vec({}, rv)
    final_state = PS_point[-6:]

    return gauss_p_smear(key, final_state, sigma_smear)


vector_sim = jax.jit(jax.vmap(scalar_sim, in_axes=(0, 0, None)), static_argnums=(2,))


def scalar_sim_from_ucon_rv(key, sigma_smear):
    key, subkey = jax.random.split(key)
    urv = jax.random.normal(subkey, shape=(14,))
    rv = SigmoidStraightThrough(urv)

    PS_point = ps_vec({}, rv)
    final_state = PS_point[-6:]

    observed = gauss_p_smear(key, final_state, sigma_smear)

    return observed, final_state, urv, rv


vector_sim_from_ucon_rv = jax.jit(
    jax.vmap(scalar_sim_from_ucon_rv, in_axes=(0, None)), static_argnums=(1,)
)


def batch_sim_from_ucon_rv(key, Nevt, sigma_smear):
    key, subkey = jax.random.split(key)
    return vector_sim_from_ucon_rv(jax.random.split(subkey, Nevt), sigma_smear)


if DO_TEST:
    key, subkey = jax.random.split(key)
    rv = jax.random.uniform(subkey, shape=(5, 14))

    key, subkey = jax.random.split(key)
    print(vector_sim(jax.random.split(subkey, 5), rv, 0.1))

In [79]:
def scalar_sim_from_ps_point(key, ps_point, sigma_smear):
    rv, _ = ps_gen.invertKinematics(
        E_cm, [madjax.phasespace.vectors.Vector(p) for p in ps_point]
    )
    urv = jax.scipy.special.logit(rv)

    final_state = ps_point[-6:]
    observed = gauss_p_smear(key, final_state, sigma_smear)

    return observed, final_state, urv, rv


vector_sim_from_ps_point = jax.jit(
    jax.vmap(scalar_sim_from_ps_point, in_axes=(0, 0, None)), static_argnums=(2,)
)


def batch_sim_from_ps_point(key, ps_points, sigma_smear):
    key, subkey = jax.random.split(key)
    _Nevt = ps_points.shape[0]
    return vector_sim_from_ps_point(
        jax.random.split(subkey, _Nevt), ps_points, sigma_smear
    )


if True:  # DO_TEST:
    test_evts = get_multiple_lhe_event_to_ps_point(lhe_events, 3)

    key, subkey = jax.random.split(key)
    print(vector_sim_from_ps_point(jax.random.split(subkey, 3), test_evts, 0.1))

    print("test events", test_evts)

(DeviceArray([[[ 1.50269019e+02,  1.23169266e+02,  1.26457309e+01,
                8.50182649e+01],
              [ 5.35975966e+01, -2.19749802e+01,  4.47050029e+01,
                1.97804277e+01],
              [ 3.24789237e+01,  2.66603288e+01, -1.78050898e+01,
               -5.20443397e+00],
              [ 1.72892643e+02, -1.25105094e+02, -4.70599076e+01,
               -1.09562113e+02],
              [ 4.88254478e+01, -1.87172138e+01,  3.50048394e+01,
                2.84297640e+01],
              [ 4.09418048e+01, -8.63805124e+00, -3.14377235e+01,
               -2.47645916e+01]],

             [[ 2.87715540e+01,  2.64778536e+00,  2.22858856e-01,
                2.82604297e+01],
              [ 1.64999061e+02, -3.72323650e+01, -8.84243428e+01,
               -1.34237018e+02],
              [ 6.05891023e+01,  3.33877062e+01, -1.96187156e+01,
               -4.65983518e+01],
              [ 1.38756973e+02, -3.97523279e+00,  1.23493913e+02,
                6.29671236e+01],
       

# Matrix Element Likelihood

In [16]:
def get_scalar_log_me(params):
    def func(rv):
        return jnp.log(matrix_element(params, rv)) + jnp.log(jacobian(params, rv))

    return func


scalar_log_me = get_scalar_log_me(params={})
vector_log_me = jax.jit(jax.vmap(scalar_log_me))

In [17]:
if DO_TEST:
    key, subkey = jax.random.split(key)
    rv = jax.random.uniform(subkey, shape=(6, 14))

    print(vector_log_me(rv))
    print(jax.vmap(jax.grad(scalar_log_me))(rv))

In [18]:
def get_scalar_neg_log_pdf(sigma_smear):
    def _scalar_neg_log_pdf(ucon_rv, obs):
        rv = SigmoidStraightThrough(ucon_rv)

        PS_point = ps_vec({}, rv)
        final_state = PS_point[-6:]

        log_smear_pdf = jnp.sum(
            jax.scipy.stats.norm.logpdf(
                obs[:, 1:4], final_state[:, 1:4], sigma_smear * final_state[:, 1:4]
            )
        )

        return -1.0 * log_smear_pdf - 1.0 * scalar_log_me(rv)

    return _scalar_neg_log_pdf


scalar_neg_log_pdf = get_scalar_neg_log_pdf(sigma_smear)
scalar_grad_neg_log_pdf = jax.grad(scalar_neg_log_pdf, argnums=0)

scalar_neg_log_pdf = jax.jit(scalar_neg_log_pdf)
scalar_grad_neg_log_pdf = jax.jit(scalar_grad_neg_log_pdf)

vector_neg_log_pdf = jax.jit(jax.vmap(scalar_neg_log_pdf, in_axes=(0, 0)))
vector_grad_neg_log_pdf = jax.jit(
    jax.vmap(jax.grad(scalar_neg_log_pdf, argnums=0), in_axes=(0, 0)),
)

In [19]:
if True:
    key, subkey = jax.random.split(key)
    urv = jax.random.normal(subkey, shape=(14,))
    rv = SigmoidStraightThrough(urv)

    key, subkey = jax.random.split(key)
    obs = scalar_sim(subkey, rv, 0.1)

    print("scalar_log_me", scalar_log_me(rv))
    print("grad_scalar_log_me", jax.grad(scalar_log_me)(rv))

    print("scalar_neg_log_pdf", scalar_neg_log_pdf(urv, obs))
    print("scalar_grad_neg_log_pdf", scalar_grad_neg_log_pdf(urv, obs))

    key, subkey = jax.random.split(key)
    urv = jax.random.normal(subkey, shape=(5, 14))
    rv = SigmoidStraightThrough(urv)

    key, subkey = jax.random.split(key)
    obs = vector_sim(jax.random.split(subkey, 5), rv, 0.1)

    # print("nll", vector_neg_log_pdf(urv, obs))
    # print("grad_nll", vector_grad_neg_log_pdf(urv, obs))

scalar_log_me -3.098937378964017
grad_scalar_log_me [-7.59001504 -3.45517603  7.89240469 -6.75377815 -1.33683325 -5.86853228
  0.47109826  4.02900106  1.31403862  2.51051036  0.49600606  1.22637648
 -1.00222307 -1.89735701]
scalar_neg_log_pdf 39.07368150437258
scalar_grad_neg_log_pdf [ -430.67742709  -914.16154711   362.65095001  -734.93131328
   115.02317085 -1018.51098536  -152.01062104 -2350.80315405
   159.42983539  1138.72282113  -171.15810957 -3333.64886468
   243.99291581 -6790.02782951]


# Combinatorics

In [20]:
# Define potential parton permutation
#  for doing combinatorial likelihood

comb = [0, 1, 2, 3, 4, 5]
_perms = permutations(comb)

perms = onp.array([onp.array(i) for i in _perms])

# selecting b's in right place
_c1 = (perms[:, 0] == 0) & (perms[:, 3] == 3)
_c2 = (perms[:, 0] == 3) & (perms[:, 3] == 0)
_cb = (_c1) | (_c2)
perms_correct_b = perms[_cb]

# ignoring swapping jets in same W

# _w1 = (perms[:,1]==2) & (perms[:,2]==2)
# _w2 = (perms[:,1]==5) & (perms[:,2]==4)
# _w3 = (perms[:,4]==2) & (perms[:,5]==2)
# _w4 = (perms[:,4]==5) & (perms[:,5]==4)
# _ww = (~_w1)&(~_w2)&(~_w3)&(~_w4)
# perms_ignore_w_swap = perms[_ww]

_w1 = (perms[:, 1] == 1) & (perms[:, 2] == 2)
_w2 = (perms[:, 1] == 2) & (perms[:, 2] == 1)
_w3 = (perms[:, 4] == 4) & (perms[:, 5] == 5)
_w4 = (perms[:, 4] == 5) & (perms[:, 5] == 4)
_ww = ((_w1) | (_w2)) & ((_w3) | (_w4))
perms_ignore_w_swap = perms[_ww]

# selecting b's in right place and ignoring swapping jets in same W
perms_correct_b_ignore_w_swap = perms[_cb & _ww]

match_combo = onp.array(
    [
        [0, 1, 2, 3, 4, 5],
        [0, 2, 1, 3, 4, 5],
        [0, 1, 2, 3, 5, 4],
        [0, 2, 1, 3, 5, 4],
        [3, 4, 5, 0, 1, 2],
        [3, 4, 5, 0, 2, 1],
        [3, 5, 4, 0, 1, 2],
        [3, 5, 4, 0, 2, 1],
    ]
)

perms_test = onp.array(
    [
        [0, 1, 2, 3, 4, 5],
        [0, 1, 4, 3, 2, 5],
        [0, 1, 5, 3, 2, 4],
        [0, 4, 2, 3, 1, 5],
        [0, 5, 2, 3, 1, 4],
        [3, 1, 2, 0, 4, 5],
    ]
)

# Combinatorial Likelihood

In [21]:
def single_event_urv_opt(
    optimizer,
    obs,
    combos,
    lr=0.1,
    max_iter=100,
    tol=0.001,
    clip_grad=1e3,
    clip_urv_init=5,
):

    obs_combos = obs[combos]
    obs_combos_unc_rv_init = vector_unc_rv_from_obs(obs_combos, E_cm, clip_urv_init)

    return run_vopt(
        optimizer, obs_combos_unc_rv_init, obs_combos, lr=lr, max_iter=max_iter, tol=tol
    )

In [22]:
optimizer = get_vopt_grad_descent(scalar_neg_log_pdf, scalar_grad_neg_log_pdf)

In [82]:
if True:  # DO_TEST:
    # key,subkey = jax.random.split(key)
    # urv = jax.random.normal(subkey, shape=(14,))
    # rv  = SigmoidStraightThrough(urv)

    # key,subkey = jax.random.split(key)
    # obs = scalar_sim(subkey, rv, 0.1)

    test_evts = get_multiple_lhe_event_to_ps_point(lhe_events, 3)

    key, subkey = jax.random.split(key)
    obs, f_state, f_urv, f_rv = vector_sim_from_ps_point(
        jax.random.split(subkey, 3), test_evts, 0.1
    )

    start_time = timeit.default_timer()
    results = single_event_urv_opt(
        optimizer, obs[0], perms_test, lr=1.0e-7, max_iter=15000, tol=0.001
    )
    elapsed = timeit.default_timer() - start_time
    print("time", elapsed, "\n")

time 10.14837414900012 



In [83]:
if True:  # DO_TEST:
    start_time = timeit.default_timer()
    i_x_min = onp.nanargmin(results["loss"])
    elapsed = timeit.default_timer() - start_time
    print("time", elapsed, "\n")
    print(i_x_min)
    print(perms_test[i_x_min])
    print(results['loss'][i_x_min])
    print(results['loss'][0])
    print(results['l_diff'][i_x_min])
    print(results['n'])
    print(results['x'][i_x_min])
    print(urv)

    psp, _ = ps_gen.generateKinematics(
        E_cm, SigmoidStraightThrough(results['x'][i_x_min])
    )
    best_match_kin = jnp.array([p.vector for p in psp])

    psp, _ = ps_gen.generateKinematics(E_cm, SigmoidStraightThrough(results['x'][0]))
    zero_kin = jnp.array([p.vector for p in psp])

    psp, _ = ps_gen.generateKinematics(E_cm, SigmoidStraightThrough(urv))
    true_kin = jnp.array([p.vector for p in psp])
    print('best_match_kin', best_match_kin)
    print('zero_kin', zero_kin)
    print('true_kin', true_kin)
    print('obs', obs)

time 0.00012145499931648374 

0
[0 1 2 3 4 5]
21.215145938340182
21.215145938340182
0.0009933005380950988
[2622 2284 2449 1193  964 2778]
[-0.7431657  -2.65419897  2.38136785 -0.56426397  0.0928289   1.72372921
 -1.12640251  4.46181157  0.86765856  1.36047706  0.17352249 -0.49376091
  0.61769525  0.81547844]
[-0.89308976 -1.81007424 -0.29823167  0.72342139  0.60786931 -2.31401922
 -0.09852325 -0.0337255  -0.69973225  0.6953217   0.34054268  1.72478678
 -0.40460646 -2.71801418]
best_match_kin [[ 250.            0.            0.          250.        ]
 [ 250.            0.            0.         -250.        ]
 [ 104.10296337   60.32490767  -84.57527078    4.82349153]
 [ 122.78353906   97.66412821   19.82370684  -71.72542208]
 [  17.745517     -7.18874771  -10.66957624   12.22233296]
 [ 170.10745348 -137.53893693   92.94425544   36.86274458]
 [  55.33906214  -21.03572655  -46.06467239   22.31492694]
 [  29.92146495    7.77437532   28.54155714   -4.49807393]]
zero_kin [[ 250.            0.

# Chi2 approach

In [45]:
def get_scalar_chi2(sigma_jet=0.1, is_parton_mass_basis=False):
    def _chi2(final_state, obs):
        _mt = 173.0
        _gt = 1.5
        _mw = 80.4
        _gw = 2.1
        _mq = 5.0e-3
        _mb = 4.7

        _fs_masses = jnp.array([_mb, _mq, _mq, _mb, _mq, _mq])

        # _parton = jnp.reshape(final_state, (6,4))

        # _parton = jax.lax.cond(is_parton_mass_basis, _parton, vector_convert_to_energy_basis, _parton, lambda x: x)

        _parton = vector_convert_to_energy_basis(
            jnp.hstack((_fs_masses.reshape(6, 1), final_state))
        )

        w1 = _parton[1] + _parton[2]
        t1 = _parton[0] + _parton[1] + _parton[2]
        w2 = _parton[4] + _parton[5]
        t2 = _parton[3] + _parton[4] + _parton[5]

        mw1 = jnp.sqrt(w1[0] * w1[0] - w1[1] * w1[1] - w1[2] * w1[2] - w1[3] * w1[3])
        mt1 = jnp.sqrt(t1[0] * t1[0] - t1[1] * t1[1] - t1[2] * t1[2] - t1[3] * t1[3])

        mw2 = jnp.sqrt(w2[0] * w2[0] - w2[1] * w2[1] - w2[2] * w2[2] - w2[3] * w2[3])
        mt2 = jnp.sqrt(t2[0] * t2[0] - t2[1] * t2[1] - t2[2] * t2[2] - t2[3] * t2[3])

        log_smear_pdf = jnp.sum(
            jax.scipy.stats.norm.logpdf(
                obs[:, 1:4], _parton[:, 1:4], sigma_smear * _parton[:, 1:4]
            )
        )
        # log_smear_pdf = jnp.sum( jnp.square((obs[:,1:4] - _parton[:,1:4]))/jnp.square((_parton[:,1:4]*sigma_smear)) )

        chi2_out = (
            jnp.square((mw1 - _mw) / _gw)
            + jnp.square((mt1 - _mt) / _gt)
            + jnp.square((mw2 - _mw) / _gw)
            + jnp.square((mt2 - _mt) / _gt)
        )
        # chi2_out = 8.0*jnp.log(2.0)*chi2_out - 2.0*log_smear_pdf

        chi2_out = chi2_out - log_smear_pdf

        return chi2_out

    return _chi2


scalar_chi2 = jax.jit(
    get_scalar_chi2(sigma_jet=sigma_smear, is_parton_mass_basis=False)
)
scalar_grad_chi2 = jax.jit(
    jax.grad(
        get_scalar_chi2(sigma_jet=sigma_smear, is_parton_mass_basis=False), argnums=0
    )
)

In [46]:
DO_TEST = True
if DO_TEST:
    key, subkey = jax.random.split(key)
    urv = jax.random.normal(subkey, shape=(14,))
    rv = SigmoidStraightThrough(urv)

    PS_point = ps_vec({}, rv)
    final_state = PS_point[-6:, 1:4]

    key, subkey = jax.random.split(key)
    obs = scalar_sim(subkey, rv, 0.1)

    print(scalar_chi2(final_state, obs))
    print(scalar_grad_chi2(final_state, obs))

7524.545223895405
[[ 1.17916732e+02 -5.21441196e+01  4.94896001e+01]
 [-3.69618315e+01 -4.20300309e+01 -6.38620154e+01]
 [-7.94770204e+01  6.78093152e+01 -6.99613811e+00]
 [-6.08846576e+00 -1.02178065e+01 -5.83907908e-01]
 [ 4.68746817e+00  1.02546070e+01  1.19278911e+00]
 [ 9.89054706e-01  4.13617149e-01  7.06629320e-02]]


In [47]:
def single_event_chi2_opt(optimizer, obs, combos, lr=0.1, max_iter=100, tol=0.001):

    obs_combos = obs[combos]
    x_init = obs_combos[:, :, 1:4]

    return run_vopt(optimizer, x_init, obs_combos, lr=lr, max_iter=max_iter, tol=tol)


def get_scalar_event_chi2_opt(optimizer, lr=0.1, max_iter=100, tol=0.001):
    def _scalar_event_chi2_opt(obs, combos):
        obs_combos = obs[combos]
        x_init = obs_combos[:, :, 1:4]
        return run_vopt(
            optimizer, x_init, obs_combos, lr=lr, max_iter=max_iter, tol=tol
        )

    return _scalar_event_chi2_opt

In [48]:
optimizer_chi2 = get_vopt_grad_descent(scalar_chi2, scalar_grad_chi2)

# scalar_event_chi2_opt = get_scalar_event_chi2_opt(optimizer_chi2, lr=.001, max_iter=3000, tol=0.001)
# vector_event_chi2_opt = jax.vmap(scalar_event_chi2_opt)

In [51]:
DO_TEST = True
if DO_TEST:
    key, subkey = jax.random.split(key)
    urv = jax.random.normal(subkey, shape=(14,))
    rv = SigmoidStraightThrough(urv)

    key, subkey = jax.random.split(key)
    obs = scalar_sim(subkey, rv, 0.1)

    print(obs)

    start_time = timeit.default_timer()
    results = single_event_chi2_opt(
        optimizer_chi2, obs, perms_test, lr=1e-5, max_iter=300000, tol=0.01
    )
    # results = scalar_event_chi2_opt(obs, perms_correct_b)
    elapsed = timeit.default_timer() - start_time
    print("results time", elapsed, "\n")

    start_time = timeit.default_timer()
    i_x_min = onp.nanargmin(results["loss"])
    elapsed = timeit.default_timer() - start_time
    print("access time", elapsed, "\n")
    print("i_x_min", i_x_min)
    print("perms_test[i_x_min]", perms_test[i_x_min])
    print("results['loss'][i_x_min]", results['loss'][i_x_min])
    print("results['loss'][0]", results['loss'][0])
    print("results['l_diff'][i_x_min]", results['l_diff'][i_x_min])
    print("results['n']", results['n'])

    fs_masses = jnp.array([4.7, 5.0e-3, 5.0e-3, 4.7, 5.0e-3, 5.0e-3])
    partons_i_x_min = vector_convert_to_energy_basis(
        jnp.hstack((fs_masses.reshape(6, 1), results['x'][i_x_min]))
    )
    partons_0 = vector_convert_to_energy_basis(
        jnp.hstack((fs_masses.reshape(6, 1), results['x'][0]))
    )

    print("best urv", scalar_unc_rv_from_obs(partons_i_x_min, E_cm))
    print("true urv", urv)

    psp, _ = ps_gen.generateKinematics(E_cm, SigmoidStraightThrough(urv))
    true_kin = jnp.array([p.vector for p in psp])
    print('best_match_kin', partons_i_x_min)
    print('zero_kin', partons_0)
    print('true_kin', true_kin[-6:])

[[ 6.97294122e+01  2.91061479e+01  5.87476154e+01 -2.32733918e+01]
 [ 4.38119168e+01  1.79736507e+01 -2.40855834e+01  3.18797208e+01]
 [ 5.90155881e+01  4.75032866e+01 -1.44141232e+01 -3.19141106e+01]
 [ 1.53515508e+02 -1.51717573e+02 -2.08834113e+01  9.51746787e+00]
 [ 8.99400002e+01  3.92934652e+01 -8.09025413e+01 -7.77680179e-02]
 [ 8.18867164e+01  2.32189987e+01  7.80141764e+01  8.94989984e+00]]
results time 0.48513628600017 

access time 0.00019970099992860924 

i_x_min 3
perms_test[i_x_min] [0 4 2 3 1 5]
results['loss'][i_x_min] 506.4986709749503
results['loss'][0] 518.3136253186576
results['l_diff'][i_x_min] 0.009999873270601256
results['n'] [55365 47091 50331 34452 42644 46583]
best urv [-0.70315955 -0.3284962  -0.19140896 -1.11842216 -0.69749419 -1.54088027
 -0.05739309  1.69750584 -1.00904745  3.12572861 -0.27316835  0.26191125
  0.59765708  1.1078236 ]
true urv [ 0.43680324  1.56728802  1.08748092 -0.22709576 -0.68265569 -1.51604668
  1.85509383  2.14418051 -1.06477907  3.59

In [52]:
if DO_TEST:
    key, subkey = jax.random.split(key)
    Nevt = 10
    observed, final_state, urv, rv = batch_sim_from_ucon_rv(
        subkey, Nevt=Nevt, sigma_smear=0.1
    )

    all_results = []
    for i in range(Nevt):
        all_results.append(
            single_event_chi2_opt(
                optimizer_chi2,
                observed[i],
                perms_test,
                lr=5e-4,
                max_iter=300000,
                tol=0.1,
            )
        )

In [53]:
def reco_ttbar(partons):
    w1 = partons[1] + partons[2]
    t1 = partons[0] + partons[1] + partons[2]
    w2 = partons[4] + partons[5]
    t2 = partons[3] + partons[4] + partons[5]

    return jnp.vstack((w1, t1, w2, t2))

In [54]:
for i in range(Nevt):
    r = all_results[i]
    i_x_min = onp.nanargmin(r["loss"])
    print("i_x_min", i_x_min)
    print("perms_test[i_x_min]", perms_test[i_x_min])

    print(final_state[i])
    reco_tt = reco_ttbar(final_state[i])
    print(reco_tt)
    print(mass_from_energy_basis(reco_tt[0]))

    print("masses", vector_mass_from_energy_basis(reco_tt))

i_x_min 4
perms_test[i_x_min] [0 5 2 3 1 4]
[[103.18803145  74.78693967  66.15199806 -25.62238555]
 [ 54.35683615 -49.1293404  -18.65708282  13.88836956]
 [ 25.23716611 -14.38738575   7.06235377  19.49463628]
 [ 79.66686076 -77.04300771  19.29502751   4.09824101]
 [109.82027578 -48.93156714 -88.73727537  42.33072962]
 [127.73082975 114.70436133  14.88497886 -54.18959093]]
[[ 79.59400226 -63.51672615 -11.59472905  33.38300584]
 [182.78203371  11.27021352  54.557269     7.76062029]
 [237.55110553  65.77279419 -73.85229651 -11.85886131]
 [317.21796629 -11.27021352 -54.557269    -7.76062029]]
32.43405424418026
masses [ 32.43405424 173.91242424 215.66101408 312.19144382]
i_x_min 5
perms_test[i_x_min] [3 1 2 0 4 5]
[[ 74.67708939 -12.77202714  73.38337076   2.51672355]
 [ 93.7681238  -61.07806583 -63.01631908  33.02838847]
 [ 97.45438165 -69.9719317  -32.11430688 -59.74911355]
 [ 58.47564847  36.54652787  39.15672526  22.98724934]
 [ 55.77323186   6.17117535  18.31912416 -52.3161512 ]
 [119.

# Comparisons
    

In [57]:
if DO_TEST:
    # key,subkey = jax.random.split(key)
    # urv = jax.random.normal(subkey, shape=(14,))
    # rv  = SigmoidStraightThrough(urv)

    # key,subkey = jax.random.split(key)
    # obs = scalar_sim(subkey, rv, 0.1)

    test_evts = get_multiple_lhe_event_to_ps_point(lhe_events, 3)

    key, subkey = jax.random.split(key)
    obs, f_state, f_urv, f_rv = vector_sim_from_ps_point(
        jax.random.split(subkey, 3), test_evts, 0.1
    )

    print(obs)

    start_time = timeit.default_timer()
    results_lk = single_event_urv_opt(
        optimizer, obs, perms_test, lr=1e-7, max_iter=4000, tol=0.001
    )
    elapsed = timeit.default_timer() - start_time
    print("time", elapsed, "\n")

    start_time = timeit.default_timer()
    results_ch2 = single_event_chi2_opt(
        optimizer_chi2, obs, perms_test, lr=1e-6, max_iter=300000, tol=0.01
    )
    # results = scalar_event_chi2_opt(obs, perms_correct_b)
    elapsed = timeit.default_timer() - start_time
    print("results time", elapsed, "\n")

[[  46.32935718  -22.15578338   11.36992328   38.78357183]
 [ 149.11238759   -6.99281059 -148.7078118    -8.46117263]
 [  80.18447234  -27.47656876   60.92506899   44.30263809]
 [  55.08547259  -37.04883419   12.9691958   -38.36017643]
 [  31.39364201  -26.81081023    7.67881638   14.41447164]
 [ 149.02130917  133.79950085   41.55193651  -50.77874291]]
time 10.719119711000076 

results time 4.184902576000013 



In [84]:
if DO_TEST:
    Nevts = 51
    events = get_multiple_lhe_event_to_ps_point(lhe_events, Nevts)

    key, subkey = jax.random.split(key)
    obs_evts, fs_evts, urv_evts, rv_evts = vector_sim_from_ps_point(
        jax.random.split(subkey, Nevts), events, 0.1
    )

    all_results_lk = []
    all_results_ch2 = []
    for i in range(Nevts):
        obs = obs_evts[i]

        # print(obs)
        if i % 10 == 0:
            print("### Event", i, "### \n")
            if i == 0:
                start_time10 = timeit.default_timer()
            else:
                print("")

        start_time = timeit.default_timer()
        results_lk = single_event_urv_opt(
            optimizer, obs, perms_test, lr=1e-7, max_iter=5000, tol=0.001
        )  # , clip_grad=1.0e4, clip_urv_init=None)
        elapsed = timeit.default_timer() - start_time
        # print("lk time", elapsed,"\n")
        results_lk["time"] = elapsed
        all_results_lk.append(results_lk)

        start_time = timeit.default_timer()
        results_ch2 = single_event_chi2_opt(
            optimizer_chi2, obs, perms_test, lr=1e-6, max_iter=300000, tol=0.01
        )
        # results = scalar_event_chi2_opt(obs, perms_correct_b)
        elapsed = timeit.default_timer() - start_time
        # print("ch2 time", elapsed,"\n")
        results_ch2["time"] = elapsed
        all_results_ch2.append(results_ch2)

### Event 0 ### 

### Event 10 ### 


### Event 20 ### 


### Event 30 ### 


### Event 40 ### 


### Event 50 ### 




In [85]:
n_match_lk = 0
n_match_ch2 = 0

print(match_combo)

for i in range(Nevts):
    print("event=", i)

    r_lk = all_results_lk[i]
    print("lk time", r_lk["time"])

    i_x_min_lk = onp.nanargmin(r_lk["loss"])
    print("i_x_min_lk", i_x_min_lk)
    print("perms_test[i_x_min_lk]", perms_test[i_x_min_lk])
    print("loss", r_lk['loss'][i_x_min_lk])
    print("loss[0]", r_lk['loss'][0])
    print("l_diff", r_lk['l_diff'][i_x_min_lk])
    print("all loss", r_lk["loss"])
    print("all n", r_lk["n"])
    if perms_test[i_x_min_lk].tolist() in match_combo.tolist():
        n_match_lk += 1

    r_ch2 = all_results_ch2[i]
    print("ch2 time", r_ch2["time"])
    i_x_min_ch2 = onp.nanargmin(r_ch2["loss"])
    print("i_x_min_ch2", i_x_min_ch2)
    print("perms_test[i_x_min_ch2]", perms_test[i_x_min_ch2])

    if perms_test[i_x_min_ch2].tolist() in match_combo.tolist():
        n_match_ch2 += 1

    print("")

print("n_match_lk", n_match_lk)
print("n_match_ch2", n_match_ch2)

[[0 1 2 3 4 5]
 [0 2 1 3 4 5]
 [0 1 2 3 5 4]
 [0 2 1 3 5 4]
 [3 4 5 0 1 2]
 [3 4 5 0 2 1]
 [3 5 4 0 1 2]
 [3 5 4 0 2 1]]
event= 0
lk time 11.219258185999934
i_x_min_lk 0
perms_test[i_x_min_lk] [0 1 2 3 4 5]
loss 10.609188604398724
loss[0] 10.609188604398724
l_diff 0.0009963321398203107
all loss [10.6091886  40.54182001 27.96798497 33.61336356 42.51355712 25.5743059 ]
all n [2929 1381 1628  947 1384 1811]
ch2 time 2.8127499339998394
i_x_min_ch2 0
perms_test[i_x_min_ch2] [0 1 2 3 4 5]

event= 1
lk time 20.11746801300069
i_x_min_lk 0
perms_test[i_x_min_lk] [0 1 2 3 4 5]
loss 62.89607989940983
loss[0] 62.89607989940983
l_diff 0.0009614181891066664
all loss [62.8960799  90.43119578 82.73507956 75.30440809 97.22887644 80.20843349]
all n [2777 2747 1419 1195 1034 4858]
ch2 time 2.8000844680000228
i_x_min_ch2 0
perms_test[i_x_min_ch2] [0 1 2 3 4 5]

event= 2
lk time 5.56533709499945
i_x_min_lk 0
perms_test[i_x_min_lk] [0 1 2 3 4 5]
loss 20.870015048636375
loss[0] 20.870015048636375
l_diff 0.00

loss[0] 13.762781220535356
l_diff 0.0009958226969928319
all loss [ 13.76278122  33.34060552  38.07298131  34.07024778  29.47162475
 161.72383338]
all n [4448 4545 3336 2791 4035 5000]
ch2 time 1.5376400029999786
i_x_min_ch2 0
perms_test[i_x_min_ch2] [0 1 2 3 4 5]

event= 28
lk time 18.606694113999765
i_x_min_lk 5
perms_test[i_x_min_lk] [3 1 2 0 4 5]
loss 20.671405261966854
loss[0] 96.3515747546989
l_diff 0.0011185455095059638
all loss [ 96.35157475 107.4977225  106.77085532 113.63637051 109.93417817
  20.67140526]
all n [5000 5000 5000 5000 5000 5000]
ch2 time 2.674505941999996
i_x_min_ch2 0
perms_test[i_x_min_ch2] [0 1 2 3 4 5]

event= 29
lk time 11.134252533999643
i_x_min_lk 0
perms_test[i_x_min_lk] [0 1 2 3 4 5]
loss 19.867199500835454
loss[0] 19.867199500835454
l_diff 0.0009784577533231698
all loss [19.8671995  31.90081886 41.19177508 36.27520494 31.92296207 28.0084226 ]
all n [1031 1259 1104 2563 2979 2634]
ch2 time 2.3692661259992747
i_x_min_ch2 0
perms_test[i_x_min_ch2] [0 1 2 3

In [None]:
print(obs_evts[8], "\n", fs_evts[8], "\n", urv_evts[8], "\n", rv_evts[8])

obs = obs_evts[8]
urv_0 = scalar_unc_rv_from_obs(obs, E_cm, 2)

print("urv_0", urv_0)
print("urv_diff", urv_0 - urv_evts[8])

print("loss_0", scalar_neg_log_pdf(urv_0, obs))
print("grrad loss_0", scalar_grad_neg_log_pdf(urv_0, obs))

start_time = timeit.default_timer()

r_x = single_event_urv_opt(
    optimizer,
    obs,
    perms_test,
    lr=1e-7,
    max_iter=500,
    tol=0.05,
    clip_grad=1.0e5,
    clip_urv_init=2,
)

elapsed = timeit.default_timer() - start_time
print("lk time", elapsed, "\n")


i_x_min_x = onp.nanargmin(r_x["loss"])
print("i_x_min_lk", i_x_min_x)
print("perms_test[i_x_min_lk]", perms_test[i_x_min_x])
print("loss", r_x['loss'][i_x_min_x])
print("l_diff", r_x['l_diff'][i_x_min_x])
print("loss[0]", r_x['loss'][0])
print("l_diff[0]", r_x['l_diff'][0])

print("all loss", r_x["loss"])
print("n", r_x["n"])
print("x", r_x["x"])
print("lr", r_x["lr"])