-
Notifications
You must be signed in to change notification settings - Fork 19
[5/x] make FSDP2 with float8 all-gather work for Float8Linear #296
Conversation
Summary: Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2 with float8 all-gather. To make the tests pass, fixes a bug with initilization ordering in `Float8Linear.from_float`, we need to have the right forward config set before stashing it on the weight wrapper. Test Plan: ``` python test/test_fsdp2/test_fsdp2_eager.py /test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2 with float8 all-gather. To make the tests pass, fixes a bug with initilization ordering in `Float8Linear.from_float`, we need to have the right forward config set before stashing it on the weight wrapper. Test Plan: ``` python test/test_fsdp2/test_fsdp2_eager.py /test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: b6d6525 Pull Request resolved: #296
…ear" Summary: Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2 with float8 all-gather. To make the tests pass, fixes a bug with initilization ordering in `Float8Linear.from_float`, we need to have the right forward config set before stashing it on the weight wrapper. Test Plan: ``` python test/test_fsdp2/test_fsdp2_eager.py /test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2 with float8 all-gather. To make the tests pass, fixes a bug with initilization ordering in `Float8Linear.from_float`, we need to have the right forward config set before stashing it on the weight wrapper. Test Plan: ``` python test/test_fsdp2/test_fsdp2_eager.py /test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 26b7138 Pull Request resolved: #296
| ) | ||
| new_mod.weight = mod.weight | ||
| else: | ||
| assert not config.enable_fsdp_fp8_all_gather, "unsupported" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: maybe a more helpful assert message
|
|
||
| def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module: | ||
| return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs) | ||
| def swap_linear_with_dynamic( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe losing some context but is there a reason why the existing swap function doesnt work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the question is why do we need swap_linear_with_dynamic, we probably don't. Removing that is not related to this PR though so I left it for a future person.
| self._test_transformer_memory(enable_fsdp_fp8_all_gather) | ||
|
|
||
| def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): | ||
| # for enable_fsdp_fp8_all_gather in [False, True]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can remove comment right?
drisspg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, maybe add a dummy test that float8Linear with not all dynamic errors when trying to use fp8 allgather
|
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
This pull request has been merged in 412222b. |
Stack from ghstack (oldest at bottom):
Summary:
Adds test coverage for
Float8Linearwith all dynamic scaling and FSDP2with float8 all-gather.
To make the tests pass, fixes a bug with initilization ordering in
Float8Linear.from_float, we need to have the right forward configset before stashing it on the weight wrapper.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D59305793