# Script to count Floops in a tensorflow model

In [1]:
import os
import tensorflow as tf
import numpy as np
def load_graph(frozen_graph_filename):

    # We load the protobuf file from the disk and parse it to retrieve the unserialized graph_def
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    # Then, we import the graph_def into a new Graph and returns it
    with tf.Graph().as_default() as graph:
    # The name var will prefix every op/nodes in your graph
    # Since we load everything in a new graph, this is not needed
        tf.import_graph_def(graph_def, name="")
    return graph, graph_def

In [2]:
from tensorflow import keras
# keras imports
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, Input, Conv2D, Dropout, Flatten
from tensorflow.keras.layers import Concatenate, BatchNormalization, Activation
from tensorflow.keras.layers import MaxPooling2D, MaxPooling3D, GRU
from tensorflow.keras.utils import plot_model
from tensorflow.keras import regularizers
from tensorflow.keras import backend as K
from tensorflow.keras import metrics
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
from tensorflow.keras.regularizers import l1

## Define all the utilities

In [3]:
from tensorflow.python.framework import graph_util

In [4]:
keras.backend.set_learning_phase(0)

In [5]:
def load_graph(frozen_graph_filename):

    # We load the protobuf file from the disk and parse it to retrieve the unserialized graph_def
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Then, we import the graph_def into a new Graph and returns it
    with tf.Graph().as_default() as graph:
        # The name var will prefix every op/nodes in your graph
        # Since we load everything in a new graph, this is not needed
        tf.import_graph_def(graph_def, name="")
    return graph, graph_def

In [6]:
def getFLOPS(model):
    run_meta = tf.RunMetadata()
    opts = tf.profiler.ProfileOptionBuilder.float_operation()

    # We use the Keras session graph in the call to the profiler.
    flops = tf.profiler.profile(graph=K.get_session().graph,
                                run_meta=run_meta, cmd='op', options=opts)

    print(flops.total_float_ops)  # Prints the "flops" of the model.

# DNN

In [7]:
# # best model from optimization
# DNN_neurons = 80
# DNN_layers = 3
# DNN_activation = 'elu'
# dropout = 0.10
# batch_size = 50
# n_epochs = 500
# labels = ['j_g', 'j_q', 'j_w', 'j_z', 'j_t']

# # #  model
# def myModel():
#     inputArray = Input(shape=(input_shape,))
#     x = Dense(DNN_neurons, activation=DNN_activation, 
#               kernel_initializer='lecun_uniform', name='dense_0')(inputArray)
#     x = Dropout(dropout)(x)
#     ####
#     for i in range(1,DNN_layers):
#         x = Dense(DNN_neurons, activation=DNN_activation, 
#                   kernel_initializer='lecun_uniform', name='dense_%i' %i)(x)
#         x = Dropout(dropout)(x)

#     output = Dense(5, activation='softmax', kernel_initializer='lecun_uniform', 
#                    name = 'output_softmax')(x)
#     ####
#     model = Model(inputs=inputArray, outputs=output)
#     model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])
#     return model

In [8]:
# input_shape = len([12, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48, 52])
# thisModel = myModel()
# thisModel.summary()

In [9]:
# thisModel = myModel()
# thisModel.compile(loss='MSE', optimizer='Adam')

# getFLOPS(thisModel)

In [10]:
# keras.backend.set_learning_phase(0)
# tfsession = keras.backend.get_session()

# constant_graph = graph_util.convert_variables_to_constants(tfsession, 
#                                                            tfsession.graph.as_graph_def(),
#                                                            ['output_softmax/Softmax'])

# f = 'constantgraph.pb'
# tf.train.write_graph(constant_graph, './', f, as_text=False)

# graph, graph_def = load_graph('constantgraph.pb')

# ins = graph.get_tensor_by_name('input_1:0')
# ins_ones = np.ones([1, 16])

# pred = graph.get_tensor_by_name('output_softmax/Softmax:0')
# sess = tf.Session(graph=graph)
# run_metadata = tf.RunMetadata()
# op = sess.graph.get_operations()
# print([m.values() for m in op])
# with graph.as_default():
#     sess.run(tf.global_variables_initializer())
#     result = sess.run(pred, 
#                       feed_dict={ins:ins_ones}, 
#                       options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#                       run_metadata=run_metadata)
#     flops = tf.profiler.profile(graph, 
#                                 options = tf.profiler.ProfileOptionBuilder.float_operation(),
#                                 run_meta=run_metadata)
#     print('FLOP after freezing', flops.total_float_ops)

# CNN

In [11]:
# CNN_filters = 10
# CNN_filter_size = 3
# CNN_MaxPool_size = 5
# CNN_layers = 1
# CNN_activation = 'elu'
# DNN_neurons = 50
# DNN_layers = 3
# DNN_activation = 'elu'
# dropout = 0.1
# batch_size = 500
# n_epochs = 500
# labels = ['j_g', 'j_q', 'j_w', 'j_z', 'j_t']
# nParticles = 50

In [12]:
# #  model
# def myModel():
#     inputImage = Input(shape=(image_shape))
#     x = Conv2D(CNN_filters, kernel_size=(CNN_filter_size,CNN_filter_size), 
#                data_format="channels_last", strides=(1, 1), padding="same", input_shape=image_shape,
#                kernel_initializer='lecun_uniform', name='cnn2D_0')(inputImage)
#     x = BatchNormalization()(x)
#     x = Activation(CNN_activation)(x)
#     x = MaxPooling2D( pool_size = (CNN_MaxPool_size,CNN_MaxPool_size))(x)
#     x = Dropout(dropout)(x)
#     for i in range(1,CNN_layers):
#         x = Conv2D(CNN_filters, kernel_size=(CNN_filter_size,CNN_filter_size), 
#                    data_format="channels_last", strides=(1, 1), padding="same", input_shape=image_shape,
#                    kernel_initializer='lecun_uniform', name='cnn2D_%i' %i)(x)
#         x = BatchNormalization()(x)
#         x = Activation(CNN_activation)(x)
#         #x = MaxPooling2D( pool_size = (CNN_MaxPool_size,CNN_MaxPool_size))(x)
#         x = Dropout(dropout)(x)
        
#     ####
#     x = Flatten()(x)
#     #
#     for i in range(DNN_layers):
#         x = Dense(DNN_neurons, activation=DNN_activation, 
#                   kernel_initializer='lecun_uniform', name='dense_%i' %i)(x)
#         x = Dropout(dropout)(x)
#     #
#     output = Dense(5, activation='softmax', kernel_initializer='lecun_uniform', 
#                    name = 'output_softmax')(x)
#     ####
#     model = Model(inputs=inputImage, outputs=output)
#     model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])

#     return model

In [13]:
# img_rows = 100#X.shape[1]
# img_cols = 100#X.shape[2]
# image_shape = (img_rows, img_cols, 1)

In [14]:
# thisModel = myModel()
# thisModel.summary()

In [15]:
# keras.backend.set_learning_phase(0)
# tfsession = keras.backend.get_session()

# constant_graph = graph_util.convert_variables_to_constants(tfsession, 
#                                                            tfsession.graph.as_graph_def(),
#                                                            ['output_softmax/Softmax'])

# f = 'constantgraph.pb'
# tf.train.write_graph(constant_graph, './', f, as_text=False)

# graph, graph_def = load_graph('constantgraph.pb')

# ins = graph.get_tensor_by_name('input_1:0')
# ins_ones = np.ones([1, 100, 100, 1])

# # batchnorm = graph.get_tensor_by_name('batch_normalization/keras_learning_phase:0')

# pred = graph.get_tensor_by_name('output_softmax/Softmax:0')
# sess = tf.Session(graph=graph)
# run_metadata = tf.RunMetadata()
# op = sess.graph.get_operations()
# print([m.values() for m in op])
# with graph.as_default():
#     sess.run(tf.global_variables_initializer())
#     result = sess.run(pred, 
#                       feed_dict={ins:ins_ones}, 
#                       options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#                       run_metadata=run_metadata)
#     flops = tf.profiler.profile(graph, 
#                                 options = tf.profiler.ProfileOptionBuilder.float_operation(),
#                                 run_meta=run_metadata)
#     print('FLOP after freezing', flops.total_float_ops)
    
# raise

# GRU

In [16]:
# nParticles = 50
# GRU_units= 50
# DNN_neurons = 40
# DNN_layers = 3
# DNN_activation = 'relu'
# dropout = 0.22
# batch_size = 500
# n_epochs = 50
# #n_epochs = 1
# labels = ['j_g', 'j_q', 'j_w', 'j_z', 'j_t']
# input_shape = (nParticles,16)

In [18]:
#  model
# def myModel():
#     inputArray = Input(shape=(input_shape))
#     x = GRU(GRU_units, activation='tanh',
#             recurrent_activation='hard_sigmoid', name='gru')(inputArray)
#     x = Dropout(dropout)(x)
#     ####
#     for i in range(0,DNN_layers):
#         x = Dense(DNN_neurons, activation=DNN_activation, 
#                   kernel_initializer='lecun_uniform', name='dense_%i' %i)(x)
#         x = Dropout(dropout)(x)
#     #
#     output = Dense(5, activation='softmax', kernel_initializer='lecun_uniform', 
#                    name = 'output_softmax')(x)
#     ####
#     model = Model(inputs=inputArray, outputs=output)
#     model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])
#     return model

In [19]:
# thisModel = myModel()
# thisModel.summary()

# thisModel.compile(loss='MSE', optimizer='Adam')
# thisModel.fit(x=np.ones([10, 50, 16]), y=np.ones([10, 5]))

In [20]:
# run_meta = tf.RunMetadata()
# opts = tf.profiler.ProfileOptionBuilder.float_operation()

# # We use the Keras session graph in the call to the profiler.
# flops = tf.profiler.profile(graph=K.get_session().graph,
#                             run_meta=run_meta, cmd='op', options=opts)

# print(flops.total_float_ops)  # Prints the "flops" of the model.

# JEDI-net

In [21]:
import torch
import torch.nn as nn
from torch.autograd.variable import *
import torch.optim as optim
import itertools

import onnx
print(onnx.__version__)
import onnx_tf
from onnx_tf.backend import prepare
# print(onnx_tf.__version__)

W0722 13:00:29.203435 139635486955328 deprecation_wrapper.py:119] From /nfshome/ocerri/miniconda2/envs/PartAN/lib/python3.7/site-packages/onnx_tf/handlers/backend/ceil.py:10: The name tf.ceil is deprecated. Please use tf.math.ceil instead.

W0722 13:00:29.211855 139635486955328 deprecation_wrapper.py:119] From /nfshome/ocerri/miniconda2/envs/PartAN/lib/python3.7/site-packages/onnx_tf/handlers/backend/depth_to_space.py:12: The name tf.depth_to_space is deprecated. Please use tf.compat.v1.depth_to_space instead.

W0722 13:00:29.215927 139635486955328 deprecation_wrapper.py:119] From /nfshome/ocerri/miniconda2/envs/PartAN/lib/python3.7/site-packages/onnx_tf/handlers/backend/erf.py:9: The name tf.erf is deprecated. Please use tf.math.erf instead.



1.5.0


W0722 13:00:30.328505 139635486955328 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

W0722 13:00:30.332672 139635486955328 deprecation_wrapper.py:119] From /nfshome/ocerri/miniconda2/envs/PartAN/lib/python3.7/site-packages/onnx_tf/handlers/backend/is_nan.py:9: The name tf.is_nan is deprecated. Please use tf.math.is_nan instead.

W0722 13:00:30.335104 139635486955328 deprecation_wrapper.py:119] From /nfshome/ocerri/miniconda2/envs/PartAN/lib/python3.7/site-packages/onnx_tf/handlers/backend/log.py:10: The name tf.log is deprecated. Please use tf.math.log instead.

W0722 13:00:30.371943 139635486955328 deprecation_wrapper.py:119] From /nfshome/ocerri/minicon

In [22]:
class GraphNet(nn.Module):
    def __init__(self, n_constituents, n_targets, params, hidden, De, Do, 
                 fr_activation=0, fo_activation=0, fc_activation=0, optimizer = 0, verbose = False):
        super(GraphNet, self).__init__()
        self.hidden = hidden
        self.P = len(params)
        self.N = n_constituents
        print(self.P, self.N)
        self.Nr = self.N * (self.N - 1)
        self.Dr = 0
        self.De = De
        self.Dx = 0
        self.Do = Do
        self.n_targets = n_targets
        self.fr_activation = fr_activation
        self.fo_activation = fo_activation
        self.fc_activation = fc_activation
        self.optimizer = optimizer
        self.verbose = verbose
        self.assign_matrices()

        self.Ra = torch.ones(self.Dr, self.Nr)
        self.fr1 = nn.Linear(2 * self.P + self.Dr, hidden)
        self.fr2 = nn.Linear(hidden, int(hidden/2))
        self.fr3 = nn.Linear(int(hidden/2), self.De)
        self.fo1 = nn.Linear(self.P + self.Dx + self.De, hidden)
        self.fo2 = nn.Linear(hidden, int(hidden/2))
        self.fo3 = nn.Linear(int(hidden/2), self.Do)
        self.fc1 = nn.Linear(self.Do * self.N, hidden)
        self.fc2 = nn.Linear(hidden, int(hidden/2))
        self.fc3 = nn.Linear(int(hidden/2), self.n_targets)

    def assign_matrices(self):
        self.Rr = torch.zeros(self.N, self.Nr)
        self.Rs = torch.zeros(self.N, self.Nr)
        receiver_sender_list = [i for i in itertools.product(range(self.N), range(self.N)) if i[0]!=i[1]]
        for i, (r, s) in enumerate(receiver_sender_list):
            self.Rr[r, i] = 1
            self.Rs[s, i] = 1
        self.Rr = Variable(self.Rr)
        self.Rs = Variable(self.Rs)

    def forward(self, x):
        Orr = self.tmul(x, self.Rr)
        Ors = self.tmul(x, self.Rs)
        B = torch.cat([Orr, Ors], 1)
        ### First MLP ###
        B = torch.transpose(B, 1, 2).contiguous()
        if self.fr_activation ==2:
            B = nn.functional.selu(self.fr1(B.view(-1, 2 * self.P + self.Dr)))
            B = nn.functional.selu(self.fr2(B))
            E = nn.functional.selu(self.fr3(B).view(-1, self.Nr, self.De))            
        elif self.fr_activation ==1:
            B = nn.functional.elu(self.fr1(B.view(-1, 2 * self.P + self.Dr)))
            B = nn.functional.elu(self.fr2(B))
            E = nn.functional.elu(self.fr3(B).view(-1, self.Nr, self.De))
        else:
            B = nn.functional.relu(self.fr1(B.view(-1, 2 * self.P + self.Dr)))
            B = nn.functional.relu(self.fr2(B))
            E = nn.functional.relu(self.fr3(B).view(-1, self.Nr, self.De))
        del B
        E = torch.transpose(E, 1, 2).contiguous()
        Ebar = self.tmul(E, torch.transpose(self.Rr, 0, 1).contiguous())
        del E
        C = torch.cat([x, Ebar], 1)
        del Ebar
        C = torch.transpose(C, 1, 2).contiguous()
        ### Second MLP ###
        if self.fo_activation ==2:
            C = nn.functional.selu(self.fo1(C.view(-1, self.P + self.Dx + self.De)))
            C = nn.functional.selu(self.fo2(C))
            O = nn.functional.selu(self.fo3(C).view(-1, self.N, self.Do))
        elif self.fo_activation ==1:
            C = nn.functional.elu(self.fo1(C.view(-1, self.P + self.Dx + self.De)))
            C = nn.functional.elu(self.fo2(C))
            O = nn.functional.elu(self.fo3(C).view(-1, self.N, self.Do))
        else:
            C = nn.functional.relu(self.fo1(C.view(-1, self.P + self.Dx + self.De)))
            C = nn.functional.relu(self.fo2(C))
            O = nn.functional.relu(self.fo3(C).view(-1, self.N, self.Do))
        del C
        ### Classification MLP ###
        if self.fc_activation ==2:
            N = nn.functional.selu(self.fc1(O.view(-1, self.Do * self.N)))
            N = nn.functional.selu(self.fc2(N))       
        elif self.fc_activation ==1:
            N = nn.functional.elu(self.fc1(O.view(-1, self.Do * self.N)))
            N = nn.functional.elu(self.fc2(N))
        else:
            N = nn.functional.relu(self.fc1(O.view(-1, self.Do * self.N)))
            N = nn.functional.relu(self.fc2(N))
        del O
        #N = nn.functional.relu(self.fc3(N))
        N = self.fc3(N)
        return N

    def tmul(self, x, y):  #Takes (I * J * K)(K * L) -> I * J * L 
        x_shape = x.size()
        y_shape = y.size()
        return torch.mm(x.view(-1, x_shape[2]), y).view(-1, x_shape[1], y_shape[1])

In [23]:
# ### Prepare Dataset                                                                                                                                                  
nParticles = 100
x = []
x.append(30) # hinned nodes                                                                                                                                            
x.append(10) # De                                                                                                                                                      
x.append(10) # Do                                                                                                                                                      
x.append(1) # fr_activation_index                                                                                                                                      
x.append(1) # fo_activation_index                                                                                                                                      
x.append(1) # fc_activation_index                                                                                                                                      
x.append(0) # optmizer_index           

params = ['j1_px', 'j1_py' , 'j1_pz' , 'j1_e' , 'j1_erel' , 'j1_pt' , 'j1_ptrel', 'j1_eta' , 'j1_etarel' , 
          'j1_etarot' , 'j1_phi' , 'j1_phirel' , 'j1_phirot', 'j1_deltaR' , 'j1_costheta' , 'j1_costhetarel']

mymodel = GraphNet(nParticles, len(params), params, int(x[0]), int(x[1]), int(x[2]), 
                       int(x[3]),  int(x[4]),  int(x[5]), int(x[6]), 0)

16 100


In [24]:
print(mymodel)
trainablePars = sum(p.numel() for p in mymodel.parameters() if p.requires_grad)
print('\nTrainable parameters:', trainablePars)

GraphNet(
  (fr1): Linear(in_features=32, out_features=30, bias=True)
  (fr2): Linear(in_features=30, out_features=15, bias=True)
  (fr3): Linear(in_features=15, out_features=10, bias=True)
  (fo1): Linear(in_features=26, out_features=30, bias=True)
  (fo2): Linear(in_features=30, out_features=15, bias=True)
  (fo3): Linear(in_features=15, out_features=10, bias=True)
  (fc1): Linear(in_features=1000, out_features=30, bias=True)
  (fc2): Linear(in_features=30, out_features=15, bias=True)
  (fc3): Linear(in_features=15, out_features=16, bias=True)
)

Trainable parameters: 33801


In [25]:
dummy_input = torch.ones((1,16,100))
out_test = mymodel(dummy_input)

torch.onnx.export(mymodel, dummy_input, 
                  "test.onnx", 
                  verbose=True,
                  input_names = ['input'], 
                  output_names = ['output'])

graph(%input : Float(1, 16, 100),
      %fr1.weight : Float(30, 32),
      %fr1.bias : Float(30),
      %fr2.weight : Float(15, 30),
      %fr2.bias : Float(15),
      %fr3.weight : Float(10, 15),
      %fr3.bias : Float(10),
      %fo1.weight : Float(30, 26),
      %fo1.bias : Float(30),
      %fo2.weight : Float(15, 30),
      %fo2.bias : Float(15),
      %fo3.weight : Float(10, 15),
      %fo3.bias : Float(10),
      %fc1.weight : Float(30, 1000),
      %fc1.bias : Float(30),
      %fc2.weight : Float(15, 30),
      %fc2.bias : Float(15),
      %fc3.weight : Float(16, 15),
      %fc3.bias : Float(16)):
  %19 : Long() = onnx::Constant[value={1}](), scope: GraphNet
  %20 : Tensor = onnx::Shape(%input), scope: GraphNet
  %21 : Long() = onnx::Gather[axis=0](%20, %19), scope: GraphNet
  %22 : Long() = onnx::Constant[value={2}](), scope: GraphNet
  %23 : Tensor = onnx::Shape(%input), scope: GraphNet
  %24 : Long() = onnx::Gather[axis=0](%23, %22), scope: GraphNet
  %25 : Float(100, 9900) 

In [26]:
model = onnx.load('test.onnx')
onnx.checker.check_model(model)
print(onnx.helper.printable_graph(model.graph))

tf_rep = prepare(model)

# Input nodes to the model
print('inputs:', tf_rep.inputs)

# Output nodes from the model
print('outputs:', tf_rep.outputs)

# All nodes in the model
print('tensor_dict:')
print(tf_rep.tensor_dict)

tf_rep.export_graph('constantgraph.pb')

  handler.ONNX_OP, handler.DOMAIN or "ai.onnx"))
  handler.ONNX_OP, handler.DOMAIN or "ai.onnx"))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))
W0722 13:00:31.272864 139635486955328 deprecation.py:323] From /nfshome/ocerri/miniconda2/envs/PartAN/lib/python3.7/site-packages/onnx_tf/handlers/backend/reshape.py:26: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
W0722 13:00:31.277809 139635486955328 deprecation.py:323] From /nfshome/ocerri/miniconda2/envs/PartAN/lib/python3.7/site-packages/onnx_tf/handlers/backend/reshape.py:31: sparse_to_dense (from tensorflow.python.ops.sparse_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` 

graph torch-jit-export (
  %input[FLOAT, 1x16x100]
) initializers (
  %fr1.weight[FLOAT, 30x32]
  %fr1.bias[FLOAT, 30]
  %fr2.weight[FLOAT, 15x30]
  %fr2.bias[FLOAT, 15]
  %fr3.weight[FLOAT, 10x15]
  %fr3.bias[FLOAT, 10]
  %fo1.weight[FLOAT, 30x26]
  %fo1.bias[FLOAT, 30]
  %fo2.weight[FLOAT, 15x30]
  %fo2.bias[FLOAT, 15]
  %fo3.weight[FLOAT, 10x15]
  %fo3.bias[FLOAT, 10]
  %fc1.weight[FLOAT, 30x1000]
  %fc1.bias[FLOAT, 30]
  %fc2.weight[FLOAT, 15x30]
  %fc2.bias[FLOAT, 15]
  %fc3.weight[FLOAT, 16x15]
  %fc3.bias[FLOAT, 16]
) {
  %19 = Constant[value = <Scalar Tensor []>]()
  %20 = Shape(%input)
  %21 = Gather[axis = 0](%20, %19)
  %22 = Constant[value = <Scalar Tensor []>]()
  %23 = Shape(%input)
  %24 = Gather[axis = 0](%23, %22)
  %25 = Constant[value = <Tensor>]()
  %26 = Constant[value = <Scalar Tensor []>]()
  %27 = Constant[value = <Scalar Tensor []>]()
  %28 = Unsqueeze[axes = [0]](%27)
  %29 = Unsqueeze[axes = [0]](%24)
  %30 = Concat[axis = 0](%28, %29)
  %31 = Reshape(%input,

In [32]:
tf.reset_default_graph()
keras.backend.set_learning_phase(0)
tfsession = keras.backend.get_session()

graph, graph_def = load_graph('constantgraph.pb')

ins = graph.get_tensor_by_name('input:0')
ins_ones = np.ones([1, 16, 100])

# batchnorm = graph.get_tensor_by_name('batch_normalization/keras_learning_phase:0')

pred = graph.get_tensor_by_name('add_22:0')
sess = tf.Session(graph=graph)
run_metadata = tf.RunMetadata()
op = sess.graph.get_operations()
print([m.values() for m in op])
with graph.as_default():
    sess.run(tf.global_variables_initializer())
    result = sess.run(pred, 
                      feed_dict={ins:ins_ones}, 
                      options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
                      run_metadata=run_metadata)
    flops = tf.profiler.profile(graph, 
                                options = tf.profiler.ProfileOptionBuilder.float_operation(),
                                run_meta=run_metadata)
    print('FLOP after freezing', flops.total_float_ops)
    
raise

[(<tf.Tensor 'Const:0' shape=(30,) dtype=float32>,), (<tf.Tensor 'Const_1:0' shape=(30, 1000) dtype=float32>,), (<tf.Tensor 'Const_2:0' shape=(15,) dtype=float32>,), (<tf.Tensor 'Const_3:0' shape=(15, 30) dtype=float32>,), (<tf.Tensor 'Const_4:0' shape=(16,) dtype=float32>,), (<tf.Tensor 'Const_5:0' shape=(16, 15) dtype=float32>,), (<tf.Tensor 'Const_6:0' shape=(30,) dtype=float32>,), (<tf.Tensor 'Const_7:0' shape=(30, 26) dtype=float32>,), (<tf.Tensor 'Const_8:0' shape=(15,) dtype=float32>,), (<tf.Tensor 'Const_9:0' shape=(15, 30) dtype=float32>,), (<tf.Tensor 'Const_10:0' shape=(10,) dtype=float32>,), (<tf.Tensor 'Const_11:0' shape=(10, 15) dtype=float32>,), (<tf.Tensor 'Const_12:0' shape=(30,) dtype=float32>,), (<tf.Tensor 'Const_13:0' shape=(30, 32) dtype=float32>,), (<tf.Tensor 'Const_14:0' shape=(15,) dtype=float32>,), (<tf.Tensor 'Const_15:0' shape=(15, 30) dtype=float32>,), (<tf.Tensor 'Const_16:0' shape=(10,) dtype=float32>,), (<tf.Tensor 'Const_17:0' shape=(10, 15) dtype=floa

W0722 13:02:43.987157 139635486955328 deprecation.py:323] From /nfshome/ocerri/miniconda2/envs/PartAN/lib/python3.7/site-packages/tensorflow/python/profiler/internal/flops_registry.py:142: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`


FLOP after freezing 116121330


RuntimeError: No active exception to reraise

# JEDI-net with Sum over O

In [33]:
class GraphNet(nn.Module):
    def __init__(self, n_constituents, n_targets, params, hidden, De, Do, 
                 fr_activation=0, fo_activation=0, fc_activation=0, optimizer = 0, verbose = False):
        super(GraphNet, self).__init__()
        self.hidden = hidden
        self.P = len(params)
        self.N = n_constituents
        self.Nr = self.N * (self.N - 1)
        self.Dr = 0
        self.De = De
        self.Dx = 0
        self.Do = Do
        self.n_targets = n_targets
        self.fr_activation = fr_activation
        self.fo_activation = fo_activation
        self.fc_activation = fc_activation
        self.optimizer = optimizer
        self.verbose = verbose
        self.assign_matrices()

        self.Ra = torch.ones(self.Dr, self.Nr)
        self.fr1 = nn.Linear(2 * self.P + self.Dr, hidden)
        self.fr2 = nn.Linear(hidden, int(hidden/2))
        self.fr3 = nn.Linear(int(hidden/2), self.De)
        self.fo1 = nn.Linear(self.P + self.Dx + self.De, hidden)
        self.fo2 = nn.Linear(hidden, int(hidden/2))
        self.fo3 = nn.Linear(int(hidden/2), self.Do)
        self.fc1 = nn.Linear(self.Do, hidden)
        self.fc2 = nn.Linear(hidden, int(hidden/2))
        self.fc3 = nn.Linear(int(hidden/2), self.n_targets)

    def assign_matrices(self):
        self.Rr = torch.zeros(self.N, self.Nr)
        self.Rs = torch.zeros(self.N, self.Nr)
        receiver_sender_list = [i for i in itertools.product(range(self.N), range(self.N)) if i[0]!=i[1]]
        for i, (r, s) in enumerate(receiver_sender_list):
            self.Rr[r, i] = 1
            self.Rs[s, i] = 1
        self.Rr = Variable(self.Rr)
        self.Rs = Variable(self.Rs)

    def forward(self, x):
        Orr = self.tmul(x, self.Rr)
        Ors = self.tmul(x, self.Rs)
        B = torch.cat([Orr, Ors], 1)
        ### First MLP ###
        B = torch.transpose(B, 1, 2).contiguous()
        if self.fr_activation ==2:
            B = nn.functional.selu(self.fr1(B.view(-1, 2 * self.P + self.Dr)))
            B = nn.functional.selu(self.fr2(B))
            E = nn.functional.selu(self.fr3(B).view(-1, self.Nr, self.De))            
        elif self.fr_activation ==1:
            B = nn.functional.elu(self.fr1(B.view(-1, 2 * self.P + self.Dr)))
            B = nn.functional.elu(self.fr2(B))
            E = nn.functional.elu(self.fr3(B).view(-1, self.Nr, self.De))
        else:
            B = nn.functional.relu(self.fr1(B.view(-1, 2 * self.P + self.Dr)))
            B = nn.functional.relu(self.fr2(B))
            E = nn.functional.relu(self.fr3(B).view(-1, self.Nr, self.De))
        del B
        E = torch.transpose(E, 1, 2).contiguous()
        Ebar = self.tmul(E, torch.transpose(self.Rr, 0, 1).contiguous())
        del E
        C = torch.cat([x, Ebar], 1)
        del Ebar
        C = torch.transpose(C, 1, 2).contiguous()
        ### Second MLP ###
        if self.fo_activation ==2:
            C = nn.functional.selu(self.fo1(C.view(-1, self.P + self.Dx + self.De)))
            C = nn.functional.selu(self.fo2(C))
            O = nn.functional.selu(self.fo3(C).view(-1, self.N, self.Do))
        elif self.fo_activation ==1:
            C = nn.functional.elu(self.fo1(C.view(-1, self.P + self.Dx + self.De)))
            C = nn.functional.elu(self.fo2(C))
            O = nn.functional.elu(self.fo3(C).view(-1, self.N, self.Do))
        else:
            C = nn.functional.relu(self.fo1(C.view(-1, self.P + self.Dx + self.De)))
            C = nn.functional.relu(self.fo2(C))
            O = nn.functional.relu(self.fo3(C).view(-1, self.N, self.Do))
        del C
        ## sum over the O matrix  
        O = torch.sum( O, dim=1)
        ### Classification MLP ###
        if self.fc_activation ==2:
            N = nn.functional.selu(self.fc1(O.view(-1, self.Do)))
            N = nn.functional.selu(self.fc2(N))       
        elif self.fc_activation ==1:
            N = nn.functional.elu(self.fc1(O.view(-1, self.Do)))
            N = nn.functional.elu(self.fc2(N))
        else:
            N = nn.functional.relu(self.fc1(O.view(-1, self.Do)))
            N = nn.functional.relu(self.fc2(N))
        del O
        #N = nn.functional.relu(self.fc3(N))
        N = self.fc3(N)
        return N

    def tmul(self, x, y):  #Takes (I * J * K)(K * L) -> I * J * L 
        x_shape = x.size()
        y_shape = y.size()
        return torch.mm(x.view(-1, x_shape[2]), y).view(-1, x_shape[1], y_shape[1])

In [34]:
# ### Prepare Dataset                                                                                                                                                  
nParticles = 150
x = []
x.append(50) # hinned nodes                                                                                                                                            
x.append(14) # De                                                                                                                                                      
x.append(12) # Do                                                                                                                                                      
x.append(2) # fr_activation_index                                                                                                                                      
x.append(2) # fo_activation_index                                                                                                                                      
x.append(2) # fc_activation_index                                                                                                                                      
x.append(0) # optmizer_index           

In [36]:
mymodel = GraphNet(nParticles, len(params), params, int(x[0]), int(x[1]), int(x[2]), 
                       int(x[3]),  int(x[4]),  int(x[5]), int(x[6]), 0)

In [37]:
print(mymodel)
trainablePars = sum(p.numel() for p in mymodel.parameters() if p.requires_grad)
print('\nTrainable parameters:', trainablePars)

GraphNet(
  (fr1): Linear(in_features=32, out_features=50, bias=True)
  (fr2): Linear(in_features=50, out_features=25, bias=True)
  (fr3): Linear(in_features=25, out_features=14, bias=True)
  (fo1): Linear(in_features=30, out_features=50, bias=True)
  (fo2): Linear(in_features=50, out_features=25, bias=True)
  (fo3): Linear(in_features=25, out_features=12, bias=True)
  (fc1): Linear(in_features=12, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=25, bias=True)
  (fc3): Linear(in_features=25, out_features=16, bias=True)
)

Trainable parameters: 8767


In [38]:
dummy_input = torch.ones((1,16,150))
out_test = mymodel(dummy_input)

torch.onnx.export(mymodel, dummy_input, 
                  "test.onnx", 
                  verbose=True,
                  input_names = ['input'], 
                  output_names = ['output'])

graph(%input : Float(1, 16, 150),
      %fr1.weight : Float(50, 32),
      %fr1.bias : Float(50),
      %fr2.weight : Float(25, 50),
      %fr2.bias : Float(25),
      %fr3.weight : Float(14, 25),
      %fr3.bias : Float(14),
      %fo1.weight : Float(50, 30),
      %fo1.bias : Float(50),
      %fo2.weight : Float(25, 50),
      %fo2.bias : Float(25),
      %fo3.weight : Float(12, 25),
      %fo3.bias : Float(12),
      %fc1.weight : Float(50, 12),
      %fc1.bias : Float(50),
      %fc2.weight : Float(25, 50),
      %fc2.bias : Float(25),
      %fc3.weight : Float(16, 25),
      %fc3.bias : Float(16)):
  %19 : Long() = onnx::Constant[value={1}](), scope: GraphNet
  %20 : Tensor = onnx::Shape(%input), scope: GraphNet
  %21 : Long() = onnx::Gather[axis=0](%20, %19), scope: GraphNet
  %22 : Long() = onnx::Constant[value={2}](), scope: GraphNet
  %23 : Tensor = onnx::Shape(%input), scope: GraphNet
  %24 : Long() = onnx::Gather[axis=0](%23, %22), scope: GraphNet
  %25 : Float(150, 22350) =

In [39]:
model = onnx.load('test.onnx')
onnx.checker.check_model(model)
print(onnx.helper.printable_graph(model.graph))

tf_rep = prepare(model)

# Input nodes to the model
print('inputs:', tf_rep.inputs)

# Output nodes from the model
print('outputs:', tf_rep.outputs)

# All nodes in the model
print('tensor_dict:')
print(tf_rep.tensor_dict)

tf_rep.export_graph('constantgraph.pb')

graph torch-jit-export (
  %input[FLOAT, 1x16x150]
) initializers (
  %fr1.weight[FLOAT, 50x32]
  %fr1.bias[FLOAT, 50]
  %fr2.weight[FLOAT, 25x50]
  %fr2.bias[FLOAT, 25]
  %fr3.weight[FLOAT, 14x25]
  %fr3.bias[FLOAT, 14]
  %fo1.weight[FLOAT, 50x30]
  %fo1.bias[FLOAT, 50]
  %fo2.weight[FLOAT, 25x50]
  %fo2.bias[FLOAT, 25]
  %fo3.weight[FLOAT, 12x25]
  %fo3.bias[FLOAT, 12]
  %fc1.weight[FLOAT, 50x12]
  %fc1.bias[FLOAT, 50]
  %fc2.weight[FLOAT, 25x50]
  %fc2.bias[FLOAT, 25]
  %fc3.weight[FLOAT, 16x25]
  %fc3.bias[FLOAT, 16]
) {
  %19 = Constant[value = <Scalar Tensor []>]()
  %20 = Shape(%input)
  %21 = Gather[axis = 0](%20, %19)
  %22 = Constant[value = <Scalar Tensor []>]()
  %23 = Shape(%input)
  %24 = Gather[axis = 0](%23, %22)
  %25 = Constant[value = <Tensor>]()
  %26 = Constant[value = <Scalar Tensor []>]()
  %27 = Constant[value = <Scalar Tensor []>]()
  %28 = Unsqueeze[axes = [0]](%27)
  %29 = Unsqueeze[axes = [0]](%24)
  %30 = Concat[axis = 0](%28, %29)
  %31 = Reshape(%input, %

In [40]:
tf.reset_default_graph()
keras.backend.set_learning_phase(0)
tfsession = keras.backend.get_session()

graph, graph_def = load_graph('constantgraph.pb')

ins = graph.get_tensor_by_name('input:0')
ins_ones = np.ones([1, 16, 150])

# batchnorm = graph.get_tensor_by_name('batch_normalization/keras_learning_phase:0')

pred = graph.get_tensor_by_name('add_22:0')
sess = tf.Session(graph=graph)
run_metadata = tf.RunMetadata()
op = sess.graph.get_operations()
print([m.values() for m in op])
with graph.as_default():
    sess.run(tf.global_variables_initializer())
    result = sess.run(pred, 
                      feed_dict={ins:ins_ones}, 
                      options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
                      run_metadata=run_metadata)
    flops = tf.profiler.profile(graph, 
                                options = tf.profiler.ProfileOptionBuilder.float_operation(),
                                run_meta=run_metadata)
    print('FLOP after freezing', flops.total_float_ops)
    
raise

[(<tf.Tensor 'Const:0' shape=(50,) dtype=float32>,), (<tf.Tensor 'Const_1:0' shape=(50, 12) dtype=float32>,), (<tf.Tensor 'Const_2:0' shape=(25,) dtype=float32>,), (<tf.Tensor 'Const_3:0' shape=(25, 50) dtype=float32>,), (<tf.Tensor 'Const_4:0' shape=(16,) dtype=float32>,), (<tf.Tensor 'Const_5:0' shape=(16, 25) dtype=float32>,), (<tf.Tensor 'Const_6:0' shape=(50,) dtype=float32>,), (<tf.Tensor 'Const_7:0' shape=(50, 30) dtype=float32>,), (<tf.Tensor 'Const_8:0' shape=(25,) dtype=float32>,), (<tf.Tensor 'Const_9:0' shape=(25, 50) dtype=float32>,), (<tf.Tensor 'Const_10:0' shape=(12,) dtype=float32>,), (<tf.Tensor 'Const_11:0' shape=(12, 25) dtype=float32>,), (<tf.Tensor 'Const_12:0' shape=(50,) dtype=float32>,), (<tf.Tensor 'Const_13:0' shape=(50, 32) dtype=float32>,), (<tf.Tensor 'Const_14:0' shape=(25,) dtype=float32>,), (<tf.Tensor 'Const_15:0' shape=(25, 50) dtype=float32>,), (<tf.Tensor 'Const_16:0' shape=(14,) dtype=float32>,), (<tf.Tensor 'Const_17:0' shape=(14, 25) dtype=float3

FLOP after freezing 457830794


RuntimeError: No active exception to reraise