Skip to content

Commit

Permalink
[Opset 10] updates for thresholded relu (#308)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinitra committed Jun 4, 2019
1 parent 5e70d8e commit 0b883fe
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
Expand Up @@ -43,11 +43,11 @@ def convert_activation(scope, operator, container):
elif activation_type =='scaledTanh':
apply_scaled_tanh(scope, inputs[0], outputs[0], container, operator_name=attrs['name'],
alpha=params.scaledTanh.alpha, beta=params.scaledTanh.beta)
elif activation_type == 'thresholdedReLU':
apply_thresholded_relu(scope, inputs, outputs, container, operator_name=attrs['name'],
alpha=params.thresholdedReLU.alpha)
else:
if activation_type == 'thresholdedReLU':
op_type = 'ThresholdedRelu'
attrs['alpha'] = params.thresholdedReLU.alpha
elif activation_type == 'softsign':
if activation_type == 'softsign':
op_type = 'Softsign'
elif activation_type == 'softplus':
op_type = 'Softplus'
Expand Down
15 changes: 15 additions & 0 deletions onnxutils/onnxconverter_common/onnx_ops.py
Expand Up @@ -580,6 +580,21 @@ def apply_sum(scope, input_names, output_name, container, operator_name=None):
def apply_tanh(scope, input_name, output_name, container, operator_name=None):
_apply_unary_operation(scope, 'Tanh', input_name, output_name, container, operator_name)

def apply_thresholded_relu(scope, input_name, output_name, container, operator_name=None, alpha=None):
if alpha == None:
alpha = [1.0]

name = _create_name_or_use_existing_one(scope, 'ThresholdedRelu', operator_name)
attrs = {'name': name, 'alpha': alpha[0]}
if container.target_opset < 10:
# ThresholdedRelu graduated from an experimental op to a full op in opset 10
# onnxruntime maintains support in the ONNX domain for ThresholdedRelu as a contrib op
attrs['op_domain'] = "ai.onnx"
op_version = 1
else:
op_version = 10
container.add_node('ThresholdedRelu', input_name, output_name, op_version=op_version, **attrs)

def apply_tile(scope, input_name, output_name, container, operator_name=None, repeats=None):
name = _create_name_or_use_existing_one(scope, 'Tile', operator_name)

Expand Down

0 comments on commit 0b883fe

Please sign in to comment.