Skip to content

Commit

Permalink
Updating the symbol weight initialization for AutoFocus
Browse files Browse the repository at this point in the history
  • Loading branch information
mahyarnajibi committed Feb 23, 2020
1 parent 6ccd2fd commit 2b87237
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
15 changes: 8 additions & 7 deletions symbols/faster/resnet_mx_101_e2e.py
Expand Up @@ -464,13 +464,14 @@ def init_weight_rcnn(self, cfg, arg_params, aux_params):

arg_params['conv_new_1_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_1_weight'])
arg_params['conv_new_1_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_1_bias'])

arg_params['conv_new_2_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_2_weight'])
arg_params['conv_new_2_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_2_bias'])
arg_params['conv_new_3_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_3_weight'])
arg_params['conv_new_3_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_3_bias'])
arg_params['conv_new_out_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_out_weight'])
arg_params['conv_new_out_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_out_bias'])

if cfg.TRAIN.AUTO_FOCUS:
arg_params['conv_new_2_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_2_weight'])
arg_params['conv_new_2_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_2_bias'])
arg_params['conv_new_3_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_3_weight'])
arg_params['conv_new_3_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_3_bias'])
arg_params['conv_new_out_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_out_weight'])
arg_params['conv_new_out_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_out_bias'])

arg_params['offset_weight'] = mx.nd.zeros(shape=self.arg_shape_dict['offset_weight'])
arg_params['offset_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['offset_bias'])
Expand Down
15 changes: 8 additions & 7 deletions symbols/faster/resnet_mx_101_e2e_mask.py
Expand Up @@ -557,13 +557,14 @@ def init_weight_rcnn(self, cfg, arg_params, aux_params):

arg_params['conv_new_1_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_1_weight'])
arg_params['conv_new_1_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_1_bias'])

arg_params['conv_new_2_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_2_weight'])
arg_params['conv_new_2_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_2_bias'])
arg_params['conv_new_3_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_3_weight'])
arg_params['conv_new_3_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_3_bias'])
arg_params['conv_new_out_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_out_weight'])
arg_params['conv_new_out_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_out_bias'])

if cfg.TRAIN.AUTO_FOCUS:
arg_params['conv_new_2_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_2_weight'])
arg_params['conv_new_2_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_2_bias'])
arg_params['conv_new_3_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_3_weight'])
arg_params['conv_new_3_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_3_bias'])
arg_params['conv_new_out_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_out_weight'])
arg_params['conv_new_out_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_out_bias'])

arg_params['offset_weight'] = mx.nd.zeros(shape=self.arg_shape_dict['offset_weight'])
arg_params['offset_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['offset_bias'])
Expand Down

0 comments on commit 2b87237

Please sign in to comment.