In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import numpy
import onnx
import os
from onnx import numpy_helper
from distutils.version import LooseVersion


# Preprocessing: create a Numpy array
numpy_array = numpy.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=float)
if LooseVersion(numpy.version.version) < LooseVersion('1.14'):
    print('Original Numpy array:\n{}\n'.format(numpy.array2string(numpy_array)))
else:
    print('Original Numpy array:\n{}\n'.format(numpy.array2string(numpy_array, legacy='1.13')))

Original Numpy array:
[[ 1.  2.  3.]
 [ 4.  5.  6.]]



In [2]:
# Convert the Numpy array to a TensorProto
tensor = numpy_helper.from_array(numpy_array)
print('TensorProto:\n{}'.format(tensor))

TensorProto:
dims: 2
dims: 3
data_type: 11
raw_data: "\000\000\000\000\000\000\360?\000\000\000\000\000\000\000@\000\000\000\000\000\000\010@\000\000\000\000\000\000\020@\000\000\000\000\000\000\024@\000\000\000\000\000\000\030@"



In [3]:
# Convert the TensorProto to a Numpy array
new_array = numpy_helper.to_array(tensor)
if LooseVersion(numpy.version.version) < LooseVersion('1.14'):
    print('After round trip, Numpy array:\n{}\n'.format(numpy.array2string(numpy_array)))
else:
    print('After round trip, Numpy array:\n{}\n'.format(numpy.array2string(numpy_array, legacy='1.13')))

After round trip, Numpy array:
[[ 1.  2.  3.]
 [ 4.  5.  6.]]



In [4]:
# Save the TensorProto
with open(os.path.join('resources', 'tensor.pb'), 'wb') as f:
    f.write(tensor.SerializeToString())

In [5]:
# Load the TensorProto
new_tensor = onnx.TensorProto()
with open(os.path.join('resources', 'tensor.pb'), 'rb') as f:
    new_tensor.ParseFromString(f.read())
print('After saving and loading, new TensorProto:\n{}'.format(new_tensor))

After saving and loading, new TensorProto:
dims: 2
dims: 3
data_type: 11
raw_data: "\000\000\000\000\000\000\360?\000\000\000\000\000\000\000@\000\000\000\000\000\000\010@\000\000\000\000\000\000\020@\000\000\000\000\000\000\024@\000\000\000\000\000\000\030@"

