Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add context manager to FSDP for easier child module wrapping #446

Merged
merged 25 commits into from Mar 2, 2021
Merged

[feat] Add context manager to FSDP for easier child module wrapping #446

merged 25 commits into from Mar 2, 2021

Conversation

SeanNaren
Copy link

@SeanNaren SeanNaren commented Feb 27, 2021

What does this PR do?

As discussed in Lightning-AI/pytorch-lightning#6152 (comment) this adds a context manager that assists in making child modules with similar defaults.

from fairscale.nn.misc import enable_wrap, wrap

with enable_wrap(**handleful_of_important_params):
    layer_1 = wrap(torch.nn.Linear(5, 5))
    layer_2 = wrap(torch.nn.Linear(5, 5), flatten_parameters=True) # Override parameters if you'd like

...

# without the context manager, creates Linear layer
layer_1 = wrap(torch.nn.Linear(5, 5))

If not within the FSDP context, this would be a no-op. This makes it easier to annotate layers without having to copy any changes in parameters.

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 27, 2021
Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution! I have some comments below.

fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
tests/nn/data_parallel/test_fsdp.py Outdated Show resolved Hide resolved
layer = FullyShardedDataParallel.auto_wrap(torch.nn.Linear(5, 5))
assert isinstance(layer, torch.nn.Linear)

def test_auto_wrap_override_defaults(self):
Copy link
Contributor

@min-xu-ai min-xu-ai Feb 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, is it worth testing nested wrapping like:

with FSDP.config_context()
    FSDP.wrap_if_in...
        with FSDP.config_context()
               FSDP.wrap_if_in

Copy link
Contributor

@myleott myleott Feb 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now maybe we can just raise an exception for nested configuration contexts? I don’t expect that to be needed for most users, and it could introduce weird corner cases.

@min-xu-ai min-xu-ai changed the title Add context manager to FSDP for easier child module wrapping [feat] Add context manager to FSDP for easier child module wrapping Feb 27, 2021
@myleott
Copy link
Contributor

myleott commented Feb 28, 2021

Thanks for doing this!

First, I just want to clarify some important details about this API and the use case. This API will be a core API for end users (e.g., Lightning users, Fairseq users) to support initializing really large models, where it's critical to shard layers as they are initialized to avoid system OOM. Thus, even after we solve "auto-determining which parts to wrap", we will still need this API to avoid system OOMs.

