Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
munrojm committed Jan 25, 2022
1 parent be77103 commit 39a0b5f
Showing 1 changed file with 45 additions and 56 deletions.
101 changes: 45 additions & 56 deletions tests/cli/test_distributed.py
Expand Up @@ -2,12 +2,18 @@
import json

import pytest
from pynng import Pair1
from pynng.exceptions import Timeout

from maggma.cli.distributed import find_port, manager, worker
from maggma.core import Builder

from zmq import REP, REQ
import zmq.asyncio as zmq
import socket as pysocket

# TODO: Timeout errors?

HOSTNAME = pysocket.gethostname()


class DummyBuilderWithNoPrechunk(Builder):
def __init__(self, dummy_prechunk: bool, val: int = -1, **kwargs):
Expand Down Expand Up @@ -44,78 +50,60 @@ async def manager_server(event_loop, log_to_stdout):

task = asyncio.create_task(
manager(
SERVER_URL, SERVER_PORT, [DummyBuilder(dummy_prechunk=False)], num_chunks=10
SERVER_URL,
SERVER_PORT,
[DummyBuilder(dummy_prechunk=False)],
num_chunks=10,
num_workers=10,
)
)
yield task
task.cancel()


@pytest.mark.asyncio
async def test_manager_wait_for_ready(manager_server):
with Pair1(
dial=f"{SERVER_URL}:{SERVER_PORT}", polyamorous=True, recv_timeout=100
) as manager:
with pytest.raises(Timeout):
manager.recv()
async def test_manager_give_out_chunks(manager_server, log_to_stdout):

context = zmq.Context()
socket = context.socket(REQ)
socket.connect(f"{SERVER_URL}:{SERVER_PORT}")

for i in range(0, 10):
log_to_stdout.debug(f"Going to ask Manager for work: {i}")
await socket.send(b"Ready")
message = await socket.recv()

@pytest.mark.asyncio
async def test_manager_give_out_chunks(manager_server, log_to_stdout):
with Pair1(
dial=f"{SERVER_URL}:{SERVER_PORT}", polyamorous=True, recv_timeout=500
) as manager_socket:

for i in range(0, 10):
log_to_stdout.debug(f"Going to ask Manager for work: {i}")
await manager_socket.asend(b"Ready")
message = await manager_socket.arecv()
print(message)
work = json.loads(message.decode("utf-8"))

assert work["@class"] == "DummyBuilder"
assert work["@module"] == "tests.cli.test_distributed"
assert work["val"] == i

await manager_socket.asend(b"Ready")
message = await manager_socket.arecv()
work = json.loads(message.decode("utf-8"))
assert work == {}

assert work["@class"] == "DummyBuilder"
assert work["@module"] == "tests.cli.test_distributed"
assert work["val"] == i


@pytest.mark.asyncio
async def test_worker():
with Pair1(
listen=f"{SERVER_URL}:{SERVER_PORT}", polyamorous=True, recv_timeout=500
) as worker_socket:

worker_task = asyncio.create_task(
worker(SERVER_URL, SERVER_PORT, num_workers=1)
)

message = await worker_socket.arecv()
assert message == b"Ready"
context = zmq.Context()
socket = context.socket(REP)
socket.bind(f"{SERVER_URL}:{SERVER_PORT}")

dummy_work = {
"@module": "tests.cli.test_distributed",
"@class": "DummyBuilder",
"@version": None,
"dummy_prechunk": False,
"val": 0,
}
for i in range(2):
await worker_socket.asend(json.dumps(dummy_work).encode("utf-8"))
await asyncio.sleep(1)
message = await worker_socket.arecv()
assert message == b"Ready"
worker_task = asyncio.create_task(worker(SERVER_URL, SERVER_PORT, num_workers=1))

await worker_socket.asend(json.dumps({}).encode("utf-8"))
with pytest.raises(Timeout):
await worker_socket.arecv()
message = await socket.recv()

assert len(worker_socket.pipes) == 0
dummy_work = {
"@module": "tests.cli.test_distributed",
"@class": "DummyBuilder",
"@version": None,
"dummy_prechunk": False,
"val": 0,
}
for i in range(2):
await socket.send(json.dumps(dummy_work).encode("utf-8"))
await asyncio.sleep(1)
message = await socket.recv()
assert message == HOSTNAME.encode("utf-8")

worker_task.cancel()
worker_task.cancel()


@pytest.mark.asyncio
Expand All @@ -127,6 +115,7 @@ async def test_no_prechunk(caplog):
SERVER_PORT,
[DummyBuilderWithNoPrechunk(dummy_prechunk=False)],
num_chunks=10,
num_workers=10,
)
)
await asyncio.sleep(1)
Expand Down

0 comments on commit 39a0b5f

Please sign in to comment.