From 2b8723729bfb91275d598b55752f32d8a00a33ad Mon Sep 17 00:00:00 2001 From: Mahyar Najibi Date: Sun, 23 Feb 2020 16:51:28 +0000 Subject: [PATCH] Updating the symbol weight initialization for AutoFocus --- symbols/faster/resnet_mx_101_e2e.py | 15 ++++++++------- symbols/faster/resnet_mx_101_e2e_mask.py | 15 ++++++++------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/symbols/faster/resnet_mx_101_e2e.py b/symbols/faster/resnet_mx_101_e2e.py index 1b00f4a40..306f1ecfd 100644 --- a/symbols/faster/resnet_mx_101_e2e.py +++ b/symbols/faster/resnet_mx_101_e2e.py @@ -464,13 +464,14 @@ def init_weight_rcnn(self, cfg, arg_params, aux_params): arg_params['conv_new_1_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_1_weight']) arg_params['conv_new_1_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_1_bias']) - - arg_params['conv_new_2_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_2_weight']) - arg_params['conv_new_2_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_2_bias']) - arg_params['conv_new_3_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_3_weight']) - arg_params['conv_new_3_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_3_bias']) - arg_params['conv_new_out_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_out_weight']) - arg_params['conv_new_out_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_out_bias']) + + if cfg.TRAIN.AUTO_FOCUS: + arg_params['conv_new_2_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_2_weight']) + arg_params['conv_new_2_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_2_bias']) + arg_params['conv_new_3_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_3_weight']) + arg_params['conv_new_3_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_3_bias']) + arg_params['conv_new_out_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_out_weight']) + arg_params['conv_new_out_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_out_bias']) arg_params['offset_weight'] = mx.nd.zeros(shape=self.arg_shape_dict['offset_weight']) arg_params['offset_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['offset_bias']) diff --git a/symbols/faster/resnet_mx_101_e2e_mask.py b/symbols/faster/resnet_mx_101_e2e_mask.py index eab2011ab..a63ab69e2 100644 --- a/symbols/faster/resnet_mx_101_e2e_mask.py +++ b/symbols/faster/resnet_mx_101_e2e_mask.py @@ -557,13 +557,14 @@ def init_weight_rcnn(self, cfg, arg_params, aux_params): arg_params['conv_new_1_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_1_weight']) arg_params['conv_new_1_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_1_bias']) - - arg_params['conv_new_2_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_2_weight']) - arg_params['conv_new_2_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_2_bias']) - arg_params['conv_new_3_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_3_weight']) - arg_params['conv_new_3_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_3_bias']) - arg_params['conv_new_out_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_out_weight']) - arg_params['conv_new_out_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_out_bias']) + + if cfg.TRAIN.AUTO_FOCUS: + arg_params['conv_new_2_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_2_weight']) + arg_params['conv_new_2_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_2_bias']) + arg_params['conv_new_3_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_3_weight']) + arg_params['conv_new_3_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_3_bias']) + arg_params['conv_new_out_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_out_weight']) + arg_params['conv_new_out_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_out_bias']) arg_params['offset_weight'] = mx.nd.zeros(shape=self.arg_shape_dict['offset_weight']) arg_params['offset_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['offset_bias'])