diff --git a/keras_hub/src/models/flux/flux_maths.py b/keras_hub/src/models/flux/flux_maths.py index fad9bc9e00..1be3dbe86c 100644 --- a/keras_hub/src/models/flux/flux_maths.py +++ b/keras_hub/src/models/flux/flux_maths.py @@ -1,5 +1,4 @@ import keras -from einops import rearrange from keras import ops @@ -58,7 +57,7 @@ def call(self, pos, dim, theta): out = ops.stack( [ops.cos(out), -ops.sin(out), ops.sin(out), ops.cos(out)], axis=-1 ) - out = rearrange(out, "... n d (i j) -> ... n d i j", i=2, j=2) + out = ops.reshape(out, ops.shape(out)[:-1] + (2, 2)) return ops.cast(out, dtype="float32") @@ -122,9 +121,9 @@ def call(self, q, k, v, positional_encoding): x = scaled_dot_product_attention( q, k, v, dropout_p=self.dropout_p, is_causal=self.is_causal ) - - x = rearrange(x, "B H L D -> B L (H D)") - return x + x = ops.transpose(x, (0, 2, 1, 3)) + b, l, h, d = ops.shape(x) + return ops.reshape(x, (b, l, h * d)) # TODO: This is probably already implemented in several places, but is needed to ensure numeric equivalence to the original diff --git a/requirements-common.txt b/requirements-common.txt index c935b10f23..2bdc4a5720 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,6 +19,3 @@ sentencepiece tensorflow-datasets safetensors pillow -# Will be replaced once https://github.com/keras-team/keras/issues/20332 -# is resolved -einops