In [None]:
import tensorflow as tf
import awkward as ak
import numpy as np
import pickle
import glob
import sys
import os

parent_dir = os.path.dirname(os.getcwd())
sys.path.append(parent_dir)

In [None]:
data_dir = '/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn/preprocessed/dev'
record_files = glob.glob(os.path.join(data_dir, '*.tfrecords'))

In [None]:
epochs = 10
batch_size = 256
shuffle_buffer = 64
loss = 'mean_absolute_error'
optimizer = 'adam'
lr = 1.e-3

activation = 'relu'
initializer = 'he_normal'
batch_norm = False
dropout = 0
units = [128, 128]

train_size = 0.6
test_size = 0.2
val_size = 0.2

In [None]:
jet_numerical = ['log_pt', 'eta', 'mass', 'phi', 'area', 'qgl_axis2', 'qgl_ptD', 'qgl_mult']
jet_categorical = ['puId', 'partonFlavour']

pf_numerical = ['rel_pt', 'rel_eta', 'rel_phi', 'd0', 'dz', 'd0Err', 'dzErr', 'trkChi2', 'vtxChi2', 'puppiWeight', 'puppiWeightNoLep']
pf_categorical = ['charge', 'lostInnerHits', 'pdgId', 'pvAssocQuality', 'trkQuality']

In [None]:
jet_fields = jet_numerical + jet_categorical
pf_fields = pf_numerical + pf_categorical

jet_keys = [f'jet_{field}' for field in jet_fields]
pf_keys = [f'pf_{field}' for field in pf_fields]

num_jet = len(jet_keys)
num_pf = len(pf_keys)

In [None]:
with open(os.path.join(data_dir, 'metadata.pkl'), 'rb') as f:
    metadata = pickle.load(f)

In [None]:
num_files = len(record_files)
train_split = int(train_size * num_files)
test_split = int(test_size * num_files) + train_split

train_files = record_files[:train_split]
test_files = record_files[train_split:test_split]
val_files = record_files[test_split:]

In [None]:
def parse_record(example_proto):
    return tf.io.parse_example(example_proto, features=metadata)

In [None]:
def select_features(data):
    jet_data = tf.concat([tf.expand_dims(data[key].values, axis=1) for key in jet_keys], axis=1)
    pf_data = tf.concat([tf.expand_dims(data[key].values, axis=1) for key in pf_keys], axis=1)
    pf_data = tf.RaggedTensor.from_row_lengths(pf_data, row_lengths=data['row_lengths'].values)
    inputs = (pf_data, jet_data)
    return inputs, data['target'].values

In [None]:
def create_dataset(paths):
    ds = tf.data.TFRecordDataset(filenames=[record_files], num_parallel_reads=tf.data.experimental.AUTOTUNE)
    ds = ds.map(parse_record, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.map(select_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.unbatch().batch(batch_size)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    return ds

In [None]:
train_ds = create_dataset(train_files).shuffle(shuffle_buffer)
val_ds = create_dataset(val_files)
test_ds = create_dataset(test_files)

In [None]:
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Activation, Dense, TimeDistributed, BatchNormalization, Dropout, Concatenate, Add
from src.layers import Sum

In [None]:
def get_deepset():
    constituents = Input(shape=(None, num_pf), ragged=True, name='constituents')

    constituents_slice = Input(shape=(constituents.shape[-1],), name='constituents_slice')
    print(constituents)
    print(constituents_slice)

    deepset_outputs_slice = mlp(constituents_slice, name='deepset')

    deepset_model_slice = Model(inputs=constituents_slice, outputs=deepset_outputs_slice, name='deepset_model_slice')

    deepset_outputs = TimeDistributed(deepset_model_slice, name='deepset_distributed')(constituents)

    constituents_head = Sum(axis=1, name='constituents_head')(deepset_outputs)

    globals = Input(shape=(num_jet,), name='globals')

    inputs_head = Concatenate(name='head')([constituents_head, globals])

    x = mlp(inputs_head, name='head')

    outputs = Dense(1, name='head_dense_output')(x)

    model = Model(inputs=[constituents, globals], outputs=outputs, name='dnn')

    model.summary()

    for layer in model.layers:
        if isinstance(layer, TimeDistributed):
            layer.layer.summary()

    return model


def mlp(x, name):
    for idx, n in enumerate(units, start=1):
        print(n, x)
        x = Dense(n, kernel_initializer=initializer, name=f'{name}_dense_{idx}')(x)
        if batch_norm:
            x = BatchNormalization(name=f'{name}_batch_normalization_{idx}')(x)
        x = Activation(activation, name=f'{name}_activation_{idx}')(x)
        if dropout:
            x = Dropout(dropout, name=f'{name}_dropout_{idx}')(x)
    return x

In [None]:
dnn = get_deepset()
dnn.compile(optimizer=optimizer, loss=loss)
dnn.optimizer.lr.assign(lr)

In [None]:
fit = dnn.fit(train_ds, validation_data=val_ds, epochs=epochs)