In [None]:
import os
import sys

NJET_BLHA='/Users/simon/packages/njet-3.1.1-1L/blha/'
sys.path.append(NJET_BLHA)

import numpy as np
from pstools.rambo import generate, dot
from njettools.njet_interface import *
from nntools.model import Model

# Setup in NJet interface to ee->3j amplitudes #

In [None]:
# set a BLHA file to provide squared amplitudes via NJet
contract_file = 'NJ_contract_ee3j_tree.lh'

In [None]:
# choose the number of training points (will later be split into NN train/test set)
n_training_points = 100000
# choose the number of points for interpolation tests after training
# NB - different from the training/validation split during training
n_test_points = 500000
n_points = n_training_points+n_test_points
delta_cut = 0.02

# Generate phase space points (RAMBO) and run NJet #

In [None]:
# generate 2 -> 3 phase-space points for training
momenta, n_trials = generate(3, n_points, rts=1000., delta=delta_cut)
momenta = momenta.tolist()

In [None]:
n_legs = len(momenta[0])
print(n_legs, n_points, n_points/n_trials)

In [None]:
# check phase-space point satisfies momentum conservation
p = momenta[0]
-p[0]-p[1]+p[2]+p[3]+p[4]

In [None]:
# start the NJet interface
olp = njet.OLP()
status = njet_init(contract_file)

if status == True:
    print ("OLP read in correctly")
else:
    print ("seems to be a problem with the contract file...")

In [None]:
mur = 100.
alphas = 0.118
alpha = 1/137.

In [None]:
# demonstrating the evaluation of tree-level matrix element squared
testval0 = olp.OLP_EvalSubProcess(1,
                       momenta[0],
                       alphas=alphas,
                       alpha=alpha,
                       mur=mur,
                       retlen=1)

print(testval0)

testval1 = olp.OLP_EvalSubProcess(1,
                       momenta[1],
                       alphas=alphas,
                       alpha=alpha,
                       mur=mur,
                       retlen=1)

print(testval1)

In [None]:
# checking against analytic formula
# a,b -> 1q,2qb,3g

def amp0sq(alphas, p):
    # match some (not very sensible) conventions in NJet
    norm = (16*np.pi)**3/16
    
    Qu = 2/3 # fractional quark charge
    Nc = 3
    CF = (Nc**2-1)/(2.*Nc)

    sab=2*dot(p[0],p[1])
    s12=2*dot(p[2],p[3])
    s13=2*dot(p[2],p[4])
    s23=2*dot(p[3],p[4])

    s1a=-2*dot(p[2],p[0])
    s1b=-2*dot(p[2],p[1])
    s2a=-2*dot(p[3],p[0])
    s2b=-2*dot(p[3],p[1])

    amp0sq = Qu**2*norm*alphas*CF*Nc*(s1a**2+s1b**2+s2a**2+s2b**2)/(sab*s13*s23)
    return amp0sq

In [None]:
print("|A|^2  =", amp0sq(alphas, momenta[0]))
print("ratio = ", amp0sq(alphas, momenta[0])/testval0[0])

print("|A|^2  =", amp0sq(alphas, momenta[1]))
print("ratio = ", amp0sq(alphas, momenta[1])/testval1[0])

In [None]:
NJ_treevals = [];
for pt in range(n_points):
    vals = olp.OLP_EvalSubProcess(1, momenta[pt], alphas=alphas, alpha=alpha, mur=mur, retlen=1)
    NJ_treevals.append(vals[0])

In [None]:
# dump generated data in case NJet and interface not available
np.save("data/NJbasic_ee3j_tree_momenta.npy", momenta)
np.save("data/NJbasic_ee3j_tree_values.npy", NJ_treevals)

# Train NN with amplitude data #

In [None]:
momentaALL = np.load("data/NJbasic_ee3j_tree_momenta.npy")
NJ_treevalsALL = np.load("data/NJbasic_ee3j_tree_values.npy")

In [None]:
momenta = momentaALL[:n_training_points]
NJ_treevals = NJ_treevalsALL[:n_training_points]

In [None]:
NN = Model(
    5*4, # train with all momenta components 
    momenta, # input data from Rambo PS generator
    np.array(NJ_treevals) # data points from NJet evaluations
)

In [None]:
model, x_mean, x_std, y_mean, y_std = NN.fit(layers=[16,32,16], epoch_interval=100)

In [None]:
testmoms = momenta

In [None]:
x_standardized = NN.process_testing_data(moms=testmoms,
                                         x_mean=x_mean,
                                         x_std=x_std,
                                         y_mean=y_mean,
                                         y_std=y_std)

In [None]:
mpred = model.predict(x_standardized)

In [None]:
amp_pred = NN.destandardise_data(mpred.reshape(-1),x_mean=x_mean,x_std=x_std,y_mean=y_mean,y_std=y_std)

In [None]:
diff = (amp_pred-np.array(NJ_treevals))/(amp_pred+np.array(NJ_treevals))

In [None]:
np.average(diff)

In [None]:
max(diff)

In [None]:
min(diff)

In [None]:
np.median(diff)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.hist(diff, density=False, bins=500)
plt.xlim([-2,2])
plt.ylabel('Accuracy')
plt.xlabel('Data');

# test trained network #

In [None]:
momenta_test = momentaALL[n_training_points:]
NJ_treevals_test = NJ_treevalsALL[n_training_points:]

In [None]:
len(NJ_treevals_test)

In [None]:
x_standardized = NN.process_testing_data(moms=momenta_test,x_mean=x_mean,x_std=x_std,y_mean=y_mean,y_std=y_std)

mpred = model.predict(x_standardized)

amp_pred = NN.destandardise_data(mpred.reshape(-1),x_mean=x_mean,x_std=x_std,y_mean=y_mean,y_std=y_std)

diff = 2.*(amp_pred-np.array(NJ_treevals_test))/(amp_pred+np.array(NJ_treevals_test))

In [None]:
mybins = np.histogram_bin_edges(diff, bins=100, range=(-2,2))

plt.hist(diff, density=False, bins=100)
plt.xlim([-2,2])
plt.ylabel('Accuracy')
plt.xlabel('Data');

In [None]:
logdiff = np.log10(abs(diff))

mybins = np.histogram_bin_edges(logdiff, bins=100, range=(-4,1))
plt.hist(logdiff, density=False, bins=100)
plt.xlim([-5,1])
plt.ylabel('Log Accuracy')
plt.xlabel('Data');

In [None]:
logratio = np.log10(amp_pred/np.array(NJ_treevals_test))

mybins = np.histogram_bin_edges(logratio, bins=100, range=(-4,1))
plt.hist(logratio, density=False, bins=100)
plt.xlim([-1,1])
plt.ylabel('Log Accuracy')
plt.xlabel('Data');

# cross-section check #

In [None]:
np.mean(NJ_treevals_test), np.std(NJ_treevals_test)*np.mean(NJ_treevals_test)

In [None]:
np.mean(amp_pred), np.std(amp_pred)*np.mean(amp_pred)

In [None]:
xs_NJ = []
xs_NN = []
for i in range(1,n_test_points):
    xs_NJ.append([np.mean(NJ_treevals_test[0:i]), np.std(NJ_treevals_test[0:i])])
    xs_NN.append([np.mean(amp_pred[0:i]), np.std(amp_pred[0:i])])

xs_NJ = np.array(xs_NJ)
xs_NN = np.array(xs_NN)

In [None]:
xs_NJ = [[np.mean(NJ_treevals_test[0:i]), np.std(NJ_treevals_test[0:i])] for i in range(1,int(n_test_points))]
xs_NN = [[np.mean(amp_pred[0:i]), np.std(amp_pred[0:i])] for i in range(1,int(n_test_points))]

In [None]:
xs_NJ = np.array(xs_NJ)
xs_NN = np.array(xs_NN)

In [None]:
plotdata1 = xs_NJ[0::100,0]
len(plotdata1)

In [None]:
plotdata2 = xs_NN[0::100,0]
len(plotdata2)

In [None]:
plt.plot(100*np.array(range(len(plotdata1))), plotdata1, 'b-', label='NJet')
plt.plot(100*np.array(range(len(plotdata2))), plotdata2, 'r-', label='NN')
plt.xlim([0,n_points])
plt.ylim([0.040,0.055])
plt.ylabel('sigma')
plt.xlabel('iteration');
plt.text(n_test_points/10,0.052,'delta cut = '+str(delta_cut))
plt.legend()