Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TF LED/Longformer attentions computation #10007

Merged
merged 9 commits into from
Feb 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 37 additions & 17 deletions src/transformers/models/led/modeling_tf_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,26 @@ def call(
),
lambda: attn_scores,
)

attn_probs = tf.nn.softmax(attn_scores, axis=-1)

# softmax sometimes inserts NaN if all positions are masked, replace them with 0
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_index = tf.cond(
is_global_attn,
lambda: tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
),
lambda: tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
)
attn_probs = tf.where(
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand this change here. The correct shape is given by attn_probs -> I don't understand why we cannot just use shape_list(attn_probs)? IMO, something like:

attn_probs = tf.where(
     masked_index,
     tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype
)

should work, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it's better to make the dtype dependent on the type of attn_probs I think

Copy link
Contributor Author

@jplu jplu Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need here, the default dtype of tf.zeros is always float (float16 if AMP is activated, or float32 if not).

0.0,
masked_index,
tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32),
attn_probs,
)

Expand Down Expand Up @@ -320,11 +333,23 @@ def call(
)

# make sure that local attention probabilities are set to 0 for indices of global attn
# When is_global_attn is True, the last dimension is always self.one_sided_attn_window_size * 2 + 1 + 1
# because of the concat Line 713.
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_global_attn_index = tf.cond(
is_global_attn,
lambda: tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
),
lambda: tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
)
attn_probs = tf.where(
tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)),
tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32),
masked_global_attn_index,
tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32),
attn_probs,
)

Expand Down Expand Up @@ -408,14 +433,9 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
axis=1,
)
first_chunk_mask = (
tf.broadcast_to(
tf.tile(
tf.range(chunks_count + 1)[None, :, None, None],
shape=(
batch_size * num_heads,
chunks_count + 1,
window_overlap,
window_overlap,
),
(batch_size * num_heads, 1, window_overlap, window_overlap),
)
< 1
)
Expand Down Expand Up @@ -463,7 +483,7 @@ def _mask_invalid_locations(input_tensor, window_overlap):
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])

# broadcast to full matrix
mask_4d = tf.broadcast_to(mask_2d[None, :, None, :], shape_list(input_tensor))
mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))

# inf tensor used for masking
inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32)
Expand Down Expand Up @@ -807,7 +827,7 @@ def _compute_global_attn_output_from_hidden(
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))

# mask global attn scores
attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores))
attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1))
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
global_attn_scores = tf.reshape(
global_attn_scores,
Expand Down Expand Up @@ -1684,7 +1704,7 @@ def _pad_to_window_size(
batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window

if tf.math.greater(padding_len, 0):
if padding_len > 0:
logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seq_len, seq_len + padding_len, attention_window
Expand Down
69 changes: 44 additions & 25 deletions src/transformers/models/longformer/modeling_tf_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,21 +395,20 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1]
question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1
# bool attention mask with True in locations of global attention
attention_mask = tf.range(input_ids_shape[1])
attention_mask = tf.range(input_ids_shape[1])[tf.newaxis, :]
attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1))
if before_sep_token is True:
attention_mask = tf.cast(
tf.broadcast_to(attention_mask, input_ids_shape) < tf.broadcast_to(question_end_index, input_ids_shape),
tf.dtypes.int32,
)
question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1]))
attention_mask = tf.cast(attention_mask < question_end_index, tf.int32)
else:
# last token is separation token and should not be counted and in the middle are two separation tokens
question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1]))
attention_mask = (
tf.cast(
tf.broadcast_to(attention_mask, input_ids_shape)
> tf.broadcast_to(question_end_index + 1, input_ids_shape),
attention_mask > question_end_index,
tf.dtypes.int32,
)
* tf.cast(tf.broadcast_to(attention_mask, input_ids_shape) < input_ids_shape[-1], tf.dtypes.int32)
* tf.cast(attention_mask < input_ids_shape[-1], tf.dtypes.int32)
)

