From 16ab243a7340cb6c305a73b26050638dde36b1ef Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Fri, 12 Jul 2024 11:51:52 +0800 Subject: [PATCH] Fix dtype bugs in `ReversibleEmbedding` and `LayerNorm` --- keras_nlp/src/layers/modeling/reversible_embedding.py | 2 +- keras_nlp/src/models/llama/llama_layernorm.py | 2 +- keras_nlp/src/models/mistral/mistral_layer_norm.py | 2 +- keras_nlp/src/models/phi3/phi3_layernorm.py | 2 +- keras_nlp/src/tests/test_case.py | 6 ++++++ 5 files changed, 10 insertions(+), 4 deletions(-) diff --git a/keras_nlp/src/layers/modeling/reversible_embedding.py b/keras_nlp/src/layers/modeling/reversible_embedding.py index c30bf609d1..b923742ccc 100644 --- a/keras_nlp/src/layers/modeling/reversible_embedding.py +++ b/keras_nlp/src/layers/modeling/reversible_embedding.py @@ -180,7 +180,7 @@ def compute_output_spec(self, inputs, reverse=False): output_shape[-1] = self.input_dim else: output_shape += [self.output_dim] - return keras.KerasTensor(output_shape, dtype=self.dtype) + return keras.KerasTensor(output_shape, dtype=self.compute_dtype) # Quantization-related (int8) methods diff --git a/keras_nlp/src/models/llama/llama_layernorm.py b/keras_nlp/src/models/llama/llama_layernorm.py index fc5a52d188..4352252146 100644 --- a/keras_nlp/src/models/llama/llama_layernorm.py +++ b/keras_nlp/src/models/llama/llama_layernorm.py @@ -40,7 +40,7 @@ def call(self, x): x = ops.cast(x, "float32") var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) x = x * ops.rsqrt(var + self.epsilon) - return ops.cast(x, self.compute_dtype) * self.scale + return ops.cast(x * self.scale, self.compute_dtype) def get_config(self): config = super().get_config() diff --git a/keras_nlp/src/models/mistral/mistral_layer_norm.py b/keras_nlp/src/models/mistral/mistral_layer_norm.py index 0ad68d7193..b12c4a835e 100644 --- a/keras_nlp/src/models/mistral/mistral_layer_norm.py +++ b/keras_nlp/src/models/mistral/mistral_layer_norm.py @@ -40,7 +40,7 @@ def call(self, x): x = ops.cast(x, "float32") var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) x = x * ops.rsqrt(var + self.epsilon) - return ops.cast(x, self.compute_dtype) * self.scale + return ops.cast(x * self.scale, self.compute_dtype) def get_config(self): config = super().get_config() diff --git a/keras_nlp/src/models/phi3/phi3_layernorm.py b/keras_nlp/src/models/phi3/phi3_layernorm.py index 9a63460632..3ff62b386f 100644 --- a/keras_nlp/src/models/phi3/phi3_layernorm.py +++ b/keras_nlp/src/models/phi3/phi3_layernorm.py @@ -40,7 +40,7 @@ def call(self, x): x = ops.cast(x, "float32") var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) x = x * ops.rsqrt(var + self.epsilon) - return ops.cast(x, self.compute_dtype) * self.scale + return ops.cast(x * self.scale, self.compute_dtype) def get_config(self): config = super().get_config() diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index 253c8a7692..7e8e0cec95 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -314,13 +314,19 @@ def run_precision_test(self, cls, init_kwargs, input_data): layer = cls(**{**init_kwargs, "dtype": policy}) if isinstance(layer, keras.Model): output_data = layer(input_data) + output_spec = layer.compute_output_spec(input_data) elif isinstance(input_data, dict): output_data = layer(**input_data) + output_spec = layer.compute_output_spec(**input_data) else: output_data = layer(input_data) + output_spec = layer.compute_output_spec(input_data) for tensor in tree.flatten(output_data): if is_float_dtype(tensor.dtype): self.assertDTypeEqual(tensor, policy.compute_dtype) + for spec in tree.flatten(output_spec): + if is_float_dtype(spec.dtype): + self.assertDTypeEqual(spec, policy.compute_dtype) for weight in layer.weights: if is_float_dtype(weight.dtype): self.assertDTypeEqual(weight, policy.variable_dtype)