Skip to content

Commit

Permalink
Fix ZSTD decompression #269
Browse files Browse the repository at this point in the history
  • Loading branch information
xzkostyan committed Nov 20, 2021
1 parent a44a172 commit 7a82f21
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 79 deletions.
2 changes: 1 addition & 1 deletion clickhouse_driver/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Client(object):
)

def __init__(self, *args, **kwargs):
self.settings = kwargs.pop('settings', {}).copy()
self.settings = (kwargs.pop('settings', None) or {}).copy()

self.client_settings = {
'insert_block_size': int(self.settings.pop(
Expand Down
43 changes: 39 additions & 4 deletions clickhouse_driver/compression/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from io import BytesIO

from ..reader import read_binary_uint32
from ..writer import write_binary_uint8, write_binary_uint32
from .. import errors

try:
from clickhouse_cityhash.cityhash import CityHash128
except ImportError:
raise RuntimeError(
'Package clickhouse-cityhash is required to use compression'
)

from .. import errors


class BaseCompressor(object):
"""
Expand All @@ -31,9 +33,23 @@ def get_value(self):
def write(self, p_str):
self.data.write(p_str)

def get_compressed_data(self, extra_header_size):
def compress_data(self, data):
raise NotImplementedError

def get_compressed_data(self, extra_header_size):
rv = BytesIO()

data = self.get_value()
compressed = self.compress_data(data)

header_size = extra_header_size + 4 + 4 # sizes

write_binary_uint32(header_size + len(compressed), rv)
write_binary_uint32(len(data), rv)
rv.write(compressed)

return rv.getvalue()


class BaseDecompressor(object):
method = None
Expand All @@ -43,10 +59,29 @@ def __init__(self, real_stream):
self.stream = real_stream
super(BaseDecompressor, self).__init__()

def decompress_data(self, data, uncompressed_size):
raise NotImplementedError

def check_hash(self, compressed_data, compressed_hash):
if CityHash128(compressed_data) != compressed_hash:
raise errors.ChecksumDoesntMatchError()

def get_decompressed_data(self, method_byte, compressed_hash,
extra_header_size):
raise NotImplementedError
size_with_header = read_binary_uint32(self.stream)
compressed_size = size_with_header - extra_header_size - 4

compressed = BytesIO(self.stream.read(compressed_size))

block_check = BytesIO()
write_binary_uint8(method_byte, block_check)
write_binary_uint32(size_with_header, block_check)
block_check.write(compressed.getvalue())

self.check_hash(block_check.getvalue(), compressed_hash)

uncompressed_size = read_binary_uint32(compressed)

compressed = compressed.read(compressed_size - 4)

return self.decompress_data(compressed, uncompressed_size)
42 changes: 4 additions & 38 deletions clickhouse_driver/compression/lz4.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,21 @@
from __future__ import absolute_import
from io import BytesIO

from lz4 import block

from .base import BaseCompressor, BaseDecompressor
from ..protocol import CompressionMethod, CompressionMethodByte
from ..reader import read_binary_uint32
from ..writer import write_binary_uint32, write_binary_uint8


class Compressor(BaseCompressor):
method = CompressionMethod.LZ4
method_byte = CompressionMethodByte.LZ4
mode = 'default'

def get_compressed_data(self, extra_header_size):
rv = BytesIO()

data = self.get_value()
compressed = block.compress(data, store_size=False, mode=self.mode)

header_size = extra_header_size + 4 + 4 # sizes

write_binary_uint32(header_size + len(compressed), rv)
write_binary_uint32(len(data), rv)
rv.write(compressed)

return rv.getvalue()
def compress_data(self, data):
return block.compress(data, store_size=False, mode=self.mode)


class Decompressor(BaseDecompressor):
method = CompressionMethod.LZ4
method_byte = CompressionMethodByte.LZ4

def get_decompressed_data(self, method_byte, compressed_hash,
extra_header_size):
size_with_header = read_binary_uint32(self.stream)
compressed_size = size_with_header - extra_header_size - 4

compressed = BytesIO(self.stream.read(compressed_size))

block_check = BytesIO()
write_binary_uint8(method_byte, block_check)
write_binary_uint32(size_with_header, block_check)
block_check.write(compressed.getvalue())

self.check_hash(block_check.getvalue(), compressed_hash)

uncompressed_size = read_binary_uint32(compressed)

compressed = compressed.read(compressed_size - 4)

return block.decompress(compressed,
uncompressed_size=uncompressed_size)
def decompress_data(self, data, uncompressed_size):
return block.decompress(data, uncompressed_size=uncompressed_size)
39 changes: 4 additions & 35 deletions clickhouse_driver/compression/zstd.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,20 @@
from __future__ import absolute_import
from io import BytesIO

import zstd

from .base import BaseCompressor, BaseDecompressor
from ..protocol import CompressionMethod, CompressionMethodByte
from ..reader import read_binary_uint32
from ..writer import write_binary_uint32, write_binary_uint8


class Compressor(BaseCompressor):
method = CompressionMethod.ZSTD
method_byte = CompressionMethodByte.ZSTD

def get_compressed_data(self, extra_header_size):
rv = BytesIO()

data = self.get_value()
compressed = zstd.compress(data)

header_size = extra_header_size + 4 + 4 # sizes

write_binary_uint32(header_size + len(compressed), rv)
write_binary_uint32(len(data), rv)
rv.write(compressed)

return rv.getvalue()
def compress_data(self, data):
return zstd.compress(data)


class Decompressor(BaseDecompressor):
method = CompressionMethod.ZSTD
method_byte = CompressionMethodByte.ZSTD

def get_decompressed_data(self, method_byte, compressed_hash,
extra_header_size):
size_with_header = read_binary_uint32(self.stream)
compressed_size = size_with_header - extra_header_size - 4

compressed = BytesIO(self.stream.read(compressed_size))

block_check = BytesIO()
write_binary_uint8(method_byte, block_check)
write_binary_uint32(size_with_header, block_check)
block_check.write(compressed.getvalue())

self.check_hash(block_check.getvalue(), compressed_hash)

compressed = compressed.read(compressed_size - 4)

return zstd.decompress(compressed)
def decompress_data(self, data, uncompressed_size):
return zstd.decompress(data)
11 changes: 10 additions & 1 deletion tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,18 @@ class BaseCompressionTestCase(BaseTestCase):
supported_compressions = file_config.get('db', 'compression').split(',')

def _create_client(self):
settings = None
if self.compression:
# Set server compression method explicitly
# By default server sends blocks compressed by LZ4.
method = self.compression
if self.server_version > (19, ):
method = method.upper()
settings = {'network_compression_method': method}

return Client(
self.host, self.port, self.database, self.user, self.password,
compression=self.compression
compression=self.compression, settings=settings
)

def setUp(self):
Expand Down

0 comments on commit 7a82f21

Please sign in to comment.