-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Runner - support 'pattern' in 'mpi' mode to run tasks in parallel (#430)
* add mpi-parallels mode * update according to comments * fix and update doc * update * merge into 'mpi' mode * udpate according to comments * fix testcases * fix ansible * regard pattern as field * udpate * fix flake8 version * add flake8 range * remove map-by from host config * udpate comments
- Loading branch information
Showing
8 changed files
with
171 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
"""Utilities for traffic pattern config.""" | ||
from superbench.common.utils import logger | ||
|
||
|
||
def gen_all_nodes_config(n): | ||
"""Generate all nodes config. | ||
Args: | ||
n (int): the number of participants. | ||
Returns: | ||
config (list): the generated config list, each item in the list is a str like "0,1,2,3". | ||
""" | ||
config = [] | ||
if n <= 0: | ||
logger.warning('n is not positive') | ||
return config | ||
config = [','.join(map(str, range(n)))] | ||
return config | ||
|
||
|
||
def __convert_config_to_host_group(config, host_list): | ||
"""Convert config format to host node. | ||
Args: | ||
host_list (list): the list of hostnames read from hostfile. | ||
config (list): the traffic pattern config. | ||
Returns: | ||
host_groups (list): the host groups converted from traffic pattern config. | ||
""" | ||
host_groups = [] | ||
for item in config: | ||
groups = item.strip().strip(';').split(';') | ||
host_group = [] | ||
for group in groups: | ||
hosts = [] | ||
for index in group.split(','): | ||
hosts.append(host_list[int(index)]) | ||
host_group.append(hosts) | ||
host_groups.append(host_group) | ||
return host_groups | ||
|
||
|
||
def gen_tarffic_pattern_host_group(host_list, pattern): | ||
"""Generate host group from specified traffic pattern. | ||
Args: | ||
host_list (list): the list of hostnames read from hostfile. | ||
pattern (DictConfig): the mpi pattern dict. | ||
Returns: | ||
host_group (list): the host group generated from traffic pattern. | ||
""" | ||
config = [] | ||
n = len(host_list) | ||
if pattern.name == 'all-nodes': | ||
config = gen_all_nodes_config(n) | ||
else: | ||
logger.error('Unsupported traffic pattern: {}'.format(pattern.name)) | ||
host_group = __convert_config_to_host_group(config, host_list) | ||
return host_group |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
"""Tests for traffic pattern config generation module.""" | ||
import argparse | ||
import unittest | ||
|
||
from superbench.common.utils import gen_tarffic_pattern_host_group | ||
|
||
|
||
class GenConfigTest(unittest.TestCase): | ||
"""Test the utils for generating config.""" | ||
def test_gen_tarffic_pattern_host_group(self): | ||
"""Test the function of generating traffic pattern config from specified mode.""" | ||
# test under 8 nodes | ||
hostx = ['node0', 'node1', 'node2', 'node3', 'node4', 'node5', 'node6', 'node7'] | ||
parser = argparse.ArgumentParser( | ||
add_help=False, | ||
usage=argparse.SUPPRESS, | ||
allow_abbrev=False, | ||
) | ||
parser.add_argument( | ||
'--name', | ||
type=str, | ||
default='all-nodes', | ||
required=False, | ||
) | ||
pattern, _ = parser.parse_known_args() | ||
expected_host_group = [[['node0', 'node1', 'node2', 'node3', 'node4', 'node5', 'node6', 'node7']]] | ||
self.assertEqual(gen_tarffic_pattern_host_group(hostx, pattern), expected_host_group) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters