In [1]:
import os
import glob
import pickle
import itertools
import awkward as ak
import numpy as np
import tensorflow as tf

In [2]:
from coffea.nanoevents import NanoEventsFactory, PFNanoAODSchema
PFNanoAODSchema.warn_missing_crossrefs = False
import warnings

In [3]:
data_dir = '/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn'

In [4]:
in_dir = os.path.join(data_dir, 'raw/dev')
out_dir = os.path.join(data_dir, 'preprocessed/dev')

root_files = glob.glob(os.path.join(in_dir, '*.root'))
num_files = len(root_files)

In [5]:
try:
    os.makedirs(out_dir)
except FileExistsError:
    pass

In [6]:
events = NanoEventsFactory.from_root(os.path.join(in_dir, '1.root'), schemaclass=PFNanoAODSchema).events()

In [7]:
all_jet_fields = list(filter(lambda field: 'IdxG' not in field, events.Jet.fields)) + ['log_pt']
all_pf_fields = list(filter(lambda field: 'IdxG' not in field, events.Jet.constituents.pf.fields)) + ['rel_eta', 'rel_phi', 'rel_pt']

all_jet_keys = [f'jet_{field}' for field in all_jet_fields]
all_pf_keys = [f'pf_{field}' for field in all_pf_fields]

In [8]:
def read_nanoaod(path):
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', message='found duplicate branch')
        events = NanoEventsFactory.from_root(path, schemaclass=PFNanoAODSchema).events()

    jets = events.Jet[(ak.count(events.Jet.matched_gen.pt, axis=1) >= 2)]

    sorted_jets = jets[ak.argsort(jets.matched_gen.pt, ascending=False, axis=1)]

    leading_jets = ak.concatenate((sorted_jets[:,0], sorted_jets[:,1]), axis=0)

    selected_jets = leading_jets[(leading_jets.matched_gen.pt > 30) & (abs(leading_jets.matched_gen.eta) < 5)]

    valid_jets = selected_jets[~ak.is_none(selected_jets.matched_gen.pt)]

    for field in ['dz', 'dzErr', 'd0', 'd0Err']:
        valid_jets = valid_jets[ak.all(valid_jets.constituents.pf[field] != np.inf, axis=1)]

    return valid_jets, valid_jets.constituents.pf

In [9]:
def preprocess(jet, pf):
    jet['target'] = jet.matched_gen.pt / jet.pt
    jet['log_pt'] = np.log(jet.pt)
    pf['rel_eta'] = (pf.eta - jet.eta) * np.sign(jet.eta)
    pf['rel_pt'] = pf.pt / jet.pt
    pf['rel_phi'] = (pf.phi - jet.phi + np.pi) % (2 * np.pi) - np.pi
    return jet, pf

In [19]:
def float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

In [20]:
def int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

In [107]:
def serialize_example(jet_row, pf_row, target_value, row_lengths):
    """
    Creates a tf.train.Example message ready to be written to a file.
    """
    # Create a dictionary mapping the feature name to the tf.train.Example-compatible
    # data type.
    
    jet_dict = {key: float_feature([value]) for key, value in zip(all_jet_keys, jet_row)}
    pf_dict = {key: float_feature(values) for key, values in zip(all_pf_keys, tf.transpose(pf_row))}
    
    feature = {'row_lengths': int64_feature([row_lengths]), 'target': float_feature([target_value])}
    feature.update(jet_dict)
    feature.update(pf_dict)
    
    # Create a Features message using tf.train.Example.
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [108]:
def create_record(root_file, record_file):
    print(record_file + '\n')
    
    jet, pf = read_nanoaod(root_file)
    jet, pf = preprocess(jet, pf)
    
    row_lengths = ak.num(pf, axis=1)
    flat_pf = ak.flatten(pf, axis=1)
    
    jet_tensor = tf.stack([tf.constant(jet[field], dtype=tf.float32) for field in all_jet_fields], axis=1)
    pf_tensor = tf.stack([tf.RaggedTensor.from_row_lengths(np.array(flat_pf[field]).astype(np.float32), row_lengths=row_lengths) for field in all_pf_fields], axis=2)
    target_tensor = tf.constant(jet.target)
    
    with tf.io.TFRecordWriter(record_file) as writer:
        for i in range(target_tensor.shape[0]):
            example = serialize_example(jet_tensor[i], pf_tensor[i], target_tensor[i], row_lengths[i])
            writer.write(example)

In [109]:
root_names = [os.path.basename(file) for file in root_files]
record_names = [f'{os.path.splitext(file)[0]}.tfrecords' for file in root_names]
record_files = [os.path.join(out_dir, record_name) for record_name in record_names]

In [110]:
for i in range(len(root_files)):
    create_record(root_files[i], record_files[i])

/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn/preprocessed/dev/1.tfrecords

/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn/preprocessed/dev/2.tfrecords

/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn/preprocessed/dev/3.tfrecords

/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn/preprocessed/dev/4.tfrecords

/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn/preprocessed/dev/5.tfrecords



In [25]:
# import multiprocessing
# with multiprocessing.Pool(processes=2) as pool:
#     pool.starmap(create_record, list(zip(root_files, record_files)))

In [116]:
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=2) as executor:
    results = executor.map(create_record, root_files, record_files)

/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn/preprocessed/dev/1.tfrecords

/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn/preprocessed/dev/2.tfrecords

/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn/preprocessed/dev/3.tfrecords

/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn/preprocessed/dev/4.tfrecords

/eos/cms/store/group/phys_jetmet/dholmber/jec-dnn/preprocessed/dev/5.tfrecords



In [115]:
!free -m

              total        used        free      shared  buff/cache   available
Mem:          15629        5967         303          37        9359        9351
Swap:             0           0           0


In [62]:
features = {
    'target': tf.io.FixedLenFeature([], dtype=tf.float32, default_value=0.0),
    'row_lengths': tf.io.FixedLenFeature([], dtype=tf.int64, default_value=0)
}
for key in all_jet_keys:
    features[key] = tf.io.FixedLenFeature([], dtype=tf.float32, default_value=0.0)
for key in all_pf_keys:
    features[key] = tf.io.VarLenFeature(dtype=tf.float32)

with open(os.path.join(out_dir, 'metadata.pkl'), 'wb') as f:
    pickle.dump(features, f)