New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Add context manager to FSDP for easier child module wrapping #446
[feat] Add context manager to FSDP for easier child module wrapping #446
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your contribution! I have some comments below.
tests/nn/data_parallel/test_fsdp.py
Outdated
layer = FullyShardedDataParallel.auto_wrap(torch.nn.Linear(5, 5)) | ||
assert isinstance(layer, torch.nn.Linear) | ||
|
||
def test_auto_wrap_override_defaults(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, is it worth testing nested wrapping like:
with FSDP.config_context()
FSDP.wrap_if_in...
with FSDP.config_context()
FSDP.wrap_if_in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now maybe we can just raise an exception for nested configuration contexts? I don’t expect that to be needed for most users, and it could introduce weird corner cases.
Thanks for doing this! First, I just want to clarify some important details about this API and the use case. This API will be a core API for end users (e.g., Lightning users, Fairseq users) to support initializing really large models, where it's critical to shard layers as they are initialized to avoid system OOM. Thus, even after we solve "auto-determining which parts to wrap", we will still need this API to avoid system OOMs. Re: naming, since this will be the main user-facing interface for most framework users (and will be peppered all over people's model init code), we should make the name really short and easy to type, and not likely to cause too much reformatting when using black (line_length=88). And at most one underscore 😄 How about just Re: context manager, maybe just |
Thanks guys!! I'll address all the points ASAP @min-xu-ai and I agree with your view @myleott! I'll change the names to reflect this I got distracted into trying to auto-wrap recursively based on parameter sizes, and I think I nearly have it! Just some weirdness when precision=mixed (@myleott @min-xu-ai for the nested FSDP wrappers, should mixed precision be set to True or just the outer layer?) Probably some kinks to iron out but it looks like this on top of what's been added already: @staticmethod
def wrap_module(x: nn.Module, max_params, **kwargs):
num_params = sum([p.numel() for p in x.parameters()])
if len(list(x.named_children())) == 0:
# If the module has no children, no need to recurse, wrap it if needed
if num_params > max_params:
return FullyShardedDataParallel.wrap(x, **kwargs), num_params
return x, 0
if num_params >= max_params:
total_wrapped_params = 0
# Iterate through the children, recursively wrap if necessary
for name, module in x.named_children():
wrapped_module, num_wrapped_params = FullyShardedDataParallel.wrap_module(module, max_params, **kwargs)
setattr(x, name, wrapped_module)
# Keep track of how many parameters have been wrapped
total_wrapped_params += num_wrapped_params
# decide if we need to wrap the current module, since the left over parameters exceed the number of params to wrap
remainder = num_params - total_wrapped_params
if remainder >= max_params:
return FullyShardedDataParallel.wrap(x, **kwargs), num_params
else:
return x, total_wrapped_params
return x, 0
@staticmethod
def auto_wrap(module, recursive=False, num_params=1e8, **kwargs):
if FullyShardedAutoWrap.in_autowrap_context and recursive:
wrapped_module, remainder = FullyShardedDataParallel.wrap_module(
module,
max_params=num_params,
**kwargs
)
return wrapped_module
return FullyShardedDataParallel.wrap(module, **kwargs)
@staticmethod
def wrap(module, **kwargs):
if FullyShardedAutoWrap.in_autowrap_context:
kwargs = {**FullyShardedAutoWrap.kwargs, **kwargs}
return FullyShardedDataParallel(module, **kwargs)
return module This will allow you to set some number threshold and wrap all modules within greater than that, trying to wrap the deepest child first. Not sure if it makes sense and probably requires a small visualization to go with. with FSDP.enable_auto_wrap(**handleful_of_important_params):
layer_1 = FSDP.auto_wrap(TransformerBlock(self.config), max_params=1e8, recursive=True) |
Thanks for the notes on OOM use case. I think it should go into the docstring. Just exploring a bit here, I wonder if we can do away without an new function. Checkout here: Basically, we can have FDSP() return the unwrapped class if it is passed with a
or manual nesting:
In the second mode, we have the Does this simplify the user facing API? I think it will only need a single additional function called |
if the class.new trick doesn't work, perhaps we can use consider |
@SeanNaren you are amazing!
I have used mixed precision and nesting and inner ones and the outer one all have mixed_precesion=True and it seems to work fine in VISSL. |
I was eyeing this as a solution as well, we could also detect we're in a context and do some magic, but I wasn't sure if we'd be up for that. At the minimum the current implementation works reasonable, but am curious to look down this path!
Thanks @min-xu-ai :) :) I've pushed the code as well as addressed a few code review changes. I need to get it working with my small reproducible script with precision set to mixed and I'll be able to benchmark, but if you could get a chance to use it would be awesome! |
Yeah, I'd be reluctant to adopt the more magical interface, mostly because it will make the overrides more opaque. Consider the following example: class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.l1 = FSDP(nn.Linear(5, 10)) # does this imply the default options for FSDP, i.e., mixed_precision=False, flatten_parameters=True?
with nesting(mixed_precision=True, flatten_parameters=False):
net = MyModule() # will net.l1 have mixed precision? flatten parameters? With a separate function it seems clearer that
One less import for user code 😄 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leaving a few suggestions, will start experimenting with this in the meantime
yield | ||
|
||
@staticmethod | ||
def wrap(module, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's rename **kwargs
here to **overrides
or **fsdp_overrides
to better reflect the behavior.
Also please add docstring:
def wrap(module, **kwargs): | |
def wrap(module, **fsdp_overrides): | |
""" | |
Annotate that a module should be wrapped with FullyShardedDataParallel. | |
Annotated modules will only be wrapped if inside of an ``enable_wrap`` | |
context manager. An important use case is annotating large layers that | |
should be sharded (in-place) during initialization, to avoid running out | |
of system memory. | |
Args: | |
module (nn.Module): module to wrap (if in ``enable_wrap`` context) | |
**fsdp_overrides: FSDP configuration overrides that will take | |
priority over the values provided by the ``enable_wrap`` context | |
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for line 1103-1104, what will happen if outer wrap and inter wrap both have same override key but different values?
I see. :-) It is a trade-off: it will result in longer lines and (potentially alot) more typing when using in a file after the saving on the imports; just saying. :-) |
Ah, actually you are right. You mean something like: from fairscale.nn.data_parallel.fully_sharded_data_parallel import wrap
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.l1 = wrap(nn.Linear(5, 10)) That's even less chars, let's definitely support that 😄 Should be easy to support both if we want, although maybe it's simpler to just decouple it from the FSDP class as you originally suggested. @SeanNaren, thoughts? |
We can put it in
|
@@ -0,0 +1,52 @@ | |||
import torch.nn as nn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a test here for recursive_wrap? It would be good to have a simple one for now, can extend as we see more use cases. it is tricky since we perhaps don't want to wrap inner module where the outer module reaches into the inner one during forward().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a very rudimentary test! I think currently I see auto_wrap
as providing very little flexibility (and a few assumptions that a sequential module has, i.e guarantees we're just calling the forward function) and wrap
doing the most flexible of things.
The great thing is we can interchange/swap them as we see fit, i.e auto_wrap
the modules we know are supported and do not do any child module magic other than call forward
, and just use wrap
for the cases where we'd not want to modify the child modules.
I think I understand the API, I've made the changes now it looks like: from fairscale.nn import enable_wrap, wrap, auto_wrap
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.l1 = wrap(nn.Linear(5, 10))
with enable_wrap():
module = MyModule() Let me know if this looks good and the code lives in the right places. I have a few things kinks to iron out, as well as figure out why in Mixed Precision my benchmark test fails with auto wrap. Thanks for the review guys! |
fairscale/nn/misc/auto_wrap.py
Outdated
return module | ||
|
||
|
||
def auto_wrap(module, min_num_params: float = 1e8, cls: Callable = FullyShardedDataParallel, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this works, we're going to need a more sensible default here, I just chose this randomish
Awesome! Will review this shortly. CI is also complaining about some lint, would be good to run the Re: “why in Mixed Precision my benchmark test fails with auto wrap” is this referring to an external benchmark or a unit test? I can try to help debug if you can share more details |
Will do! Regarding reproducing, I'm using Will's minGPT lightning repo to benchmark/test. If you don't mind installing lightning: git clone https://github.com/SeanNaren/minGPT.git && cd minGPT && git checkout fully_sharded
pip install git+https://github.com/PyTorchLightning/pytorch-lightning.git#feat/fsdp # Install fork https://github.com/PyTorchLightning/pytorch-lightning/pull/6152
pip install git+https://github.com/SeanNaren/fairscale.git#feat/fsdp_context_manager
python benchmark.py --n_layer 1 --n_head 16 --n_embd 8192 --gpus 1 --limit_train_batches 120 --precision 16 # breaks
python benchmark.py --n_layer 1 --n_head 16 --n_embd 8192 --gpus 1 --limit_train_batches 120 --precision 32 # runs fine
The error from the first command:
|
@myleott I noticed that when we wrap children modules we'll call this part of the code every time: if self.mixed_precision:
args, kwargs = cast_inputs_to_fp16(*args, **kwargs) Do we want to do this on the children modules as well? In minGPT the inputs are of type |
I have been running into a few other issues with auto-wrapping btw:
For (1), perhaps have a list of disallowed Module types to wrap? e.g., nn.ModuleList, nn.ModuleDict. Wrapping children of these is fine. |
Nice! Thanks for testing this out, all this makes sense. Blacklist certain functions that we cannot wrap, have this list pass through to the recursive wrap function, add a test? I'm thinking we'll have a blacklist in FairScale to handle most modules, and allow the user to add their own overrides if necessary? What do you think? with enable_wrap(autowrap_blacklist=[MyModuleThatShouldntBeWrapped, ...]):
... Regarding 2, Where in the docs should I add this? I think it would be nice to summarize all the info you added in the FSDP PR into a page, I could add the details there as well. If I get time I can do the first part as well, let me know EDIT: Using the minGPT example I'm seeing better throughput/scaling (can fit a much larger model without CPU offload, can't use CPU offload right now cos scaler doesn't work) |
@sshleifer will be working on a more detailed README under fairscale/nn/data_parallel. For now we can put it in the auto_wrap docstring. I will push some small local changes to this branch. |
Hey @myleott added blacklisting, and will need to update doc strings. We might be able to do the blacklisting cleaner, but let me know what you think! |
Closer to what was discussed I added an additional list to handle nodes we should not wrap but still wrap children, and nodes we shouldn't wrap including their children. This is Let me know if the names need changing and if we'd like to expose the additional list. thanks for your input here Myles! really appreciate the collab on this :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like this new version a lot so far. Just want to send the comment on the circular import first.
if t.requires_grad: | ||
t.register_hook(_pre_backward_hook) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to confirm, without the if, things still work, just useful callback is registered? If not, perhaps added a comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @myleott who I think added this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the if
is actually a bugfix for modules that return outputs which don't require grad. In this case pytorch will raise an exception when you call register_hook. I only discovered it because auto_wrap sometimes puts FSDP on module boundaries I hadn't tested before :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is turning out really really good. I like it a lot. Thanks again for the contribution! Approving with only minor suggestions.
Co-authored-by: Min Xu <24926999+min-xu-ai@users.noreply.github.com>
Any suggestion on default |
100M seems reasonable for now. I tested this on the 210M param Attention Is All You Need transformer and it's actually fastest not to have any child wrapping, but 100M causes the encoder and decoder to get wrapped separately, which is fine too (almost the same performance, and yields some memory savings). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM too! @SeanNaren anything remaining or shall we merge this?
It's ready to go from my end! Thanks Myle :) If we can get a release sometime this week that would be great too! |
What does this PR do?
As discussed in Lightning-AI/pytorch-lightning#6152 (comment) this adds a context manager that assists in making child modules with similar defaults.
If not within the FSDP context, this would be a no-op. This makes it easier to annotate layers without having to copy any changes in parameters.
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