In [1]:
import sys
import os

if sys.version_info[0] < 3:
  print('[ERROR] You need to run this with Python 3.')
  raise AssertionError

In [2]:
import numpy as np

from emtf_algos import *
from emtf_logger import *

In [3]:
# Set random seed
np.random.seed(2027)

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers as k_layers
from tensorflow.keras import backend as k_backend
import matplotlib as mpl
import matplotlib.pyplot as plt

# Set random seed
tf.random.set_seed(2027)

#import numba
#from numba import njit, vectorize
#import dask
#import dask.array as da

logger = get_logger()
logger.info('Using cmssw      : {0}'.format(os.environ['CMSSW_VERSION'] if 'CMSSW_VERSION' in os.environ else 'n/a'))
logger.info('Using python     : {0}'.format(sys.version.replace('\n', '')))
logger.info('Using numpy      : {0}'.format(np.__version__))
logger.info('Using tensorflow : {0}'.format(tf.__version__))
logger.info('Using keras      : {0}'.format(keras.__version__))
logger.info('.. list devices  : {0}'.format(tf.config.list_physical_devices()))
logger.info('Using matplotlib : {0}'.format(mpl.__version__))
#logger.info('Using numba      : {0}'.format(numba.__version__))
#logger.info('Using dask       : {0}'.format(dask.__version__))

assert k_backend.backend() == 'tensorflow'
assert k_backend.image_data_format() == 'channels_last'

%matplotlib inline

