Skip to content

Commit

Permalink
add caffe scale (#732)
Browse files Browse the repository at this point in the history
fix #720
  • Loading branch information
rainLiuplus committed Sep 11, 2019
1 parent 24f272d commit b5feba9
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 5 deletions.
10 changes: 7 additions & 3 deletions mmdnn/conversion/caffe/mapper.py
Expand Up @@ -259,11 +259,15 @@ def map_scale(cls, node):
# TODO: The gamma parameter has to be set (in node.data?) and this should work.
# Also, mean should be set to 0, and var to 1, just to be safe.
if node.data:
raise NotImplementedError
scale_value = float(node.parameters.filler.value)
kwargs = {'scale' : False, 'bias' : False, 'gamma' : scale_value, 'epsilon': 0}
if node.parameters.bias_term:
bias_value = float(node.parameters.bias_filler.value)
kwargs = {'use_scale' : True, 'use_bias' : node.parameters.bias_term, 'gamma' : scale_value, 'beta': bias_value, 'epsilon': 0}
else:
kwargs = {'use_scale' : True, 'use_bias' : node.parameters.bias_term, 'gamma' : scale_value, 'epsilon': 0}

cls._convert_output_shape(kwargs, node)
return Node.create('Scale', **kwargs)
return Node.create('Affine', **kwargs)
else:
return Node.create('Mul')

Expand Down
4 changes: 3 additions & 1 deletion mmdnn/conversion/examples/caffe/extractor.py
Expand Up @@ -44,6 +44,8 @@ class caffe_extractor(base_extractor):
'caffemodel' : 'http://dl.caffe.berkeleyvision.org/fcn16s-heavy-pascal.caffemodel'},
'voc-fcn32s' : {'prototxt' : MMDNN_BASE_URL + "caffe/voc-fcn32s_deploy.prototxt",
'caffemodel' : 'http://dl.caffe.berkeleyvision.org/fcn32s-heavy-pascal.caffemodel'},
'trailnet_sresnet': {'prototxt': 'https://raw.githubusercontent.com/NVIDIA-AI-IOT/redtail/master/models/pretrained/TrailNet_SResNet-18.prototxt',
'caffemodel': 'https://raw.githubusercontent.com/NVIDIA-AI-IOT/redtail/master/models/pretrained/TrailNet_SResNet-18.caffemodel'}
}


Expand Down Expand Up @@ -80,7 +82,7 @@ def inference(cls, architecture_name, architecture, path, image_path):
img = np.transpose(img, (2, 0, 1))
img = np.expand_dims(img, 0)
net.blobs['data'].data[...] = img
predict = np.squeeze(net.forward()[net._layer_names[-1]][0])
predict = np.squeeze(net.forward()[net._output_list[-1]][0])
predict = np.squeeze(predict)
return predict

Expand Down
7 changes: 6 additions & 1 deletion mmdnn/conversion/examples/imagenet_test.py
Expand Up @@ -88,6 +88,7 @@ class TestKit(object):
'voc-fcn8s' : lambda path : TestKit.ZeroCenter(path, 500, True),
'voc-fcn16s' : lambda path : TestKit.ZeroCenter(path, 500, True),
'voc-fcn32s' : lambda path : TestKit.ZeroCenter(path, 500, True),
'trailnet_sresnet': lambda path: TestKit.ZeroCenter(path, (320, 180), True)
},

'tensorflow' : {
Expand Down Expand Up @@ -246,7 +247,11 @@ def __init__(self):
@staticmethod
def ZeroCenter(path, size, BGRTranspose=False):
img = Image.open(path)
img = img.resize((size, size))
if isinstance(size, tuple):
h, w = size[0], size[1]
else:
h, w = size, size
img = img.resize((h, w))
x = np.array(img, dtype=np.float32)

# Reference: 1) Keras image preprocess: https://github.com/keras-team/keras/blob/master/keras/applications/imagenet_utils.py
Expand Down
39 changes: 39 additions & 0 deletions mmdnn/conversion/keras/keras2_emitter.py
Expand Up @@ -862,6 +862,23 @@ def emit_PRelu(self, IR_node, in_scope=False):
)
return code

def emit_Affine(self, IR_node, in_scope=False):
if in_scope:
raise NotImplementedError
else:
self.used_layers.add('Affine')
if IR_node.layer.attr.get('beta', None) is None:
bias = None
else:
bias = IR_node.layer.attr['beta'].f
code = "{:<15} = Affine(name='{}', scale={}, bias={})({})".format(
IR_node.variable_name,
IR_node.name,
IR_node.layer.attr['gamma'].f,
bias,
self.parent_variable_name(IR_node))
return code

def emit_yolo(self, IR_node, in_scope=False):
self.used_layers.add('Yolo')
self.yolo_parameter = [IR_node.get_attr('anchors'),
Expand Down Expand Up @@ -1183,6 +1200,28 @@ def compute_output_shape(self, input_shape):
return input_shape""")


def _layer_Affine(self):
self.add_body(0, '''
from keras.engine import Layer, InputSpec
from keras import initializers
from keras import backend as K
class Affine(Layer):
def __init__(self, scale, bias=None, **kwargs):
super(Affine, self).__init__(**kwargs)
self.gamma = scale
self.beta = bias
def call(self, inputs, training=None):
input_shape = K.int_shape(inputs)
# Prepare broadcasting shape.
return self.gamma * inputs + self.beta
def compute_output_shape(self, input_shape):
return input_shape
''')


def _layer_Split(self):
self.add_body(0, '''
def __split(input, split_num, axis):
Expand Down

0 comments on commit b5feba9

Please sign in to comment.