diff --git a/keras/layers/attention/multi_head_attention.py b/keras/layers/attention/multi_head_attention.py index d78585f3e91..ffb740ea240 100644 --- a/keras/layers/attention/multi_head_attention.py +++ b/keras/layers/attention/multi_head_attention.py @@ -217,6 +217,9 @@ class MultiHeadAttention(Layer): training mode (adding dropout) or in inference mode (no dropout). Defaults to either using the training mode of the parent layer/model, or False (inference) if there is no parent layer. + use_causal_mask: A boolean to indicate whether to apply a causal mask to + prevent tokens from attending to future tokens (e.g., used in a decoder + Transformer). Returns: attention_output: The result of the computation, of shape `(B, T, E)`, @@ -246,6 +249,7 @@ def __init__( **kwargs ): super().__init__(**kwargs) + self.supports_masking = True self._num_heads = num_heads self._key_dim = key_dim self._value_dim = value_dim if value_dim else key_dim @@ -449,7 +453,7 @@ def _build_attention(self, rank): """Builds multi-head dot-product attention computations. This function builds attributes necessary for `_compute_attention` to - costomize attention computation to replace the default dot-product + customize attention computation to replace the default dot-product attention. Args: @@ -502,7 +506,8 @@ def _compute_attention( key: Projected key `Tensor` of shape `(B, S, N, key_dim)`. value: Projected value `Tensor` of shape `(B, S, N, value_dim)`. attention_mask: a boolean mask of shape `(B, T, S)`, that prevents - attention to certain positions. + attention to certain positions. It is generally not needed if the + `query` and `value` (and/or `key`) are masked. training: Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (doing nothing). @@ -543,7 +548,16 @@ def call( attention_mask=None, return_attention_scores=False, training=None, + use_causal_mask=False, ): + attention_mask = self._compute_attention_mask( + query, + value, + key=key, + attention_mask=attention_mask, + use_causal_mask=use_causal_mask, + ) + if not self._built_from_signature: self._build_from_signature(query=query, value=value, key=key) if key is None: @@ -592,3 +606,92 @@ def call( if return_attention_scores: return attention_output, attention_scores return attention_output + + def _compute_attention_mask( + self, query, value, key=None, attention_mask=None, use_causal_mask=False + ): + """Computes the attention mask, using the Keras masks of the inputs. + + * The `query`'s mask is reshaped from [B, T] to [B, T, 1]. + * The `value`'s mask is reshaped from [B, S] to [B, 1, S]. + * The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s + mask is ignored if `key` is `None` or if `key is value`. + * If `use_causal_mask=True`, then the causal mask is computed. Its shape + is [1, T, S]. + + All defined masks are merged using a logical AND operation (`&`). + + In general, if the `query` and `value` are masked, then there is no need + to define the `attention_mask`. + + Args: + query: Projected query `Tensor` of shape `(B, T, N, key_dim)`. + key: Projected key `Tensor` of shape `(B, T, N, key_dim)`. + value: Projected value `Tensor` of shape `(B, T, N, value_dim)`. + attention_mask: a boolean mask of shape `(B, T, S)`, that prevents + attention to certain positions. + use_causal_mask: A boolean to indicate whether to apply a causal mask + to prevent tokens from attending to future tokens (e.g., used in a + decoder Transformer). + + Returns: + attention_mask: a boolean mask of shape `(B, T, S)`, that prevents + attention to certain positions, based on the Keras masks of the + `query`, `key`, `value`, and `attention_mask` tensors, and the + causal mask if `use_causal_mask=True`. + """ + query_mask = getattr(query, "_keras_mask", None) + value_mask = getattr(value, "_keras_mask", None) + key_mask = getattr(key, "_keras_mask", None) + auto_mask = None + if query_mask is not None: + query_mask = tf.cast(query_mask, tf.bool) # defensive casting + # B = batch size, T = max query length + auto_mask = query_mask[:, :, tf.newaxis] # shape is [B, T, 1] + if value_mask is not None: + value_mask = tf.cast(value_mask, tf.bool) # defensive casting + # B = batch size, S == max value length + mask = value_mask[:, tf.newaxis, :] # shape is [B, 1, S] + auto_mask = mask if auto_mask is None else auto_mask & mask + if key_mask is not None: + key_mask = tf.cast(key_mask, tf.bool) # defensive casting + # B == batch size, S == max key length == max value length + mask = key_mask[:, tf.newaxis, :] # shape is [B, 1, S] + auto_mask = mask if auto_mask is None else auto_mask & mask + if use_causal_mask: + # the shape of the causal mask is [1, T, S] + mask = self._compute_causal_mask(query, value) + auto_mask = mask if auto_mask is None else auto_mask & mask + if auto_mask is not None: + # merge attention_mask & automatic mask, to shape [B, T, S] + attention_mask = ( + auto_mask + if attention_mask is None + else tf.cast(attention_mask, bool) & auto_mask + ) + return attention_mask + + def _compute_causal_mask(self, query, value=None): + """Computes a causal mask (e.g., for masked self-attention layers). + + For example, if query and value both contain sequences of length 4, + this function returns a boolean `Tensor` equal to: + [[[True, False, False, False], + [True, True, False, False], + [True, True, True, False], + [True, True, True, True]]] + + Args: + query: query `Tensor` of shape `(B, T, ...)`. + value: value `Tensor` of shape `(B, S, ...)` (optional, defaults to + query). + + Returns: + mask: a boolean `Tensor` of shape [1, T, S] containing a lower + triangular matrix of shape [T, S]. + """ + q_seq_length = tf.shape(query)[1] + v_seq_length = q_seq_length if value is None else tf.shape(value)[1] + return tf.linalg.band_part( # creates a lower triangular matrix + tf.ones((1, q_seq_length, v_seq_length), tf.bool), -1, 0 + ) diff --git a/keras/layers/attention/multi_head_attention_test.py b/keras/layers/attention/multi_head_attention_test.py index f88cbb2791f..59ddb1a03c8 100644 --- a/keras/layers/attention/multi_head_attention_test.py +++ b/keras/layers/attention/multi_head_attention_test.py @@ -328,6 +328,67 @@ def test_ragged_tensor(self, ragged_query, ragged_value, ragged_key): results = test_layer(query, value, key) self.assertAllEqual(results.shape.as_list(), query.shape.as_list()) + def test_query_mask_progagation(self): + """Test automatic propagation of the query's mask.""" + test_layer = keras.layers.MultiHeadAttention(num_heads=2, key_dim=2) + self.assertTrue(test_layer.supports_masking) + query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]]) + masked_query = keras.layers.Embedding(4, 8, mask_zero=True)(query) + value = np.random.random((3, 3, 8)) + output = test_layer(query=masked_query, value=value) + self.assertTrue(hasattr(output, "_keras_mask")) + self.assertAllEqual(masked_query._keras_mask, output._keras_mask) + + @parameterized.named_parameters(("causal", True), ("not_causal", False)) + def test_value_mask(self, use_causal_mask): + """Test that the value and causal masks are taken into account.""" + test_layer = keras.layers.MultiHeadAttention(num_heads=2, key_dim=2) + query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]]) + masked_query = keras.layers.Embedding(4, 8, mask_zero=True)(query) + value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]]) + masked_value = keras.layers.Embedding(6, 8, mask_zero=True)(value) + output = test_layer( + query=masked_query, + value=masked_value, + use_causal_mask=use_causal_mask, + ) + mask = np.array( + [[[True, True, False]] * 3 + [[False, False, False]] * 2] + + [[[True, False, False]] * 5] + + [[[True, True, True]] + [[False, False, False]] * 4] + ) + if use_causal_mask: + mask = mask & np.array( + [ + [[True, False, False], [True, True, False]] + + [[True, True, True]] * 3 + ] + ) + del masked_query._keras_mask + del masked_value._keras_mask + output_with_manual_mask = test_layer( + query=masked_query, value=masked_value, attention_mask=mask + ) + self.assertAllClose(output, output_with_manual_mask) + + def test_masks_are_cast_to_bool(self): + """Test that the implicit and explicit masks are cast to bool.""" + test_layer = keras.layers.MultiHeadAttention(num_heads=2, key_dim=2) + query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]]) + masked_query = keras.layers.Embedding(4, 8, mask_zero=True)(query) + masked_query._keras_mask = tf.cast(masked_query._keras_mask, tf.float32) + value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]]) + masked_value = keras.layers.Embedding(6, 8, mask_zero=True)(value) + masked_value._keras_mask = tf.cast(masked_value._keras_mask, tf.float32) + float_mask = tf.constant([[[1.0]]]) + # if all works well, the following should not raise any exception: + _ = test_layer( + query=masked_query, + value=masked_value, + use_causal_mask=True, + attention_mask=float_mask, + ) + class SubclassAttention(keras.layers.MultiHeadAttention): def _build_attention(self, qkv_rank):