[INFO    ] Using cmssw      : CMSSW_10_6_3
[INFO    ] Using python     : 3.6.10 |Anaconda, Inc.| (default, May  8 2020, 02:54:21) [GCC 7.3.0]
[INFO    ] Using numpy      : 1.19.1
[INFO    ] Using tensorflow : 2.2.0
[INFO    ] Using keras      : 2.3.0-tf
[INFO    ] .. list devices  : [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:XLA_CPU:0', device_type='XLA_CPU')]
[INFO    ] Using matplotlib : 3.2.2


In [4]:
# Settings

# zone: (0,1,2) -> eta=(1.98..2.5, 1.55..1.98, 1.2..1.55)
zone = 0
#zone = 1
#zone = 2

# timezone: (0,1,2) -> BX=(-1,0,+1)
timezone = 1

maxevents = 10
#maxevents = -1

# Input files
patterns_fname = 'patterns_zone%i.npz' % zone
zone_images_fname = 'zone_images_zone%i.h5' % zone

logger.info('Processing zone {0} timezone {1}'.format(zone, timezone))
logger.info('.. maxevents        : {0}'.format(maxevents))

[INFO    ] Processing zone 0 timezone 1
[INFO    ] .. maxevents        : 10


In [5]:
# Styling
plt.style.use('tdrstyle.mplstyle')

# Color maps
from matplotlib.colors import ListedColormap
viridis_mod = ListedColormap(plt.cm.viridis.colors, name='viridis_mod')
viridis_mod.set_under('w',1)

from matplotlib.colors import LinearSegmentedColormap
cdict = {
  'red'  : ((0.0, 0.0, 0.0), (0.746032, 0.0, 0.0), (1.0, 1.0, 1.0)),
  'green': ((0.0, 0.0, 0.0), (0.365079, 0.0, 0.0), (0.746032, 1.0, 1.0), (1.0, 1.0, 1.0)),
  'blue' : ((0.0, 0.0416, 0.0416), (0.365079, 1.0, 1.0), (1.0, 1.0, 1.0)),
}
blue_hot = LinearSegmentedColormap('blue_hot', cdict)

cdict = {
  'red'  : ((0.0, 0.0, 0.0), (0.746032, 0.0, 0.0), (1.0, 1.0, 1.0)),
  'green': ((0.0, 0.0416, 0.0416), (0.365079, 1.0, 1.0), (1.0, 1.0, 1.0)),
  'blue' : ((0.0, 0.0, 0.0), (0.365079, 0.0, 0.0), (0.746032, 1.0, 1.0), (1.0, 1.0, 1.0)),
}
green_hot = LinearSegmentedColormap('green_hot', cdict)

cdict = {
  'red'  : ((0.0, 1.0, 1.0), (1.0, 1.0, 1.0)),
  'green': ((0.0, 1.0, 1.0), (1.0, 0.0, 0.0)),
  'blue' : ((0.0, 1.0, 1.0), (1.0, 0.0, 0.0)),
}
red_binary = LinearSegmentedColormap('red_binary', cdict)

### Load data

In [6]:
def load_patterns():
  patterns = []
  boxes_act = []
  hitmap_quality = []
  for i in range(num_emtf_zones):
    fname = patterns_fname.replace('zone%i' % zone, 'zone%i' % i)
    logger.info('Loading from {0}'.format(fname))
    with np.load(fname) as loaded:
      patterns.append(loaded['patterns'])
      boxes_act.append(loaded['boxes_act'])
      hitmap_quality.append(loaded['hitmap_quality_ranks'])  # hitmap_quality_ranks -> hitmap_quality
  patterns = np.asarray(patterns)
  boxes_act = np.asarray(boxes_act)
  hitmap_quality = np.asarray(hitmap_quality)
  logger.info('patterns: {0} boxes_act: {1} hitmap_quality: {2}'.format(patterns.shape, boxes_act.shape, hitmap_quality.shape))
  return patterns, boxes_act, hitmap_quality

import h5py
loaded_h5 = None  # hdf5 file handle

def load_zone_sparse_images(fname):
  global loaded_h5
  if loaded_h5 is None:
    logger.info('Loading from {0}'.format(fname))
    loaded_h5 = h5py.File(fname, 'r')
  zone_box_anchors = loaded_h5['zone_box_anchors']
  zone_sparse_images = SparseTensorValue(indices=loaded_h5['zone_sparse_images_indices'],
                                         values=loaded_h5['zone_sparse_images_values'],
                                         dense_shape=loaded_h5['zone_sparse_images_dense_shape'])
  logger.info('zone_box_anchors: {0} zone_sparse_images: {1}'.format(zone_box_anchors.shape, zone_sparse_images.dense_shape))
  return zone_box_anchors, zone_sparse_images

def load_zone_hits(fname):
  global loaded_h5
  if loaded_h5 is None:
    logger.info('Loading from {0}'.format(fname))
    loaded_h5 = h5py.File(fname, 'r')
  zone_part = loaded_h5['zone_part']
  zone_hits = RaggedTensorValue(values=loaded_h5['zone_hits_values'],
                                row_splits=loaded_h5['zone_hits_row_splits'])
  zone_simhits = RaggedTensorValue(values=loaded_h5['zone_simhits_values'],
                                   row_splits=loaded_h5['zone_simhits_row_splits'])
  logger.info('zone_part: {0} zone_hits: {1} zone_simhits: {2}'.format(zone_part.shape, zone_hits.shape, zone_simhits.shape))
  return zone_part, zone_hits, zone_simhits

In [7]:
def sparse_to_dense_quick(sparse, maxevents):
  dense_shape = (maxevents,) + sparse.dense_shape[1:]
  dense = np.zeros(dense_shape, dtype=sparse.dtype)
  for i in range(len(sparse.indices)):
    if sparse.indices[i, 0] >= maxevents:
      break
    tup = tuple(sparse.indices[i])
    dense[tup] = sparse.values[i]
  return dense

In [8]:
patterns, boxes_act, hitmap_quality = load_patterns()

# Reshape boxes_act
boxes_act_reshaped = []
for i in range(num_emtf_zones):
  b = boxes_act[i, 3, [3, 2, 4, 1, 5, 0, 6]]  # order by straightness, only prompt patterns
  boxes_act_reshaped.append(b)
boxes_act_reshaped = np.asarray(boxes_act_reshaped)
boxes_act_reshaped = np.transpose(boxes_act_reshaped, [0, 4, 3, 2, 1])  # kernel shape is HWCX
logger.info('boxes_act_reshaped: {0}'.format(boxes_act_reshaped.shape))

# Modify hitmap_quality
hitmap_quality_reshaped = hitmap_quality // 4  # from 8-bit to 6-bit
assert hitmap_quality_reshaped.max() == 63
logger.info('hitmap_quality_reshaped: {0}'.format(hitmap_quality_reshaped.shape))

[INFO    ] Loading from patterns_zone0.npz
[INFO    ] Loading from patterns_zone1.npz
[INFO    ] Loading from patterns_zone2.npz
[INFO    ] patterns: (3, 7, 7, 8, 3) boxes_act: (3, 7, 7, 8, 111, 1) hitmap_quality: (3, 256)
[INFO    ] boxes_act_reshaped: (3, 1, 111, 8, 7)
[INFO    ] hitmap_quality_reshaped: (3, 256)


In [9]:
zone_box_anchors, zone_sparse_images = load_zone_sparse_images(zone_images_fname)

zone_images_test = sparse_to_dense_quick(zone_sparse_images, maxevents)
zone_box_anchors_test = zone_box_anchors[:maxevents]
logger.info('zone_box_anchors_test: {0} zone_images_test: {1}'.format(zone_box_anchors_test.shape, zone_images_test.shape))

[INFO    ] Loading from zone_images_zone0.h5
[INFO    ] zone_box_anchors: (652055,) zone_sparse_images: (652055, 8, 288, 1)
[INFO    ] zone_box_anchors_test: (10,) zone_images_test: (10, 8, 288, 1)


In [10]:
zone_part, zone_hits, zone_simhits = load_zone_hits(zone_images_fname)

[INFO    ] zone_part: (652055, 9) zone_hits: (652055, None, 16) zone_simhits: (652055, None, 16)


### Create inputs

In [11]:
image_format = zone_images_test.shape[1:]
(num_rows, num_cols, num_channels) = image_format

num_patterns = 7
num_emtf_out_tracks = 4
num_emtf_out_variables = 36
num_embedding_input_dim = (2 ** num_rows)

hits_metadata = ['emtf_layer', 'ri_layer', 'zones', 'timezones',
                 'emtf_chamber', 'emtf_segment', 'detlayer', 'bx',
                 'emtf_phi', 'emtf_bend', 'emtf_theta', 'emtf_theta_alt',
                 'emtf_qual', 'emtf_time', 'fr', 'rsvd']
hits_metadata = dict(zip(hits_metadata, range(len(hits_metadata))))
#print(hits_metadata)

ind_emtf_chamber = hits_metadata['emtf_chamber']
ind_emtf_segment = hits_metadata['emtf_segment']

ind_emtf_phi = hits_metadata['emtf_phi']
ind_emtf_bend = hits_metadata['emtf_bend']
ind_emtf_theta = hits_metadata['emtf_theta']
ind_emtf_theta_alt = hits_metadata['emtf_theta_alt']
ind_emtf_qual = hits_metadata['emtf_qual']
ind_emtf_time = hits_metadata['emtf_time']
ind_zones = hits_metadata['zones']
ind_timezones = hits_metadata['timezones']
ind_bx = hits_metadata['bx']
ind_valid = hits_metadata['rsvd']  # CUIDADO: use 'rsvd' for the moment

In [12]:
def create_inputs():
  inputs = []
  sparse_inputs = []
  for ievt in range(zone_hits.shape[0]):
    if maxevents != -1 and ievt == maxevents:
      break

    dense_shape = np.array([num_emtf_chambers, num_emtf_segments, num_emtf_variables], dtype=np.int32)
    zone_hits_columns = [ind_emtf_chamber, ind_emtf_segment]
    indices = zone_hits[ievt][:, zone_hits_columns]
    zone_hits_columns = [ind_emtf_phi, ind_emtf_bend, ind_emtf_theta, ind_emtf_theta_alt, ind_emtf_qual, ind_emtf_time,
                         ind_zones, ind_timezones, ind_bx, ind_valid]
    values = zone_hits[ievt][:, zone_hits_columns]
    values[:, -1] = 1  # CUIDADO: set valid to 1
    #print(dense_shape)
    #print(indices.shape, indices)
    #print(values.shape, values)

    # Apply truncation
    valid = indices[:, 1] < num_emtf_segments
    indices = indices[valid]
    values = values[valid]

    # Mimic sparse_to_dense()
    ndims = indices.shape[1]
    tup = tuple(indices[: ,i] for i in range(ndims))
    dense = np.zeros(dense_shape, dtype=values.dtype)
    dense[tup] = values
    inputs.append(dense)
    sparse_inputs.append(np.concatenate((indices, values), axis=-1))
  return np.asarray(inputs), sparse_inputs

In [13]:
inputs, sparse_inputs = create_inputs()

logger.info('inputs: {0} sparse_inputs: ({1}, None, {2})'.format(inputs.shape, len(sparse_inputs), sparse_inputs[0].shape[-1]))

[INFO    ] inputs: (10, 115, 8, 10) sparse_inputs: (10, None, 12)


In [14]:
# Debug
print(np.array2string(sparse_inputs[0], separator=', ', threshold=1000))
print(np.array2string(sparse_inputs[2], separator=', ', threshold=1000))

[[   2,    0, 2548,    5,   18,   17,   -6,    0,    4,    3,    0,    1],
 [   2,    1, 2548,    5,   17,   18,   -6,    0,    4,    3,    0,    1],
 [  19,    0, 2684,    2,   16,   16,    6,    0,    4,    3,    0,    1],
 [  28,    0, 2819,   15,   17,   16,   -5,    0,    4,    3,    0,    1],
 [  28,    1, 2728,    0,   16,   17,   -5,    0,    4,    3,    0,    1],
 [  37,    0, 2736,    0,   16,   16,   -5,    0,    4,    3,    0,    1],
 [  55,    0, 2505,    0,   18,   18,    0,    0,    4,    2,    0,    1],
 [  73,    0, 2675,    0,   19,   19,    0,    0,    4,    2,    0,    1],
 [  82,    0, 2888,    0,   17,   17,    0,    1,    4,    2,    0,    1],
 [  82,    1, 2714,    0,   17,   17,    0,    1,    4,    2,    0,    1],
 [  91,    0, 2737,    0,   17,   17,    0,    0,    4,    2,    0,    1],
 [ 109,    0, 2479,   15,   17,   17,    6,    0,    4,    2,    0,    1]]
[[  29,    0, 4503,    1,   11,   11,    5,    0,    4,    6,   -1,    1],
 [  29,    1, 4503,    1,

### Create model

In [15]:
def build_zone_images(x, zone=zone, image_format=image_format):
  # Utility functions & LUTs
  inverse_fn = lambda F, y: [[i for (i, y_i) in enumerate(F) if y_i == y_j] for y_j in y]
  to_array = lambda x: np.asarray([np.asarray(x_i) for x_i in x])
  to_list = lambda x: [x_i.tolist() for x_i in x]
  flatten = lambda x: np.asarray([x_i_i for x_i in x for x_i_i in x_i])

  def to_array(x):  # improved version
    ragged = len(set([len(x_i) for x_i in x])) > 1
    if ragged:
      return np.asarray([np.asarray(x_i) for x_i in x], dtype=np.object)
    return np.asarray([x_i for x_i in x])

  num_emtf_ri_layers = 19
  ri_layer_to_chamber_lut = to_array(inverse_fn(chamber_to_ri_layer_lut, range(num_emtf_ri_layers)))
  #ri_layer_to_chamber_lut_flat = flatten(ri_layer_to_chamber_lut)

  num_emtf_zo_layers = 8
  ri_layer_to_zo_layer_lut = find_emtf_zo_layer_lut()[:, zone]
  zo_layer_to_ri_layer_lut = to_array(inverse_fn(ri_layer_to_zo_layer_lut, range(num_emtf_zo_layers)))

  zo_layer_to_chamber_lut = to_array([
      [c for ri_layer in ri_layers for c in ri_layer_to_chamber_lut[ri_layer]] \
      for ri_layers in zo_layer_to_ri_layer_lut
  ])

  def get_boolean_mask(zo_layer):
    indices = zo_layer_to_chamber_lut[zo_layer]
    boolean_mask = np.zeros(num_emtf_chambers, dtype=np.bool)
    boolean_mask[indices] = 1
    return boolean_mask

  # Prepare zone images
  zone_images = np.zeros((x.shape[0],) + image_format, dtype=np.bool)

  # Loop over events
  for ievt in range(zone_images.shape[0]):
    x_emtf_phi = x[ievt][..., 0]
    x_zones = x[ievt][..., 6]
    x_timezones = x[ievt][..., 7]
    x_valid = x[ievt][..., 9]

    valid = (x_valid == 1) & \
            (x_zones & (1<<(2-zone))).astype(np.bool) & \
            (x_timezones & (1<<(2-timezone))).astype(np.bool)
    zo_phi = find_emtf_zo_phi(x_emtf_phi)

    # Loop over rows
    rows = []
    cols = []
    channels = []
    for zo_layer in range(zone_images.shape[1]):
      boolean_mask = get_boolean_mask(zo_layer)
      _valid = valid[boolean_mask]
      _zo_phi = zo_phi[boolean_mask][_valid]
      rows.extend((_zo_phi * 0) + zo_layer)
      cols.extend(_zo_phi)
      channels.extend((_zo_phi * 0))

    # Fill zone image
    zone_images[ievt][(rows, cols, channels)] = 1
  return zone_images

In [16]:
# Creating custom layers
# See: https://www.tensorflow.org/tutorials/customization/custom_layers

class Zoning(k_layers.Layer):
  def __init__(self, zone, image_format=image_format, **kwargs):
    super(Zoning, self).__init__(**kwargs)
    self.zone = zone
    self.image_format = image_format

    # Call build_zone_images()
    import functools
    kwargs = dict(zone=self.zone, image_format=self.image_format)
    _build_zone_images = functools.partial(build_zone_images, **kwargs)
    #py_func = lambda x: tf.py_function(_build_zone_images, [x], tf.bool)
    py_func = lambda x: tf.numpy_function(_build_zone_images, [x], tf.bool)
    self.py_func = k_layers.Lambda(py_func)

  def call(self, inputs):
    x = tf.cast(inputs, dtype=tf.int32)
    x = self.py_func(x)
    x = tf.cast(x, dtype=inputs.dtype)
    output_shape = (None,) + self.image_format
    x.set_shape(output_shape)
    return x

class Pooling(k_layers.Layer):
  def __init__(self, zone, image_format=image_format, num_patterns=num_patterns, **kwargs):
    super(Pooling, self).__init__(**kwargs)
    self.zone = zone
    self.image_format = image_format
    self.num_patterns = num_patterns

    # SeparableConv2D but only the depthwise conv (i.e. without the pointwise conv)
    # See: https://www.tensorflow.org/api_docs/python/tf/keras/layers/DepthwiseConv2D
    # See: https://www.tensorflow.org/api_docs/python/tf/keras/layers/SeparableConv2D
    from k_layers_separable_conv2d import SeparableConv2D as MySeparableConv2D
    w_init = tf.keras.initializers.Constant(boxes_act_reshaped[self.zone])
    conv2d_kwargs = dict(filters=1, kernel_size=(boxes_act_reshaped.shape[1], boxes_act_reshaped.shape[2]), depth_multiplier=self.num_patterns,
                         strides=(1, 1), padding='same', activation=None, use_bias=False,
                         depthwise_initializer=w_init, pointwise_initializer='ones', trainable=False)
    self.conv2d = MySeparableConv2D(**conv2d_kwargs)

    # Dot product coeffs for packing the last axis
    self.po2_coeffs = (2 ** np.arange(self.image_format[0]))  # [1,2,4,8,16,32,64,128]

  def call(self, inputs):
    # Conv
    x = inputs
    x = tf.transpose(x, perm=(0, 3, 2, 1))  # NHWC -> NCWH
    x = self.conv2d(x)
    x = tf.clip_by_value(x, 0, 1)
    x = tf.reshape(x, [-1, self.image_format[2], self.image_format[1], self.image_format[0], self.num_patterns])  # NCWHX
    x = tf.transpose(x, perm=(0, 1, 2, 4, 3))  # NCWHX -> NCWXH

    # Pack the last axis
    x = tf.reduce_sum(x * self.po2_coeffs, axis=-1)  # pack the 8 bits from H into a single number
    x = tf.reduce_sum(x, axis=1)  # NCWX -> NWX, C is dim of size 1
    return x

class Erosion(k_layers.Layer):
  def __init__(self, zone, image_format=image_format, **kwargs):
    super(Erosion, self).__init__(**kwargs)
    self.zone = zone
    self.image_format = image_format

    # Embedding
    # See: https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
    w_init = tf.keras.initializers.Constant(hitmap_quality_reshaped[self.zone])
    embedding_kwargs = dict(input_dim=num_embedding_input_dim, output_dim=1, input_length=1,
                            embeddings_initializer=w_init, trainable=False)
    self.embedding = k_layers.Embedding(**embedding_kwargs)

  def call(self, inputs):
    # Embedding
    x = inputs
    x = self.embedding(x)
    x = tf.reduce_sum(x, axis=-1)  # NWXY -> NWX, Y is dim of size 1
    x = tf.floor(x / 4)  # truncate the last two bits

    # Non max suppression
    # Regarding the usage of tf.gather_nd(), see https://stackoverflow.com/q/50578544
    indices = tf.argmax(x, axis=-1, output_type=tf.int32)
    indices_0 = tf.meshgrid(tf.range(tf.shape(indices)[0]), tf.range(indices.shape[1]), indexing='ij')
    indices = tf.stack(indices_0 + [indices], axis=-1)
    x = tf.gather_nd(x, indices)  # like x = x[indices]
    #x_padded = tf.pad(x, paddings=[[0, 0], [1, 1]])  # note: x == x_padded[:, 1:-1]
    #mask = (x > x_padded[:, :-2]) & (x >= x_padded[:, 2:])  # note: x > x_left && x >= x_right
    x_padded = tf.pad(x, paddings=[[0, 0], [3, 3]])  # use wider receptive field, note: x == x_padded[:, 3:-3]
    mask = (x > x_padded[:, :-6]) & (x > x_padded[:, 2:-4]) & (x >= x_padded[:, 4:-2]) & (x >= x_padded[:, 6:])  # note: x > x_left && x >= x_right
    mask = tf.cast(mask, dtype=x.dtype)

    # Apply indices and mask to inputs
    x = inputs
    x = tf.gather_nd(x, indices)
    x = x * mask
    return x

class ZoneSorting(k_layers.Layer):
  def __init__(self, zone, num_emtf_out_tracks=num_emtf_out_tracks, **kwargs):
    super(ZoneSorting, self).__init__(**kwargs)
    self.zone = zone
    self.num_emtf_out_tracks = num_emtf_out_tracks

    # Embedding
    # See: https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
    w_init = 'zeros'  #FIXME
    embedding_kwargs = dict(input_dim=num_embedding_input_dim, output_dim=1, input_length=1,
                            embeddings_initializer=w_init, trainable=False)
    self.embedding = k_layers.Embedding(**embedding_kwargs)

  def call(self, inputs):
    # Embedding
    x = inputs
    x = self.embedding(x)
    x = tf.reduce_sum(x, axis=-1)  # NWY -> NW, Y is dim of size 1

    # Max elements
    indices = tf.argsort(x, axis=-1, direction='DESCENDING', stable=True)
    indices = indices[:, :self.num_emtf_out_tracks]
    indices_0 = tf.tile(tf.expand_dims(tf.range(tf.shape(indices)[0]), axis=-1), multiples=(1, self.num_emtf_out_tracks))
    indices = tf.stack([indices_0, indices], axis=-1)

    # Apply indices to inputs
    x = inputs
    x = tf.gather_nd(x, indices)
    return x

class ZoneMerging(k_layers.Layer):
  def __init__(self, num_emtf_out_tracks=num_emtf_out_tracks, **kwargs):
    super(ZoneMerging, self).__init__(**kwargs)
    self.num_emtf_out_tracks = num_emtf_out_tracks

    # Embedding
    # See: https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
    w_init = 'zeros'  #FIXME
    embedding_kwargs = dict(input_dim=num_embedding_input_dim, output_dim=1, input_length=1,
                            embeddings_initializer=w_init, trainable=False)
    self.embedding = k_layers.Embedding(**embedding_kwargs)

  def call(self, inputs):
    # Embedding
    x = inputs
    x = self.embedding(x)
    x = tf.reduce_sum(x, axis=-1)  # NWY -> NW, Y is dim of size 1

    # Max elements
    indices = tf.argsort(x, axis=-1, direction='DESCENDING', stable=True)
    indices = indices[:, :self.num_emtf_out_tracks]
    indices_0 = tf.tile(tf.expand_dims(tf.range(tf.shape(indices)[0]), axis=-1), multiples=(1, self.num_emtf_out_tracks))
    indices = tf.stack([indices_0, indices], axis=-1)

    # Apply indices to inputs
    x = inputs
    x = tf.gather_nd(x, indices)
    return x

class TrackBuilding(k_layers.Layer):
  def __init__(self, num_emtf_out_variables=num_emtf_out_variables, **kwargs):
    super(TrackBuilding, self).__init__(**kwargs)
    self.num_emtf_out_variables = num_emtf_out_variables

  def call(self, inputs):
    # Expand dim
    indices = tf.zeros_like(inputs, dtype=tf.int32)
    indices_0 = tf.expand_dims(tf.expand_dims(tf.range(tf.shape(indices)[0]), axis=-1), axis=-1)
    indices_0 = tf.tile(indices_0, multiples=(1, indices.shape[1], self.num_emtf_out_variables))
    indices_1 = tf.expand_dims(tf.expand_dims(tf.range(indices.shape[1]), axis=0), axis=-1)
    indices_1 = tf.tile(indices_1, multiples=(tf.shape(indices)[0], 1, self.num_emtf_out_variables))
    indices = tf.stack([indices_0, indices_1], axis=-1)

    # Apply indices to inputs
    x = inputs
    x = tf.gather_nd(x, indices)
    return x

In [17]:
def create_model():
  # Input
  inputs = keras.Input(shape=(num_emtf_chambers, num_emtf_segments, num_emtf_variables), name='inputs')
  x = inputs

  # Loop over zones
  x_list = []

  for i in range(num_emtf_zones):
    # Make zone images
    x_i = Zoning(zone=i, name='zoning_{0}'.format(i))(x)

    # Pattern recognition
    x_i = Pooling(zone=i, name='pooling_{0}'.format(i))(x_i)
    x_i = Erosion(zone=i, name='erosion_{0}'.format(i))(x_i)

    # Zone sorter
    x_i = ZoneSorting(zone=i, name='zone_sorting_{0}'.format(i))(x_i)

    # Add x_i to x_list
    x_list.append(x_i)

  # Merge zone outputs
  x = k_layers.Concatenate(axis=-1)(x_list)
  x = ZoneMerging(name='zone_merging')(x)

  # Track builder
  x = TrackBuilding(name='track_building'.format(i))(x)

  # Output
  outputs = x

  # Model
  model = keras.Model(inputs=inputs, outputs=outputs, name='awesome_model')

  # Summary
  model.summary()
  return model

In [18]:
model = create_model()

print('trainable weights:', len(model.trainable_weights))
print('all weights:', len(model.weights))

Model: "awesome_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
inputs (InputLayer)             [(None, 115, 8, 10)] 0                                            
__________________________________________________________________________________________________
zoning_0 (Zoning)               (None, 8, 288, 1)    0           inputs[0][0]                     
__________________________________________________________________________________________________
zoning_1 (Zoning)               (None, 8, 288, 1)    0           inputs[0][0]                     
__________________________________________________________________________________________________
zoning_2 (Zoning)               (None, 8, 288, 1)    0           inputs[0][0]                     
______________________________________________________________________________________

### Evaluate model

In [19]:
outputs = model(inputs)

logger.info('outputs: {0} type: {1}'.format(outputs.shape, type(outputs)))

[INFO    ] outputs: (10, 4, 36) type: <class 'tensorflow.python.framework.ops.EagerTensor'>


In [20]:
# Debug
model_zoning_0 = keras.Model(inputs=model.input,
                             outputs=model.get_layer('zoning_0').output)
outputs = model_zoning_0(inputs)
print('outputs: {0} type: {1}'.format(outputs.shape, type(outputs)))

if isinstance(outputs, tf.Tensor):
  x = outputs.numpy()
else:
  x = outputs

with np.printoptions(linewidth=100, threshold=1000):
  print(x[0].nonzero())
  print(x[2].nonzero())

outputs: (10, 8, 288, 1) type: <class 'tensorflow.python.framework.ops.EagerTensor'>
(array([0, 1, 2, 3, 4, 5, 5, 6, 6, 7]), array([127, 129, 132, 140, 140, 143, 149, 142, 153, 144]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
(array([0, 2, 3, 4, 5, 6, 7]), array([264, 263, 257, 256, 254, 253, 253]), array([0, 0, 0, 0, 0, 0, 0]))


In [21]:
# Debug
model_pooling_0 = keras.Model(inputs=model.input,
                              outputs=model.get_layer('pooling_0').output)
outputs = model_pooling_0(inputs)
print('outputs: {0} type: {1}'.format(outputs.shape, type(outputs)))

if isinstance(outputs, tf.Tensor):
  x = outputs.numpy()
else:
  x = outputs

with np.printoptions(linewidth=100, threshold=1000):
  print(x[0].nonzero())
  print(x[2].nonzero())

outputs: (10, 288, 7) type: <class 'tensorflow.python.framework.ops.EagerTensor'>
(array([ 93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 106, 107, 107, 108,
       108, 109, 109, 110, 110, 111, 111, 112, 112, 113, 113, 114, 114, 115, 115, 115, 116, 116,
       116, 117, 117, 117, 118, 118, 118, 119, 119, 119, 120, 120, 120, 121, 121, 121, 122, 122,
       122, 123, 123, 123, 123, 124, 124, 124, 124, 125, 125, 125, 125, 126, 126, 126, 127, 127,
       127, 128, 128, 128, 129, 129, 129, 130, 130, 130, 131, 131, 131, 132, 132, 133, 133, 133,
       134, 134, 134, 134, 135, 135, 135, 135, 135, 136, 136, 136, 136, 137, 137, 137, 137, 138,
       138, 138, 138, 138, 138, 139, 139, 139, 139, 139, 139, 139, 140, 140, 140, 140, 140, 140,
       140, 141, 141, 141, 141, 141, 141, 141, 142, 142, 142, 142, 142, 142, 142, 143, 143, 143,
       143, 143, 143, 143, 144, 144, 144, 144, 144, 144, 144, 145, 145, 145, 145, 145, 145, 146,
       146, 146, 146, 146, 147, 147, 147, 14

In [22]:
# Debug
model_erosion_0 = keras.Model(inputs=model.input,
                              outputs=model.get_layer('erosion_0').output)
outputs = model_erosion_0(inputs)
print('outputs: {0} type: {1}'.format(outputs.shape, type(outputs)))

if isinstance(outputs, tf.Tensor):
  x = outputs.numpy()
else:
  x = outputs

with np.printoptions(linewidth=100, threshold=1000):
  print(x[0].nonzero(), x[0][x[0].nonzero()])
  print(x[2].nonzero(), x[2][x[2].nonzero()])

outputs: (10, 288) type: <class 'tensorflow.python.framework.ops.EagerTensor'>
(array([ 93,  97, 106, 129, 139]),) [  1.   3.   7.   7. 247.]
(array([230, 237, 244, 254]),) [  1.   5. 133. 237.]
