Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Feb 16, 2024
1 parent eaea5ef commit b34e1d0
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 11 deletions.
1 change: 1 addition & 0 deletions opteryx/compiled/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from hash_table import HashSet
from hash_table import HashTable
from hash_table import distinct
from ip_address import ip_in_cidr
53 changes: 53 additions & 0 deletions opteryx/compiled/functions/ip_address.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# cython: language_level=3

from libc.stdint cimport uint32_t
from libc.stdlib cimport strtol
from libc.string cimport strchr
from libc.string cimport memset
import numpy as np
cimport numpy as cnp


cdef uint32_t ip_to_int(char* ip):
cdef uint32_t result = 0
cdef uint32_t num = 0
cdef int shift = 24 # Start with the leftmost byte
cdef char* end

# Convert each part of the IP to an integer
for _ in range(4):
num = strtol(ip, &end, 10) # Convert substring to long
if num > 255 or ip == end or (end[0] not in (b'.', b'\0') and _ < 3): # Validate octet and check for non-digit characters
raise ValueError("Invalid IP address")
result += num << shift
shift -= 8
if end[0] == b'\0': # Check if end of string
break
ip = end + 1 # Move to the next part

if shift != -8 or end[0] != b'\0': # Ensure exactly 4 octets and end of string
raise ValueError("Invalid IP address")

return result


def ip_in_cidr(cnp.ndarray ip_addresses, str cidr):
cdef uint32_t base_ip, netmask, ip_int
cdef int mask_size
cdef str base_ip_str
cdef list cidr_parts = cidr.split('/')

base_ip_str, mask_size = cidr_parts[0], int(cidr_parts[1])
netmask = (0xFFFFFFFF << (32 - mask_size)) & 0xFFFFFFFF

base_ip = ip_to_int(base_ip_str.encode('utf-8'))

cdef cnp.ndarray result = np.empty(ip_addresses.shape[0], dtype=np.bool_)
cdef int i = 0

for i in range(ip_addresses.shape[0]):
if ip_addresses[i] is not None:
ip_int = ip_to_int(ip_addresses[i].encode('utf-8'))
result[i] = (ip_int & netmask) == base_ip

return result
10 changes: 4 additions & 6 deletions opteryx/functions/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,12 @@ def _ip_containment(left: List[Optional[str]], right: List[str]) -> List[Optiona
List[Optional[bool]]:
A list of boolean values indicating if each corresponding IP in 'left' is in 'right'.
"""
from ipaddress import AddressValueError
from ipaddress import IPv4Address
from ipaddress import IPv4Network

from opteryx.compiled.functions import ip_in_cidr

try:
network = IPv4Network(right[0], strict=False)
return [(IPv4Address(ip) in network) if ip is not None else None for ip in left]
except AddressValueError as err:
return ip_in_cidr(left, str(right[0]))
except (IndexError, AttributeError, ValueError) as err:
from opteryx.exceptions import IncorrectTypeError

raise IncorrectTypeError(
Expand Down
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None:
include_dirs=[numpy.get_include()],
extra_compile_args=COMPILE_FLAGS,
),
Extension(
name="ip_address",
sources=["opteryx/compiled/functions/ip_address.pyx"],
include_dirs=[numpy.get_include()],
extra_compile_args=COMPILE_FLAGS,
),
Extension(
name="hash_table",
sources=["opteryx/compiled/functions/hash_table.pyx"],
Expand Down
11 changes: 6 additions & 5 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,16 +1093,17 @@
("SELECT name FROM $satellites WHERE '1' | '1'", None, None, IncorrectTypeError),
("SELECT name FROM $satellites WHERE 'abc' | '192.168.1.1'", None, None, IncorrectTypeError),
("SELECT name FROM $satellites WHERE 123 | '192.168.1.1'", None, None, IncorrectTypeError),
("SELECT name FROM $satellites WHERE '10.10.10.10' | '192.168.1.1'", 0, 1, None),
("SELECT name FROM $satellites WHERE '10.10.10.10' | '192.168.1.1'", 0, 1, IncorrectTypeError),
("SELECT name FROM $satellites WHERE 0 | 0", 0, 1, None),
("SELECT name FROM $satellites WHERE 0 | 123456", 177, 1, None),
("SELECT name FROM $satellites WHERE 123456 | 0", 177, 1, None),
("SELECT name FROM $satellites WHERE 987654321 | 123456789", 177, 1, None),
("SELECT '192.168.1.1' | '255.255..255'", None, None, IncorrectTypeError),
("SELECT '192.168.1.1/32' | '192.168.1.1'", None, None, IncorrectTypeError),
("SELECT '192.168.1.*' | '192.168.1.1'", None, None, IncorrectTypeError),
("SELECT '!!' | '192.168.1.1'", None, None, IncorrectTypeError),
("SELECT null | '192.168.1.1'", 1, 1, None),
("SELECT '192.168.1.1/32' | '192.168.1.1/8'", None, None, IncorrectTypeError),
("SELECT '192.168.1.*' | '192.168.1.1/8'", None, None, IncorrectTypeError),
("SELECT '!!' | '192.168.1.1/8'", None, None, IncorrectTypeError),
("SELECT null | '192.168.1.1'", 1, 1, IncorrectTypeError),
("SELECT null | '192.168.1.1/8'", 1, 1, IncorrectTypeError),

("SELECT * FROM testdata.flat.different", 196902, 15, None),
("SELECT * FROM testdata.flat.different WHERE following < 10", 7814, 15, None),
Expand Down

0 comments on commit b34e1d0

Please sign in to comment.