diff --git a/asyncio_dgram/aio.py b/asyncio_dgram/aio.py index 8117282..7487a98 100644 --- a/asyncio_dgram/aio.py +++ b/asyncio_dgram/aio.py @@ -22,17 +22,20 @@ class DatagramStream: raised. """ - def __init__(self, transport, recvq, excq): + def __init__(self, transport, recvq, excq, drained): """ @param transport - asyncio transport @param recvq - asyncio queue that gets populated by the DatagramProtocol with received datagrams. @param excq - asyncio queue that gets populated with any errors detected by the DatagramProtocol. + @param drained - asyncio event that is unset when writing is + paused and set otherwise. """ self._transport = transport self._recvq = recvq self._excq = excq + self._drained = drained def __del__(self): self._transport.close() @@ -87,6 +90,7 @@ async def send(self, data, addr=None): """ _ = self.exception self._transport.sendto(data, addr) + await self._drained.wait() async def recv(self): """ @@ -132,13 +136,18 @@ class Protocol(asyncio.DatagramProtocol): based asyncio into higher level coroutines. """ - def __init__(self, recvq, excq): + def __init__(self, recvq, excq, drained): """ @param recvq - asyncio.Queue for new datagrams @param excq - asyncio.Queue for exceptions + @param drained - asyncio.Event set when the write buffer is below the + high watermark. """ self._recvq = recvq self._excq = excq + self._drained = drained + + self._drained.set() # Transports are connected at the time a connection is made. self._transport = None @@ -167,12 +176,19 @@ def datagram_received(self, data, addr): def error_received(self, exc): self._excq.put_nowait(exc) + def pause_writing(self): + self._drained.clear() + super().pause_writing() + + def resume_writing(self): + self._drained.set() + super().resume_writing() + async def bind(addr): """ Bind a socket to a local address for datagrams. The socket will be either - AF_INET or AF_INET6 depending upon the type of address specified. The - socket will be reusable (SO_REUSEADDR) once it enters TIME_WAIT. + AF_INET or AF_INET6 depending upon the type of address specified. @param addr - For AF_INET or AF_INET6, a tuple with the the host and port to to bind; port may be set to 0 to get any free port. @@ -181,12 +197,13 @@ async def bind(addr): loop = asyncio.get_event_loop() recvq = asyncio.Queue() excq = asyncio.Queue() + drained = asyncio.Event() transport, protocol = await loop.create_datagram_endpoint( - lambda: Protocol(recvq, excq), local_addr=addr, reuse_address=True + lambda: Protocol(recvq, excq, drained), local_addr=addr, reuse_address=False ) - return DatagramServer(transport, recvq, excq) + return DatagramServer(transport, recvq, excq, drained) async def connect(addr): @@ -201,12 +218,13 @@ async def connect(addr): loop = asyncio.get_event_loop() recvq = asyncio.Queue() excq = asyncio.Queue() + drained = asyncio.Event() transport, protocol = await loop.create_datagram_endpoint( - lambda: Protocol(recvq, excq), remote_addr=addr + lambda: Protocol(recvq, excq, drained), remote_addr=addr ) - return DatagramClient(transport, recvq, excq) + return DatagramClient(transport, recvq, excq, drained) async def from_socket(sock): @@ -224,6 +242,7 @@ async def from_socket(sock): loop = asyncio.get_event_loop() recvq = asyncio.Queue() excq = asyncio.Queue() + drained = asyncio.Event() if sock.family not in (socket.AF_INET, socket.AF_INET6): raise TypeError( @@ -234,12 +253,12 @@ async def from_socket(sock): raise TypeError("socket must be %s" % (socket.SOCK_DGRAM,)) transport, protocol = await loop.create_datagram_endpoint( - lambda: Protocol(recvq, excq), sock=sock + lambda: Protocol(recvq, excq, drained), sock=sock ) if transport.get_extra_info("peername") is not None: # Workaround transport ignoring the peer address of the socket. transport._address = transport.get_extra_info("peername") - return DatagramClient(transport, recvq, excq) + return DatagramClient(transport, recvq, excq, drained) else: - return DatagramServer(transport, recvq, excq) + return DatagramServer(transport, recvq, excq, drained) diff --git a/test/test_aio.py b/test/test_aio.py index 7d52a99..7b71ffc 100644 --- a/test/test_aio.py +++ b/test/test_aio.py @@ -1,12 +1,23 @@ import asyncio import contextlib +import os import socket +import unittest.mock import pytest import asyncio_dgram +@pytest.fixture +def mock_socket(): + s = unittest.mock.create_autospec(socket.socket) + s.family = socket.AF_INET + s.type = socket.SOCK_DGRAM + + return s + + @contextlib.contextmanager def loop_exception_handler(): """ @@ -260,3 +271,81 @@ async def test_unconnected_sender(addr): with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(connected.recv(), 0.05) + + +@pytest.mark.asyncio +async def test_protocol_pause_resume(monkeypatch, mock_socket, tmp_path): + # This is a little involved, but necessary to make sure that the Protocol + # is correctly noticing when writing as been paused and resumed. In + # summary: + # + # - Mock the Protocol with one that sets the write buffer limits to 0 and + # records when pause and recume writing are called. + # + # - Use a mock socket so that we can inject a BlockingIOError on send. + # Ideally we'd mock method itself, but it's read-only the entire object + # needs to be mocked. Due to this, we need to use a temporary file that we + # can write to in order to kick the event loop to consider it ready for + # writing. + + class TestableProtocol(asyncio_dgram.aio.Protocol): + pause_writing_called = 0 + resume_writing_called = 0 + instance = None + + def __init__(self, *args, **kwds): + TestableProtocol.instance = self + super().__init__(*args, **kwds) + + def connection_made(self, transport): + transport.set_write_buffer_limits(low=0, high=0) + super().connection_made(transport) + + def pause_writing(self): + self.pause_writing_called += 1 + super().pause_writing() + + def resume_writing(self): + self.resume_writing_called += 1 + super().resume_writing() + + async def passthrough(): + """ + Used to mock the wait method on the asyncio.Event tracking if the write + buffer is past the high water mark or not. Given we're testing how + that case is handled, we know it's safe locally to mock it. + """ + pass + + with monkeypatch.context() as ctx: + ctx.setattr(asyncio_dgram.aio, "Protocol", TestableProtocol) + + client = await asyncio_dgram.from_socket(mock_socket) + mock_socket.send.side_effect = BlockingIOError + mock_socket.fileno.return_value = os.open( + tmp_path / "socket", os.O_RDONLY | os.O_CREAT + ) + + with monkeypatch.context() as ctx2: + ctx2.setattr(client._drained, "wait", passthrough) + await client.send(b"foo") + + assert TestableProtocol.instance.pause_writing_called == 1 + assert TestableProtocol.instance.resume_writing_called == 0 + assert not TestableProtocol.instance._drained.is_set() + + mock_socket.send.side_effect = None + fd = os.open(tmp_path / "socket", os.O_WRONLY) + os.write(fd, b"\n") + os.close(fd) + + with monkeypatch.context() as ctx2: + ctx2.setattr(client._drained, "wait", passthrough) + await client.send(b"foo") + await asyncio.sleep(0.1) + + assert TestableProtocol.instance.pause_writing_called == 1 + assert TestableProtocol.instance.resume_writing_called == 1 + assert TestableProtocol.instance._drained.is_set() + + os.close(mock_socket.fileno.return_value)