In [None]:
import tensorflow as tf
import awkward as ak
import numpy as np
import glob
import os

In [None]:
data_dir = '/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn'
parquet_dir = os.path.join(data_dir, 'preprocessed/dev')

In [None]:
epochs = 10
batch_size = 256
loss = 'mean_absolute_error'
optimizer = 'adam'
lr = 1.e-3

activation = 'relu'
initializer = 'he_normal'
batch_norm = False
dropout = 0
units = [128, 128]

train_size = 0.6
test_size = 0.2
val_size = 0.2

In [None]:
jet_numerical = ['pt_log', 'eta', 'mass', 'phi', 'area', 'qgl_axis2', 'qgl_ptD', 'qgl_mult']
jet_categorical = ['puId', 'partonFlavour']

pf_numerical = ['rel_pt', 'rel_eta', 'rel_phi', 'd0', 'dz', 'd0Err', 'dzErr', 'trkChi2', 'vtxChi2', 'puppiWeight', 'puppiWeightNoLep']
pf_categorical = ['charge', 'lostInnerHits', 'pdgId', 'pvAssocQuality', 'trkQuality']

In [None]:
jet_fields = jet_numerical + jet_categorical
pf_fields = pf_numerical + pf_categorical

jet_keys = [f'jet_{field}' for field in jet_fields]
pf_keys = [f'pf_{field}' for field in pf_fields]

num_jet = len(jet_keys)
num_pf = len(pf_keys)

In [None]:
dirs = glob.glob(os.path.join(parquet_dir, '*'))
num_dirs = len(dirs)
train_split = int(train_size * num_dirs)
test_split = int(test_size * num_dirs) + train_split

train_dirs = dirs[:train_split]
test_dirs = dirs[train_split:test_split]
val_dirs = dirs[test_split:]

In [None]:
train_dirs

In [None]:
def read_parquet(path):
    path = path.decode()

    jet = ak.from_parquet(os.path.join(path, 'jet.parquet'))
    pf = ak.from_parquet(os.path.join(path, 'pf.parquet'))
    
    row_lengths = ak.num(pf, axis=1)
    flat_pf = ak.flatten(pf, axis=1)
    
    data = [ak.to_numpy(row_lengths).astype(np.int32), ak.to_numpy(jet.target).astype(np.float32)]
    
    for field in jet_fields:
        data.append(ak.to_numpy(jet[field]).astype(np.float32))

    for field in pf_fields:
        data.append(ak.to_numpy(flat_pf[field]).astype(np.float32))
    
    return data

In [None]:
def read_parquet_wrapper(path):
    inp = [path]
    Tout = [tf.int32] + [tf.float32] + [tf.float32] * num_jet + [tf.float32] * num_pf
    
    cols = tf.numpy_function(read_parquet, inp=inp, Tout=Tout)
    
    keys = ['row_lengths'] + ['target'] + jet_keys + pf_keys
    data = {key: value for key, value in zip(keys, cols)}
    
    target = data.pop('target')
    target.set_shape((None,))
    
    row_lengths = data.pop('row_lengths')
    row_lengths.set_shape((None,))
    
    for field in jet_keys:
        # Shape from <unknown> to (None,)
        data[field].set_shape((None,))
        # Shape from (None,) to (None, 1)
        data[field] = tf.expand_dims(data[field], axis=1)
    
    for key in pf_keys:
        # Shape from <unknown> to (None,)
        data[key].set_shape((None,))
        # shape from (None,) to (None, None)
        data[key] = tf.RaggedTensor.from_row_lengths(data[key], row_lengths=row_lengths)
        # Shape from (None, None) to (None, None, 1)
        data[key] = tf.expand_dims(data[key], axis=2)
    
    jet_data = tf.concat([data[key] for key in jet_keys], axis=1)
    pf_data = tf.concat([data[key] for key in pf_keys], axis=2)
    inputs = (pf_data, jet_data)
    
    return inputs, target

In [None]:
def create_dataset(paths):
    ds = tf.data.Dataset.from_tensor_slices(paths)
    ds = ds.map(read_parquet_wrapper, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.unbatch().batch(256)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds

In [None]:
train_ds = create_dataset(train_dirs).shuffle(64)
val_ds = create_dataset(val_dirs)
test_ds = create_dataset(test_dirs)

In [None]:
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Activation, Dense, TimeDistributed, BatchNormalization, Dropout, Concatenate, Add
from src.layers import Sum

In [None]:
def get_deepset():
    constituents = Input(shape=(None, num_pf), ragged=True, name='constituents')

    constituents_slice = Input(shape=(constituents.shape[-1],), name='constituents_slice')

    deepset_outputs_slice = mlp(constituents_slice, name='deepset')

    deepset_model_slice = Model(inputs=constituents_slice, outputs=deepset_outputs_slice, name='deepset_model_slice')

    deepset_outputs = TimeDistributed(deepset_model_slice, name='deepset_distributed')(constituents)

    constituents_head = Sum(axis=1, name='constituents_head')(deepset_outputs)

    globals = Input(shape=(num_jet,), name='globals')

    inputs_head = Concatenate(name='head')([constituents_head, globals])

    x = mlp(inputs_head, name='head')

    outputs = Dense(1, name='head_dense_output')(x)

    model = Model(inputs=[constituents, globals], outputs=outputs, name='dnn')

    model.summary()

    for layer in model.layers:
        if isinstance(layer, TimeDistributed):
            layer.layer.summary()

    return model


def mlp(x, name):
    for idx, n in enumerate(units, start=1):
        x = Dense(n, kernel_initializer=initializer, name=f'{name}_dense_{idx}')(x)
        if batch_norm:
            x = BatchNormalization(name=f'{name}_batch_normalization_{idx}')(x)
        x = Activation(activation, name=f'{name}_activation_{idx}')(x)
        if dropout:
            x = Dropout(dropout, name=f'{name}_dropout_{idx}')(x)
    return x

In [None]:
dnn = get_deepset()
dnn.compile(optimizer=optimizer, loss=loss)
dnn.optimizer.lr.assign(lr)

In [None]:
# tf.keras.utils.plot_model(dnn, dpi=100, show_shapes=True, expand_nested=True)

In [None]:
fit = dnn.fit(train_ds, validation_data=val_ds, epochs=epochs)