-
Notifications
You must be signed in to change notification settings - Fork 17
/
run.py
174 lines (137 loc) · 5.02 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# -*- coding: utf-8 -*-
import re
import os
import pcode.utils.op_files as op_files
import parameters as para
import tmux_cluster.tmux as tx
def read_hostfile(file_path):
def _parse(line):
matched_line = re.findall(r"^(.*?) slots=(.*?)$", line, re.DOTALL)
matched_line = [x.strip() for x in matched_line[0]]
return matched_line
# read file
lines = op_files.read_txt(file_path)
# use regex to parse the file.
ip2slots = dict(_parse(line) for line in lines)
return ip2slots
def map_slot(ip2slots):
ip_slot = []
for ip, slots in ip2slots.items():
for _ in range(int(slots)):
ip_slot += [ip]
return ip_slot
def run_cmd(cmd):
# run the cmd.
print("\nRun the following cmd:\n" + cmd)
os.system(cmd)
def get_random_port():
import socket
from contextlib import closing
def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
return find_free_port()
def build_nccl_script(conf, replacement=None):
# build runnable script.
cmd = " main.py "
for k, v in conf.__dict__.items():
if replacement is not None and k in replacement:
cmd += " --{} {} ".format(k, replacement[k])
elif v is not None:
cmd += " --{} {} ".format(k, v)
return cmd
def build_mpi_script(conf, replacement=None):
# get prefix_cmd.
if conf.n_mpi_process > 1:
# prefix_cmd = "mpirun -n {} --hostfile {} -bind-to none -map-by slot -mca pml ob1 -mca btl ^openib -x CUDA_LAUNCH_BLOCKING=0 -x NCCL_DEBUG=INFO --mca orte_base_help_aggregate 1 --mca btl_tcp_if_exclude docker0,lo --mca btl_smcuda_use_cuda_ipc 1 --prefix {} "
prefix_cmd = f"mpirun -n {conf.n_mpi_process} --hostfile {conf.hostfile} --mca orte_base_help_aggregate 0 --mca btl_tcp_if_exclude docker0,lo --mca btl_smcuda_use_cuda_ipc {1 if conf.use_ipc else 0} --prefix {conf.mpi_path} "
prefix_cmd += (
f" -x {conf.mpi_env}"
if conf.mpi_env is not None and len(conf.mpi_env) > 0
else ""
)
else:
prefix_cmd = ""
# build complete script.
cmd = " {} main.py ".format(conf.python_path)
for k, v in conf.__dict__.items():
if replacement is not None and k in replacement:
cmd += " --{} {} ".format(k, replacement[k])
elif v is not None:
cmd += " --{} {} ".format(k, v)
return prefix_cmd + cmd
def create_job_on_nodes(conf, tasks):
# rebuild tasks for each script.
node_tasks = []
for ip, _tasks in tasks.items():
_tasks = " & ".join(_tasks)
node_tasks += [(ip, _tasks)]
if (not conf.remote_exec) or "localhost" in tasks:
run_cmd(node_tasks[0][1])
else:
print("\nrun the job on the remote host.\n")
for ip, _tasks in node_tasks:
tx.Run(name=f"{conf.experiment}", job_node=ip).make_job(
job_name="job", task_scripts=[_tasks]
)
def main_nccl_or_gloo(conf, ip2slot):
# build runnable script for a single machine.
script = build_nccl_script(conf)
assert conf.work_dir is not None
# build scripts for distributed world
tasks = dict()
for rank in range(conf.n_mpi_process):
if conf.clean_python:
cmd = "pkill -9 python"
else:
script = build_nccl_script(conf, replacement={"local_rank": rank})
# build remote executable script.
cmd = "cd {work_dir} && {env} {python_path} {script}".format(
work_dir=conf.work_dir,
env="",
python_path=conf.python_path,
script=script,
)
if ip2slot[rank] in tasks:
tasks[ip2slot[rank]].append(cmd)
else:
tasks[ip2slot[rank]] = [cmd]
# build cmd.
print(
"build cmd ({rank}/{world_size}): \n{cmd}\n\n".format(
rank=rank + 1, world_size=conf.n_mpi_process, cmd=cmd
)
)
# run multiple cmds on node.
create_job_on_nodes(conf, tasks)
def main_mpi(conf, ip2slot):
# build scripts for distributed world
tasks = dict()
if conf.clean_python:
cmd = "pkill -9 python"
else:
# build runnable script for a single machine.
cmd = build_mpi_script(conf)
tasks[ip2slot[0]] = [
(
"cd {work_dir} && ".format(work_dir=conf.work_dir)
if conf.work_dir is not None
else ""
)
+ cmd
]
# run cmd.
create_job_on_nodes(conf, tasks)
if __name__ == "__main__":
# parse the arguments.
conf = para.get_args()
# get ip and the corresponding # of slots.
ip2slots = read_hostfile(conf.hostfile)
ip2slot = map_slot(ip2slots)
# run the main script.
if conf.backend == "nccl" or conf.backend == "gloo":
main_nccl_or_gloo(conf, ip2slot)
elif conf.backend == "mpi":
main_mpi(conf, ip2slot)