Skip to content

Commit

Permalink
Merge pull request #25 from daiyl0320/master
Browse files Browse the repository at this point in the history
add retry mechanism to ConnectTracker and modify Listen backlog to 128 in rabit_traker.py
  • Loading branch information
tqchen committed Oct 21, 2015
2 parents c71ed6f + 35c3b37 commit e81a11d
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 19 deletions.
25 changes: 22 additions & 3 deletions src/allreduce_base.cc
Expand Up @@ -24,6 +24,7 @@ AllreduceBase::AllreduceBase(void) {
nport_trial = 1000;
rank = 0;
world_size = -1;
connect_retry = 5;
hadoop_mode = 0;
version_number = 0;
// 32 K items
Expand All @@ -46,6 +47,7 @@ AllreduceBase::AllreduceBase(void) {
env_vars.push_back("DMLC_NUM_ATTEMPT");
env_vars.push_back("DMLC_TRACKER_URI");
env_vars.push_back("DMLC_TRACKER_PORT");
env_vars.push_back("DMLC_WORKER_CONNECT_RETRY");
}

// initialization function
Expand Down Expand Up @@ -175,6 +177,9 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "rabit_reduce_buffer")) {
reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3;
}
if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) {
connect_retry = atoi(val);
}
}
/*!
* \brief initialize connection to the tracker
Expand All @@ -185,9 +190,23 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
// get information from tracker
utils::TCPSocket tracker;
tracker.Create();
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
utils::Socket::Error("Connect");
}

int retry = 0;
do {
fprintf(stderr, "connect to ip: [%s]\n", tracker_uri.c_str());
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
if (++retry >= connect_retry) {
fprintf(stderr, "connect to (failed): [%s]\n", tracker_uri.c_str());
utils::Socket::Error("Connect");
} else {
fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str());
sleep(1);
continue;
}
}
break;
} while (1);

using utils::Assert;
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
"ReConnectLink failure 1");
Expand Down
2 changes: 2 additions & 0 deletions src/allreduce_base.h
Expand Up @@ -519,6 +519,8 @@ class AllreduceBase : public IEngine {
int rank;
// world size
int world_size;
// connect retry time
int connect_retry;
};
} // namespace engine
} // namespace rabit
Expand Down
32 changes: 16 additions & 16 deletions tracker/rabit_tracker.py
@@ -1,6 +1,6 @@
"""
Tracker script for rabit
Implements the tracker control protocol
Implements the tracker control protocol
- start rabit jobs
- help nodes to establish links with each other
Expand All @@ -19,13 +19,13 @@
"""
Extension of socket to handle recv and send of special data
"""
class ExSocket:
class ExSocket:
def __init__(self, sock):
self.sock = sock
def recvall(self, nbytes):
res = []
sock = self.sock
nread = 0
nread = 0
while nread < nbytes:
chunk = self.sock.recv(min(nbytes - nread, 1024))
nread += len(chunk)
Expand Down Expand Up @@ -106,7 +106,7 @@ def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map):
for r in conset:
self.sock.sendstr(wait_conn[r].host)
self.sock.sendint(wait_conn[r].port)
self.sock.sendint(r)
self.sock.sendint(r)
nerr = self.sock.recvint()
if nerr != 0:
continue
Expand All @@ -121,7 +121,7 @@ def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map):
wait_conn.pop(r, None)
self.wait_accept = len(badset) - len(conset)
return rmset

class Tracker:
def __init__(self, port = 9091, port_end = 9999, verbose = True, hostIP = 'auto'):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand All @@ -132,7 +132,7 @@ def __init__(self, port = 9091, port_end = 9999, verbose = True, hostIP = 'auto'
break
except socket.error:
continue
sock.listen(16)
sock.listen(128)
self.sock = sock
self.verbose = verbose
if hostIP == 'auto':
Expand All @@ -145,22 +145,22 @@ def slave_envs(self):
"""
get enviroment variables for slaves
can be passed in as args or envs
"""
"""
if self.hostIP == 'dns':
host = socket.gethostname()
elif self.hostIP == 'ip':
host = socket.gethostbyname(socket.getfqdn())
else:
host = self.hostIP
return {'rabit_tracker_uri': host,
'rabit_tracker_port': self.port}
'rabit_tracker_port': self.port}
def get_neighbor(self, rank, nslave):
rank = rank + 1
ret = []
if rank > 1:
ret.append(rank / 2 - 1)
if rank * 2 - 1 < nslave:
ret.append(rank * 2 - 1)
ret.append(rank * 2 - 1)
if rank * 2 < nslave:
ret.append(rank * 2)
return ret
Expand Down Expand Up @@ -198,10 +198,10 @@ def get_ring(self, tree_map, parent_map):
rlst = self.find_share_ring(tree_map, parent_map, 0)
assert len(rlst) == len(tree_map)
ring_map = {}
nslave = len(tree_map)
nslave = len(tree_map)
for r in range(nslave):
rprev = (r + nslave - 1) % nslave
rnext = (r + 1) % nslave
rnext = (r + 1) % nslave
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
return ring_map

Expand Down Expand Up @@ -231,7 +231,7 @@ def get_link_map(self, nslave):
else:
parent_map_[rmap[k]] = -1
return tree_map_, parent_map_, ring_map_

def handle_print(self,slave, msg):
sys.stdout.write(msg)

Expand All @@ -253,14 +253,14 @@ def accept_slaves(self, nslave):
pending = []
# lazy initialize tree_map
tree_map = None

while len(shutdown) != nslave:
fd, s_addr = self.sock.accept()
s = SlaveEntry(fd, s_addr)
if s.cmd == 'print':
msg = s.sock.recvstr()
self.handle_print(s, msg)
continue
continue
if s.cmd == 'shutdown':
assert s.rank >= 0 and s.rank not in shutdown
assert s.rank not in wait_conn
Expand All @@ -280,12 +280,12 @@ def accept_slaves(self, nslave):
assert s.world_size == -1 or s.world_size == nslave
if s.cmd == 'recover':
assert s.rank >= 0

rank = s.decide_rank(job_map)
# batch assignment of ranks
if rank == -1:
assert len(todo_nodes) != 0
pending.append(s)
pending.append(s)
if len(pending) == len(todo_nodes):
pending.sort(key = lambda x : x.host)
for s in pending:
Expand Down

0 comments on commit e81a11d

Please sign in to comment.