diff --git a/kafka/client_async.py b/kafka/client_async.py index 0d9e56258..9c19ac319 100644 --- a/kafka/client_async.py +++ b/kafka/client_async.py @@ -617,7 +617,7 @@ def _poll(self, timeout): conn = key.data processed.add(conn) - if not conn.in_flight_requests: + if not conn.has_in_flight_requests(): # if we got an EVENT_READ but there were no in-flight requests, one of # two things has happened: # @@ -648,12 +648,7 @@ def _poll(self, timeout): self._pending_completion.extend(conn.recv()) for conn in six.itervalues(self._conns): - if conn.requests_timed_out(): - log.warning('%s timed out after %s ms. Closing connection.', - conn, conn.config['request_timeout_ms']) - conn.close(error=Errors.RequestTimedOutError( - 'Request timed out after %s ms' % - conn.config['request_timeout_ms'])) + conn.close_if_timed_out() if self._sensors: self._sensors.io_time.record((time.time() - end_select) * 1000000000) diff --git a/kafka/conn.py b/kafka/conn.py index 33950dbbf..dd214358d 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -273,7 +273,7 @@ def __init__(self, host, port, afi, **configs): # per-connection locks to the upstream client, we will use this lock to # make sure that access to the protocol buffer is synchronized # when sends happen on multiple threads - self._lock = threading.Lock() + self._lock = threading.RLock() # the protocol parser instance manages actual tracking of the # sequence of in-flight requests to responses, which should @@ -316,38 +316,43 @@ def _next_afi_sockaddr(self): return (afi, sockaddr) def connect_blocking(self, timeout=float('inf')): - if self.connected(): - return True - timeout += time.time() - # First attempt to perform dns lookup - # note that the underlying interface, socket.getaddrinfo, - # has no explicit timeout so we may exceed the user-specified timeout - self._dns_lookup() - - # Loop once over all returned dns entries - selector = None - while self._gai: - while time.time() < timeout: - self.connect() - if self.connected(): - if selector is not None: - selector.close() - return True - elif self.connecting(): - if selector is None: - selector = self.config['selector']() - selector.register(self._sock, selectors.EVENT_WRITE) - selector.select(1) - elif self.disconnected(): - if selector is not None: - selector.close() - selector = None + with self._lock: + if self.connected(): + return True + timeout += time.time() + # First attempt to perform dns lookup + # note that the underlying interface, socket.getaddrinfo, + # has no explicit timeout so we may exceed the user-specified timeout + self._dns_lookup() + + # Loop once over all returned dns entries + selector = None + while self._gai: + while time.time() < timeout: + self._connect() + if self.connected(): + if selector is not None: + selector.close() + return True + elif self.connecting(): + if selector is None: + selector = self.config['selector']() + selector.register(self._sock, selectors.EVENT_WRITE) + selector.select(1) + elif self.disconnected(): + if selector is not None: + selector.close() + selector = None + break + else: break - else: - break - return False + return False def connect(self): + with self._lock: + self._connect() + + def _connect(self): """Attempt to connect and return ConnectionState""" if self.state is ConnectionStates.DISCONNECTED and not self.blacked_out(): self.last_attempt = time.time() @@ -784,43 +789,48 @@ def close(self, error=None): will be failed with this exception. Default: kafka.errors.KafkaConnectionError. """ - if self.state is ConnectionStates.DISCONNECTED: - if error is not None: - log.warning('%s: Duplicate close() with error: %s', self, error) - return - log.info('%s: Closing connection. %s', self, error or '') - self.state = ConnectionStates.DISCONNECTING - self.config['state_change_callback'](self) - self._update_reconnect_backoff() - self._close_socket() - self.state = ConnectionStates.DISCONNECTED - self._sasl_auth_future = None - self._protocol = KafkaProtocol( - client_id=self.config['client_id'], - api_version=self.config['api_version']) + with self._lock: + if self.state is ConnectionStates.DISCONNECTED: + if error is not None: + log.warning('%s: Duplicate close() with error: %s', self, error) + self._fail_ifrs(error) + return + log.info('%s: Closing connection. %s', self, error or '') + self.state = ConnectionStates.DISCONNECTING + self.config['state_change_callback'](self) + self._update_reconnect_backoff() + self._close_socket() + self.state = ConnectionStates.DISCONNECTED + self._sasl_auth_future = None + self._protocol = KafkaProtocol( + client_id=self.config['client_id'], + api_version=self.config['api_version']) + self._fail_ifrs(error) + self.config['state_change_callback'](self) + + def _fail_ifrs(self, error=None): if error is None: error = Errors.Cancelled(str(self)) while self.in_flight_requests: (_correlation_id, (future, _timestamp)) = self.in_flight_requests.popitem() future.failure(error) - self.config['state_change_callback'](self) def send(self, request, blocking=True): """Queue request for async network send, return Future()""" - future = Future() - if self.connecting(): - return future.failure(Errors.NodeNotReadyError(str(self))) - elif not self.connected(): - return future.failure(Errors.KafkaConnectionError(str(self))) - elif not self.can_send_more(): - return future.failure(Errors.TooManyInFlightRequests(str(self))) - return self._send(request, blocking=blocking) + with self._lock: + future = Future() + if self.connecting(): + return future.failure(Errors.NodeNotReadyError(str(self))) + elif not self.connected(): + return future.failure(Errors.KafkaConnectionError(str(self))) + elif not self.can_send_more(): + return future.failure(Errors.TooManyInFlightRequests(str(self))) + return self._send(request, blocking=blocking) def _send(self, request, blocking=True): assert self.state in (ConnectionStates.AUTHENTICATING, ConnectionStates.CONNECTED) future = Future() - with self._lock: - correlation_id = self._protocol.send_request(request) + correlation_id = self._protocol.send_request(request) log.debug('%s Request %d: %s', self, correlation_id, request) if request.expect_response(): @@ -839,24 +849,25 @@ def _send(self, request, blocking=True): def send_pending_requests(self): """Can block on network if request is larger than send_buffer_bytes""" - if self.state not in (ConnectionStates.AUTHENTICATING, - ConnectionStates.CONNECTED): - return Errors.NodeNotReadyError(str(self)) with self._lock: + if self.state not in (ConnectionStates.AUTHENTICATING, + ConnectionStates.CONNECTED): + return Errors.NodeNotReadyError(str(self)) + data = self._protocol.send_bytes() - try: - # In the future we might manage an internal write buffer - # and send bytes asynchronously. For now, just block - # sending each request payload - total_bytes = self._send_bytes_blocking(data) - if self._sensors: - self._sensors.bytes_sent.record(total_bytes) - return total_bytes - except ConnectionError as e: - log.exception("Error sending request data to %s", self) - error = Errors.KafkaConnectionError("%s: %s" % (self, e)) - self.close(error=error) - return error + try: + # In the future we might manage an internal write buffer + # and send bytes asynchronously. For now, just block + # sending each request payload + total_bytes = self._send_bytes_blocking(data) + if self._sensors: + self._sensors.bytes_sent.record(total_bytes) + return total_bytes + except ConnectionError as e: + log.exception("Error sending request data to %s", self) + error = Errors.KafkaConnectionError("%s: %s" % (self, e)) + self.close(error=error) + return error def can_send_more(self): """Return True unless there are max_in_flight_requests_per_connection.""" @@ -868,42 +879,38 @@ def recv(self): Return list of (response, future) tuples """ - if not self.connected() and not self.state is ConnectionStates.AUTHENTICATING: - log.warning('%s cannot recv: socket not connected', self) - # If requests are pending, we should close the socket and - # fail all the pending request futures - if self.in_flight_requests: - self.close(Errors.KafkaConnectionError('Socket not connected during recv with in-flight-requests')) - return () - - elif not self.in_flight_requests: - log.warning('%s: No in-flight-requests to recv', self) - return () - - responses = self._recv() - if not responses and self.requests_timed_out(): - log.warning('%s timed out after %s ms. Closing connection.', - self, self.config['request_timeout_ms']) - self.close(error=Errors.RequestTimedOutError( - 'Request timed out after %s ms' % - self.config['request_timeout_ms'])) - return () - - # augment respones w/ correlation_id, future, and timestamp - for i, (correlation_id, response) in enumerate(responses): - try: - (future, timestamp) = self.in_flight_requests.pop(correlation_id) - except KeyError: - self.close(Errors.KafkaConnectionError('Received unrecognized correlation id')) + with self._lock: + if not self.connected() and not self.state is ConnectionStates.AUTHENTICATING: + log.warning('%s cannot recv: socket not connected', self) + # If requests are pending, we should close the socket and + # fail all the pending request futures + if self.in_flight_requests: + self.close(Errors.KafkaConnectionError('Socket not connected during recv with in-flight-requests')) return () - latency_ms = (time.time() - timestamp) * 1000 - if self._sensors: - self._sensors.request_time.record(latency_ms) - log.debug('%s Response %d (%s ms): %s', self, correlation_id, latency_ms, response) - responses[i] = (response, future) + elif not self.in_flight_requests: + log.warning('%s: No in-flight-requests to recv', self) + return () - return responses + responses = self._recv() + if not responses and self.close_if_timed_out(): + return () + + # augment respones w/ correlation_id, future, and timestamp + for i, (correlation_id, response) in enumerate(responses): + try: + (future, timestamp) = self.in_flight_requests.pop(correlation_id) + except KeyError: + self.close(Errors.KafkaConnectionError('Received unrecognized correlation id')) + return () + latency_ms = (time.time() - timestamp) * 1000 + if self._sensors: + self._sensors.request_time.record(latency_ms) + + log.debug('%s Response %d (%s ms): %s', self, correlation_id, latency_ms, response) + responses[i] = (response, future) + + return responses def _recv(self): """Take all available bytes from socket, return list of any responses from parser""" @@ -948,6 +955,24 @@ def _recv(self): else: return responses + def has_in_flight_requests(self): + with self._lock: + return bool(self.in_flight_requests) + + def close_if_timed_out(self): + """ If the connection has timed-out in-flight-requests, close it and return True. Otherwise return False """ + with self._lock: + if self.requests_timed_out(): + log.warning('%s timed out after %s ms. Closing connection.', + self, self.config['request_timeout_ms']) + self.close(error=Errors.RequestTimedOutError( + 'Request timed out after %s ms' % + self.config['request_timeout_ms'])) + + return True + + return False + def requests_timed_out(self): if self.in_flight_requests: get_timestamp = lambda v: v[1]