Skip to content

Commit

Permalink
Apply style
Browse files Browse the repository at this point in the history
  • Loading branch information
jplu committed Feb 4, 2021
1 parent bfea217 commit 502ae7e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
14 changes: 10 additions & 4 deletions src/transformers/models/led/modeling_tf_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,10 @@ def call(
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
# Make sure to create a mask with the proper shape:
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
masked_index = tf.tile(is_index_masked[:, :, None, None], (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1))
masked_index = tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
)
attn_probs = tf.where(
masked_index,
tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32),
Expand Down Expand Up @@ -325,7 +328,10 @@ def call(
# make sure that local attention probabilities are set to 0 for indices of global attn
# Make sure to create a mask with the proper shape:
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
masked_global_attn_index = tf.tile(is_index_global_attn[:, :, None, None], (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1))
masked_global_attn_index = tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
)
attn_probs = tf.where(
masked_global_attn_index,
tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32),
Expand Down Expand Up @@ -413,8 +419,8 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
)
first_chunk_mask = (
tf.tile(
tf.range(chunks_count + 1)[None, :, None, None],
(batch_size * num_heads, 1, window_overlap, window_overlap)
tf.range(chunks_count + 1)[None, :, None, None],
(batch_size * num_heads, 1, window_overlap, window_overlap),
)
< 1
)
Expand Down
24 changes: 15 additions & 9 deletions src/transformers/models/longformer/modeling_tf_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,9 +904,9 @@ def call(
is_local_index_no_global_attn_nonzero,
) = self._get_global_attn_indices(is_index_global_attn)

#print("before", attn_scores.shape)
#print("max_num_global_attn_indices", max_num_global_attn_indices.numpy())
#print("self.one_sided_attn_window_size", self.one_sided_attn_window_size)
# print("before", attn_scores.shape)
# print("max_num_global_attn_indices", max_num_global_attn_indices.numpy())
# print("self.one_sided_attn_window_size", self.one_sided_attn_window_size)
# this function is only relevant for global attention
attn_scores = tf.cond(
is_global_attn,
Expand All @@ -921,13 +921,16 @@ def call(
),
lambda: attn_scores,
)
#print("after", attn_scores.shape)
# print("after", attn_scores.shape)
attn_probs = tf.nn.softmax(attn_scores, axis=-1)

# softmax sometimes inserts NaN if all positions are masked, replace them with 0
# Make sure to create a mask with the proper shape:
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
masked_index = tf.tile(is_index_masked[:, :, None, None], (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1))
masked_index = tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
)
attn_probs = tf.where(
masked_index,
tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32),
Expand Down Expand Up @@ -981,7 +984,10 @@ def call(
# make sure that local attention probabilities are set to 0 for indices of global attn
# Make sure to create a mask with the proper shape:
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
masked_global_attn_index = tf.tile(is_index_global_attn[:, :, None, None], (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1))
masked_global_attn_index = tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
)
attn_probs = tf.where(
masked_global_attn_index,
tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32),
Expand Down Expand Up @@ -1069,8 +1075,8 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
)
first_chunk_mask = (
tf.tile(
tf.range(chunks_count + 1)[None, :, None, None],
(batch_size * num_heads, 1, window_overlap, window_overlap)
tf.range(chunks_count + 1)[None, :, None, None],
(batch_size * num_heads, 1, window_overlap, window_overlap),
)
< 1
)
Expand Down Expand Up @@ -1358,7 +1364,7 @@ def _concat_with_global_key_attn_probs(
# concat to attn_probs
# (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
attn_scores = tf.concat((attn_probs_from_global_key, attn_scores), axis=-1)

return attn_scores

def _compute_attn_output_with_global_indices(
Expand Down

0 comments on commit 502ae7e

Please sign in to comment.