In [1]:
import sys
import os

if sys.version_info[0] < 3:
  print('[ERROR] You need to run this with Python 3.')
  raise AssertionError

In [2]:
import numpy as np

from emtf_algos import *
from emtf_logger import get_logger
from emtf_colormap import get_colormap

In [3]:
# Set random seed
np.random.seed(2027)

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers as k_layers
from tensorflow.keras import backend as k_backend
import matplotlib as mpl
import matplotlib.pyplot as plt

# Set random seed
tf.random.set_seed(2027)

#import numba
#from numba import njit, vectorize
import dask
import dask.array as da

logger = get_logger()
logger.info('Using cmssw      : {0}'.format(os.environ['CMSSW_VERSION'] if 'CMSSW_VERSION' in os.environ else 'n/a'))
logger.info('Using python     : {0}'.format(sys.version.replace('\n', '')))
logger.info('Using numpy      : {0}'.format(np.__version__))
logger.info('Using tensorflow : {0}'.format(tf.__version__))
logger.info('Using keras      : {0}'.format(keras.__version__))
logger.info('.. list devices  : {0}'.format(tf.config.list_physical_devices()))
logger.info('Using matplotlib : {0}'.format(mpl.__version__))
#logger.info('Using numba      : {0}'.format(numba.__version__))
logger.info('Using dask       : {0}'.format(dask.__version__))

assert k_backend.backend() == 'tensorflow'
assert k_backend.image_data_format() == 'channels_last'

%matplotlib inline

[INFO    ] Using cmssw      : CMSSW_10_6_3
[INFO    ] Using python     : 3.6.10 |Anaconda, Inc.| (default, Mar 25 2020, 23:51:54) [GCC 7.3.0]
[INFO    ] Using numpy      : 1.19.2
[INFO    ] Using tensorflow : 2.4.0
[INFO    ] Using keras      : 2.4.0
[INFO    ] .. list devices  : [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]
[INFO    ] Using matplotlib : 3.3.2
[INFO    ] Using dask       : 2020.12.0


In [4]:
# Settings

# zone: (0,1,2) -> eta=(1.98..2.5, 1.55..1.98, 1.2..1.55)
#zone = 0
#zone = 1
zone = 2

# timezone: (0,1,2) -> BX=(0,-1,-2)
timezone = 0

# masked array filling value
ma_fill_value = 999999

#maxevents = 10
#maxevents = 20000
maxevents = -1

#workers = 1
workers = 8

# Input files
patterns_fname = 'patterns_zone%i.npz' % zone
zone_images_fname = 'zone_images_zone%i.h5' % zone

# Styling
plt.style.use('tdrstyle.mplstyle')
cm = get_colormap()

logger.info('Processing zone {0} timezone {1}'.format(zone, timezone))
logger.info('.. maxevents        : {0}'.format(maxevents))
logger.info('.. workers          : {0}'.format(workers))

[INFO    ] Processing zone 2 timezone 0
[INFO    ] .. maxevents        : -1
[INFO    ] .. workers          : 8


### Load data

In [5]:
def load_patterns():
  patterns = []
  boxes_act = []
  hitmap_quality_ranks = []
  for i in range(num_emtf_zones):
    fname = patterns_fname.replace('zone%i' % zone, 'zone%i' % i)  # modify filename
    logger.info('Loading from {0}'.format(fname))
    with np.load(fname) as loaded:
      patterns.append(loaded['patterns'])
      boxes_act.append(loaded['boxes_act'])
      hitmap_quality_ranks.append(loaded['hitmap_quality_ranks'])
  patterns = np.asarray(patterns)
  boxes_act = np.asarray(boxes_act)
  hitmap_quality_ranks = np.asarray(hitmap_quality_ranks)
  logger.info('patterns: {0} boxes_act: {1} hitmap_quality_ranks: {2}'.format(patterns.shape, boxes_act.shape, hitmap_quality_ranks.shape))
  return patterns, boxes_act, hitmap_quality_ranks

import h5py
loaded_h5 = None  # hdf5 file handle

def load_zone_sparse_images(fname):
  global loaded_h5
  if loaded_h5 is None:
    logger.info('Loading from {0}'.format(fname))
    loaded_h5 = h5py.File(fname, 'r')
  zone_box_anchors = loaded_h5['zone_box_anchors']
  zone_sparse_images = SparseTensorValue(indices=loaded_h5['zone_sparse_images_indices'],
                                         values=loaded_h5['zone_sparse_images_values'],
                                         dense_shape=loaded_h5['zone_sparse_images_dense_shape'])
  logger.info('zone_box_anchors: {0} zone_sparse_images: {1}'.format(zone_box_anchors.shape, zone_sparse_images.dense_shape))
  return zone_box_anchors, zone_sparse_images

def load_zone_hits(fname):
  global loaded_h5
  if loaded_h5 is None:
    logger.info('Loading from {0}'.format(fname))
    loaded_h5 = h5py.File(fname, 'r')
  zone_part = loaded_h5['zone_part']
  zone_hits = RaggedTensorValue(values=loaded_h5['zone_hits_values'],
                                row_splits=loaded_h5['zone_hits_row_splits'])
  zone_simhits = RaggedTensorValue(values=loaded_h5['zone_simhits_values'],
                                   row_splits=loaded_h5['zone_simhits_row_splits'])
  logger.info('zone_part: {0} zone_hits: {1} zone_simhits: {2}'.format(zone_part.shape, zone_hits.shape, zone_simhits.shape))
  return zone_part, zone_hits, zone_simhits

def load_zone_hits_lazy(fname):
  global loaded_h5
  if loaded_h5 is None:
    logger.info('Loading from {0}'.format(fname))
    loaded_h5 = h5py.File(fname, 'r')
  zone_part = da.from_array(loaded_h5['zone_part'])
  zone_hits_values = da.from_array(loaded_h5['zone_hits_values'])
  zone_hits_row_splits = da.from_array(loaded_h5['zone_hits_row_splits'])
  zone_hits_shape = (zone_hits_row_splits.shape[0] - 1,) + (None,) + zone_hits_values.shape[1:]
  zone_simhits_values = da.from_array(loaded_h5['zone_simhits_values'])
  zone_simhits_row_splits = da.from_array(loaded_h5['zone_simhits_row_splits'])
  zone_simhits_shape = (zone_simhits_row_splits.shape[0] - 1,) + (None,) + zone_simhits_values.shape[1:]
  logger.info('zone_part: {0} zone_hits: {1} zone_simhits: {2}'.format(zone_part.shape, zone_hits_shape, zone_simhits_shape))
  return zone_part, (zone_hits_values, zone_hits_row_splits), (zone_simhits_values, zone_simhits_row_splits)

In [6]:
# Load patterns
patterns, boxes_act, hitmap_quality_ranks = load_patterns()

# Create patterns_reshaped
patterns_reshaped = []
for i in range(num_emtf_zones):
  p = patterns[i, 3, [3, 2, 4, 1, 5, 0, 6]]  # order by straightness, only prompt patterns
  patterns_reshaped.append(p)
patterns_reshaped = np.asarray(patterns_reshaped)
logger.info('patterns_reshaped: {0}'.format(patterns_reshaped.shape))

# Create boxes_act_reshaped
boxes_act_reshaped = []
for i in range(num_emtf_zones):
  b = boxes_act[i, 3, [3, 2, 4, 1, 5, 0, 6]]  # order by straightness, only prompt patterns
  b = np.transpose(b, [3, 2, 1, 0])  # kernel shape is HWCD
  boxes_act_reshaped.append(b)
boxes_act_reshaped = np.asarray(boxes_act_reshaped)
logger.info('boxes_act_reshaped: {0}'.format(boxes_act_reshaped.shape))

# Create boxes_qual_reshaped
boxes_qual_reshaped = hitmap_quality_ranks // 4  # from 8-bit to 6-bit
assert boxes_qual_reshaped.max() == 63
logger.info('boxes_qual_reshaped: {0}'.format(boxes_qual_reshaped.shape))

[INFO    ] Loading from patterns_zone0.npz
[INFO    ] Loading from patterns_zone1.npz
[INFO    ] Loading from patterns_zone2.npz
[INFO    ] patterns: (3, 7, 7, 8, 3) boxes_act: (3, 7, 7, 8, 111, 1) hitmap_quality_ranks: (3, 256)
[INFO    ] patterns_reshaped: (3, 7, 8, 3)
[INFO    ] boxes_act_reshaped: (3, 1, 111, 8, 7)
[INFO    ] boxes_qual_reshaped: (3, 256)


In [7]:
# Load zone_hits (lazily)
zone_part_l, zone_hits_l, zone_simhits_l = load_zone_hits_lazy(zone_images_fname)

[INFO    ] Loading from zone_images_zone2.h5
[INFO    ] zone_part: (435533, 9) zone_hits: (435533, None, 18) zone_simhits: (435533, None, 18)


### Create inputs

In [8]:
# Input data columns
hits_metadata = ['emtf_site', 'emtf_host', 'emtf_chamber',
                 'emtf_segment', 'zones', 'timezones',
                 'emtf_phi', 'emtf_bend', 'emtf_theta',
                 'emtf_theta_alt', 'emtf_qual', 'emtf_qual_alt',
                 'emtf_time', 'strip', 'wire',
                 'fr', 'detlayer', 'bx']
hits_metadata = dict(zip(hits_metadata, range(len(hits_metadata))))
#print(hits_metadata)

# Image format
num_channels = 1
num_cols = 288  # 80 degrees
num_rows = 8
image_format = (num_rows, num_cols, num_channels)

# Array indices
ind_emtf_chamber = hits_metadata['emtf_chamber']
ind_emtf_segment = hits_metadata['emtf_segment']

ind_emtf_phi = hits_metadata['emtf_phi']
ind_emtf_bend = hits_metadata['emtf_bend']
ind_emtf_theta1 = hits_metadata['emtf_theta']
ind_emtf_theta2 = hits_metadata['emtf_theta_alt']
ind_emtf_qual1 = hits_metadata['emtf_qual']
ind_emtf_qual2 = hits_metadata['emtf_qual_alt']
ind_emtf_time = hits_metadata['emtf_time']
ind_zones = hits_metadata['zones']
ind_tzones = hits_metadata['timezones']
ind_fr = hits_metadata['fr']
ind_dl = hits_metadata['detlayer']
ind_bx = hits_metadata['bx']
ind_valid = hits_metadata['bx']  # CUIDADO: use 'bx' for the moment

new_hits_metadata = ['emtf_phi', 'emtf_bend', 'emtf_theta1',
                     'emtf_theta2', 'emtf_qual1', 'emtf_qual2',
                     'emtf_time', 'zones', 'tzones',
                     'fr', 'dl', 'bx',
                     'valid']
new_hits_metadata = dict(zip(new_hits_metadata, range(len(new_hits_metadata))))
#print(new_hits_metadata)

In [9]:
@dask.delayed
def build_inputs_batch(batch):
  results = []
  for ievt in batch:
    dense_shape = np.array([num_emtf_chambers, num_emtf_segments, num_emtf_variables], dtype=np.int32)
    zone_hits_columns = [ind_emtf_chamber, ind_emtf_segment]
    indices = zone_hits[ievt][:, zone_hits_columns]
    zone_hits_columns = [ind_emtf_phi, ind_emtf_bend, ind_emtf_theta1, ind_emtf_theta2,
                         ind_emtf_qual1, ind_emtf_qual2, ind_emtf_time, ind_zones,
                         ind_tzones, ind_fr, ind_dl, ind_bx,
                         ind_valid]
    values = zone_hits[ievt][:, zone_hits_columns]
    values[:, -1] = 1  # CUIDADO: set 'valid' to 1

    # Apply truncation
    valid = indices[:, 1] < num_emtf_segments
    indices = indices[valid]
    values = values[valid]

    # Sparse -> Dense
    sparse = SparseTensorValue(indices=indices, values=values, dense_shape=dense_shape)
    dense = sparse_to_dense(sparse)

    # Append result
    result = (dense, np.concatenate((indices, values), axis=-1))
    results.append(result)
  return results

def build_inputs_all_batches():
  # Split into batches
  _maxevents = zone_hits.shape[0] if maxevents == -1 else maxevents
  batches = np.array_split(np.arange(_maxevents), workers)

  results = []
  for batch in batches:
    results_batch = build_inputs_batch(batch)
    results.append(results_batch)
  return results

def build_inputs():
  results = build_inputs_all_batches()
  with dask.config.set(scheduler='threads', num_workers=workers):
    results, = dask.compute(results)  # now wait...

  import itertools
  results = list(itertools.chain.from_iterable(results))
  inputs, sparse_inputs = zip(*results)
  return (np.asarray(inputs), sparse_inputs)

In [10]:
# Build inputs
zone_hits = RaggedTensorValue(values=np.array(zone_hits_l[0]), row_splits=np.array(zone_hits_l[1]))  # actually load zone_hits

inputs, sparse_inputs = build_inputs()
sparse_inputs_shape = (len(sparse_inputs), None, sparse_inputs[0].shape[-1])
logger.info('inputs: {0} sparse_inputs: {1}'.format(inputs.shape, sparse_inputs_shape))

[INFO    ] inputs: (435533, 115, 2, 13) sparse_inputs: (435533, None, 15)


In [11]:
# Debug
def tf_constant_value(tensor):
  if tf.is_tensor(tensor):
    try:
      return tensor.numpy()
    except:
      return tensor
  else:
    return None

import functools
my_array2string = functools.partial(np.array2string, separator=', ', formatter={'int':lambda x: '% 4i' % x}, max_line_width=100, threshold=1000)

print(my_array2string(sparse_inputs[0]))
print(my_array2string(sparse_inputs[2]))
print(my_array2string(sparse_inputs[5]))

[[  14,    0,  4444,   -7,   65,   65,    5,    5,    0,    1,    4,    1,    0,    0,    1],
 [  26,    0,  4352,    0,   64,   64,    5,    5,    0,    1,    4,    1,    0,    0,    1],
 [  35,    0,  4345,    1,   62,   62,   -5,   -5,    0,    1,    4,    0,    0,    0,    1],
 [  44,    0,  4361,    3,   61,   61,   -4,   -4,    0,    1,    4,    0,    0,    0,    1],
 [  68,    0,  4440,    0,   56,   56,    0,    0,   -1,    3,    4,    1,    0,    0,    1],
 [  80,    0,  4368,    0,   68,   68,    0,    0,    0,    1,    4,    1,    0,    0,    1],
 [  89,    0,  4348,    0,   60,   60,    0,    0,    1,    1,    4,    0,    0,    0,    1],
 [  98,    0,  4368,    0,   64,   64,    0,    0,    2,    1,    4,    0,    0,    0,    1]]
[[   4,    0,  1973,    5,   72,   72,    6,    6,    0,    1,    4,    1,    0,    0,    1],
 [  22,    0,  2028,    0,   72,   72,    6,    6,    0,    1,    4,    1,    0,    0,    1],
 [  31,    0,  2024,    0,   72,   72,   -6,   -6,    0,    

### Create model

In [12]:
def prepare_constants():
  # Utility functions & LUTs
  inverse_fn = lambda F, y: [[i for (i, y_i) in enumerate(F) if y_i == y_j] for y_j in y]
  to_array = lambda x: np.asarray([np.asarray(x_i) for x_i in x])
  to_list = lambda x: [x_i.tolist() for x_i in x]
  flatten = lambda x: np.asarray([x_i_i for x_i in x for x_i_i in x_i])

  def to_array(x):  # improved version
    is_ragged = len(set([len(x_i) for x_i in x])) > 1
    if is_ragged:
      return np.asarray([np.asarray(x_i) for x_i in x], dtype=np.object)
    else:
      return np.asarray([x_i for x_i in x])

  class Constants:
    pass

  cc = Constants()

  cc.site_to_img_row_luts = site_to_img_row_luts

  host_to_chamber_lut = to_array(inverse_fn(chamber_to_host_lut, np.arange(num_emtf_hosts)))
  #host_to_chamber_lut_flat = flatten(host_to_chamber_lut)

  site_to_host_lut = to_array(inverse_fn(host_to_site_lut, np.arange(num_emtf_sites)))
  site_to_chamber_lut = to_array([
      [c for host in hosts for c in host_to_chamber_lut[host]] \
      for hosts in site_to_host_lut
  ])
  cc.site_to_chamber_lut = site_to_chamber_lut

  cc.img_row_to_chamber_luts = []

  for i in range(num_emtf_zones):
    host_to_img_row_lut = find_emtf_img_row_lut()[:, i]
    img_row_to_host_lut = to_array(inverse_fn(host_to_img_row_lut, np.arange(num_rows)))
    img_row_to_chamber_lut = to_array([
        [c for host in hosts for c in host_to_chamber_lut[host]] \
        for hosts in img_row_to_host_lut
    ])
    cc.img_row_to_chamber_luts.append(img_row_to_chamber_lut)
  return cc

cc = prepare_constants()

In [13]:
def build_zone_images(x, zone, timezone, image_format):
  # Prepare zone images
  zone_images = np.zeros((x.shape[0],) + image_format, dtype=np.bool)

  def get_boolean_mask(zone, row):
    indices = cc.img_row_to_chamber_luts[zone][row]
    boolean_mask = np.zeros(num_emtf_chambers, dtype=np.bool)
    boolean_mask[indices] = 1
    return boolean_mask

  # Loop over events
  for ievt in range(zone_images.shape[0]):
    new_ind_emtf_phi = new_hits_metadata['emtf_phi']
    new_ind_zones = new_hits_metadata['zones']
    new_ind_tzones = new_hits_metadata['tzones']
    new_ind_valid = new_hits_metadata['valid']

    x_emtf_phi = x[ievt][..., new_ind_emtf_phi]
    x_zones = x[ievt][..., new_ind_zones]
    x_tzones = x[ievt][..., new_ind_tzones]
    x_valid = x[ievt][..., new_ind_valid]

    _valid = (x_valid == 1) & \
             (x_zones & (1<<((num_emtf_zones - 1) - zone))).astype(np.bool) & \
             (x_tzones & (1<<((num_emtf_timezones - 1) - timezone))).astype(np.bool)

    # Loop over rows
    rows = []
    cols = []
    channels = []
    for row in range(zone_images.shape[1]):
      boolean_mask = get_boolean_mask(zone, row)
      valid = _valid[boolean_mask]
      col = find_emtf_img_col(x_emtf_phi[boolean_mask][valid])
      rows.extend((col * 0) + row)
      cols.extend(col)
      channels.extend((col * 0))

    # Fill zone image
    zone_images[ievt][(rows, cols, channels)] = 1
  return zone_images

In [14]:
def build_track_cands(x, x_patt, idx_h, idx_w, idx_z, num_features):
  # Prepare track cands
  num_feature_axes = 7  # (emtf_phi, emtf_bend, emtf_theta1, emtf_theta2, emtf_qual1, emtf_qual2, emtf_time)
  track_cands_tmp = np.zeros((x_patt.shape[0], x_patt.shape[1], num_feature_axes * num_emtf_sites), dtype=np.int32)
  track_cands = np.zeros((x_patt.shape[0], x_patt.shape[1], num_features), dtype=np.int32)

  invalid_marker_ph_seg = (num_emtf_chambers * num_emtf_segments)
  track_cands_seg = np.zeros((x_patt.shape[0], x_patt.shape[1], num_emtf_sites), dtype=np.int32) + invalid_marker_ph_seg
  track_cands_seg_rm = np.zeros((x_patt.shape[0], x_patt.shape[1], num_emtf_sites_rm), dtype=np.int32) + invalid_marker_ph_seg

  def get_phi_patt(zone, patt, row, phi_patt, _which):
    # In the following, (0, 0, 4) is zone 0 patt 'straightest' row 'ME2/1'
    if _which == 'start':
      phi_patt_corr = patterns_reshaped[zone, patt, row, 0] - patterns_reshaped[0, 0, 4, 1]
    elif _which == 'mid':
      phi_patt_corr = patterns_reshaped[zone, patt, row, 1] - patterns_reshaped[0, 0, 4, 1]
    elif _which == 'stop':
      phi_patt_corr = patterns_reshaped[zone, patt, row, 2] - patterns_reshaped[0, 0, 4, 1]
    else:
      raise ValueError('Invalid: %s' % _which)
    phi_patt = phi_patt + phi_patt_corr
    phi_patt = find_emtf_img_col_inverse(phi_patt)
    return phi_patt

  def get_img_row(zone, site):
    return cc.site_to_img_row_luts[zone][site]

  def get_boolean_mask(zone, site):
    indices = cc.site_to_chamber_lut[site]
    boolean_mask = np.zeros(num_emtf_chambers, dtype=np.bool)
    boolean_mask[indices] = 1
    return boolean_mask

  # Loop over events
  for ievt in range(x_patt.shape[0]):
    new_ind_emtf_phi = new_hits_metadata['emtf_phi']
    new_ind_emtf_bend = new_hits_metadata['emtf_bend']
    new_ind_emtf_theta1 = new_hits_metadata['emtf_theta1']
    new_ind_emtf_theta2 = new_hits_metadata['emtf_theta2']
    new_ind_emtf_qual1 = new_hits_metadata['emtf_qual1']
    new_ind_emtf_qual2 = new_hits_metadata['emtf_qual2']
    new_ind_emtf_time = new_hits_metadata['emtf_time']
    new_ind_zones = new_hits_metadata['zones']
    new_ind_tzones = new_hits_metadata['tzones']
    new_ind_valid = new_hits_metadata['valid']

    x_emtf_phi = x[ievt][..., new_ind_emtf_phi]
    x_emtf_bend = x[ievt][..., new_ind_emtf_bend]
    x_emtf_theta1 = x[ievt][..., new_ind_emtf_theta1]
    x_emtf_theta2 = x[ievt][..., new_ind_emtf_theta2]
    x_emtf_qual1 = x[ievt][..., new_ind_emtf_qual1]
    x_emtf_qual2 = x[ievt][..., new_ind_emtf_qual2]
    x_emtf_time = x[ievt][..., new_ind_emtf_time]
    x_zones = x[ievt][..., new_ind_zones]
    x_tzones = x[ievt][..., new_ind_tzones]
    x_valid = x[ievt][..., new_ind_valid]

    x_seg = np.arange(num_emtf_chambers * num_emtf_segments).reshape((num_emtf_chambers, num_emtf_segments))

    # Loop over fired patterns
    for itrk in range(x_patt.shape[1]):
      zone = idx_z[ievt, itrk]
      #timezone = idx_tz[ievt, itrk]
      patt = idx_h[ievt, itrk]
      phi_patt = idx_w[ievt, itrk]
      qual_patt = x_patt[ievt, itrk]
      bx_patt = 0

      num_theta_values = 9
      theta_values = np.zeros(num_theta_values, dtype=np.int32)
      theta_values_s1 = np.zeros(num_theta_values, dtype=np.int32)

      _valid = (x_valid == 1) & \
               (x_zones & (1<<((num_emtf_zones - 1) - zone))).astype(np.bool) & \
               (x_tzones & (1<<((num_emtf_timezones - 1) - timezone))).astype(np.bool)

      # Loop over sites
      for site in range(num_emtf_sites):
        row = get_img_row(zone, site)
        phi_patt_start = get_phi_patt(zone, patt, row, phi_patt, 'start')
        phi_patt_mid = get_phi_patt(zone, patt, row, phi_patt, 'mid')
        phi_patt_stop = get_phi_patt(zone, patt, row, phi_patt, 'stop')

        _valid_phi = (find_emtf_img_col(phi_patt_start) <= find_emtf_img_col(x_emtf_phi)) & \
                     (find_emtf_img_col(x_emtf_phi) <= find_emtf_img_col(phi_patt_stop))

        boolean_mask = get_boolean_mask(zone, site)
        valid = (_valid & _valid_phi)[boolean_mask]
        phi_values = x_emtf_phi[boolean_mask][valid]

        if phi_values.size > 0:
          # Select min dphi
          dphi = np.abs(phi_values - phi_patt_mid)
          idx = np.argmin(dphi)
          _vars = [
            x_emtf_phi[boolean_mask][valid][idx],
            x_emtf_bend[boolean_mask][valid][idx],
            x_emtf_theta1[boolean_mask][valid][idx],
            x_emtf_theta2[boolean_mask][valid][idx],
            x_emtf_qual1[boolean_mask][valid][idx],
            x_emtf_qual2[boolean_mask][valid][idx],
            x_emtf_time[boolean_mask][valid][idx],
          ]
          track_cands_tmp[ievt, itrk, np.arange(num_feature_axes) * num_emtf_sites + site] = _vars
          track_cands_seg[ievt, itrk, site] = x_seg[boolean_mask][valid][idx]

          # Set theta_values, theta_values_s1
          if site == 0:
            theta_values_s1[1] = _vars[new_ind_emtf_theta1]
            theta_values_s1[4] = _vars[new_ind_emtf_theta2]
          elif site == 1:
            theta_values_s1[0] = _vars[new_ind_emtf_theta1]
            theta_values_s1[3] = _vars[new_ind_emtf_theta2]
          elif site == 2:
            theta_values[0] = _vars[new_ind_emtf_theta1]
            theta_values[3] = _vars[new_ind_emtf_theta2]
          elif site == 3:
            theta_values[1] = _vars[new_ind_emtf_theta1]
            theta_values[4] = _vars[new_ind_emtf_theta2]
          elif site == 4:
            theta_values[2] = _vars[new_ind_emtf_theta1]
            theta_values[5] = _vars[new_ind_emtf_theta2]
          elif site == 5:
            theta_values_s1[6] = _vars[new_ind_emtf_theta1]
          elif site == 6:
            theta_values[6] = _vars[new_ind_emtf_theta1]
          elif site == 7:
            theta_values[7] = _vars[new_ind_emtf_theta1]
          elif site == 8:
            theta_values[8] = _vars[new_ind_emtf_theta1]
          elif site == 9:
            theta_values_s1[7] = _vars[new_ind_emtf_theta1]
          elif site == 10:
            theta_values[6] = theta_values[6] if theta_values[6] != 0 else _vars[new_ind_emtf_theta1]
          elif site == 11:
            theta_values_s1[8] = _vars[new_ind_emtf_theta1]
          # End if phi_values.size > 0
        # End loop over site

      # Find phi_median and theta_median
      def find_theta_median():
        def find_median_of_three(x):
          x = np.sort(x)
          return pick_the_median(x[x > 0]) if np.any(x > 0) else pick_the_first(x)
        #
        def find_median_of_nine(x):
          return find_median_of_three([find_median_of_three(x[0:3]), find_median_of_three(x[3:6]), find_median_of_three(x[6:9])])
        #
        if np.any(theta_values > 0):
          return find_median_of_nine(theta_values)
        else:
          return find_median_of_nine(theta_values_s1)

      phi_median = find_emtf_img_col_inverse(phi_patt)
      theta_median = find_theta_median()

      # Require theta window, find best theta values
      th_window = 8
      for site in range(num_emtf_sites):
        theta1 = track_cands_tmp[ievt, itrk, (new_ind_emtf_theta1 * num_emtf_sites) + site]
        theta2 = track_cands_tmp[ievt, itrk, (new_ind_emtf_theta2 * num_emtf_sites) + site]
        dtheta1 = np.abs(theta1 - theta_median)
        dtheta2 = np.abs(theta2 - theta_median)
        valid_site = (theta1 != 0 and theta2 != 0) and (dtheta1 < th_window or dtheta2 < th_window)
        if dtheta2 < dtheta1:
          track_cands_tmp[ievt, itrk, (new_ind_emtf_theta1 * num_emtf_sites) + site] = theta2

        track_cands_tmp[ievt, itrk, (new_ind_emtf_phi * num_emtf_sites) + site] -= phi_median
        track_cands_tmp[ievt, itrk, (new_ind_emtf_theta1 * num_emtf_sites) + site] -= theta_median
        if not valid_site:
          track_cands_tmp[ievt, itrk, np.arange(num_feature_axes) * num_emtf_sites + site] = ma_fill_value

      # Extract features
      track_cands_tmp_indices = [
         0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,  # emtf_phi
        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,  # emtf_theta
        12, 13, 14, 15, 16, 23,                          # emtf_bend
        48, 49, 50, 51, 52, 59,                          # emtf_qual
      ]
      additional_features = [
        (phi_median - (max_emtf_strip // 2)), theta_median, qual_patt, bx_patt,
      ]
      additional_features = np.array(additional_features, dtype=np.int32)
      track_cands[ievt, itrk] = np.hstack([track_cands_tmp[ievt, itrk, track_cands_tmp_indices], additional_features])

      # Duplicate removal
      def find_seg_rm(x):
        invalid = invalid_marker_ph_seg
        return pick_the_first(x[x != invalid]) if np.any(x != invalid) else pick_the_first(x)

      track_cands_seg_indices = [[0, 9, 1, 5], [2, 10, 6], [3, 7], [4, 8], [11]]
      for site in range(num_emtf_sites_rm):
        track_cands_seg_rm[ievt, itrk, site] = find_seg_rm(track_cands_seg[ievt, itrk, track_cands_seg_indices[site]])
      # End loop over fired patterns

    # Duplicate removal
    num_tracks = x_patt.shape[1]
    removal = np.zeros(num_tracks, dtype=np.bool)
    for i in range(num_tracks-1):
      for j in range(i+1, num_tracks):
        removal[j] |= np.any(
            (track_cands_seg_rm[ievt, i] != invalid_marker_ph_seg) & \
            (track_cands_seg_rm[ievt, j] != invalid_marker_ph_seg) & \
            (track_cands_seg_rm[ievt, i] == track_cands_seg_rm[ievt, j])
        )

    i = 0
    for j in range(num_tracks):
      if not removal[j]:
        track_cands[ievt, i] = track_cands[ievt, j]
        i += 1
    if i != num_tracks:
      for i in range(i, num_tracks):
        track_cands[ievt, i] = ma_fill_value
    # End loop over events

  return track_cands

In [15]:
# Creating custom layers
# See: https://www.tensorflow.org/tutorials/customization/custom_layers

class Zoning(k_layers.Layer):
  def __init__(self, zone, timezone, image_format=image_format, **kwargs):
    super(Zoning, self).__init__(**kwargs)
    self.zone = zone
    self.timezone = timezone
    self.image_format = image_format

    # Set up to call build_zone_images()
    import functools
    kwargs = dict(zone=self.zone, timezone=self.timezone, image_format=self.image_format)
    _build_zone_images = functools.partial(build_zone_images, **kwargs)
    #py_func = lambda x: tf.py_function(_build_zone_images, [x], tf.bool)
    py_func = lambda x: tf.numpy_function(_build_zone_images, [x], tf.bool)
    self.py_func = k_layers.Lambda(py_func)

  def call(self, inputs):
    x = inputs

    # Call build_zone_images()
    x = tf.cast(x, dtype=tf.int32)
    x = self.py_func(x)
    output_shape = (None,) + self.image_format
    x.set_shape(output_shape)
    return x

class Pooling(k_layers.Layer):
  def __init__(self, zone, image_format=image_format, num_patterns=num_emtf_patterns, **kwargs):
    super(Pooling, self).__init__(**kwargs)
    self.zone = zone
    self.image_format = image_format
    self.num_patterns = num_patterns

    # SeparableConv2D but only the depthwise conv (i.e. without the pointwise conv)
    # See: https://www.tensorflow.org/api_docs/python/tf/keras/layers/DepthwiseConv2D
    # See: https://www.tensorflow.org/api_docs/python/tf/keras/layers/SeparableConv2D
    from k_layers_separable_conv2d import SeparableConv2D as MySeparableConv2D
    w_init = keras.initializers.Constant(boxes_act_reshaped[self.zone])
    conv2d_kwargs = dict(filters=1, kernel_size=(boxes_act_reshaped.shape[1], boxes_act_reshaped.shape[2]), depth_multiplier=self.num_patterns,
                         strides=(1, 1), padding='same', activation=None, use_bias=False,
                         depthwise_initializer=w_init, pointwise_initializer='ones', trainable=False)
    self.conv2d = MySeparableConv2D(**conv2d_kwargs)

    # Embedding
    # See: https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
    w_init = keras.initializers.Constant(boxes_qual_reshaped[self.zone])
    num_embedding_input_dim = (2 ** num_rows)
    embedding_kwargs = dict(input_dim=num_embedding_input_dim, output_dim=1, input_length=1,
                            embeddings_initializer=w_init, trainable=False)
    self.embedding = k_layers.Embedding(**embedding_kwargs)

    # Dot product coeffs for packing the last axis
    self.po2_coeffs = (2 ** np.arange(self.image_format[0]))  # [1,2,4,8,16,32,64,128]
    self.po2_coeffs = self.po2_coeffs.astype(np.int32)

  def call(self, inputs):
    # Conv
    x = inputs  # NHWC, which is (None, 8, 288, 1)
    x = tf.cast(x, dtype=tf.float32)
    x = tf.transpose(x, perm=(0, 3, 2, 1))  # NHWC -> NCWH
    x = self.conv2d(x)  # NCWH -> NCWH', H' is dim of size H * D, D is depth_multipler
    x = tf.reshape(x, [-1, self.image_format[2], self.image_format[1], self.image_format[0], self.num_patterns])  # NCWH' -> NCWHD
    x = tf.transpose(x, perm=(0, 1, 2, 4, 3))  # NCWHD -> NCWDH
    x = tf.reduce_sum(x, axis=1)  # NCWDH -> NWDH, C is dim of size 1 and has been dropped

    # Pack 8 bits into a single number
    x = tf.clip_by_value(x, 0, 1)
    x = tf.cast(x, dtype=tf.int32)
    x = tf.reduce_sum(x * self.po2_coeffs, axis=-1)  # NWDH -> NWD, H has been packed into a single number and dropped
    x = tf.cast(x, dtype=tf.float32)

    # Embedding
    x = self.embedding(x)  # NWD -> NWDE, E is embedding output dim
    x = tf.reduce_sum(x, axis=-1)  # NWDE -> NWD, E is dim of size 1 and has been dropped
    x = tf.cast(x, dtype=tf.int32)

    # Gather max element
    idx_h = tf.argmax(x, axis=-1, output_type=tf.int32)  # NWD -> NW
    x = tf.gather(x, idx_h, axis=-1, batch_dims=2)  # NWD -> NW
    return (x, idx_h)

class Suppression(k_layers.Layer):
  def __init__(self, zone, image_format=image_format, **kwargs):
    super(Suppression, self).__init__(**kwargs)
    self.zone = zone
    self.image_format = image_format

  def call(self, inputs):
    x, idx_h = inputs

    # Non-max suppression
    x_padded = tf.pad(x, paddings=((0, 0), (1, 1)))  # ((pad_t, pad_b), (pad_l, pad_r))
    mask = (x > x_padded[:, :-2]) & (x >= x_padded[:, 2:])  # x > x_left && x >= x_right
    mask = tf.cast(mask, dtype=x.dtype)
    x = x * mask
    return (x, idx_h)

class ZoneSorting(k_layers.Layer):
  def __init__(self, zone, num_tracks=num_emtf_tracks, **kwargs):
    super(ZoneSorting, self).__init__(**kwargs)
    self.zone = zone
    self.num_tracks = num_tracks

  def call(self, inputs):
    x, idx_h = inputs

    # Sort (descending)
    idx_w = tf.argsort(x, axis=-1, direction='DESCENDING', stable=True)
    idx_w.set_shape(x.shape)
    idx_w = tf.transpose(tf.transpose(idx_w)[:self.num_tracks])  # truncate

    # Gather max elements
    x = tf.gather(x, idx_w, axis=-1, batch_dims=1)  # NW -> NW', W' is dim of size num_tracks
    idx_h = tf.gather(idx_h, idx_w, axis=-1, batch_dims=1)  # NW -> NW'
    return (x, idx_h, idx_w)

class ZoneMerging(k_layers.Layer):
  def __init__(self, num_tracks=num_emtf_tracks, **kwargs):
    super(ZoneMerging, self).__init__(**kwargs)
    self.num_tracks = num_tracks

  def call(self, inputs):
    x, idx_h, idx_w = inputs

    # Sort (descending)
    idx_z = tf.argsort(x, axis=-1, direction='DESCENDING', stable=True)
    idx_z.set_shape(x.shape)
    idx_z = tf.transpose(tf.transpose(idx_z)[:self.num_tracks])  # truncate

    # Gather max elements
    x = tf.gather(x, idx_z, axis=-1, batch_dims=1)  # NW' -> NW", W" is dim of size num_tracks
    idx_h = tf.gather(idx_h, idx_z, axis=-1, batch_dims=1)  # NW' -> NW"
    idx_w = tf.gather(idx_w, idx_z, axis=-1, batch_dims=1)  # NW' -> NW"
    idx_z = idx_z // self.num_tracks
    return (x, idx_h, idx_w, idx_z)

class TrkBuilding(k_layers.Layer):
  def __init__(self, num_features=num_emtf_features, **kwargs):
    super(TrkBuilding, self).__init__(**kwargs)
    self.num_features=num_features

    # Set up to call build_track_cands()
    import functools
    kwargs = dict(num_features=self.num_features)
    _build_track_cands = functools.partial(build_track_cands, **kwargs)
    #py_func = lambda x: tf.py_function(_build_track_cands, x, tf.int32)
    py_func = lambda x: tf.numpy_function(_build_track_cands, x, tf.int32)
    self.py_func = k_layers.Lambda(py_func)

  def call(self, inputs):
    x, x_patt, idx_h, idx_w, idx_z = inputs

    # Call build_track_cands()
    x = tf.cast(x, dtype=tf.int32)
    x = (x, x_patt, idx_h, idx_w, idx_z)
    x = self.py_func(x)
    output_shape = (None,) + (x_patt.shape[1], self.num_features)
    x.set_shape(output_shape)
    return x

In [16]:
def create_model():
  # Input
  inputs = keras.Input(shape=(num_emtf_chambers, num_emtf_segments, num_emtf_variables), name='inputs')
  x = inputs

  # Loop over zones
  x_list = []

  for i in range(num_emtf_zones):
    # Make zone images
    x_i = Zoning(zone=i, timezone=timezone, name='zoning_{0}'.format(i))(x)

    # Pattern recognition
    x_i = Pooling(zone=i, name='pooling_{0}'.format(i))(x_i)
    x_i = Suppression(zone=i, name='suppression_{0}'.format(i))(x_i)

    # Sort zone outputs
    x_i = ZoneSorting(zone=i, name='zonesorting_{0}'.format(i))(x_i)

    # Add x_i to x_list
    x_list.append(x_i)

  # Merge zone outputs
  i = next(iter(range(num_emtf_zones)))
  x = (k_layers.Concatenate(axis=-1)([x[0] for x in x_list]),
       k_layers.Concatenate(axis=-1)([x[1] for x in x_list]),
       k_layers.Concatenate(axis=-1)([x[2] for x in x_list]))
  x = ZoneMerging(name='zonemerging_{0}'.format(i))(x)

  # Track builder
  x = (inputs,) + x
  x = TrkBuilding(name='trkbuilding_{0}'.format(i))(x)

  # Output
  outputs = x

  # Model
  model = keras.Model(inputs=inputs, outputs=outputs, name='awesome_model')
  model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy")

  # Summary
  model.summary()
  print('trainable weights:', len(model.trainable_weights))
  print('all weights:', len(model.weights))
  return model

model = create_model()

Model: "awesome_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
inputs (InputLayer)             [(None, 115, 2, 13)] 0                                            
__________________________________________________________________________________________________
zoning_0 (Zoning)               (None, 8, 288, 1)    0           inputs[0][0]                     
__________________________________________________________________________________________________
zoning_1 (Zoning)               (None, 8, 288, 1)    0           inputs[0][0]                     
__________________________________________________________________________________________________
zoning_2 (Zoning)               (None, 8, 288, 1)    0           inputs[0][0]                     
______________________________________________________________________________________

### Evaluate model

In [17]:
class DataGenerator(keras.utils.Sequence):
  def __init__(self, x, batch_size=None, steps=None):
    self.x = x
    self.num_samples = int(x.shape[0])
    if not batch_size:
      batch_size = int(np.ceil(self.num_samples / float(steps))) if steps else 32
    self.batch_size = batch_size
    self.num_batches = int(np.ceil(self.num_samples / float(batch_size)))

  def __len__(self):
    return self.num_batches

  def __getitem__(self, index):
    start, stop = (index * self.batch_size, min(self.num_samples, (index + 1) * self.batch_size))
    return self.x[start:stop]

In [18]:
# Make predictions
if maxevents == -1:
  outputs = model.predict(DataGenerator(inputs, batch_size=10000), workers=workers, use_multiprocessing=False)  # now wait...
  logger.info('outputs: {0} type: {1}'.format(outputs.shape, type(outputs)))
else:
  outputs = model.predict(inputs, batch_size=10000)
  logger.info('outputs: {0} type: {1}'.format(outputs.shape, type(outputs)))

[INFO    ] outputs: (435533, 4, 40) type: <class 'numpy.ndarray'>


In [19]:
# Write to file
if maxevents == -1:
  outfile = 'features_zone{0}.h5'.format(zone)
  outdict = {'features': da.from_array(outputs[:, 0, :]), 'truths': zone_part_l}
  da.to_hdf5(outfile, outdict, compression='lzf')
  logger.info('Wrote to {0}'.format(outfile))

[INFO    ] Wrote to features_zone2.h5


In [20]:
# Debug
tiny_inputs = inputs[:10]
model_zoning_0 = keras.Model(inputs=model.input, outputs=model.get_layer('zoning_0').output)
model_pooling_0 = keras.Model(inputs=model.input, outputs=model.get_layer('pooling_0').output)
model_suppression_0 = keras.Model(inputs=model.input, outputs=model.get_layer('suppression_0').output)
model_zonesorting_0 = keras.Model(inputs=model.input, outputs=model.get_layer('zonesorting_0').output)
model_zonemerging_0 = keras.Model(inputs=model.input, outputs=model.get_layer('zonemerging_0').output)
model_trkbuilding_0 = keras.Model(inputs=model.input, outputs=model.get_layer('trkbuilding_0').output)

outputs = model_zoning_0(tiny_inputs)
x = tf_constant_value(outputs[0])
with np.printoptions(linewidth=100, threshold=1000):
  print('zoning_0_out:')
  print(x[0].nonzero())
  print(x[2].nonzero())
  print(x[5].nonzero())

outputs = model_pooling_0(tiny_inputs)
x = tf_constant_value(outputs[0])
with np.printoptions(linewidth=100, threshold=1000):
  print('pooling_0_out:')
  print(x[0].nonzero())
  print(x[2].nonzero())
  print(x[5].nonzero())

outputs = model_suppression_0(tiny_inputs)
x = tf_constant_value(outputs[0])
with np.printoptions(linewidth=100, threshold=1000):
  print('suppression_0_out:')
  print(x[0].nonzero(), x[0][x[0].nonzero()])
  print(x[2].nonzero(), x[2][x[2].nonzero()])
  print(x[5].nonzero(), x[5][x[5].nonzero()])

outputs = model_zonesorting_0(tiny_inputs)
x = tf_constant_value(outputs[0])
with np.printoptions(linewidth=100, threshold=1000):
  print('zonesorting_0_out:')
  print(x[0].nonzero(), x[0][x[0].nonzero()])
  print(x[2].nonzero(), x[2][x[2].nonzero()])
  print(x[5].nonzero(), x[5][x[5].nonzero()])

outputs = model_zonemerging_0(tiny_inputs)
x = tf_constant_value(outputs[0])
with np.printoptions(linewidth=100, threshold=1000):
  print('zonemerging_0_out:')
  print(x[0].nonzero(), x[0][x[0].nonzero()])
  print(x[2].nonzero(), x[2][x[2].nonzero()])
  print(x[5].nonzero(), x[5][x[5].nonzero()])

outputs = model_trkbuilding_0(tiny_inputs)
x = tf_constant_value(outputs)
x[x == ma_fill_value] = 0
with np.printoptions(linewidth=100, threshold=1000):
  print('trkbuilding_0_out:')
  print(x[0])
  print(x[2])
  print(x[5])

zoning_0_out:
(array([], dtype=int64), array([], dtype=int64))
(array([], dtype=int64), array([], dtype=int64))
(array([], dtype=int64), array([], dtype=int64))
pooling_0_out:
(array([], dtype=int64),)
(array([], dtype=int64),)
(array([], dtype=int64),)
suppression_0_out:
(array([], dtype=int64),) []
(array([], dtype=int64),) []
(array([], dtype=int64),) []
zonesorting_0_out:
(array([], dtype=int64),) []
(array([], dtype=int64),) []
(array([], dtype=int64),) []
zonemerging_0_out:
(array([0, 1, 2, 3]),) [63 53 53 21]
(array([0, 1, 2, 3]),) [63 53 49 18]
(array([0, 1, 2, 3]),) [63 53 49 18]
trkbuilding_0_out:
[[   0  100    8    1   17   96   24    4   24    0    0    0    0    3    2    0   -1   -6    6
    -2    2    0    0    0    0   -7    0    1    3    0    0    5    5   -5   -4    0 1824   62
    63    0]
 [   0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0