Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion kafka/net/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,13 @@ async def _sasl_authenticate(self):

# Step 2: SASL authentication exchange
version = response.API_VERSION
# Prefer the configured hostname (stored on the transport) so that
# mechanisms like GSSAPI construct service principals against the
# user-supplied name, not whichever IP getaddrinfo handed us.
sasl_host = self.transport.host if self.transport.host else self.transport.getPeer()[0]
try:
mechanism = get_sasl_mechanism(self.config['sasl_mechanism'])(
host=self.transport.getPeer()[0], **self.config)
host=sasl_host, **self.config)
except Exception as exc:
self.close(exc)
return
Expand Down
6 changes: 3 additions & 3 deletions kafka/net/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,10 @@ async def _build_transport(self, node):
self.config['socket_options'],
proxy_url=self.config['proxy_url'])
if self.ssl_enabled:
hostname = node.host if self.config['ssl_check_hostname'] else None
transport = KafkaSSLTransport(self._net, sock, self._build_ssl_context(), hostname)
transport = KafkaSSLTransport(self._net, sock, self._build_ssl_context(),
host=node.host, ssl_check_hostname=self.config['ssl_check_hostname'])
else:
transport = KafkaTCPTransport(self._net, sock)
transport = KafkaTCPTransport(self._net, sock, host=node.host)

try:
await transport.handshake()
Expand Down
10 changes: 7 additions & 3 deletions kafka/net/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@


class KafkaTCPTransport:
def __init__(self, net, sock):
def __init__(self, net, sock, host=None):
self._net = net
self._sock = sock
self.host = host
self._closed = False
self._write_buffer = deque()
self._writing = False
Expand Down Expand Up @@ -335,6 +336,8 @@ async def handshake(self):
pass

def host_port(self):
if self._sock is None:
return 'none'
try:
host, port = self._sock.getpeername()[0:2]
except (OSError, ValueError):
Expand All @@ -351,11 +354,12 @@ def __str__(self):


class KafkaSSLTransport(KafkaTCPTransport):
def __init__(self, net, sock, ssl_context, server_hostname=None):
def __init__(self, net, sock, ssl_context, host=None, ssl_check_hostname=False):
self._ssl_context = ssl_context
server_hostname = host if ssl_check_hostname else None
sock = ssl_context.wrap_socket(
sock, server_hostname=server_hostname, do_handshake_on_connect=False)
super().__init__(net, sock)
super().__init__(net, sock, host=host)

async def handshake(self):
while True:
Expand Down
60 changes: 60 additions & 0 deletions test/net/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,63 @@ def mock_send_request(request):

net.run(conn._sasl_authenticate())
transport.abort.assert_called_once()

def _drive_handshake_with_recording_mechanism(self, net, conn):
from kafka.protocol.sasl import SaslHandshakeRequest
api_versions = {SaslHandshakeRequest[0].API_KEY: (0, 1)}
conn.broker_version_data = BrokerVersionData(api_versions=api_versions)
handshake_response = MagicMock()
handshake_response.error_code = 0
handshake_response.mechanisms = ['PLAIN']
auth_response = MagicMock()
auth_response.error_code = 0
auth_response.auth_bytes = b''
responses = iter([handshake_response, auth_response])
def mock_send_request(_):
f = Future()
f.success(next(responses))
return f
conn._send_request = mock_send_request

captured = {}
from kafka.sasl import register_sasl_mechanism
from kafka.sasl.plain import SaslMechanismPlain

class RecordingPlain(SaslMechanismPlain):
def __init__(self, **config):
captured['host'] = config.get('host')
super().__init__(**config)
register_sasl_mechanism('PLAIN', RecordingPlain, overwrite=True)
try:
net.run(conn._sasl_authenticate())
finally:
register_sasl_mechanism('PLAIN', SaslMechanismPlain, overwrite=True)
return captured

def test_sasl_uses_transport_host_for_mechanism(self, net):
conn = KafkaConnection(
net, node_id='test',
security_protocol='SASL_PLAINTEXT', sasl_mechanism='PLAIN',
sasl_plain_username='user', sasl_plain_password='pass')
transport = MagicMock()
transport.host = 'kafka.example.com'
transport.getPeer.return_value = ('10.0.0.1', 9092)
conn.transport = transport
conn.initializing = True

captured = self._drive_handshake_with_recording_mechanism(net, conn)
assert captured['host'] == 'kafka.example.com'

def test_sasl_falls_back_to_peer_ip_when_transport_host_unset(self, net):
conn = KafkaConnection(
net, node_id='test',
security_protocol='SASL_PLAINTEXT', sasl_mechanism='PLAIN',
sasl_plain_username='user', sasl_plain_password='pass')
transport = MagicMock()
transport.host = None
transport.getPeer.return_value = ('10.0.0.1', 9092)
conn.transport = transport
conn.initializing = True

captured = self._drive_handshake_with_recording_mechanism(net, conn)
assert captured['host'] == '10.0.0.1'
Loading