In [1]:
import tensorflow as tf
import awkward as ak
import numpy as np
import pickle
import glob
import sys
import os

parent_dir = os.path.dirname(os.getcwd())
sys.path.append(parent_dir)

In [2]:
data_dir = '/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn/preprocessed/dev'
record_files = glob.glob(os.path.join(data_dir, '*.tfrecords'))

In [3]:
epochs = 10
batch_size = 256
shuffle_buffer = 64
loss = 'mean_absolute_error'
optimizer = 'adam'
lr = 1.e-3

activation = 'relu'
initializer = 'he_normal'
pooling = 'average' # average or max
batch_norm = False
shortcut = False
dropout = 0
K = 16
channels = [
  [64, 64, 64],
  [128, 128, 128],
  [256, 256, 256]
]
units = [128, 128]

train_size = 0.6
test_size = 0.2
val_size = 0.2

num_points = 100

In [4]:
jet_numerical = ['log_pt', 'eta', 'mass', 'phi', 'area', 'qgl_axis2', 'qgl_ptD', 'qgl_mult']
jet_categorical = ['puId', 'partonFlavour']

mandatory_pf = ['rel_pt', 'rel_eta', 'rel_phi'] # order must be as is
pf_numerical = mandatory_pf + ['d0', 'dz', 'd0Err', 'dzErr', 'trkChi2', 'vtxChi2', 'puppiWeight', 'puppiWeightNoLep']
pf_categorical = ['charge', 'lostInnerHits', 'pdgId', 'pvAssocQuality', 'trkQuality']

In [5]:
jet_fields = jet_numerical + jet_categorical
pf_fields = pf_numerical + pf_categorical

jet_keys = [f'jet_{field}' for field in jet_fields]
pf_keys = [f'pf_{field}' for field in pf_fields]

num_jet = len(jet_keys)
num_pf = len(pf_keys)

In [6]:
with open(os.path.join(data_dir, 'metadata.pkl'), 'rb') as f:
    metadata = pickle.load(f)

In [7]:
num_files = len(record_files)
train_split = int(train_size * num_files)
test_split = int(test_size * num_files) + train_split

train_files = record_files[:train_split]
test_files = record_files[train_split:test_split]
val_files = record_files[test_split:]

In [8]:
def parse_record(example_proto):
    return tf.io.parse_single_example(example_proto, features=metadata)

In [9]:
def select_features(batch):
    jet_data = tf.stack([batch[key] for key in jet_keys], axis=1)
    pf_data = tf.stack([batch[key].values for key in pf_keys], axis=1)
    pf_data = tf.RaggedTensor.from_row_lengths(pf_data, row_lengths=batch['row_lengths']).to_tensor(shape=(None, num_points, num_pf))
    
    mask = tf.cast(tf.math.not_equal(pf_data[:,:,0:1], 0), dtype=tf.float32) # 1 if valid
    coord_shift = tf.multiply(1e6, tf.cast(tf.math.equal(mask, 0), dtype=tf.float32))
    points = tf.concat([pf_data[:,:,1:2], pf_data[:,:,2:3]], axis=2)
    
    inputs = (pf_data, jet_data, points, coord_shift, mask)
    return inputs, batch['target']

In [10]:
def create_dataset(paths):
    ds = tf.data.TFRecordDataset(filenames=[record_files], num_parallel_reads=tf.data.experimental.AUTOTUNE)
    ds = ds.map(parse_record, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.batch(batch_size)
    ds = ds.map(select_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    return ds

In [11]:
train_ds = create_dataset(train_files).shuffle(shuffle_buffer)
val_ds = create_dataset(val_files)
test_ds = create_dataset(test_files)

In [13]:
import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Activation, Add, BatchNormalization, Conv2D, Dense, Dropout, Layer, Multiply, Concatenate
from src.layers import Mean, Max, Expand, Squeeze

In [14]:
def get_particle_net():
    """
    ParticleNet: Jet Tagging via Particle Clouds
    arxiv.org/abs/1902.08570
    
    Parameters
    ----------
    input_shapes : dict
        The shapes of each input (`points`, `features`, `mask`).
    """

    features = Input(name='features', shape=(num_points, num_pf))
    globals = Input(name='globals', shape=(num_jet,))
    points = Input(name='points', shape=(num_points, 2))
    coord_shift = Input(name='coord_shift', shape=(num_points, 1))
    mask = Input(name='mask', shape=(num_points, 1))

    outputs = particle_net_base(points, features, mask, coord_shift, globals)

    model = Model(inputs=[features, globals, points, coord_shift, mask], outputs=outputs)

    model.summary()

    return model


def particle_net_base(points, features, mask, coord_shift, globals):
    """
    points : (N, P, C_coord)
    features:  (N, P, C_features), optional
    mask: (N, P, 1), optional
    """

    # fts = tf.squeeze(BatchNormalization(name='fts_bn')(tf.expand_dims(features, axis=2)), axis=2)
    fts = features
    for layer_idx, sub_channels in enumerate(channels, start=1):
        pts = Add(name=f'add_{layer_idx}')([coord_shift, points]) if layer_idx == 1 else Add(name=f'add_{layer_idx}')([coord_shift, fts])
        fts = edge_conv(
            pts, fts, num_points, sub_channels, name=f'edge_conv_{layer_idx}'
        )

    fts = Multiply()([fts, mask])

    pool = Mean(axis=1)(fts) # (N, C)

    x = Concatenate(name='head')([pool, globals])

    for layer_idx, n in enumerate(units):
        x = Dense(n)(x)
        x = Activation(activation)(x)
        if dropout:
            x = Dropout(dropout)(x)
    out = Dense(1, name='out')(x)
    return out # (N, num_classes)


def edge_conv(points, features, num_points, sub_channels, name):
    """EdgeConv
    Args:
        K: int, number of neighbors
        in_channels: # of input channels
        channels: tuple of output channels
        pooling: pooling method ('max' or 'average')
    Inputs:
        points: (N, P, C_p)
        features: (N, P, C_0)
    Returns:
        transformed points: (N, P, C_out), C_out = channels[-1]
    """

    fts = features
    knn_fts = KNearestNeighbors(num_points, K, name=f'{name}_knn')([points, fts])

    x = knn_fts
    for idx, channel in enumerate(sub_channels, start=1):
        x = Conv2D(
            channel, kernel_size=(1, 1), strides=1, data_format='channels_last',
            use_bias=False if batch_norm else True, kernel_initializer=initializer, name=f'{name}_conv_{idx}'
        )(x)
        if batch_norm:
            x = BatchNormalization(name=f'{name}_batchnorm_{idx}')(x)
        if activation:
            x = Activation(activation, name=f'{name}_activation_{idx}')(x)

    if pooling == 'max':
        fts = Max(axis=2, name=f'{name}_max')(x) # (N, P, C')
    else:
        fts = Mean(axis=2, name=f'{name}_mean')(x) # (N, P, C')

    if shortcut:
        sc = Expand(axis=2, name=f'{name}_shortcut_expand')(features)
        sc = Conv2D(
            sub_channels[-1], kernel_size=(1, 1), strides=1, data_format='channels_last',
            use_bias=False if batch_norm else True, kernel_initializer=initializer, name=f'{name}_shortcut_conv'
        )(sc)
        if batch_norm:
            sc = BatchNormalization(name=f'{name}_shortcut_batchnorm')(sc)
        sc = Squeeze(axis=2, name=f'{name}_shortcut_squeeze')(sc)

        x = Add(name=f'{name}_add')([sc, fts])
    else:
        x = fts

    return Activation(activation, name=f'{name}_activation')(x) # (N, P, C')


class KNearestNeighbors(Layer):
    def __init__(self, num_points, k, **kwargs):
        super().__init__(**kwargs)
        self.num_points = num_points
        self.k = k

    def call(self, inputs):
        points, features = inputs
        # distance
        D = batch_distance_matrix_general(points, points) # (N, P, P)
        _, top_k_indices = tf.math.top_k(-D, k=self.k + 1) # (N, P, K+1)
        top_k_indices = top_k_indices[:, :, 1:] # (N, P, K)

        queries_shape = tf.shape(features)
        batch_size = queries_shape[0]
        batch_indices = tf.tile(tf.reshape(tf.range(batch_size), (-1, 1, 1, 1)), (1, self.num_points, self.k, 1))
        indices = tf.concat([batch_indices, tf.expand_dims(top_k_indices, axis=3)], axis=3) # (N, P, K, 2)
        
        knn_fts =  tf.gather_nd(features, indices) # (N, P, K, C)
        knn_fts_center = tf.tile(tf.expand_dims(features, axis=2), (1, 1, self.k, 1)) # (N, P, K, C)

        return tf.concat([knn_fts_center, tf.subtract(knn_fts, knn_fts_center)], axis=-1) # (N, P, K, 2*C)


# A shape is (N, P_A, C), B shape is (N, P_B, C)
# D shape is (N, P_A, P_B)
def batch_distance_matrix_general(A, B):
    r_A = tf.math.reduce_sum(A * A, axis=2, keepdims=True)
    r_B = tf.math.reduce_sum(B * B, axis=2, keepdims=True)
    m = tf.linalg.matmul(A, tf.transpose(B, perm=(0, 2, 1)))
    D = r_A - 2 * m + tf.transpose(r_B, perm=(0, 2, 1))
    return D

In [15]:
dnn = get_particle_net()
dnn.compile(optimizer=optimizer, loss=loss)
dnn.optimizer.lr.assign(lr)

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
coord_shift (InputLayer)        [(None, 100, 1)]     0                                            
__________________________________________________________________________________________________
points (InputLayer)             [(None, 100, 2)]     0                                            
__________________________________________________________________________________________________
add_1 (Add)                     (None, 100, 2)       0           coord_shift[0][0]                
                                                                 points[0][0]                     
__________________________________________________________________________________________________
features (InputLayer)           [(None, 100, 16)]    0                                        

<tf.Variable 'UnreadVariable' shape=() dtype=float32, numpy=0.001>

In [17]:
# fit = dnn.fit(train_ds, validation_data=val_ds, epochs=epochs)