Skip to content

Commit

Permalink
Merge c4833a8 into 6f932ba
Browse files Browse the repository at this point in the history
  • Loading branch information
Tincu Gabriel committed Oct 20, 2020
2 parents 6f932ba + c4833a8 commit 82c44a0
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 11 deletions.
13 changes: 9 additions & 4 deletions kafka/admin/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from kafka.admin.acl_resource import ACLOperation, ACLPermissionType, ACLFilter, ACL, ResourcePattern, ResourceType, \
ACLResourcePatternType
from kafka.client_async import KafkaClient, selectors
from kafka.client_async import selectors
from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment, ConsumerProtocol
import kafka.errors as Errors
from kafka.errors import (
Expand All @@ -26,6 +26,7 @@
from kafka.protocol.metadata import MetadataRequest
from kafka.protocol.types import Array
from kafka.structs import TopicPartition, OffsetAndMetadata, MemberInformation, GroupInformation
from kafka.util import get_client_factory
from kafka.version import __version__


Expand Down Expand Up @@ -146,6 +147,7 @@ class KafkaAdminClient(object):
sasl mechanism handshake. Default: one of bootstrap servers
sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider
instance. (See kafka.oauth.abstract). Default: None
client_factory (callable): Custom class / callable for creating KafkaClient instances
"""
DEFAULT_CONFIG = {
Expand Down Expand Up @@ -186,6 +188,7 @@ class KafkaAdminClient(object):
'metric_reporters': [],
'metrics_num_samples': 2,
'metrics_sample_window_ms': 30000,
'client_factory': None,
}

def __init__(self, **configs):
Expand All @@ -205,9 +208,11 @@ def __init__(self, **configs):
reporters = [reporter() for reporter in self.config['metric_reporters']]
self._metrics = Metrics(metric_config, reporters)

self._client = KafkaClient(metrics=self._metrics,
metric_group_prefix='admin',
**self.config)
self._client = get_client_factory(self.config)(
metrics=self._metrics,
metric_group_prefix='admin',
**self.config
)
self._client.check_version(timeout=(self.config['api_version_auto_timeout_ms'] / 1000))

# Get auto-discovered version from client if necessary
Expand Down
7 changes: 5 additions & 2 deletions kafka/consumer/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from kafka.vendor import six

from kafka.client_async import KafkaClient, selectors
from kafka.client_async import selectors
from kafka.consumer.fetcher import Fetcher
from kafka.consumer.subscription_state import SubscriptionState
from kafka.coordinator.consumer import ConsumerCoordinator
Expand All @@ -18,6 +18,7 @@
from kafka.metrics import MetricConfig, Metrics
from kafka.protocol.offset import OffsetResetStrategy
from kafka.structs import TopicPartition
from kafka.util import get_client_factory
from kafka.version import __version__

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -244,6 +245,7 @@ class KafkaConsumer(six.Iterator):
sasl mechanism handshake. Default: one of bootstrap servers
sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider
instance. (See kafka.oauth.abstract). Default: None
client_factory (callable): Custom class / callable for creating KafkaClient instances
Note:
Configuration parameters are described in more detail at
Expand Down Expand Up @@ -306,6 +308,7 @@ class KafkaConsumer(six.Iterator):
'sasl_kerberos_domain_name': None,
'sasl_oauth_token_provider': None,
'legacy_iterator': False, # enable to revert to < 1.4.7 iterator
'client_factory': None,
}
DEFAULT_SESSION_TIMEOUT_MS_0_9 = 30000

Expand Down Expand Up @@ -353,7 +356,7 @@ def __init__(self, *topics, **configs):
log.warning('use api_version=%s [tuple] -- "%s" as str is deprecated',
str(self.config['api_version']), str_version)

self._client = KafkaClient(metrics=self._metrics, **self.config)
self._client = get_client_factory(self.config)(metrics=self._metrics, **self.config)

# Get auto-discovered version from client if necessary
if self.config['api_version'] is None:
Expand Down
14 changes: 9 additions & 5 deletions kafka/producer/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from kafka.vendor import six

import kafka.errors as Errors
from kafka.client_async import KafkaClient, selectors
from kafka.client_async import selectors
from kafka.codec import has_gzip, has_snappy, has_lz4, has_zstd
from kafka.metrics import MetricConfig, Metrics
from kafka.partitioner.default import DefaultPartitioner
Expand All @@ -22,6 +22,7 @@
from kafka.record.legacy_records import LegacyRecordBatchBuilder
from kafka.serializer import Serializer
from kafka.structs import TopicPartition
from kafka.util import get_client_factory


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -280,6 +281,7 @@ class KafkaProducer(object):
sasl mechanism handshake. Default: one of bootstrap servers
sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider
instance. (See kafka.oauth.abstract). Default: None
client_factory (callable): Custom class / callable for creating KafkaClient instances
Note:
Configuration parameters are described in more detail at
Expand Down Expand Up @@ -332,7 +334,8 @@ class KafkaProducer(object):
'sasl_plain_password': None,
'sasl_kerberos_service_name': 'kafka',
'sasl_kerberos_domain_name': None,
'sasl_oauth_token_provider': None
'sasl_oauth_token_provider': None,
'client_factory': None,
}

_COMPRESSORS = {
Expand Down Expand Up @@ -378,9 +381,10 @@ def __init__(self, **configs):
reporters = [reporter() for reporter in self.config['metric_reporters']]
self._metrics = Metrics(metric_config, reporters)

client = KafkaClient(metrics=self._metrics, metric_group_prefix='producer',
wakeup_timeout_ms=self.config['max_block_ms'],
**self.config)
client = get_client_factory(self.config)(
metrics=self._metrics, metric_group_prefix='producer',
wakeup_timeout_ms=self.config['max_block_ms'],
**self.config)

# Get auto-discovered version from client if necessary
if self.config['api_version'] is None:
Expand Down
10 changes: 10 additions & 0 deletions kafka/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import

import binascii
import kafka
import weakref

from kafka.vendor import six
Expand Down Expand Up @@ -64,3 +65,12 @@ class Dict(dict):
See: https://docs.python.org/2/library/weakref.html
"""
pass


def get_client_factory(config):
if 'client_factory' in config and config['client_factory'] is not None:
client_factory = config['client_factory']
assert callable(client_factory), "'client_factory' should be a callable or None, is {}".format(type(client_factory))
else:
client_factory = kafka.client_async.KafkaClient
return client_factory

0 comments on commit 82c44a0

Please sign in to comment.