Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect attention mask computation #75

Open
yushijinhun opened this issue Apr 15, 2024 · 5 comments
Open

Incorrect attention mask computation #75

yushijinhun opened this issue Apr 15, 2024 · 5 comments

Comments

@yushijinhun
Copy link

I found the generate_attention_mask function block_transformer.py seems to calculate the attention mask incorrectly. Here an example:

Attention mask debug info

According to the above printed attention mask diagram, the token at 656 (in t=1 obs_wrist) should NOT attend to the token at 657 (in t=1 readout_action). However, attention_mask[656, 657] is True. You can reproduce this using jdb. It seems that get_token_metadata function doesn't calculate the belonging group of tokens correctly.

(jdb) l
> /home/yushijinhun/octo/octo/octo/model/components/block_transformer.py(325)
                    mask = int(metadata_i.should_attend_to(metadata_j))
                    attention_mask[i, j] = mask
    
            pad_attention_mask = self.generate_pad_attention_mask(
                prefix_groups, timestep_groups
            )
->          jax.debug.breakpoint()
            attention_mask = jnp.logical_and(attention_mask, pad_attention_mask)
            return attention_mask
    
(jdb) bt
Traceback:
  File "/home/yushijinhun/octo/octo-experiment/test.py", line 11
    actions = model.sample_actions(
  File "/home/yushijinhun/octo/octo/octo/model/octo_model.py", line 187
    transformer_outputs = self.run_transformer(
  File "/home/yushijinhun/octo/octo/octo/model/octo_model.py", line 152
    return self.module.apply(
  File "/home/yushijinhun/octo/octo/octo/model/octo_module.py", line 249
    prefix_outputs, timestep_outputs = BlockTransformer(self.transformer_kwargs)(
  File "/home/yushijinhun/octo/octo/octo/model/components/block_transformer.py", line 172
    attention_mask = self.generate_attention_mask(prefix_groups, timestep_groups)
  File "/home/yushijinhun/octo/octo/octo/model/components/block_transformer.py", line 325
    jax.debug.breakpoint()
(jdb) tokens_per_prefix_group
[16]
(jdb) tokens_per_timestep_group
[256, 64, 1]
(jdb) horizon
2
(jdb) tokens_for_prefix
16
(jdb) tokens_per_time_step
321
(jdb) total_tokens
658
(jdb) get_token_metadata(657)    #### <--- Token 657 should belong to group "t=1 readout_action", NOT "t=1 obs_wrist"
TokenMetadata(name='obs_wrist', timestep=1, attention_rules={'task_*': <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>, 'obs_*': <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>})
(jdb) get_token_metadata(656)
TokenMetadata(name='obs_wrist', timestep=1, attention_rules={'task_*': <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>, 'obs_*': <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>})
(jdb) attention_mask[657, 656]
1
(jdb) attention_mask[656, 657]    #### <--- This should be FALSE, group "t=1 obs_wrist" should NOT attend to group "t=1 readout_action"
1
@dibyaghosh
Copy link
Collaborator

dibyaghosh commented Apr 15, 2024

TL;DR: There is a small bug in the attention masking code; this should not practically affect anyone using the released model or training their own models (unless you're doing some special attention mask scheme), but we will fix it soon in an update.

Thanks for catching this! Looks like a victim of an off-by-one error caused by np.searchsorted.

In this section

        def _get_position(i, tokens_per_elem):
            return np.searchsorted(np.cumsum(tokens_per_elem), i)

the correct code should have been:

        def _get_position(i, tokens_per_elem):
            return np.searchsorted(np.cumsum(tokens_per_elem), i, side='right')

I'll work to get this patched in at some point, but since fixing the issue will affect our current released checkpoints (which were trained on the old incorrect attention mask structure), doing so will require a little bit of tact and care. I'll keep you updated in this issue.

For others reading the issue:

For most people (if you are using the released model checkpoints, if you are using our config for pretraining): There is a small bug in the attention mask generation that should not affect your use cases, which causes observation tokens to attend to the action readout. This isn't a fatal bug (e.g. there's no information leakage from future timesteps to current timesteps), but we will fix it in a future edition.

If you are using octo.model.BlockTransformer with non-default attention masking strategies (not common): If you have multiple timestep groups, the bug causes the first token in the second group to be misclassified as being in the first group (similarly, 1st token of 3rd group is misclassified as being in group 2, so on). If your model relies on different timestep groups not being able to attend to each other (this is a pretty non-standard case), then please make sure to incorporate the fix in #76 to avoid infrormation leakage between different timestep groups.

@yushijinhun
Copy link
Author

Thank you for the timely reply! I've incorporated your fix and I can confirm the attention mask computation is now correct.

@andrearosasco
Copy link
Contributor

andrearosasco commented Apr 18, 2024

@dibyaghosh do you think finetuning the pretrained "bugged" model with correct mask would improve or deteriorate the results? Being just one token my guess is it shouldn't make too much difference

@dibyaghosh
Copy link
Collaborator

Especially for finetuning, my guess is that it shouldn't make any difference really, but these things can be hard to predict.

WenchangGaoT pushed a commit to WenchangGaoT/octo1 that referenced this issue May 10, 2024
@zwbx
Copy link

zwbx commented Jul 22, 2024

It is useful, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants