In [2]:
# confirm PyTorch version 
import torch
torch.__version__

'1.13.0.dev20220711+cu113'

FSDP has three sharding strategies (with a 4th on the way).  These control the degree of sharding/redundancy (which lowers communication) for the model parameters, optimizer states and gradients. 

This is controlled with a single parameter input to FSDP, effectively making FSDP a universal training framework for a huge range of model sizes and server configs.

In [3]:
# import FSDP, including the class enum ShardingStrategy:
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    ShardingStrategy,
)

In [4]:
# Three available sharding strategies - tradeoff memory size vs communication overhead:
ShardingStrategy.FULL_SHARD # default!  Model, optimizer and gradient are all sharded (communicated) ... max model size support
ShardingStrategy.SHARD_GRAD_OP # Zero2 mode - model parameters are not freed after forward pass, reducing communication needs
ShardingStrategy.NO_SHARD  # DDP mode - each GPU keeps a full copy of the model, optimizer and gradients
                           # only grad synch needed

In [5]:
# Future support:
ShardingStrategy.HYBRID_SHARD   #FSDP Full shard within each node, but No Shard (DDP) between each nodes. 

In [None]:
# To use - just pass in desired sharding at FSDP init:
# ----- main FSDP init -----------
    model = FSDP(
        model,
        auto_wrap_policy=my_auto_wrap_policy,
        mixed_precision=mp_policy,
        backward_prefetch=prefetch_policy,
        # sharding control
        sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,  # Zero2 or DDP, or Full_Shard (FSDP default)

        device_id=torch.cuda.current_device(),
        forward_prefetch=True,
    )

An example using the same server (AWS A10, G5.48xlarge) to showcase how the sharding strategies support different model sizes for training:


![Max model size example](images/fsdp_sharding_strategies50.png)


![Max model size details:](images/fsdp_sharding_strategies_details.png)


In [1]:
# Best practice - compare sharding strategies for your specific model and server resources and network speed to optimize throughput 
# (see the gpu throughput maximization tutorial)