Skip to content
This repository has been archived by the owner on Nov 21, 2023. It is now read-only.

Support exporting for CPU Mask & Keypoint nets #449

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions detectron/utils/model_convert_utils.py
Expand Up @@ -315,11 +315,11 @@ def gen_init_net_from_blobs(blobs, blobs_to_use=None, excluded_blobs=None):
blobs_to_use = [x for x in blobs_to_use if x not in excluded_blobs]
for name in blobs_to_use:
blob = blobs[name]
if isinstance(blob, str):
if isinstance(blob, np.ndarray):
add_tensor(ret, name, blob)
else:
print('Blob {} with type {} is not supported in generating init net,'
' skipped.'.format(name, type(blob)))
continue
add_tensor(ret, name, blob)

return ret

Expand Down
206 changes: 145 additions & 61 deletions tools/convert_pkl_to_pb.py
Expand Up @@ -52,9 +52,13 @@
from detectron.utils.model_convert_utils import op_filter
import detectron.utils.blob as blob_utils
import detectron.core.test_engine as test_engine
import detectron.core.test as test
import detectron.utils.c2 as c2_utils
import detectron.utils.model_convert_utils as mutils
import detectron.utils.vis as vis_utils
import detectron.utils.blob as blob_utils
import detectron.utils.keypoints as keypoint_utils
import pycocotools.mask as mask_utils

c2_utils.import_contrib_ops()
c2_utils.import_detectron_ops()
Expand Down Expand Up @@ -334,8 +338,8 @@ def add_bbox_ops(args, net, blobs):
new_ops.extend([op_nms])
new_external_outputs.extend(['score_nms', 'bbox_nms', 'class_nms'])

net.Proto().op.extend(new_ops)
net.Proto().external_output.extend(new_external_outputs)
net.op.extend(new_ops)
net.external_output.extend(new_external_outputs)


def convert_model_gpu(args, net, init_net):
Expand Down Expand Up @@ -392,23 +396,23 @@ def gen_init_net(net, blobs, empty_blobs):
def _save_image_graphs(args, all_net, all_init_net):
print('Saving model graph...')
mutils.save_graph(
all_net.Proto(), os.path.join(args.out_dir, "model_def.png"),
all_net.Proto(), os.path.join(args.out_dir, all_net.Proto().name + '.png'),
op_only=False)
print('Model def image saved to {}.'.format(args.out_dir))


def _save_models(all_net, all_init_net, args):
print('Writing converted model to {}...'.format(args.out_dir))
fname = "model"
fname = all_net.Proto().name

if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)

with open(os.path.join(args.out_dir, fname + '.pb'), 'w') as f:
with open(os.path.join(args.out_dir, fname + '.pb'), 'wb') as f:
f.write(all_net.Proto().SerializeToString())
with open(os.path.join(args.out_dir, fname + '.pbtxt'), 'w') as f:
f.write(str(all_net.Proto()))
with open(os.path.join(args.out_dir, fname + '_init.pb'), 'w') as f:
with open(os.path.join(args.out_dir, fname + '_init.pb'), 'wb') as f:
f.write(all_init_net.Proto().SerializeToString())

_save_image_graphs(args, all_net, all_init_net)
Expand Down Expand Up @@ -457,13 +461,14 @@ def run_model_cfg(args, im, check_blobs):
cls_boxes, cls_segms, cls_keyps = test_engine.im_detect_all(
model, im, None, None,
)

boxes, segms, keypoints, classes = vis_utils.convert_from_cls_format(
boxes, segms, keypoints, classids = vis_utils.convert_from_cls_format(
cls_boxes, cls_segms, cls_keyps)

segms = mask_utils.decode(segms) if segms else None

# sort the results based on score for comparision
boxes, segms, keypoints, classes = _sort_results(
boxes, segms, keypoints, classes)
boxes, segms, keypoints, classids = _sort_results(
boxes, segms, keypoints, classids)

# write final results back to workspace
def _ornone(res):
Expand All @@ -472,12 +477,16 @@ def _ornone(res):
workspace.FeedBlob(core.ScopedName('result_boxes'), _ornone(boxes))
workspace.FeedBlob(core.ScopedName('result_segms'), _ornone(segms))
workspace.FeedBlob(core.ScopedName('result_keypoints'), _ornone(keypoints))
workspace.FeedBlob(core.ScopedName('result_classids'), _ornone(classes))
workspace.FeedBlob(core.ScopedName('result_classids'), _ornone(classids))

# get result blobs
with c2_utils.NamedCudaScope(0):
ret = _get_result_blobs(check_blobs)

print('result_boxes', _ornone(boxes))
print('result_segms', _ornone(segms))
print('result_keypoints', _ornone(keypoints))
print('result_classids', _ornone(classids))
return ret


Expand Down Expand Up @@ -519,7 +528,6 @@ def run_model_pb(args, net, init_net, im, check_blobs):
mutils.create_input_blobs_for_net(net.Proto())
workspace.CreateNet(net)

# input_blobs, _ = core_test._get_blobs(im, None)
input_blobs = _prepare_blobs(
im,
cfg.PIXEL_MEANS,
Expand All @@ -538,36 +546,76 @@ def run_model_pb(args, net, init_net, im, check_blobs):

try:
workspace.RunNet(net)
scores = workspace.FetchBlob('score_nms')
classids = workspace.FetchBlob('class_nms')
boxes = workspace.FetchBlob('bbox_nms')
scores = workspace.FetchBlob(core.ScopedName('score_nms'))
classids = workspace.FetchBlob(core.ScopedName('class_nms'))
boxes = workspace.FetchBlob(core.ScopedName('bbox_nms'))
except Exception as e:
print('Running pb model failed.\n{}'.format(e))
# may not detect anything at all
logger.warn('Running pb model failed.\n{}'.format(e))
R = 0
scores = np.zeros((R,), dtype=np.float32)
boxes = np.zeros((R, 4), dtype=np.float32)
classids = np.zeros((R,), dtype=np.float32)

cls_segms, cls_keyps = None, None

if net.BlobIsDefined(core.ScopedName('kps_score')):
pred_heatmaps = workspace.FetchBlob(core.ScopedName('kps_score')).squeeze()
# In case of 1
if pred_heatmaps.ndim == 3:
pred_heatmaps = np.expand_dims(pred_heatmaps, axis=0)
xy_preds = keypoint_utils.heatmaps_to_keypoints(pred_heatmaps, boxes)
cls_keyps = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]
cls_keyps[1] = [xy_preds[i] for i in range(xy_preds.shape[0])]
else:
logger.info('Keypoint blob is not defined')

if net.BlobIsDefined(core.ScopedName('mask_fcn_probs')):
# Fetch masks
pred_masks = workspace.FetchBlob(core.ScopedName('mask_fcn_probs')).squeeze()
M = cfg.MRCNN.RESOLUTION
if cfg.MRCNN.CLS_SPECIFIC_MASK:
pred_masks = pred_masks.reshape([-1, cfg.MODEL.NUM_CLASSES, M, M])
else:
pred_masks = pred_masks.reshape([-1, 1, M, M])
cls_boxes = [np.empty(list(classids).count(i)) for i in range(cfg.MODEL.NUM_CLASSES)]
cls_segms = test.segm_results(cls_boxes, pred_masks, boxes, im.shape[0], im.shape[1])
else:
logger.info('Mask blob is not defined')

boxes = np.column_stack((boxes, scores))

_, segms, keypoints, _ = vis_utils.convert_from_cls_format([], cls_segms, cls_keyps)
segms = mask_utils.decode(segms) if segms else None

# sort the results based on score for comparision
boxes, _, _, classids = _sort_results(
boxes, None, None, classids)
boxes, segms, keypoints, classids = _sort_results(
boxes, segms, keypoints, classids)

# write final result back to workspace
workspace.FeedBlob('result_boxes', boxes)
workspace.FeedBlob('result_classids', classids)
def _ornone(res):
return np.array(res) if res is not None else np.array([], dtype=np.float32)
workspace.FeedBlob(core.ScopedName('result_boxes'), _ornone(boxes))
workspace.FeedBlob(core.ScopedName('result_classids'), _ornone(classids))
workspace.FeedBlob(core.ScopedName('result_segms'), _ornone(segms))
workspace.FeedBlob(core.ScopedName('result_keypoints'), _ornone(keypoints))

ret = _get_result_blobs(check_blobs)

print('result_boxes', _ornone(boxes))
print('result_segms', _ornone(segms))
print('result_keypoints', _ornone(keypoints))
print('result_classids', _ornone(classids))
return ret


def verify_model(args, model_pb, test_img_file):
check_blobs = [
"result_boxes", "result_classids", # result
]
def verify_model(args, net, init_net, test_img_file):
check_blobs = ['result_boxes', 'result_classids']

if cfg.MODEL.MASK_ON:
check_blobs.append('result_segms')

if cfg.MODEL.KEYPOINTS_ON:
check_blobs.append('result_keypoints')

print('Loading test file {}...'.format(test_img_file))
test_img = cv2.imread(test_img_file)
Expand All @@ -577,13 +625,49 @@ def _run_cfg_func(im, blobs):
return run_model_cfg(args, im, check_blobs)

def _run_pb_func(im, blobs):
return run_model_pb(args, model_pb[0], model_pb[1], im, check_blobs)
return run_model_pb(args, net, init_net, im, check_blobs)

print('Checking models...')
assert mutils.compare_model(
_run_cfg_func, _run_pb_func, test_img, check_blobs)


def convert_to_pb(args, net, blobs, input_blobs):
pb_net = core.Net('')
pb_net.Proto().op.extend(copy.deepcopy(net.op))

pb_net.Proto().external_input.extend(
copy.deepcopy(net.external_input))
pb_net.Proto().external_output.extend(
copy.deepcopy(net.external_output))
pb_net.Proto().type = args.net_execution_type
pb_net.Proto().num_workers = 1 if args.net_execution_type == 'simple' else 4

# Reset the device_option, change to unscope name and replace python operators
convert_net(args, pb_net.Proto(), blobs)

# add operators for bbox
add_bbox_ops(args, pb_net.Proto(), blobs)

if args.fuse_af:
print('Fusing affine channel...')
pb_net, blobs = mutils.fuse_net_affine(pb_net, blobs)

if args.use_nnpack:
mutils.update_mobile_engines(pb_net.Proto())

# generate init net
pb_init_net = gen_init_net(pb_net, blobs, input_blobs)

if args.device == 'gpu':
[pb_net, pb_init_net] = convert_model_gpu(args, pb_net, pb_init_net)

pb_net.Proto().name = args.net_name + '_net'
pb_init_net.Proto().name = args.net_name + '_net_init'

return pb_net, pb_init_net


def main():
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
args = parse_args()
Expand All @@ -598,55 +682,55 @@ def main():
logger.info('Converting model with config:')
logger.info(pprint.pformat(cfg))

# script will stop when it can't find an operator rather
# than stopping based on these flags
#
# assert not cfg.MODEL.KEYPOINTS_ON, "Keypoint model not supported."
# assert not cfg.MODEL.MASK_ON, "Mask model not supported."
# assert not cfg.FPN.FPN_ON, "FPN not supported."
# assert not cfg.RETINANET.RETINANET_ON, "RetinaNet model not supported."

# load model from cfg
model, blobs = load_model(args)

net = core.Net('')
net.Proto().op.extend(copy.deepcopy(model.net.Proto().op))
net.Proto().external_input.extend(
copy.deepcopy(model.net.Proto().external_input))
net.Proto().external_output.extend(
copy.deepcopy(model.net.Proto().external_output))
net.Proto().type = args.net_execution_type
net.Proto().num_workers = 1 if args.net_execution_type == 'simple' else 4
input_net = ['data', 'im_info']

# Reset the device_option, change to unscope name and replace python operators
convert_net(args, net.Proto(), blobs)
if cfg.MODEL.KEYPOINTS_ON:
model_kps = model.keypoint_net.Proto()

# add operators for bbox
add_bbox_ops(args, net, blobs)
# Connect rois blobs
for op in model_kps.op:
for i, input_name in enumerate(op.input):
op.input[i] = input_name.replace("keypoint_rois", "rois")

if args.fuse_af:
print('Fusing affine channel...')
net, blobs = mutils.fuse_net_affine(
net, blobs)
# Remove external input defined in main net
kps_external_input = []
for i in model_kps.external_input:
if not model.net.BlobIsDefined(i) and \
not "keypoint_rois" in i:
kps_external_input.append(i)

if args.use_nnpack:
mutils.update_mobile_engines(net.Proto())
model.net.Proto().op.extend(model_kps.op)
model.net.Proto().external_output.extend(model_kps.external_output)
model.net.Proto().external_input.extend(kps_external_input)

# generate init net
empty_blobs = ['data', 'im_info']
init_net = gen_init_net(net, blobs, empty_blobs)
if cfg.MODEL.MASK_ON:
model_mask = model.mask_net.Proto()

if args.device == 'gpu':
[net, init_net] = convert_model_gpu(args, net, init_net)
# Connect rois blobs
for op in model_mask.op:
for i, input_name in enumerate(op.input):
op.input[i] = input_name.replace("mask_rois", "rois")

net.Proto().name = args.net_name
init_net.Proto().name = args.net_name + "_init"
# Remove external input defined in main net
mask_external_input = []
for i in model_mask.external_input:
if not model.net.BlobIsDefined(i) and \
not "mask_rois" in i:
mask_external_input.append(i)

if args.test_img is not None:
verify_model(args, [net, init_net], args.test_img)
model.net.Proto().op.extend(model_mask.op)
model.net.Proto().external_output.extend(model_mask.external_output)
model.net.Proto().external_input.extend(mask_external_input)

net, init_net = convert_to_pb(args, model.net.Proto(), blobs, input_net)

_save_models(net, init_net, args)

if args.test_img is not None:
verify_model(args, net, init_net, args.test_img)

if __name__ == '__main__':
main()