In [None]:
import tensorflow as tf
import awkward as ak
import glob
import os

In [None]:
jet_numerical = ['pt', 'eta', 'mass', 'phi', 'area', 'qgl_axis2', 'qgl_ptD', 'qgl_mult']
jet_categorical = ['puId', 'partonFlavour']

pf_numerical = ['pt', 'eta', 'phi', 'd0', 'dz', 'd0Err', 'dzErr', 'trkChi2', 'vtxChi2', 'puppiWeight', 'puppiWeightNoLep']
pf_categorical = ['charge', 'lostInnerHits', 'pdgId', 'pvAssocQuality', 'trkQuality']

In [None]:
data_dir = '/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn'

parquet_dir = os.path.join(data_dir, 'preprocessed/dev')

In [None]:
train_size = 0.6
test_size = 0.2
val_size = 0.2

dirs = glob.glob(os.path.join(parquet_dir, '*'))
num_dirs = len(dirs)
train_split = int(train_size * num_dirs)
test_split = int(test_size * num_dirs) + train_split

train_dirs = dirs[:train_split]
test_dirs = dirs[train_split:test_split]
val_dirs = dirs[test_split:]

In [None]:
train_dirs

In [None]:
def read_parquet(path):
    path = path.decode()
    
    jet = ak.from_parquet(os.path.join(path, 'jet.parquet'))
    pf = ak.from_parquet(os.path.join(path, 'pf.parquet'))
    
    row_lengths = ak.num(pf, axis=1)
    flat_pf = ak.flatten(constituents, axis=1)
    
    data = [ak.to_numpy(row_lengths).astype(np.int32), ak.to_numpy(jet['target']).astype(np.float32)]
    
    for field in jet_numerical:
        data.append(ak.to_numpy(jet[field]).astype(np.float32))

    for field in jet_categorical:
        data.append(ak.to_numpy(jet[field]).astype(np.int32))

    for field in pf_numerical:
        data.append(ak.to_numpy(flat_pf[field]).astype(np.float32))

    for field in pf_categorical:
        data.append(ak.to_numpy(flat_pf[field]).astype(np.int32))
    
    return jet, pf

In [None]:
def read_parquet_wrapper(path, jet_fields, pf_fields):
    inp = path, jet_fields, pf_fields
    Tout = (
        [tf.int32] + [tf.float32] +
        [tf.float32] * len(jet_numerical) +
        [tf.int32] * len(jet_categorical) +
        [tf.float32] * len(pf_numerical) +
        [tf.int32] * len(pf_categorical)
    )
    
    cols = tf.numpy_function(read_parquet, inp=inp, Tout=Tout)
    
    keys = ['row_lengths'] + ['target'] + jet_fields + pf_fields
    data = {key: value for key, value in zip(keys, cols)}
    
    target = data.pop('target')
    target.set_shape((None,))
    
    row_lengths = data.pop('row_lengths')
    row_lengths.set_shape((None,))
    
    for field in jet_fields:
        # Shape from <unknown> to (None,)
        data[field].set_shape((None,))
        # Shape from (None,) to (None, 1)
        data[field] = tf.expand_dims(data[field], axis=1)
    
    for field in pf_fields:
        # Shape from <unknown> to (None,)
        data[field].set_shape((None,))
        # shape from (None,) to (None, None)
        data[field] = tf.RaggedTensor.from_row_lengths(data[field], row_lengths=row_lengths)
        # Shape from (None, None) to (None, None, 1)
        data[field] = tf.expand_dims(data[field], axis=2)

In [None]:
ds = tf.data.Dataset.from_tensor_slices(train_dirs)

In [None]:
#list(ds.as_numpy_iterator())

In [None]:
import tensorflow_io as tfio
tfio.IOTensor.from_parquet('/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn/preprocessed/dev/1/pf.parquet')

In [None]:
ak.from_parquet('/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn/preprocessed/dev/1/jet.parquet')

In [None]:
ds = ds.map(
    lambda path: read_parquet(path), 
    num_parallel_calls=tf.data.AUTOTUNE
)

In [None]:
tfio.IOTensor