Skip to content
This repository has been archived by the owner on Nov 21, 2023. It is now read-only.

Support exporting for CPU Mask & Keypoint nets #449

Closed
wants to merge 5 commits into from

Conversation

gadcam
Copy link
Contributor

@gadcam gadcam commented May 28, 2018

Prerequisite : #372
Purpose : enable exporting all the models for CPU by exporting 2 separate nets : one for the bboxes and one for the rest of the inference.

Two main modifications

  • Refactor the main() : it will call a function convert_to_pb for each sub network
  • run_model_pb : always do the inference for bbox and then call mask or keypoint part if needed. The exact same approach is adopted.

Then helper functions are only lightly modified to fit with the new objective to export 2 pb files

Pre-requisite : facebookresearch#372
Purpose : enable exporting all the models for CPU by exporting 2 separate nets : one for the bboxes and one for the rest of the inference.

Two main modifications
- Refactor the main() : it will call a function convert_to_pb for each sub_net
- run_model_pb : always do the inference for bbox and then call mask or keypoint part if needed. The exact same approach is adopted.

Then helper functions are only lightly modified to fit with the new objective to export 2 pb files
@gadcam
Copy link
Contributor Author

gadcam commented May 28, 2018

@daquexian I will more than welcome a review from you on this pull request if you have some time to put into it.

@gadcam
Copy link
Contributor Author

gadcam commented Jun 4, 2018

@daquexian @ir413 @rbgirshick @orionr
I understand reviewing a +178/−61 commit can take a long time (even if there are not many paths in this case) so if I can help you in any way or if you already spotted a bug do not hesitate to tell me.

@gadcam gadcam mentioned this pull request Jul 12, 2018
3 tasks
@newstzpz
Copy link
Contributor

We have open sourced the operator for keypoint (https://github.com/pytorch/pytorch/blob/master/caffe2/operators/heatmap_max_keypoint_op.cc) that it is now feasible to convert the keypoint model to a single caffe2 model.

@gadcam
Copy link
Contributor Author

gadcam commented Jul 22, 2018

@newstzpz thank you! That is really good to hear!

@ir413 @rbgirshick do you plan to change the code of the master to provide a way to run directly using a single model ? (with a CFG parameter ?)

If I want to enable exporting a single caffe2 model for keypoint inference what is the best way to proceed ?

@sidnav
Copy link

sidnav commented Jul 24, 2018

Hi @gadcam , I tried implementing this branch to obtain the .pb files for the e2e_mask_rcnn_R-101-FPN_2x.yaml config file with the --device as cpu. How do I integrate the two to provide an e2e mask output from image data?

@gadcam
Copy link
Contributor Author

gadcam commented Jul 24, 2018

@sidnav Currently It is a two stage process : run_model_pb is implementing inference.

If you only want to support mask models you can remove the part starting with

if 'keypoint_net' in models_pb:

@sidnav
Copy link

sidnav commented Jul 25, 2018

Thank you @gadcam. You have mentioned #372 as a prerequisite for this, do I merge the commits? Will that allow me to run the GPU model of faster RCNN from #372 along with the CPU Mask Net from your PR?

@gadcam
Copy link
Contributor Author

gadcam commented Jul 25, 2018

@sidnav Yes you need to merge them.

along with

I do not really know what you mean by that.
It will allow you to run either inference on GPU or on CPU if that is the question.

@sidnav
Copy link

sidnav commented Jul 25, 2018

@gadcam I merged the commits and tried implementing the GPU option to create the pb files. I get the following error - AssertionError: input rpn_rois should be defined in the net. As you mentioned earlier about the two stage process in run_model_pb , do I need to implement a GPUtoCPU Converter to transfer the net blobs to the mask_net as external_input?

@gadcam
Copy link
Contributor Author

gadcam commented Jul 25, 2018

@sidnav Could you describe in detail what you are doing and also where is the error ? I do not have enough information to understand what is going on :)

AssertionError: input rpn_rois should be defined in the net.

do I need to implement a GPUtoCPU Converter to transfer the net blobs to the mask_net as external_input?

The GPU to CPU conversion is implemented.
The "transfering th blobs as external_input" is managed by this piece of code in case of mask model.
https://github.com/gadcam/Detectron-1/blob/763ff296958acd9d7ab0b6160716b7c1dffef524/tools/convert_pkl_to_pb.py#L531-L550

@sidnav
Copy link

sidnav commented Jul 25, 2018

Hi @gadcam, I run `convert_pkl_to_pb.py' with the device option set to gpu. I get the following error
File "convert_pkl_to_pb.py", line 752, in
main()
File "convert_pkl_to_pb.py", line 736, in main
models_pb['mask_net'] = convert_to_pb(args, model.mask_net.Proto(), blobs, part_name='mask_net')
File "convert_pkl_to_pb.py", line 706, in convert_to_pb
[pb_net, pb_init_net] = convert_model_gpu(args, pb_net, pb_init_net)
File "convert_pkl_to_pb.py", line 361, in convert_model_gpu
ret = core.InjectDeviceCopiesAmongNets([ret_init_net, ret_net])
File "/usr/local/lib/python2.7/dist-packages/caffe2/python/core.py", line 2420, in InjectDeviceCopiesAmongNets
blob_remap=blob_remap,
File "/usr/local/lib/python2.7/dist-packages/caffe2/python/core.py", line 2311, in InjectCrossDeviceCopies
"input {} should be defined in the net.".format(input)
AssertionError: input rpn_rois should be defined in the net

@gadcam
Copy link
Contributor Author

gadcam commented Jul 25, 2018

@sidnav
I never had this error but ok let's debug this!

Could you share these things please ?

  • the full command line you use
  • the model you want to convert
  • anything which is not default/you changed
  • the full output of the script

I know I ask you a lot of things but if I can not reproduce the bug on my side I will have hard time to debug something I do not know.

@sidnav
Copy link

sidnav commented Jul 26, 2018

@gadcam

  • Command- python2 tools/convert_pkl_to_pb.py --cfg configs/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml --out_dir ./model_files --device gpu

  • Model - e2e_mask_rcnn_R-101-FPN_2x

  • No changes on the Default Script

  • Output ->
    Found Detectron ops lib: /usr/local/lib/libcaffe2_detectron_ops_gpu.so
    E0726 10:37:14.902185 12560 init_intrinsics_check.cc:43] CPU feature avx is present on your machine, but the Caffe2 binary is not compiled with it. It means you may not get the full speed of your CPU.
    E0726 10:37:14.902205 12560 init_intrinsics_check.cc:43] CPU feature avx2 is present on your machine, but the Caffe2 binary is not compiled with it. It means you may not get the full speed of your CPU.
    E0726 10:37:14.902209 12560 init_intrinsics_check.cc:43] CPU feature fma is present on your machine, but the Caffe2 binary is not compiled with it. It means you may not get the full speed of your CPU.
    WARNING convert_pkl_to_pb.py: 121: Should not use mobile engine for gpu model.
    INFO convert_pkl_to_pb.py: 719: Called with args:
    INFO convert_pkl_to_pb.py: 720: Namespace(cfg_file='../configs/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml', device='gpu', fuse_af=1, net_execution_type='simple', net_name='detectron', opts=[], out_dir='/home/siddharthnavin/Detectron/model_files', test_img='/home/siddharthnavin/Downloads/TensorRT-4.0.1.6/data/faster-rcnn/16004479832_a748d55f21_k.jpg', use_nnpack=0)
    INFO convert_pkl_to_pb.py: 727: Conerting model with config:
    INFO convert_pkl_to_pb.py: 728: {'BBOX_XFORM_CLIP': 4.135166556742356,
    'CLUSTER': {'ON_CLUSTER': False},
    'DATA_LOADER': {'BLOBS_QUEUE_CAPACITY': 8,
    'MINIBATCH_QUEUE_SIZE': 64,
    'NUM_THREADS': 4},
    'DEDUP_BOXES': 0.0625,
    'DOWNLOAD_CACHE': '/tmp/detectron-download-cache',
    'EPS': 1e-14,
    'EXPECTED_RESULTS': [],
    'EXPECTED_RESULTS_ATOL': 0.005,
    'EXPECTED_RESULTS_EMAIL': '',
    'EXPECTED_RESULTS_RTOL': 0.1,
    'FAST_RCNN': {'CONV_HEAD_DIM': 256,
    'MLP_HEAD_DIM': 1024,
    'NUM_STACKED_CONVS': 4,
    'ROI_BOX_HEAD': 'fast_rcnn_heads.add_roi_2mlp_head',
    'ROI_XFORM_METHOD': 'RoIAlign',
    'ROI_XFORM_RESOLUTION': 7,
    'ROI_XFORM_SAMPLING_RATIO': 2},
    'FPN': {'COARSEST_STRIDE': 32,
    'DIM': 256,
    'EXTRA_CONV_LEVELS': False,
    'FPN_ON': True,
    'MULTILEVEL_ROIS': True,
    'MULTILEVEL_RPN': True,
    'ROI_CANONICAL_LEVEL': 4,
    'ROI_CANONICAL_SCALE': 224,
    'ROI_MAX_LEVEL': 5,
    'ROI_MIN_LEVEL': 2,
    'RPN_ANCHOR_START_SIZE': 32,
    'RPN_ASPECT_RATIOS': (0.5, 1, 2),
    'RPN_MAX_LEVEL': 6,
    'RPN_MIN_LEVEL': 2,
    'USE_GN': False,
    'ZERO_INIT_LATERAL': False},
    'GROUP_NORM': {'DIM_PER_GP': -1, 'EPSILON': 1e-05, 'NUM_GROUPS': 32},
    'KRCNN': {'CONV_HEAD_DIM': 256,
    'CONV_HEAD_KERNEL': 3,
    'CONV_INIT': 'GaussianFill',
    'DECONV_DIM': 256,
    'DECONV_KERNEL': 4,
    'DILATION': 1,
    'HEATMAP_SIZE': -1,
    'INFERENCE_MIN_SIZE': 0,
    'KEYPOINT_CONFIDENCE': 'bbox',
    'LOSS_WEIGHT': 1.0,
    'MIN_KEYPOINT_COUNT_FOR_VALID_MINIBATCH': 20,
    'NMS_OKS': False,
    'NORMALIZE_BY_VISIBLE_KEYPOINTS': True,
    'NUM_KEYPOINTS': -1,
    'NUM_STACKED_CONVS': 8,
    'ROI_KEYPOINTS_HEAD': '',
    'ROI_XFORM_METHOD': 'RoIAlign',
    'ROI_XFORM_RESOLUTION': 7,
    'ROI_XFORM_SAMPLING_RATIO': 0,
    'UP_SCALE': -1,
    'USE_DECONV': False,
    'USE_DECONV_OUTPUT': False},
    'MATLAB': 'matlab',
    'MEMONGER': True,
    'MEMONGER_SHARE_ACTIVATIONS': False,
    'MODEL': {'BBOX_REG_WEIGHTS': (10.0, 10.0, 5.0, 5.0),
    'CLS_AGNOSTIC_BBOX_REG': False,
    'CONV_BODY': 'FPN.add_fpn_ResNet101_conv5_body',
    'EXECUTION_TYPE': 'dag',
    'FASTER_RCNN': True,
    'KEYPOINTS_ON': False,
    'MASK_ON': True,
    'NUM_CLASSES': 81,
    'RPN_ONLY': False,
    'TYPE': 'generalized_rcnn'},
    'MRCNN': {'CLS_SPECIFIC_MASK': True,
    'CONV_INIT': 'MSRAFill',
    'DILATION': 1,
    'DIM_REDUCED': 256,
    'RESOLUTION': 28,
    'ROI_MASK_HEAD': 'mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs',
    'ROI_XFORM_METHOD': 'RoIAlign',
    'ROI_XFORM_RESOLUTION': 14,
    'ROI_XFORM_SAMPLING_RATIO': 2,
    'THRESH_BINARIZE': 0.5,
    'UPSAMPLE_RATIO': 1,
    'USE_FC_OUTPUT': False,
    'WEIGHT_LOSS_MASK': 1.0},
    'NUM_GPUS': 1,
    'OUTPUT_DIR': '.',
    'PIXEL_MEANS': array([[[102.9801, 115.9465, 122.7717]]]),
    'RESNETS': {'NUM_GROUPS': 1,
    'RES5_DILATION': 1,
    'SHORTCUT_FUNC': 'basic_bn_shortcut',
    'STEM_FUNC': 'basic_bn_stem',
    'STRIDE_1X1': True,
    'TRANS_FUNC': 'bottleneck_transformation',
    'WIDTH_PER_GROUP': 64},
    'RETINANET': {'ANCHOR_SCALE': 4,
    'ASPECT_RATIOS': (0.5, 1.0, 2.0),
    'BBOX_REG_BETA': 0.11,
    'BBOX_REG_WEIGHT': 1.0,
    'CLASS_SPECIFIC_BBOX': False,
    'INFERENCE_TH': 0.05,
    'LOSS_ALPHA': 0.25,
    'LOSS_GAMMA': 2.0,
    'NEGATIVE_OVERLAP': 0.4,
    'NUM_CONVS': 4,
    'POSITIVE_OVERLAP': 0.5,
    'PRE_NMS_TOP_N': 1000,
    'PRIOR_PROB': 0.01,
    'RETINANET_ON': False,
    'SCALES_PER_OCTAVE': 3,
    'SHARE_CLS_BBOX_TOWER': False,
    'SOFTMAX': False},
    'RFCN': {'PS_GRID_SIZE': 3},
    'RNG_SEED': 3,
    'ROOT_DIR': '/home/siddharthnavin/Detectron/tools',
    'RPN': {'ASPECT_RATIOS': (0.5, 1, 2),
    'RPN_ON': True,
    'SIZES': (64, 128, 256, 512),
    'STRIDE': 16},
    'SOLVER': {'BASE_LR': 0.02,
    'GAMMA': 0.1,
    'LOG_LR_CHANGE_THRESHOLD': 1.1,
    'LRS': [],
    'LR_POLICY': 'steps_with_decay',
    'MAX_ITER': 180000,
    'MOMENTUM': 0.9,
    'SCALE_MOMENTUM': True,
    'SCALE_MOMENTUM_THRESHOLD': 1.1,
    'STEPS': [0, 120000, 160000],
    'STEP_SIZE': 30000,
    'WARM_UP_FACTOR': 0.3333333333333333,
    'WARM_UP_ITERS': 500,
    'WARM_UP_METHOD': u'linear',
    'WEIGHT_DECAY': 0.0001,
    'WEIGHT_DECAY_GN': 0.0},
    'TEST': {'BBOX_AUG': {'AREA_TH_HI': 32400,
    'AREA_TH_LO': 2500,
    'ASPECT_RATIOS': (),
    'ASPECT_RATIO_H_FLIP': False,
    'COORD_HEUR': 'UNION',
    'ENABLED': False,
    'H_FLIP': False,
    'MAX_SIZE': 4000,
    'SCALES': (),
    'SCALE_H_FLIP': False,
    'SCALE_SIZE_DEP': False,
    'SCORE_HEUR': 'UNION'},
    'BBOX_REG': True,
    'BBOX_VOTE': {'ENABLED': False,
    'SCORING_METHOD': 'ID',
    'SCORING_METHOD_BETA': 1.0,
    'VOTE_TH': 0.8},
    'COMPETITION_MODE': True,
    'DATASETS': ('coco_2014_minival',),
    'DETECTIONS_PER_IM': 100,
    'FORCE_JSON_DATASET_EVAL': False,
    'KPS_AUG': {'AREA_TH': 32400,
    'ASPECT_RATIOS': (),
    'ASPECT_RATIO_H_FLIP': False,
    'ENABLED': False,
    'HEUR': 'HM_AVG',
    'H_FLIP': False,
    'MAX_SIZE': 4000,
    'SCALES': (),
    'SCALE_H_FLIP': False,
    'SCALE_SIZE_DEP': False},
    'MASK_AUG': {'AREA_TH': 32400,
    'ASPECT_RATIOS': (),
    'ASPECT_RATIO_H_FLIP': False,
    'ENABLED': False,
    'HEUR': 'SOFT_AVG',
    'H_FLIP': False,
    'MAX_SIZE': 4000,
    'SCALES': (),
    'SCALE_H_FLIP': False,
    'SCALE_SIZE_DEP': False},
    'MAX_SIZE': 1333,
    'NMS': 0.5,
    'PRECOMPUTED_PROPOSALS': False,
    'PROPOSAL_FILES': (),
    'PROPOSAL_LIMIT': 2000,
    'RPN_MIN_SIZE': 0,
    'RPN_NMS_THRESH': 0.7,
    'RPN_POST_NMS_TOP_N': 1000,
    'RPN_PRE_NMS_TOP_N': 1000,
    'SCALE': 800,
    'SCORE_THRESH': 0.05,
    'SOFT_NMS': {'ENABLED': False, 'METHOD': 'linear', 'SIGMA': 0.5},
    'WEIGHTS': u'/tmp/detectron-download-cache/35861858/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml.02_32_51.SgT4y1cO/output/train/coco_2014_train:coco_2014_valminusminival/generalized_rcnn/model_final.pkl'},
    'TRAIN': {'ASPECT_GROUPING': True,
    'AUTO_RESUME': True,
    'BATCH_SIZE_PER_IM': 512,
    'BBOX_THRESH': 0.5,
    'BG_THRESH_HI': 0.5,
    'BG_THRESH_LO': 0.0,
    'CROWD_FILTER_THRESH': 0.7,
    'DATASETS': ('coco_2014_train', 'coco_2014_valminusminival'),
    'FG_FRACTION': 0.25,
    'FG_THRESH': 0.5,
    'FREEZE_CONV_BODY': False,
    'GT_MIN_AREA': -1,
    'IMS_PER_BATCH': 2,
    'MAX_SIZE': 1333,
    'PROPOSAL_FILES': (),
    'RPN_BATCH_SIZE_PER_IM': 256,
    'RPN_FG_FRACTION': 0.5,
    'RPN_MIN_SIZE': 0,
    'RPN_NEGATIVE_OVERLAP': 0.3,
    'RPN_NMS_THRESH': 0.7,
    'RPN_POSITIVE_OVERLAP': 0.7,
    'RPN_POST_NMS_TOP_N': 2000,
    'RPN_PRE_NMS_TOP_N': 2000,
    'RPN_STRADDLE_THRESH': 0,
    'SCALES': (800,),
    'SNAPSHOT_ITERS': 20000,
    'USE_FLIPPED': True,
    'WEIGHTS': u'/tmp/detectron-download-cache/ImageNetPretrained/MSRA/R-101.pkl'},
    'USE_NCCL': False,
    'VIS': False,
    'VIS_TH': 0.9}
    WARNING cnn.py: 25: [====DEPRECATE WARNING====]: you are creating an object from CNNModelHelper class which will be deprecated soon. Please use ModelHelper object with brew module. For more information, please refer to caffe2.ai and python/brew.py, python/brew_test.py for more information.
    INFO net.py: 59: Loading weights from: /tmp/detectron-download-cache/35861858/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml.02_32_51.SgT4y1cO/output/train/coco_2014_train:coco_2014_valminusminival/generalized_rcnn/model_final.pkl
    I0726 10:37:17.566745 12560 operator.cc:169] Engine CUDNN is not available for operator MaxPool.
    I0726 10:37:17.574590 12560 operator.cc:169] Engine CUDNN is not available for operator MaxPool.
    I0726 10:37:17.576123 12560 net_dag_utils.cc:102] Operator graph pruning prior to chain compute took: 0.000145365 secs
    I0726 10:37:17.580385 12560 operator.cc:169] Engine CUDNN is not available for operator MaxPool.
    I0726 10:37:17.588212 12560 operator.cc:169] Engine CUDNN is not available for operator MaxPool.
    I0726 10:37:17.588510 12560 net_dag_utils.cc:102] Operator graph pruning prior to chain compute took: 0.000126763 secs
    I0726 10:37:17.590132 12560 net_dag_utils.cc:102] Operator graph pruning prior to chain compute took: 1.5253e-05 secs
    Removing op StopGradient:
    input: "gpu_0/res2_2_sum"
    output: "gpu_0/res2_2_sum"
    name: ""
    type: "StopGradient"
    device_option {
    device_type: 1
    cuda_gpu_id: 0
    }

Converting GenerateProposals Python -> C++:
input: "gpu_0/rpn_cls_probs_fpn2"
input: "gpu_0/rpn_bbox_pred_fpn2"
input: "gpu_0/im_info"
output: "gpu_0/rpn_rois_fpn2"
output: "gpu_0/rpn_roi_probs_fpn2"
name: "GenerateProposalsOp:gpu_0/rpn_cls_probs_fpn2,gpu_0/rpn_bbox_pred_fpn2,im_info"
type: "Python"
arg {
name: "grad_input_indices"
}
arg {
name: "token"
s: "forward"
}
arg {
name: "spatial_scale"
f: 0.25
}
arg {
name: "grad_output_indices"
}
device_option {
device_type: 1
cuda_gpu_id: 0
}

anchors [[-22. -10. 25. 13.]
[-14. -14. 17. 17.]
[-10. -22. 13. 25.]]
Converting GenerateProposals Python -> C++:
input: "gpu_0/rpn_cls_probs_fpn3"
input: "gpu_0/rpn_bbox_pred_fpn3"
input: "gpu_0/im_info"
output: "gpu_0/rpn_rois_fpn3"
output: "gpu_0/rpn_roi_probs_fpn3"
name: "GenerateProposalsOp:gpu_0/rpn_cls_probs_fpn3,gpu_0/rpn_bbox_pred_fpn3,im_info"
type: "Python"
arg {
name: "grad_input_indices"
}
arg {
name: "token"
s: "forward:1"
}
arg {
name: "spatial_scale"
f: 0.125
}
arg {
name: "grad_output_indices"
}
device_option {
device_type: 1
cuda_gpu_id: 0
}

anchors [[-40. -20. 47. 27.]
[-28. -28. 35. 35.]
[-20. -44. 27. 51.]]
Converting GenerateProposals Python -> C++:
input: "gpu_0/rpn_cls_probs_fpn4"
input: "gpu_0/rpn_bbox_pred_fpn4"
input: "gpu_0/im_info"
output: "gpu_0/rpn_rois_fpn4"
output: "gpu_0/rpn_roi_probs_fpn4"
name: "GenerateProposalsOp:gpu_0/rpn_cls_probs_fpn4,gpu_0/rpn_bbox_pred_fpn4,im_info"
type: "Python"
arg {
name: "grad_input_indices"
}
arg {
name: "token"
s: "forward:2"
}
arg {
name: "spatial_scale"
f: 0.0625
}
arg {
name: "grad_output_indices"
}
device_option {
device_type: 1
cuda_gpu_id: 0
}

anchors [[-84. -40. 99. 55.]
[-56. -56. 71. 71.]
[-36. -80. 51. 95.]]
Converting GenerateProposals Python -> C++:
input: "gpu_0/rpn_cls_probs_fpn5"
input: "gpu_0/rpn_bbox_pred_fpn5"
input: "gpu_0/im_info"
output: "gpu_0/rpn_rois_fpn5"
output: "gpu_0/rpn_roi_probs_fpn5"
name: "GenerateProposalsOp:gpu_0/rpn_cls_probs_fpn5,gpu_0/rpn_bbox_pred_fpn5,im_info"
type: "Python"
arg {
name: "grad_input_indices"
}
arg {
name: "token"
s: "forward:3"
}
arg {
name: "spatial_scale"
f: 0.03125
}
arg {
name: "grad_output_indices"
}
device_option {
device_type: 1
cuda_gpu_id: 0
}

anchors [[-164. -72. 195. 103.]
[-112. -112. 143. 143.]
[ -76. -168. 107. 199.]]
Converting GenerateProposals Python -> C++:
input: "gpu_0/rpn_cls_probs_fpn6"
input: "gpu_0/rpn_bbox_pred_fpn6"
input: "gpu_0/im_info"
output: "gpu_0/rpn_rois_fpn6"
output: "gpu_0/rpn_roi_probs_fpn6"
name: "GenerateProposalsOp:gpu_0/rpn_cls_probs_fpn6,gpu_0/rpn_bbox_pred_fpn6,im_info"
type: "Python"
arg {
name: "grad_input_indices"
}
arg {
name: "token"
s: "forward:4"
}
arg {
name: "spatial_scale"
f: 0.015625
}
arg {
name: "grad_output_indices"
}
device_option {
device_type: 1
cuda_gpu_id: 0
}

anchors [[-332. -152. 395. 215.]
[-224. -224. 287. 287.]
[-148. -328. 211. 391.]]
Converting CollectAndDistributeFpnRpnProposals Python -> C++:
input: "gpu_0/rpn_rois_fpn2"
input: "gpu_0/rpn_rois_fpn3"
input: "gpu_0/rpn_rois_fpn4"
input: "gpu_0/rpn_rois_fpn5"
input: "gpu_0/rpn_rois_fpn6"
input: "gpu_0/rpn_roi_probs_fpn2"
input: "gpu_0/rpn_roi_probs_fpn3"
input: "gpu_0/rpn_roi_probs_fpn4"
input: "gpu_0/rpn_roi_probs_fpn5"
input: "gpu_0/rpn_roi_probs_fpn6"
output: "gpu_0/rois"
output: "gpu_0/rois_fpn2"
output: "gpu_0/rois_fpn3"
output: "gpu_0/rois_fpn4"
output: "gpu_0/rois_fpn5"
output: "gpu_0/rois_idx_restore_int32"
name: "CollectAndDistributeFpnRpnProposalsOp:gpu_0/rpn_rois_fpn2,gpu_0/rpn_rois_fpn3,gpu_0/rpn_rois_fpn4,gpu_0/rpn_rois_fpn5,gpu_0/rpn_rois_fpn6,gpu_0/rpn_roi_probs_fpn2,gpu_0/rpn_roi_probs_fpn3,gpu_0/rpn_roi_probs_fpn4,gpu_0/rpn_roi_probs_fpn5,gpu_0/rpn_roi_probs_fpn6"
type: "Python"
arg {
name: "grad_input_indices"
}
arg {
name: "token"
s: "forward:5"
}
arg {
name: "grad_output_indices"
}
device_option {
device_type: 1
cuda_gpu_id: 0
}

Converting op CollectAndDistributeFpnRpnProposals output name: rois -> rpn_rois:
input: "rpn_rois_fpn2"
input: "rpn_rois_fpn3"
input: "rpn_rois_fpn4"
input: "rpn_rois_fpn5"
input: "rpn_rois_fpn6"
input: "rpn_roi_probs_fpn2"
input: "rpn_roi_probs_fpn3"
input: "rpn_roi_probs_fpn4"
input: "rpn_roi_probs_fpn5"
input: "rpn_roi_probs_fpn6"
output: "rois"
output: "rois_fpn2"
output: "rois_fpn3"
output: "rois_fpn4"
output: "rois_fpn5"
output: "rois_idx_restore_int32"
name: ""
type: "CollectAndDistributeFpnRpnProposals"
arg {
name: "roi_max_level"
i: 5
}
arg {
name: "rpn_post_nms_topN"
i: 1000
}
arg {
name: "roi_canonical_scale"
i: 224
}
arg {
name: "rpn_min_level"
i: 2
}
arg {
name: "roi_canonical_level"
i: 4
}
arg {
name: "roi_min_level"
i: 2
}
arg {
name: "rpn_max_level"
i: 6
}

Fusing affine channel...
Fusing affine channel...
Blob fpn_res2_2_sum with type <type 'str'> is not supported in generating init net, skipped.
Blob mask_rois_fpn2 with type <type 'str'> is not supported in generating init net, skipped.
Blob fpn_res3_3_sum with type <type 'str'> is not supported in generating init net, skipped.
Blob mask_rois_fpn3 with type <type 'str'> is not supported in generating init net, skipped.
Blob fpn_res4_22_sum with type <type 'str'> is not supported in generating init net, skipped.
Blob mask_rois_fpn4 with type <type 'str'> is not supported in generating init net, skipped.
Blob fpn_res5_2_sum with type <type 'str'> is not supported in generating init net, skipped.
Blob mask_rois_fpn5 with type <type 'str'> is not supported in generating init net, skipped.
Blob mask_rois_idx_restore_int32 with type <type 'str'> is not supported in generating init net, skipped.
Traceback (most recent call last):
File "convert_pkl_to_pb.py", line 756, in
main()
File "convert_pkl_to_pb.py", line 740, in main
models_pb['mask_net'] = convert_to_pb(args, model.mask_net.Proto(), blobs, part_name='mask_net')#,input_blobs=['rpn_rois', 'bbox_pred', 'im_info', 'cls_prob'])
File "convert_pkl_to_pb.py", line 708, in convert_to_pb
[pb_net, pb_init_net] = convert_model_gpu(args, pb_net, pb_init_net)
File "convert_pkl_to_pb.py", line 363, in convert_model_gpu
ret = core.InjectDeviceCopiesAmongNets([ret_init_net, ret_net])
File "/usr/local/lib/python2.7/dist-packages/caffe2/python/core.py", line 2420, in InjectDeviceCopiesAmongNets
blob_remap=blob_remap,
File "/usr/local/lib/python2.7/dist-packages/caffe2/python/core.py", line 2311, in InjectCrossDeviceCopies
"input {} should be defined in the net.".format(input)
AssertionError: input rpn_rois should be defined in the net.

  • A similar error occurs when I run 'python2 tools/convert_pkl_to_pb.py --cfg configs/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml --out_dir ./model_files --device cpu --test_img demo/16004479832_a748d55f21_k.jpg' and try to use the 'args.test_img' in run_model_pb. In this case there are no issues in creating the pb model files but when run_model_pb is called the following error occurs.
    Model def image saved to /home/siddharthnavin/Detectron/model_files.
    I0726 10:52:59.913588 13278 ThreadPool.cc:79] Constructing thread pool with 20 threads
    WARNING workspace.py: 187: Original python traceback for operator 1754176616 in network detectron_mask_net in exception above (most recent call last):
    Traceback (most recent call last):
    File "tools/convert_pkl_to_pb.py", line 755, in
    main()
    File "tools/convert_pkl_to_pb.py", line 752, in main
    run_model_pb(args, models_pb, im, check_blobs)
    File "tools/convert_pkl_to_pb.py", line 590, in run_model_pb
    workspace.CreateNet(mask_net)
    File "/usr/local/lib/python2.7/dist-packages/caffe2/python/workspace.py", line 154, in CreateNet
    StringifyProto(net), overwrite,
    File "/usr/local/lib/python2.7/dist-packages/caffe2/python/workspace.py", line 180, in CallWithExceptionIntercept
    return func(*args, **kwargs)
    RuntimeError: [enforce fail at net.cc:69] . op BBoxTransform: Source for input rpn_rois is unknown for net detectron_mask_net, operator input: "rpn_rois" input: "bbox_pred" input: "im_info" output: "pred_bbox" name: "" type: "BBoxTransform" arg { name: "correct_transform_coords" i: 1 } arg { name: "apply_scale" i: 0 } arg { name: "weights" floats: 10 floats: 10 floats: 5 floats: 5 }

@gadcam
Copy link
Contributor Author

gadcam commented Jul 27, 2018

Command- python2 tools/convert_pkl_to_pb.py --cfg configs/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml --out_dir ./model_files --device gpu

@sidnav Everything is here --device gpu. (also in your previous reply you mentionned GPU but I did not understand what it really meant)
This pull request is for GPU to CPU conversion only in its current state.
But I do not really understand what you try to do, it should not really be a limitation as you can already use the raw pkl model on GPU :)

@gadcam
Copy link
Contributor Author

gadcam commented Aug 11, 2018

@rbgirshick As #372 is in there is no more technical barrier to merge this one.
Tell me if you want me to make some tests or take any measure to ease a potential review.
I fixed 2 minor details to enable Python 3 support (if #110 were to be merged for example).

@daquexian do not hesitate to make any remark, as you worked on the same piece of code it should be easier to check if there are any mistakes.

@rbgirshick
Copy link
Contributor

@gadcam thanks for flagging this. If @newstzpz signs off on this, I'll merge it.

@newstzpz
Copy link
Contributor

@gadcam Thanks for working on this. Since we have open-sourced the heatmap_max_keypoint_op (a fast implementation of keypoint_utils.heatmaps_to_keypoints), and now we have the FPN support in pb, do you think it would be easier if we create a single model directly (instead of multiple pbs)?

@gadcam
Copy link
Contributor Author

gadcam commented Aug 16, 2018

@newstzpz I started to modify this PR to merge the nets. Then, I will get to the heatmaps_to_keypoints question. I can not share the code now.
Currently I have this error

op RoIAlign: Source for input keypoint_rois_fpn2 is unknown for net detectron_net operator input: "fpn_res2_2_sum"
input: "keypoint_rois_fpn2"
output: "_[pose]_roi_feat_fpn2"
name: "" type: "RoIAlign"
arg { name: "pooled_w" i: 14 }
arg { name: "pooled_h" i: 14 }
arg { name: "spatial_scale" f: 0.25 }
arg { name: "sampling_ratio" i: 2 }
device_option { }
engine: ""

As you see there is a naming problem in the inputs/outputs names: keypoint_rois_fpn2 and rois_fpn2 should be the same (as I understand it at least)
Do you know how to solve it or what could I try ?

part_net

- WIP : the models run but there are differences in the results
@gadcam
Copy link
Contributor Author

gadcam commented Aug 19, 2018

@newstzpz @rbgirshick @ir413
Here is a commit to merge the nets

  • Step 1 - Connect the rois blobs
  • Step 2 - Give the good external_input/output from the kps/mask_net

It is running however I have different results from the original network. I am not able to find where I am wrong. Could you check the code ?
The part where I should be mistaking is here

if cfg.MODEL.KEYPOINTS_ON:
model_kps = model.keypoint_net.Proto()
# Connect rois blobs
for op in model_kps.op:
for i, input_name in enumerate(op.input):
op.input[i] = input_name.replace("keypoint_rois", "rois")
# Remove external input defined in main net
kps_external_input = []
for i in model_kps.external_input:
if not model.net.BlobIsDefined(i) and \
not "keypoint_rois" in i:
kps_external_input.append(i)
model.net.Proto().op.extend(model_kps.op)
model.net.Proto().external_output.extend(model_kps.external_output)
model.net.Proto().external_input.extend(kps_external_input)

Ps: before this commit (and so trying to merge nets) the models had the same output

@gadcam
Copy link
Contributor Author

gadcam commented Sep 3, 2018

@newstzpz @rbgirshick @ir413 Did you have time to have a look at the code ?
I think I need to do some transformation on the rois blobs before passing them but I can not figure what I need to do. Could check if there is something to do here or no ?

@Houd1ny
Copy link

Houd1ny commented Dec 25, 2018

Is there any update in this?
As I understand you can not export keypoints part of network using convert_pkl_to_pb.py?

@Houd1ny
Copy link

Houd1ny commented Dec 25, 2018

@gadcam is your current version of pull request working?

@Houd1ny
Copy link

Houd1ny commented Dec 26, 2018

@gadcam I have merged your PL and was able to convert keypoints model to CPU.
Thank for your work.
But I could ask a few questions

  1. In this comment you said that keypoints are not the same. Are keypoints generated by converted model good?
  2. I still do not undestand how to use converted model in end-to-end fashion. Invetigating detectron code im_detect_all I have found that models run in few forward passes. The question is how to corrrectly run converted model as it is only one model?
    I have tried to run it and have inconsistent shapes
    When I print shape of kps_score during infer_simple.py it is (7, 56, 56)
    And when I fetch using my code it is
    (1000, 7, 56, 56)
    How to deal with it?

Thank for you time!

@pkuxwguan
Copy link

pkuxwguan commented Feb 27, 2019

Hi @gadcam, I run `convert_pkl_to_pb.py' with the device option set to gpu. I get the following error
File "convert_pkl_to_pb.py", line 752, in
main()
File "convert_pkl_to_pb.py", line 736, in main
models_pb['mask_net'] = convert_to_pb(args, model.mask_net.Proto(), blobs, part_name='mask_net')
File "convert_pkl_to_pb.py", line 706, in convert_to_pb
[pb_net, pb_init_net] = convert_model_gpu(args, pb_net, pb_init_net)
File "convert_pkl_to_pb.py", line 361, in convert_model_gpu
ret = core.InjectDeviceCopiesAmongNets([ret_init_net, ret_net])
File "/usr/local/lib/python2.7/dist-packages/caffe2/python/core.py", line 2420, in InjectDeviceCopiesAmongNets
blob_remap=blob_remap,
File "/usr/local/lib/python2.7/dist-packages/caffe2/python/core.py", line 2311, in InjectCrossDeviceCopies
"input {} should be defined in the net.".format(input)
AssertionError: input rpn_rois should be defined in the net

hi @sidnav Have u solved this issue? can you post the answer here.Thank you

@gadcam
Copy link
Contributor Author

gadcam commented Sep 22, 2019

Looks like there is not enough interest in this work to eventually merge this PR or part of this work.

@gadcam gadcam closed this Sep 22, 2019
@chenhangcal
Copy link

Very helpful work!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants