-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Overriding Layer.forward
unexpectedly changes the signature of Layer.__call__
under torch backend
#19730
Comments
Hi @LarsKue , Thanks for reporting the issue. I have replicated the issue and observed that with Pytorch as backend its Need to check with Keras Dev team whether this is intended or overlook. Thanks! |
@haifeng-jin any thoughts on this? I'm not familiar enough with the torch to know what our intended behavior is. |
To have a Keras layer works with torch Module and training loops, which would call the |
@LarsKue , Please let us know your specific use case and whether it can be solved by add a Thanks |
@haifeng-jin The use case is invertible networks. My current work-around is not to expose import os
os.environ["KERAS_BACKEND"] = "torch"
import keras
class MyLayer(keras.Layer):
def call(self, xz, inverse: bool = False):
if inverse:
return self._inverse(xz)
return self._forward(xz)
def _forward(self, x):
pass
def _inverse(self, z):
pass
layer = MyLayer()
x = keras.ops.zeros((128, 2))
# works now
layer(x, inverse=True) Adding a |
In torch, one typically writes layer
__call__
methods by overriding theforward
method. Under keras, we instead use thecall
method.I would not expect overriding
forward
to have any effect on how akeras.Layer
is called, even under the torch backend, since this is the purpose ofcall
. However, it seems when theforward
method is overridden, this takes priority over overridingLayer.call
.Minimal Example:
My keras version: 3.3.3
The text was updated successfully, but these errors were encountered: