Skip to content

Conversation

@rcasero
Copy link

@rcasero rcasero commented Jul 25, 2018

Summary

Add element-wise weighting of the loss function.
Described in Issue #10561

Keras API Design Review google doc with comments enabled here
https://docs.google.com/document/d/19BDXgNmeTgpgb9xYKzNboXyM7XX2PeM3mlvCFCdIQj0/edit?usp=sharing

Related Issues

None, as far as I know.

PR Overview

  • [y] This PR requires new unit tests [y/n] (make sure tests are included)

As described in the API Design Review, I've had some trouble with this, and could use some guidance.

  • [y] This PR requires to update the documentation [y/n] (make sure the docs are up-to-date)

Help added in the code, but as noted in the API Design Review doc, I also need a bit of guidance with that.

  • [?] This PR is backwards compatible [y/n]

I don't know how to test this.

  • [y] This PR changes the current API [y/n] (all API changes need to be approved by @fchollet )

It adds a new possible value 'element' to the option sample_weight_mode in model.compile().

@rcasero
Copy link
Author

rcasero commented Aug 1, 2018

I have asked for guidance in the thread "How to debug a keras pull request" in the mailing list.

rcasero added 2 commits August 5, 2018 17:40
…work.errors_impl.InvalidArgumentError: Incompatible shapes
Some text lines were too long
@fchollet
Copy link
Collaborator

fchollet commented Aug 5, 2018

@pavithrasv what's your take on this feature?

@fchollet
Copy link
Collaborator

fchollet commented Aug 5, 2018

Thank you for the PR @rcasero. Will review soon.

Copy link
Contributor

@Dapid Dapid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! We need some tests, though.

@Dapid
Copy link
Contributor

Dapid commented Aug 6, 2018

I think this feature is very interesting, and I will be using it a lot.

My use case is predictions of some experimental results. When I am predicting 1D features, I can use "temporal" mode and mask out the points where I am missing data, but when I am predicting 2D, I can no longer use that trick. My solution now is to wrap my model and manually force the outputs at the missing pixels to 0; essentially multiplying the outputs by the "present" mask.

This mode would allow me to have simpler models, avoid the extra complication of extracting the inner model when I need to use it, and also use the same structure for 1 and 2D.

@rcasero
Copy link
Author

rcasero commented Aug 6, 2018

I wrote a test script, but I'd need help to turn it into test units. (As I mention in the Keras API Design Review google doc, I've tried following the instructions to test keras, but it fails for me even with the unpatched main keras branch.)

@pavithrasv
Copy link
Contributor

I can see this feature being useful for use cases like the one @Dapid has mentioned. It could also be used to mask unknown elements in a sample. Did a first pass through the code, will review again after unit tests have been added.

@rcasero
Copy link
Author

rcasero commented Aug 22, 2018

Suggestions by @pavithrasv pushed in commit 157244c

@rcasero
Copy link
Author

rcasero commented Aug 23, 2018

I implemented this feature because I generate training data with unknown elements, as @pavithrasv mentions. Training a network this way is working for me.

As mentioned above, I have a testing script, but I don't know how to generate unit tests in keras.

@fchollet
Copy link
Collaborator

@pavithrasv Could you please take a look at the recent changes? Thank you.

Copy link
Contributor

@pavithrasv pavithrasv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the changes. Can you add unit tests to test_training.py :https://github.com/keras-team/keras/blob/master/tests/keras/engine/test_training.py

@rcasero
Copy link
Author

rcasero commented Sep 18, 2018

@pavithrasv What I see in "View changes" proposed by you is the patch I've submitted, right? Do I need to do anything about that?

Roger about the unit tests.

@pavithrasv
Copy link
Contributor

Yes, the only change I had requested after that commit were the unit tests. Thank you!

@rcasero
Copy link
Author

rcasero commented Sep 18, 2018

@pavithrasv I've found an error if one has two outputs, and one is e.g. (None, 22, 22, 2) instead of (None, 22, 22, 1). Code and error below. Any help appreciated!

import os
os.environ['KERAS_BACKEND'] = 'tensorflow'
import keras
import keras.backend as K
import numpy as np

from keras.models import Model, Sequential
from keras.layers import Activation, Conv2D, Input
from keras.layers.normalization import BatchNormalization

from keras.utils import multi_gpu_model

# remove warning "Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Just disables the warning, doesn't enable AVX/FMA

# image_data_format = 'channels_first'
image_data_format = 'channels_last'

K.set_image_data_format(image_data_format)

# simulate input images
im = np.zeros(shape=(10, 64, 64, 3), dtype='uint8')

# simulate network output
out = 2 * np.ones(shape=(10, 64, 64, 1), dtype='float32')
aux_out = 5 * np.ones(shape=(10, 22, 22, 1), dtype='float32')
# simulate training weights for network output
# weight = np.ones(shape=(10, 64, 64, 1), dtype='float32')
weight = np.ones(shape=(10, 64, 64, 1), dtype='float32')
aux_weight = np.ones(shape=(10, 22, 22, 1), dtype='float32')

# simulate validation data
im_validation = 3 * np.ones(shape=(5, 64, 64, 3), dtype='uint8')
out_validation = 4 * np.ones(shape=(5, 64, 64, 1), dtype='float32')

validation_data = (im_validation, out_validation)

# optimizer
optimizer = keras.optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)

'''Multi-output CNN with outputs of different number of features
'''

# create network model
input = Input(shape=im.shape[1:], dtype='float32')
x = Conv2D(filters=32, kernel_size=(3, 3), strides=1, padding='same')(input)
x = BatchNormalization(axis=3)(x)
x = Activation('relu')(x)

main_output = Conv2D(filters=1, kernel_size=(1, 1), strides=1, padding='same', name='main_output')(x)
aux_output = Conv2D(filters=2, kernel_size=(1, 1), strides=3, padding='same', name='aux_output')(x)

model = Model(inputs=input, outputs=[main_output, aux_output])

'''list format (sample_weight_mode=['element', 'element'])
'''

model.compile(loss='mae', optimizer=optimizer, metrics=['accuracy'],
              sample_weight_mode=['element', 'element'])

model.fit(im, [out, np.repeat(aux_out, repeats=2, axis=3)],
          sample_weight=[weight, np.repeat(aux_weight, repeats=2, axis=3)],
          batch_size=3, epochs=3)
model.fit(im, [out, np.repeat(aux_out, repeats=2, axis=3)],
          sample_weight=[weight, np.repeat(aux_weight, repeats=2, axis=3)],
          batch_size=3, epochs=3)
Epoch 1/3
Traceback (most recent call last):
  File "<input>", line 3, in <module>
  File "/home/rcasero/Software/keras_branch_sample_weight/keras/engine/training.py", line 1070, in fit
    validation_steps=validation_steps)
  File "/home/rcasero/Software/keras_branch_sample_weight/keras/engine/training_arrays.py", line 199, in fit_loop
    outs = f(ins_batch)
  File "/home/rcasero/Software/keras_branch_sample_weight/keras/backend/tensorflow_backend.py", line 2661, in __call__
    return self._call(inputs)
  File "/home/rcasero/Software/keras_branch_sample_weight/keras/backend/tensorflow_backend.py", line 2631, in _call
    fetched = self._callable_fn(*array_vals)
  File "/home/rcasero/.conda/envs/cytometer_tensorflow/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1454, in __call__
    self._session._session, self._handle, args, status, None)
  File "/home/rcasero/.conda/envs/cytometer_tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 519, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 2904 values, but the requested shape has 1452
	 [[Node: loss/aux_output_loss/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:GPU:0"](_arg_aux_output_sample_weights_0_4/_113, loss/aux_output_loss/Shape_1)]]
	 [[Node: loss/add/_151 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_723_loss/add", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

@gabrieldemarmiesse
Copy link
Contributor

I think something is going wrong around those lines:

# reduce weight array to same ndim as score_array (needed for
# sample_weight_mode='element')
if weight_ndim > K.ndim(score_array):
      weights = K.reshape(weights, K.shape(score_array))

If you do print(K.int_shape(weights)) and print(K.int_shape(score_array)) before the reshape, the error should be obvious. I didn't pull your branch though, so no guarantees.

@rcasero
Copy link
Author

rcasero commented Sep 19, 2018

Thanks @gabrieldemarmiesse . You were right, the problem was there. Furthermore, the patch had a conceptual error, because the element-wise weights should have the size of score_array, not the output. Both have been fixed with commit d5701ea.

I still need to do some tests and write the test units.

@rcasero
Copy link
Author

rcasero commented Sep 19, 2018

My last commits broke the 10782.8 test, Python: 2.7, KERAS_BACKEND=cntk in the Travis CI build, in particular tests/test_multiprocessing.py::test_multiprocessing_predict_error.

Can't debug for the moment, because keras docker requires nvidia-docker, which doesn't accept Ubuntu 17.10 (my current distribution).

On a related note, I had to:

  • Edit keras/docker/Makefile, adding --network host to the build, because otherwise my computer cannot reach ubuntu servers
docker build -t keras --build-arg python_version=$(PYTHON_VERSION) ...
  • Edit keras/docker/Dockerfile, replacing
git clone git://github.com/

with

git clone http://github.com/

and replacing

pip install git+git://github.com/keras-team/keras.git

with

pip install git+http://github.com/keras-team/keras.git

because otherwise the git and pip commands time out without completing.

@gabrieldemarmiesse
Copy link
Contributor

The error is not due to your commit. It is a flaky test. I'll restart the build for you. You can also install keras cpu by folowing the instructions in .travis.yml to debug stuff locally.

@rcasero
Copy link
Author

rcasero commented Sep 19, 2018

Thanks, @gabrieldemarmiesse . I'll look at .travis.yml for the testing.

@gabrieldemarmiesse
Copy link
Contributor

I'll try to fix the test. Please use git pull before working again on your branch.

@gabrieldemarmiesse
Copy link
Contributor

I can't review this anymore since I added some commits. We need new reviewers. The build is passing, so it's ready for review.

@rcasero
Copy link
Author

rcasero commented Nov 2, 2018

Thanks. Anything else that needs to be done? Do we need that error message if the weights ndim is not 1 less than the output's ndim?

@gabrieldemarmiesse
Copy link
Contributor

We need this error message. You already wrote it. It looks like this now:

if sample_weight is not None and sample_weight.shape != score_array_shape:
    raise ValueError('Found a `sample_weight` array with shape ' +
                             str(sample_weight.shape) +
                             ' for output with shape ' +
                             str(y.shape) +
                             '. When sample_weight_mode="element", ' +
                             'weights and score_array must have the same size.'
                             'Your `sample_weight` array should have the '
                             'following shape: ' + str(score_array_shape))

@gabrieldemarmiesse
Copy link
Contributor

I think we just need to wait for the build, and if it passes, for a review from another member of the keras team. Thanks for your work @rcasero !

@rcasero
Copy link
Author

rcasero commented Nov 2, 2018

Thanks for your help, @gabrieldemarmiesse

@gabrieldemarmiesse
Copy link
Contributor

@pavithrasv, this PR is ready. Could you please take a look?

@rcasero
Copy link
Author

rcasero commented Dec 5, 2018

@pavithrasv Just bumping this up

@gabrieldemarmiesse
Copy link
Contributor

@fchollet please take a look at it when you have the time.

@rcasero
Copy link
Author

rcasero commented Feb 4, 2019

@fchollet @gabrieldemarmiesse bumping this up, as it seems to have been forgotten

@gabrieldemarmiesse
Copy link
Contributor

I haven't forgotten, but each PR changing the API needs @fchollet 's approval. We need to be better organised for this.

@rcasero
Copy link
Author

rcasero commented Feb 7, 2019

@gabrieldemarmiesse @fchollet Not blaming anyone. :) I'd just love this to be merged so that updates to keras don't keep breaking the branch, and we can refer to it in a publication. I've just merged the official keras into my branch, as there were some new conflicts. Now it passes the tests again.

@rcasero
Copy link
Author

rcasero commented Feb 7, 2019

@todiketan Perhaps you could ask that question in the thread you mention (just reply to my message there), as it's off-topic here? (This is an issue about adding an element-wise weighting feature to keras).

@fchollet
Copy link
Collaborator

Thank you for preparing this PR.

We are no longer adding new features to multi-backend Keras, as we are refocusing development efforts on tf.keras. If you are still interested in submitting this PR, please direct it to tf.keras in the TensorFlow repository instead.

@fchollet fchollet closed this Sep 11, 2019
@rcasero
Copy link
Author

rcasero commented Sep 12, 2019

@fchollet Thanks for letting me know. Could you give me a couple of pointers of where the code goes within https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/keras? The code structure seems to be different to regular keras, although the filenames are the same.

The PR makes the following changes:

  • keras/engine/training.py

    • In compile(), add an 'element' option to sample_weight_modes
    • instance a K.placeholder with the size of the loss function for the weights.
  • keras/engine/training_utils.py: In weighted_masked_objective(),

    • check that the size of the weights array coincides with the size of the loss function
    • calculate the size of the score_array_shape
  • tests/keras/engine/test_training.py

    • test unit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants