Skip to content

Commit

Permalink
Merge pull request #1572 from mabel-dev/#1571
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Apr 8, 2024
2 parents 3e8a643 + 7a4b5f2 commit e725eed
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 3 deletions.
6 changes: 3 additions & 3 deletions opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 410
__build__ = 414

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,8 +28,8 @@ class VersionStatus(Enum):

_major = 0
_minor = 14
_revision = 0
_status = VersionStatus.RELEASE
_revision = 1
_status = VersionStatus.ALPHA

__author__ = "@joocer"
__version__ = f"{_major}.{_minor}.{_revision}" + (
Expand Down
1 change: 1 addition & 0 deletions opteryx/compiled/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from hash_table import HashTable
from hash_table import distinct
from ip_address import ip_in_cidr
from vectors import vectorize
49 changes: 49 additions & 0 deletions opteryx/compiled/functions/vectors.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# cython: language_level=3
# cython: boundscheck=False
# cython: wraparound=False

import numpy as np
cimport numpy as cnp
cimport cython

from libc.stdint cimport uint32_t, int32_t, uint16_t, uint64_t
from cpython cimport PyUnicode_AsUTF8String, PyBytes_GET_SIZE

cdef double GOLDEN_RATIO_APPROX = 1.618033988749895
cdef uint32_t VECTOR_SIZE = 1024

cdef uint64_t djb2_hash(char* byte_array, uint64_t length) nogil:
"""
Hashes a byte array using the djb2 algorithm, designed to be called without
holding the Global Interpreter Lock (GIL).
Parameters:
byte_array: char*
The byte array to hash.
length: uint64_t
The length of the byte array.
Returns:
uint64_t: The hash value.
"""
cdef uint64_t hash_value = 5381
cdef uint64_t i = 0
for i in range(length):
hash_value = ((hash_value << 5) + hash_value) + byte_array[i]
return hash_value


def vectorize(list tokens):
cdef cnp.ndarray[cnp.uint16_t, ndim=1] vector = np.zeros(VECTOR_SIZE, dtype=np.uint16)
cdef uint32_t hash_1
cdef uint32_t hash_2
cdef bytes token_bytes

for token in tokens:
token_bytes = PyUnicode_AsUTF8String(token)
hash_1 = djb2_hash(token_bytes, PyBytes_GET_SIZE(token_bytes)) & (VECTOR_SIZE - 1)
hash_2 = <int32_t>(hash_1 * GOLDEN_RATIO_APPROX) & (VECTOR_SIZE - 1)
vector[hash_1] += 1
vector[hash_2] += 1

return vector
3 changes: 3 additions & 0 deletions opteryx/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ def select_values(boolean_arrays, value_arrays):
"NULLIF": other_functions.null_if,
"CASE": select_values, #other_functions.case_when,

# Vector
"COSINE_SIMILARITY": other_functions.cosine_similarity,

# NUMERIC
"ROUND": number_functions.round,
"FLOOR": compute.floor,
Expand Down
52 changes: 52 additions & 0 deletions opteryx/functions/other_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,55 @@ def case_when(conditions, values):
else:
res.append(None)
return res


def cosine_similarity(arr, val):
"""
ad hoc cosine similarity function, slow.
"""
import re
import string

import numpy as np

from opteryx.compiled.functions import vectorize

# import time

if len(val) == 0:
return []
# print(len(val))

# Compile a regular expression pattern that matches any punctuation
punctuation_pattern = re.compile(r"[{}]".format(re.escape(string.punctuation)))

def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray, vec2_norm: np.float32) -> float:
vec1 = vec1.astype(np.float32)
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * vec2_norm)

def tokenize_and_remove_punctuation(arr):
# Replace each punctuation mark with a space
no_punctuation = punctuation_pattern.sub(" ", arr)
# Split the modified string into tokens by spaces and filter out empty tokens
tokens = [token for token in no_punctuation.lower().split(" ") if token]
return tokens

# t = time.monotonic_ns()
tokenized_strings = [tokenize_and_remove_punctuation(s) for s in arr] + [
tokenize_and_remove_punctuation(val[0])
]
# print("time tokenizing ", time.monotonic_ns() - t)
# t = time.monotonic_ns()
vectors = [vectorize(tokens) for tokens in tokenized_strings]
# print("time vectorizing", time.monotonic_ns() - t)
comparison_vector = vectors[-1].astype(np.float32)
comparison_vector_norm = np.linalg.norm(comparison_vector)

# t = time.monotonic_ns()
similarities = [
cosine_similarity(vector, comparison_vector, comparison_vector_norm)
for vector in vectors[:-1]
]
# print("time comparing ", time.monotonic_ns() - t)

return similarities
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None:
language="c++",
extra_compile_args=COMPILE_FLAGS + ["-std=c++11"],
),
Extension(
name="vectors",
sources=["opteryx/compiled/functions/vectors.pyx"],
include_dirs=[numpy.get_include()],
language="c++",
extra_compile_args=COMPILE_FLAGS + ["-std=c++11"],
),
]

setup_config = {
Expand Down
2 changes: 2 additions & 0 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,8 @@
("SELECT P_1.* FROM $planets AS P_1 INNER JOIN $planets AS P_2 USING (id, name)", 9, 20, None),
("SELECT * FROM $satellites AS P_1 INNER JOIN $satellites AS P_2 USING (id, name)", 177, 14, None),

("SELECT * FROM $missions WHERE COSINE_SIMILARITY(Location, 'LC-18A, Cape Canaveral AFS, Florida, USA') > 0.7", 657, 8, None),

("SELECT DISTINCT planetId FROM $satellites RIGHT OUTER JOIN $planets ON $satellites.planetId = $planets.id", 8, 1, None),
("SELECT DISTINCT planetId FROM $satellites RIGHT JOIN $planets ON $satellites.planetId = $planets.id", 8, 1, None),
("SELECT planetId FROM $satellites RIGHT JOIN $planets ON $satellites.planetId = $planets.id", 179, 1, None),
Expand Down

0 comments on commit e725eed

Please sign in to comment.