From 0b883fea1058d7e26c21d4848514e66c649549f4 Mon Sep 17 00:00:00 2001 From: Vinitra Swamy Date: Tue, 4 Jun 2019 10:29:57 -0700 Subject: [PATCH] [Opset 10] updates for thresholded relu (#308) --- .../neural_network/Activation.py | 8 ++++---- onnxutils/onnxconverter_common/onnx_ops.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/onnxmltools/convert/coreml/operator_converters/neural_network/Activation.py b/onnxmltools/convert/coreml/operator_converters/neural_network/Activation.py index a511737b..da1319e5 100644 --- a/onnxmltools/convert/coreml/operator_converters/neural_network/Activation.py +++ b/onnxmltools/convert/coreml/operator_converters/neural_network/Activation.py @@ -43,11 +43,11 @@ def convert_activation(scope, operator, container): elif activation_type =='scaledTanh': apply_scaled_tanh(scope, inputs[0], outputs[0], container, operator_name=attrs['name'], alpha=params.scaledTanh.alpha, beta=params.scaledTanh.beta) + elif activation_type == 'thresholdedReLU': + apply_thresholded_relu(scope, inputs, outputs, container, operator_name=attrs['name'], + alpha=params.thresholdedReLU.alpha) else: - if activation_type == 'thresholdedReLU': - op_type = 'ThresholdedRelu' - attrs['alpha'] = params.thresholdedReLU.alpha - elif activation_type == 'softsign': + if activation_type == 'softsign': op_type = 'Softsign' elif activation_type == 'softplus': op_type = 'Softplus' diff --git a/onnxutils/onnxconverter_common/onnx_ops.py b/onnxutils/onnxconverter_common/onnx_ops.py index f9f4e2b8..cd753b99 100644 --- a/onnxutils/onnxconverter_common/onnx_ops.py +++ b/onnxutils/onnxconverter_common/onnx_ops.py @@ -580,6 +580,21 @@ def apply_sum(scope, input_names, output_name, container, operator_name=None): def apply_tanh(scope, input_name, output_name, container, operator_name=None): _apply_unary_operation(scope, 'Tanh', input_name, output_name, container, operator_name) +def apply_thresholded_relu(scope, input_name, output_name, container, operator_name=None, alpha=None): + if alpha == None: + alpha = [1.0] + + name = _create_name_or_use_existing_one(scope, 'ThresholdedRelu', operator_name) + attrs = {'name': name, 'alpha': alpha[0]} + if container.target_opset < 10: + # ThresholdedRelu graduated from an experimental op to a full op in opset 10 + # onnxruntime maintains support in the ONNX domain for ThresholdedRelu as a contrib op + attrs['op_domain'] = "ai.onnx" + op_version = 1 + else: + op_version = 10 + container.add_node('ThresholdedRelu', input_name, output_name, op_version=op_version, **attrs) + def apply_tile(scope, input_name, output_name, container, operator_name=None, repeats=None): name = _create_name_or_use_existing_one(scope, 'Tile', operator_name)