Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aio: drain write buffer during send #2

Merged
merged 2 commits into from Dec 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 30 additions & 11 deletions asyncio_dgram/aio.py
Expand Up @@ -22,17 +22,20 @@ class DatagramStream:
raised.
"""

def __init__(self, transport, recvq, excq):
def __init__(self, transport, recvq, excq, drained):
"""
@param transport - asyncio transport
@param recvq - asyncio queue that gets populated by the
DatagramProtocol with received datagrams.
@param excq - asyncio queue that gets populated with any errors
detected by the DatagramProtocol.
@param drained - asyncio event that is unset when writing is
paused and set otherwise.
"""
self._transport = transport
self._recvq = recvq
self._excq = excq
self._drained = drained

def __del__(self):
self._transport.close()
Expand Down Expand Up @@ -87,6 +90,7 @@ async def send(self, data, addr=None):
"""
_ = self.exception
self._transport.sendto(data, addr)
await self._drained.wait()

async def recv(self):
"""
Expand Down Expand Up @@ -132,13 +136,18 @@ class Protocol(asyncio.DatagramProtocol):
based asyncio into higher level coroutines.
"""

def __init__(self, recvq, excq):
def __init__(self, recvq, excq, drained):
"""
@param recvq - asyncio.Queue for new datagrams
@param excq - asyncio.Queue for exceptions
@param drained - asyncio.Event set when the write buffer is below the
high watermark.
"""
self._recvq = recvq
self._excq = excq
self._drained = drained

self._drained.set()

# Transports are connected at the time a connection is made.
self._transport = None
Expand Down Expand Up @@ -167,12 +176,19 @@ def datagram_received(self, data, addr):
def error_received(self, exc):
self._excq.put_nowait(exc)

def pause_writing(self):
self._drained.clear()
super().pause_writing()

def resume_writing(self):
self._drained.set()
super().resume_writing()


async def bind(addr):
"""
Bind a socket to a local address for datagrams. The socket will be either
AF_INET or AF_INET6 depending upon the type of address specified. The
socket will be reusable (SO_REUSEADDR) once it enters TIME_WAIT.
AF_INET or AF_INET6 depending upon the type of address specified.

@param addr - For AF_INET or AF_INET6, a tuple with the the host and port to
to bind; port may be set to 0 to get any free port.
Expand All @@ -181,12 +197,13 @@ async def bind(addr):
loop = asyncio.get_event_loop()
recvq = asyncio.Queue()
excq = asyncio.Queue()
drained = asyncio.Event()

transport, protocol = await loop.create_datagram_endpoint(
lambda: Protocol(recvq, excq), local_addr=addr, reuse_address=True
lambda: Protocol(recvq, excq, drained), local_addr=addr, reuse_address=False
)

return DatagramServer(transport, recvq, excq)
return DatagramServer(transport, recvq, excq, drained)


async def connect(addr):
Expand All @@ -201,12 +218,13 @@ async def connect(addr):
loop = asyncio.get_event_loop()
recvq = asyncio.Queue()
excq = asyncio.Queue()
drained = asyncio.Event()

transport, protocol = await loop.create_datagram_endpoint(
lambda: Protocol(recvq, excq), remote_addr=addr
lambda: Protocol(recvq, excq, drained), remote_addr=addr
)

return DatagramClient(transport, recvq, excq)
return DatagramClient(transport, recvq, excq, drained)


async def from_socket(sock):
Expand All @@ -224,6 +242,7 @@ async def from_socket(sock):
loop = asyncio.get_event_loop()
recvq = asyncio.Queue()
excq = asyncio.Queue()
drained = asyncio.Event()

if sock.family not in (socket.AF_INET, socket.AF_INET6):
raise TypeError(
Expand All @@ -234,12 +253,12 @@ async def from_socket(sock):
raise TypeError("socket must be %s" % (socket.SOCK_DGRAM,))

transport, protocol = await loop.create_datagram_endpoint(
lambda: Protocol(recvq, excq), sock=sock
lambda: Protocol(recvq, excq, drained), sock=sock
)

if transport.get_extra_info("peername") is not None:
# Workaround transport ignoring the peer address of the socket.
transport._address = transport.get_extra_info("peername")
return DatagramClient(transport, recvq, excq)
return DatagramClient(transport, recvq, excq, drained)
else:
return DatagramServer(transport, recvq, excq)
return DatagramServer(transport, recvq, excq, drained)
89 changes: 89 additions & 0 deletions test/test_aio.py
@@ -1,12 +1,23 @@
import asyncio
import contextlib
import os
import socket
import unittest.mock

import pytest

import asyncio_dgram


@pytest.fixture
def mock_socket():
s = unittest.mock.create_autospec(socket.socket)
s.family = socket.AF_INET
s.type = socket.SOCK_DGRAM

return s


@contextlib.contextmanager
def loop_exception_handler():
"""
Expand Down Expand Up @@ -260,3 +271,81 @@ async def test_unconnected_sender(addr):

with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(connected.recv(), 0.05)


@pytest.mark.asyncio
async def test_protocol_pause_resume(monkeypatch, mock_socket, tmp_path):
# This is a little involved, but necessary to make sure that the Protocol
# is correctly noticing when writing as been paused and resumed. In
# summary:
#
# - Mock the Protocol with one that sets the write buffer limits to 0 and
# records when pause and recume writing are called.
#
# - Use a mock socket so that we can inject a BlockingIOError on send.
# Ideally we'd mock method itself, but it's read-only the entire object
# needs to be mocked. Due to this, we need to use a temporary file that we
# can write to in order to kick the event loop to consider it ready for
# writing.

class TestableProtocol(asyncio_dgram.aio.Protocol):
pause_writing_called = 0
resume_writing_called = 0
instance = None

def __init__(self, *args, **kwds):
TestableProtocol.instance = self
super().__init__(*args, **kwds)

def connection_made(self, transport):
transport.set_write_buffer_limits(low=0, high=0)
super().connection_made(transport)

def pause_writing(self):
self.pause_writing_called += 1
super().pause_writing()

def resume_writing(self):
self.resume_writing_called += 1
super().resume_writing()

async def passthrough():
"""
Used to mock the wait method on the asyncio.Event tracking if the write
buffer is past the high water mark or not. Given we're testing how
that case is handled, we know it's safe locally to mock it.
"""
pass

with monkeypatch.context() as ctx:
ctx.setattr(asyncio_dgram.aio, "Protocol", TestableProtocol)

client = await asyncio_dgram.from_socket(mock_socket)
mock_socket.send.side_effect = BlockingIOError
mock_socket.fileno.return_value = os.open(
tmp_path / "socket", os.O_RDONLY | os.O_CREAT
)

with monkeypatch.context() as ctx2:
ctx2.setattr(client._drained, "wait", passthrough)
await client.send(b"foo")

assert TestableProtocol.instance.pause_writing_called == 1
assert TestableProtocol.instance.resume_writing_called == 0
assert not TestableProtocol.instance._drained.is_set()

mock_socket.send.side_effect = None
fd = os.open(tmp_path / "socket", os.O_WRONLY)
os.write(fd, b"\n")
os.close(fd)

with monkeypatch.context() as ctx2:
ctx2.setattr(client._drained, "wait", passthrough)
await client.send(b"foo")
await asyncio.sleep(0.1)

assert TestableProtocol.instance.pause_writing_called == 1
assert TestableProtocol.instance.resume_writing_called == 1
assert TestableProtocol.instance._drained.is_set()

os.close(mock_socket.fileno.return_value)