Permalink
Browse files

Add prelu layer support for caffe convert tool (#4277)

* add prelu support

* fix params
  • Loading branch information...
1 parent d6328a5 commit 541d10924362dfaee8d80ea6115e83293314b46f @fengshikun fengshikun committed with piiswrong Jan 11, 2017
Showing with 14 additions and 1 deletion.
  1. +9 −1 tools/caffe_converter/convert_model.py
  2. +5 −0 tools/caffe_converter/convert_symbol.py
@@ -60,7 +60,15 @@ def main():
first_conv = True
for layer_name, layer_type, layer_blobs in iter:
- if layer_type == 'Convolution' or layer_type == 'InnerProduct' or layer_type == 4 or layer_type == 14:
+ if layer_type == 'Convolution' or layer_type == 'InnerProduct' or layer_type == 4 or layer_type == 14 \
+ or layer_type == 'PReLU':
+ if layer_type == 'PReLU':
+ assert(len(layer_blobs) == 1)
+ wmat = layer_blobs[0].data
+ weight_name = layer_name + '_gamma'
+ arg_params[weight_name] = mx.nd.zeros(wmat.shape)
+ arg_params[weight_name][:] = wmat
+ continue
assert(len(layer_blobs) == 2)
wmat_dim = []
if getattr(layer_blobs[0].shape, 'dim', None) is not None:
@@ -165,6 +165,11 @@ def proto2script(proto_file):
type_string = 'mx.symbol.BatchNorm'
param = layer[i].batch_norm_param
param_string = 'use_global_stats=%s' % param.use_global_stats
+ if layer[i].type == 'PReLU':
+ type_string = 'mx.symbol.LeakyReLU'
+ param = layer[i].prelu_param
+ param_string = "act_type='prelu', slope=%f" % param.filler.value
+ need_flatten[name] = need_flatten[mapping[layer[i].bottom[0]]]
if type_string == '':
raise Exception('Unknown Layer %s!' % layer[i].type)
if type_string != 'split':

0 comments on commit 541d109

Please sign in to comment.