Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: set use_reentrant to True to fix Mixtral-7b bug #3928

Merged
merged 1 commit into from
Feb 9, 2024

Conversation

geoffreyangus
Copy link
Collaborator

@geoffreyangus geoffreyangus commented Feb 9, 2024

Got the following error when training Mixstral-7b in a multi-GPU setting:

tensor at position 96:
saved metadata: {'shape': torch.Size([142]), 'dtype': torch.int64, 'device': device(type='cuda', index=1)}
recomputed metadata: {'shape': torch.Size([143]), 'dtype': torch.int64, 'device': device(type='cuda', index=1)}
tensor at position 97:
saved metadata: {'shape': torch.Size([142]), 'dtype': torch.int64, 'device': device(type='cuda', index=1)}
recomputed metadata: {'shape': torch.Size([143]), 'dtype': torch.int64, 'device': device(type='cuda', index=1)}
...

This can be resolved by setting reentrant to True ensures

Reentrant checkpoint always recomputes function in its entirety during the backward pass.
Source: https://pytorch.org/docs/stable/checkpoint.html

So shape mismatches should be prevented.

Copy link
Contributor

@arnavgarg1 arnavgarg1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also bump the minimum torch version as a follow up PR?

@geoffreyangus geoffreyangus merged commit ea890d9 into master Feb 9, 2024
14 of 17 checks passed
@geoffreyangus geoffreyangus deleted the use-reentrant-fix branch February 9, 2024 22:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants