In [None]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
from flax import serialization
import optax

from SmoothNF import SmoothNormalizingFlow 

import numpy as np
import matplotlib.pyplot as plt
import h5py

from functools import partial
from typing import Sequence

import timeit
import pickle


from jax.config import config; #config.update("jax_enable_x64", True)

# physics stuff
import pylhe

# Madjax
import madjax

In [None]:
key = jax.random.PRNGKey(42)

batch_size = 100
num_epochs = 300
num_warm_up_epochs = 10
steps_per_epoch = 75 #20000 // batch_size

learning_rate = 0.5e-4

eps_border = 1.0e-10


# Setup MadJax

In [None]:

E_cm = 500.
config_name  = "ee_ttbar_bqq_bqq"
process_name = "Matrix_1_epem_ttx_t_budx_tx_bxdux"
nDimPS=14

mj = madjax.MadJax(config_name=config_name)
matrix_element = mj.matrix_element(E_cm=E_cm, process_name=process_name, return_grad=False, do_jit=False)
jacobian = mj.jacobian(E_cm=E_cm, process_name=process_name, do_jit=False)
me_and_jac = mj.matrix_element_and_jacobian(E_cm=E_cm, process_name=process_name)
ps_gen = mj.phasespace_generator(E_cm=E_cm, process_name=process_name)({})
ps_vec = mj.phasespace_vectors(E_cm=E_cm, process_name=process_name)

external_params={}

sigma_smear=0.1

In [None]:
def scalar_rv_from_ps_point(ps_point, E_cm):
    i_rv, i_wt = ps_gen.invertKinematics(E_cm, [madjax.phasespace.vectors.Vector(p) for p in ps_point])
    
    return jnp.array(i_rv)

vector_rv_from_ps_point = jax.jit(jax.vmap(scalar_rv_from_ps_point, in_axes=(0,None)), static_argnums=1)

In [None]:
def get_scalar_log_me(params):
    _eps = 1.0e-10
    def func(rv):
        me_val, jac_val = me_and_jac(params, rv)
        return jnp.log(me_val)+ jnp.log(jac_val)
        #return jnp.log(matrix_element(params, rv)+_eps)+jnp.log(jacobian(params, rv)+_eps)
    return func

scalar_log_me = get_scalar_log_me(params={})
vector_log_me = jax.vmap(get_scalar_log_me(params={}))

grad_scalar_log_me = jax.grad(get_scalar_log_me(params={}))
grad_vector_log_me = jax.vmap(jax.grad(get_scalar_log_me(params={})))


scalar_log_me_jit = jax.jit(get_scalar_log_me(params={}))
vector_log_me_jit = jax.jit(jax.vmap(get_scalar_log_me(params={})))

grad_scalar_log_me_jit = jax.jit(jax.grad(get_scalar_log_me(params={})))
grad_vector_log_me_jit = jax.jit(jax.vmap(jax.grad(get_scalar_log_me(params={}))))

In [None]:

def prejit_matrix_elements():
    init_data = 0.9*jnp.ones((10,nDimPS))
    
    start_time = timeit.default_timer()
    scalar_log_me_jit(init_data[0])
    elapsed = timeit.default_timer() - start_time
    print("scalar_log_me time", elapsed,"\n")
    
    start_time = timeit.default_timer()
    vector_log_me_jit(init_data)
    elapsed = timeit.default_timer() - start_time
    print("vector_log_me time", elapsed,"\n")
    

    start_time = timeit.default_timer()
    grad_scalar_log_me_jit(init_data[0])
    elapsed = timeit.default_timer() - start_time
    print("grad_scalar_log_me time", elapsed,"\n")
    
    start_time = timeit.default_timer()
    grad_vector_log_me_jit(init_data)
    elapsed = timeit.default_timer() - start_time
    print("grad_vector_log_me time", elapsed,"\n")
    
if True:
    prejit_matrix_elements()

# LHE Event Manipulation

In [None]:
# get LHE events generator
lhe_events = pylhe.readLHE("./data/ee_ttbar_bqq_bqq/unweighted_events.lhe")

In [None]:
def lhe_event_to_ps_point(event):
    _particles = []
    for p in event.particles:
        #print(p.id, p.status, p.px, p.py, p.pz, p.e)
        if p.status==-1 or p.status==1:
            _particles += [[p.e, p.px, p.py, p.pz]]
    return jnp.array(_particles)

def get_multiple_lhe_event_to_ps_point(lhe_event_generator, n_events):
    _evts=[]
    for i in range(n_events):
        try:
            _ev = lhe_event_generator.__next__()
        except:
            print("No More Events")
            return _evts
        
        _ps = lhe_event_to_ps_point(_ev)
        _evts += [_ps]
    
    return np.array(_evts)

In [None]:
events = get_multiple_lhe_event_to_ps_point(lhe_events, 10000)
dataset = vector_rv_from_ps_point(events, E_cm)

train_data = dataset[0:8000]
test_data = dataset[8000:]

# Setup Smooth NF

In [None]:
def model1():
    return SmoothNormalizingFlow(num_flows = 4,
                                num_biject = 4,
                                num_in_feat = nDimPS,
                                cond_mlp_width = [[100,100]])


In [None]:

#def force_MSE(x, grad_flow_logprob):
#    grad_logME = grad_vector_log_me(x)
#    return jnp.square(grad_logME - grad_flow_logprob).sum(axis=tuple(range(1,x.ndim))).reshape(-1,1)

In [None]:
@jax.jit
def train_step1(state, batch, key):
    def loss_fn(params):
        def _logp(x):
            z, ldj = model1().apply({'params':params}, x, method=model1().inverse_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp + ldj)

        def synth_samp_logp_minus_logME(z):
            x, ldj = model1().apply({'params':params}, z, method=model1().forward_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp - ldj) - vector_log_me(x).reshape(-1,1)
                
        forward_kld_loss = (-1.0*_logp(batch)).mean()
        
        synth_z = jax.random.uniform(key, shape=batch.shape, minval=1.0e-10, maxval=(1.0-1.0e-10))
        reverse_kld_loss = (synth_samp_logp_minus_logME(synth_z)).mean()
                
       
        loss = forward_kld_loss + 0.1*reverse_kld_loss # + 0.001*force_MSE_loss
        return loss
    
   
    grads = jax.grad(loss_fn)(state.params)
    
    return state.apply_gradients(grads=grads)
    
        
    

@jax.jit
def train_step1_only_fkl(state, batch, key):
    def loss_fn(params):
        def _logp(x):
            z, ldj = model1().apply({'params':params}, x, method=model1().inverse_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp + ldj)
                
        forward_kld_loss = (-1.0*_logp(batch)).mean()
       
        loss = forward_kld_loss 
        return loss
    
    grads = jax.grad(loss_fn)(state.params)
    
    return state.apply_gradients(grads=grads)

In [None]:


@jax.jit
def eval1(params, batch, key):
    def eval_model(smoothnf):
        def _logp(x):
            z, ldj = model1().apply({'params':params}, x, method=model1().inverse_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp + ldj)

        def synth_samp_logp_minus_logME(z):
            x, ldj = model1().apply({'params':params}, z, method=model1().forward_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp - ldj) - vector_log_me(x).reshape(-1,1)
        
        def _force_MSE_and_cos(x, grad_flow_logprob):
            grad_logME = grad_vector_log_me_jit(x)
            _force = jnp.square(grad_logME - grad_flow_logprob).sum(axis=tuple(range(1,x.ndim))).reshape(-1,1)
            
            gme_norm = jnp.sqrt(jnp.square(grad_logME).sum(axis=tuple(range(1,x.ndim))).reshape(-1,1))
            gfl_norm = jnp.sqrt(jnp.square(grad_flow_logprob).sum(axis=tuple(range(1,x.ndim))).reshape(-1,1))
            _cos = jnp.multiply(grad_logME, grad_flow_logprob) / (gme_norm + 1e-10) / (gfl_norm + 1e-10)
            
            return _force, _cos
            
        
                
        forward_kld_loss = (-1.0*_logp(batch)).mean()
        
        synth_z = jax.random.uniform(key, shape=batch.shape, minval=1e-10, maxval=(1.0-1e-10))
        reverse_kld_loss = synth_samp_logp_minus_logME(synth_z).mean()
        
        _, _gradx_logp = model1().apply({'params':params}, batch, method=model1().val_and_gradx_logprob)
        _f, _c = _force_MSE_and_cos(batch, _gradx_logp)
        force_mse = _f.mean()
        cos_loss = _c.mean()
                
        metrics = {'fkld': forward_kld_loss,
                   'rkld': reverse_kld_loss,
                   'force': force_mse,
                   'cos': cos_loss,
                  }
        
        return metrics

    return nn.apply(eval_model, model1())({'params': params})




In [None]:
# lhe_events = pylhe.readLHE("./data/ee_ttbar_bqq_bqq/unweighted_events.lhe")
# events = get_multiple_lhe_event_to_ps_point(lhe_events, 10000)
# dataset1 = vector_rv_from_ps_point(events, E_cm)
# train_data1 = dataset1[0:8000]
# test_data1 = dataset1[8000:]

train_data1 = train_data
test_data1 = test_data

rng, key, eval_rng = jax.random.split(key, 3)

#batch_size = 50
#learning_rate = 1e-3
init_data = jnp.ones((batch_size,nDimPS))


#optimizer1 = optax.chain( optax.zero_nans(), optax.adam(learning_rate), optax.zero_nans())

state1_init = train_state.TrainState.create(
      apply_fn=model1().apply,
      params=model1().init(key, init_data)['params'],
      tx=optax.adam(learning_rate),
  )

state1 = state1_init

#prejit:
print("pre-jit model1")
batch = train_data1[np.random.choice(np.arange(len(train_data1)), size = batch_size)]

start_time = timeit.default_timer()
train_step1(state1, batch, key)
elapsed = timeit.default_timer() - start_time
print("train_step1: elapsed time", elapsed,"\n")

start_time = timeit.default_timer()
train_step1_only_fkl(state1, batch, key)
elapsed = timeit.default_timer() - start_time
print("train_step1_norkl: elapsed time", elapsed,"\n")


    

In [None]:
print("pre-jit eval1")
start_time = timeit.default_timer()
metrics = eval1(state1.params, test_data1, eval_rng)
print('eval1, fwdKLD: {:.4f}, revKLD: {:.4f}, forceMSE: {:.4f}, CosLoss: {:.4f} \n'.format(
        metrics['fkld'], metrics['rkld'], metrics['force'], metrics['cos']))
elapsed = timeit.default_timer() - start_time
print("elapsed time", elapsed,"\n")

In [None]:
saved_params1 = []
saved_params1.append(state1.params)

losses1 = []

state1 = state1_init

In [None]:

config.update("jax_debug_nans", True)

# num_epochs = 100
# steps_per_epoch = 75 #20000 // batch_size


for epoch in range(num_epochs):
    print("epoch",epoch)
    start_time = timeit.default_timer()
    
    
    for step in range(steps_per_epoch):
        
        batch = train_data1[np.random.choice(np.arange(len(train_data1)), size = batch_size)]
        rng, key = jax.random.split(rng)
        if epoch < num_warm_up_epochs:
            state1 = train_step1_only_fkl(state1, batch, key)
        else:
            try:
                state_update = train_step1(state1, batch, key)
                state1 = state_update
            except FloatingPointError:
                print("### GRAD FloatingPointError - continue without update")
        
    saved_params1.append(state1.params)
        
    elapsed = timeit.default_timer() - start_time
    print("elapsed time", elapsed)
    
    metrics = eval1(state1.params, test_data1, eval_rng)
    print('eval1 epoch: {}, fwdKLD: {:.4f}, revKLD: {:.4f}, forceMSE: {:.4f}, CosLoss: {:.4f} \n'.format(
        epoch, metrics['fkld'], metrics['rkld'], metrics['force'], metrics['cos']))
    
    losses1.append([ metrics['fkld'], metrics['rkld'], metrics['force'], metrics['cos']])

In [None]:

outfile_name = "results_comp_smoothNF/model1_wk0.1_wf0_results2.pkl"
outfile = open(outfile_name,'wb')
saved_param_dict = []
for ip in saved_params1:
    ip_dict = serialization.to_state_dict(ip)
    saved_param_dict.append(ip_dict)

outtuple = (saved_param_dict, losses1)
pickle.dump(outtuple, outfile)
outfile.close()

# No RKL

In [None]:
def model2():
    return SmoothNormalizingFlow(num_flows = 4,
                                num_biject = 4,
                                num_in_feat = nDimPS,
                                cond_mlp_width = [[100,100]])

In [None]:
@jax.jit
def train_step2(state, batch, key):
    def loss_fn(params):
        def _logp(x):
            z, ldj = model2().apply({'params':params}, x, method=model2().inverse_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp + ldj)
                
        forward_kld_loss = (-1.0*_logp(batch)).mean()
        
       
        loss = forward_kld_loss 
        return loss
    
    grads = jax.grad(loss_fn)(state.params)
    
    return state.apply_gradients(grads=grads)

In [None]:
@jax.jit
def eval2(params, batch, key):
    def eval_model(smoothnf):
        def _logp(x):
            z, ldj = model2().apply({'params':params}, x, method=model2().inverse_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp + ldj)

        def synth_samp_logp_minus_logME(z):
            x, ldj = model2().apply({'params':params}, z, method=model2().forward_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp - ldj) - vector_log_me(x).reshape(-1,1)
        
        def _force_MSE_and_cos(x, grad_flow_logprob):
            grad_logME = grad_vector_log_me_jit(x)
            _force = jnp.square(grad_logME - grad_flow_logprob).sum(axis=tuple(range(1,x.ndim))).reshape(-1,1)
            
            gme_norm = jnp.sqrt(jnp.square(grad_logME).sum(axis=tuple(range(1,x.ndim))).reshape(-1,1))
            gfl_norm = jnp.sqrt(jnp.square(grad_flow_logprob).sum(axis=tuple(range(1,x.ndim))).reshape(-1,1))
            _cos = jnp.multiply(grad_logME, grad_flow_logprob) / (gme_norm + 1e-10) / (gfl_norm + 1e-10)
            
            return _force, _cos

                
        forward_kld_loss = (-1.0*_logp(batch)).mean()
        #forward_kld_loss = (-1.0*_logp).mean()
        
        synth_z = jax.random.uniform(key, batch.shape)
        reverse_kld_loss = synth_samp_logp_minus_logME(synth_z).mean()
        
        _, _gradx_logp = model2().apply({'params':params}, batch, method=model2().val_and_gradx_logprob)
        _f, _c = _force_MSE_and_cos(batch, _gradx_logp)
        force_mse = _f.mean()
        cos_loss = _c.mean()
        
        
        metrics = {'fkld': forward_kld_loss,
                   'rkld': reverse_kld_loss,
                   'force': force_mse,
                   'cos': cos_loss,
                   'loss': forward_kld_loss + 0.1*reverse_kld_loss #+ 0.001*force_loss
                  }
        
        return metrics

    return nn.apply(eval_model, model2())({'params': params})



In [None]:
# lhe_events = pylhe.readLHE("./data/ee_ttbar_bqq_bqq/unweighted_events.lhe")
# events = get_multiple_lhe_event_to_ps_point(lhe_events, 10000)
# dataset2 = vector_rv_from_ps_point(events, E_cm)
# train_data2 = dataset2[0:8000]
# test_data2 = dataset2[8000:]

train_data2 = train_data
test_data2 = test_data

rng, key, eval_rng = jax.random.split(key, 3)

# batch_size = 50
#learning_rate = 1e-3
init_data = jnp.ones((batch_size,nDimPS))


#optimizer2 = optax.chain( optax.zero_nans(), optax.adam(learning_rate), optax.zero_nans())

state2_init = train_state.TrainState.create(
      apply_fn=model2().apply,
      params=model2().init(key, init_data)['params'],
      tx=optax.adam(learning_rate),
  )

state2 = state2_init

#prejit:
print("pre-jit model2")
start_time = timeit.default_timer()
batch = train_data2[np.random.choice(np.arange(len(train_data2)), size = batch_size)]
train_step2(state2, batch, key)
elapsed = timeit.default_timer() - start_time
print("elapsed time", elapsed,"\n")


    

In [None]:
print("pre-jit eval2")
start_time = timeit.default_timer()
metrics = eval2(state2.params, test_data2, eval_rng)
print('eval2,  fwdKLD: {:.4f}, revKLD: {:.4f}, forceMSE: {:.4f}, cosLoss: {:.4f} \n'.format(
         metrics['fkld'], metrics['rkld'], metrics['force'], metrics['cos']))
elapsed = timeit.default_timer() - start_time
print("elapsed time", elapsed,"\n")

In [None]:
saved_params2 = []
saved_params2.append(state2.params)

losses2 = []

In [None]:
config.update("jax_debug_nans", True)

# num_epochs = 100
# steps_per_epoch = 75 #20000 // batch_size


for epoch in range(num_epochs):
    print("epoch",epoch)
    start_time = timeit.default_timer()
    
    
    for step in range(steps_per_epoch):
        
        batch = train_data2[np.random.choice(np.arange(len(train_data2)), size = batch_size)]
        rng, key = jax.random.split(rng)
        state2 = train_step2(state2, batch, key)
        
    saved_params2.append(state2.params)
        
    elapsed = timeit.default_timer() - start_time
    print("elapsed time", elapsed)
    
    metrics = eval2(state2.params, test_data2, eval_rng)
    print('eval2 epoch: {}, fwdKLD: {:.4f}, revKLD: {:.4f}, forceMSE: {:.4f}, cosLoss: {:.4f} \n'.format(
        epoch,  metrics['fkld'], metrics['rkld'], metrics['force'], metrics['cos']))
    
    losses2.append([ metrics['fkld'], metrics['rkld'], metrics['force'], metrics['cos']])

In [None]:
outfile_name = "results_comp_smoothNF/model2_wk0_wf0_results2.pkl"
outfile = open(outfile_name,'wb')
saved_param_dict = []
for ip in saved_params2:
    ip_dict = serialization.to_state_dict(ip)
    saved_param_dict.append(ip_dict)

outtuple = (saved_param_dict, losses2)
pickle.dump(outtuple, outfile)
outfile.close()

In [None]:
# x_epoch = jnp.array([range(num_epochs)]).squeeze()

# m1_fkl = jnp.array(losses1)[:,1]
# m1_rkl = jnp.array(losses1)[:,2]

# m2_fkl = jnp.array(losses2)[:,1]
# m2_rkl = jnp.array(losses2)[:,2]

# plt.plot(x_epoch, m1_fkl, '-', c='#1f77b4', label=r'$\omega_k=0.1$: Likelihood')
# plt.plot(x_epoch, m1_rkl, '--', c='#1f77b4', alpha=0.5, label=r'$\omega_k=0.1$: Rev DKL')
# plt.plot(x_epoch, m2_fkl, '-', c='#ff7f0e', label=r'$\omega_k=0$: Likelihood')
# plt.plot(x_epoch, m2_rkl, '--', c='#ff7f0e', alpha=0.5, label=r'$\omega_k=0$: Rev DKL')
# plt.xlabel("Epoch")
# plt.ylabel("Loss")
# plt.legend()

# with Force

In [None]:
def model3():
    return SmoothNormalizingFlow(num_flows = 4,
                                num_biject = 4,
                                num_in_feat = nDimPS,
                                cond_mlp_width = [[100,100]])

In [None]:

# def force_MSE(x, grad_flow_logprob):
#     grad_logME = grad_vector_log_me_jit(x)
#     return jnp.square(grad_logME - grad_flow_logprob).sum(axis=tuple(range(1,x.ndim))).reshape(-1,1)


In [None]:
@jax.jit
def train_step3(state, batch, key):
    def loss_fn(params):
        
        def _logp(x):
            z, ldj = model3().apply({'params':params}, x, method=model3().inverse_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp + ldj)
        
        def _force_MSE(x, grad_flow_logprob):
            grad_logME = grad_vector_log_me_jit(x)
            return jnp.square(grad_logME - grad_flow_logprob).sum(axis=tuple(range(1,x.ndim))).reshape(-1,1)
                
        forward_kld_loss = (-1.0*_logp(batch)).mean()
        
        _, _gradx_logp = model3().apply({'params':params}, batch, method=model3().val_and_gradx_logprob)        
        force_mse = _force_MSE(batch, _gradx_logp).mean()
        
       
        loss = forward_kld_loss + (1.0e-6)*force_mse
        return loss
    
    grads = jax.grad(loss_fn)(state.params)
    
    return state.apply_gradients(grads=grads)


@jax.jit
def train_step3_only_fkl(state, batch, key):
    def loss_fn(params):
        def _logp(x):
            z, ldj = model3().apply({'params':params}, x, method=model3().inverse_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp + ldj)
                
        forward_kld_loss = (-1.0*_logp(batch)).mean()
       
        loss = forward_kld_loss 
        return loss
    
    grads = jax.grad(loss_fn)(state.params)
    
    return state.apply_gradients(grads=grads)

In [None]:
@jax.jit
def eval3(params, batch, key):
    def eval_model(smoothnf):
        
        def _logp(x):
            z, ldj = model3().apply({'params':params}, x, method=model3().inverse_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp + ldj)
        
        def synth_samp_logp_minus_logME(z):
            x, ldj = model3().apply({'params':params}, z, method=model3().forward_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp - ldj) - vector_log_me(x).reshape(-1,1)
        
        def _force_MSE_and_cos(x, grad_flow_logprob):
            grad_logME = grad_vector_log_me_jit(x)
            _force = jnp.square(grad_logME - grad_flow_logprob).sum(axis=tuple(range(1,x.ndim))).reshape(-1,1)
            
            gme_norm = jnp.sqrt(jnp.square(grad_logME).sum(axis=tuple(range(1,x.ndim))).reshape(-1,1))
            gfl_norm = jnp.sqrt(jnp.square(grad_flow_logprob).sum(axis=tuple(range(1,x.ndim))).reshape(-1,1))
            _cos = jnp.multiply(grad_logME, grad_flow_logprob) / (gme_norm + 1e-10) / (gfl_norm + 1e-10)
            
            return _force, _cos
          
        forward_kld_loss = (-1.0*_logp(batch)).mean()
        #forward_kld_loss = (-1.0*_logp).mean()
        
        synth_z = jax.random.uniform(key, batch.shape)
        reverse_kld_loss = synth_samp_logp_minus_logME(synth_z).mean()
        
        _, _gradx_logp = model3().apply({'params':params}, batch, method=model3().val_and_gradx_logprob)
        _f, _c = _force_MSE_and_cos(batch, _gradx_logp)
        force_mse = _f.mean()
        cos_loss = _c.mean()
        
        metrics = {'fkld': forward_kld_loss,
                   'rkld': reverse_kld_loss,
                   'force': force_mse,
                   'cos': cos_loss,
                   'loss': forward_kld_loss + 0.1*reverse_kld_loss #+ 0.001*force_loss
                  }
        
        return metrics

    return nn.apply(eval_model, model3())({'params': params})

In [None]:
# lhe_events = pylhe.readLHE("./data/ee_ttbar_bqq_bqq/unweighted_events.lhe")
# events = get_multiple_lhe_event_to_ps_point(lhe_events, 10000)
# dataset3 = vector_rv_from_ps_point(events, E_cm)
# train_data3 = dataset3[0:8000]
# test_data3 = dataset3[8000:]

train_data3 = train_data
test_data3 = test_data

rng, key, eval_rng = jax.random.split(key, 3)

# batch_size = 50
#learning_rate = 1e-3
init_data = jnp.ones((batch_size,nDimPS))

#optimizer2 = optax.chain( optax.zero_nans(), optax.adam(learning_rate), optax.zero_nans())

state3_init = train_state.TrainState.create(
      apply_fn=model3().apply,
      params=model3().init(key, init_data)['params'],
      tx=optax.adam(learning_rate),
  )

state3 = state3_init


#prejit:
print("pre-eval model3")
batch = train_data3[np.random.choice(np.arange(len(train_data3)), size = batch_size)]
start_time = timeit.default_timer()
train_step3(state3, batch, key)
elapsed = timeit.default_timer() - start_time
print("train_step3: elapsed time", elapsed,"\n")

start_time = timeit.default_timer()
train_step3_only_fkl(state3, batch, key)
elapsed = timeit.default_timer() - start_time
print("train_step3_nograd:  elapsed time", elapsed,"\n")


    

In [None]:
print("pre-jit eval3")
start_time = timeit.default_timer()
metrics = eval3(state3.params, test_data3, eval_rng)
print('eval3,  fwdKLD: {:.4f}, revKLD: {:.4f}, forceMSE: {:.4f}, cosLoss: {:.4f} \n'.format(
         metrics['fkld'], metrics['rkld'], metrics['force'], metrics['cos']))
elapsed = timeit.default_timer() - start_time
print("elapsed time", elapsed,"\n")

In [None]:
saved_params3 = []
saved_params3.append(state3_init.params)

losses3 = []

state3 = state3_init


In [None]:
# num_epochs = 100
# steps_per_epoch = 75 #20000 // batch_size


for epoch in range(num_epochs):
    print("epoch",epoch)
    start_time = timeit.default_timer()
    
    
    for step in range(steps_per_epoch):
        
        batch = train_data3[np.random.choice(np.arange(len(train_data3)), size = batch_size)]
        rng, key = jax.random.split(rng)
        if epoch < num_warm_up_epochs:
            state3 = train_step3_only_fkl(state3, batch, key)
        else:
            state3 = train_step3(state3, batch, key)
        
    saved_params3.append(state3.params)
        
    elapsed = timeit.default_timer() - start_time
    print("elapsed time", elapsed)
    
    metrics = eval3(state3.params, test_data3, eval_rng)
    print('eval3 epoch: {}, fwdKLD: {:.4f}, revKLD: {:.4f}, forceMSE: {:.4f}, cosLoss: {:.4f} \n'.format(
        epoch,  metrics['fkld'], metrics['rkld'], metrics['force'], metrics['cos']))
    
    losses3.append([metrics['fkld'], metrics['rkld'], metrics['force'], metrics['cos']])

In [None]:
outfile_name = "results_comp_smoothNF/model3_wk0_wf1Em6_results2.pkl"
outfile = open(outfile_name,'wb')
saved_param_dict = []
for ip in saved_params3:
    ip_dict = serialization.to_state_dict(ip)
    saved_param_dict.append(ip_dict)

outtuple = (saved_param_dict, losses3)
pickle.dump(outtuple, outfile)
outfile.close()

In [None]:
# ne=200

# x_epoch = jnp.array([range(ne)]).squeeze()

# m1_fkl = jnp.array(losses1)[:ne,1]
# m1_rkl = jnp.array(losses1)[:ne,2]

# m2_fkl = jnp.array(losses2)[:ne,1]
# m2_rkl = jnp.array(losses2)[:ne,2]

# m3_fkl = jnp.array(losses3)[:ne,1]
# m3_rkl = jnp.array(losses3)[:ne,2]

# plt.plot(x_epoch, m1_fkl, '-', c='#1f77b4', label=r'$\omega_k=0.1$: Likelihood')
# plt.plot(x_epoch, m1_rkl, '--', c='#1f77b4', alpha=0.5, label=r'$\omega_k=0.1$: Rev DKL')
# plt.plot(x_epoch, m2_fkl, '-', c='#ff7f0e', label=r'$\omega_k=0$: Likelihood')
# plt.plot(x_epoch, m2_rkl, '--', c='#ff7f0e', alpha=0.5, label=r'$\omega_k=0$: Rev DKL')
# plt.plot(x_epoch, m3_fkl, '-', c='#15b01a', label=r'$\omega_k=0$: Likelihood')
# plt.plot(x_epoch, m3_rkl, '--', c='#15b01a', alpha=0.5, label=r'$\omega_k=0$: Rev DKL')
# plt.xlabel("Epoch")
# plt.ylabel("Loss")
# plt.legend()

# With Force and RKL

In [None]:
def model4():
    return SmoothNormalizingFlow(num_flows = 4,
                                num_biject = 4,
                                num_in_feat = nDimPS,
                                cond_mlp_width = [[100,100]])

In [None]:
@jax.jit
def train_step4(state, batch, key):
    def loss_fn(params):
        
        def _logp(x):
            z, ldj = model4().apply({'params':params}, x, method=model4().inverse_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp + ldj)
        
        def synth_samp_logp_minus_logME(z):
            x, ldj = model4().apply({'params':params}, z, method=model4().forward_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp - ldj) - vector_log_me(x).reshape(-1,1)
        
        def _force_MSE(x, grad_flow_logprob):
            grad_logME = grad_vector_log_me_jit(x)
            return jnp.square(grad_logME - grad_flow_logprob).sum(axis=tuple(range(1,x.ndim))).reshape(-1,1)
            
        # FKL Loss
        forward_kld_loss = (-1.0*_logp(batch)).mean()
        
        # RKL Loss
        synth_z = jax.random.uniform(key, batch.shape)
        reverse_kld_loss = (synth_samp_logp_minus_logME(synth_z)).mean()
        
        # Force Loss
        _, _gradx_logp = model4().apply({'params':params}, batch, method=model4().val_and_gradx_logprob)           
        force_mse = _force_MSE(batch, _gradx_logp).mean()
        
       
        loss = forward_kld_loss + 0.1*reverse_kld_loss + (1.0e-6)*force_mse
        return loss
    
    grads = jax.grad(loss_fn)(state.params)
    
    return state.apply_gradients(grads=grads)


@jax.jit
def train_step4_only_fkl(state, batch, key):
    def loss_fn(params):
        def _logp(x):
            z, ldj = model4().apply({'params':params}, x, method=model4().inverse_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp + ldj)
                
        forward_kld_loss = (-1.0*_logp(batch)).mean()
       
        loss = forward_kld_loss 
        return loss
    
    grads = jax.grad(loss_fn)(state.params)
    
    return state.apply_gradients(grads=grads)

In [None]:
@jax.jit
def eval4(params, batch, key):
    def eval_model(smoothnf):
        
        def _logp(x):
            z, ldj = model4().apply({'params':params}, x, method=model4().inverse_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp + ldj)
        
        def synth_samp_logp_minus_logME(z):
            x, ldj = model4().apply({'params':params}, z, method=model4().forward_bijection)
            base_logp = jax.scipy.stats.uniform.logpdf(z, loc=0, scale=1).sum(axis=tuple(range(1,z.ndim))).reshape(-1,1)
            return (base_logp - ldj) - vector_log_me(x).reshape(-1,1)
        
        def _force_MSE_and_cos(x, grad_flow_logprob):
            grad_logME = grad_vector_log_me_jit(x)
            _force = jnp.square(grad_logME - grad_flow_logprob).sum(axis=tuple(range(1,x.ndim))).reshape(-1,1)
            
            gme_norm = jnp.sqrt(jnp.square(grad_logME).sum(axis=tuple(range(1,x.ndim))).reshape(-1,1))
            gfl_norm = jnp.sqrt(jnp.square(grad_flow_logprob).sum(axis=tuple(range(1,x.ndim))).reshape(-1,1))
            _cos = jnp.multiply(grad_logME, grad_flow_logprob) / (gme_norm + 1e-10) / (gfl_norm + 1e-10)
            
            return _force, _cos
                
        forward_kld_loss = (-1.0*_logp(batch)).mean()
        #forward_kld_loss = (-1.0*_logp).mean()
        
        synth_z = jax.random.uniform(key, batch.shape)
        reverse_kld_loss = synth_samp_logp_minus_logME(synth_z).mean()
        
        _, _gradx_logp = model4().apply({'params':params}, batch, method=model4().val_and_gradx_logprob)
        _f, _c = _force_MSE_and_cos(batch, _gradx_logp)
        force_mse = _f.mean()
        cos_loss = _c.mean()
        
        metrics = {'fkld': forward_kld_loss,
                   'rkld': reverse_kld_loss,
                   'force': force_mse,
                   'cos': cos_loss,
                   'loss': forward_kld_loss + 0.1*reverse_kld_loss #+ 0.001*force_loss
                  }
        
        return metrics

    return nn.apply(eval_model, model4())({'params': params})

In [None]:
# lhe_events = pylhe.readLHE("./data/ee_ttbar_bqq_bqq/unweighted_events.lhe")
# events = get_multiple_lhe_event_to_ps_point(lhe_events, 10000)
# dataset4 = vector_rv_from_ps_point(events, E_cm)
# train_data4 = dataset4[0:8000]
# test_data4 = dataset4[8000:]

train_data4 = train_data
test_data4 = test_data

rng, key, eval_rng = jax.random.split(key, 3)

# batch_size = 50
# learning_rate = 1e-3
init_data = jnp.ones((batch_size,nDimPS))

#optimizer2 = optax.chain( optax.zero_nans(), optax.adam(learning_rate), optax.zero_nans())

state4_init = train_state.TrainState.create(
      apply_fn=model4().apply,
      params=model4().init(key, init_data)['params'],
      tx=optax.adam(learning_rate),
  )

state4 = state4_init


#prejit:
print("pre-eval model4")
batch = train_data4[np.random.choice(np.arange(len(train_data4)), size = batch_size)]
start_time = timeit.default_timer()
train_step4(state4, batch, key)
elapsed = timeit.default_timer() - start_time
print("train_step4: elapsed time", elapsed,"\n")

start_time = timeit.default_timer()
train_step4_only_fkl(state4, batch, key)
elapsed = timeit.default_timer() - start_time
print("train_step4_only_fkl:  elapsed time", elapsed,"\n")

  

In [None]:
print("pre-jit eval4")
start_time = timeit.default_timer()
metrics = eval4(state4.params, test_data4, eval_rng)
print('eval4, fwdKLD: {:.4f}, revKLD: {:.4f}, forceMSE: {:.4f}, cosLoss: {:.4f} \n'.format(
        metrics['fkld'], metrics['rkld'], metrics['force'], metrics['cos']))
elapsed = timeit.default_timer() - start_time
print("elapsed time", elapsed,"\n")

In [None]:
saved_params4 = []
saved_params4.append(state4.params)

losses4 = []

state4 = state4_init

In [None]:
# state4 = state4_init
# state4 = train_step1_only_fkl(state4, batch, key)



In [None]:
# num_epochs = 100
# steps_per_epoch = 75 #20000 // batch_size


for epoch in range(num_epochs):
    print("epoch",epoch)
    start_time = timeit.default_timer()
    
    
    for step in range(steps_per_epoch):
        
        batch = train_data4[np.random.choice(np.arange(len(train_data4)), size = batch_size)]
        rng, key = jax.random.split(rng)
        if epoch < num_warm_up_epochs:
            state4 = train_step4_only_fkl(state4, batch, key)
        else:
            state4 = train_step4(state4, batch, key)
        
    saved_params4.append(state4.params)
        
    elapsed = timeit.default_timer() - start_time
    print("elapsed time", elapsed)
    
    metrics = eval4(state4.params, test_data4, eval_rng)
    print('eval4 epoch: {}, fwdKLD: {:.4f}, revKLD: {:.4f}, forceMSE: {:.4f}, cosLoss: {:.4f} \n'.format(
        epoch, metrics['fkld'], metrics['rkld'], metrics['force'], metrics['cos']))
    
    losses4.append([metrics['fkld'], metrics['rkld'], metrics['force'], metrics['cos']])

In [None]:
outfile_name = "results_comp_smoothNF/model4_wk0.1_wf1Em6_results2.pkl"
outfile = open(outfile_name,'wb')
saved_param_dict = []
for ip in saved_params4:
    ip_dict = serialization.to_state_dict(ip)
    saved_param_dict.append(ip_dict)

outtuple = (saved_param_dict, losses4)
pickle.dump(outtuple, outfile)
outfile.close()

# Model Evaluation




In [None]:
all_params1, all_losses1 = pickle.load( open( "results_comp_smoothNF/model1_wk0.1_wf0_results2.pkl", "rb" ) )
all_params2, all_losses2 = pickle.load( open( "results_comp_smoothNF/model2_wk0_wf0_results2.pkl", "rb" ) )
all_params3, all_losses3 = pickle.load( open( "results_comp_smoothNF/model3_wk0_wf1Em6_results2.pkl", "rb" ) )
all_params4, all_losses4 = pickle.load( open( "results_comp_smoothNF/model4_wk0.1_wf1Em6_results2.pkl", "rb" ) )

In [None]:
ns=12
ne=212
alpha=0.1
x_epoch = jnp.array([range(ne-ns)]).squeeze()

m1_fkl = jnp.array(all_losses1)[ns:ne,0]
m2_fkl = jnp.array(all_losses2)[ns:ne,0]
m3_fkl = jnp.array(all_losses3)[ns:ne,0]
m4_fkl = jnp.array(all_losses4)[ns:ne,0]

plt.figure(figsize=(6, 4))
plt.plot(x_epoch, m2_fkl, '-', c='#ff7f0e', alpha=0.7, label=r'$\omega_k=0$,    $\omega_f=0$')
plt.plot(x_epoch, m1_fkl, '-', c='#1f77b4', alpha=0.7, label=r'$\omega_k=0.1$, $\omega_f=0$')
plt.plot(x_epoch, m3_fkl, '-', c='#15b01a', alpha=0.7, label=r'$\omega_k=0$,    $\omega_f=10^{-5}$')
plt.plot(x_epoch, m4_fkl, '-', c='#c20078', alpha=0.7, label=r'$\omega_k=0.1$, $\omega_f=10^{-5}$')
plt.xlabel("Epoch")
plt.ylabel("Likelihood Loss")
plt.legend()
#plt.show()
plt.savefig('flow_likelihood.pdf')


In [None]:
plt.figure(figsize=(6, 4))
m1_rkl = jnp.array(all_losses1)[ns:ne,1]
m2_rkl = jnp.array(all_losses2)[ns:ne,1]
m3_rkl = jnp.array(all_losses3)[ns:ne,1]
m4_rkl = jnp.array(all_losses4)[ns:ne,1]

plt.plot(x_epoch, m2_rkl, '-', c='#ff7f0e', alpha=0.7, label=r'$\omega_k=0$,    $\omega_f=0$')
plt.plot(x_epoch, m1_rkl, '-', c='#1f77b4', alpha=0.7, label=r'$\omega_k=0.1$, $\omega_f=0$')
plt.plot(x_epoch, m3_rkl, '-', c='#15b01a', alpha=0.7, label=r'$\omega_k=0$,    $\omega_f=10^{-5}$')
plt.plot(x_epoch, m4_rkl, '-', c='#c20078', alpha=0.7, label=r'$\omega_k=0.1$, $\omega_f=10^{-5}$')
plt.xlabel("Epoch")
plt.ylabel("Reverse KL Loss")
plt.legend()
#plt.show()
plt.savefig('flow_reverseKL.pdf')

In [None]:
ne=200

In [None]:
rng, key = jax.random.split(rng) 
z1 = jax.random.uniform(key, (10000,14))
x1, _ = model1().apply({'params':all_params1[ne]}, z1, method=model1().forward_bijection)

rng, key = jax.random.split(rng)
z2 = jax.random.uniform(key, (10000,14))
x2, _ = model2().apply({'params':all_params2[ne]}, z2, method=model2().forward_bijection)

rng, key = jax.random.split(rng) 
z3 = jax.random.uniform(key, (10000,14))
x3, _ = model4().apply({'params':all_params3[ne]}, z3, method=model3().forward_bijection)

rng, key = jax.random.split(rng)
z4 = jax.random.uniform(key, (10000,14))
x4, _ = model4().apply({'params':all_params4[ne]}, z4, method=model4().forward_bijection)



In [None]:
#idx = 1

hbins = np.linspace(0,1,11)


for idx in range(14):
    plt.hist(train_data4[:,idx], bins=hbins, density=True, histtype='step', ec='black', lw=2)
    plt.hist(x1[:,idx], bins=hbins, density=True, histtype='step', ec='#1f77b4')
    plt.hist(x2[:,idx], bins=hbins, density=True, histtype='step', ec='#ff7f0e')
    plt.hist(x3[:,idx], bins=hbins, density=True, histtype='step', ec='#15b01a')
    plt.hist(x4[:,idx], bins=hbins, density=True, histtype='step', ec='#c20078')
    plt.show()




In [None]:
# ne=19
# metric = ["total", 'NLL', 'RKL $\\times 0.1$', 'FME $\\times 10^{-6}$']
# weight = [1.0, 1.0, 0.1, 1.0e-6]

# print('Metric', '&', '$\\omega_{RKL}=0$', '&', '$\\omega_{RKL}=0.1$', '&', '$\\omega_{RKL}=0$',       '&', '$\\omega_{RKL}=0.1$', "\\\\")
# print(''      , '&', '$\\omega_{FME}=0$', '&', '$\\omega_{FME}=0$'  , '&', '$\\omega_{FME}=10^{-6}$', '&', '$\\omega_{FME}=10^{-6}$$', '\\\\ \\hline')
# for iloss in range(1,4):
#     print(metric[iloss], '&', 
#           "{:.2f}".format(weight[iloss]*jnp.asarray(losses2[ne][iloss])), '&', 
#           "{:.2f}".format(weight[iloss]*jnp.asarray(losses1[ne][iloss])), '&', 
#           "{:.2f}".format(weight[iloss]*jnp.asarray(losses3[ne][iloss])), '&', 
#           "{:.2f}".format(weight[iloss]*jnp.asarray(losses4[ne][iloss])), '\\\\ \\hline')

# Extra

In [None]:
# _logp, _gradx_logp = model1().apply({'params':state3.params}, batch, method=model3().val_and_gradx_logprob)
# print(_gradx_logp)
# print(" ")
# me_gradlogp = grad_vector_log_me_jit(batch)
# print(jnp.square(me_gradlogp-_gradx_logp).sum(axis=-1))


# print(jnp.square(me_gradlogp - _gradx_logp).sum(axis=tuple(range(1,batch.ndim))).reshape(-1,1))

In [None]:
#ff = lambda x: model3().apply({'params':state3.params}, x, method=model3().forward_bijection)
#jax.jax.make_jaxpr(ff)(batch[0:2])

In [None]:
#jax.jax.make_jaxpr(scalar_log_me)(batch[0])