Using FSDP Checkpoint activations

In [1]:
# ensure you are on a June 18 or higher (nightly)
import torch
torch.__version__

'1.13.0.dev20220618+cu113'

In [3]:
#some basic imports
import torch
from functools import partial

In [11]:
# main FSDP imports
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    CPUOffload,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
    FullStateDictConfig,
    StateDictType,
)

In [5]:
# verify we have FSDP activation support ready by importing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing,
)

In [7]:
# first step - we have to make a check function to find what layers we want to checkpoint.
# For transformers, you'll want to use the same layers as you used for wrapping your transformer. 
# (Please view the using transformer wrapper tutorial if needed first).

# we'll checkpoint a DeepVit model, so we'll want to look for the Residual layer class.
from vit_pytorch import Residual



In [8]:
# second create the submodule check function as a lambda:
check_fn = lambda submodule: isinstance(submodule, Residual)

In [9]:
# create a non-reentrant wrapper.  
# This is basically to provide some options for the checkpoint wrapper, 
# and we use non-reentrant style for best performance.

non_reentrant_wrapper = partial(
    checkpoint_wrapper,
    offload_to_cpu=False,
    checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

In [None]:
# Important - the next step is actually to init your model with FSDP.  
# Activation checkpointing is shard aware, so it must be done ** after ** FSDP init:
model = FSDP(
        model,
        auto_wrap_policy=wrapping_policy,
        mixed_precision=mp_policy,
        sharding_strategy=model_sharding_strategy,
        device_id=torch.cuda.current_device(),  # streaming init
    )

In [None]:
# finally, we'll apply the checkpoint wrapper, and submodule check lamdba to your sharded model
#  to complete the activation checkpointing process:

apply_activation_checkpointing_wrapper(
        model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
    )

In [12]:
# that's it!  Your model is now both sharded and checkpoint activation ready. 

 #### best practices - 
 In general, you can expect to see roughly a 20-25% training time slowdown with activation checkpointing. 

 but you'll usually free up 33 - 38% GPU memory.  
 You can use that freed up memory by greatly increasing your batch size.
 The increase batch size can result in substantial (2-3x+) total training time improvements due to much greater throughput.    
 You can maximize your throughput with a bit of tuning to use up enough GPU memory but without creating cudaMalloRetries.
 (future tutorial on this).