diff --git a/keras_nlp/src/layers/modeling/transformer_decoder.py b/keras_nlp/src/layers/modeling/transformer_decoder.py index a0fefa1ae7..7d1f410ffa 100644 --- a/keras_nlp/src/layers/modeling/transformer_decoder.py +++ b/keras_nlp/src/layers/modeling/transformer_decoder.py @@ -160,10 +160,16 @@ def build( dtype=self.dtype_policy, name="self_attention", ) - self._self_attention_layer.build( - query_shape=decoder_sequence_shape, - value_shape=decoder_sequence_shape, - ) + if hasattr(self._self_attention_layer, "_build_from_signature"): + self._self_attention_layer._build_from_signature( + query=decoder_sequence_shape, + value=decoder_sequence_shape, + ) + else: + self._self_attention_layer.build( + query_shape=decoder_sequence_shape, + value_shape=decoder_sequence_shape, + ) self._self_attention_layer_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, @@ -189,10 +195,16 @@ def build( dtype=self.dtype_policy, name="cross_attention", ) - self._cross_attention_layer.build( - query_shape=decoder_sequence_shape, - value_shape=encoder_sequence_shape, - ) + if hasattr(self._cross_attention_layer, "_build_from_signature"): + self._cross_attention_layer._build_from_signature( + query=decoder_sequence_shape, + value=encoder_sequence_shape, + ) + else: + self._cross_attention_layer.build( + query_shape=decoder_sequence_shape, + value_shape=encoder_sequence_shape, + ) self._cross_attention_layer_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, diff --git a/keras_nlp/src/layers/modeling/transformer_encoder.py b/keras_nlp/src/layers/modeling/transformer_encoder.py index 84272bfce8..a861dffba4 100644 --- a/keras_nlp/src/layers/modeling/transformer_encoder.py +++ b/keras_nlp/src/layers/modeling/transformer_encoder.py @@ -128,10 +128,16 @@ def build(self, inputs_shape): dtype=self.dtype_policy, name="self_attention_layer", ) - self._self_attention_layer.build( - query_shape=inputs_shape, - value_shape=inputs_shape, - ) + if hasattr(self._self_attention_layer, "_build_from_signature"): + self._self_attention_layer._build_from_signature( + query=inputs_shape, + value=inputs_shape, + ) + else: + self._self_attention_layer.build( + query_shape=inputs_shape, + value_shape=inputs_shape, + ) self._self_attention_layer_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy,