Skip to content

Commit

Permalink
changed according to
Browse files Browse the repository at this point in the history
  • Loading branch information
crimefightingllama committed Apr 2, 2023
1 parent 25f53a7 commit 5e1d265
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions mrcnn/model.py
Expand Up @@ -699,7 +699,7 @@ def refine_detections_graph(rois, probs, deltas, window, config):
# Class IDs per ROI
class_ids = tf.argmax(probs, axis=1, output_type=tf.int32)
# Class probability of the top class of each ROI
indices = tf.stack([tf.range(probs.shape[0]), class_ids], axis=1)
indices = tf.stack([tf.range(tf.shape(probs)[0]), class_ids], axis = 1)
class_scores = tf.gather_nd(probs, indices)
# Class-specific bounding box deltas
deltas_specific = tf.gather_nd(deltas, indices)
Expand Down Expand Up @@ -948,7 +948,10 @@ def fpn_classifier_graph(rois, feature_maps, image_meta,
name='mrcnn_bbox_fc')(shared)
# Reshape to [batch, num_rois, NUM_CLASSES, (dy, dx, log(dh), log(dw))]
s = K.int_shape(x)
mrcnn_bbox = KL.Reshape((s[1], num_classes, 4), name="mrcnn_bbox")(x)
if s[1]==None:
mrcnn_bbox = layers.Reshape((-1, num_classes, 4), name="mrcnn_bbox")(x)
else:
mrcnn_bbox = layers.Reshape((s[1], num_classes, 4), name="mrcnn_bbox")(x)

return mrcnn_class_logits, mrcnn_probs, mrcnn_bbox

Expand Down

0 comments on commit 5e1d265

Please sign in to comment.