You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for your great job. We want know more explanation of the output and so how can we visualise the attention weight of the inputs(including image and language)
Thanks for your attention and keep waiting for your kind response!
The text was updated successfully, but these errors were encountered:
The general way to do this is to log the intermediates of the network computations, and then recompute the attention mask. Here's a brief sketch of how you would do this:
@jax.jitdefget_attention_mask(model, observation, task):
_, intermediates=model.module.apply(
{'params': model.params},
observation,
task,
observation['timestep_pad_mask'],
train=False,
method="octo_transformer",
mutable=['intermediates'],
capture_intermediates=True
)
# Intermediates holds literally the output of every submodule run in the NN# As an example, let's get out the last Transformer MHAouts=intermediates['intermediates']['octo_transformer']['BlockTransformer_0']['Transformer_0']['encoderblock_11']['MultiHeadDotProductAttention_0']
key=outs['key']['__call__']
query=outs['query']['__call__']
attention_weights=nn.dot_product_attention_weights(query, key)
# get the attention weights corresponding to the readout tokenreturnattention_weights[..., -1, :] # Shape (batch_size, # attention heads, # tokens)
Some notes:
Always run capture_intermediates=True inside a jax.jit (so that it never actually materializes the output of every submodule, only the one that it has to return from the jax.jit)
This should be easily extendible to whatever else you might want to monitor / log
You have to do a little handwork to figure out which tokens corresponds to images / languages / etc, but shouldn't be terrible
Hello,
Thanks for your great job. We want know more explanation of the output and so how can we visualise the attention weight of the inputs(including image and language)
Thanks for your attention and keep waiting for your kind response!
The text was updated successfully, but these errors were encountered: