Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layernorm not supporting axis [-2, 3] #19642

Closed
lllllllllaa opened this issue Apr 30, 2024 · 2 comments · Fixed by #19643
Closed

Layernorm not supporting axis [-2, 3] #19642

lllllllllaa opened this issue Apr 30, 2024 · 2 comments · Fixed by #19643
Assignees
Labels

Comments

@lllllllllaa
Copy link

lllllllllaa commented Apr 30, 2024

Hi,
I wanted to normalise my output on the -2 and -3 axis, (image height and width), however, it seems that the with rms_scaling=true, the self.gamma is not broadcasted to same shape as layer input causing this error,

inputs shape: (1, 1920, 1200, 3)
inv shape: (1, 1, 1, 3)
gamma_cast shape: (1920, 1200)
inv shape: (1, 1920, 1200, 3)
2024-04-30 13:50:54.238379: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INVALID_ARGUMENT: Incompatible shapes: [1,1920,1200,3] vs. [1920,1200]
Traceback (most recent call last):
  File "C:\Users\88bbh\PycharmProjects\AI\tempt.py", line 10, in <module>
    layer(np.zeros((1, 1920, 1200, 3)))
  File "C:\Users\88bbh\PycharmProjects\AI\venv\lib\site-packages\keras\src\utils\traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "C:\Users\88bbh\PycharmProjects\AI\venv\lib\site-packages\tensorflow\python\framework\ops.py", line 5983, in raise_from_not_ok_status
    raise core._status_to_exception(e) from None  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Exception encountered when calling LayerNormalization.call().

{{function_node __wrapped__Mul_device_/job:localhost/replica:0/task:0/device:CPU:0}} Incompatible shapes: [1,1920,1200,3] vs. [1920,1200] [Op:Mul] name: 

Arguments received by LayerNormalization.call():
  • inputs=tf.Tensor(shape=(1, 1920, 1200, 3), dtype=float32)

code to reproduce

layer = keras.layers.LayerNormalization(axis=[-3, -2], rms_scaling=True)
layer.build([None, 1920, 1200, 3])
layer(np.zeros((1, 1920, 1200, 3)))

the error is in layernorm call method

        if self.rms_scaling:
            # Calculate outputs with only variance and gamma if rms scaling
            # is enabled
            # Calculate the variance along self.axis (layer activations).
            variance = ops.var(inputs, axis=self.axis, keepdims=True)
            inv = ops.rsqrt(variance + self.epsilon)
            print("inputs shape:", inputs.shape)
            print("inv shape:", inv.shape)
            print("gamma_cast shape:", self.gamma.shape)
            print("inv shape:", (inputs * inv).shape)
            outputs = inputs * inv * ops.cast(self.gamma, inputs.dtype)

the error can be fixed by changing

outputs = inputs * inv * ops.cast(self.gamma, inputs.dtype)
to
outputs = inputs * inv * ops.cast(_broadcast(self.gamma), inputs.dtype)

please fix it in the next update
thank you

@SuryanarayanaY
Copy link
Collaborator

Hi @lllllllllaa ,

Thanks for reporting. I acknowledge the issue and proposed fix seems correct.Proposed fix on attached PR.

Thanks!

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants