In [5]:
import sys
import timeit
import tracemalloc
import logging

import jax
from jax import custom_jvp
from jax.nn import sigmoid
from jax.config import config
config.update('jax_enable_x64', True)

import jax.numpy as jnp
from jax.ops import index, index_update
from jax.interpreters import xla


import madjax
from madjax.phasespace.flat_phase_space_generator import FlatInvertiblePhasespace
#from madjax.phasespace.new_flat_phase_space_generator import generate_phase_space_inputs, generateKinematics, invertKinematics
from madjax.phasespace.vectors import Vector

key = jax.random.PRNGKey(0)

#numpy & friends
import numpy as onp
import matplotlib.pyplot as plt
%matplotlib inline
import scipy.optimize

import random as random_orig
from itertools import permutations


In [None]:
def print_memory_stuff(note=""):
    current, peak = tracemalloc.get_traced_memory()
    print(note, f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
    
import warnings
warnings.simplefilter("ignore")

In [6]:
#Custom sigmoid, with identity gradient
#Easy way to implement natural gradient descent 
#  on sigmoid reparameterization of unconstrained dual space

@custom_jvp
def SigmoidStraightThrough(x):
    return sigmoid(x)

SigmoidStraightThrough.defjvps(lambda t, ans, x: t)



In [2]:
#Gradient descent optimizer
#Since SigmoidStraightThrough takes care of jacobian cancellation
#  in natural gradient descent (more numerically stable)
#  ... so nothing extra has to be added here

def optim_rep_natural_st(loss, grad_loss, u_init, lr=.1, max_iter=100, tol=0.001):

    u = jnp.array(u_init, copy=True)
    
    i_iter = 0
    
    loss_val = 0
    
    _success = True
    

    for i_iter in range(max_iter):        
        u_curr = u - lr*grad_loss(u)
        loss_curr = loss(u_curr)
        
        if jnp.isnan(loss_curr):
            _success=False
            break
        
        if tol is not None and jnp.abs(loss_curr - loss_val) < tol:
            u = u_curr
            loss_val = loss_curr
            break
        else:
            u = u_curr
            loss_val = loss_curr

    return {"x":u, "f":loss_val, "n":i_iter, "success":_success}

In [None]:
#Setup MadJax
mj = madjax.MadJax(config_name="ttbar_bqq_bqq")

In [8]:
#Setup Process
process_name = "Matrix_1_gg_ttx_t_budx_tx_bxdux"
process = mj.processes[process_name]()

E_cm = 14000.

nDimPhaseSpace = 14

sig_jet = 0.1

matrix_element = mj.matrix_element(E_cm=E_cm, process_name=process_name, return_grad=False, do_jit=False)


external_parameters = {}
parameters = mj.parameters.calculate_full_parameters(external_parameters)
external_masses = process.get_external_masses(parameters)


ps_generator = FlatInvertiblePhasespace(
    external_masses[0], external_masses[1],
    beam_Es = (E_cm/2.,E_cm/2.),
    beam_types=(0,0)
)



In [9]:
#Define Likelihood

def get_final_state(point):
    return jax.numpy.asarray( [p.asarray() for p in point[-6:]] )

def convert_to_PS_point(incoming, observed):
    _out = jax.numpy.vstack(([p.asarray() for p in incoming], [p for p in observed]))
    return [Vector(p) for p in _out]

def smear_final_state(subkey, final, sigma_jet=0.1):
    smear = jax.random.normal(subkey,(6,4))*sigma_jet
    return final*(1 + smear)

def smear_logpdf(observed, final, sigma_jet=0.1):
    log_pdf = jax.scipy.stats.norm.logpdf(observed,final, sigma_jet)
    return jax.numpy.sum(log_pdf)

def logpdf(observed, parameters, random_variables, sigma_jet=0.1):
    PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)
    ME_value = matrix_element(parameters, random_variables)
    final = get_final_state(PS_point)
    return smear_logpdf(observed, final, sigma_jet) + jax.numpy.log(ME_value) + jax.numpy.log(jacobian)


def logpdf_STE(observed, parameters, unconstrained_random_variables, sigma_jet=0.1):
    random_variables = SigmoidStraightThrough(unconstrained_random_variables)
    PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)
    ME_value = matrix_element(parameters, random_variables)
    final = get_final_state(PS_point)
    return smear_logpdf(observed, final, sigma_jet) + jax.numpy.log(ME_value) + jax.numpy.log(jacobian)

In [10]:
#Define objectives that only depend on random variables

def get_nll_STE(params, sigma_jet=0.1):
    def nll(observed, rv):
        return -logpdf_STE(observed, params, rv, sigma_jet) 

    v_and_g = jax.jit(jax.value_and_grad(nll, argnums=1))
    #v_and_g = jax.value_and_grad(nll, argnums=1)
    
    #def objective(observed, rv):
    #    v,g = v_and_g(observed, rv)
    #    return v,g
    
    def objective(observed, rv):
        v,g = v_and_g(observed, rv)
        return v
    
    def objective_grad(observed, rv):
        v,g = v_and_g(observed, rv)
        return g
    
    return objective, objective_grad


def get_rv_objective_STE(observed, nll_objective, nll_objective_grad):
    
    def rv_objective(rv):
        return nll_objective(observed, rv)
    
    def rv_objective_grad(rv):
        return nll_objective_grad(observed, rv)
    
    return rv_objective, rv_objective_grad




In [None]:
#Test evaluation
#Also serve to do the JIT compiling, which takes awhile

uncon_random_variables = jnp.array([onp.random.standard_normal() for _ in range(ps_generator.nDimPhaseSpace())])
random_variables = SigmoidStraightThrough(uncon_random_variables)
PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)

fs = get_final_state(PS_point)

key,subkey = jax.random.split(key)
obs = smear_final_state(subkey, fs, sig_jet)

nll_objective_STE, nll_objective_grad_STE = get_nll_STE(external_parameters, sigma_jet=sig_jet)
objective_STE, objective_grad_STE = get_rv_objective_STE(obs, nll_objective_STE, nll_objective_grad_STE)

start_time = timeit.default_timer()
print(objective_STE(uncon_random_variables))
elapsed = timeit.default_timer() - start_time
print("time", elapsed)

start_time = timeit.default_timer()
print(objective_grad_STE(uncon_random_variables))
elapsed = timeit.default_timer() - start_time
print("time", elapsed)

In [None]:
#Test optimization

#Generate Observation
uncon_random_variables = jnp.array([onp.random.standard_normal() for _ in range(ps_generator.nDimPhaseSpace())])
random_variables = SigmoidStraightThrough(uncon_random_variables)
PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)

fs = get_final_state(PS_point)

key,subkey = jax.random.split(key)
obs = smear_final_state(subkey, fs, sig_jet)


#Determine intial guess for unconstrained random variables
obs_vec = convert_to_PS_point(incoming=PS_point[0:2], observed=obs)
i_rv, i_wt = ps_generator.invertKinematics(E_cm, obs_vec)
i_unc_rv = jax.scipy.special.logit(i_rv)


#Get Objective
objective_STE, objective_grad_STE = get_rv_objective_STE(obs, nll_objective_STE, nll_objective_grad_STE)


#Do optimization
start_time = timeit.default_timer()

r = optim_rep_natural_st(objective_STE, objective_grad_STE , i_unc_rv, lr=1.0e-10, max_iter=100, tol=None)

elapsed = timeit.default_timer() - start_time

print("time",elapsed)
print("u_fit:", r["x"])
print("f(u_fit):", r["f"] )

print("u_true:", uncon_random_variables)
print("f(u_true):", objective_STE(uncon_random_variables) )

print("rv_fit", SigmoidStraightThrough(r["x"]))
print("true rvs:",random_variables)

In [None]:
#Define potential parton permutation
#  for doing combinatorial likelihood

comb = [0,1,2,3,4,5]
_perms = permutations(comb)

perms = onp.array([onp.array(i) for i in _perms])

#selecting b's in right place
_c1 = (perms[:,0]==0) & (perms[:,3]==3)
_c2 = (perms[:,0]==3) & (perms[:,3]==0)
_cb = (_c1)|(_c2)
perms_correct_b = perms[_cb]

#ignoring swapping jets in same W

_w1 = (perms[:,1]==2) & (perms[:,2]==2)
_w2 = (perms[:,1]==5) & (perms[:,2]==4)
_w3 = (perms[:,4]==2) & (perms[:,5]==2)
_w4 = (perms[:,4]==5) & (perms[:,5]==4)
_ww = (~_w1)&(~_w2)&(~_w3)&(~_w4)
perms_ignore_w_swap = perms[_ww]

#selecting b's in right place and ignoring swapping jets in same W
perms_correct_b_ignore_w_swap = perms[_cb&_ww]

In [None]:
#Combinatorial likelihood fit

n_evt = 10

all_results = []

for i_evt in range(n_evt):  
    print("Starting event", i_evt)
    
    random_variables = [random_orig.random() for _ in range(ps_generator.nDimPhaseSpace())]
    PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)

    fs = get_final_state(PS_point)

    key,subkey = jax.random.split(key)
    obs = smear_final_state(subkey, fs, sig_jet)

    obs_vec = convert_to_PS_point(incoming=PS_point[0:2], observed=obs)
    i_rv, i_wt = ps_generator.invertKinematics(E_cm, obs_vec)
    i_unc_rv = jax.scipy.special.logit(i_rv)

    combos = perms_correct_b

    results = []

    tracemalloc.start()

    total_start_time = timeit.default_timer()

    for i in range(combos.shape[0]):
        if i%10==0: print("Starting Combination", i, "for event", i_evt)

        new_obs = obs[ combos[i] ]

        new_objective, new_objective_grad = get_rv_objective_STE(new_obs, nll_objective_STE, nll_objective_grad_STE)

      

        #start_time = timeit.default_timer()
        r = optim_rep_natural_st(new_objective, new_objective_grad , i_unc_rv, lr=1.0e-10, max_iter=200, tol=None)

        #elapsed = timeit.default_timer() - start_time

        #print("time", elapsed)

        #print_memory_stuff("4")

        results.append(r)

        del new_objective, new_objective_grad

        #pring some memory stuff if we want
        if False:
            snapshot=tracemalloc.take_snapshot()
            for i, stat in enumerate(snapshot.statistics('filename')[:5], 1):
                logging.info("top_current",i=i, stat=str(stat))

            current, peak = tracemalloc.get_traced_memory()
            print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")

    total_elapsed = timeit.default_timer() - total_start_time
    print("total_elapsed_timed", total_elapsed)

    print("-----------")
    r_fun = [ (r["f"] if r["success"] else 1.0e100) for r in results]
    print(onp.argmin(r_fun))
    
    all_results.append((results, onp.argmin(r_fun)))


    current, peak = tracemalloc.get_traced_memory()
    print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
    tracemalloc.stop()


Chi2 Using only kinematics

In [None]:
def energy_from_mass_basis(p):
    return jnp.sqrt(jnp.sum(jnp.square(p)))

def convert_to_energy_basis(particles):
        _energies = jnp.array([energy_from_mass_basis(p) for p in particles])
        jax.ops.index_update(particles , jax.ops.index[:,0], _energies)
        return particles

def mass_from_energy_basis(p):
    return jnp.sqrt(p[0]*p[0] - p[1]*p[1] - p[2]*p[2] - p[3]*p[3])

def convert_to_mass_basis(particles):
        _masses = jnp.array([mass_from_energy_basis(p) for p in particles])
        jax.ops.index_update(particles , jax.ops.index[:,0], _masses)
        return particles


    
def chi2(observed, final, sigma_jet=0.1, is_final_mass_basis=False):
    _mt = 173.0
    _gt = 1.5
    _mw = 80.4
    _gw = 2.1
    
    
    _final = jnp.reshape(final, (6,4))
    
    _final = jax.lax.cond(is_final_mass_basis, _final, convert_to_energy_basis, _final, lambda x: x)
    
    
    w1 = _final[1]+_final[2]
    t1 = _final[0]+_final[1]+_final[2]
    w2 = _final[4]+_final[5]
    t2 = _final[3]+_final[4]+_final[5]
    
    mw1 = jnp.sqrt(w1[0]*w1[0] - w1[1]*w1[1] - w1[2]*w1[2] - w1[3]*w1[3])
    mt1 = jnp.sqrt(t1[0]*t1[0] - t1[1]*t1[1] - t1[2]*t1[2] - t1[3]*t1[3])
    
    mw2 = jnp.sqrt(w2[0]*w2[0] - w2[1]*w2[1] - w2[2]*w2[2] - w2[3]*w2[3])
    mt2 = jnp.sqrt(t2[0]*t2[0] - t2[1]*t2[1] - t2[2]*t2[2] - t2[3]*t2[3])
    
    return 8.0*jnp.log(2.0)*(jnp.square((mw1-_mw)/_gw) + jnp.square((mt1-_mt)/_gt) 
                             + jnp.square((mw2-_mw)/_gw) + jnp.square((mt2-_mt)/_gt)) - 2.0*smear_logpdf(observed, _final, sigma_jet)
    

In [None]:
def get_chi2_objective(observed, sigma_jet=0.1, is_final_mass_basis=False):
    
    v_and_g = jax.jit(jax.value_and_grad(chi2, argnums=1), static_argnums=2)
    #v_and_g = jax.value_and_grad(chi2, argnums=1)
    
    def chi2_objective(final):
        v, g = v_and_g(observed, final, sigma_jet, is_final_mass_basis)
        return v, g
    
    return chi2_objective

In [None]:
random_variables = [random_orig.random() for _ in range(ps_generator.nDimPhaseSpace())]
PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)
fs = get_final_state(PS_point)

key,subkey = jax.random.split(key)
obs = smear_final_state(subkey, fs, sig_jet)

print(fs)

ch2_obj = get_chi2_objective(obs, sig_jet)
print(ch2_obj(fs))

In [None]:


random_variables = [random_orig.random() for _ in range(ps_generator.nDimPhaseSpace())]
PS_point, jacobian = ps_generator.generateKinematics(E_cm, random_variables)

fs = get_final_state(PS_point)

key,subkey = jax.random.split(key)
obs = smear_final_state(subkey, fs, sig_jet)

#obs = convert_to_mass_basis(obs)

print("fs",fs)
print("obs",obs)


combos = [[0,1,2,3,4,5],
          [0,4,2,3,1,5],
          [0,5,2,3,4,1],
          [3,1,2,0,4,5],
          [3,4,2,0,1,5],
          [3,5,2,0,4,1]]

results = []
for i in range(6):
    print("Starting Combination", i)
    
    new_obs = obs[ combos[i] ]
    
    new_objective = get_chi2_objective(new_obs, sig_jet, True)
    
    start_time = timeit.default_timer()
    r = scipy.optimize.minimize(
    new_objective,
    onp.asarray(new_obs),
    jac = True,
    method = 'SLSQP',
    bounds = [(0.0,7000.0), (-7000.,7000), (-7000.,7000), (-7000.,7000)]*6
    )
    elapsed = timeit.default_timer() - start_time

    results.append(r)
    print(r)
    print("-----")
    for v in jnp.reshape(r.x, (6,4)):
        print("[", energy_from_mass_basis(v), v[1], v[2], v[3], "]") 
    print("fs", fs)
    print("time",elapsed)
    print("")  

print("-----------")
r_fun = [r.fun for r in results]
print(onp.argmin(r_fun))