Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi Class Segmentation #81

Closed
nargeshn opened this issue Jul 28, 2017 · 14 comments
Closed

Multi Class Segmentation #81

nargeshn opened this issue Jul 28, 2017 · 14 comments
Labels

Comments

@nargeshn
Copy link

I think this question has been asked by other people but I can not find the issue and your response.
I am trying to use U_net for segmentation of medical images. The segmentations contain more than one label. I modified the labels to binary but I am just curious if U-Net can handle the multi_Class segmentation.

@nargeshn nargeshn changed the title Multi_Class Segmentation Multi Class Segmentation Jul 28, 2017
@jakeret
Copy link
Owner

jakeret commented Jul 29, 2017

Yes it can handle mutli class segmentation, see this example where I segment, starts, galaxies and background pixels.

Training, however, can be hard. People had successes training one model per class and combining their predictions

@nargeshn
Copy link
Author

nargeshn commented Aug 2, 2017

Thanks for your response!
I am still a bit confused and just want to make sure that I have understood the process correctly.
For multi class (for example 10) I need to create a mask for every class in every image. So in this case for every image, I need to create 10 masks. As the result, I will have 10 predictions for every image. At the time of classification, for every pixel, I need to choose the class that has the highest probability (highest value for prediction).
Please correct me if I am wrong.
In this way, if we have 10 classes but not all images have 10 classes (for example some images have 7 or 8 classes), wouldn't it cause any issue?

@jakeret
Copy link
Owner

jakeret commented Aug 3, 2017

Frankly, my experience with multiclass segmentation problems is rather limited. Anyway, as often with DL there is no absolute correct answer. Which means you have to experiment what works best for your problem and data set. What you describe is probably what I would try. I recommend to read the reports of the Kaggle DSTL challenge winners

@nargeshn
Copy link
Author

Hello jakeret,

I tried multi-label segmentation for Unet. As you mentioned in the previous post regarding the multilabel segmentation, I created my train data and corresponding labels with the following dimensions and these are what I sent to the data provider (n_class = 3).
data = [number_of_trainedData , image.shape(1), image.shape(2) , channels]
label = [number_of_trainedData , image.shape(1), image.shape(2) , n_class]
The network can be trained but the problem is that when I want to test and I initialize the network as follow:

net = unet.Unet(channels=1, n_class=3, cost="cross_entropy" , features_root=64)

I receive the following error:

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [1,1,64,3] rhs shape= [1,1,64,2]
[[Node: save/Assign_17 = Assign[T=DT_FLOAT, _class=["loc:@Variable_24"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_24, save/RestoreV2_17)]]

and when I change the initialization of network as follow, it works
net = unet.Unet(channels=1, n_class=2, cost="cross_entropy" , features_root=64)

So it means that although I have used the 3 labels in my train data, the network still is training on just 2 labels. Have you ever trained the network on multiple labels?
I am trying to figure out the problem and make some changes in network, but no success yet :(

@jakeret
Copy link
Owner

jakeret commented Sep 28, 2017

Hard to tell what is going on without the traceback.
Have you checked the shapes of what your data provider is returning?

@nargeshn
Copy link
Author

nargeshn commented Sep 28, 2017

Thanks for your reply. Yes, I checked what I receive from the data provider and I also visualized them to see if they are what I want. Data provider creates the correct input.
I slightly changed the code and here is complete traceback:

Traceback (most recent call last):
File "output_creation.py", line 90, in
patch2Image_2D(inputImg, labelMaskImg, slice_number, patchSize_2d, acceptRate , stride)
File "output_creation.py", line 53, in patch2Image_2D
prediction = net.predict(path, x_test)
File "/Users/Narges/Documents/SummerResearch/tf_unet/tf_unet/unet.py", line 262, in predict
self.restore(sess, model_path)
File "/Users/Narges/Documents/SummerResearch/tf_unet/tf_unet/unet.py", line 290, in restore
saver.restore(sess, model_path)
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1548, in restore
{self.saver_def.filename_tensor_name: save_path})
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 789, in run
run_metadata_ptr)
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 997, in _run
feed_dict_string, options, run_metadata)
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1132, in _do_run
target_list, options, run_metadata)
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1152, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [3] rhs shape= [2]
[[Node: save/Assign_18 = Assign[T=DT_FLOAT, _class=["loc:@Variable_25"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_25, save/RestoreV2_18)]]

Caused by op 'save/Assign_18', defined at:
File "output_creation.py", line 90, in
patch2Image_2D(inputImg, labelMaskImg, slice_number, patchSize_2d, acceptRate , stride)
File "output_creation.py", line 53, in patch2Image_2D
prediction = net.predict(path, x_test)
File "/Users/Narges/Documents/SummerResearch/tf_unet/tf_unet/unet.py", line 262, in predict
self.restore(sess, model_path)
File "/Users/Narges/Documents/SummerResearch/tf_unet/tf_unet/unet.py", line 289, in restore
saver = tf.train.Saver()
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1139, in init
self.build()
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1170, in build
restore_sequentially=self._restore_sequentially)
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 691, in build
restore_sequentially, reshape)
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 419, in _AddRestoreOps
assign_ops.append(saveable.restore(tensors, shapes))
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 155, in restore
self.op.get_shape().is_fully_defined())
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/state_ops.py", line 271, in assign
validate_shape=validate_shape)
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/gen_state_ops.py", line 45, in assign
use_locking=use_locking, name=name)
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
op_def=op_def)
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2506, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1269, in init
self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [3] rhs shape= [2]
[[Node: save/Assign_18 = Assign[T=DT_FLOAT, _class=["loc:@Variable_25"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_25, save/RestoreV2_18)]]

@nargeshn
Copy link
Author

nargeshn commented Sep 28, 2017

Also, this is the code that I used for training the network:

data_provider = image_util.SimpleDataProvider(data , label , channels=1, n_class = 3)
net = unet.Unet(channels=data_provider.channels,
n_class=data_provider.n_class,
layers=3,
cost_kwargs=dict(regularizer=0.001,
class_weights=weights),
)

@jakeret
Copy link
Owner

jakeret commented Oct 4, 2017

well the training seems to work fine. That about x_test that you pass in

File "output_creation.py", line 53, in patch2Image_2D
prediction = net.predict(path, x_test)

@sindhurk
Copy link

@nargeshn may I know what tool u used to do generate masks for training?

@dhinkris
Copy link

@nargeshn may I know what tool u used to do generate masks for training?

itk-snap may work if dealing with nifti images

@dhinkris
Copy link

Also, this is the code that I used for training the network:

data_provider = image_util.SimpleDataProvider(data , label , channels=1, n_class = 3)
net = unet.Unet(channels=data_provider.channels,
n_class=data_provider.n_class,
layers=3,
cost_kwargs=dict(regularizer=0.001,
class_weights=weights),
)

any luck on this implementation? at present I am dealing with the same issue

@meijie0401
Copy link

@dhinkris You should edit the image_util.py file. In it, you should set classes to 3 and change codes in _process_labels part.

@ikramabdel
Copy link

@meijie0401 Hi i am currently attempting the same change but i dont understand what should be changes in _process_labels part ? Thanks !

@q841496770
Copy link

@meijie0401 Hi i am currently attempting the same change but i dont understand what should be changes in _process_labels part ? Thanks !

Do you solved it ? I have the same problem with you . do you have any idea ? thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

7 participants