In [1]:
from test.models.keras.mnist_conv_net import model as keras_conv_model
from test.models.pytorch.mnist_conv_net import model as pytorch_conv_net
from torchsummary import summary
from typing import Any
from generator.loader.onnx_loader import KerasGraphExtractor, PyTorchGraphExtractor

import onnx

In [10]:
keras_conv_model.summary(line_length=75)

Model: "sequential"
___________________________________________________________________________
 Layer (type)                    Output Shape                  Param #     
 conv2d (Conv2D)                 (None, 26, 26, 32)            320         
                                                                           
 max_pooling2d (MaxPooling2D)    (None, 13, 13, 32)            0           
                                                                           
 conv2d_1 (Conv2D)               (None, 11, 11, 64)            18496       
                                                                           
 max_pooling2d_1 (MaxPooling2D)  (None, 5, 5, 64)              0           
                                                                           
 flatten (Flatten)               (None, 1600)                  0           
                                                                           
 dropout (Dropout)               (None, 1600)                  0    

In [7]:
summary(pytorch_conv_net, (1,28,28), device='cpu',)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 26, 26]             320
              ReLU-2           [-1, 32, 26, 26]               0
         MaxPool2d-3           [-1, 32, 13, 13]               0
            Conv2d-4           [-1, 64, 11, 11]          18,496
              ReLU-5           [-1, 64, 11, 11]               0
         MaxPool2d-6             [-1, 64, 5, 5]               0
           Flatten-7                 [-1, 1600]               0
           Dropout-8                 [-1, 1600]               0
            Linear-9                   [-1, 10]          16,010
          Softmax-10                   [-1, 10]               0
Total params: 34,826
Trainable params: 34,826
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.53
Params size (MB): 0.13
Estimated Tot

In [4]:
keras_conv_onnx = onnx.load('../test/models/keras/saved_models/my_cnn.onnx')
pytorch_conv_onnx = onnx.load('../test/models/pytorch/my_cnn_params.onnx')

In [5]:
keras_nodes = keras_conv_onnx.graph.node
pytorch_nodes = pytorch_conv_onnx.graph.node

In [6]:
keras_extractor = KerasGraphExtractor(keras_conv_onnx)
pytorch_extractor = PyTorchGraphExtractor(pytorch_conv_onnx)

In [7]:
keras_extractor.graph_structure

[('StatefulPartitionedCall/sequential/conv2d/BiasAdd',
  {'type': 'conv',
   'filter': 32,
   'pads': [0, 0, 0, 0],
   'strides': [1, 1],
   'kernel_shape': [3, 3]}),
 ('StatefulPartitionedCall/sequential/conv2d/Relu', {'type': 'relu'}),
 ('StatefulPartitionedCall/sequential/max_pooling2d/MaxPool',
  {'type': 'maxpool', 'strides': [2, 2], 'kernel_shape': [2, 2]}),
 ('StatefulPartitionedCall/sequential/conv2d_1/BiasAdd',
  {'type': 'conv',
   'filter': 64,
   'pads': [0, 0, 0, 0],
   'strides': [1, 1],
   'kernel_shape': [3, 3]}),
 ('StatefulPartitionedCall/sequential/conv2d_1/Relu', {'type': 'relu'}),
 ('StatefulPartitionedCall/sequential/max_pooling2d_1/MaxPool',
  {'type': 'maxpool', 'strides': [2, 2], 'kernel_shape': [2, 2]}),
 ('StatefulPartitionedCall/sequential/flatten/Reshape', {'type': 'flatten'}),
 ('StatefulPartitionedCall/sequential/dense/MatMul_StatefulPartitionedCall/sequential/dense/BiasAdd',
  {'type': 'dense', 'input': 1600, 'output': 10}),
 ('StatefulPartitionedCall/se

In [8]:
pytorch_extractor.graph_structure

[('Conv_0',
  {'type': 'conv',
   'filter': 32,
   'kernel_shape': [3, 3],
   'pads': [0, 0, 0, 0],
   'strides': [1, 1]}),
 ('Relu_1', {'type': 'relu'}),
 ('MaxPool_2', {'type': 'maxpool', 'kernel_shape': [2, 2], 'strides': [2, 2]}),
 ('Conv_3',
  {'type': 'conv',
   'filter': 64,
   'kernel_shape': [3, 3],
   'pads': [0, 0, 0, 0],
   'strides': [1, 1]}),
 ('Relu_4', {'type': 'relu'}),
 ('MaxPool_5', {'type': 'maxpool', 'kernel_shape': [2, 2], 'strides': [2, 2]}),
 ('Flatten_6', {'type': 'flatten'}),
 ('Gemm_10', {'type': 'dense', 'input': 1600, 'output': 10}),
 ('Softmax_11', {'type': 'softmax'})]

In [9]:
keras_graph_unified = [n for name, n in keras_extractor.graph_structure]
keras_graph_unified

[{'type': 'conv',
  'filter': 32,
  'pads': [0, 0, 0, 0],
  'strides': [1, 1],
  'kernel_shape': [3, 3]},
 {'type': 'relu'},
 {'type': 'maxpool', 'strides': [2, 2], 'kernel_shape': [2, 2]},
 {'type': 'conv',
  'filter': 64,
  'pads': [0, 0, 0, 0],
  'strides': [1, 1],
  'kernel_shape': [3, 3]},
 {'type': 'relu'},
 {'type': 'maxpool', 'strides': [2, 2], 'kernel_shape': [2, 2]},
 {'type': 'flatten'},
 {'type': 'dense', 'input': 1600, 'output': 10},
 {'type': 'softmax'}]

In [10]:
pytorch_graph_unified = [n for name, n in pytorch_extractor.graph_structure]
pytorch_graph_unified

[{'type': 'conv',
  'filter': 32,
  'kernel_shape': [3, 3],
  'pads': [0, 0, 0, 0],
  'strides': [1, 1]},
 {'type': 'relu'},
 {'type': 'maxpool', 'kernel_shape': [2, 2], 'strides': [2, 2]},
 {'type': 'conv',
  'filter': 64,
  'kernel_shape': [3, 3],
  'pads': [0, 0, 0, 0],
  'strides': [1, 1]},
 {'type': 'relu'},
 {'type': 'maxpool', 'kernel_shape': [2, 2], 'strides': [2, 2]},
 {'type': 'flatten'},
 {'type': 'dense', 'input': 1600, 'output': 10},
 {'type': 'softmax'}]

In [11]:
pytorch_graph_unified == keras_graph_unified

True