Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions deepspeed/launcher/multinode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,31 @@ def name(self):

def validate_args(self):
super().validate_args()

# Validate and set MPI environment variables
self._setup_mpi_environment()

#TODO: Allow for include/exclude at node-level but not gpu-level
if self.args.include != "" or self.args.exclude != "":
raise ValueError(f"{self.name} backend does not support worker include/exclusion")
if self.args.num_nodes != -1 or self.args.num_gpus != -1:
raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")

def _setup_mpi_environment(self):
"""Sets up MPI-related environment variables or raises an error if they're missing."""

required_vars = ['OMPI_COMM_WORLD_LOCAL_RANK', 'OMPI_COMM_WORLD_RANK', 'OMPI_COMM_WORLD_SIZE']

# Check if all these are present
if not all(var in os.environ for var in required_vars):
raise EnvironmentError("MPI environment variables are not set. "
"Ensure you are running the script with an MPI-compatible launcher.")

# Now safe to read all
os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']

def get_cmd(self, environment, active_resources):
total_process_count = sum(self.resource_pool.values())

Expand Down
71 changes: 66 additions & 5 deletions tests/unit/launcher/test_multinode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ def runner_info():
return env, hosts, world_info, args


@pytest.fixture
def mock_mpi_env(monkeypatch):
# Provide the 3 required MPI variables:
monkeypatch.setenv('OMPI_COMM_WORLD_LOCAL_RANK', '0')
monkeypatch.setenv('OMPI_COMM_WORLD_RANK', '0')
monkeypatch.setenv('OMPI_COMM_WORLD_SIZE', '1')


def test_pdsh_runner(runner_info):
env, resource_pool, world_info, args = runner_info
runner = mnrunner.PDSHRunner(args, world_info)
Expand All @@ -27,34 +35,87 @@ def test_pdsh_runner(runner_info):
assert env['PDSH_RCMD_TYPE'] == 'ssh'


def test_openmpi_runner(runner_info):
def test_openmpi_runner(runner_info, mock_mpi_env):
env, resource_pool, world_info, args = runner_info
runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
cmd = runner.get_cmd(env, resource_pool)
assert cmd[0] == 'mpirun'
assert 'eth0' in cmd


def test_btl_nic_openmpi_runner(runner_info):
def test_btl_nic_openmpi_runner(runner_info, mock_mpi_env):
env, resource_pool, world_info, _ = runner_info
args = parse_args(['--launcher_arg', '-mca btl_tcp_if_include eth1', 'test_launcher.py'])

runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
cmd = runner.get_cmd(env, resource_pool)
assert 'eth0' not in cmd
assert 'eth1' in cmd


def test_btl_nic_two_dashes_openmpi_runner(runner_info):
def test_btl_nic_two_dashes_openmpi_runner(runner_info, mock_mpi_env):
env, resource_pool, world_info, _ = runner_info
args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])

runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
cmd = runner.get_cmd(env, resource_pool)
assert 'eth0' not in cmd
assert 'eth1' in cmd


def test_setup_mpi_environment_success():
"""Test that _setup_mpi_environment correctly sets environment variables when MPI variables exist."""
os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = '0'
os.environ['OMPI_COMM_WORLD_RANK'] = '1'
os.environ['OMPI_COMM_WORLD_SIZE'] = '2'

args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])

runner = mnrunner.OpenMPIRunner(args, None, None)
# Set up the MPI environment
runner._setup_mpi_environment()

assert os.environ['LOCAL_RANK'] == '0'
assert os.environ['RANK'] == '1'
assert os.environ['WORLD_SIZE'] == '2'

# Clean up environment
del os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
del os.environ['OMPI_COMM_WORLD_RANK']
del os.environ['OMPI_COMM_WORLD_SIZE']
del os.environ['LOCAL_RANK']
del os.environ['RANK']
del os.environ['WORLD_SIZE']


def test_setup_mpi_environment_missing_variables():
"""Test that _setup_mpi_environment raises an EnvironmentError when MPI variables are missing."""

# Clear relevant environment variables
os.environ.pop('OMPI_COMM_WORLD_LOCAL_RANK', None)
os.environ.pop('OMPI_COMM_WORLD_RANK', None)
os.environ.pop('OMPI_COMM_WORLD_SIZE', None)

args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])

with pytest.raises(EnvironmentError, match="MPI environment variables are not set"):
mnrunner.OpenMPIRunner(args, None, None)


def test_setup_mpi_environment_fail():
"""Test that _setup_mpi_environment fails if only partial MPI variables are provided."""
os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = '0'
os.environ.pop('OMPI_COMM_WORLD_RANK', None) # missing variable
os.environ['OMPI_COMM_WORLD_SIZE'] = '2'

args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])

with pytest.raises(EnvironmentError, match="MPI environment variables are not set"):
runner = mnrunner.OpenMPIRunner(args, None, None)

# Clean up environment
del os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
del os.environ['OMPI_COMM_WORLD_SIZE']


def test_mpich_runner(runner_info):
env, resource_pool, world_info, args = runner_info
runner = mnrunner.MPICHRunner(args, world_info, resource_pool)
Expand Down