Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An error raise when positive_overlaps has shape (0,0), how to fix? #170

Open
ypflll opened this issue Jan 5, 2018 · 15 comments · Fixed by #561
Open

An error raise when positive_overlaps has shape (0,0), how to fix? #170

ypflll opened this issue Jan 5, 2018 · 15 comments · Fixed by #561

Comments

@ypflll
Copy link

ypflll commented Jan 5, 2018

Sometimes, a picture gives no positive_roi, and this will raise an error:

W tensorflow/core/framework/op_kernel.cc:1192] Invalid argument: Reduction axis 1 is empty in shape [0,0]
[[Node: proposal_targets/ArgMax = ArgMax[T=DT_FLOAT, Tidx=DT_INT32, output_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"](proposal_targets/Gather_5, rpn_class_loss/Equal/y)]]
Traceback (most recent call last):
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1323, in _do_call
return fn(*args)
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1302, in _run_fn
status, run_metadata)
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 473, in exit
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Reduction axis 1 is empty in shape [0,0]
[[Node: proposal_targets/ArgMax = ArgMax[T=DT_FLOAT, Tidx=DT_INT32, output_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"](proposal_targets/Gather_5, rpn_class_loss/Equal/y)]]
[[Node: roi_align_classifier/Cast_2/_7079 = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_8551_roi_align_classifier/Cast_2", tensor_type=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "portrait_seg.py", line 206, in
layers='4+')
File "/home/xxx/Desktop/keras_Mask_RCNN/model.py", line 2211, in train
use_multiprocessing=True,
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper
return func(*args, **kwargs)
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/keras/engine/training.py", line 2096, in fit_generator
class_weight=class_weight)
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/keras/engine/training.py", line 1814, in train_on_batch
outputs = self.train_function(ins)
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2352, in call
**self.session_kwargs)
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 889, in run
run_metadata_ptr)
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1120, in _run
feed_dict_tensor, options, run_metadata)
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1317, in _do_run
options, run_metadata)
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1336, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Reduction axis 1 is empty in shape [0,0]
[[Node: proposal_targets/ArgMax = ArgMax[T=DT_FLOAT, Tidx=DT_INT32, output_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"](proposal_targets/Gather_5, rpn_class_loss/Equal/y)]]
[[Node: roi_align_classifier/Cast_2/_7079 = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_8551_roi_align_classifier/Cast_2", tensor_type=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

Caused by op 'proposal_targets/ArgMax', defined at:
File "portrait_seg.py", line 168, in
model_dir=MODEL_DIR)
File "/home/xxx/Desktop/keras_Mask_RCNN/model.py", line 1744, in init
self.keras_model = self.build(mode=mode, config=config)
File "/home/xxx/Desktop/keras_Mask_RCNN/model.py", line 1885, in build
target_rois, input_gt_class_ids, gt_boxes, input_gt_masks])
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/keras/engine/topology.py", line 603, in call
output = self.call(inputs, **kwargs)
File "/home/xxx/Desktop/keras_Mask_RCNN/model.py", line 641, in call
self.config.IMAGES_PER_GPU, names=names)
File "/home/xxx/Desktop/keras_Mask_RCNN/utils.py", line 673, in batch_slice
output_slice = graph_fn(*inputs_slice)
File "/home/xxx/Desktop/keras_Mask_RCNN/model.py", line 640, in
w, x, y, z, self.config),
File "/home/xxx/Desktop/keras_Mask_RCNN/model.py", line 544, in detection_targets_graph
roi_gt_box_assignment = tf.argmax(positive_overlaps, axis=1)
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 316, in new_func
return func(*args, **kwargs)
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py", line 205, in argmax
return gen_math_ops.arg_max(input, axis, name=name, output_type=output_type)
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py", line 441, in arg_max
name=name)
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
op_def=op_def)
File "/home/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1470, in init
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Reduction axis 1 is empty in shape [0,0]
[[Node: proposal_targets/ArgMax = ArgMax[T=DT_FLOAT, Tidx=DT_INT32, output_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"](proposal_targets/Gather_5, rpn_class_loss/Equal/y)]]
[[Node: roi_align_classifier/Cast_2/_7079 = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_8551_roi_align_classifier/Cast_2", tensor_type=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

I tried to fix this by add a check if there is a positive_indices, in model.py, line528:

# Subsample ROIs. Aim for 33% positive
# Positive ROIs
positive_count = int(config.TRAIN_ROIS_PER_IMAGE *
                     config.ROI_POSITIVE_RATIO)
positive_indices = tf.random_shuffle(positive_indices)[:positive_count]
a = positive_indices.get_shape()[0]
if a == 0:
    positive_indices = tf.constant([1])
positive_count = tf.shape(positive_indices)[0]

However, it can avoid some wrong cases, but still raise error in other cases.
How to fix this throghly?

@horvitzs
Copy link

horvitzs commented Jan 22, 2018

I had same problem. When I tried following codes, if worked for me
# Assign positive ROIs to GT boxes. positive_overlaps = tf.gather(overlaps, positive_indices) roi_gt_box_assignment = tf.cond(tf.greater(tf.shape(positive_overlaps)[1], 0), true_fn = lambda: tf.argmax(positive_overlaps, axis=1), false_fn = lambda: tf.cast(tf.constant([]),tf.int64) )
(https://github.com/tensorflow/models/pull/1986/files)

@ypflll
Copy link
Author

ypflll commented Jan 24, 2018

Thank for sharing.
I tried your code, it worked, only partially:
When set the tensor to size 0, it causes another problem, like this:
tensorflow/tensorflow#14962

Wondering what is your tf version? 1.5.0 has fixed this, but I am using 1.4.1.

@horvitzs
Copy link

That's strange. My tf version is also 1.4.1

@ypflll
Copy link
Author

ypflll commented Jan 24, 2018

The error is:
F tensorflow/stream_executor/cuda/cuda_dnn.cc:444] could not convert BatchDescriptor {count: 0 feature_map_count: 1 spatial: 28 28 value_min: 0.000000 value_max: 0.000000 layout: BatchDepthYX} to cudnn tensor descriptor: CUDNN_STATUS_BAD_PARAM
Aborted (core dumped)

My cudn version: 8.0.44, cudnn version: 5.1.10.

@ppwwyyxx Have your PR: tensorflow/tensorflow#14657 is inclued in tf version 1.4.1?
I also find that after the error occured, my gpu(geforce gtx titanx, 12g) memory is always occupied with no process found. Maybe your PR did not solve the problem throughly. I'm not sure.

@ppwwyyxx
Copy link

No. The fix will probably in 1.6. You can use tf.cond to work around the bug like this: https://github.com/ppwwyyxx/tensorpack/blob/6bdd046057e507087f6da3af909d4bcf1726cff2/examples/FasterRCNN/train.py#L122-L133

@ypflll
Copy link
Author

ypflll commented Jan 29, 2018

@ppwwyyxx Thanks for your code. I've thought it would be easy following your code, but get stuck still.

The primary code is:
mrcnn_mask = build_fpn_mask_graph(rois, mrcnn_feature_maps,
config.IMAGE_SHAPE,
config.MASK_POOL_SIZE,
config.NUM_CLASSES)

Like your code, I changed it to:
def ff_true():
mrcnn_mask = build_fpn_mask_graph(rois, mrcnn_feature_maps,
config.IMAGE_SHAPE,
config.MASK_POOL_SIZE,
config.NUM_CLASSES)
return mrcnn_mask

def ff_false():
return target_mask

mrcnn_mask = tf.cond(tf.equal(tf.reduce_mean(rois), 0), ff_true, ff_true)

This raise an error:

ValueError: Initializer for variable cond/mrcnn_class_conv1/kernel/ is from inside a control-flow construct, such as a loop or conditional. When creating a variable inside a loop or conditional, use a lambda as the initializer.

Google it and seems that it's a matter a datatype, which tf also has a bug in error reporting, like this:
tensorflow/tensorflow#14729

So, I tried this:
a = build_fpn_mask_graph(rois, mrcnn_feature_maps,
config.IMAGE_SHAPE,
config.MASK_POOL_SIZE,
config.NUM_CLASSES)

def ff_true():
return a

def ff_false():
return target_mask

mrcnn_mask = tf.cond(tf.equal(tf.reduce_mean(rois), 0), ff_true, ff_true)

Also give me another error:

Traceback (most recent call last):
File "coco.py", line 453, in
model_dir=args.logs)
File "/xxx/model.py", line 1775, in init
self.keras_model = self.build(mode=mode, config=config)
File "/xxx/model.py", line 2022, in build
model = KM.Model(inputs, outputs, name='mask_rcnn')
File "/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper
return func(*args, **kwargs)
File "/xxx/anaconda2/envs/py36/lib/python3.6/site-packages/keras/engine/topology.py", line 1579, in init
'Keras tensors. Found: ' + str(x))
TypeError: Output tensors to a Model must be Keras tensors. Found: Tensor("cond/Merge:0", shape=(?, 200, 28, 28, 81), dtype=float32)

A little clueless to this.
Any clue will be welcomed!

@Prausome You didn't meet this on tf version 1.4.1 seems strange to me. Do your runs coco.py with no modification?
If you change the iou threshold bigger(like 2), maybe you can reproduce this:

positive_roi_bool = (roi_iou_max >= 0.5)

@ppwwyyxx
Copy link

TypeError: Output tensors to a Model must be Keras tensors. Found: Tensor("cond/Merge:0", shape=(?, 200, 28, 28, 81), dtype=float32)

I think the error is saying that Keras has a restrictions on the type of models you can use. But I don't know much about Keras to tell more.

This issue would only happen very occasionally in my experience. So I'm not surprised if someone doesn't see the same error.

@ypflll
Copy link
Author

ypflll commented Jan 31, 2018

@ppwwyyxx Many thanks for your useful advice. I figure this out. It's exactly a difference between tf and keras tensor.

Keras tensors are theano/tf tensors with additional information included. You get keras tensors
from keras.layers.Input or any time you pass an Input to a keras.layers.Layer.
If you're just using the tensor in a loss calculation or something else, you don't have to wrap it in Lambdas. Refer to: keras-team/keras#6263
This code works:
mrcnn_mask = KL.Lambda(lambda x: tf.cond(tf.equal(tf.reduce_mean(x), 0),ff_true, ff_true)) (rois)

However, after fixing this, another error still exists:

ValueError: Initializer for variable cond/mrcnn_class_conv1/kernel/ is from inside a control-flow construct, such as a loop or conditional. When creating a variable inside a loop or conditional, use a lambda as the initializer.

I've raised this on stackoverflow:
https://stackoverflow.com/questions/48515034/keras-tensorflow-initializer-for-variable-is-from-inside-a-control-flow-con
Maybe I can get some clue to fix this bug finally.

@ppwwyyxx
Copy link

This is just another problem of Keras. Keras uses tensors to initialize variables, which is not legal inside conditional. tf.keras does not have this issue, btw.

@waleedka
Copy link
Collaborator

@ypflll Is issue solved for you?

I tried to reproduce the error you got, but I can't reproduce it at the moment. I forced positive_overlaps to be [], and the training is working as usual for me without any change to the code. I'm testing on TF 1.5 at the moment. Probably the latest version of TF solved it?

@ypflll
Copy link
Author

ypflll commented Feb 24, 2018

Sorry for late reply, cause I'm on holiday.

I met the problem: 'Reduction axis 1 is empty in shape [0,0]' on the adobe portrait dataset, not on coco: http://xiaoyongshen.me/webpage_portrait/index.html
On coco, I tried to set positive_overlaps to [] as you and no problem reaised.

Actually, I follow Prausome's code and solved this problem:
# Assign positive ROIs to GT boxes. positive_overlaps = tf.gather(overlaps, positive_indices) roi_gt_box_assignment = tf.cond(tf.greater(tf.shape(positive_overlaps)[1], 0), true_fn = lambda: tf.argmax(positive_overlaps, axis=1), false_fn = lambda: tf.cast(tf.constant([]),tf.int64) )

Code for portrait segmentation is here:
https://github.com/ypflll/portrait_seg_maskrcnn/blob/master/portrait.py
Things not clear for me are what causes this problem and how to reproduce it.

The second problem I found as above is:
F tensorflow/stream_executor/cuda/cuda_dnn.cc:444] could not convert BatchDescriptor {count: 0 feature_map_count: 1 spatial: 28 28 value_min: 0.000000 value_max: 0.000000 layout: BatchDepthYX} to cudnn tensor descriptor: CUDNN_STATUS_BAD_PARAM
Aborted (core dumped)

I've thought it's caused by tf Conv2D backwards doesn't support zero batch size, like:
tensorflow/tensorflow#14657.
After debugging by adding my code on your primary code, I find that this occurs when I add a new loss, code is here:
https://github.com/ypflll/portrait_seg_maskrcnn/blob/master/model.py
When the first problem occurs, some tensors are NULL, and this causes the second problem when running the code I add.
This may beyond this issue. I will explore it later when I have time.

@kstseng
Copy link

kstseng commented May 1, 2018

follow the answer from @horvitzs , and it works!
Thanks!

@ziyigogogo
Copy link
Contributor

ziyigogogo commented May 14, 2018

@waleedka i am using 1.7.0 version of tensorflow and met the same problem, so i do not think it's related to tensorflow version at the moment.
I am now trying the code from @horvitzs, not sure if it works. But i will update.

By the way, I met this problem while trying to run this project: https://github.com/crowdAI/crowdai-mapping-challenge-mask-rcnn.
The full datasets has this problem while the small subset dataset does not.
I am trying to locate the particular image. If i successed, i will post that image here to see if you can reproduce the same problem.

julienr pushed a commit to Picterra/Mask_RCNN that referenced this issue May 15, 2018
waleedka pushed a commit that referenced this issue Jun 5, 2018
LexLuc pushed a commit to LexLuc/Mask_RCNN that referenced this issue Jun 5, 2018
@waleedka
Copy link
Collaborator

waleedka commented Jun 5, 2018

I merged a PR by @julienr that might help with this issue. It's similar to the fix suggested above by @horvitzs.

I did more testing on this case focusing on detection_targets_graph(). This function tries to match ROIs with ground truth boxes. I focused on these two edge cases:

  1. No positive ROIs
  2. No ground truth boxes

1. No positive ROIs

If I update the code to simulate not finding any positive ROIs, I get a crash and core dump on TF 1.8.

./tensorflow/core/util/cuda_launch_config.h:127] Check failed: work_element_count > 0 (0 vs. 0)
Aborted (core dumped)

I verified that detection_target_graph() is running all the way through without problems. So the error is happening at some point after that.

2. No ground truth boxes

When I update the code to simulate not having any ground truth boxes, I get the same error reported by @ypflll above. The recently merged PR fixes this issue. So now detection_target_graph() runs through correctly, and then I get a crash and core dump after that, just like the above.

So, the good news: we fixed the error reported above. The bad news: One of the TF operations are crashing when it receives a tensor with one of it's dimensions as 0.

@waleedka waleedka reopened this Jun 5, 2018
LexLuc added a commit to LexLuc/Mask_RCNN that referenced this issue Jun 6, 2018
* Small typo fix

* loss weights

* Fix multi-GPU training.

A previous fix to let validation run across more
than one batch caused an issue with multi-GPU
training. The issue seems to be in how Keras
averages loss and metric values, where it expects
them to be scalars rather than arrays. This fix
causes scalar outputs from a model to remain
scalar in multi-GPU training.

* Replace keep_dims with keepdims in TF calls.

TF replaced keep_dims with keepdims a while ago
and now shows a warning when using the old name.

* Headline typo fix in README.md

Fixed the typo in the headline of the README.md file. "Spash" should be "Splash"

* Splash sample: fix filename and link to blog post

* Update utils.py

* Minor cleanup in compute_overlaps_masks()

* Fix: color_splash when no masks are detected

Reported here: matterport#500

* fix typo

fix typo

* fix "No such file or directory" if not use: "keras.callbacks.TensorBoard"

* Allow dashes in model name.
Print a message when re-starting from saved epoch

* Fix problem with argmax on (0,0) arrays.

Fix matterport#170

* Allow configuration of FPN layers size and top-down pyramid size

* Allow custom backbone implementation through Config.BACKBONE

This allows one to set a callable in Config.BACKBONE to use a custom
backbone model.

* modified comment for image augmentation line import to include correct 'pip3 install imgaug' instructions

* Raise clear error if last training weights are not foundIf using the --weights=last (or --model=last) to resume trainingbut the weights are not found now it raises a clear error message.
@fedebayle
Copy link

By the way, I met this problem while trying to run this project: https://github.com/crowdAI/crowdai-mapping-challenge-mask-rcnn.
The full datasets has this problem while the small subset dataset does not.
I am trying to locate the particular image. If i successed, i will post that image here to see if you can reproduce the same problem.

@ziyigogogo Did you located the particular image? I'm experimenting the same problem with the full dataset.

LackesLab pushed a commit to LackesLab/Mask_RCNN that referenced this issue Aug 24, 2018
LexLuc added a commit to LexLuc/Mask_RCNN that referenced this issue Sep 28, 2018
* Small typo fix

* loss weights

* Fix multi-GPU training.

A previous fix to let validation run across more
than one batch caused an issue with multi-GPU
training. The issue seems to be in how Keras
averages loss and metric values, where it expects
them to be scalars rather than arrays. This fix
causes scalar outputs from a model to remain
scalar in multi-GPU training.

* Replace keep_dims with keepdims in TF calls.

TF replaced keep_dims with keepdims a while ago
and now shows a warning when using the old name.

* Headline typo fix in README.md

Fixed the typo in the headline of the README.md file. "Spash" should be "Splash"

* Splash sample: fix filename and link to blog post

* Update utils.py

* Minor cleanup in compute_overlaps_masks()

* Fix: color_splash when no masks are detected

Reported here: matterport#500

* fix typo

fix typo

* fix "No such file or directory" if not use: "keras.callbacks.TensorBoard"

* Allow dashes in model name.
Print a message when re-starting from saved epoch

* Fix problem with argmax on (0,0) arrays.

Fix matterport#170

* Allow configuration of FPN layers size and top-down pyramid size

* Allow custom backbone implementation through Config.BACKBONE

This allows one to set a callable in Config.BACKBONE to use a custom
backbone model.

* modified comment for image augmentation line import to include correct 'pip3 install imgaug' instructions

* Raise clear error if last training weights are not foundIf using the --weights=last (or --model=last) to resume trainingbut the weights are not found now it raises a clear error message.

* Fix Keras engine topology to saving

* Fix load_weights() for Keras versions before 2.2

Improve previous commit to not break on older versions of Keras.

* Update README.md

* Add custom callbacks to model training

Add an optional parameter for calling a list of keras.callbacks to be add to the original list.

* Add no augmentation sources

Add the possibility to exclude some sources from augmentation by passing a list of sources. This is useful when you want to retrain a model having few images.

* Improve previous commit to avoid mutable default arguments

* Updated Coco Example

* edit loss desc

* spellcheck config.py

* doublecheck on config.py

* spellcheck utils.py

* spellcheck visualize.py

* Links to two more projects in README

* Add Bibtex to README

* make pre_nms_limit configurable

* Make pre_nms_limit configurable

* Made compatible to new version of VIA JSON format

VIA has changed JSON formatting in later versions. Now instead of a dictionary, "regions" has a list, see the issue matterport#928

* Comments to explain VIA 2.0 JSON change

* Fix the comment on output shape in RPN

* Bugfix for MaskRCNN creating empty log_dir that breaks find_last()
- Current implementation creates self.log_dir in set_log_dir() function,
  which creates an empty log directory if none exists. This causes
  find_last() to fail after creating a model because it finds this new
  empty directory instead of the previous training directory.
- New implementation moves log_dir creation to the train() function to
  ensure it is only created when it will be used.

* Added automated epoch recognition for Windows. (matterport#798)

Unified regex expression for both, Linux and Windows.

* Fixed tabbing issue in previous commit

* bug fix: the output_shape of roi_gt_class_ids is incorrect

* Bug fix: inspect_balloon_model.ipynb

Fix bugs of not showing boxes in 1.b RPN Predictions.
TF 1.9 introduces "ROI/rpn_non_max_suppression/NonMaxSuppressionV3:0", so the original code can't work.

* Apply previous commit to the other notebooks

* Fixed comment on GPU_COUNT (matterport#878)

Fixed comment on GPU_COUNT

* add IMAGE_CHANNEL_COUNT class variable to config to make it easier to use Mask_RCNN for non 3-channel images

* Additional comments for the previous commit

* Link to new projects in README

* Tiny correction in README.

* Adjust PyramidROIAlign layer shape comment

For PyramidROIAlign's output shape, use pool_height and pool_width instead of height and width to avoid confusion with those of feature_maps.

* fix output shape of fpn_classifier_graph

1. fix the comment on output shape in fpn_classifier_graph
2. unify NUM_CLASSES and num_classes to NUM_CLASSES
3. unify boxes, num_boxes, num_rois, roi_count to num_rois
4. use more specific POOL_SIZE and MASK_ POOL_SIZE to replace pool_height and pool_width

* Fix PyramidROIAlign output shape

As discussed in: matterport#919

* Fix comments in Detection Layer

1. fix description on window
2. fix output shape of detection layer

* use smooth_l1_loss() to reduce code duplication

* A wrapper for skimage resize() to avoid warnings

skimage generates different warnings depending on the version. This wrapper function calls skimage.tranform.resize() with the right parameter for each version.

* Remove unused method: append_data()
LexLuc added a commit to LexLuc/Mask_RCNN that referenced this issue Sep 28, 2018
* Small typo fix

* loss weights

* Fix multi-GPU training.

A previous fix to let validation run across more
than one batch caused an issue with multi-GPU
training. The issue seems to be in how Keras
averages loss and metric values, where it expects
them to be scalars rather than arrays. This fix
causes scalar outputs from a model to remain
scalar in multi-GPU training.

* Replace keep_dims with keepdims in TF calls.

TF replaced keep_dims with keepdims a while ago
and now shows a warning when using the old name.

* Headline typo fix in README.md

Fixed the typo in the headline of the README.md file. "Spash" should be "Splash"

* Splash sample: fix filename and link to blog post

* Update utils.py

* Minor cleanup in compute_overlaps_masks()

* Fix: color_splash when no masks are detected

Reported here: matterport#500

* fix typo

fix typo

* fix "No such file or directory" if not use: "keras.callbacks.TensorBoard"

* Allow dashes in model name.
Print a message when re-starting from saved epoch

* Fix problem with argmax on (0,0) arrays.

Fix matterport#170

* Allow configuration of FPN layers size and top-down pyramid size

* Allow custom backbone implementation through Config.BACKBONE

This allows one to set a callable in Config.BACKBONE to use a custom
backbone model.

* modified comment for image augmentation line import to include correct 'pip3 install imgaug' instructions

* Raise clear error if last training weights are not foundIf using the --weights=last (or --model=last) to resume trainingbut the weights are not found now it raises a clear error message.

* Fix Keras engine topology to saving

* Fix load_weights() for Keras versions before 2.2

Improve previous commit to not break on older versions of Keras.

* Update README.md

* Add custom callbacks to model training

Add an optional parameter for calling a list of keras.callbacks to be add to the original list.

* Add no augmentation sources

Add the possibility to exclude some sources from augmentation by passing a list of sources. This is useful when you want to retrain a model having few images.

* Improve previous commit to avoid mutable default arguments

* Updated Coco Example

* edit loss desc

* spellcheck config.py

* doublecheck on config.py

* spellcheck utils.py

* spellcheck visualize.py

* Links to two more projects in README

* Add Bibtex to README

* make pre_nms_limit configurable

* Make pre_nms_limit configurable

* Made compatible to new version of VIA JSON format

VIA has changed JSON formatting in later versions. Now instead of a dictionary, "regions" has a list, see the issue matterport#928

* Comments to explain VIA 2.0 JSON change

* Fix the comment on output shape in RPN

* Bugfix for MaskRCNN creating empty log_dir that breaks find_last()
- Current implementation creates self.log_dir in set_log_dir() function,
  which creates an empty log directory if none exists. This causes
  find_last() to fail after creating a model because it finds this new
  empty directory instead of the previous training directory.
- New implementation moves log_dir creation to the train() function to
  ensure it is only created when it will be used.

* Added automated epoch recognition for Windows. (matterport#798)

Unified regex expression for both, Linux and Windows.

* Fixed tabbing issue in previous commit

* bug fix: the output_shape of roi_gt_class_ids is incorrect

* Bug fix: inspect_balloon_model.ipynb

Fix bugs of not showing boxes in 1.b RPN Predictions.
TF 1.9 introduces "ROI/rpn_non_max_suppression/NonMaxSuppressionV3:0", so the original code can't work.

* Apply previous commit to the other notebooks

* Fixed comment on GPU_COUNT (matterport#878)

Fixed comment on GPU_COUNT

* add IMAGE_CHANNEL_COUNT class variable to config to make it easier to use Mask_RCNN for non 3-channel images

* Additional comments for the previous commit

* Link to new projects in README

* Tiny correction in README.

* Adjust PyramidROIAlign layer shape comment

For PyramidROIAlign's output shape, use pool_height and pool_width instead of height and width to avoid confusion with those of feature_maps.

* fix output shape of fpn_classifier_graph

1. fix the comment on output shape in fpn_classifier_graph
2. unify NUM_CLASSES and num_classes to NUM_CLASSES
3. unify boxes, num_boxes, num_rois, roi_count to num_rois
4. use more specific POOL_SIZE and MASK_ POOL_SIZE to replace pool_height and pool_width

* Fix PyramidROIAlign output shape

As discussed in: matterport#919

* Fix comments in Detection Layer

1. fix description on window
2. fix output shape of detection layer

* use smooth_l1_loss() to reduce code duplication

* A wrapper for skimage resize() to avoid warnings

skimage generates different warnings depending on the version. This wrapper function calls skimage.tranform.resize() with the right parameter for each version.

* Remove unused method: append_data()
Cpruce pushed a commit to Cpruce/Mask_RCNN that referenced this issue Jan 17, 2019
withyou53 pushed a commit to withyou53/mask_r-cnn_for_object_detection that referenced this issue Sep 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
7 participants