Skip to content

[REQUEST] Saving checkpoints to cloud bucket #2701

@BioGeek

Description

@BioGeek

Is your feature request related to a problem? Please describe.

I'm using DeepSpeed in a Kubernetes setup on a multi-GPU cluster. After training is done, the pod will get shut down and I will lose access to the pod. Therefore I would like to periodically save checkpoints to a remote cloud bucket.

Describe the solution you'd like

Ideally, I would like to see that the save_dir argument of engine.save_checkpoint() can also accept a remote cloud bucket location in addition to a directory on the local file system:

engine.save_checkpoint(save_dir='s3://my-bucket/checkpoints/')
engine.save_checkpoint(save_dir='gs://my-bucket/checkpoints/')
engine.save_checkpoint(save_dir='az://my-bucket/checkpoints/')

I think a solution could be implemented by detecting the cloud service prefix (s3://, gs://, az://) and then using a library like cloudpathlib to upload the checkpoint.

Describe alternatives you've considered

Currently, as a workaround I am first saving the checkpoint locally on the Kubernetes pod and the I upload the checkpoint to S3:

            client_state = {}

            if steps % cfg.checkpoint_interval == 0 and steps != 0:
                checkpoint_tag = f"ckpt_{steps:08d}"

                client_state["steps"] = steps
                client_state["last_epoch"] = epoch
                client_state["cfg_yaml"] = OmegaConf.to_yaml(cfg)

                local_dir = os.path.join(cfg.local_checkpoint_path, checkpoint_tag)
                if not os.path.exists(local_dir):
                    logging.info(f"[RANK {rank}] Creating {local_dir}")
                    os.makedirs(local_dir, exist_ok=True)

                # First save checkpoint locally, must be done on all ranks
                # Hangs to synchronise all threads.
                # Also calls .barrier() at the end to ensure all threads are done writing.
                logging.info(f"[RANK {rank}] Saving checkpoint locally to {local_dir} with client_state: {client_state}")
                model_engine.save_checkpoint(
                    save_dir=cfg.local_checkpoint_path,
                    tag=checkpoint_tag,
                    client_state=client_state,
                )
                logging.info(f"[RANK {rank}] Saved checkpoint locally to {local_dir}")

                if rank == 0:
                    # now upload to S3
                    logging.info(f"[RANK {rank}] Creating s3fs.core.S3FileSystem")
                    s3 = s3fs.core.S3FileSystem(
                        client_kwargs={"endpoint_url": os.environ.get("S3_ENDPOINT")}
                    )
                    logging.info(f" [RANK {rank}] Created s3fs.core.S3FileSystem: {s3}")

                    # Prepare for checkpoint load by ensuring all parameters are partitioned
                    # https://github.com/microsoft/DeepSpeed/blob/6273dffc2f192275a08268b683c309a328b52191/deepspeed/runtime/engine.py#L2752
                    if model_engine.zero_optimization_partition_weights():
                        model_engine.optimizer.checkpoint_event_prologue()

                    # https://github.com/microsoft/DeepSpeed/blob/6273dffc2f192275a08268b683c309a328b52191/deepspeed/runtime/engine.py#L2789
                    ckpt_list = model_engine._get_all_ckpt_names(
                        cfg.local_checkpoint_path, checkpoint_tag
                    )
                    logging.info(f"ckpt_list: {ckpt_list}")
                    for local_chkpt_path in ckpt_list:
                        relative_path = Path(local_chkpt_path).relative_to(cfg.local_checkpoint_path)
                        s3_chkpt_path = f"{os.environ['S3_BUCKET']}{relative_path}"
                        logging.info(f"s3_chkpt_path: {s3_chkpt_path}")

                        with open(local_chkpt_path, "rb") as local_fp, s3.open(
                            s3_chkpt_path, "wb"
                        ) as remote_fp:
                            remote_fp.write(local_fp.read())
                            logging.info(f"Wrote {local_chkpt_path} to {s3_chkpt_path}")

Additional context
Add any other context or screenshots about the feature request here.

See related discussion in #2638

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions