Skip to content

Commit

Permalink
Merge pull request #715 from materialsproject/bugfix/distributed_stal…
Browse files Browse the repository at this point in the history
…ling

Fix stalling in distributed code
  • Loading branch information
munrojm committed Sep 27, 2022
2 parents 90e3614 + 9ca8615 commit bc02193
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 27 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ aioitertools==0.10.0
pydantic==1.9.1
fastapi==0.79.0
numpy==1.21.0;python_version>"3.6"
pyzmq==22.3.0
pyzmq==24.0.1
dnspython==2.2.1
uvicorn==0.18.2
sshtunnel==0.4.0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"numpy>=1.17.3",
"pydantic>=0.32.2",
"fastapi>=0.42.0",
"pyzmq==22.3.0",
"pyzmq==24.0.1",
"dnspython>=1.16.0",
"sshtunnel>=0.1.5",
"msgpack>=0.5.6",
Expand Down
2 changes: 1 addition & 1 deletion src/maggma/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def run(
# worker
loop = asyncio.get_event_loop()
loop.run_until_complete(
worker(url=url, port=port, num_processes=num_processes)
worker(url=url, port=port, num_processes=num_processes, no_bars=no_bars)
)
else:
if num_processes == 1:
Expand Down
39 changes: 24 additions & 15 deletions src/maggma/cli/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from time import perf_counter
import asyncio
from random import randint

from monty.json import jsanitize
from monty.serialization import MontyDecoder
Expand Down Expand Up @@ -45,6 +46,7 @@ def manager(

logger.info(f"Binding to Manager URL {url}:{port}")
context = zmq.Context()
context.setsockopt(opt=zmq.SocketOption.ROUTER_MANDATORY, value=1)
socket = context.socket(zmq.ROUTER)
socket.bind(f"{url}:{port}")

Expand All @@ -66,12 +68,12 @@ def manager(
for d in builder.prechunk(num_chunks)
]
pbar_distributed = tqdm(
total=num_chunks,
total=len(chunk_dicts),
desc="Distributed chunks for {}".format(builder.__class__.__name__),
)

pbar_completed = tqdm(
total=num_chunks,
total=len(chunk_dicts),
desc="Completed chunks for {}".format(builder.__class__.__name__),
)

Expand All @@ -94,13 +96,13 @@ def manager(
raise RuntimeError("No workers to distribute chunks to")

# Poll and look for messages from workers
connections = dict(poll.poll(500))
connections = dict(poll.poll(100))

# If workers send messages decode and figure out what do
if connections:
identity, _, msg = socket.recv_multipart()
identity, _, bmsg = socket.recv_multipart()

msg = msg.decode("utf-8")
msg = bmsg.decode("utf-8")

if "READY" in msg:
if identity not in workers:
Expand All @@ -121,7 +123,11 @@ def manager(

# If everything is distributed, send EXIT to the worker
if all(chunk["distributed"] for chunk in chunk_dicts):
logger.debug(
f"Sending exit signal to worker: {msg.split('_')[1]}"
)
socket.send_multipart([identity, b"", b"EXIT"])
workers.pop(identity)

elif "ERROR" in msg:
# Remove worker and requeue work sent to it
Expand All @@ -142,12 +148,12 @@ def manager(
handle_dead_workers(workers, socket)

for work_index, chunk_dict in enumerate(chunk_dicts):
temp_builder_dict = dict(**builder_dict)
temp_builder_dict.update(chunk_dict["chunk"]) # type: ignore
temp_builder_dict = jsanitize(temp_builder_dict)

if not chunk_dict["distributed"]:

temp_builder_dict = dict(**builder_dict)
temp_builder_dict.update(chunk_dict["chunk"]) # type: ignore
temp_builder_dict = jsanitize(temp_builder_dict)

# Send work for available workers
for identity in workers:
if not workers[identity]["working"]:
Expand Down Expand Up @@ -210,17 +216,19 @@ def handle_dead_workers(workers, socket):
)


async def worker(url: str, port: int, num_processes: int):
async def worker(url: str, port: int, num_processes: int, no_bars: bool):
"""
Simple distributed worker that connects to a manager asks for work and deploys
using multiprocessing
"""
# Should this have some sort of unique ID?
logger = getLogger("Worker")
identity = "%04X-%04X" % (randint(0, 0x10000), randint(0, 0x10000))
logger = getLogger(f"Worker {identity}")

logger.info(f"Connnecting to Manager at {url}:{port}")
context = azmq.Context()
socket = context.socket(zmq.REQ)

socket.setsockopt_string(zmq.IDENTITY, identity)
socket.connect(f"{url}:{port}")

# Initial message package
Expand All @@ -231,16 +239,17 @@ async def worker(url: str, port: int, num_processes: int):
while running:
await socket.send("READY_{}".format(hostname).encode("utf-8"))
try:
message = await asyncio.wait_for(socket.recv(), timeout=MANAGER_TIMEOUT)
bmessage: bytes = await asyncio.wait_for(socket.recv(), timeout=MANAGER_TIMEOUT) # type: ignore
except asyncio.TimeoutError:
socket.close()
raise RuntimeError("Stopping work as manager timed out.")
message = message.decode("utf-8")

message = bmessage.decode("utf-8")
if "@class" in message and "@module" in message:
# We have a valid builder
work = json.loads(message)
builder = MontyDecoder().process_decoded(work)
await multi(builder, num_processes, socket=socket)
await multi(builder, num_processes, socket=socket, no_bars=no_bars)
elif message == "EXIT":
# End the worker
running = False
Expand Down
2 changes: 1 addition & 1 deletion src/maggma/cli/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from maggma.utils import primed

MANAGER_TIMEOUT = 120 # max timeout in seconds for manager
MANAGER_TIMEOUT = 600 # max timeout in seconds for manager

logger = getLogger("MultiProcessor")

Expand Down
12 changes: 4 additions & 8 deletions tests/cli/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,7 @@ async def test_manager_and_worker(log_to_stdout):
)
manager_thread.start()

context = zmq.Context()
socket = context.socket(REQ)
socket.connect(f"{SERVER_URL}:{SERVER_PORT}")

tasks = [worker(SERVER_URL, SERVER_PORT, num_processes=1) for _ in range(5)]
tasks = [worker(SERVER_URL, SERVER_PORT, num_processes=1, no_bars=True) for _ in range(5)]
await asyncio.gather(*tasks)

manager_thread.join()
Expand Down Expand Up @@ -115,7 +111,7 @@ async def test_worker_error():
socket = context.socket(REP)
socket.bind(f"{SERVER_URL}:{SERVER_PORT}")

worker_task = asyncio.create_task(worker(SERVER_URL, SERVER_PORT, num_processes=1))
worker_task = asyncio.create_task(worker(SERVER_URL, SERVER_PORT, num_processes=1, no_bars=True))

message = await socket.recv()
assert message == "READY_{}".format(HOSTNAME).encode("utf-8")
Expand All @@ -142,11 +138,11 @@ async def test_worker_exit():
socket = context.socket(REP)
socket.bind(f"{SERVER_URL}:{SERVER_PORT}")

worker_task = asyncio.create_task(worker(SERVER_URL, SERVER_PORT, num_processes=1))
worker_task = asyncio.create_task(worker(SERVER_URL, SERVER_PORT, num_processes=1, no_bars=True))

message = await socket.recv()
assert message == "READY_{}".format(HOSTNAME).encode("utf-8")

await asyncio.sleep(1)
await socket.send(b"EXIT")
await asyncio.sleep(1)
assert worker_task.done()
Expand Down

0 comments on commit bc02193

Please sign in to comment.