Skip to content
Permalink
Browse files

Added elu emit function to keras2_emitter and caffe_emitter (#544)

  • Loading branch information...
BlaiseRitchie authored and rainLiuplus committed Jun 11, 2019
1 parent 5ba8b93 commit d7e40c48e3f93310232acecf717957763e3bed3c
@@ -501,6 +501,15 @@ def emit_Relu(self, IR_node):
self.parent_variable_name(IR_node),
in_place))


def emit_Elu(self, IR_node):
in_place = True
self.add_body(1, "n.{:<15} = L.ELU(n.{}, in_place={}, ntop=1)".format(
IR_node.variable_name,
self.parent_variable_name(IR_node),
in_place))


def emit_LeakyRelu(self, IR_node):
in_place = True
self.add_body(1, "n.{:<15} = L.ReLU(n.{}, in_place={}, negative_slope={}, ntop=1)".format(
@@ -141,12 +141,20 @@ def map_crop(cls, node):
return Node.create('Crop', **kwargs)


@classmethod
def map_elu(cls, node):
kwargs = {}
cls._convert_output_shape(kwargs, node)
return Node.create('ELU', **kwargs)


@classmethod
def map_relu(cls, node):
kwargs = {}
cls._convert_output_shape(kwargs, node)
return Node.create('Relu', **kwargs)


@classmethod
def map_p_re_lu(cls, node):
# print(node.parameters)
@@ -4,6 +4,7 @@
#----------------------------------------------------------------------------------------------

from __future__ import absolute_import
import os
from mmdnn.conversion.examples.imagenet_test import TestKit
from mmdnn.conversion.examples.extractor import base_extractor
from mmdnn.conversion.common.utils import download_file
@@ -59,6 +60,7 @@ def download(cls, architecture, path="./"):
if not weight_file:
return None


print("Caffe Model {} saved as [{}] and [{}].".format(architecture, architecture_file, weight_file))
return (architecture_file, weight_file)

@@ -72,6 +74,7 @@ def inference(cls, architecture_name, architecture, path, image_path):
import caffe
import numpy as np
net = caffe.Net(architecture[0], architecture[1], caffe.TEST)

func = TestKit.preprocess_func['caffe'][architecture_name]
img = func(image_path)
img = np.transpose(img, (2, 0, 1))
@@ -25,7 +25,7 @@ class keras_extractor(base_extractor):
'xception' : lambda : keras.applications.xception.Xception(input_shape=(299, 299, 3)),
'inception_resnet_v2' : lambda : keras.applications.inception_resnet_v2.InceptionResNetV2(input_shape=(299, 299, 3)),
'densenet' : lambda : keras.applications.densenet.DenseNet201(),
'nasnet' : lambda : keras.applications.nasnet.NASNetLarge()
'nasnet' : lambda : keras.applications.nasnet.NASNetLarge(),
}

thirdparty_map = {
@@ -490,6 +490,14 @@ def emit_Reshape(self, IR_node, in_scope=False):
return code


def emit_Elu(self, IR_node):
self._emit_activation(IR_node, 'elu')


def emit_Relu(self, IR_node):
self._emit_activation(IR_node, 'relu')


def emit_Tanh(self, IR_node, in_scope=False):
code = self._emit_activation(IR_node, 'tanh', in_scope)
return code
@@ -948,6 +948,7 @@ def _test_function(self, original_framework, parser):
# get original model prediction result
original_predict = parser(network_name, test_input)


IR_file = TestModels.tmpdir + original_framework + '_' + network_name + "_converted"
for emit in self.test_table[original_framework][network_name]:
if isinstance(emit, staticmethod):
@@ -966,6 +967,7 @@ def _test_function(self, original_framework, parser):
IR_file + ".npy",
test_input)


self._compare_outputs(
original_framework,
target_framework,

0 comments on commit d7e40c4

Please sign in to comment.
You can’t perform that action at this time.