Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Keras end-to-end tests #36

Merged
merged 5 commits into from
May 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxmltools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from .convert import convert_coreml
from .convert import convert_sklearn
from .convert import convert_keras

from .utils import load_model
from .utils import save_model
Expand Down
6 changes: 3 additions & 3 deletions onnxmltools/convert/keras/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ def determine_tensor_type(tensor, default_batch_size, keras_shape=None):
raise ValueError('Unable to find out a correct type for tensor %s' % tensor)


def parse_keras(model):
def parse_keras(model, initial_types=None):
raw_model_container = KerasModelContainer(model)
topology = Topology(raw_model_container, default_batch_size=1)

topology = Topology(raw_model_container, default_batch_size=1, initial_types=initial_types)
scope = topology.declare_scope('__root__')

for node in model.inbound_nodes:
Expand Down Expand Up @@ -134,4 +135,3 @@ def _parse_keras(topology, parent_scope, model, inbound_node):

else:
raise RuntimeError('Unsupported Keras component %s' % type(model))

40 changes: 14 additions & 26 deletions onnxmltools/convert/keras/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,29 @@
# --------------------------------------------------------------------------

from uuid import uuid4
from ...proto import onnx_proto
from ..common import utils
from ..common._container import RawModelContainer
from ..common._topology import convert_topology
from ._parse import parse_keras

# Register conversion functions and shape inference functions
from . import operator_converters
from . import shape_calculators

class KerasModelContainer(RawModelContainer):

def __init__(self, keras_model):
super(KerasModelContainer, self).__init__(keras_model)
self._input_raw_names = set()
self._output_raw_names = set()
def convert(model, name=None, initial_types=None, doc_string=''):
'''
Convert Keras-Tensorflow Model and Sequence objects into Topology. Note that default batch size is 1 here instead of
`None` used in CoreML conversion framework. To overwrite this behavior, we can specify initial_types. Assume that a
Keras tensor is named input:0 and its shape is [None, 3]. If the desired batch size is 10, we can specify
>>> from onnxmltools.convert.common.data_types import FloatTensorType
>>> initial_types=[('input:0', FloatTensorType([10, 3]))]

def add_input_name(self, name):
self._input_raw_names.add(name)

def add_output_name(self, name):
self._output_raw_names.add(name)

@property
def input_names(self):
return [name for name in self._input_raw_names]

@property
def output_names(self):
return [name for name in self._output_raw_names]


def convert(model, name=None, doc_string=''):
topology = parse_keras(model)
:param model: A Keras model (Model or Sequence object)
:param name: Optional graph name of the produced ONNX model
:param initial_types: A list providing types for some input variables. Each element is a tuple of a variable name
and a type defined in data_types.py.
:param doc_string: A string attached onto the produced ONNX model
'''
topology = parse_keras(model, initial_types)

topology.compile()

Expand All @@ -48,4 +37,3 @@ def convert(model, name=None, doc_string=''):
onnx_model = convert_topology(topology, name, doc_string)

return onnx_model

18 changes: 11 additions & 7 deletions onnxmltools/convert/keras/operator_converters/BatchNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def convert_keras_batch_normalization(scope, operator, container):
op = operator.raw_operator
if op.axis != 3:
if op.axis != 3 and op.axis != -1:
adjusted_input_name = operator.inputs[0].full_name
else:
adjusted_input_name = scope.get_unique_variable_name(operator.inputs[0].full_name + '_transposed')
Expand All @@ -30,28 +30,31 @@ def convert_keras_batch_normalization(scope, operator, container):
if not op.center:
params.insert(1, np.zeros(params[1].shape, dtype=float))

gamma = params[0] / np.sqrt(params[3] + op.epsilon)
beta = params[1] - params[0] * params[2] / np.sqrt(params[3] + op.epsilon)

scale_tensor_name = scope.get_unique_variable_name('scale')
container.add_initializer(scale_tensor_name, onnx_proto.TensorProto.FLOAT, params[0].shape, params[0])
container.add_initializer(scale_tensor_name, onnx_proto.TensorProto.FLOAT, params[0].shape, gamma)
input_tensor_names.append(scale_tensor_name)

bias_tensor_name = scope.get_unique_variable_name('bias')
container.add_initializer(bias_tensor_name, onnx_proto.TensorProto.FLOAT, params[1].shape, params[1])
container.add_initializer(bias_tensor_name, onnx_proto.TensorProto.FLOAT, params[1].shape, beta)
input_tensor_names.append(bias_tensor_name)

mean_tensor_name = scope.get_unique_variable_name('mean')
container.add_initializer(mean_tensor_name, onnx_proto.TensorProto.FLOAT, params[2].shape, params[2])
container.add_initializer(mean_tensor_name, onnx_proto.TensorProto.FLOAT, params[2].shape, 0 * params[2])
input_tensor_names.append(mean_tensor_name)

var_tensor_name = scope.get_unique_variable_name('var')
container.add_initializer(var_tensor_name, onnx_proto.TensorProto.FLOAT, params[3].shape, params[3])
container.add_initializer(var_tensor_name, onnx_proto.TensorProto.FLOAT, params[3].shape, 1 + 0 * params[3] )
input_tensor_names.append(var_tensor_name)

attrs['epsilon'] = op.epsilon
attrs['epsilon'] = 0.
attrs['momentum'] = op.momentum
attrs['spatial'] = 1
attrs['is_test'] = 1

if op.axis != 3:
if op.axis != 3 and op.axis != -1:
# If no transpose is required, we can simply use the output of ONNX BatchNorm as the final outcome
container.add_node(op_type, input_tensor_names, operator.output_full_names, **attrs)
else:
Expand All @@ -64,3 +67,4 @@ def convert_keras_batch_normalization(scope, operator, container):


register_converter(BatchNormalization, convert_keras_batch_normalization)

2 changes: 1 addition & 1 deletion onnxmltools/convert/keras/operator_converters/Dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

_activation_map = {_get_activation('sigmoid'): 'Sigmoid',
_get_activation('softmax'): 'Softmax',
_get_activation('linear'): 'Affine',
_get_activation('linear'): 'Identity',
_get_activation('relu'): 'Relu'}


Expand Down
4 changes: 2 additions & 2 deletions onnxmltools/convert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def convert_coreml(model, name=None, initial_types=None, doc_string=''):
return convert(model, name=name, initial_types=initial_types, doc_string=doc_string)


def convert_keras(model, name=None):
def convert_keras(model, name=None, initial_types=None, doc_string=''):
if not utils.keras_installed():
raise RuntimeError('keras is not installed. Please install coremltools to use this feature.')

from .keras.convert import convert
return convert(model, name)
return convert(model, name, initial_types=initial_types, doc_string=doc_string)
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ numpy
protobuf
codecov
tensorflow
keras
keras==2.0.9
coremltools
pandas
pytest
Expand Down
86 changes: 62 additions & 24 deletions tests/end2end/test_single_operator_with_cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def _create_tensor(N, C, H=None, W=None):
class TestKeras2CoreML2ONNX(unittest.TestCase):

def _test_one_to_one_operator_core(self, keras_model, x):
# Verify Keras-to-CoreML-to-ONNX path
coreml_model = coremltools.converters.keras.convert(keras_model)
onnx_model = onnxmltools.convert_coreml(coreml_model)

Expand All @@ -87,8 +88,18 @@ def _test_one_to_one_operator_core(self, keras_model, x):

self.assertTrue(np.allclose(y_reference, y_produced))

# Verify Keras-to-ONNX path
onnx_model = onnxmltools.convert_keras(keras_model)
y_produced = _evaluate(onnx_model, x)

self.assertTrue(np.allclose(y_reference, y_produced))

def _test_one_to_one_operator_core_channels_last(self, keras_model, x):
'''
There are two test paths. One is Keras-->CoreML-->ONNX and the other one is Keras-->ONNX.

Keras-->CoreML-->ONNX:

Keras computation path:
[N, C, H, W] ---> numpy transpose ---> [N, H, W, C] ---> keras convolution --->
[N, H, W, C] ---> numpy transpose ---> [N, C, H, W]
Expand All @@ -98,25 +109,44 @@ def _test_one_to_one_operator_core_channels_last(self, keras_model, x):

The reason for having extra transpose's in the Keras path is that CoreMLTools doesn't not handle channels_last
flag properly. Precisely, oreMLTools always converts Conv2D under channels_first mode.

Keras-->ONNX

Keras computation path:
[N, C, H, W] ---> numpy transpose ---> [N, H, W, C] ---> keras convolution --->
[N, H, W, C]

ONNX computation path:
[N, C, H, W] ---> numpy transpose ---> [N, H, W, C] ---> ONNX convolution ---> [N, H, W, C]

'''
# Verify Keras-to-CoreML-to-ONNX path
coreml_model = coremltools.converters.keras.convert(keras_model)
onnx_model = onnxmltools.convert_coreml(coreml_model)
onnx_model_p1 = onnxmltools.convert_coreml(coreml_model)
onnx_model_p2 = onnxmltools.convert_keras(keras_model)

if isinstance(x, list):
x_t = [np.transpose(_, [0, 2, 3, 1]) for _ in x]
else:
x_t = np.transpose(x, [0, 2, 3, 1])
y_reference = np.transpose(keras_model.predict(x_t), [0, 3, 1, 2])

y_produced = _evaluate(onnx_model, x)
y_produced = _evaluate(onnx_model_p1, x)

self.assertTrue(np.allclose(y_reference, y_produced))

# Verify Keras-to-ONNX path
y_reference = np.transpose(y_reference, [0, 2, 3, 1])
y_produced = _evaluate(onnx_model_p2, x_t)

self.assertTrue(np.allclose(y_reference, y_produced, atol=1e-6))

def test_dense(self):
N, C = 2, 3
x = _create_tensor(N, C)
model = Sequential()
model.add(Dense(2, input_dim=C))

input = Input(shape=(C,))
result = Dense(2)(input)
model = Model(input=input, output=result)
model.compile(optimizer='adagrad', loss='mse')

self._test_one_to_one_operator_core(model, x)
Expand All @@ -125,9 +155,10 @@ def test_conv_4d(self):
N, C, H, W = 1, 2, 4, 3
x = _create_tensor(N, C, H, W)

model = Sequential()
model.add(Conv2D(2, kernel_size=(1, 2), strides=(1, 1), padding='valid', input_shape=(H, W, C),
data_format='channels_last'))
input = Input(shape=(H, W, C))
result = Conv2D(2, kernel_size=(1, 2), strides=(1, 1), padding='valid', input_shape=(H, W, C),
data_format='channels_last')(input)
model = Model(input=input, output=result)
model.compile(optimizer='adagrad', loss='mse')

self._test_one_to_one_operator_core_channels_last(model, x)
Expand All @@ -137,8 +168,9 @@ def test_pooling_4d(self):
N, C, H, W = 1, 2, 4, 3
x = _create_tensor(N, C, H, W)
for layer in layers_to_be_tested:
model = Sequential()
model.add(layer(2, input_shape=(H, W, C), data_format='channels_last'))
input = Input(shape=(H, W, C))
result = layer(2, data_format='channels_last')(input)
model = Model(input=input, output=result)
model.compile(optimizer='adagrad', loss='mse')

self._test_one_to_one_operator_core_channels_last(model, x)
Expand All @@ -148,8 +180,9 @@ def test_convolution_transpose_2d(self):
N, C, H, W = 2, 2, 1, 1
x = _create_tensor(N, C, H, W)

model = Sequential()
model.add(Conv2DTranspose(2, (2, 1), input_shape=(H, W, C), data_format='channels_last'))
input = Input(shape=(H, W, C))
result = Conv2DTranspose(2, (2, 1), data_format='channels_last')(input)
model = Model(input=input, output=result)
model.compile(optimizer='adagrad', loss='mse')

self._test_one_to_one_operator_core_channels_last(model, x)
Expand Down Expand Up @@ -190,11 +223,12 @@ def test_activation_2d(self):
x = _create_tensor(N, C)

for activation in activation_to_be_tested:
model = Sequential()
input = Input(shape=(C,))
if isinstance(activation, str):
model.add(Activation(activation, input_shape=(3,)))
result = Activation(activation)(input)
else:
model.add(activation(input_shape=(3,)))
result = activation()(input)
model = Model(input=input, output=result)
model.compile(optimizer='adagrad', loss='mse')

self._test_one_to_one_operator_core(model, x)
Expand All @@ -206,11 +240,12 @@ def test_activation_4d(self):
x = _create_tensor(N, C, H, W)

for activation in activation_to_be_tested:
model = Sequential()
input = Input(shape=(H, W, C))
if isinstance(activation, str):
model.add(Activation(activation, input_shape=(H, W, C)))
result = Activation(activation)(input)
else:
model.add(activation(input_shape=(H, W, C)))
result = activation()(input)
model = Model(input=input, output=result)
model.compile(optimizer='adagrad', loss='mse')

self._test_one_to_one_operator_core_channels_last(model, x)
Expand All @@ -234,10 +269,12 @@ def test_batch_normalization(self):
N, C, H, W = 2, 2, 3, 4
x = _create_tensor(N, C, H, W)
model = Sequential()
model.add(BatchNormalization(beta_initializer='random_uniform', gamma_initializer='random_uniform',
moving_mean_initializer='random_uniform',
moving_variance_initializer=RandomUniform(minval=0.1, maxval=0.5),
input_shape=(H, W, C)))
input = Input(shape=(H, W, C))
result = BatchNormalization(beta_initializer='random_uniform', gamma_initializer='random_uniform',
moving_mean_initializer='random_uniform',
moving_variance_initializer=RandomUniform(minval=0.1, maxval=0.5),
)(input)
model = Model(input=input, output=result)
model.compile(optimizer='adagrad', loss='mse')

self._test_one_to_one_operator_core_channels_last(model, x)
Expand All @@ -262,8 +299,9 @@ def test_upsample(self):
N, C, H, W = 2, 3, 1, 2
x = _create_tensor(N, C, H, W)

model = Sequential()
model.add(UpSampling2D(input_shape=(H, W, C)))
input = Input(shape=(H, W, C))
result = UpSampling2D(input)
model = Model(input=input, output=result)
model.compile(optimizer='adagrad', loss='mse')

self._test_one_to_one_operator_core_channels_last(model, x)
Expand Down