Skip to content

Commit

Permalink
Process session requests concurrently. Add processing_count.
Browse files Browse the repository at this point in the history
  • Loading branch information
Neil Booth committed Apr 12, 2019
1 parent 9bf761a commit 41cbf54
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 54 deletions.
113 changes: 63 additions & 50 deletions aiorpcx/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(self, *, framer=None, loop=None):
self.recv_count = 0
self.recv_size = 0
self.last_recv = self.start_time
self.processing_count = 0
# Resource usage
self.cost = 0.0
self._cost_last = 0.0
Expand Down Expand Up @@ -349,34 +350,40 @@ class MessageSession(SessionBase):
and perhaps a proxy.
'''
async def _receive_messages(self):
while not self.is_closing():
try:
message = await self.framer.receive_message()
except BadMagicError as e:
magic, expected = e.args
self.logger.error(
f'bad network magic: got {magic} expected {expected}, '
f'disconnecting'
)
self._close()
except OversizedPayloadError as e:
command, payload_len = e.args
self.logger.error(
f'oversized payload of {payload_len:,d} bytes to command '
f'{command}, disconnecting'
)
self._close()
except BadChecksumError as e:
payload_checksum, claimed_checksum = e.args
self.logger.warning(
f'checksum mismatch: actual {payload_checksum.hex()} '
f'vs claimed {claimed_checksum.hex()}'
)
self._bump_errors()
else:
self.last_recv = time.time()
self.recv_count += 1
await self._throttled_message(message)
async with TaskGroup() as group:
while not self.is_closing():
try:
message = await self.framer.receive_message()
except BadMagicError as e:
magic, expected = e.args
self.logger.error(
f'bad network magic: got {magic} expected {expected}, '
f'disconnecting'
)
self._close()
except OversizedPayloadError as e:
command, payload_len = e.args
self.logger.error(
f'oversized payload of {payload_len:,d} bytes to command '
f'{command}, disconnecting'
)
self._close()
except BadChecksumError as e:
payload_checksum, claimed_checksum = e.args
self.logger.warning(
f'checksum mismatch: actual {payload_checksum.hex()} '
f'vs claimed {claimed_checksum.hex()}'
)
self._bump_errors()
else:
self.last_recv = time.time()
self.recv_count += 1
self.processing_count += 1
await group.spawn(self._throttled_message(message))

while group._done:
await group.next_result()
self.processing_count -= 1

async def _throttled_message(self, message):
'''Process a single request, respecting the concurrency limit.'''
Expand Down Expand Up @@ -493,28 +500,34 @@ def __init__(self, *, framer=None, loop=None, connection=None):
self.connection = connection or self.default_connection()

async def _receive_messages(self):
while not self.is_closing():
try:
message = await self.framer.receive_message()
except MemoryError as e:
self.logger.warning(f'{e!r}')
continue

self.last_recv = time.time()
self.recv_count += 1

try:
requests = self.connection.receive_message(message)
except ProtocolError as e:
self.logger.debug(f'{e}')
if e.error_message:
await self._send_message(e.error_message)
if e.code == JSONRPC.PARSE_ERROR:
self.max_errors = 0
self._bump_errors()
else:
for request in requests:
await self._throttled_request(request)
async with TaskGroup() as group:
while not self.is_closing():
try:
message = await self.framer.receive_message()
except MemoryError as e:
self.logger.warning(f'{e!r}')
continue

self.last_recv = time.time()
self.recv_count += 1

try:
requests = self.connection.receive_message(message)
except ProtocolError as e:
self.logger.debug(f'{e}')
if e.error_message:
await self._send_message(e.error_message)
if e.code == JSONRPC.PARSE_ERROR:
self.max_errors = 0
self._bump_errors()
else:
self.processing_count += len(requests)
for request in requests:
await group.spawn(self._throttled_request(request))

while group._done:
await group.next_result()
self.processing_count -= 1

async def _throttled_request(self, request):
'''Process a single request, respecting the concurrency limit.'''
Expand Down
9 changes: 5 additions & 4 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,10 +722,11 @@ async def test_basic_send(self, msg_server):
async with Connector(MessageSession, 'localhost',
msg_server.port) as client:
server_session = await MessageServer.current_server()
await client.send_message((b'version', b'abc'))
# Give the receiving task time to process before closing the connection
await sleep(0.001)
assert server_session.messages == [(b'version', b'abc')]
for n in range(3):
await client.send_message((b'version', b'abc'))
# Give the receiving task time to process before closing the connection
await sleep(0.001)
assert server_session.messages == [(b'version', b'abc')] * 3

@pytest.mark.asyncio
async def test_many_sends(self, msg_server):
Expand Down

0 comments on commit 41cbf54

Please sign in to comment.