Skip to content

Commit

Permalink
Fix distribution if workers error out
Browse files Browse the repository at this point in the history
  • Loading branch information
munrojm committed Jan 26, 2022
1 parent 3fceea7 commit dc05fbf
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 24 deletions.
39 changes: 24 additions & 15 deletions src/maggma/cli/distributed.py
Expand Up @@ -3,6 +3,7 @@

import json
from logging import getLogger
from multiprocessing.sharedctypes import Value
import socket as pysocket
from typing import List

Expand Down Expand Up @@ -47,25 +48,34 @@ async def manager(
try:

builder.connect()
chunks_dicts = list(builder.prechunk(num_chunks))
chunks_tuples = [(d, False) for d in builder.prechunk(num_chunks)]

logger.info(f"Distributing {len(chunks_dicts)} chunks to workers")
for chunk_dict in tqdm(chunks_dicts, desc="Chunks"):
temp_builder_dict = dict(**builder_dict)
temp_builder_dict.update(chunk_dict)
temp_builder_dict = jsanitize(temp_builder_dict)
logger.info(f"Distributing {len(chunks_tuples)} chunks to workers")

# Wait for client connection that announces client and says it is ready to do work
logger.debug("Waiting for a worker")
for chunk_dict, distributed in tqdm(chunks_tuples, desc="Chunks"):
while not distributed:
if num_workers <= 0:
socket.close()
raise RuntimeError("No workers left to distribute chunks to")

worker = await socket.recv()
temp_builder_dict = dict(**builder_dict)
temp_builder_dict.update(chunk_dict)
temp_builder_dict = jsanitize(temp_builder_dict)

if worker.decode("utf-8") == "ERROR":
num_workers -= 1
# Wait for client connection that announces client and says it is ready to do work
logger.debug("Waiting for a worker")

logger.debug(f"Got connection from worker: {worker.decode('utf-8')}")
# Send out the next chunk
await socket.send(json.dumps(temp_builder_dict).encode("utf-8"))
worker = await socket.recv()

if worker.decode("utf-8") == "ERROR":
num_workers -= 1
else:
logger.debug(
f"Got connection from worker: {worker.decode('utf-8')}"
)
# Send out the next chunk
await socket.send(json.dumps(temp_builder_dict).encode("utf-8"))
distributed = True

logger.info("Sending exit messages to workers")
for _ in range(num_workers):
Expand Down Expand Up @@ -113,7 +123,6 @@ async def worker(url: str, port: int, num_processes: int):
except Exception as e:
logger.error(f"A worker failed with error: {e}")
await socket.send("ERROR".encode("utf-8"))
await socket.recv()

socket.close()

Expand Down
55 changes: 46 additions & 9 deletions tests/cli/test_distributed.py
Expand Up @@ -56,10 +56,10 @@ def process_items(self, items):
SERVER_PORT = 8234


@pytest.fixture(scope="function")
async def manager_server(event_loop, log_to_stdout):
@pytest.mark.asyncio
async def test_manager_give_out_chunks(log_to_stdout):

task = asyncio.create_task(
manager_server = asyncio.create_task(
manager(
SERVER_URL,
SERVER_PORT,
Expand All @@ -68,12 +68,6 @@ async def manager_server(event_loop, log_to_stdout):
num_workers=10,
)
)
yield task
task.cancel()


@pytest.mark.asyncio
async def test_manager_give_out_chunks(manager_server, log_to_stdout):

context = zmq.Context()
socket = context.socket(REQ)
Expand All @@ -95,6 +89,31 @@ async def test_manager_give_out_chunks(manager_server, log_to_stdout):
message = await socket.recv()
assert message == b'"EXIT"'

manager_server.cancel()


@pytest.mark.asyncio
async def test_manager_worker_error(log_to_stdout):

manager_server = asyncio.create_task(
manager(
SERVER_URL,
SERVER_PORT,
[DummyBuilder(dummy_prechunk=False)],
num_chunks=10,
num_workers=1,
)
)

context = zmq.Context()
socket = context.socket(REQ)
socket.connect(f"{SERVER_URL}:{SERVER_PORT}")

await socket.send("ERROR".encode("utf-8"))
await asyncio.sleep(1)
assert manager_server.done()
manager_server.cancel()


@pytest.mark.asyncio
async def test_worker():
Expand Down Expand Up @@ -149,6 +168,24 @@ async def test_worker_error():
worker_task.cancel()


@pytest.mark.asyncio
async def test_worker_exit():
context = zmq.Context()
socket = context.socket(REP)
socket.bind(f"{SERVER_URL}:{SERVER_PORT}")

worker_task = asyncio.create_task(worker(SERVER_URL, SERVER_PORT, num_processes=1))

message = await socket.recv()
assert message == HOSTNAME.encode("utf-8")

await socket.send_json("EXIT")
await asyncio.sleep(1)
assert worker_task.done()

worker_task.cancel()


@pytest.mark.asyncio
async def test_no_prechunk(caplog):

Expand Down

0 comments on commit dc05fbf

Please sign in to comment.