Skip to content

Commit

Permalink
Merge pull request #16619 from ageron:mha_automask
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 460320823
  • Loading branch information
tensorflower-gardener committed Jul 11, 2022
2 parents c48142d + 7605567 commit 645b361
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 4 deletions.
Expand Up @@ -157,7 +157,7 @@ tf_class {
}
member_method {
name: "call"
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], "
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\', \'use_causal_mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "compute_mask"
Expand Down
Expand Up @@ -157,7 +157,7 @@ tf_class {
}
member_method {
name: "call"
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], "
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\', \'use_causal_mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "compute_mask"
Expand Down
109 changes: 107 additions & 2 deletions keras/layers/attention/multi_head_attention.py
Expand Up @@ -217,6 +217,9 @@ class MultiHeadAttention(Layer):
training mode (adding dropout) or in inference mode (no dropout).
Defaults to either using the training mode of the parent layer/model,
or False (inference) if there is no parent layer.
use_causal_mask: A boolean to indicate whether to apply a causal mask to
prevent tokens from attending to future tokens (e.g., used in a decoder
Transformer).
Returns:
attention_output: The result of the computation, of shape `(B, T, E)`,
Expand Down Expand Up @@ -246,6 +249,7 @@ def __init__(
**kwargs
):
super().__init__(**kwargs)
self.supports_masking = True
self._num_heads = num_heads
self._key_dim = key_dim
self._value_dim = value_dim if value_dim else key_dim
Expand Down Expand Up @@ -449,7 +453,7 @@ def _build_attention(self, rank):
"""Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to
costomize attention computation to replace the default dot-product
customize attention computation to replace the default dot-product
attention.
Args:
Expand Down Expand Up @@ -502,7 +506,8 @@ def _compute_attention(
key: Projected key `Tensor` of shape `(B, S, N, key_dim)`.
value: Projected value `Tensor` of shape `(B, S, N, value_dim)`.
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions.
attention to certain positions. It is generally not needed if the
`query` and `value` (and/or `key`) are masked.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Expand Down Expand Up @@ -543,7 +548,16 @@ def call(
attention_mask=None,
return_attention_scores=False,
training=None,
use_causal_mask=False,
):
attention_mask = self._compute_attention_mask(
query,
value,
key=key,
attention_mask=attention_mask,
use_causal_mask=use_causal_mask,
)

if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
Expand Down Expand Up @@ -592,3 +606,94 @@ def call(
if return_attention_scores:
return attention_output, attention_scores
return attention_output

def _compute_attention_mask(
self, query, value, key=None, attention_mask=None, use_causal_mask=False
):
"""Computes the attention mask, using the Keras masks of the inputs.
* The `query`'s mask is reshaped from [B, T] to [B, T, 1].
* The `value`'s mask is reshaped from [B, S] to [B, 1, S].
* The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s
mask is ignored if `key` is `None` or if `key is value`.
* If `use_causal_mask=True`, then the causal mask is computed. Its shape
is [1, T, S].
All defined masks are merged using a logical AND operation (`&`).
In general, if the `query` and `value` are masked, then there is no need
to define the `attention_mask`.
Args:
query: Projected query `Tensor` of shape `(B, T, N, key_dim)`.
key: Projected key `Tensor` of shape `(B, T, N, key_dim)`.
value: Projected value `Tensor` of shape `(B, T, N, value_dim)`.
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions.
use_causal_mask: A boolean to indicate whether to apply a causal mask
to prevent tokens from attending to future tokens (e.g., used in a
decoder Transformer).
Returns:
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions, based on the Keras masks of the
`query`, `key`, `value`, and `attention_mask` tensors, and the
causal mask if `use_causal_mask=True`.
"""
query_mask = getattr(query, "_keras_mask", None)
value_mask = getattr(value, "_keras_mask", None)
key_mask = getattr(key, "_keras_mask", None)
auto_mask = None
if query_mask is not None:
query_mask = tf.cast(query_mask, tf.bool) # defensive casting
# B = batch size, T = max query length
auto_mask = query_mask[:, :, tf.newaxis] # shape is [B, T, 1]
if value_mask is not None:
value_mask = tf.cast(value_mask, tf.bool) # defensive casting
# B = batch size, S == max value length
mask = value_mask[:, tf.newaxis, :] # shape is [B, 1, S]
auto_mask = mask if auto_mask is None else auto_mask & mask
if key_mask is not None:
key_mask = tf.cast(key_mask, tf.bool) # defensive casting
# B == batch size, S == max key length == max value length
mask = key_mask[:, tf.newaxis, :] # shape is [B, 1, S]
auto_mask = mask if auto_mask is None else auto_mask & mask
if use_causal_mask:
# the shape of the causal mask is [1, T, S]
mask = self._compute_causal_mask(query, value)
auto_mask = mask if auto_mask is None else auto_mask & mask
if auto_mask is not None:
# merge attention_mask & automatic mask, to shape [B, T, S]
attention_mask = (
auto_mask
if attention_mask is None
else tf.cast(attention_mask, bool) & auto_mask
)
return attention_mask

def _compute_causal_mask(self, query, value=None):
"""Computes a causal mask (e.g., for masked self-attention layers).
For example, if query and value both contain sequences of length 4,
this function returns a boolean `Tensor` equal to:
```
[[[True, False, False, False],
[True, True, False, False],
[True, True, True, False],
[True, True, True, True]]]
```
Args:
query: query `Tensor` of shape `(B, T, ...)`.
value: value `Tensor` of shape `(B, S, ...)` (optional, defaults to
query).
Returns:
mask: a boolean `Tensor` of shape [1, T, S] containing a lower
triangular matrix of shape [T, S].
"""
q_seq_length = tf.shape(query)[1]
v_seq_length = q_seq_length if value is None else tf.shape(value)[1]
return tf.linalg.band_part( # creates a lower triangular matrix
tf.ones((1, q_seq_length, v_seq_length), tf.bool), -1, 0
)
63 changes: 63 additions & 0 deletions keras/layers/attention/multi_head_attention_test.py
Expand Up @@ -20,6 +20,7 @@

import keras
from keras.testing_infra import test_combinations
from keras.testing_infra import test_utils


# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
Expand Down Expand Up @@ -328,6 +329,68 @@ def test_ragged_tensor(self, ragged_query, ragged_value, ragged_key):
results = test_layer(query, value, key)
self.assertAllEqual(results.shape.as_list(), query.shape.as_list())

def test_query_mask_progagation(self):
"""Test automatic propagation of the query's mask."""
test_layer = keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
self.assertTrue(test_layer.supports_masking)
query = tf.constant([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])
masked_query = keras.layers.Embedding(4, 8, mask_zero=True)(query)
value = tf.random.normal((3, 3, 8))
output = test_layer(query=masked_query, value=value)
self.assertTrue(hasattr(output, "_keras_mask"))
self.assertAllEqual(masked_query._keras_mask, output._keras_mask)

@parameterized.named_parameters(("causal", True), ("not_causal", False))
@test_utils.run_v2_only
def test_value_mask(self, use_causal_mask):
"""Test that the value and causal masks are taken into account."""
test_layer = keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
query = tf.constant([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])
masked_query = keras.layers.Embedding(4, 8, mask_zero=True)(query)
value = tf.constant([[5, 4, 0], [3, 0, 0], [2, 1, 1]])
masked_value = keras.layers.Embedding(6, 8, mask_zero=True)(value)
output = test_layer(
query=masked_query,
value=masked_value,
use_causal_mask=use_causal_mask,
)
mask = tf.constant(
[[[True, True, False]] * 3 + [[False, False, False]] * 2]
+ [[[True, False, False]] * 5]
+ [[[True, True, True]] + [[False, False, False]] * 4]
)
if use_causal_mask:
mask = mask & tf.constant(
[
[[True, False, False], [True, True, False]]
+ [[True, True, True]] * 3
]
)
del masked_query._keras_mask
del masked_value._keras_mask
output_with_manual_mask = test_layer(
query=masked_query, value=masked_value, attention_mask=mask
)
self.assertAllClose(output, output_with_manual_mask)

def test_masks_are_cast_to_bool(self):
"""Test that the implicit and explicit masks are cast to bool."""
test_layer = keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])
masked_query = keras.layers.Embedding(4, 8, mask_zero=True)(query)
masked_query._keras_mask = tf.cast(masked_query._keras_mask, tf.float32)
value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]])
masked_value = keras.layers.Embedding(6, 8, mask_zero=True)(value)
masked_value._keras_mask = tf.cast(masked_value._keras_mask, tf.float32)
float_mask = tf.constant([[[1.0]]])
# if all works well, the following should not raise any exception:
_ = test_layer(
query=masked_query,
value=masked_value,
use_causal_mask=True,
attention_mask=float_mask,
)


class SubclassAttention(keras.layers.MultiHeadAttention):
def _build_attention(self, qkv_rank):
Expand Down

0 comments on commit 645b361

Please sign in to comment.