Skip to content

Commit

Permalink
Merge pull request #73 from AchilleAsh/master
Browse files Browse the repository at this point in the history
Add native IPv4 and IPv6 types support
  • Loading branch information
xzkostyan committed Feb 24, 2019
2 parents 787f1a3 + dd2aaa7 commit 2beeac0
Show file tree
Hide file tree
Showing 7 changed files with 331 additions and 1 deletion.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
env:
- VERSION=19.3.3
- VERSION=18.12.17
- VERSION=18.12.13
- VERSION=18.10.3
Expand Down
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Features
* Nullable(T)
* UUID
* Decimal
* IPv4/IPv6

- Query progress information.

Expand Down
93 changes: 93 additions & 0 deletions clickhouse_driver/columns/ipcolumn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from ipaddress import IPv4Address, IPv6Address, AddressValueError

from .. import errors
from ..util import compat
from .exceptions import ColumnTypeMismatchException
from .stringcolumn import ByteFixedString
from .intcolumn import UInt32Column


class IPv4Column(UInt32Column):
ch_type = "IPv4"
py_types = compat.string_types + (IPv4Address, int)

def __init__(self, types_check=False, **kwargs):
# UIntColumn overrides before_write_item and check_item
# in its __init__ when types_check is True so we force
# __init__ without it then add the appropriate check method for IPv4
super(UInt32Column, self).__init__(types_check=False, **kwargs)

self.types_check_enabled = types_check
if types_check:

def check_item(value):
if isinstance(value, int) and value < 0:
raise ColumnTypeMismatchException(value)

if not isinstance(value, IPv4Address):
try:
value = IPv4Address(value)
except AddressValueError:
# Cannot parse input in a valid IPv4
raise ColumnTypeMismatchException(value)

self.check_item = check_item

def before_write_item(self, value):
# allow Ipv4 in integer, string or IPv4Address object
try:
if isinstance(value, int):
return value

if not isinstance(value, IPv4Address):
value = IPv4Address(value)

return int(value)
except AddressValueError:
raise errors.CannotParseDomainError(
"Cannot parse IPv4 '{}'".format(value)
)

def after_read_item(self, value):
return IPv4Address(value)


class IPv6Column(ByteFixedString):
ch_type = "IPv6"
py_types = compat.string_types + (IPv6Address, bytes)

def __init__(self, types_check=False, **kwargs):
super(IPv6Column, self).__init__(16, types_check=types_check, **kwargs)

if types_check:

def check_item(value):
if isinstance(value, bytes) and len(value) != 16:
raise ColumnTypeMismatchException(value)

if not isinstance(value, IPv6Address):
try:
value = IPv6Address(value)
except AddressValueError:
# Cannot parse input in a valid IPv6
raise ColumnTypeMismatchException(value)

self.check_item = check_item

def before_write_item(self, value):
# allow Ipv6 in bytes or python IPv6Address
# this is raw bytes (not encoded) in order to fit FixedString(16)
try:
if isinstance(value, bytes):
return value

if not isinstance(value, IPv6Address):
value = IPv6Address(value)
return value.packed
except AddressValueError:
raise errors.CannotParseDomainError(
"Cannot parse IPv6 '{}'".format(value)
)

def after_read_item(self, value):
return IPv6Address(value)
3 changes: 2 additions & 1 deletion clickhouse_driver/columns/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
IntervalDayColumn, IntervalHourColumn, IntervalMinuteColumn,
IntervalSecondColumn
)
from .ipcolumn import IPv4Column, IPv6Column


column_by_type = {c.ch_type: c for c in [
Expand All @@ -29,7 +30,7 @@
NothingColumn, NullColumn, UUIDColumn,
IntervalYearColumn, IntervalMonthColumn, IntervalWeekColumn,
IntervalDayColumn, IntervalHourColumn, IntervalMinuteColumn,
IntervalSecondColumn
IntervalSecondColumn, IPv4Column, IPv6Column
]}


Expand Down
5 changes: 5 additions & 0 deletions clickhouse_driver/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ class ErrorCodes(object):
FUNCTION_THROW_IF_VALUE_IS_NON_ZERO = 395
TOO_MANY_ROWS_OR_BYTES = 396
QUERY_IS_NOT_SUPPORTED_IN_MATERIALIZED_VIEW = 397
CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING = 441

KEEPER_EXCEPTION = 999
POCO_EXCEPTION = 1000
Expand Down Expand Up @@ -466,3 +467,7 @@ class UnknownPacketFromServerError(Error):

class CannotParseUuidError(Error):
code = ErrorCodes.CANNOT_PARSE_UUID


class CannotParseDomainError(Error):
code = ErrorCodes.CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
install_requires = ['pytz']
if not PY34:
install_requires.append('enum34')
install_requires.append('ipaddress')


def read_version():
Expand Down
228 changes: 228 additions & 0 deletions tests/columns/test_ip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
from __future__ import unicode_literals

from clickhouse_driver import errors
from ipaddress import IPv6Address, IPv4Address

from tests.testcase import BaseTestCase
from tests.util import require_server_version


class IPv4TestCase(BaseTestCase):
@require_server_version(19, 3, 3)
def test_simple(self):
with self.create_table('a IPv4'):
data = [
(IPv4Address("10.0.0.1"),),
(IPv4Address("192.168.253.42"),)
]
self.client.execute(
'INSERT INTO test (a) VALUES', data
)

query = 'SELECT * FROM test'
inserted = self.emit_cli(query)
self.assertEqual(inserted, (
'10.0.0.1\n'
'192.168.253.42\n'
))
inserted = self.client.execute(query)
self.assertEqual(inserted, [
(IPv4Address("10.0.0.1"),),
(IPv4Address("192.168.253.42"),)
])

@require_server_version(19, 3, 3)
def test_from_int(self):
with self.create_table('a IPv4'):
data = [
(167772161,),
]
self.client.execute(
'INSERT INTO test (a) VALUES', data, types_check=True
)

query = 'SELECT * FROM test'
inserted = self.emit_cli(query)
self.assertEqual(inserted, (
'10.0.0.1\n'
))
inserted = self.client.execute(query)
self.assertEqual(inserted, [
(IPv4Address("10.0.0.1"),),
])

@require_server_version(19, 3, 3)
def test_from_str(self):
with self.create_table('a IPv4'):
data = [
("10.0.0.1",),
]
self.client.execute(
'INSERT INTO test (a) VALUES', data, types_check=True
)

query = 'SELECT * FROM test'
inserted = self.emit_cli(query)
self.assertEqual(inserted, (
'10.0.0.1\n'
))
inserted = self.client.execute(query)
self.assertEqual(inserted, [
(IPv4Address("10.0.0.1"),),
])

@require_server_version(19, 3, 3)
def test_type_mismatch(self):
data = [(1025.2147,)]
with self.create_table('a IPv4'):
with self.assertRaises(errors.TypeMismatchError):
self.client.execute(
'INSERT INTO test (a) VALUES', data, types_check=True
)

@require_server_version(19, 3, 3)
def test_bad_ipv4(self):
data = [('985.512.12.0',)]
with self.create_table('a IPv4'):
with self.assertRaises(errors.CannotParseDomainError):
self.client.execute(
'INSERT INTO test (a) VALUES', data
)

@require_server_version(19, 3, 3)
def test_bad_ipv4_with_type_check(self):
data = [('985.512.12.0',)]
with self.create_table('a IPv4'):
with self.assertRaises(errors.TypeMismatchError):
self.client.execute(
'INSERT INTO test (a) VALUES', data, types_check=True
)

@require_server_version(19, 3, 3)
def test_nullable(self):
with self.create_table('a Nullable(IPv4)'):
data = [(IPv4Address('10.10.10.10'),), (None,)]
self.client.execute(
'INSERT INTO test (a) VALUES', data
)

query = 'SELECT * FROM test'
inserted = self.emit_cli(query)
self.assertEqual(inserted,
'10.10.10.10\n\\N\n')

inserted = self.client.execute(query)
self.assertEqual(inserted, data)


class IPv6TestCase(BaseTestCase):
@require_server_version(19, 3, 3)
def test_simple(self):
with self.create_table('a IPv6'):
data = [
(IPv6Address('79f4:e698:45de:a59b:2765:28e3:8d3a:35ae'),),
(IPv6Address('a22:cc64:cf47:1653:4976:3c0c:ff8d:417c'),),
(IPv6Address('12ff:0000:0000:0000:0000:0000:0000:0001'),)
]
self.client.execute(
'INSERT INTO test (a) VALUES', data
)

query = 'SELECT * FROM test'
inserted = self.emit_cli(query)
self.assertEqual(inserted, (
'79f4:e698:45de:a59b:2765:28e3:8d3a:35ae\n'
'a22:cc64:cf47:1653:4976:3c0c:ff8d:417c\n'
'12ff::1\n'
))
inserted = self.client.execute(query)
self.assertEqual(inserted, [
(IPv6Address('79f4:e698:45de:a59b:2765:28e3:8d3a:35ae'),),
(IPv6Address('a22:cc64:cf47:1653:4976:3c0c:ff8d:417c'),),
(IPv6Address('12ff::1'),)
])

@require_server_version(19, 3, 3)
def test_from_str(self):
with self.create_table('a IPv6'):
data = [
('79f4:e698:45de:a59b:2765:28e3:8d3a:35ae',),
]
self.client.execute(
'INSERT INTO test (a) VALUES', data, types_check=True
)

query = 'SELECT * FROM test'
inserted = self.emit_cli(query)
self.assertEqual(inserted, (
'79f4:e698:45de:a59b:2765:28e3:8d3a:35ae\n'
))
inserted = self.client.execute(query)
self.assertEqual(inserted, [
(IPv6Address('79f4:e698:45de:a59b:2765:28e3:8d3a:35ae'),),
])

@require_server_version(19, 3, 3)
def test_from_bytes(self):
with self.create_table('a IPv6'):
data = [
(b"y\xf4\xe6\x98E\xde\xa5\x9b'e(\xe3\x8d:5\xae",),
]
self.client.execute(
'INSERT INTO test (a) VALUES', data, types_check=True
)

query = 'SELECT * FROM test'
inserted = self.emit_cli(query)
self.assertEqual(inserted, (
'79f4:e698:45de:a59b:2765:28e3:8d3a:35ae\n'
))
inserted = self.client.execute(query)
self.assertEqual(inserted, [
(IPv6Address('79f4:e698:45de:a59b:2765:28e3:8d3a:35ae'),),
])

@require_server_version(19, 3, 3)
def test_type_mismatch(self):
data = [(1025.2147,)]
with self.create_table('a IPv6'):
with self.assertRaises(errors.TypeMismatchError):
self.client.execute(
'INSERT INTO test (a) VALUES', data, types_check=True
)

@require_server_version(19, 3, 3)
def test_bad_ipv6(self):
data = [("ghjk:e698:45de:a59b:2765:28e3:8d3a:zzzz",)]
with self.create_table('a IPv6'):
with self.assertRaises(errors.CannotParseDomainError):
self.client.execute(
'INSERT INTO test (a) VALUES', data
)

@require_server_version(19, 3, 3)
def test_bad_ipv6_with_type_check(self):
data = [("ghjk:e698:45de:a59b:2765:28e3:8d3a:zzzz",)]
with self.create_table('a IPv6'):
with self.assertRaises(errors.TypeMismatchError):
self.client.execute(
'INSERT INTO test (a) VALUES', data, types_check=True
)

@require_server_version(19, 3, 3)
def test_nullable(self):
with self.create_table('a Nullable(IPv6)'):
data = [
(IPv6Address('79f4:e698:45de:a59b:2765:28e3:8d3a:35ae'),),
(None,)]
self.client.execute(
'INSERT INTO test (a) VALUES', data
)

query = 'SELECT * FROM test'
inserted = self.emit_cli(query)
self.assertEqual(inserted,
'79f4:e698:45de:a59b:2765:28e3:8d3a:35ae\n\\N\n')

inserted = self.client.execute(query)
self.assertEqual(inserted, data)

0 comments on commit 2beeac0

Please sign in to comment.