Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
BUG: Dont update multi-node env vars for single node training (#796)
Browse files Browse the repository at this point in the history
Update environment variables for multi-node jobs
  • Loading branch information
mebristo committed Sep 2, 2022
1 parent 8ffab94 commit 8cf63c8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
38 changes: 31 additions & 7 deletions InnerEye/Azure/azure_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@

# Environment variables used for multi-node training
ENV_AZ_BATCHAI_MPI_MASTER_NODE = "AZ_BATCHAI_MPI_MASTER_NODE"
ENV_AZ_BATCH_MASTER_NODE = "AZ_BATCH_MASTER_NODE"
ENV_MASTER_ADDR = "MASTER_ADDR"
ENV_MASTER_IP = "MASTER_IP"
ENV_MASTER_PORT = "MASTER_PORT"
ENV_OMPI_COMM_WORLD_RANK = "OMPI_COMM_WORLD_RANK"
ENV_NODE_RANK = "NODE_RANK"
ENV_GLOBAL_RANK = "GLOBAL_RANK"
ENV_LOCAL_RANK = "LOCAL_RANK"
MASTER_PORT_DEFAULT = 6105


def get_git_tags(azure_config: AzureConfig) -> Dict[str, str]:
Expand Down Expand Up @@ -294,21 +296,43 @@ def set_environment_variables_for_multi_node() -> None:
"""
Sets the environment variables that PyTorch Lightning needs for multi-node training.
"""

if ENV_AZ_BATCHAI_MPI_MASTER_NODE in os.environ:
if ENV_AZ_BATCH_MASTER_NODE in os.environ:
master_node = os.environ[ENV_AZ_BATCH_MASTER_NODE]
logging.debug(
f"Found AZ_BATCH_MASTER_NODE: {master_node} in environment variables")
# For AML BATCHAI
split_master_node_addr = master_node.split(":")
if len(split_master_node_addr) == 2:
master_addr, port = split_master_node_addr
os.environ[ENV_MASTER_PORT] = port
elif len(split_master_node_addr) == 1:
master_addr = split_master_node_addr[0]
else:
raise ValueError(f"Format not recognized: {master_node}")
os.environ[ENV_MASTER_ADDR] = master_addr
elif ENV_AZ_BATCHAI_MPI_MASTER_NODE in os.environ and os.environ.get(ENV_AZ_BATCHAI_MPI_MASTER_NODE) != "localhost":
mpi_master_node = os.environ[ENV_AZ_BATCHAI_MPI_MASTER_NODE]
logging.debug(
f"Found AZ_BATCHAI_MPI_MASTER_NODE: {mpi_master_node} in environment variables")
# For AML BATCHAI
os.environ[ENV_MASTER_ADDR] = os.environ[ENV_AZ_BATCHAI_MPI_MASTER_NODE]
os.environ[ENV_MASTER_ADDR] = mpi_master_node
elif ENV_MASTER_IP in os.environ:
master_ip = os.environ[ENV_MASTER_IP]
logging.debug(
f"Found MASTER_IP: {master_ip} in environment variables")
# AKS
os.environ[ENV_MASTER_ADDR] = os.environ[ENV_MASTER_IP]
os.environ[ENV_MASTER_ADDR] = master_ip
else:
logging.info("No settings for the MPI central node found. Assuming that this is a single node training job.")
return

if ENV_MASTER_PORT not in os.environ:
os.environ[ENV_MASTER_PORT] = "6105"
os.environ[ENV_MASTER_PORT] = str(MASTER_PORT_DEFAULT)

if ENV_OMPI_COMM_WORLD_RANK in os.environ:
os.environ[ENV_NODE_RANK] = os.environ[ENV_OMPI_COMM_WORLD_RANK] # node rank is the world_rank from mpi run
world_rank = os.environ[ENV_OMPI_COMM_WORLD_RANK]
logging.debug(f"Found OMPI_COMM_WORLD_RANK: {world_rank} in environment variables")
os.environ[ENV_NODE_RANK] = world_rank # node rank is the world_rank from mpi run

env_vars = ", ".join(f"{var} = {os.environ[var]}" for var in [ENV_MASTER_ADDR, ENV_MASTER_PORT, ENV_NODE_RANK])
print(f"Distributed training: {env_vars}")
logging.info(f"Distributed training: {env_vars}")
3 changes: 2 additions & 1 deletion InnerEye/ML/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ def run_in_situ(self, azure_run_info: AzureRunInfo) -> None:
else:
# Set environment variables for multi-node training if needed. This function will terminate early
# if it detects that it is not in a multi-node environment.
set_environment_variables_for_multi_node()
if self.azure_config.num_nodes > 1:
set_environment_variables_for_multi_node()
self.ml_runner = self.create_ml_runner()
self.ml_runner.setup(azure_run_info)
self.ml_runner.run()
Expand Down

0 comments on commit 8cf63c8

Please sign in to comment.