Skip to content

Commit

Permalink
Merge branch 'tnt_cli' into 'master'
Browse files Browse the repository at this point in the history
tnt: CLI: update command line options

See merge request carpenamarie/hpdlf!207
  • Loading branch information
Alexandra Carpen-Amarie committed May 13, 2022
2 parents f77d938 + 16b7d58 commit b69c95d
Showing 1 changed file with 93 additions and 52 deletions.
145 changes: 93 additions & 52 deletions src/runtime/tarantella_cli.py
Expand Up @@ -28,8 +28,9 @@ def create_parser():
help="path to the list of nodes (hostnames) on which to execute the SCRIPT",
default = None)
multinode_group.add_argument("--n-per-node", "--devices-per-node",
help="number of devices (i.e., either GPUs or processes on CPUs) to be " \
"used on each node",
help="""number of devices (i.e., either GPUs or processes on CPUs) to be
used on each node
""",
dest = "npernode",
type = int,
default = None)
Expand All @@ -38,55 +39,76 @@ def create_parser():
dest = "use_gpus",
action='store_false',
default = True)
parser.add_argument("--output-on-all-devices",
help="enable output on all devices (e.g., training info)",
dest = "output_all",
action='store_true',
default = False)
log_levels = ('DEBUG', 'INFO', 'WARNING', 'ERROR')
parser.add_argument('--log-level', default='WARNING', choices=log_levels,
help = "logging level for library messages")
parser.add_argument("--log-on-all-devices",
help="enable library logging messages on all devices",
dest = "log_all",
action='store_true',
default = False)
log_levels = ('DEBUG', 'INFO', 'WARNING', 'ERROR')
parser.add_argument('--log-level', default='WARNING', choices=log_levels,
help = "logging level for library messages")
parser.add_argument("--fusion-threshold",
help="tensor fusion threshold [kilobytes]",
dest = "fusion_threshold_kb",
type = int,
default = None)
parser.add_argument("--dry-run",
help="print generated files and execution command",
dest = "dry_run",
parser.add_argument("--output-on-all-devices",
help="enable output on all devices (e.g., training info)",
dest = "output_all",
action='store_true',
default = False)
parser.add_argument("-x",
help = "list of space-separated KEY=VALUE environment variables to be " \
"set on all ranks. " \
"Example: `-x DATASET=/scratch/data TF_CPP_MIN_LOG_LEVEL=1`",
help = """list of space-separated KEY=VALUE environment variables to be
set on all ranks.
Example: `-x DATASET=/scratch/data TF_CPP_MIN_LOG_LEVEL=1`
""",
dest = "setenv",
type = str,
nargs="+",
default=[])
parser.add_argument("--pin-to-socket",
help="pin each rank to a socket based on rank id using `numactl`",
dest = "pin_to_socket",
action='store_true',
default=False)
parser.add_argument("--cleanup",
help="clean up remaining processes after an abnormal termination",
dest = "cleanup",
action='store_true',
default=False)
parser.add_argument("--force",
help="force termination of cleaned up processes",
dest = "force",
perf_group = parser.add_argument_group('Performance tuning')
perf_group.add_argument("--fusion-threshold",
help="tensor fusion threshold [kilobytes]; use 0 to disable tensor fusion",
dest = "fusion_threshold_kb",
type = int,
default = None)
perf_group.add_argument("--pin-to-socket",
help="pin each rank to a socket using `numactl`, based on rank id",
dest = "pin_to_socket",
action='store_true',
default=False)
perf_group.add_argument("--pin-memory-to-socket",
help="""pin memory allocation for each rank to a socket using `numactl`,
based on rank id [default: False - memory will only be
preferentially allocated from the current socket
(`numactl --preferred`)]
""",
dest = "pin_mem_to_socket",
action='store_true',
default=False)
perf_group.add_argument("--python-interpreter",
help="""use a specific Python interpreter instead of the default
`python` found in $PATH.
Pass an empty string to this option to run your SCRIPT without an
intepreter.
""",
dest = "python_interpreter",
type = str,
default = "python")
cleanup_group = parser.add_argument_group('Cleanup')
cleanup_group.add_argument("--cleanup",
help = "clean up remaining processes after an abnormal termination",
dest = "cleanup",
action = 'store_true',
default = False)
cleanup_group.add_argument("--force",
help = "force termination of cleaned up processes",
dest = "force",
action = 'store_true',
default = False)
parser.add_argument("--dry-run",
help = "print generated files and execution command",
dest = "dry_run",
action='store_true',
default=False)
default = False)
parser.add_argument("--version",
action='version',
version=generate_version_message())
action = 'version',
version = generate_version_message())
parser.add_argument('script', nargs='+', metavar='-- SCRIPT')
return parser

Expand Down Expand Up @@ -152,7 +174,23 @@ def get_numa_nodes_count():
out = out.split('\n')[0].split()[1]
return int(out) if out.isdigit() else 0

def get_numa_prefix(npernode):
def cpu_pinning_option(pin_to_socket: bool) -> str:
cpu_pinning = "--cpunodebind=$socket" if pin_to_socket else ""
return cpu_pinning

def mem_pinning_option(pin_mem_to_socket: bool) -> str:
mem_pinning = "--membind=$socket" if pin_mem_to_socket else "--preferred=$socket"
return mem_pinning

def validate_command(command: str) -> None:
if command is not None:
path_to_command = shutil.which(command.split(' ')[0])
if path_to_command is None:
raise FileNotFoundError(f"[TNT_CLI] Cannot execute `{command}`; make sure that `{command}` " \
"exists and has been added to the current `PATH`.")


def get_numa_prefix(npernode: int, pin_to_socket: bool, pin_mem_to_socket: bool):
node_count = get_numa_nodes_count()
if node_count == 0 or npernode == 0 or node_count < npernode:
raise ValueError(f"[TNT_CLI] Cannot pin {npernode} ranks to {node_count} NUMA nodes. " \
Expand All @@ -162,8 +200,11 @@ def get_numa_prefix(npernode):
if node_count != npernode:
logger.warn(f"Pinning {npernode} ranks to NUMA nodes on each host " \
f"(available NUMA nodes: {node_count}).")
cpu_pinning = cpu_pinning_option(pin_to_socket)
mem_pinning = mem_pinning_option(pin_mem_to_socket)

command = f"socket=$(( $GASPI_RANK % {npernode} ))\n"
command += f"numactl --cpunodebind=$socket --membind=$socket"
command += f"numactl {cpu_pinning} {mem_pinning}"
return command

class TarantellaCLI:
Expand All @@ -188,17 +229,13 @@ def __init__(self, hostlist, num_gpus_per_node, num_cpus_per_node, args):
len(hostlist), self.npernode, device_type))

def generate_interpreter(self):
interpreter = "python"

if self.args.pin_to_socket:
path_to_numa = shutil.which("numactl")
if path_to_numa is None:
raise FileNotFoundError("[TNT_CLI] Cannot execute `numactl` as required by " \
"the `--pin-to-socket` flag; make sure that `numactl` " \
"is installed and has been added to the current `PATH`.")
interpreter = f"{get_numa_prefix(self.npernode)} {interpreter}"

return interpreter
interpreter = self.args.python_interpreter
if self.args.pin_to_socket or self.args.pin_mem_to_socket:
numactl_prefix = get_numa_prefix(self.npernode,
pin_to_socket = self.args.pin_to_socket,
pin_mem_to_socket = self.args.pin_mem_to_socket)
interpreter = f"{numactl_prefix} {interpreter}"
return interpreter.strip()

def get_absolute_path(self, module):
return module.__file__
Expand All @@ -215,7 +252,7 @@ def generate_executable_script(self):
env_config.gen_exports_from_dict(env_config.get_environment_vars_from_args(self.args))

command = f"{self.generate_interpreter()} {' '.join(self.command_list)}"
return file_man.GPIScriptFile(header, environment, command, dir = os.getcwd())
return file_man.GPIScriptFile(header, environment, command.strip(), dir = os.getcwd())

def generate_cleanup_script(self):
header = "#!/bin/bash\n"
Expand All @@ -237,7 +274,7 @@ def run(self):

def normal_run(self):
self.execute_with_gaspi_run(self.nranks, self.hostfile, self.executable_script,
self.args.dry_run)
self.args.dry_run)

def clean_up_run(self):
cleanup_script = self.generate_cleanup_script()
Expand All @@ -256,6 +293,10 @@ def execute_with_gaspi_run(self, nranks, hostfile, executable_script, dry_run =
executable_script.filename))
return

if self.args.python_interpreter:
validate_command(self.args.python_interpreter)
if self.args.pin_to_socket or self.args.pin_mem_to_socket:
validate_command('numactl')
try:
subprocess.run(command_list,
check = True,
Expand Down

0 comments on commit b69c95d

Please sign in to comment.