In [1]:
import sys
import timeit
import tracemalloc
import logging

import jax
from jax import custom_jvp
from jax.nn import sigmoid
from jax.config import config

config.update('jax_enable_x64', True)

import jax.numpy as jnp
from jax.ops import index, index_update
from jax.interpreters import xla


import madjax
from madjax.phasespace.flat_phase_space_generator import FlatInvertiblePhasespace

# from madjax.phasespace.new_flat_phase_space_generator import generate_phase_space_inputs, generateKinematics, invertKinematics
from madjax.phasespace.vectors import Vector

key = jax.random.PRNGKey(0)

# numpy & friends
import numpy as onp
import matplotlib.pyplot as plt

%matplotlib inline
import scipy.optimize

import random as random_orig
from itertools import permutations



In [2]:
def print_memory_stuff(note=""):
    current, peak = tracemalloc.get_traced_memory()
    print(
        note, f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB"
    )


import warnings

warnings.simplefilter("ignore")

In [3]:
# Custom sigmoid, with identity gradient
# Easy way to implement natural gradient descent
#  on sigmoid reparameterization of unconstrained dual space


@custom_jvp
def SigmoidStraightThrough(x):
    return sigmoid(x)


SigmoidStraightThrough.defjvps(lambda t, ans, x: t)

In [4]:
# Gradient descent optimizer
# Since SigmoidStraightThrough takes care of jacobian cancellation
#  in natural gradient descent (more numerically stable)
#  ... so nothing extra has to be added here


def optim_rep_natural_st(loss, grad_loss, u_init, lr=0.1, max_iter=100, tol=0.001):

    u = jnp.array(u_init, copy=True)

    i_iter = 0

    loss_val = 0

    _success = True

    for i_iter in range(max_iter):
        u_curr = u - lr * grad_loss(u)
        loss_curr = loss(u_curr)

        if jnp.isnan(loss_curr):
            _success = False
            break

        if tol is not None and jnp.abs(loss_curr - loss_val) < tol:
            u = u_curr
            loss_val = loss_curr
            break
        else:
            u = u_curr
            loss_val = loss_curr

    return {"x": u, "f": loss_val, "n": i_iter, "success": _success}

In [5]:
# Setup MadJax
mj = madjax.MadJax(config_name="ttbar_bqq_bqq")

In [6]:
# Setup Process
process_name = "Matrix_1_gg_ttx_t_budx_tx_bxdux"
process = mj.processes[process_name]()

E_cm = 14000.0

nDimPhaseSpace = 14

sig_jet = 0.1

matrix_element = mj.matrix_element(
    E_cm=E_cm, process_name=process_name, return_grad=False, do_jit=False
)


external_parameters = {}
parameters = mj.parameters.calculate_full_parameters(external_parameters)
external_masses = process.get_external_masses(parameters)


ps_generator = FlatInvertiblePhasespace(
    external_masses[0],
    external_masses[1],
    beam_Es=(E_cm / 2.0, E_cm / 2.0),
    beam_types=(0, 0),
)

In [7]:
# Define Likelihood


def get_final_state(point):
    return jax.numpy.asarray([p.asarray() for p in point[-6:]])


def convert_to_PS_point(incoming, observed):
    _out = jax.numpy.vstack(([p.asarray() for p in incoming], [p for p in observed]))
    return [Vector(p) for p in _out]


def smear_final_state(subkey, final, sigma_jet=0.1):
    smear = jax.random.normal(subkey, (6, 4)) * sigma_jet
    return final * (1 + smear)


def smear_logpdf(observed, final, sigma_jet=0.1):
    log_pdf = jax.scipy.stats.norm.logpdf(observed, final, sigma_jet)
    return jax.numpy.sum(log_pdf)


def logpdf(observed, parameters, random_variables, sigma_jet=0.1):
    PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)
    ME_value = matrix_element(parameters, random_variables)
    final = get_final_state(PS_point)
    return (
        smear_logpdf(observed, final, sigma_jet)
        + jax.numpy.log(ME_value)
        + jax.numpy.log(jacobian)
    )


def logpdf_STE(observed, parameters, unconstrained_random_variables, sigma_jet=0.1):
    random_variables = SigmoidStraightThrough(unconstrained_random_variables)
    PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)
    ME_value = matrix_element(parameters, random_variables)
    final = get_final_state(PS_point)
    return (
        smear_logpdf(observed, final, sigma_jet)
        + jax.numpy.log(ME_value)
        + jax.numpy.log(jacobian)
    )

In [8]:
# Define objectives that only depend on random variables


def get_nll_STE(params, sigma_jet=0.1):
    def nll(observed, rv):
        return -logpdf_STE(observed, params, rv, sigma_jet)

    v_and_g = jax.jit(jax.value_and_grad(nll, argnums=1))
    # v_and_g = jax.value_and_grad(nll, argnums=1)

    # def objective(observed, rv):
    #    v,g = v_and_g(observed, rv)
    #    return v,g

    def objective(observed, rv):
        v, g = v_and_g(observed, rv)
        return v

    def objective_grad(observed, rv):
        v, g = v_and_g(observed, rv)
        return g

    return objective, objective_grad


def get_rv_objective_STE(observed, nll_objective, nll_objective_grad):
    def rv_objective(rv):
        return nll_objective(observed, rv)

    def rv_objective_grad(rv):
        return nll_objective_grad(observed, rv)

    return rv_objective, rv_objective_grad

In [9]:
# Create objective as function of observation and RVs
# ***  NOTE everytime you call this, you will have to re-JIT
# ***       so try not to call this alot!
nll_objective_STE, nll_objective_grad_STE = get_nll_STE(
    external_parameters, sigma_jet=sig_jet
)

In [10]:
# Test evaluation
# Also serve to do the JIT compiling, which takes awhile

uncon_random_variables = jnp.array(
    [onp.random.standard_normal() for _ in range(ps_generator.nDimPhaseSpace())]
)
random_variables = SigmoidStraightThrough(uncon_random_variables)
PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)

fs = get_final_state(PS_point)

key, subkey = jax.random.split(key)
obs = smear_final_state(subkey, fs, sig_jet)

objective_STE, objective_grad_STE = get_rv_objective_STE(
    obs, nll_objective_STE, nll_objective_grad_STE
)

start_time = timeit.default_timer()
print("Objective value", objective_STE(uncon_random_variables))
elapsed = timeit.default_timer() - start_time
print("time", elapsed, "\n")

start_time = timeit.default_timer()
print("Objective Jacobian", objective_grad_STE(uncon_random_variables))
elapsed = timeit.default_timer() - start_time
print("time", elapsed)

Objective value 32396179.005501162
time 132.434018404 

Objective Jacobian [ 1.54484167e+08  4.04486883e+07 -5.40266097e+07  1.33077067e+07
  1.64471777e+07  3.74813859e+08 -1.06045014e+08 -1.50483961e+08
 -1.82975967e+07  9.02504409e+07 -2.63649413e+07  2.88387356e+08
  3.42118140e+08 -6.52595426e+08]
time 0.002853151000010712


In [11]:
# Test optimization

# Generate Observation
uncon_random_variables = jnp.array(
    [onp.random.standard_normal() for _ in range(ps_generator.nDimPhaseSpace())]
)
random_variables = SigmoidStraightThrough(uncon_random_variables)
PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)

fs = get_final_state(PS_point)

key, subkey = jax.random.split(key)
obs = smear_final_state(subkey, fs, sig_jet)


# Determine intial guess for unconstrained random variables
obs_vec = convert_to_PS_point(incoming=PS_point[0:2], observed=obs)
i_rv, i_wt = ps_generator.invertKinematics(E_cm, obs_vec)
i_unc_rv = jax.scipy.special.logit(i_rv)


# Get Objective
objective_STE, objective_grad_STE = get_rv_objective_STE(
    obs, nll_objective_STE, nll_objective_grad_STE
)


# Do optimization
start_time = timeit.default_timer()

r = optim_rep_natural_st(
    objective_STE, objective_grad_STE, i_unc_rv, lr=1.0e-10, max_iter=100, tol=None
)

elapsed = timeit.default_timer() - start_time

print("time", elapsed, "\n")

print("- Fit Unconstrained rv, and its loss value -")
print("u_fit:", r["x"])
print("f(u_fit):", r["f"], "\n")

print("- True Unconstrained rv, and its loss value -")
print("u_true:", uncon_random_variables)
print("f(u_true):", objective_STE(uncon_random_variables), "\n")

print("- Hypercube rv, fit and true -")
print("rv_fit", SigmoidStraightThrough(r["x"]))
print("true rvs:", random_variables)

time 0.3715729129999943 

- Fit Unconstrained rv, and its loss value -
u_fit: [ 0.43518521  0.37232088  1.88711255 -0.59010379  0.31263717 -1.0106878
  0.07350352 -0.08202146 -0.48108224 -1.00197898 -2.26698346  0.07952424
 -0.13440073  1.92815405]
f(u_fit): 22333464.60726571 

- True Unconstrained rv, and its loss value -
u_true: [ 0.51482218  0.54601361  1.75205516 -0.53956164  0.37868054 -1.00390846
  0.04253644 -0.0874629  -0.50753472 -1.03251062 -2.12959646  0.08832888
 -0.23857732  1.9754216 ]
f(u_true): 30680730.18138084 

- Hypercube rv, fit and true -
rv_fit [0.60711117 0.59201967 0.86842596 0.35661104 0.57752883 0.26684527
 0.51836761 0.47950612 0.3819966  0.26855251 0.09389454 0.51987059
 0.46645031 0.87304496]
true rvs: [0.62593622 0.63321022 0.85221183 0.36828956 0.59355483 0.26817367
 0.51063251 0.4781482  0.37577162 0.26259766 0.10625331 0.52206787
 0.44063698 0.87819226]


In [12]:
# Define potential parton permutation
#  for doing combinatorial likelihood

comb = [0, 1, 2, 3, 4, 5]
_perms = permutations(comb)

perms = onp.array([onp.array(i) for i in _perms])

# selecting b's in right place
_c1 = (perms[:, 0] == 0) & (perms[:, 3] == 3)
_c2 = (perms[:, 0] == 3) & (perms[:, 3] == 0)
_cb = (_c1) | (_c2)
perms_correct_b = perms[_cb]

# ignoring swapping jets in same W

_w1 = (perms[:, 1] == 2) & (perms[:, 2] == 2)
_w2 = (perms[:, 1] == 5) & (perms[:, 2] == 4)
_w3 = (perms[:, 4] == 2) & (perms[:, 5] == 2)
_w4 = (perms[:, 4] == 5) & (perms[:, 5] == 4)
_ww = (~_w1) & (~_w2) & (~_w3) & (~_w4)
perms_ignore_w_swap = perms[_ww]

# selecting b's in right place and ignoring swapping jets in same W
perms_correct_b_ignore_w_swap = perms[_cb & _ww]

In [13]:
# Combinatorial likelihood fit

n_evt = 2

all_results = []

# Loop over events, if we want to examine multiple events
for i_evt in range(n_evt):
    print("Starting event", i_evt)

    random_variables = [
        random_orig.random() for _ in range(ps_generator.nDimPhaseSpace())
    ]
    PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)

    fs = get_final_state(PS_point)

    key, subkey = jax.random.split(key)
    obs = smear_final_state(subkey, fs, sig_jet)

    obs_vec = convert_to_PS_point(incoming=PS_point[0:2], observed=obs)
    i_rv, i_wt = ps_generator.invertKinematics(E_cm, obs_vec)
    i_unc_rv = jax.scipy.special.logit(i_rv)

    combos = perms_correct_b

    results = []

    tracemalloc.start()

    total_start_time = timeit.default_timer()

    # Combinatorial likelihood, i.e. loop through potential permutations
    for i in range(combos.shape[0]):
        if i % 10 == 0:
            print("Starting Combination", i, "for event", i_evt)

        # Since parton definition is fixed in likelihood from ME,
        #  permutations are evaluated by rearranging observation
        #  the observation index is what assigns jet[i] to parton[i]
        new_obs = obs[combos[i]]

        new_objective, new_objective_grad = get_rv_objective_STE(
            new_obs, nll_objective_STE, nll_objective_grad_STE
        )

        # start_time = timeit.default_timer()
        r = optim_rep_natural_st(
            new_objective,
            new_objective_grad,
            i_unc_rv,
            lr=1.0e-10,
            max_iter=200,
            tol=None,
        )
        # elapsed = timeit.default_timer() - start_time
        # print("time", elapsed)
        # print_memory_stuff("4")

        results.append(r)

        del new_objective, new_objective_grad

        # pring some memory stuff if we want
        if False:
            snapshot = tracemalloc.take_snapshot()
            for i, stat in enumerate(snapshot.statistics('filename')[:5], 1):
                logging.info("top_current", i=i, stat=str(stat))

            current, peak = tracemalloc.get_traced_memory()
            print(
                f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB"
            )

    total_elapsed = timeit.default_timer() - total_start_time
    print("total_elapsed_timed", total_elapsed)

    print("-----------")
    r_fun = [(r["f"] if r["success"] else 1.0e100) for r in results]
    print(onp.argmin(r_fun))

    all_results.append((results, onp.argmin(r_fun)))

    current, peak = tracemalloc.get_traced_memory()
    print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
    tracemalloc.stop()

Starting event 0
Starting Combination 0 for event 0
Starting Combination 10 for event 0
Starting Combination 20 for event 0
Starting Combination 30 for event 0
Starting Combination 40 for event 0
total_elapsed_timed 25.37512263900001
-----------
13
Current memory usage is 0.172119MB; Peak was 0.233362MB
Starting event 1
Starting Combination 0 for event 1
Starting Combination 10 for event 1
Starting Combination 20 for event 1
Starting Combination 30 for event 1
Starting Combination 40 for event 1
total_elapsed_timed 23.654974951000014
-----------
0
Current memory usage is 0.112598MB; Peak was 0.175924MB


Chi2 Using only kinematics

In [14]:
def energy_from_mass_basis(p):
    return jnp.sqrt(jnp.sum(jnp.square(p)))


def convert_to_energy_basis(particles):
    _energies = jnp.array([energy_from_mass_basis(p) for p in particles])
    jax.ops.index_update(particles, jax.ops.index[:, 0], _energies)
    return particles


def mass_from_energy_basis(p):
    return jnp.sqrt(p[0] * p[0] - p[1] * p[1] - p[2] * p[2] - p[3] * p[3])


def convert_to_mass_basis(particles):
    _masses = jnp.array([mass_from_energy_basis(p) for p in particles])
    jax.ops.index_update(particles, jax.ops.index[:, 0], _masses)
    return particles


def chi2(observed, final, sigma_jet=0.1, is_final_mass_basis=False):
    _mt = 173.0
    _gt = 1.5
    _mw = 80.4
    _gw = 2.1

    _final = jnp.reshape(final, (6, 4))

    _final = jax.lax.cond(
        is_final_mass_basis, _final, convert_to_energy_basis, _final, lambda x: x
    )

    w1 = _final[1] + _final[2]
    t1 = _final[0] + _final[1] + _final[2]
    w2 = _final[4] + _final[5]
    t2 = _final[3] + _final[4] + _final[5]

    mw1 = jnp.sqrt(w1[0] * w1[0] - w1[1] * w1[1] - w1[2] * w1[2] - w1[3] * w1[3])
    mt1 = jnp.sqrt(t1[0] * t1[0] - t1[1] * t1[1] - t1[2] * t1[2] - t1[3] * t1[3])

    mw2 = jnp.sqrt(w2[0] * w2[0] - w2[1] * w2[1] - w2[2] * w2[2] - w2[3] * w2[3])
    mt2 = jnp.sqrt(t2[0] * t2[0] - t2[1] * t2[1] - t2[2] * t2[2] - t2[3] * t2[3])

    return 8.0 * jnp.log(2.0) * (
        jnp.square((mw1 - _mw) / _gw)
        + jnp.square((mt1 - _mt) / _gt)
        + jnp.square((mw2 - _mw) / _gw)
        + jnp.square((mt2 - _mt) / _gt)
    ) - 2.0 * smear_logpdf(observed, _final, sigma_jet)

In [15]:
def get_chi2_objective(observed, sigma_jet=0.1, is_final_mass_basis=False):

    v_and_g = jax.jit(jax.value_and_grad(chi2, argnums=1), static_argnums=2)
    # v_and_g = jax.value_and_grad(chi2, argnums=1)

    def chi2_objective(final):
        v, g = v_and_g(observed, final, sigma_jet, is_final_mass_basis)
        return v, g

    return chi2_objective

In [16]:
random_variables = [random_orig.random() for _ in range(ps_generator.nDimPhaseSpace())]
PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)
fs = get_final_state(PS_point)

key, subkey = jax.random.split(key)
obs = smear_final_state(subkey, fs, sig_jet)

print(fs)

ch2_obj = get_chi2_objective(obs, sig_jet)
print(ch2_obj(fs))

[[ 3745.91885301  -264.10954712 -2701.55238969 -2581.42340506]
 [ 3755.89453316  -789.77958686  3606.4733523    690.17527305]
 [ 1494.15724483   923.36889374   304.99642096 -1134.40422318]
 [ 2620.35970562 -1333.51404385   296.07986949  2236.14398073]
 [  646.9175827    545.02061547  -333.04896585  -102.63173906]
 [ 1736.75208068   919.01366863 -1172.94828721   892.14011352]]
(DeviceArray(2.77270515e+08, dtype=float64), DeviceArray([[  -217.33023758,   7741.5679748 ,  62620.56975618,
               88260.88019092],
             [ 65120.48260186,  -6683.80865924, -70469.10044417,
                9604.14268265],
             [ 84852.05070553,  23442.90550079, -19956.08488974,
               44747.92635135],
             [ 67931.5588203 ,  38855.81576421,   6433.58342313,
              -63837.02984645],
             [ 61484.89611439,   5173.50120338,  12343.08659797,
              -14326.95864354],
             [ 40792.79833936,  14620.17885725,  17035.78843683,
                1255.23058

In [17]:
random_variables = [random_orig.random() for _ in range(ps_generator.nDimPhaseSpace())]
PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)

fs = get_final_state(PS_point)

key, subkey = jax.random.split(key)
obs = smear_final_state(subkey, fs, sig_jet)

# obs = convert_to_mass_basis(obs)

print("fs", fs)
print("obs", obs)


combos = [
    [0, 1, 2, 3, 4, 5],
    [0, 4, 2, 3, 1, 5],
    [0, 5, 2, 3, 4, 1],
    [3, 1, 2, 0, 4, 5],
    [3, 4, 2, 0, 1, 5],
    [3, 5, 2, 0, 4, 1],
]

results = []
for i in range(6):
    print("Starting Combination", i)

    new_obs = obs[combos[i]]

    new_objective = get_chi2_objective(new_obs, sig_jet, True)

    start_time = timeit.default_timer()
    r = scipy.optimize.minimize(
        new_objective,
        onp.asarray(new_obs),
        jac=True,
        method='SLSQP',
        bounds=[(0.0, 7000.0), (-7000.0, 7000), (-7000.0, 7000), (-7000.0, 7000)] * 6,
    )
    elapsed = timeit.default_timer() - start_time

    results.append(r)
    print(r)
    print("-----")
    for v in jnp.reshape(r.x, (6, 4)):
        print("[", energy_from_mass_basis(v), v[1], v[2], v[3], "]")
    print("fs", fs)
    print("time", elapsed)
    print("")

print("-----------")
r_fun = [r.fun for r in results]
print(onp.argmin(r_fun))

fs [[ 2073.19725333  1868.58339312   770.03159846  -462.13872543]
 [  965.63454333   504.5676478   -569.50935876   594.57602571]
 [ 1638.05632449   -23.7488696  -1570.63385881   464.51457992]
 [ 2676.88540534   304.01205331  2435.33920126  1068.8278765 ]
 [ 3855.32452113 -1448.49717765   643.50465211 -3514.4394791 ]
 [ 2790.90195238 -1204.91704698 -1708.73223426  1848.65972241]]
obs [[ 1819.2991895   2046.40566739   861.88621033  -557.80014494]
 [  781.53809843   440.62100062  -473.0225599    674.39560325]
 [ 1725.28909474   -27.27152179 -1480.45973008   397.37162189]
 [ 2462.0296645    329.54218726  2283.46322316  1185.07381045]
 [ 3894.35284898 -1229.68954545   603.92281025 -4284.92030743]
 [ 3027.02445317 -1288.65390546 -1705.1642168   2182.9143924 ]]
Starting Combination 0
     fun: 231612225.0572351
     jac: array([ 0.00038646,  0.00039525, -0.00118965,  0.00014231, -0.02599639,
        0.00650593, -0.02536224,  0.01403018, -0.03084985,  0.00648783,
       -0.02540685,  0.0140264

     fun: 174906027.84219682
     jac: array([-4.51743421e-04,  8.79641204e-04,  5.70686866e-04,  3.27204252e-04,
        9.00642495e-04, -8.11948963e-04, -1.42936773e-03, -7.70434148e-04,
        2.98730692e-03,  1.12246889e-05, -1.40844292e-03, -6.95302984e-04,
       -9.52644852e-04,  1.06206196e-03,  4.50299202e-04,  9.65971392e-04,
       -1.62596387e-02, -1.19011217e-03, -7.82656348e-03,  1.39342572e-02,
       -1.70484014e-02, -1.46888388e-03, -8.85765410e-03,  1.24518108e-02])
 message: 'Optimization terminated successfully.'
    nfev: 107
     nit: 40
    njev: 40
  status: 0
 success: True
       x: array([ 2283.95293834,   304.72136979,  2319.31707231,  1112.6011825 ,
        3653.65300744, -1270.83265469,   629.72259796, -4407.76747937,
        1484.58926363,   -68.41462692, -1454.65994226,   274.52445034,
        1692.09893078,  2077.61446419,   830.19441305,  -501.64728267,
         716.51100749,   484.67388777,  -468.06674485,   681.99799179,
        2961.99735829, -1244