From d7faf37cccf0c5f3ef9ce70c49647ba160b949ae Mon Sep 17 00:00:00 2001 From: Yizhou Wang Date: Mon, 29 Jan 2024 09:56:45 +0000 Subject: [PATCH] fix PMI name and add EnvironmentError --- deepspeed/launcher/launcher_helper.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/deepspeed/launcher/launcher_helper.py b/deepspeed/launcher/launcher_helper.py index d4a8755be2e6..649364a4dd62 100644 --- a/deepspeed/launcher/launcher_helper.py +++ b/deepspeed/launcher/launcher_helper.py @@ -60,6 +60,8 @@ def env_mapping(env, rank_name_list=None, local_rank_name_list=None): rank = env.get(rank_name) elif rank != env.get(rank_name): raise EnvironmentError(f"rank number doesn't match!") + if rank == None: + raise EnvironmentError(f"rank number is not in current env!") env['RANK'] = rank local_rank = None @@ -69,6 +71,8 @@ def env_mapping(env, rank_name_list=None, local_rank_name_list=None): local_rank = env.get(local_rank_name) elif local_rank != env.get(local_rank_name): raise EnvironmentError(f"local_rank number doesn't match!") + if local_rank == None: + raise EnvironmentError(f"rank number is not in current env!") env['LOCAL_RANK'] = local_rank return env @@ -81,7 +85,7 @@ def main(args=None): args.launcher = args.launcher.lower() if args.launcher == MPICH_LAUNCHER: - rank_name_list = ["PMIX_RANK"] + ["PMIX_RANK"] + rank_name_list = ["PMIX_RANK"] + ["PMI_RANK"] local_rank_name_list = ["PALS_LOCAL_RANKID"] + ["MPI_LOCALRANKID"] env = env_mapping(env, rank_name_list=rank_name_list, local_rank_name_list=local_rank_name_list) else: