Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError when channels > 3 #4

Closed
tdrobbins opened this issue May 14, 2020 · 1 comment
Closed

ValueError when channels > 3 #4

tdrobbins opened this issue May 14, 2020 · 1 comment

Comments

@tdrobbins
Copy link
Contributor

Trying to train on images with more than three channels raises a ValueError. I fixed the bug with a really simple patch to unet.utils.to_rgb() and can submit a pull request, if you like.

Here's the full error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-8-220be538e83e> in <module>
----> 1 trainer.fit(model4, data)

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/unet/trainer.py in fit(self, model, train_dataset, validation_dataset, test_dataset, epochs, batch_size, **fit_kwargs)
     94                             epochs=epochs,
     95                             callbacks=callbacks,
---> 96                             **fit_kwargs)
     97 
     98         self.evaluate(model, test_dataset, prediction_shape)

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
     64   def _method_wrapper(self, *args, **kwargs):
     65     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
---> 66       return method(self, *args, **kwargs)
     67 
     68     # Running inside `run_distribute_coordinator` already.

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
    874           epoch_logs.update(val_logs)
    875 
--> 876         callbacks.on_epoch_end(epoch, epoch_logs)
    877         if self.stop_training:
    878           break

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in on_epoch_end(self, epoch, logs)
    363     logs = self._process_logs(logs)
    364     for callback in self.callbacks:
--> 365       callback.on_epoch_end(epoch, logs)
    366 
    367   def on_train_batch_begin(self, batch, logs=None):

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/unet/callbacks.py in on_epoch_end(self, epoch, logs)
     32         self._log_histogramms(epoch, predictions)
     33 
---> 34         self._log_image_summaries(epoch, predictions)
     35 
     36         self.file_writer.flush()

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/unet/callbacks.py in _log_image_summaries(self, epoch, predictions)
     50                                  utils.to_rgb(cropped_labels[..., :1].numpy()),
     51                                  utils.to_rgb(mask)),
---> 52                                 axis=2)
     53 
     54         with self.file_writer.as_default():

<__array_function__ internals> in concatenate(*args, **kwargs)

ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 3, the array at index 0 has size 4 and the array at index 1 has size 3

And a quick script to reproduce the error:

import unet
import tensorflow as tf

X = tf.random.normal((100,256,256,4))
Y_flat = tf.random.categorical(tf.math.log([[0.5, 0.5]]),100*256*256)
Y = tf.reshape(Y_flat,(100,256,256))
Y_onehot = tf.one_hot(Y,2)

data = tf.data.Dataset.from_tensor_slices((X,Y_onehot))
train_data = data.take(75)
test_data = data.skip(75)

model4 = unet.build_model(256,256,channels=4,padding="same")
unet.finalize_model(model4,loss=tf.keras.losses.categorical_crossentropy)
trainer = unet.Trainer()

trainer.fit(model4, data)
@jakeret
Copy link
Owner

jakeret commented May 15, 2020

Hi, thanks for reporting this. Yes, I somehow always had 1 or 3 channels in mind when writing to_rgb.
Would be great if you could send me a PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants