Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1571 #1572

Merged
merged 4 commits into from
Apr 8, 2024
Merged

#1571 #1572

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 410
__build__ = 414

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,8 +28,8 @@ class VersionStatus(Enum):

_major = 0
_minor = 14
_revision = 0
_status = VersionStatus.RELEASE
_revision = 1
_status = VersionStatus.ALPHA

__author__ = "@joocer"
__version__ = f"{_major}.{_minor}.{_revision}" + (
Expand Down
1 change: 1 addition & 0 deletions opteryx/compiled/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from hash_table import HashTable
from hash_table import distinct
from ip_address import ip_in_cidr
from vectors import vectorize
49 changes: 49 additions & 0 deletions opteryx/compiled/functions/vectors.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# cython: language_level=3
# cython: boundscheck=False
# cython: wraparound=False

import numpy as np
cimport numpy as cnp
cimport cython

from libc.stdint cimport uint32_t, int32_t, uint16_t, uint64_t
from cpython cimport PyUnicode_AsUTF8String, PyBytes_GET_SIZE

cdef double GOLDEN_RATIO_APPROX = 1.618033988749895
cdef uint32_t VECTOR_SIZE = 1024

cdef uint64_t djb2_hash(char* byte_array, uint64_t length) nogil:
"""
Hashes a byte array using the djb2 algorithm, designed to be called without
holding the Global Interpreter Lock (GIL).

Parameters:
byte_array: char*
The byte array to hash.
length: uint64_t
The length of the byte array.

Returns:
uint64_t: The hash value.
"""
cdef uint64_t hash_value = 5381
cdef uint64_t i = 0
for i in range(length):
hash_value = ((hash_value << 5) + hash_value) + byte_array[i]
return hash_value


def vectorize(list tokens):
cdef cnp.ndarray[cnp.uint16_t, ndim=1] vector = np.zeros(VECTOR_SIZE, dtype=np.uint16)
cdef uint32_t hash_1
cdef uint32_t hash_2
cdef bytes token_bytes

for token in tokens:
token_bytes = PyUnicode_AsUTF8String(token)
hash_1 = djb2_hash(token_bytes, PyBytes_GET_SIZE(token_bytes)) & (VECTOR_SIZE - 1)
hash_2 = <int32_t>(hash_1 * GOLDEN_RATIO_APPROX) & (VECTOR_SIZE - 1)
vector[hash_1] += 1
vector[hash_2] += 1

return vector
3 changes: 3 additions & 0 deletions opteryx/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ def select_values(boolean_arrays, value_arrays):
"NULLIF": other_functions.null_if,
"CASE": select_values, #other_functions.case_when,

# Vector
"COSINE_SIMILARITY": other_functions.cosine_similarity,

# NUMERIC
"ROUND": number_functions.round,
"FLOOR": compute.floor,
Expand Down
52 changes: 52 additions & 0 deletions opteryx/functions/other_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,55 @@ def case_when(conditions, values):
else:
res.append(None)
return res


def cosine_similarity(arr, val):
"""
ad hoc cosine similarity function, slow.
"""
import re
import string

import numpy as np

from opteryx.compiled.functions import vectorize

# import time

if len(val) == 0:
return []
# print(len(val))

# Compile a regular expression pattern that matches any punctuation
punctuation_pattern = re.compile(r"[{}]".format(re.escape(string.punctuation)))

def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray, vec2_norm: np.float32) -> float:
vec1 = vec1.astype(np.float32)
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * vec2_norm)

def tokenize_and_remove_punctuation(arr):
# Replace each punctuation mark with a space
no_punctuation = punctuation_pattern.sub(" ", arr)
# Split the modified string into tokens by spaces and filter out empty tokens
tokens = [token for token in no_punctuation.lower().split(" ") if token]
return tokens

# t = time.monotonic_ns()
tokenized_strings = [tokenize_and_remove_punctuation(s) for s in arr] + [
tokenize_and_remove_punctuation(val[0])
]
# print("time tokenizing ", time.monotonic_ns() - t)
# t = time.monotonic_ns()
vectors = [vectorize(tokens) for tokens in tokenized_strings]
# print("time vectorizing", time.monotonic_ns() - t)
comparison_vector = vectors[-1].astype(np.float32)
comparison_vector_norm = np.linalg.norm(comparison_vector)

# t = time.monotonic_ns()
similarities = [
cosine_similarity(vector, comparison_vector, comparison_vector_norm)
for vector in vectors[:-1]
]
# print("time comparing ", time.monotonic_ns() - t)

return similarities
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None:
language="c++",
extra_compile_args=COMPILE_FLAGS + ["-std=c++11"],
),
Extension(
name="vectors",
sources=["opteryx/compiled/functions/vectors.pyx"],
include_dirs=[numpy.get_include()],
language="c++",
extra_compile_args=COMPILE_FLAGS + ["-std=c++11"],
),
]

setup_config = {
Expand Down
2 changes: 2 additions & 0 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,8 @@
("SELECT P_1.* FROM $planets AS P_1 INNER JOIN $planets AS P_2 USING (id, name)", 9, 20, None),
("SELECT * FROM $satellites AS P_1 INNER JOIN $satellites AS P_2 USING (id, name)", 177, 14, None),

("SELECT * FROM $missions WHERE COSINE_SIMILARITY(Location, 'LC-18A, Cape Canaveral AFS, Florida, USA') > 0.7", 657, 8, None),

("SELECT DISTINCT planetId FROM $satellites RIGHT OUTER JOIN $planets ON $satellites.planetId = $planets.id", 8, 1, None),
("SELECT DISTINCT planetId FROM $satellites RIGHT JOIN $planets ON $satellites.planetId = $planets.id", 8, 1, None),
("SELECT planetId FROM $satellites RIGHT JOIN $planets ON $satellites.planetId = $planets.id", 179, 1, None),
Expand Down