In [4]:
import torch
torch.__version__

'1.13.0.dev20220711+cu113'

FSDP Mixed Precision supports BFloat16 and FP16 with fine grained policies that control paramaters, gradient communications and buffers

In [5]:
# import Mixed Precision class along with FSDP import:

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
)

In [6]:
# create a policy - one for Bfloat16 is shown:
bfSixteen = MixedPrecision(
        # Param precision
        param_dtype=torch.bfloat16,
        # Gradient communication precision.
        reduce_dtype=torch.bfloat16,
        # Buffer precision.
        buffer_dtype=torch.bfloat16,
    )

# you can mix types:
comboPolicy = MixedPrecision(
        # Param precision
        param_dtype=torch.bfloat16,
        # Gradient communication precision.
        reduce_dtype=torch.float32,
        # Buffer precision.
        buffer_dtype=torch.float32,
    )


In [7]:
# Details:

# 1 - BatchNorm is automatically kept in fp32 for precision (overrides buffer policy, no user action needed)
# 2 - Local gradients during backprop are also always fp32 (automatic, no user action needed)
# 3 - Models are always saved in fp32 format for max portability

In [11]:
# bfloat16 support verification imports (network and gpu native support)
from pkg_resources import packaging
import torch.cuda.nccl as nccl
import torch.distributed as dist

In [10]:
bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and packaging.version.parse(torch.version.cuda).release >= (11, 0):
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
)

  and LooseVersion(torch.version.cuda) >= "11.0"
