Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multilabel trainning problem #29

Closed
vndee opened this issue Feb 27, 2019 · 2 comments
Closed

Multilabel trainning problem #29

vndee opened this issue Feb 27, 2019 · 2 comments

Comments

@vndee
Copy link

vndee commented Feb 27, 2019

Hi,
I have used this code to train a Document Layout Analysis model. I set:
prediction_type = utils.PredictionType.MULTILABEL
And my classes.txt (9 classes) file:

0 0 0 1 0 0 0 0 0 0 0 0
25 255 255 0 1 0 0 0 0 0 0 0
142 130 255 0 0 1 0 0 0 0 0 0
191 130 74 0 0 0 1 0 0 0 0 0
191 14 74 0 0 0 0 1 0 0 0 0
191 181 74 0 0 0 0 0 1 0 0 0
36 13 249 0 0 0 0 0 0 1 0 0
110 49 7 0 0 0 0 0 0 0 1 0
250 246 7 0 0 0 0 0 0 0 0 1

But I've got an error:

Caused by op 'Label2Img/GatherNd', defined at:
File "train.py", line 47, in
@ex.automain
File "/home/it/.local/lib/python3.6/site-packages/sacred/experiment.py", line 137, in automain
self.run_commandline()
File "/home/it/.local/lib/python3.6/site-packages/sacred/experiment.py", line 260, in run_commandline
return self.run(cmd_name, config_updates, named_configs, {}, args)
File "/home/it/.local/lib/python3.6/site-packages/sacred/experiment.py", line 209, in run
run()
File "/home/it/.local/lib/python3.6/site-packages/sacred/run.py", line 221, in call
self.result = self.main_function(*args)
File "/home/it/.local/lib/python3.6/site-packages/sacred/config/captured_function.py", line 46, in captured_function
result = wrapped(*args, **kwargs)
File "train.py", line 111, in run
num_threads=32))
File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 356, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1181, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1208, in _train_model_default
input_fn, model_fn_lib.ModeKeys.TRAIN))
File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1049, in _get_features_and_labels_from_input_fn
self._call_input_fn(input_fn, mode))
File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1136, in _call_input_fn
return input_fn(**kwargs)
File "/home/it/Projects/DLA/dhSegment/dh_segment/io/input.py", line 224, in fn
label_export = utils.multiclass_to_label_image(label_export, classes_file)
File "/home/it/Projects/DLA/dhSegment/dh_segment/utils/labels.py", line 67, in multiclass_to_label_image
return tf.gather_nd(c, tf.cast(class_label_tensor, tf.int32))
File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3140, in gather_nd
"GatherNd", params=params, indices=indices, name=name)
File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
return func(*args, **kwargs)
File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3272, in create_op
op_def=op_def)
File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1768, in init
self._traceback = tf_stack.extract_stack()
InvalidArgumentError (see above for traceback): Only indices.shape[-1] values between 1 and 7 are currently supported. Requested rank: 9
[[{{node Label2Img/GatherNd}} = GatherNd[Tindices=DT_INT32, Tparams=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](Label2Img/GatherNd/params, Cast_1)]]

I guess that we can not train a multilabel classification model with more than 7 classes. Can anyone help me to fix this problems?
Thanks.

@solivr
Copy link
Member

solivr commented Feb 28, 2019

Hi @vndee,

It seems that this error is due to multiclass_to_label_image which is only used for visualization purposes (tf.summary), and not during training. I'll need to have a closer look to see if there is another/better way for visualizing many labels but for now you can comment the lines related to image summaries (the ones using multiclass_to_label_image).
Also regarding your classes.txt file, your classes seems to be mutually exclusive, i.e only one label is given to each color, so you should be using PredictionType.CLASSIFICATION (with only RGB codes in the classes.txt).

@vndee
Copy link
Author

vndee commented Feb 28, 2019

I've just fixed my problems by reducing the number of classes to 7.
Thanks for you reply, it help me to clearly understand work flow of the model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants