Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
363 changes: 44 additions & 319 deletions README.md

Large diffs are not rendered by default.

35 changes: 35 additions & 0 deletions configs/panet/e2e_panet_R-50-FPN_1x_det.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.fpn_ResNet50_conv5_body_bup
FASTER_RCNN: True
NUM_GPUS: 8
RESNETS:
IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_model/resnet50_caffe.pth'
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 90000
STEPS: [0, 60000, 80000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.roi_Xconv1fc_gn_head_panet # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
TRAIN:
SCALES: (1200,1200,1000,800,600,400)
MAX_SIZE: 1400
BATCH_SIZE_PER_IM: 512
RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
TEST:
SCALE: 1000
MAX_SIZE: 1400
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000
35 changes: 35 additions & 0 deletions configs/panet/e2e_panet_R-50-FPN_1x_det_2fc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.fpn_ResNet50_conv5_body_bup
FASTER_RCNN: True
NUM_GPUS: 8
RESNETS:
IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_model/resnet50_caffe.pth'
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 90000
STEPS: [0, 60000, 80000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.roi_2mlp_head_gn_panet # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
TRAIN:
SCALES: (1200,1200,1000,800,600,400)
MAX_SIZE: 1400
BATCH_SIZE_PER_IM: 512
RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
TEST:
SCALE: 1000
MAX_SIZE: 1400
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000
45 changes: 45 additions & 0 deletions configs/panet/e2e_panet_R-50-FPN_2x_mask.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.fpn_ResNet50_conv5_body_bup
FASTER_RCNN: True
MASK_ON: True
NUM_GPUS: 8
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 180000
STEPS: [0, 120000, 160000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
RESNETS:
IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_model/resnet50_caffe.pth'
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.roi_Xconv1fc_gn_head_panet # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
MRCNN:
ROI_MASK_HEAD: mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs_gn_adp_ff # Note: this is a GN mask head
RESOLUTION: 28 # (output mask resolution) default 14
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 14 # default 7
ROI_XFORM_SAMPLING_RATIO: 2 # default 0
DILATION: 1 # default 2
CONV_INIT: MSRAFill # default GaussianFill
TRAIN:
SCALES: (1200, 1200, 1000, 800, 600, 400)
MAX_SIZE: 1400
BATCH_SIZE_PER_IM: 512
RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
TEST:
SCALE: 1000
MAX_SIZE: 1400
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000

60 changes: 56 additions & 4 deletions lib/modeling/FPN.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def fpn_ResNet50_conv5_body():
ResNet.ResNet50_conv5_body, fpn_level_info_ResNet50_conv5()
)

def fpn_ResNet50_conv5_body_bup():
return fpn(
ResNet.ResNet50_conv5_body, fpn_level_info_ResNet50_conv5(),
panet_buttomup=True
)


def fpn_ResNet50_conv5_P2only_body():
return fpn(
Expand Down Expand Up @@ -77,10 +83,11 @@ class fpn(nn.Module):
similarly for fpn_level_info.dims: e.g [2048, 1024, 512, 256]
similarly for spatial_scale: e.g [1/32, 1/16, 1/8, 1/4]
"""
def __init__(self, conv_body_func, fpn_level_info, P2only=False):
def __init__(self, conv_body_func, fpn_level_info, P2only=False, panet_buttomup=False):
super().__init__()
self.fpn_level_info = fpn_level_info
self.P2only = P2only
self.panet_buttomup = panet_buttomup

self.dim_out = fpn_dim = cfg.FPN.DIM
min_level, max_level = get_min_max_levels()
Expand Down Expand Up @@ -125,6 +132,35 @@ def __init__(self, conv_body_func, fpn_level_info, P2only=False):

self.spatial_scale.append(fpn_level_info.spatial_scales[i])

# add for panet buttom-up path
if self.panet_buttomup:
self.panet_buttomup_conv1_modules = nn.ModuleList()
self.panet_buttomup_conv2_modules = nn.ModuleList()
for i in range(self.num_backbone_stages - 1):
if cfg.FPN.USE_GN:
self.panet_buttomup_conv1_modules.append(nn.Sequential(
nn.Conv2d(fpn_dim, fpn_dim, 3, 2, 1, bias=True),
nn.GroupNorm(net_utils.get_group_gn(fpn_dim), fpn_dim,
eps=cfg.GROUP_NORM.EPSILON),
nn.ReLU(inplace=True)
))
self.panet_buttomup_conv2_modules.append(nn.Sequential(
nn.Conv2d(fpn_dim, fpn_dim, 3, 1, 1, bias=True),
nn.GroupNorm(net_utils.get_group_gn(fpn_dim), fpn_dim,
eps=cfg.GROUP_NORM.EPSILON),
nn.ReLU(inplace=True)
))
else:
self.panet_buttomup_conv1_modules.append(
nn.Conv2d(fpn_dim, fpn_dim, 3, 2, 1)
)
self.panet_buttomup_conv2_modules.append(
nn.Conv2d(fpn_dim, fpn_dim, 3, 1, 1)
)

#self.spatial_scale.append(fpn_level_info.spatial_scales[i])


#
# Step 2: build up starting from the coarsest backbone level
#
Expand Down Expand Up @@ -160,6 +196,7 @@ def _init_weights(self):
def init_func(m):
if isinstance(m, nn.Conv2d):
mynn.init.XavierFill(m.weight)
#mynn.init.MSRAFill(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)

Expand Down Expand Up @@ -236,10 +273,25 @@ def forward(self, x):
self.topdown_lateral_modules[i](fpn_inner_blobs[-1], conv_body_blobs[-(i+2)])
)
fpn_output_blobs = []
if self.panet_buttomup:
fpn_middle_blobs = []
for i in range(self.num_backbone_stages):
fpn_output_blobs.append(
self.posthoc_modules[i](fpn_inner_blobs[i])
)
if not self.panet_buttomup:
fpn_output_blobs.append(
self.posthoc_modules[i](fpn_inner_blobs[i])
)
else:
fpn_middle_blobs.append(
self.posthoc_modules[i](fpn_inner_blobs[i])
)
if self.panet_buttomup:
fpn_output_blobs.append(fpn_middle_blobs[-1])
for i in range(2, self.num_backbone_stages + 1):
fpn_tmp = self.panet_buttomup_conv1_modules[i - 2](fpn_output_blobs[0])
#print(fpn_middle_blobs[self.num_backbone_stages - i].size())
fpn_tmp = fpn_tmp + fpn_middle_blobs[self.num_backbone_stages - i]
fpn_tmp = self.panet_buttomup_conv2_modules[i - 2](fpn_tmp)
fpn_output_blobs.insert(0, fpn_tmp)

if hasattr(self, 'maxpool_p6'):
fpn_output_blobs.insert(0, self.maxpool_p6(fpn_output_blobs[0]))
Expand Down
Loading