Skip to content

Commit

Permalink
Merge e3b1ad2 into 0d21644
Browse files Browse the repository at this point in the history
  • Loading branch information
dpkp committed Oct 16, 2017
2 parents 0d21644 + e3b1ad2 commit 7f553e6
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 31 deletions.
22 changes: 13 additions & 9 deletions kafka/client.py
Expand Up @@ -175,7 +175,8 @@ def _send_broker_unaware_request(self, payloads, encoder_fn, decoder_fn):

# Block
while not future.is_done:
conn.recv()
for r, f in conn.recv():
f.success(r)

if future.failed():
log.error("Request failed: %s", future.exception)
Expand Down Expand Up @@ -288,7 +289,8 @@ def failed_payloads(payloads):

if not future.is_done:
conn, _ = connections_by_future[future]
conn.recv()
for r, f in conn.recv():
f.success(r)
continue

_, broker = connections_by_future.pop(future)
Expand Down Expand Up @@ -352,8 +354,6 @@ def _send_consumer_aware_request(self, group, payloads, encoder_fn, decoder_fn):
try:
host, port, afi = get_ip_port_afi(broker.host)
conn = self._get_conn(host, broker.port, afi)
conn.send(request_id, request)

except ConnectionError as e:
log.warning('ConnectionError attempting to send request %s '
'to server %s: %s', request_id, broker, e)
Expand All @@ -365,6 +365,11 @@ def _send_consumer_aware_request(self, group, payloads, encoder_fn, decoder_fn):
# No exception, try to get response
else:

future = conn.send(request_id, request)
while not future.is_done:
for r, f in conn.recv():
f.success(r)

# decoder_fn=None signal that the server is expected to not
# send a response. This probably only applies to
# ProduceRequest w/ acks = 0
Expand All @@ -376,18 +381,17 @@ def _send_consumer_aware_request(self, group, payloads, encoder_fn, decoder_fn):
responses[topic_partition] = None
return []

try:
response = conn.recv(request_id)
except ConnectionError as e:
log.warning('ConnectionError attempting to receive a '
if future.failed():
log.warning('Error attempting to receive a '
'response to request %s from server %s: %s',
request_id, broker, e)
request_id, broker, future.exception)

for payload in payloads:
topic_partition = (payload.topic, payload.partition)
responses[topic_partition] = FailedPayloadsError(payload)

else:
response = future.value
_resps = []
for payload_response in decoder_fn(response):
topic_partition = (payload_response.topic,
Expand Down
32 changes: 25 additions & 7 deletions kafka/client_async.py
@@ -1,5 +1,6 @@
from __future__ import absolute_import, division

import collections
import copy
import functools
import heapq
Expand Down Expand Up @@ -204,6 +205,11 @@ def __init__(self, **configs):
self._wake_r, self._wake_w = socket.socketpair()
self._wake_r.setblocking(False)
self._wake_lock = threading.Lock()

# when requests complete, they are transferred to this queue prior to
# invocation.
self._pending_completion = collections.deque()

self._selector.register(self._wake_r, selectors.EVENT_READ)
self._idle_expiry_manager = IdleConnectionManager(self.config['connections_max_idle_ms'])
self._closed = False
Expand Down Expand Up @@ -254,7 +260,8 @@ def _bootstrap(self, hosts):
future = bootstrap.send(metadata_request)
while not future.is_done:
self._selector.select(1)
bootstrap.recv()
for r, f in bootstrap.recv():
f.success(r)
if future.failed():
bootstrap.close()
continue
Expand Down Expand Up @@ -512,7 +519,9 @@ def poll(self, timeout_ms=None, future=None, delayed_tasks=True):
Returns:
list: responses received (can be empty)
"""
if timeout_ms is None:
if future is not None:
timeout_ms = 100
elif timeout_ms is None:
timeout_ms = self.config['request_timeout_ms']

responses = []
Expand Down Expand Up @@ -551,7 +560,9 @@ def poll(self, timeout_ms=None, future=None, delayed_tasks=True):
self.config['request_timeout_ms'])
timeout = max(0, timeout / 1000.0) # avoid negative timeouts

responses.extend(self._poll(timeout))
self._poll(timeout)

responses.extend(self._fire_pending_completed_requests())

# If all we had was a timeout (future is None) - only do one poll
# If we do have a future, we keep looping until it is done
Expand All @@ -561,7 +572,7 @@ def poll(self, timeout_ms=None, future=None, delayed_tasks=True):
return responses

def _poll(self, timeout):
responses = []
"""Returns list of (response, future) tuples"""
processed = set()

start_select = time.time()
Expand Down Expand Up @@ -600,14 +611,14 @@ def _poll(self, timeout):
continue

self._idle_expiry_manager.update(conn.node_id)
responses.extend(conn.recv()) # Note: conn.recv runs callbacks / errbacks
self._pending_completion.extend(conn.recv())

# Check for additional pending SSL bytes
if self.config['security_protocol'] in ('SSL', 'SASL_SSL'):
# TODO: optimize
for conn in self._conns.values():
if conn not in processed and conn.connected() and conn._sock.pending():
responses.extend(conn.recv())
self._pending_completion.extend(conn.recv())

for conn in six.itervalues(self._conns):
if conn.requests_timed_out():
Expand All @@ -621,7 +632,6 @@ def _poll(self, timeout):
self._sensors.io_time.record((time.time() - end_select) * 1000000000)

self._maybe_close_oldest_connection()
return responses

def in_flight_request_count(self, node_id=None):
"""Get the number of in-flight requests for a node or all nodes.
Expand All @@ -640,6 +650,14 @@ def in_flight_request_count(self, node_id=None):
else:
return sum([len(conn.in_flight_requests) for conn in self._conns.values()])

def _fire_pending_completed_requests(self):
responses = []
while self._pending_completion:
response, future = self._pending_completion.popleft()
future.success(response)
responses.append(response)
return responses

def least_loaded_node(self):
"""Choose the node with fewest outstanding requests, with fallbacks.
Expand Down
39 changes: 25 additions & 14 deletions kafka/conn.py
Expand Up @@ -5,6 +5,14 @@
import errno
import logging
from random import shuffle, uniform

# selectors in stdlib as of py3.4
try:
import selectors # pylint: disable=import-error
except ImportError:
# vendored backport module
from .vendor import selectors34 as selectors

import socket
import struct
import sys
Expand Down Expand Up @@ -138,6 +146,9 @@ class BrokerConnection(object):
api_version_auto_timeout_ms (int): number of milliseconds to throw a
timeout exception from the constructor when checking the broker
api version. Only applies if api_version is None
selector (selectors.BaseSelector): Provide a specific selector
implementation to use for I/O multiplexing.
Default: selectors.DefaultSelector
state_change_callback (callable): function to be called when the
connection state changes from CONNECTING to CONNECTED etc.
metrics (kafka.metrics.Metrics): Optionally provide a metrics
Expand Down Expand Up @@ -173,6 +184,7 @@ class BrokerConnection(object):
'ssl_crlfile': None,
'ssl_password': None,
'api_version': (0, 8, 2), # default to most restrictive
'selector': selectors.DefaultSelector,
'state_change_callback': lambda conn: True,
'metrics': None,
'metric_group_prefix': '',
Expand Down Expand Up @@ -705,7 +717,7 @@ def can_send_more(self):
def recv(self):
"""Non-blocking network receive.
Return response if available
Return list of (response, future)
"""
if not self.connected() and not self.state is ConnectionStates.AUTHENTICATING:
log.warning('%s cannot recv: socket not connected', self)
Expand All @@ -728,17 +740,16 @@ def recv(self):
self.config['request_timeout_ms']))
return ()

for response in responses:
# augment respones w/ correlation_id, future, and timestamp
for i in range(len(responses)):
(correlation_id, future, timestamp) = self.in_flight_requests.popleft()
if isinstance(response, Errors.KafkaError):
self.close(response)
break

latency_ms = (time.time() - timestamp) * 1000
if self._sensors:
self._sensors.request_time.record((time.time() - timestamp) * 1000)
self._sensors.request_time.record(latency_ms)

log.debug('%s Response %d: %s', self, correlation_id, response)
future.success(response)
response = responses[i]
log.debug('%s Response %d (%s ms): %s', self, correlation_id, latency_ms, response)
responses[i] = (response, future)

return responses

Expand Down Expand Up @@ -900,12 +911,12 @@ def connect():
# request was unrecognized
mr = self.send(MetadataRequest[0]([]))

if self._sock:
self._sock.setblocking(True)
selector = self.config['selector']()
selector.register(self._sock, selectors.EVENT_READ)
while not (f.is_done and mr.is_done):
self.recv()
if self._sock:
self._sock.setblocking(False)
for response, future in self.recv():
future.success(response)
selector.select(1)

if f.succeeded():
if isinstance(request, ApiVersionRequest[0]):
Expand Down
3 changes: 2 additions & 1 deletion test/test_client.py
Expand Up @@ -28,6 +28,7 @@ def mock_conn(conn, success=True):
else:
mocked.send.return_value = Future().failure(Exception())
conn.return_value = mocked
conn.recv.return_value = []


class TestSimpleClient(unittest.TestCase):
Expand Down Expand Up @@ -94,7 +95,7 @@ def test_send_broker_unaware_request(self):
mock_conn(mocked_conns[('kafka03', 9092)], success=False)
future = Future()
mocked_conns[('kafka02', 9092)].send.return_value = future
mocked_conns[('kafka02', 9092)].recv.side_effect = lambda: future.success('valid response')
mocked_conns[('kafka02', 9092)].recv.return_value = [('valid response', future)]

def mock_get_conn(host, port, afi):
return mocked_conns[(host, port)]
Expand Down

0 comments on commit 7f553e6

Please sign in to comment.