diff --git a/mrcnn/model.py b/mrcnn/model.py index a248de8762..76d329e1ca 100644 --- a/mrcnn/model.py +++ b/mrcnn/model.py @@ -560,7 +560,11 @@ def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config) # Assign positive ROIs to GT boxes. positive_overlaps = tf.gather(overlaps, positive_indices) - roi_gt_box_assignment = tf.argmax(positive_overlaps, axis=1) + roi_gt_box_assignment = tf.cond( + tf.greater(tf.shape(positive_overlaps)[1], 0), + true_fn = lambda: tf.argmax(positive_overlaps, axis=1), + false_fn = lambda: tf.cast(tf.constant([]),tf.int64) + ) roi_gt_boxes = tf.gather(gt_boxes, roi_gt_box_assignment) roi_gt_class_ids = tf.gather(gt_class_ids, roi_gt_box_assignment)