Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] fix for high GPU reserved memory (v2) #1052

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

ruanslv
Copy link
Contributor

@ruanslv ruanslv commented Aug 3, 2022

v2 for #972

The general idea is to try to make a guess on the FSDP module execution order and store that in two lists (one for the forward and another for the backward pass) during the first pass. Then, as opposed to letting the CPU run free and scheduling all GPU operations ahead of time (while reserving GPU memory for each), we only schedule the all-gather for the next module and wait until computations for the current module are finished before continuing. By scheduling the all-gather for the next module, we attempt to keep a good parallelism between data transfer and compute streams.

Given the lack of a global/static execution graph in Pytorch, my understanding on the best we can do here is a heuristic based on local information. Pytorch Distributed is also facing a similar issue and, given there is no perfect solution, arguments for/against each approach largely depend on the models you pick to measure success.

Advantages of current approach:
(1) Based on several comments in the original PR, it showed to significantly help different large scale runs that are memory bound. I'm attaching some profiles for a ~370M param transformer showing the new behaviour. As long as we get the execution order right and there is not a lot of variance in terms of how long each module takes to run, the parallelism across compute and data transfer is maintained.
(2) We only wait until computations are finished for the modules where reshard_after_forward = True, with the assumption that if this is not set then memory is not a limitation and we should let CPU continue. This should help preventing side-effects on models that do not care about memory.

Disadvantages:
(1) We may not always schedule the correct all-gather, which can cause run delays. One example is with activation checkpointing, where the execution needs to go further back in the model and execute a forward to recompute activations. The first module in this forward is not all-gathered ahead of time. Another example is if your execution order changes across different passes. So it is theoretically possible for this PR to cause performance degradations in some scenarios.
(2) We may under-schedule all-gathers, which can also cause performance degradation. This could happen on models where there is a big variation on execution time across different modules, and we end up having to wait for a long all-gather to finish after running a module that had very short computation time.

One alternative heuristic proposed by Min was to use the amount of available memory to make a call on whether to wait for the current module to finish or continue execution. However, I found it hard to find a justifiable memory threshold (either relative or absolute) that would work well across a large variety of cases, especially with a lack of a comprehensive benchmark to experiment with. Given we've already seen examples of the approach in this PR working well in practice, it seems safer to just go with this route instead and reevaluate if we find evidence of the contrary.

Original behavior: computation stream is always active, however CPU schedules everything at once.
Screenshot 2022-08-03 at 18 38 43

New behavior: computation stream still mostly active, with CPU scheduling one module at a time.
Screenshot 2022-08-03 at 18 38 16

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 3, 2022
@min-xu-ai
Copy link
Contributor

@ruanslv, thanks for the detailed description and explanation in this new PR! Kudos for adding a new test.

I haven't looked at the code changes in details but can you say a bit about the changes in the overlap fsdp test? That test was put in to specifically catch a performance issue. With the changes you make in that test, does it still catch the performance issue it intended to catch? You can "git log" on the file and find out which issue was it introduced to address. There was a bug fix together with the test I believe. You can probably undo the fix and see if the test catch the issue or not.

@ruanslv
Copy link
Contributor Author

ruanslv commented Aug 4, 2022

Sorry had missed SSD offload tests as they only run on a later Pytorch version compared to the one I was using. Fixed them with the last commit. The remaining CI failures are unrelated.

@ruanslv
Copy link
Contributor Author

ruanslv commented Aug 4, 2022

@min-xu-ai re: the overlap test, in the deleted part it was measuring that CPU is running ahead of GPU, which is exactly what we are disabling in this PR so that behaviour is not expected anymore. Further down, the test still measures whether all-gather and compute are running in parallel, which still should happen here. I couldn't think of relevant tests comparing CPU against GPU timing, happy to take suggestions

Original commit for reference: 8a42a8e

@zhaojuanmao
Copy link
Contributor

zhaojuanmao commented Aug 4, 2022

@ruanslv the traces are amazing and clearly showed the root cause is because of fast CPU thread. @mrshenli this also proved what we've discussed.

This seems to be a common problem for cudaCacheAllocator? it seems that FSDP use cases are hurt more by high reserved memory issue because these use cases are usually training large models, when reserved memory is high and fragmented, allocating a new big tensor from caches tends to fail more frequently and thus results in either OOM or slow cudaMalloc (a lot of cudaFree calls first).

@mrshenli will the pluggable cudaCacheAllocator help resolve the problem?

@awgu
Copy link

awgu commented Aug 13, 2022

Hi @ruanslv! Do you have a code pointer/hardware setup info for me to reproduce the same 370M parameter transformer and check the trace (for a different FSDP patch)? Also, are there any publicly available setups for replicating the high GPU reserved memory?

I think I have a similar but slightly different approach for this.

@stephenroller
Copy link

Does it always have issues with activation checkpointing? Is there a way we might address those? Activation checkpointing is pretty much non-optional for any any model 5B+ params.

@ruanslv
Copy link
Contributor Author

ruanslv commented Aug 16, 2022

Sorry folks, I missed the comments here. Please feel free to ping me on work chat directly too.

@awgu I used an internal benchmark we have at https://github.com/fairinternal/fairscale_benchmarks/blob/main/benchmarks/transformer_wikitext2_benchmark.py. This model doesn't necessarily help seeing the issue, I used it as an illustration / test environment for prefetching and CPU behavior. I'm not aware of any public setups for replicating the problem.

@stephenroller I couldn't find a way with prefetching to get around checkpointing. However, @zhaojuanmao organized a chat where we concluded that the optimal solution here is not to use pretetch + synchronize(). Instead, we should use explicit CUDA events to synchronize behaviour across different CUDA streams and help the CUDA cache allocator behave properly (i.e. reusing memory as opposed to reserving new one based on the CUDA event information).

I haven't found the cycles to explore this idea yet, but if it works it would be a generic / clean solution that doesn't require CPU waits.

awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 19, 2022
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 19, 2022
### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 19, 2022
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 19, 2022
### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 22, 2022
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 22, 2022
### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 22, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 22, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 22, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 22, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 22, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 22, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 23, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 23, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 23, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 23, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 23, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 23, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants