Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
wbarnha committed Apr 9, 2024
2 parents 3c997fd + 6756974 commit 8f99d84
Show file tree
Hide file tree
Showing 19 changed files with 516 additions and 394 deletions.
4 changes: 4 additions & 0 deletions kafka/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class KafkaClient:
rate. To avoid connection storms, a randomization factor of 0.2
will be applied to the backoff resulting in a random range between
20% below and 20% above the computed value. Default: 1000.
connection_timeout_ms (int): Connection timeout in milliseconds.
Default: None, which defaults it to the same value as
request_timeout_ms.
request_timeout_ms (int): Client request timeout in milliseconds.
Default: 30000.
connections_max_idle_ms: Close idle connections after the number of
Expand Down Expand Up @@ -145,6 +148,7 @@ class KafkaClient:
'bootstrap_servers': 'localhost',
'bootstrap_topics_filter': set(),
'client_id': 'kafka-python-' + __version__,
'connection_timeout_ms': None,
'request_timeout_ms': 30000,
'wakeup_timeout_ms': 3000,
'connections_max_idle_ms': 9 * 60 * 1000,
Expand Down
37 changes: 20 additions & 17 deletions kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,6 @@ class SSLWantWriteError(Exception):
gssapi = None
GSSError = None

# needed for AWS_MSK_IAM authentication:
try:
from botocore.session import Session as BotoSession
except ImportError:
# no botocore available, will disable AWS_MSK_IAM mechanism
BotoSession = None

AFI_NAMES = {
socket.AF_UNSPEC: "unspecified",
socket.AF_INET: "IPv4",
Expand Down Expand Up @@ -112,6 +105,9 @@ class BrokerConnection:
rate. To avoid connection storms, a randomization factor of 0.2
will be applied to the backoff resulting in a random range between
20% below and 20% above the computed value. Default: 1000.
connection_timeout_ms (int): Connection timeout in milliseconds.
Default: None, which defaults it to the same value as
request_timeout_ms.
request_timeout_ms (int): Client request timeout in milliseconds.
Default: 30000.
max_in_flight_requests_per_connection (int): Requests are pipelined
Expand Down Expand Up @@ -188,6 +184,7 @@ class BrokerConnection:
'client_id': 'kafka-python-' + __version__,
'node_id': 0,
'request_timeout_ms': 30000,
'connection_timeout_ms': None,
'reconnect_backoff_ms': 50,
'reconnect_backoff_max_ms': 1000,
'max_in_flight_requests_per_connection': 5,
Expand Down Expand Up @@ -232,6 +229,9 @@ def __init__(self, host, port, afi, **configs):
if key in configs:
self.config[key] = configs[key]

if self.config['connection_timeout_ms'] is None:
self.config['connection_timeout_ms'] = self.config['request_timeout_ms']

self.node_id = self.config.pop('node_id')

if self.config['receive_buffer_bytes'] is not None:
Expand All @@ -246,19 +246,15 @@ def __init__(self, host, port, afi, **configs):
assert self.config['security_protocol'] in self.SECURITY_PROTOCOLS, (
'security_protocol must be in ' + ', '.join(self.SECURITY_PROTOCOLS))


if self.config['security_protocol'] in ('SSL', 'SASL_SSL'):
assert ssl_available, "Python wasn't built with SSL support"

if self.config['sasl_mechanism'] == 'AWS_MSK_IAM':
assert BotoSession is not None, 'AWS_MSK_IAM requires the "botocore" package'
assert self.config['security_protocol'] == 'SASL_SSL', 'AWS_MSK_IAM requires SASL_SSL'

if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'):
assert self.config['sasl_mechanism'] in sasl.MECHANISMS, (
'sasl_mechanism must be one of {}'.format(', '.join(sasl.MECHANISMS.keys()))
)
sasl.MECHANISMS[self.config['sasl_mechanism']].validate_config(self)

# This is not a general lock / this class is not generally thread-safe yet
# However, to avoid pushing responsibility for maintaining
# per-connection locks to the upstream client, we will use this lock to
Expand All @@ -284,7 +280,10 @@ def __init__(self, host, port, afi, **configs):
if self.config['ssl_context'] is not None:
self._ssl_context = self.config['ssl_context']
self._sasl_auth_future = None
self.last_attempt = 0
self.last_activity = 0
# This value is not used for internal state, but it is left to allow backwards-compatability
# The variable last_activity is now used instead, but is updated more often may therefore break compatability with some hacks.
self.last_attempt= 0
self._gai = []
self._sensors = None
if self.config['metrics']:
Expand Down Expand Up @@ -362,6 +361,7 @@ def connect(self):
self.config['state_change_callback'](self.node_id, self._sock, self)
log.info('%s: connecting to %s:%d [%s %s]', self, self.host,
self.port, self._sock_addr, AFI_NAMES[self._sock_afi])
self.last_activity = time.time()

if self.state is ConnectionStates.CONNECTING:
# in non-blocking mode, use repeated calls to socket.connect_ex
Expand Down Expand Up @@ -394,6 +394,7 @@ def connect(self):
self.state = ConnectionStates.CONNECTED
self._reset_reconnect_backoff()
self.config['state_change_callback'](self.node_id, self._sock, self)
self.last_activity = time.time()

# Connection failed
# WSAEINVAL == 10022, but errno.WSAEINVAL is not available on non-win systems
Expand All @@ -419,6 +420,7 @@ def connect(self):
self.state = ConnectionStates.CONNECTED
self._reset_reconnect_backoff()
self.config['state_change_callback'](self.node_id, self._sock, self)
self.last_activity = time.time()

if self.state is ConnectionStates.AUTHENTICATING:
assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL')
Expand All @@ -429,12 +431,13 @@ def connect(self):
self.state = ConnectionStates.CONNECTED
self._reset_reconnect_backoff()
self.config['state_change_callback'](self.node_id, self._sock, self)
self.last_activity = time.time()

if self.state not in (ConnectionStates.CONNECTED,
ConnectionStates.DISCONNECTED):
# Connection timed out
request_timeout = self.config['request_timeout_ms'] / 1000.0
if time.time() > request_timeout + self.last_attempt:
request_timeout = self.config['connection_timeout_ms'] / 1000.0
if time.time() > request_timeout + self.last_activity:
log.error('Connection attempt to %s timed out', self)
self.close(Errors.KafkaConnectionError('timeout'))
return self.state
Expand Down Expand Up @@ -595,7 +598,7 @@ def blacked_out(self):
re-establish a connection yet
"""
if self.state is ConnectionStates.DISCONNECTED:
if time.time() < self.last_attempt + self._reconnect_backoff:
if time.time() < self.last_activity + self._reconnect_backoff:
return True
return False

Expand All @@ -606,7 +609,7 @@ def connection_delay(self):
the reconnect backoff time. When connecting or connected, returns a very
large number to handle slow/stalled connections.
"""
time_waited = time.time() - (self.last_attempt or 0)
time_waited = time.time() - (self.last_activity or 0)
if self.state is ConnectionStates.DISCONNECTED:
return max(self._reconnect_backoff - time_waited, 0) * 1000
else:
Expand Down
2 changes: 1 addition & 1 deletion kafka/coordinator/assignors/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class AbstractPartitionAssignor(object):
partition counts which are always needed in assignors).
"""

@abc.abstractproperty
@abc.abstractmethod
def name(self):
""".name should be a string identifying the assignor"""
pass
Expand Down
1 change: 0 additions & 1 deletion kafka/coordinator/assignors/sticky/sticky_assignor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections import defaultdict, namedtuple
from copy import deepcopy

from kafka.cluster import ClusterMetadata
from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor
from kafka.coordinator.assignors.sticky.partition_movements import PartitionMovements
from kafka.coordinator.assignors.sticky.sorted_set import SortedSet
Expand Down
9 changes: 5 additions & 4 deletions kafka/errors.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import inspect
import sys
from typing import Any


class KafkaError(RuntimeError):
retriable = False
# whether metadata should be refreshed on error
invalid_metadata = False

def __str__(self):
def __str__(self) -> str:
if not self.args:
return self.__class__.__name__
return '{}: {}'.format(self.__class__.__name__,
Expand Down Expand Up @@ -65,7 +66,7 @@ class IncompatibleBrokerVersion(KafkaError):


class CommitFailedError(KafkaError):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(
"""Commit cannot be completed since the group has already
rebalanced and assigned the partitions to another member.
Expand All @@ -92,7 +93,7 @@ class BrokerResponseError(KafkaError):
message = None
description = None

def __str__(self):
def __str__(self) -> str:
"""Add errno to standard KafkaError str"""
return '[Error {}] {}'.format(
self.errno,
Expand Down Expand Up @@ -509,7 +510,7 @@ def _iter_broker_errors():
kafka_errors = {x.errno: x for x in _iter_broker_errors()}


def for_code(error_code):
def for_code(error_code: int) -> Any:
return kafka_errors.get(error_code, UnknownError)


Expand Down
4 changes: 4 additions & 0 deletions kafka/producer/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ class KafkaProducer:
brokers or partitions. Default: 300000
retry_backoff_ms (int): Milliseconds to backoff when retrying on
errors. Default: 100.
connection_timeout_ms (int): Connection timeout in milliseconds.
Default: None, which defaults it to the same value as
request_timeout_ms.
request_timeout_ms (int): Client request timeout in milliseconds.
Default: 30000.
receive_buffer_bytes (int): The size of the TCP receive buffer
Expand Down Expand Up @@ -300,6 +303,7 @@ class KafkaProducer:
'max_request_size': 1048576,
'metadata_max_age_ms': 300000,
'retry_backoff_ms': 100,
'connection_timeout_ms': None,
'request_timeout_ms': 30000,
'receive_buffer_bytes': None,
'send_buffer_bytes': None,
Expand Down
14 changes: 7 additions & 7 deletions kafka/protocol/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,22 @@ class Request(Struct):

FLEXIBLE_VERSION = False

@abc.abstractproperty
@abc.abstractmethod
def API_KEY(self):
"""Integer identifier for api request"""
pass

@abc.abstractproperty
@abc.abstractmethod
def API_VERSION(self):
"""Integer of api request version"""
pass

@abc.abstractproperty
@abc.abstractmethod
def SCHEMA(self):
"""An instance of Schema() representing the request structure"""
pass

@abc.abstractproperty
@abc.abstractmethod
def RESPONSE_TYPE(self):
"""The Response class associated with the api request"""
pass
Expand All @@ -93,17 +93,17 @@ def parse_response_header(self, read_buffer):
class Response(Struct):
__metaclass__ = abc.ABCMeta

@abc.abstractproperty
@abc.abstractmethod
def API_KEY(self):
"""Integer identifier for api request/response"""
pass

@abc.abstractproperty
@abc.abstractmethod
def API_VERSION(self):
"""Integer of api request/response version"""
pass

@abc.abstractproperty
@abc.abstractmethod
def SCHEMA(self):
"""An instance of Schema() representing the response structure"""
pass
Expand Down
14 changes: 8 additions & 6 deletions kafka/protocol/struct.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from io import BytesIO
from typing import List, Union

from kafka.protocol.abstract import AbstractType
from kafka.protocol.types import Schema


from kafka.util import WeakMethod


class Struct(AbstractType):
SCHEMA = Schema()

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
if len(args) == len(self.SCHEMA.fields):
for i, name in enumerate(self.SCHEMA.names):
self.__dict__[name] = args[i]
Expand All @@ -36,23 +38,23 @@ def encode(cls, item): # pylint: disable=E0202
bits.append(field.encode(item[i]))
return b''.join(bits)

def _encode_self(self):
def _encode_self(self) -> bytes:
return self.SCHEMA.encode(
[self.__dict__[name] for name in self.SCHEMA.names]
)

@classmethod
def decode(cls, data):
def decode(cls, data: Union[BytesIO, bytes]) -> "Struct":
if isinstance(data, bytes):
data = BytesIO(data)
return cls(*[field.decode(data) for field in cls.SCHEMA.fields])

def get_item(self, name):
def get_item(self, name: str) -> Union[int, List[List[Union[int, str, bool, List[List[Union[int, List[int]]]]]]], str, List[List[Union[int, str]]]]:
if name not in self.SCHEMA.names:
raise KeyError("%s is not in the schema" % name)
return self.__dict__[name]

def __repr__(self):
def __repr__(self) -> str:
key_vals = []
for name, field in zip(self.SCHEMA.names, self.SCHEMA.fields):
key_vals.append(f'{name}={field.repr(self.__dict__[name])}')
Expand All @@ -61,7 +63,7 @@ def __repr__(self):
def __hash__(self):
return hash(self.encode())

def __eq__(self, other):
def __eq__(self, other: "Struct") -> bool:
if self.SCHEMA != other.SCHEMA:
return False
for attr in self.SCHEMA.names:
Expand Down
6 changes: 3 additions & 3 deletions kafka/record/_crc32c.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
_MASK = 0xFFFFFFFF


def crc_update(crc, data):
def crc_update(crc: int, data: bytes) -> int:
"""Update CRC-32C checksum with data.
Args:
crc: 32-bit checksum to update as long.
Expand All @@ -116,7 +116,7 @@ def crc_update(crc, data):
return crc ^ _MASK


def crc_finalize(crc):
def crc_finalize(crc: int) -> int:
"""Finalize CRC-32C checksum.
This function should be called as last step of crc calculation.
Args:
Expand All @@ -127,7 +127,7 @@ def crc_finalize(crc):
return crc & _MASK


def crc(data):
def crc(data: bytes) -> int:
"""Compute CRC-32C checksum of the data.
Args:
data: byte array, string or iterable over bytes.
Expand Down
Loading

0 comments on commit 8f99d84

Please sign in to comment.