In [1]:
import torch
torch.__version__

'1.13.0.dev20220719+cu113'

Distributed checkpoints, also known as local state dict saving! 

Unlike full_state_dict, saving with a distributed checkpoint will not create a single .pt file...rather it will generate hundreds to thousands of smaller files all within a directory.  

This allows for saving of gigantic models that would otherwise exceed CPU memory.  Full state dict will try to assemble the whole model in CPU memory, vs distributed checkpointing will not.  

In [2]:
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    StateDictType,
    FullStateDictConfig,  # general model non-sharded, non-flattened params
    LocalStateDictConfig,  # flattened params, usable only by FSDP
    # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
)



In [3]:
from torch.distributed._shard.checkpoint import (
    FileSystemReader,
    FileSystemWriter,
    save_state_dict,
    load_state_dict,
)

In [4]:
from pathlib import Path

In [None]:
def load_distributed_model_checkpoint(model, rank, cfg):

    if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
        print(f"loading distributed checkpoint, rank {rank}...")
        folder_name = cfg.dist_checkpoint_root_folder+"/"+cfg.dist_checkpoint_folder+"-"+cfg.model_name

        checkdir = Path.cwd() / folder_name

        if not checkdir.exists():
            if rank==0:
                print(f"No checkpoint directory found...skipping")
            return


        reader = FileSystemReader(checkdir)

        with FSDP.state_dict_type(
            model,
            StateDictType.LOCAL_STATE_DICT,
        ):
            state_dict = model.state_dict()
            load_state_dict(state_dict, reader)
            
            model.load_state_dict(state_dict)

        print(f"--> local state loaded on rank {rank}")
        

In [5]:
def save_distributed_model_checkpoint(model, rank, cfg, epoch=1):
    # distributed checkpoint saving

    if rank == 0:
        print(f"Starting distributed checkpoint save...")
    

    # confirm type of checkpoint and save
    if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
        # create writer to current path
        #folder_name = cfg.dist_checkpoint_folder+"-"+cfg.model_name
        folder_name = cfg.dist_checkpoint_root_folder+"/"+cfg.dist_checkpoint_folder+"-"+cfg.model_name
        save_dir = Path.cwd() / folder_name

        writer = FileSystemWriter(save_dir)

        with FSDP.state_dict_type(
            model,
            StateDictType.LOCAL_STATE_DICT,
        ):
            state_dict = model.state_dict()

        # write out distributed checkpoint
        save_state_dict(state_dict, writer)

        if rank == 0:
            print(f"--> distributed checkpoint saved at {save_dir}")

        

In [6]:
# notes - be very careful to increment your directory names for distributed checkpoints.  Saving into the same
# directory will overwrite the previous .metadata file that controls the info on all the guid-named files within. 