Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with reproducing "strided" attention scheme from the paper #7

Open
krishnadubba opened this issue May 21, 2019 · 3 comments
Open

Comments

@krishnadubba
Copy link

HI,
I am trying to visualize the attention schemes using this code. Basically trying to reproduce Fig:3 from the paper. I could reproduce the "fixed" attention scheme as shown below:

fixed_sparse_attn

The problem is I could not reproduce the "strided" scheme (Fig 3.b from paper). All I get is the following no matter what parameters I try:

strided_wrong

If I change some code then I can get the correct "strided" version as shown in the paper. The following is after some code changes:

strided_correct

Did anyone face the same issue?

@pengfeiZhao1993
Copy link

After reading this code -- "attention.py", I find this base code only contains separate implementations of strided attention, called "first / second step of strided attention" within it. Therefore, you perhaps need to implement a integral version of strided attention by yourself with each head corresponding to one of aforementioned two steps for a two head sparse self-attention.

@benathi
Copy link

benathi commented Dec 6, 2022

@krishnadubba Have you successfully implement the strided version btw? Could you share the code change?

@jaindhairyahere
Copy link

jaindhairyahere commented Jan 13, 2024

@krishnadubba Have you successfully implement the strided version btw? Could you share the code change?

I was able to reproduce the patterns using this function.
image
image

`

def sparse_attention_mask(n_tokens, stride_length=3, c=2):

  x = tf.reshape(tf.range(n_tokens), [n_tokens, 1])

  y = tf.transpose(x)

  z = tf.zeros((n_tokens,n_tokens))

  Q = z + x

  K = z + y

  causal_attention_mask = (Q>=K)

  fixed_mask_1 = tf.equal(Q//stride_length, K//stride_length)
  fixed_mask_2 = tf.logical_and(tf.math.floormod(K, stride_length) >= stride_length-c, tf.math.floormod(K, stride_length)<=stride_length)
  combined_mask_fixed = tf.logical_and(causal_attention_mask, tf.logical_or(fixed_mask_1, fixed_mask_2))

  stride_mask_1 = tf.less_equal(Q-K, stride_length)
  stride_mask_2 = tf.equal(tf.math.floormod(Q-K, stride_length), 0)
  combined_mask_stride = tf.logical_and(causal_attention_mask, tf.logical_or(stride_mask_1, stride_mask_2))

  return tf.reshape(combined_mask_fixed, [1, 1, n_tokens, n_tokens]), tf.reshape(combined_mask_stride, [1, 1, n_tokens, n_tokens])`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants