In [1]:
import torch
torch.__version__

'1.13.0.dev20220711+cu113'

FSDP Mixed Precision supports BFloat16 and FP16 with fine grained policies that control paramaters, gradient communications and buffers

In [5]:
# import Mixed Precision class along with FSDP import:

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
)

In [24]:
# create a policy - one for Bfloat16 is shown:

bfloatPolicy = MixedPrecision(
        # Param precision
        param_dtype=torch.bfloat16,
        # Gradient communication precision.
        reduce_dtype=torch.bfloat16,
        
    )

# you can mix types:
comboPolicy = MixedPrecision(
        # Param precision
        param_dtype=torch.bfloat16,
        # Gradient communication precision.
        reduce_dtype=torch.float32,
        # Buffer precision.
        buffer_dtype=torch.float32,
    )


In [23]:
# then simply pass the policy in during FSDP init:
# ----- main FSDP init -----------
model = FSDP(
        model,
        auto_wrap_policy=my_auto_wrap_policy,

        mixed_precision=bfloatPolicy,    #  < --------- mixed precision policy
        
        backward_prefetch=prefetch_policy,
        sharding_strategy=cfg.sharding_strategy,
        device_id=torch.cuda.current_device(),
        forward_prefetch=True,
    )

### Bfloat offers significant speedup of training - nearly 2x, can go higher based on memory
![val loss comparison](./bfloat_training.png)


### refresher on the mixed precision types
![](./datatypes_mp.png)
Image credit: Nvidia - https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/

In [7]:
# Details:

# 1 - BatchNorm is automatically kept in fp32 for precision (overrides buffer policy, no user action needed)
# 2 - Local gradients during backprop are also always fp32 (automatic, no user action needed)
# 3 - Models are always saved in fp32 format for max portability

In [11]:
# bfloat16 support verification imports (network and gpu native support)
from pkg_resources import packaging
import torch.cuda.nccl as nccl
import torch.distributed as dist

In [20]:
verify_bfloat_support = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and packaging.version.parse(torch.version.cuda).release >= (11, 0)
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
)

# simple check =
basic_bfloat_ready = torch.cuda.is_bf16_supported()   # does not confirm network can handle it, just gpu native support

Important point - always verify native bfloat support is available!  
V100 GPU's will 'support' bfloat if you just run without checking, but its emulated 
and training will run much, much slower (worse than fp32!). 

In [22]:
verify_bfloat_support

True

In [None]:
# fp16 requires not just adding a policy...requires the sharded grad scaler:
if cfg.use_fp16:
        from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
        scaler = ShardedGradScaler()


# in training loop:
loss = output["loss"]
if scaler:
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()  # adjust scaling for next minibatch
    else:
        loss.backward()
        optimizer.step()

In [2]:
# Recommend using BFloat16 if possible.  
# FP16 runs 4% slower vs Bfloat16, all things equal, likely due to cost of rescaling. 
# Rescaler has to play guessing game of how much to rescale, 
# bad guesses mean that mini-batch is tossed due to having NAN values (inefficient)

In [3]:
# tf32 can be used as well, but is not controlled atm via FSDP policy. 

In [None]:
# The flag below controls whether to allow TF32 on matrix multiplies. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# not as fast as BFloat16, but faster than FP32 (10 bits for precision vs 7 for Bfloat)
# Even Nvidia notes bfloat is faster - 
# "For maximum performance, the A100 also has enhanced 16-bit math capabilities.  
# It supports both FP16 and Bfloat16 (BF16) at double the rate of TF32. "
