<a href="https://colab.research.google.com/github/nickchak21/QuarkGluonClassifiers/blob/master/Executable_Colab_Notebooks/PFN_example_herwig.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install energyflow
!pip install h5py

Collecting energyflow
[?25l  Downloading https://files.pythonhosted.org/packages/35/ba/f598bafbde78553b962dc1f693ef95365cc752ddbdb448856858093579eb/EnergyFlow-1.0.0-py2.py3-none-any.whl (679kB)
[K     |████████████████████████████████| 686kB 6.8MB/s 
[?25hCollecting h5py>=2.9.0
[?25l  Downloading https://files.pythonhosted.org/packages/60/06/cafdd44889200e5438b897388f3075b52a8ef01f28a17366d91de0fa2d05/h5py-2.10.0-cp36-cp36m-manylinux1_x86_64.whl (2.9MB)
[K     |████████████████████████████████| 2.9MB 50.0MB/s 
Installing collected packages: h5py, energyflow
  Found existing installation: h5py 2.8.0
    Uninstalling h5py-2.8.0:
      Successfully uninstalled h5py-2.8.0
Successfully installed energyflow-1.0.0 h5py-2.10.0


In [2]:
!python --version

Python 3.6.8


In [3]:
!pip install POT

Collecting POT
[?25l  Downloading https://files.pythonhosted.org/packages/15/36/07d3c0960a590b88b81fa1837e666cc7479b90c7e9fd1063024ce9331122/POT-0.6.0-cp36-cp36m-manylinux1_x86_64.whl (305kB)
[K     |█                               | 10kB 17.2MB/s eta 0:00:01[K     |██▏                             | 20kB 4.3MB/s eta 0:00:01[K     |███▏                            | 30kB 6.1MB/s eta 0:00:01[K     |████▎                           | 40kB 7.7MB/s eta 0:00:01[K     |█████▍                          | 51kB 5.0MB/s eta 0:00:01[K     |██████▍                         | 61kB 5.6MB/s eta 0:00:01[K     |███████▌                        | 71kB 6.3MB/s eta 0:00:01[K     |████████▋                       | 81kB 6.9MB/s eta 0:00:01[K     |█████████▋                      | 92kB 7.6MB/s eta 0:00:01[K     |██████████▊                     | 102kB 6.5MB/s eta 0:00:01[K     |███████████▉                    | 112kB 6.5MB/s eta 0:00:01[K     |████████████▉                   | 122kB 6.5MB/

In [4]:
!python -c "import energyflow; energyflow.utils.get_examples()"

Downloading efp_example.py from https://github.com/pkomiske/EnergyFlow/raw/master/examples/efp_example.py to /root/.energyflow/examples
Downloading dnn_example.py from https://github.com/pkomiske/EnergyFlow/raw/master/examples/dnn_example.py to /root/.energyflow/examples
Downloading cnn_example.py from https://github.com/pkomiske/EnergyFlow/raw/master/examples/cnn_example.py to /root/.energyflow/examples
Downloading efn_example.py from https://github.com/pkomiske/EnergyFlow/raw/master/examples/efn_example.py to /root/.energyflow/examples
Downloading pfn_example.py from https://github.com/pkomiske/EnergyFlow/raw/master/examples/pfn_example.py to /root/.energyflow/examples

Summary of examples:
efp_example.py exists at /root/.energyflow/examples
dnn_example.py exists at /root/.energyflow/examples
cnn_example.py exists at /root/.energyflow/examples
efn_example.py exists at /root/.energyflow/examples
pfn_example.py exists at /root/.energyflow/examples



In [0]:
%pycat /root/.energyflow/examples/pfn_example.py

In [0]:
rm /root/.energyflow/examples/pfn_example.py

In [7]:
%%writefile /root/.energyflow/examples/pfn_example.py
"""An example involving Particle Flow Networks (PFNs), which were 
introduced in [1810.05165](https://arxiv.org/abs/1810.05165). The 
[`PFN`](../docs/archs/#pfn) class is used to construct the 
network architecture. The output of the example is a plot of the 
ROC curves obtained by the PFN as well as the jet mass and 
constituent multiplicity observables.
"""

# standard library imports
from __future__ import absolute_import, division, print_function

# standard numerical library imports
import numpy as np

# energyflow imports
import energyflow as ef
from energyflow.archs import PFN
from energyflow.datasets import qg_jets
from energyflow.utils import data_split, remap_pids, to_categorical

# attempt to import sklearn
try:
    from sklearn.metrics import roc_auc_score, roc_curve
except:
    print('please install scikit-learn in order to make ROC curves')
    roc_curve = False

# attempt to import matplotlib
try:
    import matplotlib.pyplot as plt
except:
    print('please install matploltib in order to make plots')
    plt = False

################################### SETTINGS ###################################
# the commented values correspond to those in 1810.05165
###############################################################################

# data controls, can go up to 2000000 for full dataset
train, val, test = 675000, 90000, 135000
# train, val, test = 1000000, 200000, 200000
use_pids = True

# network architecture parameters
Phi_sizes, F_sizes = (100, 100, 128), (100, 100, 100)
# Phi_sizes, F_sizes = (100, 100, 256), (100, 100, 100)

# network training parameters
num_epoch = 20
batch_size = 500

################################################################################

# load data
X, y = qg_jets.load(train + val + test, generator='herwig')

# convert labels to categorical
Y = to_categorical(y, num_classes=2)

print('Loaded quark and gluon jets')

# preprocess by centering jets and normalizing pts
for x in X:
    mask = x[:,0] > 0
    yphi_avg = np.average(x[mask,1:3], weights=x[mask,0], axis=0)
    x[mask,1:3] -= yphi_avg
    x[mask,0] /= x[:,0].sum()

# handle particle id channel
if use_pids:
    remap_pids(X, pid_i=3)
else:
    X = X[:,:,:3]

print('Finished preprocessing')

# do train/val/test split 
(X_train, X_val, X_test,
 Y_train, Y_val, Y_test) = data_split(X, Y, val=val, test=test)

print('Done train/val/test split')
print('Model summary:')

# build architecture
pfn = PFN(input_dim=X.shape[-1], Phi_sizes=Phi_sizes, F_sizes=F_sizes)

# train model
pfn.fit(X_train, Y_train,
          epochs=num_epoch,
          batch_size=batch_size,
          validation_data=(X_val, Y_val),
          verbose=1)

# get predictions on test data
preds = pfn.predict(X_test, batch_size=1000)

# get ROC curve if we have sklearn
if roc_curve:
    pfn_fp, pfn_tp, threshs = roc_curve(Y_test[:,1], preds[:,1])

    # get area under the ROC curve
    auc = roc_auc_score(Y_test[:,1], preds[:,1])
    print()
    print('PFN AUC:', auc)
    print()

    # make ROC curve plot if we have matplotlib
    if plt:

        # get multiplicity and mass for comparison
        masses = np.asarray([ef.ms_from_p4s(ef.p4s_from_ptyphims(x).sum(axis=0)) for x in X])
        mults = np.asarray([np.count_nonzero(x[:,0]) for x in X])
        mass_fp, mass_tp, threshs = roc_curve(Y[:,1], -masses)
        mult_fp, mult_tp, threshs = roc_curve(Y[:,1], -mults)

        # some nicer plot settings 
        plt.rcParams['figure.figsize'] = (4,4)
        plt.rcParams['font.family'] = 'serif'
        plt.rcParams['figure.autolayout'] = True

        # plot the ROC curves
        plt.plot(pfn_tp, 1-pfn_fp, '-', color='black', label='PFN')
        plt.plot(mass_tp, 1-mass_fp, '-', color='blue', label='Jet Mass')
        plt.plot(mult_tp, 1-mult_fp, '-', color='red', label='Multiplicity')

        # axes labels
        plt.xlabel('Quark Jet Efficiency')
        plt.ylabel('Gluon Jet Rejection')

        # axes limits
        plt.xlim(0, 1)
        plt.ylim(0, 1)

        # make legend and show plot
        plt.legend(loc='lower left', frameon=False)
        plt.show()

Writing /root/.energyflow/examples/pfn_example.py


In [8]:
!python /root/.energyflow/examples/pfn_example.py

Using TensorFlow backend.
Downloading QG_jets_herwig_0.npz from https://www.dropbox.com/s/xizexr2tjq2bm59/QG_jets_herwig_0.npz?dl=1 to /root/.energyflow/datasets
Downloading QG_jets_herwig_1.npz from https://www.dropbox.com/s/ym675q2ui3ik3n9/QG_jets_herwig_1.npz?dl=1 to /root/.energyflow/datasets
Downloading QG_jets_herwig_2.npz from https://www.dropbox.com/s/qic6ejl27y6vpqj/QG_jets_herwig_2.npz?dl=1 to /root/.energyflow/datasets
Downloading QG_jets_herwig_3.npz from https://www.dropbox.com/s/ea5a9wruo7sf3zy/QG_jets_herwig_3.npz?dl=1 to /root/.energyflow/datasets
Downloading QG_jets_herwig_4.npz from https://www.dropbox.com/s/5iz5q2pjcys74tb/QG_jets_herwig_4.npz?dl=1 to /root/.energyflow/datasets
Downloading QG_jets_herwig_5.npz from https://www.dropbox.com/s/6zha7fka0dl7t30/QG_jets_herwig_5.npz?dl=1 to /root/.energyflow/datasets
Downloading QG_jets_herwig_6.npz from https://www.dropbox.com/s/vljp5nhoocv2zmf/QG_jets_herwig_6.npz?dl=1 to /root/.energyflow/datasets
Downloading QG_jets_he