Re: naming, since this will be the main user-facing interface for most framework users (and will be peppered all over people's model init code), we should make the name really short and easy to type, and not likely to cause too much reformatting when using black (line_length=88). And at most one underscore 😄

How about just FSDP.wrap? Then later if we introduce some other functionality to automatically wrap layers we still have the option to name that new functionality auto_wrap.

Re: context manager, maybe just with FSDP.enable_wrap(...)?

@SeanNaren
Copy link
Author

Thanks guys!! I'll address all the points ASAP @min-xu-ai and I agree with your view @myleott! I'll change the names to reflect this

I got distracted into trying to auto-wrap recursively based on parameter sizes, and I think I nearly have it! Just some weirdness when precision=mixed (@myleott @min-xu-ai for the nested FSDP wrappers, should mixed precision be set to True or just the outer layer?)

Probably some kinks to iron out but it looks like this on top of what's been added already:

    @staticmethod
    def wrap_module(x: nn.Module, max_params, **kwargs):
        num_params = sum([p.numel() for p in x.parameters()])

        if len(list(x.named_children())) == 0:
            # If the module has no children, no need to recurse, wrap it if needed
            if num_params > max_params:
                return FullyShardedDataParallel.wrap(x, **kwargs), num_params
            return x, 0

        if num_params >= max_params:
            total_wrapped_params = 0
            # Iterate through the children, recursively wrap if necessary
            for name, module in x.named_children():
                wrapped_module, num_wrapped_params = FullyShardedDataParallel.wrap_module(module, max_params, **kwargs)
                setattr(x, name, wrapped_module)
                # Keep track of how many parameters have been wrapped
                total_wrapped_params += num_wrapped_params
            # decide if we need to wrap the current module, since the left over parameters exceed the number of params to wrap
            remainder = num_params - total_wrapped_params
            if remainder >= max_params:
                return FullyShardedDataParallel.wrap(x, **kwargs), num_params
            else:
                return x, total_wrapped_params
        return x, 0

    @staticmethod
    def auto_wrap(module, recursive=False, num_params=1e8, **kwargs):
        if FullyShardedAutoWrap.in_autowrap_context and recursive:
            wrapped_module, remainder = FullyShardedDataParallel.wrap_module(
                module,
                max_params=num_params,
                **kwargs
            )
            return wrapped_module
        return FullyShardedDataParallel.wrap(module, **kwargs)

    @staticmethod
    def wrap(module, **kwargs):
        if FullyShardedAutoWrap.in_autowrap_context:
            kwargs = {**FullyShardedAutoWrap.kwargs, **kwargs}
            return FullyShardedDataParallel(module, **kwargs)
        return module

This will allow you to set some number threshold and wrap all modules within greater than that, trying to wrap the deepest child first. Not sure if it makes sense and probably requires a small visualization to go with.

with FSDP.enable_auto_wrap(**handleful_of_important_params):
    layer_1 = FSDP.auto_wrap(TransformerBlock(self.config), max_params=1e8, recursive=True)

@min-xu-ai
Copy link
Contributor

Thanks for the notes on OOM use case. I think it should go into the docstring.

Just exploring a bit here, I wonder if we can do away without an new function. Checkout here:

https://stackoverflow.com/questions/3209233/how-to-replace-an-instance-in-init-with-a-different-object

Basically, we can have FDSP() return the unwrapped class if it is passed with a passthrough=True option. Therefore, we have two modes:

  • normal mode, DDP replacement
m = FSDP(m, options ...)

or manual nesting:

m = FSDP(FSDP(m, options ...), options...)
  • maybe wrap mode
with nesting(options ...):
   FSDP(FSDP(m, overrides...), overrides...)

In the second mode, we have the nesting() context, which is a like a stack. outer FSDP will push their arguments on to the stack and inner one will take the top of the stack options and union with its own overrides and make an object. If the option passthrough=True, then it will just return the underlying object without construct a new one.

Does this simplify the user facing API? I think it will only need a single additional function called nesting for the context. No new wrapper function is needed?

@min-xu-ai
Copy link
Contributor

How about just FSDP.wrap? Then later if we introduce some other functionality to automatically wrap layers we still have the option to name that new functionality auto_wrap.

Re: context manager, maybe just with FSDP.enable_wrap(...)?

if the class.new trick doesn't work, perhaps we can use consider factory and enable_factory since this is a pattern of object factory? Also, do these two functions need to be attached to the class?

@min-xu-ai
Copy link
Contributor

@SeanNaren you are amazing!

for the nested FSDP wrappers, should mixed precision be set to True or just the outer layer?)

I have used mixed precision and nesting and inner ones and the outer one all have mixed_precesion=True and it seems to work fine in VISSL.

@SeanNaren
Copy link
Author

with nesting(options ...):
   FSDP(FSDP(m, overrides...), overrides...)

In the second mode, we have the nesting() context, which is a like a stack. outer FSDP will push their arguments on to the stack and inner one will take the top of the stack options and union with its own overrides and make an object. If the option passthrough=True, then it will just return the underlying object without construct a new one.

Does this simplify the user facing API? I think it will only need a single additional function called nesting for the context. No new wrapper function is needed?

I was eyeing this as a solution as well, we could also detect we're in a context and do some magic, but I wasn't sure if we'd be up for that. At the minimum the current implementation works reasonable, but am curious to look down this path!

@SeanNaren you are amazing!

for the nested FSDP wrappers, should mixed precision be set to True or just the outer layer?)

I have used mixed precision and nesting and inner ones and the outer one all have mixed_precesion=True and it seems to work fine in VISSL.

Thanks @min-xu-ai :) :)

I've pushed the code as well as addressed a few code review changes. I need to get it working with my small reproducible script with precision set to mixed and I'll be able to benchmark, but if you could get a chance to use it would be awesome!

@myleott
Copy link
Contributor

myleott commented Feb 28, 2021

with nesting(options ...):
   FSDP(FSDP(m, overrides...), overrides...)

In the second mode, we have the nesting() context, which is a like a stack. outer FSDP will push their arguments on to the stack and inner one will take the top of the stack options and union with its own overrides and make an object. If the option passthrough=True, then it will just return the underlying object without construct a new one.

Does this simplify the user facing API? I think it will only need a single additional function called nesting for the context. No new wrapper function is needed?

I was eyeing this as a solution as well, we could also detect we're in a context and do some magic, but I wasn't sure if we'd be up for that.

Yeah, I'd be reluctant to adopt the more magical interface, mostly because it will make the overrides more opaque. Consider the following example:

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = FSDP(nn.Linear(5, 10))  # does this imply the default options for FSDP, i.e., mixed_precision=False, flatten_parameters=True?

with nesting(mixed_precision=True, flatten_parameters=False):
    net = MyModule()  # will net.l1 have mixed precision? flatten parameters?

With a separate function it seems clearer that wrap(**overrides) only applies overrides to the config.

do these two functions need to be attached to the class?

One less import for user code 😄

Copy link
Contributor

@myleott myleott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving a few suggestions, will start experimenting with this in the meantime

tests/nn/data_parallel/test_fsdp_context.py Outdated Show resolved Hide resolved
yield

@staticmethod
def wrap(module, **kwargs):
Copy link
Contributor

@myleott myleott Feb 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's rename **kwargs here to **overrides or **fsdp_overrides to better reflect the behavior.

Also please add docstring:

Suggested change
def wrap(module, **kwargs):
def wrap(module, **fsdp_overrides):
"""
Annotate that a module should be wrapped with FullyShardedDataParallel.
Annotated modules will only be wrapped if inside of an ``enable_wrap``
context manager. An important use case is annotating large layers that
should be sharded (in-place) during initialization, to avoid running out
of system memory.
Args:
module (nn.Module): module to wrap (if in ``enable_wrap`` context)
**fsdp_overrides: FSDP configuration overrides that will take
priority over the values provided by the ``enable_wrap`` context
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for line 1103-1104, what will happen if outer wrap and inter wrap both have same override key but different values?

fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
@min-xu-ai
Copy link
Contributor

One less import for user code 😄

I see. :-) It is a trade-off: it will result in longer lines and (potentially alot) more typing when using in a file after the saving on the imports; just saying. :-)

@myleott
Copy link
Contributor

myleott commented Feb 28, 2021

I see. :-) It is a trade-off: it will result in longer lines and (potentially alot) more typing when using in a file after the saving on the imports; just saying. :-)

Ah, actually you are right. You mean something like:

from fairscale.nn.data_parallel.fully_sharded_data_parallel import wrap

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = wrap(nn.Linear(5, 10))

That's even less chars, let's definitely support that 😄 Should be easy to support both if we want, although maybe it's simpler to just decouple it from the FSDP class as you originally suggested. @SeanNaren, thoughts?

@min-xu-ai
Copy link
Contributor

min-xu-ai commented Feb 28, 2021

from fairscale.nn.data_parallel.fully_sharded_data_parallel import wrap

We can put it in __init__.py so, it will be

from fairscale.nn.data_parallel import wrap

@@ -0,0 +1,52 @@
import torch.nn as nn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a test here for recursive_wrap? It would be good to have a simple one for now, can extend as we see more use cases. it is tricky since we perhaps don't want to wrap inner module where the outer module reaches into the inner one during forward().

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a very rudimentary test! I think currently I see auto_wrap as providing very little flexibility (and a few assumptions that a sequential module has, i.e guarantees we're just calling the forward function) and wrap doing the most flexible of things.

The great thing is we can interchange/swap them as we see fit, i.e auto_wrap the modules we know are supported and do not do any child module magic other than call forward, and just use wrap for the cases where we'd not want to modify the child modules.

@SeanNaren
Copy link
Author

I see. :-) It is a trade-off: it will result in longer lines and (potentially alot) more typing when using in a file after the saving on the imports; just saying. :-)

Ah, actually you are right. You mean something like:

from fairscale.nn.data_parallel.fully_sharded_data_parallel import wrap

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = wrap(nn.Linear(5, 10))

That's even less chars, let's definitely support that 😄 Should be easy to support both if we want, although maybe it's simpler to just decouple it from the FSDP class as you originally suggested. @SeanNaren, thoughts?

I think I understand the API, I've made the changes now it looks like:

from fairscale.nn import enable_wrap, wrap, auto_wrap

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = wrap(nn.Linear(5, 10))

with enable_wrap():
    module = MyModule()

Let me know if this looks good and the code lives in the right places. I have a few things kinks to iron out, as well as figure out why in Mixed Precision my benchmark test fails with auto wrap.

Thanks for the review guys!

fairscale/nn/__init__.py Outdated Show resolved Hide resolved
return module


def auto_wrap(module, min_num_params: float = 1e8, cls: Callable = FullyShardedDataParallel, **kwargs):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this works, we're going to need a more sensible default here, I just chose this randomish

@myleott
Copy link
Contributor

myleott commented Mar 1, 2021

Awesome! Will review this shortly. CI is also complaining about some lint, would be good to run the alias that @min-xu-ai shared above.

Re: “why in Mixed Precision my benchmark test fails with auto wrap” is this referring to an external benchmark or a unit test? I can try to help debug if you can share more details

@SeanNaren
Copy link
Author

SeanNaren commented Mar 1, 2021

Awesome! Will review this shortly. CI is also complaining about some lint, would be good to run the alias that @min-xu-ai shared above.

Re: “why in Mixed Precision my benchmark test fails with auto wrap” is this referring to an external benchmark or a unit test? I can try to help debug if you can share more details

Will do!

Regarding reproducing, I'm using Will's minGPT lightning repo to benchmark/test. If you don't mind installing lightning:

git clone https://github.com/SeanNaren/minGPT.git && cd minGPT && git checkout fully_sharded
pip install git+https://github.com/PyTorchLightning/pytorch-lightning.git#feat/fsdp # Install fork https://github.com/PyTorchLightning/pytorch-lightning/pull/6152
pip install git+https://github.com/SeanNaren/fairscale.git#feat/fsdp_context_manager

python benchmark.py --n_layer 1 --n_head 16 --n_embd 8192 --gpus 1 --limit_train_batches 120 --precision 16 # breaks
python benchmark.py --n_layer 1 --n_head 16 --n_embd 8192 --gpus 1 --limit_train_batches 120 --precision 32 # runs fine

The error from the first command:

Traceback (most recent call last):
  File "benchmark.py", line 96, in <module>
    trainer.fit(model, train_loader)
  File "/home/sean/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 489, in fit
    self.dispatch()
  File "/home/sean/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 529, in dispatch
    self.accelerator.start_training(self)
  File "/home/sean/pytorch-lightning/pytorch_lightning/accelerators/accelerator.py", line 82, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/home/sean/pytorch-lightning/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 118, in start_training
    self._results = trainer.run_train()
  File "/home/sean/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 612, in run_train
    self.train_loop.run_training_epoch()
  File "/home/sean/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 491, in run_training_epoch
    batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
  File "/home/sean/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 652, in run_training_batch
    self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
  File "/home/sean/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 424, in optimizer_step
    model_ref.optimizer_step(
  File "/home/sean/pytorch-lightning/pytorch_lightning/core/lightning.py", line 1395, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/sean/pytorch-lightning/pytorch_lightning/core/optimizer.py", line 219, in step
    self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
  File "/home/sean/pytorch-lightning/pytorch_lightning/core/optimizer.py", line 135, in __optimizer_step
    trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
  File "/home/sean/pytorch-lightning/pytorch_lightning/accelerators/accelerator.py", line 285, in optimizer_step
    make_optimizer_step = self.precision_plugin.pre_optimizer_step(
  File "/home/sean/pytorch-lightning/pytorch_lightning/plugins/precision/native_amp.py", line 78, in pre_optimizer_step
    lambda_closure()
  File "/home/sean/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 646, in train_step_and_backward_closure
    result = self.training_step_and_backward(
  File "/home/sean/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 747, in training_step_and_backward
    self.backward(result, optimizer, opt_idx)
  File "/home/sean/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 776, in backward
    result.closure_loss = self.trainer.accelerator.backward(
  File "/home/sean/pytorch-lightning/pytorch_lightning/accelerators/accelerator.py", line 268, in backward
    output = self.precision_plugin.backward(
  File "/home/sean/pytorch-lightning/pytorch_lightning/plugins/precision/native_amp.py", line 59, in backward
    closure_loss = super().backward(model, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs)
  File "/home/sean/pytorch-lightning/pytorch_lightning/plugins/precision/precision_plugin.py", line 71, in backward
    model.backward(closure_loss, optimizer, opt_idx)
  File "/home/sean/pytorch-lightning/pytorch_lightning/core/lightning.py", line 1259, in backward
    loss.backward(*args, **kwargs)
  File "/home/sean/miniconda3/lib/python3.8/site-packages/torch/tensor.py", line 245, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/sean/miniconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 145, in backward
    Variable._execution_engine.run_backward(
SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f278e0788a0> returned NULL without setting an error

@SeanNaren
Copy link
Author

@myleott I noticed that when we wrap children modules we'll call this part of the code every time:

        if self.mixed_precision:
            args, kwargs = cast_inputs_to_fp16(*args, **kwargs)

Do we want to do this on the children modules as well? In minGPT the inputs are of type torch.long so they won't be converted, but the children modules will receive float I think

@SeanNaren
Copy link
Author

Thanks @myleott this fixes the error I was running into, I made a PR here with your suggestion: #452

@myleott
Copy link
Contributor

myleott commented Mar 1, 2021

I have been running into a few other issues with auto-wrapping btw:

  1. wrapping an nn.ModuleList breaks because FSDP doesn't implement __len__ or __iter__. This is a bit tricky since for __iter__ we'd actually want to expand the params before iterating, since we'll never call the FSDP.forward method.
  2. shared params can be trickly. For example, an encoder-decoder architecture with shared embeddings on the encoder input, decoder input and decoder output. In this case wrapping the encoder and decoder separately can break sharing. I think we should ideally exclude shared params from auto_wrap and require these to be handled by the outer-most FSDP wrapper.
  3. certain modules should be forced to be leafs and their children should not be wrapped. For example, one should not wrap the nn.Linear inside nn.MultiheadAttention, since it will break because MHA.forward accesses self.out_proj.weight.

For (1), perhaps have a list of disallowed Module types to wrap? e.g., nn.ModuleList, nn.ModuleDict. Wrapping children of these is fine.
For (2), we could detect shared params and remove them from auto_wrap, but for now maybe just a note in the docs saying this doesn't work well with shared params is fine.
For (3), we can ideally have some Module types that are forced to be leafs, e.g., nn.MultiheadAttention

@SeanNaren
Copy link
Author

SeanNaren commented Mar 1, 2021

Nice! Thanks for testing this out, all this makes sense.

Blacklist certain functions that we cannot wrap, have this list pass through to the recursive wrap function, add a test?

I'm thinking we'll have a blacklist in FairScale to handle most modules, and allow the user to add their own overrides if necessary? What do you think?

with enable_wrap(autowrap_blacklist=[MyModuleThatShouldntBeWrapped, ...]):
    ...

Regarding 2, Where in the docs should I add this? I think it would be nice to summarize all the info you added in the FSDP PR into a page, I could add the details there as well. If I get time I can do the first part as well, let me know

EDIT: Using the minGPT example I'm seeing better throughput/scaling (can fit a much larger model without CPU offload, can't use CPU offload right now cos scaler doesn't work)

@myleott
Copy link
Contributor

myleott commented Mar 1, 2021

Regarding 2, Where in the docs should I add this? I think it would be nice to summarize all the info you added in the FSDP PR into a page, I could add the details there as well. If I get time I can do the first part as well, let me know

@sshleifer will be working on a more detailed README under fairscale/nn/data_parallel.

For now we can put it in the auto_wrap docstring. I will push some small local changes to this branch.

@myleott myleott self-requested a review March 1, 2021 16:23
@SeanNaren
Copy link
Author

Hey @myleott added blacklisting, and will need to update doc strings. We might be able to do the blacklisting cleaner, but let me know what you think!

@SeanNaren
Copy link
Author

Closer to what was discussed I added an additional list to handle nodes we should not wrap but still wrap children, and nodes we shouldn't wrap including their children. This is FSDP_MODULE_EXCLUDE_WRAP and FSDP_MODULE_BLOCKLIST respectively.

Let me know if the names need changing and if we'd like to expose the additional list. thanks for your input here Myles! really appreciate the collab on this :)

Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like this new version a lot so far. Just want to send the comment on the circular import first.

Comment on lines +796 to +797
if t.requires_grad:
t.register_hook(_pre_backward_hook)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to confirm, without the if, things still work, just useful callback is registered? If not, perhaps added a comment?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @myleott who I think added this

Copy link
Contributor

@myleott myleott Mar 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the if is actually a bugfix for modules that return outputs which don't require grad. In this case pytorch will raise an exception when you call register_hook. I only discovered it because auto_wrap sometimes puts FSDP on module boundaries I hadn't tested before :)

fairscale/nn/misc/auto_wrap.py Outdated Show resolved Hide resolved
Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is turning out really really good. I like it a lot. Thanks again for the contribution! Approving with only minor suggestions.

fairscale/nn/misc/auto_wrap.py Outdated Show resolved Hide resolved
fairscale/nn/misc/auto_wrap.py Outdated Show resolved Hide resolved
fairscale/nn/misc/auto_wrap.py Outdated Show resolved Hide resolved
fairscale/nn/misc/auto_wrap.py Outdated Show resolved Hide resolved
fairscale/nn/misc/auto_wrap.py Outdated Show resolved Hide resolved
fairscale/nn/misc/auto_wrap.py Outdated Show resolved Hide resolved
fairscale/nn/misc/auto_wrap.py Outdated Show resolved Hide resolved
tests/nn/misc/test_wrap.py Outdated Show resolved Hide resolved
@SeanNaren
Copy link
Author

Any suggestion on default min_num_params? I set it to 100M as default which seemed reasonable, but was curious if you guys had any intuition around this.

@myleott
Copy link
Contributor

myleott commented Mar 2, 2021

Any suggestion on default min_num_params? I set it to 100M as default which seemed reasonable, but was curious if you guys had any intuition around this.

100M seems reasonable for now. I tested this on the 210M param Attention Is All You Need transformer and it's actually fastest not to have any child wrapping, but 100M causes the encoder and decoder to get wrapped separately, which is fine too (almost the same performance, and yields some memory savings).

Copy link
Contributor

@myleott myleott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM too! @SeanNaren anything remaining or shall we merge this?

@SeanNaren
Copy link
Author

SeanNaren commented Mar 2, 2021

It's ready to go from my end! Thanks Myle :)

If we can get a release sometime this week that would be great too!

@myleott myleott merged commit f335955 into facebookresearch:master Mar 2, 2021
@SeanNaren SeanNaren deleted the feat/fsdp_context_manager branch March 2, 2021 15:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants