Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to visualise the attention weight of the inputs #106

Open
oym1994 opened this issue May 31, 2024 · 2 comments
Open

How to visualise the attention weight of the inputs #106

oym1994 opened this issue May 31, 2024 · 2 comments

Comments

@oym1994
Copy link

oym1994 commented May 31, 2024

Hello,

Thanks for your great job. We want know more explanation of the output and so how can we visualise the attention weight of the inputs(including image and language)

Thanks for your attention and keep waiting for your kind response!

@dibyaghosh
Copy link
Collaborator

The general way to do this is to log the intermediates of the network computations, and then recompute the attention mask. Here's a brief sketch of how you would do this:

@jax.jit
def get_attention_mask(model, observation, task):
  _, intermediates = model.module.apply(
      {'params': model.params},
     observation, 
     task, 
     observation['timestep_pad_mask'],
     train=False,
     method="octo_transformer",
     mutable=['intermediates'],
     capture_intermediates=True
  )
  # Intermediates holds literally the output of every submodule run in the NN
  # As an example, let's get out the last Transformer MHA
  outs = intermediates['intermediates']['octo_transformer']['BlockTransformer_0']['Transformer_0']['encoderblock_11']['MultiHeadDotProductAttention_0']
  key = outs['key']['__call__']
  query = outs['query']['__call__']
  attention_weights =  nn.dot_product_attention_weights(query, key)
  # get the attention weights corresponding to the readout token
  return attention_weights[..., -1, :] # Shape (batch_size, # attention heads, # tokens)

Some notes:

  1. Always run capture_intermediates=True inside a jax.jit (so that it never actually materializes the output of every submodule, only the one that it has to return from the jax.jit)
  2. This should be easily extendible to whatever else you might want to monitor / log
  3. You have to do a little handwork to figure out which tokens corresponds to images / languages / etc, but shouldn't be terrible

@oym1994
Copy link
Author

oym1994 commented Jun 1, 2024

Get it! Thanks for your kind response! And I will try with this right now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants