-
Notifications
You must be signed in to change notification settings - Fork 162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incorrect attention mask computation #75
Comments
TL;DR: There is a small bug in the attention masking code; this should not practically affect anyone using the released model or training their own models (unless you're doing some special attention mask scheme), but we will fix it soon in an update. Thanks for catching this! Looks like a victim of an off-by-one error caused by np.searchsorted. In this section def _get_position(i, tokens_per_elem):
return np.searchsorted(np.cumsum(tokens_per_elem), i) the correct code should have been: def _get_position(i, tokens_per_elem):
return np.searchsorted(np.cumsum(tokens_per_elem), i, side='right') I'll work to get this patched in at some point, but since fixing the issue will affect our current released checkpoints (which were trained on the old incorrect attention mask structure), doing so will require a little bit of tact and care. I'll keep you updated in this issue. For others reading the issue: For most people (if you are using the released model checkpoints, if you are using our config for pretraining): There is a small bug in the attention mask generation that should not affect your use cases, which causes observation tokens to attend to the action readout. This isn't a fatal bug (e.g. there's no information leakage from future timesteps to current timesteps), but we will fix it in a future edition. If you are using octo.model.BlockTransformer with non-default attention masking strategies (not common): If you have multiple timestep groups, the bug causes the first token in the second group to be misclassified as being in the first group (similarly, 1st token of 3rd group is misclassified as being in group 2, so on). If your model relies on different timestep groups not being able to attend to each other (this is a pretty non-standard case), then please make sure to incorporate the fix in #76 to avoid infrormation leakage between different timestep groups. |
Thank you for the timely reply! I've incorporated your fix and I can confirm the attention mask computation is now correct. |
@dibyaghosh do you think finetuning the pretrained "bugged" model with correct mask would improve or deteriorate the results? Being just one token my guess is it shouldn't make too much difference |
Especially for finetuning, my guess is that it shouldn't make any difference really, but these things can be hard to predict. |
Minor change to optimizer
It is useful, thanks! |
I found the
generate_attention_mask
functionblock_transformer.py
seems to calculate the attention mask incorrectly. Here an example:According to the above printed attention mask diagram, the token at 656 (in
t=1 obs_wrist
) should NOT attend to the token at 657 (int=1 readout_action
). However,attention_mask[656, 657]
is True. You can reproduce this using jdb. It seems thatget_token_metadata
function doesn't calculate the belonging group of tokens correctly.The text was updated successfully, but these errors were encountered: