From 025f4c178f14a969bf9f43c75a6929db58dc9253 Mon Sep 17 00:00:00 2001 From: Julien Rebetez Date: Tue, 15 May 2018 11:24:57 +0200 Subject: [PATCH] Fix problem with argmax on (0,0) arrays. Fix matterport/Mask_RCNN#170 --- mrcnn/model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mrcnn/model.py b/mrcnn/model.py index a248de8762..76d329e1ca 100644 --- a/mrcnn/model.py +++ b/mrcnn/model.py @@ -560,7 +560,11 @@ def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config) # Assign positive ROIs to GT boxes. positive_overlaps = tf.gather(overlaps, positive_indices) - roi_gt_box_assignment = tf.argmax(positive_overlaps, axis=1) + roi_gt_box_assignment = tf.cond( + tf.greater(tf.shape(positive_overlaps)[1], 0), + true_fn = lambda: tf.argmax(positive_overlaps, axis=1), + false_fn = lambda: tf.cast(tf.constant([]),tf.int64) + ) roi_gt_boxes = tf.gather(gt_boxes, roi_gt_box_assignment) roi_gt_class_ids = tf.gather(gt_class_ids, roi_gt_box_assignment)