Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix for high gpu reserved memory #972

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

ngoyal2707
Copy link
Contributor

@ngoyal2707 ngoyal2707 commented Apr 11, 2022

Before this change:

Training 175b on 1024 A100s:

"ppl": "4217.39", "wps": "105123", "ups": "0.05", "wpb": "2.09715e+06", "bsz": "1024", "num_updates": "26", "lr": "1.09091e-06", "gnorm": "40.827", "clip": "100", "train_wall": "20", "max_cuda_gb_allocated": "13.4", "max_cuda_gb_reserved": "75.4", "current_cuda_gb_allocated": "1.9", "current_cuda_gb_reserved": "75.4", "cuda_gb_free": "65.9"

After the change:

"ppl": "4217.39", "wps": "107317", "ups": "0.05", "wpb": "2.09715e+06", "bsz": "1024", "num_updates": "26", "lr": "1.09091e-06", "gnorm": "40.827", "clip": "100", "train_wall": "20", "max_cuda_gb_allocated": "13.8", "max_cuda_gb_reserved": "22.1", "current_cuda_gb_allocated": "1.9", "current_cuda_gb_reserved": "22.1", "cuda_gb_free": "65.5"

No idea if the approach is perfect or not, but for above use case, I am getting slightly better WPS (this most likely could be
a flake though, and I'd expect them to match), much better cuda gb reserved (which I can almost explain now, where its being used) and ppl matches.

plus I can train much much larger models now. (will post additional details and results in post).

Will do additional testing

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 11, 2022
@ngoyal2707 ngoyal2707 assigned ngoyal2707 and unassigned ngoyal2707 Apr 11, 2022
@anj-s
Copy link
Contributor

anj-s commented Apr 11, 2022

Thanks Naman! Looks great and we should definitely incorporate. Couple of questions:

  1. This still uses the prefetch mechanism right? Can you explain the approach a bit.
  2. Curious why the cuda_gb_free does not increase even when the reserved value decreases so drastically? Maybe we are logging something else but I would expect the value of free to be higher.

@@ -1978,8 +2001,8 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:

if self.move_params_to_cpu and (self.params[0].dtype == self.compute_dtype):
self._free_fp16_param_shard([p])

torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious about this change? Why do we condition on the boolean?

@ngoyal2707
Copy link
Contributor Author

  1. Yes it is using prefetch approach, there could maybe a way to achieve above without prefetch, but I thought hard about it and couldn't see it

  2. that is just by how we log it: https://github.com/fairinternal/fairseq-py/blob/main/fairseq/trainer.py#L1060

@min-xu-ai
Copy link
Contributor

cc @zhaojuanmao

@zhaojuanmao
Copy link
Contributor

@ngoyal2707 are you disabling multiple process groups for reduce_scattter and all_gather in the backward, instead, using the backward prefetch approach here?

PyTorch version of FSDP is doing the backward prefetch right now. This will better avoid over prefetching more layers in the backward pass using multiple process group approach

@zhaojuanmao
Copy link
Contributor

but will it be conflict with multi process group approach? will the backward prefetch be a config?

@stephenroller
Copy link

Tried this on AWS. Find it doubles roughly WPS for my particular workload. (1700 -> 3200)

@@ -2047,6 +2070,7 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
# Storage object and unshard it in-place. For now, just resize
# the Storage to 0 to save memory.
free_storage_(p._full_param_padded)
torch.cuda.current_stream().synchronize()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngoyal2707 I guess this helped releasing GPU reserved memory...

For @stephenroller models, were your training slowed down by high reserved memory and slow CUDA cache allocator? if so, maybe this change helped a lot...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but I'm wondering whether this will result in regression for other models which are not suffered from high reserved memory and slow CUDA cache allocator issue

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah my jobs are fine tuned of a large pre trained model. We consistently saw on both 40gb and 80gb nodes that the memory reserved was always 99%

Comment on lines +1406 to +1413
if (
self._fsdp_forward_ordering is not None
and self._my_fsdp_instance_idx is not None and self._my_fsdp_instance_idx < len(self._fsdp_forward_ordering) - 1
):
self._fsdp_forward_ordering[self._my_fsdp_instance_idx + 1]._rebuild_full_params(
wait_for_all_gather=False
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice to add this forward prefetching!

@stephenroller
Copy link

Let's push this through

@anj-s
Copy link
Contributor

anj-s commented Apr 14, 2022

Let's push this through

To push this through we need an owner who can:

  1. Evaluate if we should move to the prefetch solution or if it will work with the existing 2 PG solution.
  2. Fix unit tests and make sure it works for FSDP in general and not just this use case.

Let me know if anyone is up for taking this on!

@zhaojuanmao
Copy link
Contributor

@anj-s and @stephenroller, I suspected the speedup is possibly mostly from this single sentence change "torch.cuda.current_stream().synchronize()" for @stephenroller's use case, possibly it is good to verify that. e..g, @stephenroller, just use FairScale master branch + this change "torch.cuda.current_stream().synchronize()" to your use case, and see whether there is a speed up.

But this change is better to be optional, not be default, otherwise other FSDP use cases that are not suffered from high GPU reserved memory issue possibly will get regression.

@anj-s
Copy link
Contributor

anj-s commented Apr 14, 2022

@anj-s and @stephenroller, I suspected the speedup is possibly mostly from this single sentence change "torch.cuda.current_stream().synchronize()" for @stephenroller's use case, possibly it is good to verify that. e..g, @stephenroller, just use FairScale master branch + this change "torch.cuda.current_stream().synchronize()" to your use case, and see whether there is a speed up.

But this change is better to be optional, not be default, otherwise other FSDP use cases that are not suffered from high GPU reserved memory issue possibly will get regression.

Can you explain a little bit about why you think there might be a regression?

@stephenroller
Copy link

I'm not entirely sure about this. We also have large scale, not latency bound pretraining, and this change increased WPS there--when ideal reserved memory is only 25% of the node. What's clear in both cases was we were thrashing the allocator aggressively.

This could maybe be a regression for places where the full state and activations fit in memory, and the network is able to fetch faster than the compute can utilize the state. But that case is better addressed with zero2 anyway, no?

@stephenroller
Copy link

@ngoyal2707 is clearly the owner and should push this through.

@zhaojuanmao
Copy link
Contributor

"What's clear in both cases was we were thrashing the allocator aggressively."

==> that mostly means "torch.cuda.current_stream().synchronize()" helps here, because CUDA allocation is a async operation, it could try to allocate a large storage before completely free full parameters here. So it is trashing. Adding this sync for your use case basically makes sure full parameters are freed before allocating. This helped reducing CUDA allocator trash.

For the use cases that are not suffering from slow cudaAllocator, this blocking operation will wait for all pending async operations and reduce parallelizations, and thus could result in regression, no?

@zhaojuanmao
Copy link
Contributor

So for unblocking this effort, suggestion here is:

  1. root cause whether prefetching or cuda sync really helps here first
  2. if add cuda sync, make it optional

@ngoyal2707
Copy link
Contributor Author

So I agree with everyone here. I should be the owner but honestly I still don't know if this is the best approach here and I want opinion of people who understands cuda and pytorch caching allocator better

The core reason of the issue is due to unbounded queue of launching kernels combined with multiple cuda streams.

One very simple pytorch example of this is following:

import torch
torch.cuda._sleep(100000000000)
for i in range(100):

    x=torch.zeros(4 * 1024 * 1024, device="cuda")
    del x
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())
# print(torch.cuda.memory_snapshot())

above code outputs at end:

0
16777216

That means, we only reserve 16MB of memory for the 4M fp32 tensor cause its only been used by single stream, so pytorch caching allocator can just assign every time the same memory block as operations within stream are sequential.

comparing above to following:

import torch
s1 = torch.cuda.Stream()

torch.cuda._sleep(100000000000)

with torch.cuda.stream(s1):
     torch.cuda._sleep(100000000000)

for i in range(100):
     x=torch.zeros(4 * 1024 * 1024, device="cuda")
     x.record_stream(s1)
     torch.cuda.current_stream().wait_stream(s1)
     del x

print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())
# print(torch.cuda.memory_snapshot())

This outputs:

0
1677721600

i.e. pytorch reserves 16MB tensor for every new allocation cause each memory block is being used by both current stream and s1 stream.

So in FSDP case, s1 is the all_gather stream, which is of course running slower than CPU thread. What I don't know is, if there is any better solution than just making CPU thread wait. We can change how it waits i.e. instead of just waiting on torch.cuda.current_stream().synchronize(), we can change to always launch new all_gather once previous one is done.

So I think there possibly can be some regression, specially to very small models cause GPU can be idle while CPU thread is launching new kernels.

@qingliu1111
Copy link

Want to add a data point about the performance improvement on AWS brought by this PR : )
I'm running fine-tuning flow for 175B model on AWS with 16 nodes (128GPU).

  • Without this PR: wps is ~ 2817.1
  • With this PR: wps is ~ 8817.1 (almost tripled the speed!)

More context: With similar training setup we were getting wps ~12000 on Azure. For 175B model fine-tuning, this PR helps us close a lot the speed gap between Azure and AWS.

@zhaojuanmao
Copy link
Contributor

adding @mrshenli who is looking into pluggable cudaCacheAllocator, and hopefully can resolve this issue in general

yf225 pushed a commit to yf225/fairscale that referenced this pull request Jun 3, 2022
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 18, 2022
### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 19, 2022
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 19, 2022
### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 19, 2022
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 19, 2022
### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 22, 2022
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 22, 2022
### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 22, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 22, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 22, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 22, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 22, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 22, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 23, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 23, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 23, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 23, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 23, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Aug 23, 2022
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants