Skip to content

Commit

Permalink
core: add support for flexible versions
Browse files Browse the repository at this point in the history
Varint code lifted from the java kafka repo, with the caveat that we
need to specify a byte size by bitmasking due to how python integers
handle the sign bit
Serialization tests lifted from java repo as well

- Add support for varints
- Add support for compact collections (byte array, string, array)
- Add support for new request and response headers, supporting flexible
versions
- Add List / Alter partition reassignments apis
  • Loading branch information
Tincu Gabriel committed Nov 19, 2020
1 parent 53dc740 commit eaedde8
Show file tree
Hide file tree
Showing 6 changed files with 347 additions and 17 deletions.
2 changes: 2 additions & 0 deletions kafka/protocol/__init__.py
Expand Up @@ -43,5 +43,7 @@
40: 'ExpireDelegationToken',
41: 'DescribeDelegationToken',
42: 'DeleteGroups',
45: 'AlterPartitionReassignments',
46: 'ListPartitionReassignments',
48: 'DescribeClientQuotas',
}
91 changes: 90 additions & 1 deletion kafka/protocol/admin.py
@@ -1,7 +1,7 @@
from __future__ import absolute_import

from kafka.protocol.api import Request, Response
from kafka.protocol.types import Array, Boolean, Bytes, Int8, Int16, Int32, Int64, Schema, String, Float64
from kafka.protocol.types import Array, Boolean, Bytes, Int8, Int16, Int32, Int64, Schema, String, Float64, CompactString, CompactArray, TaggedFields


class ApiVersionResponse_v0(Response):
Expand Down Expand Up @@ -963,3 +963,92 @@ class DescribeClientQuotasRequest_v0(Request):
DescribeClientQuotasResponse = [
DescribeClientQuotasResponse_v0,
]


class AlterPartitionReassignmentsResponse_v0(Response):
API_KEY = 45
API_VERSION = 0
SCHEMA = Schema(
("throttle_time_ms", Int32),
("error_code", Int16),
("error_message", CompactString("utf-8")),
("responses", CompactArray(
("name", CompactString("utf-8")),
("partitions", CompactArray(
("partition_index", Int32),
("error_code", Int16),
("error_message", CompactString("utf-8")),
("tags", TaggedFields)
)),
("tags", TaggedFields)
)),
("tags", TaggedFields)
)


class AlterPartitionReassignmentsRequest_v0(Request):
FLEXIBLE_VERSION = True
API_KEY = 45
API_VERSION = 0
RESPONSE_TYPE = AlterPartitionReassignmentsResponse_v0
SCHEMA = Schema(
("timeout_ms", Int32),
("topics", CompactArray(
("name", CompactString("utf-8")),
("partitions", CompactArray(
("partition_index", Int32),
("replicas", CompactArray(Int32)),
("tags", TaggedFields)
)),
("tags", TaggedFields)
)),
("tags", TaggedFields)
)


AlterPartitionReassignmentsRequest = [AlterPartitionReassignmentsRequest_v0]

AlterPartitionReassignmentsResponse = [AlterPartitionReassignmentsResponse_v0]


class ListPartitionReassignmentsResponse_v0(Response):
API_KEY = 46
API_VERSION = 0
SCHEMA = Schema(
("throttle_time_ms", Int32),
("error_code", Int16),
("error_message", CompactString("utf-8")),
("topics", CompactArray(
("name", CompactString("utf-8")),
("partitions", CompactArray(
("partition_index", Int32),
("replicas", CompactArray(Int32)),
("adding_replicas", CompactArray(Int32)),
("removing_replicas", CompactArray(Int32)),
("tags", TaggedFields)
)),
("tags", TaggedFields)
)),
("tags", TaggedFields)
)


class ListPartitionReassignmentsRequest_v0(Request):
FLEXIBLE_VERSION = True
API_KEY = 46
API_VERSION = 0
RESPONSE_TYPE = ListPartitionReassignmentsResponse_v0
SCHEMA = Schema(
("timeout_ms", Int32),
("topics", CompactArray(
("name", CompactString("utf-8")),
("partition_index", CompactArray(Int32)),
("tags", TaggedFields)
)),
("tags", TaggedFields)
)


ListPartitionReassignmentsRequest = [ListPartitionReassignmentsRequest_v0]

ListPartitionReassignmentsResponse = [ListPartitionReassignmentsResponse_v0]
43 changes: 42 additions & 1 deletion kafka/protocol/api.py
Expand Up @@ -3,7 +3,7 @@
import abc

from kafka.protocol.struct import Struct
from kafka.protocol.types import Int16, Int32, String, Schema, Array
from kafka.protocol.types import Int16, Int32, String, Schema, Array, TaggedFields


class RequestHeader(Struct):
Expand All @@ -20,9 +20,40 @@ def __init__(self, request, correlation_id=0, client_id='kafka-python'):
)


class RequestHeaderV2(Struct):
# Flexible response / request headers end in field buffer
SCHEMA = Schema(
('api_key', Int16),
('api_version', Int16),
('correlation_id', Int32),
('client_id', String('utf-8')),
('tags', TaggedFields),
)

def __init__(self, request, correlation_id=0, client_id='kafka-python', tags=None):
super(RequestHeaderV2, self).__init__(
request.API_KEY, request.API_VERSION, correlation_id, client_id, tags or {}
)


class ResponseHeader(Struct):
SCHEMA = Schema(
('correlation_id', Int32),
)


class ResponseHeaderV2(Struct):
SCHEMA = Schema(
('correlation_id', Int32),
('tags', TaggedFields),
)


class Request(Struct):
__metaclass__ = abc.ABCMeta

FLEXIBLE_VERSION = False

@abc.abstractproperty
def API_KEY(self):
"""Integer identifier for api request"""
Expand Down Expand Up @@ -50,6 +81,16 @@ def expect_response(self):
def to_object(self):
return _to_object(self.SCHEMA, self)

def build_request_header(self, correlation_id, client_id):
if self.FLEXIBLE_VERSION:
return RequestHeaderV2(self, correlation_id=correlation_id, client_id=client_id)
return RequestHeader(self, correlation_id=correlation_id, client_id=client_id)

def parse_response_header(self, read_buffer):
if self.FLEXIBLE_VERSION:
return ResponseHeaderV2.decode(read_buffer)
return ResponseHeader.decode(read_buffer)


class Response(Struct):
__metaclass__ = abc.ABCMeta
Expand Down
21 changes: 7 additions & 14 deletions kafka/protocol/parser.py
Expand Up @@ -4,10 +4,9 @@
import logging

import kafka.errors as Errors
from kafka.protocol.api import RequestHeader
from kafka.protocol.commit import GroupCoordinatorResponse
from kafka.protocol.frame import KafkaBytes
from kafka.protocol.types import Int32
from kafka.protocol.types import Int32, TaggedFields
from kafka.version import __version__

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -59,9 +58,8 @@ def send_request(self, request, correlation_id=None):
log.debug('Sending request %s', request)
if correlation_id is None:
correlation_id = self._next_correlation_id()
header = RequestHeader(request,
correlation_id=correlation_id,
client_id=self._client_id)

header = request.build_request_header(correlation_id=correlation_id, client_id=self._client_id)
message = b''.join([header.encode(), request.encode()])
size = Int32.encode(len(message))
data = size + message
Expand Down Expand Up @@ -135,17 +133,12 @@ def receive_bytes(self, data):
return responses

def _process_response(self, read_buffer):
recv_correlation_id = Int32.decode(read_buffer)
log.debug('Received correlation id: %d', recv_correlation_id)

if not self.in_flight_requests:
raise Errors.CorrelationIdError(
'No in-flight-request found for server response'
' with correlation ID %d'
% (recv_correlation_id,))

raise Errors.CorrelationIdError('No in-flight-request found for server response')
(correlation_id, request) = self.in_flight_requests.popleft()

response_header = request.parse_response_header(read_buffer)
recv_correlation_id = response_header.correlation_id
log.debug('Received correlation id: %d', recv_correlation_id)
# 0.8.2 quirk
if (recv_correlation_id == 0 and
correlation_id != 0 and
Expand Down
153 changes: 153 additions & 0 deletions kafka/protocol/types.py
Expand Up @@ -210,3 +210,156 @@ def repr(self, list_of_items):
if list_of_items is None:
return 'NULL'
return '[' + ', '.join([self.array_of.repr(item) for item in list_of_items]) + ']'


class UnsignedVarInt32(AbstractType):
@classmethod
def decode(cls, data):
value, i = 0, 0
while True:
b, = struct.unpack('B', data.read(1))
if not (b & 0x80):
break
value |= (b & 0x7f) << i
i += 7
if i > 28:
raise ValueError('Invalid value {}'.format(value))
value |= b << i
return value

@classmethod
def encode(cls, value):
value &= 0xffffffff
ret = b''
while (value & 0xffffff80) != 0:
b = (value & 0x7f) | 0x80
ret += struct.pack('B', b)
value >>= 7
ret += struct.pack('B', value)
return ret


class VarInt32(AbstractType):
@classmethod
def decode(cls, data):
value = UnsignedVarInt32.decode(data)
return (value >> 1) ^ -(value & 1)

@classmethod
def encode(cls, value):
# bring it in line with the java binary repr
value &= 0xffffffff
return UnsignedVarInt32.encode((value << 1) ^ (value >> 31))


class VarInt64(AbstractType):
@classmethod
def decode(cls, data):
value, i = 0, 0
while True:
b = data.read(1)
if not (b & 0x80):
break
value |= (b & 0x7f) << i
i += 7
if i > 63:
raise ValueError('Invalid value {}'.format(value))
value |= b << i
return (value >> 1) ^ -(value & 1)

@classmethod
def encode(cls, value):
# bring it in line with the java binary repr
value &= 0xffffffffffffffff
v = (value << 1) ^ (value >> 63)
ret = b''
while (v & 0xffffffffffffff80) != 0:
b = (value & 0x7f) | 0x80
ret += struct.pack('B', b)
v >>= 7
ret += struct.pack('B', v)
return ret


class CompactString(String):
def decode(self, data):
length = UnsignedVarInt32.decode(data) - 1
if length < 0:
return None
value = data.read(length)
if len(value) != length:
raise ValueError('Buffer underrun decoding string')
return value.decode(self.encoding)

def encode(self, value):
if value is None:
return UnsignedVarInt32.encode(0)
value = str(value).encode(self.encoding)
return UnsignedVarInt32.encode(len(value) + 1) + value


class TaggedFields(AbstractType):
@classmethod
def decode(cls, data):
num_fields = UnsignedVarInt32.decode(data)
ret = {}
if not num_fields:
return ret
prev_tag = -1
for i in range(num_fields):
tag = UnsignedVarInt32.decode(data)
if tag <= prev_tag:
raise ValueError('Invalid or out-of-order tag {}'.format(tag))
prev_tag = tag
size = UnsignedVarInt32.decode(data)
val = data.read(size)
ret[tag] = val
return ret

@classmethod
def encode(cls, value):
ret = UnsignedVarInt32.encode(len(value))
for k, v in value.items():
# do we allow for other data types ?? It could get complicated really fast
assert isinstance(v, bytes), 'Value {} is not a byte array'.format(v)
assert isinstance(k, int) and k > 0, 'Key {} is not a positive integer'.format(k)
ret += UnsignedVarInt32.encode(k)
ret += v
return ret


class CompactBytes(AbstractType):
@classmethod
def decode(cls, data):
length = UnsignedVarInt32.decode(data) - 1
if length < 0:
return None
value = data.read(length)
if len(value) != length:
raise ValueError('Buffer underrun decoding Bytes')
return value

@classmethod
def encode(cls, value):
if value is None:
return UnsignedVarInt32.encode(0)
else:
return UnsignedVarInt32.encode(len(value) + 1) + value


class CompactArray(Array):

def encode(self, items):
if items is None:
return UnsignedVarInt32.encode(0)
return b''.join(
[UnsignedVarInt32.encode(len(items) + 1)] +
[self.array_of.encode(item) for item in items]
)

def decode(self, data):
length = UnsignedVarInt32.decode(data) - 1
if length == -1:
return None
return [self.array_of.decode(data) for _ in range(length)]

0 comments on commit eaedde8

Please sign in to comment.