-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Description
Is your feature request related to a problem? Please describe.
I'm using DeepSpeed in a Kubernetes setup on a multi-GPU cluster. After training is done, the pod will get shut down and I will lose access to the pod. Therefore I would like to periodically save checkpoints to a remote cloud bucket.
Describe the solution you'd like
Ideally, I would like to see that the save_dir argument of engine.save_checkpoint() can also accept a remote cloud bucket location in addition to a directory on the local file system:
engine.save_checkpoint(save_dir='s3://my-bucket/checkpoints/')
engine.save_checkpoint(save_dir='gs://my-bucket/checkpoints/')
engine.save_checkpoint(save_dir='az://my-bucket/checkpoints/')
I think a solution could be implemented by detecting the cloud service prefix (s3://, gs://, az://) and then using a library like cloudpathlib to upload the checkpoint.
Describe alternatives you've considered
Currently, as a workaround I am first saving the checkpoint locally on the Kubernetes pod and the I upload the checkpoint to S3:
client_state = {}
if steps % cfg.checkpoint_interval == 0 and steps != 0:
checkpoint_tag = f"ckpt_{steps:08d}"
client_state["steps"] = steps
client_state["last_epoch"] = epoch
client_state["cfg_yaml"] = OmegaConf.to_yaml(cfg)
local_dir = os.path.join(cfg.local_checkpoint_path, checkpoint_tag)
if not os.path.exists(local_dir):
logging.info(f"[RANK {rank}] Creating {local_dir}")
os.makedirs(local_dir, exist_ok=True)
# First save checkpoint locally, must be done on all ranks
# Hangs to synchronise all threads.
# Also calls .barrier() at the end to ensure all threads are done writing.
logging.info(f"[RANK {rank}] Saving checkpoint locally to {local_dir} with client_state: {client_state}")
model_engine.save_checkpoint(
save_dir=cfg.local_checkpoint_path,
tag=checkpoint_tag,
client_state=client_state,
)
logging.info(f"[RANK {rank}] Saved checkpoint locally to {local_dir}")
if rank == 0:
# now upload to S3
logging.info(f"[RANK {rank}] Creating s3fs.core.S3FileSystem")
s3 = s3fs.core.S3FileSystem(
client_kwargs={"endpoint_url": os.environ.get("S3_ENDPOINT")}
)
logging.info(f" [RANK {rank}] Created s3fs.core.S3FileSystem: {s3}")
# Prepare for checkpoint load by ensuring all parameters are partitioned
# https://github.com/microsoft/DeepSpeed/blob/6273dffc2f192275a08268b683c309a328b52191/deepspeed/runtime/engine.py#L2752
if model_engine.zero_optimization_partition_weights():
model_engine.optimizer.checkpoint_event_prologue()
# https://github.com/microsoft/DeepSpeed/blob/6273dffc2f192275a08268b683c309a328b52191/deepspeed/runtime/engine.py#L2789
ckpt_list = model_engine._get_all_ckpt_names(
cfg.local_checkpoint_path, checkpoint_tag
)
logging.info(f"ckpt_list: {ckpt_list}")
for local_chkpt_path in ckpt_list:
relative_path = Path(local_chkpt_path).relative_to(cfg.local_checkpoint_path)
s3_chkpt_path = f"{os.environ['S3_BUCKET']}{relative_path}"
logging.info(f"s3_chkpt_path: {s3_chkpt_path}")
with open(local_chkpt_path, "rb") as local_fp, s3.open(
s3_chkpt_path, "wb"
) as remote_fp:
remote_fp.write(local_fp.read())
logging.info(f"Wrote {local_chkpt_path} to {s3_chkpt_path}")Additional context
Add any other context or screenshots about the feature request here.
See related discussion in #2638