-
Notifications
You must be signed in to change notification settings - Fork 5.3k
/
spherenet.py
102 lines (87 loc) · 3.72 KB
/
spherenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import mxnet as mx
import numpy as np
import math
from mxnet.base import _Null
def conv_main(data, units, filters, workspace):
body = data
for i in xrange(len(units)):
f = filters[i]
_weight = mx.symbol.Variable("conv%d_%d_weight"%(i+1, 1), lr_mult=1.0)
_bias = mx.symbol.Variable("conv%d_%d_bias"%(i+1, 1), lr_mult=2.0, wd_mult=0.0)
body = mx.sym.Convolution(data=body, weight = _weight, bias = _bias, num_filter=f, kernel=(3, 3), stride=(2,2), pad=(1, 1),
name= "conv%d_%d"%(i+1, 1), workspace=workspace)
#body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=0.9, name='bn%d_%d'%(i+1, 1))
body = mx.sym.LeakyReLU(data = body, act_type='prelu', name = "relu%d_%d" % (i+1, 1))
idx = 2
for j in xrange(units[i]):
_body = mx.sym.Convolution(data=body, no_bias=True, num_filter=f, kernel=(3, 3), stride=(1,1), pad=(1, 1),
name= "conv%d_%d"%(i+1, idx), workspace=workspace)
#_body = mx.sym.BatchNorm(data=_body, fix_gamma=False, eps=2e-5, momentum=0.9, name='bn%d_%d'%(i+1, idx))
_body = mx.sym.LeakyReLU(data = _body, act_type='prelu', name = "relu%d_%d" % (i+1, idx))
idx+=1
_body = mx.sym.Convolution(data=_body, no_bias=True, num_filter=f, kernel=(3, 3), stride=(1,1), pad=(1, 1),
name= "conv%d_%d"%(i+1, idx), workspace=workspace)
#_body = mx.sym.BatchNorm(data=_body, fix_gamma=False, eps=2e-5, momentum=0.9, name='bn%d_%d'%(i+1, idx))
_body = mx.sym.LeakyReLU(data = _body, act_type='prelu', name = "relu%d_%d" % (i+1, idx))
idx+=1
body = body+_body
#body = mx.sym.LeakyReLU(data = body, act_type='prelu', name = "relu%d_%d" % (i+1, idx)) #modify
#idx+=1
return body
def get_symbol(num_classes, num_layers, conv_workspace=256, **kwargs):
if num_layers==64:
units = [3,8,16,3]
filters = [64,128,256,512]
elif num_layers==20:
units = [1,2,4,1]
filters = [64,128,256,512]
#filters = [64, 256, 512, 1024]
elif num_layers==36:
units = [2,4,8,2]
filters = [64,128,256,512]
#filters = [64, 256, 512, 1024]
elif num_layers==60:
units = [3,8,14,3]
filters = [64,128,256,512]
elif num_layers==104:
units = [3,8,36,3]
filters = [64,128,256,512]
#filters = [64, 256, 512, 1024]
data = mx.symbol.Variable('data')
data = data-127.5
data = data*0.0078125
body = conv_main(data = data, units = units, filters = filters, workspace = conv_workspace)
#modify begin
#body = mx.sym.Pooling(data=body, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1')
#body = mx.sym.Flatten(data=body)
#modify end
_weight = mx.symbol.Variable("fc1_weight", lr_mult=1.0)
_bias = mx.symbol.Variable("fc1_bias", lr_mult=2.0, wd_mult=0.0)
fc1 = mx.sym.FullyConnected(data=body, weight=_weight, bias=_bias, num_hidden=num_classes, name='fc1')
return fc1
def init_weights(sym, data_shape_dict, num_layers):
arg_name = sym.list_arguments()
aux_name = sym.list_auxiliary_states()
arg_shape, aaa, aux_shape = sym.infer_shape(**data_shape_dict)
#print(data_shape_dict)
#print(arg_name)
#print(arg_shape)
arg_params = {}
aux_params = None
#print(aaa)
#print(aux_shape)
arg_shape_dict = dict(zip(arg_name, arg_shape))
aux_shape_dict = dict(zip(aux_name, aux_shape))
#print(aux_shape)
#print(aux_params)
#print(arg_shape_dict)
for k,v in arg_shape_dict.iteritems():
if k.startswith('conv') and k.endswith('_weight'):
if not k.find('_1_')>=0:
if num_layers<100:
arg_params[k] = mx.random.normal(0, 0.01, shape=v)
print('init', k)
if k.endswith('_bias'):
arg_params[k] = mx.nd.zeros(shape=v)
print('init', k)
return arg_params, aux_params