Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions kafka/producer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,29 @@ def is_full(self):
return self._buffer.tell() >= self._batch_size

def close(self):
if self._compressor:
# TODO: avoid copies with bytearray / memoryview
self._buffer.seek(4)
msg = Message(self._compressor(self._buffer.read()),
attributes=self._compression_attributes,
magic=self._message_version)
encoded = msg.encode()
self._buffer.seek(4)
self._buffer.write(Int64.encode(0)) # offset 0 for wrapper msg
self._buffer.write(Int32.encode(len(encoded)))
self._buffer.write(encoded)

# Update the message set size, and return ready for full read()
size = self._buffer.tell() - 4
self._buffer.seek(0)
self._buffer.write(Int32.encode(size))
# This method may be called multiple times on the same batch
# i.e., on retries
# we need to make sure we only close it out once
# otherwise compressed messages may be double-compressed
# see Issue 718
if not self._closed:
if self._compressor:
# TODO: avoid copies with bytearray / memoryview
self._buffer.seek(4)
msg = Message(self._compressor(self._buffer.read()),
attributes=self._compression_attributes,
magic=self._message_version)
encoded = msg.encode()
self._buffer.seek(4)
self._buffer.write(Int64.encode(0)) # offset 0 for wrapper msg
self._buffer.write(Int32.encode(len(encoded)))
self._buffer.write(encoded)

# Update the message set size, and return ready for full read()
size = self._buffer.tell() - 4
self._buffer.seek(0)
self._buffer.write(Int32.encode(size))

self._buffer.seek(0)
self._closed = True

Expand Down
70 changes: 70 additions & 0 deletions test/test_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# pylint: skip-file
from __future__ import absolute_import

import io

import pytest

from kafka.producer.buffer import MessageSetBuffer
from kafka.protocol.message import Message, MessageSet


def test_buffer_close():
records = MessageSetBuffer(io.BytesIO(), 100000)
orig_msg = Message(b'foobar')
records.append(1234, orig_msg)
records.close()

msgset = MessageSet.decode(records.buffer())
assert len(msgset) == 1
(offset, size, msg) = msgset[0]
assert offset == 1234
assert msg == orig_msg

# Closing again should work fine
records.close()

msgset = MessageSet.decode(records.buffer())
assert len(msgset) == 1
(offset, size, msg) = msgset[0]
assert offset == 1234
assert msg == orig_msg


@pytest.mark.parametrize('compression', [
'gzip',
'snappy',
pytest.mark.skipif("sys.version_info < (2,7)")('lz4'), # lz4tools does not work on py26
])
def test_compressed_buffer_close(compression):
records = MessageSetBuffer(io.BytesIO(), 100000, compression_type=compression)
orig_msg = Message(b'foobar')
records.append(1234, orig_msg)
records.close()

msgset = MessageSet.decode(records.buffer())
assert len(msgset) == 1
(offset, size, msg) = msgset[0]
assert offset == 0
assert msg.is_compressed()

msgset = msg.decompress()
(offset, size, msg) = msgset[0]
assert not msg.is_compressed()
assert offset == 1234
assert msg == orig_msg

# Closing again should work fine
records.close()

msgset = MessageSet.decode(records.buffer())
assert len(msgset) == 1
(offset, size, msg) = msgset[0]
assert offset == 0
assert msg.is_compressed()

msgset = msg.decompress()
(offset, size, msg) = msgset[0]
assert not msg.is_compressed()
assert offset == 1234
assert msg == orig_msg