return attention_mask
Expand Down Expand Up @@ -784,13 +783,26 @@ def call(
),
lambda: attn_scores,
)

attn_probs = tf.nn.softmax(attn_scores, axis=-1)

# softmax sometimes inserts NaN if all positions are masked, replace them with 0
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_index = tf.cond(
is_global_attn,
lambda: tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
),
lambda: tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
)
attn_probs = tf.where(
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)),
0.0,
masked_index,
tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32),
attn_probs,
)

Expand Down Expand Up @@ -839,11 +851,23 @@ def call(
)

# make sure that local attention probabilities are set to 0 for indices of global attn
# When is_global_attn is True, the last dimension is always self.one_sided_attn_window_size * 2 + 1 + 1
# because of the concat Line 713.
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_global_attn_index = tf.cond(
is_global_attn,
lambda: tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
),
lambda: tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
)
attn_probs = tf.where(
tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)),
tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32),
masked_global_attn_index,
tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32),
attn_probs,
)

Expand Down Expand Up @@ -927,14 +951,9 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
axis=1,
)
first_chunk_mask = (
tf.broadcast_to(
tf.tile(
tf.range(chunks_count + 1)[None, :, None, None],
shape=(
batch_size * num_heads,
chunks_count + 1,
window_overlap,
window_overlap,
),
(batch_size * num_heads, 1, window_overlap, window_overlap),
)
< 1
)
Expand Down Expand Up @@ -982,7 +1001,7 @@ def _mask_invalid_locations(input_tensor, window_overlap):
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])

# broadcast to full matrix
mask_4d = tf.broadcast_to(mask_2d[None, :, None, :], shape_list(input_tensor))
mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))

# inf tensor used for masking
inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32)
Expand Down Expand Up @@ -1326,7 +1345,7 @@ def _compute_global_attn_output_from_hidden(
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))

# mask global attn scores
attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores))
attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1))
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
global_attn_scores = tf.reshape(
global_attn_scores,
Expand Down Expand Up @@ -1701,7 +1720,7 @@ def _pad_to_window_size(
batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window

if tf.math.greater(padding_len, 0):
if padding_len > 0:
logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seq_len, seq_len + padding_len, attention_window
Expand Down
13 changes: 3 additions & 10 deletions tests/test_modeling_tf_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(
# [num_attention_heads, encoder_seq_length, encoder_key_length], but TFLongformerSelfAttention
# returns attention of shape [num_attention_heads, encoder_seq_length, self.attention_window + 1]
# because its local attention only attends to `self.attention_window` and one before and one after
self.key_length = self.attention_window + 1
self.key_length = self.attention_window + 2

# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
# the `test_attention_outputs` and `test_hidden_states_output` tests
Expand Down Expand Up @@ -362,15 +362,8 @@ def test_xla_mode(self):
pass

def test_saved_model_with_attentions_output(self):
# This test don't pass because of the error:
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
# This occurs line 323 in modeling_tf_led.py because the condition line 255
# returns a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 2]
# if is_global_attn is True and a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
# This is due to the tf.concat call line 703 that adds one dimension
# Need to check with PVP how to properly fix this
# Temporarily disable this test in order to find
# how to better handle it without timing out the CI
pass

@slow
Expand Down
13 changes: 3 additions & 10 deletions tests/test_modeling_tf_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,15 +340,8 @@ def test_for_multiple_choice(self):

@slow
def test_saved_model_with_attentions_output(self):
# This test don't pass because of the error:
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
# This occurs line 323 in modeling_tf_led.py because the condition line 255
# returns a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 2]
# if is_global_attn is True and a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
# This is due to the tf.concat call line 703 that adds one dimension
# Need to check with PVP how to properly fix this
# Temporarily disable this test in order to find
# how to better handle it without timing out the CI
pass

@slow
Expand All @@ -372,7 +365,7 @@ def test_mixed_precision(self):
pass

def test_xla_mode(self):
# TODO JP: Make Blenderbot XLA compliant
# TODO JP: Make Longformer XLA compliant
pass


Expand Down