FSDP - Saving and Loading Models and Optimizer checkpoints

This tutorial covers loading and saving via FULL_STATE_DICT, or where the model and optimizer are fully assembled on the rank 0 cpu memory.  This means there is an upper limit of model sizes that can be saved based on cpu memory...assume models around 20B will fit. 

For larger models, we'll use LOCAL_STATE_DICT, which will be covered seperately.  

In [1]:
import torch
torch.__version__

'1.13.0.dev20220627+cu113'

In [3]:
# import FSDP with two additional items:
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    CPUOffload,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
    FullStateDictConfig,  # < -- configuration policy for full_state actions
    StateDictType,  # < -- enum, use to confirm what type of states we are handling (in this case FULL_STATE)
)

In [None]:
# build model
model = build_model(cfg.model_name)

In [None]:
# load checkpoint for model
# preload checkpoint if desired
    if (
        cfg.load_model_checkpoint
        and cfg.checkpoint_type == StateDictType.FULL_STATE_DICT
    ):
        model_checkpointing.load_model_checkpoint(model, rank, cfg)

In [4]:
# load function is just like regular PyTorch checkpoint loading for FULL_STATE
# only called on Rank0!
def load_model_checkpoint(model, rank, cfg, verbose=True):
    """load local checkpoint to rank0 cpu
    must be called * before * passing to FSDP"""

    if rank != 0:
        return

    # where is the checkpoint at...
    full_state_dict_model_path = (
        Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename
    )
    # is it present...
    if not full_state_dict_model_path.is_file():
        print(
            f"model checkpoint {full_state_dict_model_path} not present. Returning..."
        )
        return
        
    # load the checkpoint
    model_checkpoint = torch.load(full_state_dict_model_path)
    # integrate into loaded model
    model.load_state_dict(model_checkpoint)

    if cfg.verbose:
        print(f"model checkpoint loaded to rank0 cpu")

In [None]:
# init FSDP and shard the model
# ----- main FSDP init -----------
    model = FSDP(
        model,
        auto_wrap_policy=my_auto_wrap_policy,
        mixed_precision=mp_policy,
        # backward_prefetch=prefetch_policy,
        device_id=torch.cuda.current_device(),
        sharding_strategy=ShardingStrategy.FULL_SHARD,  # Zero2
        # cpu_offload= cpu_policy,
        forward_prefetch=True,
    )

In [None]:
# prepare optimizer with sharded model
# optimizer ----------
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=8e-4, weight_decay=0.005
    )

In [None]:
# load optimizer checkpoint
    if cfg.load_optimizer:
        model_checkpointing.load_optimizer_checkpoint(model, optimizer, rank, cfg)

In [None]:
def load_optimizer_checkpoint(model, optimizer, rank, cfg):
    """load an fdsp optimizer full_state checkpoint using scatter method
    this ensures only rank 0 loads the optimizer state dict and scatters to other ranks"""

    opt_file_path = Path.cwd() / cfg.checkpoint_folder / cfg.optimizer_checkpoint_file

    if not opt_file_path.is_file():
        print(
            f"warning - optimizer checkpoint not present {opt_file_path}. Returning. "
        )
        return

    full_osd = None

    if rank == 0:
        full_osd = torch.load(opt_file_path)

        if cfg.verbose:
            print(f"loaded full osd on rank 0")

    # called from all ranks, though only rank0 has a valid param for full_osd
    sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)

    if cfg.verbose:
        print(f"optimizer shard loaded on rank {rank}")

In [None]:
# Training Loop!

In [None]:
# check metrics, decided to save model and optimizer
# model checkpointing

# create singleton saving policies to avoid making over and over
fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)

In [None]:
def save_model_checkpoint(
    model,
    optimizer,
    rank,
    cfg,
    epoch=1,
):
    """saving model via rank0 cpu streaming and full_state_dict"""

    # saving with rank0 cpu
    if not cfg.checkpoint_type == StateDictType.FULL_STATE_DICT:
        print(f" unable to handle checkpoint type {cfg.checkpoint_type}, aborting")

    with FSDP.state_dict_type(
        model, StateDictType.FULL_STATE_DICT, fullstate_save_policy
    ):
        cpu_state = model.state_dict()

    if cfg.verbose:
        print(f"saving process: rank {rank}  done w model state_dict")

    if rank == 0:
        print(f"--> saving model ...")
        # create save path
        save_dir = Path.cwd() / cfg.checkpoint_folder
        save_dir.mkdir(parents=True, exist_ok=True)
        save_name = cfg.model_save_name + "-" + str(epoch) + ".pt"
        save_full_path = str(save_dir) + "/" + save_name

        # save model
        torch.save(cpu_state, save_full_path)

        if cfg.verbose:
            print(f"model checkpoint saved for epoch {epoch} at {save_full_path}")



In [None]:
# and saving our optimizer
def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
    """save optimizer state via full state dict"""

    if cfg.verbose:
        print(f"--> optim state call on rank {rank}")

    # pull all sharded optimizer states to rank0 cpu...

    optim_state = FSDP.full_optim_state_dict(model, optimizer)

    if cfg.verbose:
        print(f"optim state dict ready on {rank} and len of {len(optim_state)}")

    if rank == 0:
        save_dir = Path.cwd() / cfg.checkpoint_folder
        save_dir.mkdir(parents=True, exist_ok=True)

        opt_save_name = (
            cfg.optimizer_name + "-" + cfg.model_save_name + "-" + str(epoch) + ".pt"
        )
        opt_save_full_path = save_dir / opt_save_name

        # note that saving can be time consuming...i.e. 1.5B can take up to 3 minutes (17GB)
        # thus always print state so no one thinks it has hung
        print(f"--> saving optimizer state...")

        torch.save(optim_state, opt_save_full_path)

        print(f"--> saved {opt_save_full_path} to disk")
