From b34e1d055db79cb69ab619babc56b64f3b664e99 Mon Sep 17 00:00:00 2001 From: joocer Date: Fri, 16 Feb 2024 22:41:58 +0000 Subject: [PATCH] #1451 --- opteryx/compiled/functions/__init__.py | 1 + opteryx/compiled/functions/ip_address.pyx | 53 +++++++++++++++++++ opteryx/functions/binary_operators.py | 10 ++-- setup.py | 6 +++ .../test_shapes_and_errors_battery.py | 11 ++-- 5 files changed, 70 insertions(+), 11 deletions(-) create mode 100644 opteryx/compiled/functions/ip_address.pyx diff --git a/opteryx/compiled/functions/__init__.py b/opteryx/compiled/functions/__init__.py index 6475ca02..1b03f36a 100644 --- a/opteryx/compiled/functions/__init__.py +++ b/opteryx/compiled/functions/__init__.py @@ -2,3 +2,4 @@ from hash_table import HashSet from hash_table import HashTable from hash_table import distinct +from ip_address import ip_in_cidr diff --git a/opteryx/compiled/functions/ip_address.pyx b/opteryx/compiled/functions/ip_address.pyx new file mode 100644 index 00000000..e8be2a5d --- /dev/null +++ b/opteryx/compiled/functions/ip_address.pyx @@ -0,0 +1,53 @@ +# cython: language_level=3 + +from libc.stdint cimport uint32_t +from libc.stdlib cimport strtol +from libc.string cimport strchr +from libc.string cimport memset +import numpy as np +cimport numpy as cnp + + +cdef uint32_t ip_to_int(char* ip): + cdef uint32_t result = 0 + cdef uint32_t num = 0 + cdef int shift = 24 # Start with the leftmost byte + cdef char* end + + # Convert each part of the IP to an integer + for _ in range(4): + num = strtol(ip, &end, 10) # Convert substring to long + if num > 255 or ip == end or (end[0] not in (b'.', b'\0') and _ < 3): # Validate octet and check for non-digit characters + raise ValueError("Invalid IP address") + result += num << shift + shift -= 8 + if end[0] == b'\0': # Check if end of string + break + ip = end + 1 # Move to the next part + + if shift != -8 or end[0] != b'\0': # Ensure exactly 4 octets and end of string + raise ValueError("Invalid IP address") + + return result + + +def ip_in_cidr(cnp.ndarray ip_addresses, str cidr): + cdef uint32_t base_ip, netmask, ip_int + cdef int mask_size + cdef str base_ip_str + cdef list cidr_parts = cidr.split('/') + + base_ip_str, mask_size = cidr_parts[0], int(cidr_parts[1]) + netmask = (0xFFFFFFFF << (32 - mask_size)) & 0xFFFFFFFF + + base_ip = ip_to_int(base_ip_str.encode('utf-8')) + + cdef cnp.ndarray result = np.empty(ip_addresses.shape[0], dtype=np.bool_) + cdef int i = 0 + + for i in range(ip_addresses.shape[0]): + if ip_addresses[i] is not None: + ip_int = ip_to_int(ip_addresses[i].encode('utf-8')) + result[i] = (ip_int & netmask) == base_ip + + return result diff --git a/opteryx/functions/binary_operators.py b/opteryx/functions/binary_operators.py index f3ae27bf..91cc88b0 100644 --- a/opteryx/functions/binary_operators.py +++ b/opteryx/functions/binary_operators.py @@ -120,14 +120,12 @@ def _ip_containment(left: List[Optional[str]], right: List[str]) -> List[Optiona List[Optional[bool]]: A list of boolean values indicating if each corresponding IP in 'left' is in 'right'. """ - from ipaddress import AddressValueError - from ipaddress import IPv4Address - from ipaddress import IPv4Network + + from opteryx.compiled.functions import ip_in_cidr try: - network = IPv4Network(right[0], strict=False) - return [(IPv4Address(ip) in network) if ip is not None else None for ip in left] - except AddressValueError as err: + return ip_in_cidr(left, str(right[0])) + except (IndexError, AttributeError, ValueError) as err: from opteryx.exceptions import IncorrectTypeError raise IncorrectTypeError( diff --git a/setup.py b/setup.py index 11849562..a8edc4ca 100644 --- a/setup.py +++ b/setup.py @@ -86,6 +86,12 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None: include_dirs=[numpy.get_include()], extra_compile_args=COMPILE_FLAGS, ), + Extension( + name="ip_address", + sources=["opteryx/compiled/functions/ip_address.pyx"], + include_dirs=[numpy.get_include()], + extra_compile_args=COMPILE_FLAGS, + ), Extension( name="hash_table", sources=["opteryx/compiled/functions/hash_table.pyx"], diff --git a/tests/sql_battery/test_shapes_and_errors_battery.py b/tests/sql_battery/test_shapes_and_errors_battery.py index b51656cc..41c716cf 100644 --- a/tests/sql_battery/test_shapes_and_errors_battery.py +++ b/tests/sql_battery/test_shapes_and_errors_battery.py @@ -1093,16 +1093,17 @@ ("SELECT name FROM $satellites WHERE '1' | '1'", None, None, IncorrectTypeError), ("SELECT name FROM $satellites WHERE 'abc' | '192.168.1.1'", None, None, IncorrectTypeError), ("SELECT name FROM $satellites WHERE 123 | '192.168.1.1'", None, None, IncorrectTypeError), - ("SELECT name FROM $satellites WHERE '10.10.10.10' | '192.168.1.1'", 0, 1, None), + ("SELECT name FROM $satellites WHERE '10.10.10.10' | '192.168.1.1'", 0, 1, IncorrectTypeError), ("SELECT name FROM $satellites WHERE 0 | 0", 0, 1, None), ("SELECT name FROM $satellites WHERE 0 | 123456", 177, 1, None), ("SELECT name FROM $satellites WHERE 123456 | 0", 177, 1, None), ("SELECT name FROM $satellites WHERE 987654321 | 123456789", 177, 1, None), ("SELECT '192.168.1.1' | '255.255..255'", None, None, IncorrectTypeError), - ("SELECT '192.168.1.1/32' | '192.168.1.1'", None, None, IncorrectTypeError), - ("SELECT '192.168.1.*' | '192.168.1.1'", None, None, IncorrectTypeError), - ("SELECT '!!' | '192.168.1.1'", None, None, IncorrectTypeError), - ("SELECT null | '192.168.1.1'", 1, 1, None), + ("SELECT '192.168.1.1/32' | '192.168.1.1/8'", None, None, IncorrectTypeError), + ("SELECT '192.168.1.*' | '192.168.1.1/8'", None, None, IncorrectTypeError), + ("SELECT '!!' | '192.168.1.1/8'", None, None, IncorrectTypeError), + ("SELECT null | '192.168.1.1'", 1, 1, IncorrectTypeError), + ("SELECT null | '192.168.1.1/8'", 1, 1, IncorrectTypeError), ("SELECT * FROM testdata.flat.different", 196902, 15, None), ("SELECT * FROM testdata.flat.different WHERE following < 10", 7814, 15, None),