Skip to content

Commit

Permalink
launcher/launcher_helper.py: fix PMI name and add EnvironmentError (m…
Browse files Browse the repository at this point in the history
…icrosoft#5025)

Hi, for my last PR microsoft#4699
about launcher_helper, it mistakenly used two "PMIX". In this PR I
corrected them to be "PMIX" and "PMI". And I also added
_EnvironmentError_ to make sure env not get _NONE_ type, otherwise it
would trigger env setting error.
  • Loading branch information
YizhouZ authored and amaurya committed Feb 17, 2024
1 parent bda5173 commit 97861c9
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion deepspeed/launcher/launcher_helper.py
Expand Up @@ -60,6 +60,8 @@ def env_mapping(env, rank_name_list=None, local_rank_name_list=None):
rank = env.get(rank_name)
elif rank != env.get(rank_name):
raise EnvironmentError(f"rank number doesn't match!")
if rank == None:
raise EnvironmentError(f"rank number is not in current env!")
env['RANK'] = rank

local_rank = None
Expand All @@ -69,6 +71,8 @@ def env_mapping(env, rank_name_list=None, local_rank_name_list=None):
local_rank = env.get(local_rank_name)
elif local_rank != env.get(local_rank_name):
raise EnvironmentError(f"local_rank number doesn't match!")
if local_rank == None:
raise EnvironmentError(f"rank number is not in current env!")
env['LOCAL_RANK'] = local_rank

return env
Expand All @@ -81,7 +85,7 @@ def main(args=None):

args.launcher = args.launcher.lower()
if args.launcher == MPICH_LAUNCHER:
rank_name_list = ["PMIX_RANK"] + ["PMIX_RANK"]
rank_name_list = ["PMIX_RANK"] + ["PMI_RANK"]
local_rank_name_list = ["PALS_LOCAL_RANKID"] + ["MPI_LOCALRANKID"]
env = env_mapping(env, rank_name_list=rank_name_list, local_rank_name_list=local_rank_name_list)
else:
Expand Down

0 comments on commit 97861c9

Please sign in to comment.