Skip to content
Permalink
Browse files

Added support for inclusive NIC names (#1808)

* changed network_interfaces to network_interface

Signed-off-by: fardin abdi <fardin@uber.com>

* removed extra code

Signed-off-by: fardin abdi <fardin@uber.com>

* removed duplicate

Signed-off-by: fardin abdi <fardin@uber.com>

* fixed comment

Signed-off-by: fardin abdi <fardin@uber.com>

* fixed comment

Signed-off-by: fardin abdi <fardin@uber.com>

* fixed comment

Signed-off-by: fardin abdi <fardin@uber.com>

* fixed None type

Signed-off-by: fardin abdi <fardin@uber.com>
  • Loading branch information
abditag2 committed Mar 25, 2020
1 parent 4371007 commit a38040921420f6e03f46b029b3300bb0f7094062
@@ -41,8 +41,8 @@ def __init__(self, all_task_addresses):


class BasicDriverService(network.BasicService):
def __init__(self, num_proc, name, key, nic):
super(BasicDriverService, self).__init__(name, key, nic)
def __init__(self, num_proc, name, key, nics):
super(BasicDriverService, self).__init__(name, key, nics)
self._num_proc = num_proc
self._all_task_addresses = {}
self._task_addresses_for_driver = {}
@@ -53,8 +53,8 @@ def __init__(self, result):


class BasicTaskService(network.BasicService):
def __init__(self, name, key, nic, service_env_keys):
super(BasicTaskService, self).__init__(name, key, nic)
def __init__(self, name, key, nics, service_env_keys):
super(BasicTaskService, self).__init__(name, key, nics)
self._initial_registration_complete = False
self._wait_cond = threading.Condition()
self._service_env_keys = service_env_keys
@@ -85,10 +85,10 @@ def read(self, rfile):


class BasicService(object):
def __init__(self, service_name, key, nic):
def __init__(self, service_name, key, nics):
self._service_name = service_name
self._wire = Wire(key)
self._nic = nic
self._nics = nics
self._server, _ = find_port(
lambda addr: socketserver.ThreadingTCPServer(
addr, self._make_handler()))
@@ -124,16 +124,16 @@ def _handle(self, req, client_address):
def _get_local_addresses(self):
result = {}
for intf, intf_addresses in psutil.net_if_addrs().items():
if self._nic and intf != self._nic:
if self._nics and intf not in self._nics:
continue
for addr in intf_addresses:
if addr.family == socket.AF_INET:
if intf not in result:
result[intf] = []
result[intf].append((addr.address, self._port))
if not result and self._nic:
if not result and self._nics:
raise NoValidAddressesFound(
'No available network interface found matching user provided interface: {}'.format(self._nic))
'No available network interface found matching user provided interface: {}'.format(self._nics))
return result

def addresses(self):
@@ -18,7 +18,7 @@ class Settings(object):

def __init__(self, verbose=0, ssh_port=None, extra_mpi_args=None, tcp_flag=None,
binding_args=None, key=None, timeout=None, num_hosts=None, num_proc=None,
hosts=None, output_filename=None, run_func_mode=None, nic=None):
hosts=None, output_filename=None, run_func_mode=None, nics=None):
"""
:param verbose: level of verbosity
:type verbose: int
@@ -45,8 +45,8 @@ def __init__(self, verbose=0, ssh_port=None, extra_mpi_args=None, tcp_flag=None,
:type output_filename: string
:param run_func_mode: whether it is run function mode
:type run_func_mode: boolean
:param nic: specify the NIC for tcp network communication.
:type nic: string
:param nics: specify the NICs to be used for tcp network communication.
:type nics: string
"""
self.verbose = verbose
self.ssh_port = ssh_port
@@ -60,5 +60,5 @@ def __init__(self, verbose=0, ssh_port=None, extra_mpi_args=None, tcp_flag=None,
self.hosts = hosts
self.output_filename = output_filename
self.run_func_mode = run_func_mode
self.nic = nic
self.nics = nics

@@ -28,10 +28,10 @@
class HorovodRunDriverService(driver_service.BasicDriverService):
NAME = 'horovodrun driver service'

def __init__(self, num_hosts, key, nic):
def __init__(self, num_hosts, key, nics):
super(HorovodRunDriverService, self).__init__(num_hosts,
HorovodRunDriverService.NAME,
key, nic)
key, nics)


class HorovodRunDriverClient(driver_service.BasicDriverClient):
@@ -197,7 +197,7 @@ def _driver_fn(all_host_names, local_host_names, settings):
driver.shutdown()


def _get_common_interfaces(settings, all_host_names, remote_host_names, fn_cache):
def get_common_interfaces(settings, all_host_names, remote_host_names, fn_cache):
'''
Find the set of common and routed interfaces on all the hosts.
:param settings: the object that contains the setting for running horovod
@@ -215,39 +215,43 @@ def _get_common_interfaces(settings, all_host_names, remote_host_names, fn_cache
return None

if len(remote_host_names) > 0:
if settings.verbose >= 2:
print('Testing interfaces on all the hosts.')
if settings.nics:
# If args.nics is provided, we will use those interfaces. All the workers
# must have at least one of those interfaces available.
nics = settings.nics
else:
# Find the set of common, routed interfaces on all the hosts (remote
# and local) and specify it in the args to be used by NCCL. It is
# expected that the following function will find at least one interface
# otherwise, it will raise an exception.
if settings.verbose >= 2:
print('Testing interfaces on all the hosts.')

local_host_names = set(all_host_names) - set(remote_host_names)
# Find the set of common, routed interfaces on all the hosts (remote
# and local) and specify it in the args to be used by NCCL. It is
# expected that the following function will find at least one interface
# otherwise, it will raise an exception.
common_intfs = _driver_fn(all_host_names, local_host_names,
settings, fn_cache=fn_cache)
local_host_names = set(all_host_names) - set(remote_host_names)
nics = _driver_fn(all_host_names, local_host_names, settings, fn_cache=fn_cache)

if settings.verbose >= 2:
print('Interfaces on all the hosts were successfully checked.')
print('Common interface found: ' + ' '.join(common_intfs))
if settings.verbose >= 2:
print('Interfaces on all the hosts were successfully checked.')
print('Common interface found: ' + ' '.join(nics))

else:
if settings.verbose >= 2:
print('All hosts are local, finding the interfaces '
'with address 127.0.0.1')
# If all the given hosts are local, find the interfaces with address
# 127.0.0.1
common_intfs = set()
nics = set()
for iface, addrs in net_if_addrs().items():
if settings.nic and iface != settings.nic:
if settings.nics and iface not in settings.nics:
continue
for addr in addrs:
if addr.family == AF_INET and addr.address == '127.0.0.1':
common_intfs.add(iface)
nics.add(iface)
break

if len(common_intfs) == 0:
if len(nics) == 0:
raise ValueError('No interface is found for address 127.0.0.1.')

if settings.verbose >= 2:
print('Local interface found ' + ' '.join(common_intfs))
return common_intfs
print('Local interface found ' + ' '.join(nics))
return nics
@@ -259,7 +259,7 @@ def set_event_on_sigterm(signum, frame):
.format(name=name, code=exit_code))


def gloo_run(settings, remote_host_names, common_intfs, env, server_ip, command):
def gloo_run(settings, remote_host_names, nics, env, server_ip, command):
# allocate processes into slots
host_alloc_plan = _allocate(settings.hosts, settings.num_proc)

@@ -268,20 +268,20 @@ def gloo_run(settings, remote_host_names, common_intfs, env, server_ip, command)
# Start rendezvous server and get port that it is listening
global_rendezv_port = global_rendezv.start_server(host_alloc_plan)

iface = list(common_intfs)[0]
iface = list(nics)[0]

run_command = (
'HOROVOD_GLOO_RENDEZVOUS_ADDR={addr} '
'HOROVOD_GLOO_RENDEZVOUS_PORT={port} '
'HOROVOD_CONTROLLER=gloo '
'HOROVOD_CPU_OPERATIONS=gloo '
'HOROVOD_GLOO_IFACE={iface} '
'NCCL_SOCKET_IFNAME={common_intfs} '
'NCCL_SOCKET_IFNAME={nics} '
'{command}' # expect a lot of environment variables
.format(addr=server_ip,
port=global_rendezv_port,
iface=iface, # TODO: add multiple ifaces in future
common_intfs=','.join(common_intfs),
nics=','.join(nics),
command=' '.join(quote(par) for par in command)))

_launch_jobs(settings, env, host_alloc_plan, remote_host_names, run_command)
@@ -78,14 +78,14 @@ def _get_mpi_implementation_flags(tcp_flag):
return None, None


def mpi_run(settings, common_intfs, env, command, stdout=None, stderr=None, run_func=safe_shell_exec.execute):
def mpi_run(settings, nics, env, command, stdout=None, stderr=None, run_func=safe_shell_exec.execute):
"""
Runs mpi_run.
Args:
settings: Settings for running MPI.
Note: settings.num_proc and settings.hosts must not be None.
common_intfs: Interfaces to include by MPI.
nics: Interfaces to include by MPI.
env: Environment dictionary to use for running MPI.
command: Command and arguments to run as a list of string.
stdout: Stdout of the mpi process.
@@ -108,9 +108,9 @@ def mpi_run(settings, common_intfs, env, command, stdout=None, stderr=None, run_
hosts_arg = '-H {hosts}'.format(hosts=settings.hosts)

tcp_intf_arg = '-mca btl_tcp_if_include {common_intfs}'.format(
common_intfs=','.join(common_intfs)) if common_intfs else ''
common_intfs=','.join(nics)) if nics else ''
nccl_socket_intf_arg = '-x NCCL_SOCKET_IFNAME={common_intfs}'.format(
common_intfs=','.join(common_intfs)) if common_intfs else ''
common_intfs=','.join(nics)) if nics else ''

# On large cluster runs (e.g. Summit), we need extra settings to work around OpenMPI issues
if settings.num_hosts and settings.num_hosts >= _LARGE_CLUSTER_THRESHOLD:

0 comments on commit a380409

Please sign in to comment.
You can’t perform that action at this time.