Skip to content

Commit 45f09c1

Browse files
committed
fix(vec0): correct parser logic and use proper float math functions
- Fix parser condition checks: change && to || in token validation (vec0_parse_table_option, vec0_parse_partition_key_definition, vec0_parse_auxiliary_column_definition, vec0_parse_primary_key_definition, vec0_parse_vector_column) - Use sqrtf() instead of sqrt() for f32 distance calculations - Initialize variables to prevent undefined behavior on error paths - Add missing <stdarg.h> include for va_start/va_end Add comprehensive parser edge case tests and adjust existing tests for sqrtf precision differences.
1 parent c39ada1 commit 45f09c1

File tree

4 files changed

+507
-26
lines changed

4 files changed

+507
-26
lines changed

sqlite-vec.c

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <inttypes.h>
77
#include <limits.h>
88
#include <math.h>
9+
#include <stdarg.h>
910
#include <stdbool.h>
1011
#include <stdint.h>
1112
#include <stdlib.h>
@@ -448,7 +449,7 @@ static f32 l2_sqr_float(const void *pVect1v, const void *pVect2v,
448449
pVect2++;
449450
res += t * t;
450451
}
451-
return sqrt(res);
452+
return sqrtf(res);
452453
}
453454

454455
static f32 l2_sqr_int8(const void *pA, const void *pB, const void *pD) {
@@ -463,7 +464,7 @@ static f32 l2_sqr_int8(const void *pA, const void *pB, const void *pD) {
463464
b++;
464465
res += t * t;
465466
}
466-
return sqrt(res);
467+
return sqrtf(res);
467468
}
468469

469470
static f32 distance_l2_sqr_float(const void *a, const void *b, const void *d) {
@@ -591,7 +592,7 @@ static f32 distance_cosine_bit_u64(u64 *a, u64 *b, size_t n) {
591592
if (aMag == 0 || bMag == 0) {
592593
return 1.0f;
593594
}
594-
return 1 - (dot / (sqrt(aMag) * sqrt(bMag)));
595+
return 1 - (dot / (sqrtf(aMag) * sqrtf(bMag)));
595596
}
596597

597598
static f32 distance_cosine_bit_u8(u8 *a, u8 *b, size_t n) {
@@ -609,7 +610,7 @@ static f32 distance_cosine_bit_u8(u8 *a, u8 *b, size_t n) {
609610
if (aMag == 0 || bMag == 0) {
610611
return 1.0f;
611612
}
612-
return 1 - (dot / (sqrt(aMag) * sqrt(bMag)));
613+
return 1 - (dot / (sqrtf(aMag) * sqrtf(bMag)));
613614
}
614615

615616
static f32 distance_cosine_bit(const void *pA, const void *pB,
@@ -642,7 +643,7 @@ static f32 distance_cosine_float(const void *pVect1v, const void *pVect2v,
642643
if (aMag == 0 || bMag == 0) {
643644
return 1.0f;
644645
}
645-
return 1 - (dot / (sqrt(aMag) * sqrt(bMag)));
646+
return 1 - (dot / (sqrtf(aMag) * sqrtf(bMag)));
646647
}
647648
static f32 distance_cosine_int8(const void *pA, const void *pB,
648649
const void *pD) {
@@ -664,7 +665,7 @@ static f32 distance_cosine_int8(const void *pA, const void *pB,
664665
if (aMag == 0 || bMag == 0) {
665666
return 1.0f;
666667
}
667-
return 1 - (dot / (sqrt(aMag) * sqrt(bMag)));
668+
return 1 - (dot / (sqrtf(aMag) * sqrtf(bMag)));
668669
}
669670

670671
static f32 distance_hamming_u8(u8 *a, u8 *b, size_t n) {
@@ -2047,20 +2048,20 @@ int vec0_parse_table_option(const char *source, int source_length,
20472048
vec0_scanner_init(&scanner, source, source_length);
20482049

20492050
rc = vec0_scanner_next(&scanner, &token);
2050-
if (rc != VEC0_TOKEN_RESULT_SOME &&
2051+
if (rc != VEC0_TOKEN_RESULT_SOME ||
20512052
token.token_type != TOKEN_TYPE_IDENTIFIER) {
20522053
return SQLITE_EMPTY;
20532054
}
20542055
key = token.start;
20552056
keyLength = token.end - token.start;
20562057

20572058
rc = vec0_scanner_next(&scanner, &token);
2058-
if (rc != VEC0_TOKEN_RESULT_SOME && token.token_type != TOKEN_TYPE_EQ) {
2059+
if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_EQ) {
20592060
return SQLITE_EMPTY;
20602061
}
20612062

20622063
rc = vec0_scanner_next(&scanner, &token);
2063-
if (rc != VEC0_TOKEN_RESULT_SOME &&
2064+
if (rc != VEC0_TOKEN_RESULT_SOME ||
20642065
!((token.token_type == TOKEN_TYPE_IDENTIFIER) ||
20652066
(token.token_type == TOKEN_TYPE_DIGIT))) {
20662067
return SQLITE_ERROR;
@@ -2103,7 +2104,7 @@ int vec0_parse_partition_key_definition(const char *source, int source_length,
21032104

21042105
// Check first token is identifier, will be the column name
21052106
int rc = vec0_scanner_next(&scanner, &token);
2106-
if (rc != VEC0_TOKEN_RESULT_SOME &&
2107+
if (rc != VEC0_TOKEN_RESULT_SOME ||
21072108
token.token_type != TOKEN_TYPE_IDENTIFIER) {
21082109
return SQLITE_EMPTY;
21092110
}
@@ -2113,7 +2114,7 @@ int vec0_parse_partition_key_definition(const char *source, int source_length,
21132114

21142115
// Check the next token matches "text" or "integer", as column type
21152116
rc = vec0_scanner_next(&scanner, &token);
2116-
if (rc != VEC0_TOKEN_RESULT_SOME &&
2117+
if (rc != VEC0_TOKEN_RESULT_SOME ||
21172118
token.token_type != TOKEN_TYPE_IDENTIFIER) {
21182119
return SQLITE_EMPTY;
21192120
}
@@ -2130,7 +2131,7 @@ int vec0_parse_partition_key_definition(const char *source, int source_length,
21302131

21312132
// Check the next token is identifier and matches "partition"
21322133
rc = vec0_scanner_next(&scanner, &token);
2133-
if (rc != VEC0_TOKEN_RESULT_SOME &&
2134+
if (rc != VEC0_TOKEN_RESULT_SOME ||
21342135
token.token_type != TOKEN_TYPE_IDENTIFIER) {
21352136
return SQLITE_EMPTY;
21362137
}
@@ -2140,7 +2141,7 @@ int vec0_parse_partition_key_definition(const char *source, int source_length,
21402141

21412142
// Check the next token is identifier and matches "key"
21422143
rc = vec0_scanner_next(&scanner, &token);
2143-
if (rc != VEC0_TOKEN_RESULT_SOME &&
2144+
if (rc != VEC0_TOKEN_RESULT_SOME ||
21442145
token.token_type != TOKEN_TYPE_IDENTIFIER) {
21452146
return SQLITE_EMPTY;
21462147
}
@@ -2186,7 +2187,7 @@ int vec0_parse_auxiliary_column_definition(const char *source, int source_length
21862187
}
21872188

21882189
rc = vec0_scanner_next(&scanner, &token);
2189-
if (rc != VEC0_TOKEN_RESULT_SOME &&
2190+
if (rc != VEC0_TOKEN_RESULT_SOME ||
21902191
token.token_type != TOKEN_TYPE_IDENTIFIER) {
21912192
return SQLITE_EMPTY;
21922193
}
@@ -2196,7 +2197,7 @@ int vec0_parse_auxiliary_column_definition(const char *source, int source_length
21962197

21972198
// Check the next token matches "text" or "integer", as column type
21982199
rc = vec0_scanner_next(&scanner, &token);
2199-
if (rc != VEC0_TOKEN_RESULT_SOME &&
2200+
if (rc != VEC0_TOKEN_RESULT_SOME ||
22002201
token.token_type != TOKEN_TYPE_IDENTIFIER) {
22012202
return SQLITE_EMPTY;
22022203
}
@@ -2318,7 +2319,7 @@ int vec0_parse_primary_key_definition(const char *source, int source_length,
23182319

23192320
// Check first token is identifier, will be the column name
23202321
int rc = vec0_scanner_next(&scanner, &token);
2321-
if (rc != VEC0_TOKEN_RESULT_SOME &&
2322+
if (rc != VEC0_TOKEN_RESULT_SOME ||
23222323
token.token_type != TOKEN_TYPE_IDENTIFIER) {
23232324
return SQLITE_EMPTY;
23242325
}
@@ -2328,7 +2329,7 @@ int vec0_parse_primary_key_definition(const char *source, int source_length,
23282329

23292330
// Check the next token matches "text" or "integer", as column type
23302331
rc = vec0_scanner_next(&scanner, &token);
2331-
if (rc != VEC0_TOKEN_RESULT_SOME &&
2332+
if (rc != VEC0_TOKEN_RESULT_SOME ||
23322333
token.token_type != TOKEN_TYPE_IDENTIFIER) {
23332334
return SQLITE_EMPTY;
23342335
}
@@ -2345,7 +2346,7 @@ int vec0_parse_primary_key_definition(const char *source, int source_length,
23452346

23462347
// Check the next token is identifier and matches "primary"
23472348
rc = vec0_scanner_next(&scanner, &token);
2348-
if (rc != VEC0_TOKEN_RESULT_SOME &&
2349+
if (rc != VEC0_TOKEN_RESULT_SOME ||
23492350
token.token_type != TOKEN_TYPE_IDENTIFIER) {
23502351
return SQLITE_EMPTY;
23512352
}
@@ -2355,7 +2356,7 @@ int vec0_parse_primary_key_definition(const char *source, int source_length,
23552356

23562357
// Check the next token is identifier and matches "key"
23572358
rc = vec0_scanner_next(&scanner, &token);
2358-
if (rc != VEC0_TOKEN_RESULT_SOME &&
2359+
if (rc != VEC0_TOKEN_RESULT_SOME ||
23592360
token.token_type != TOKEN_TYPE_IDENTIFIER) {
23602361
return SQLITE_EMPTY;
23612362
}
@@ -2525,7 +2526,7 @@ int vec0_parse_vector_column(const char *source, int source_length,
25252526
}
25262527
// ensure equal sign after distance_metric
25272528
rc = vec0_scanner_next(&scanner, &token);
2528-
if (rc != VEC0_TOKEN_RESULT_SOME && token.token_type != TOKEN_TYPE_EQ) {
2529+
if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_EQ) {
25292530
return SQLITE_ERROR;
25302531
}
25312532

@@ -9463,7 +9464,7 @@ int vec0Update_UpdateVectorColumn(vec0_vtab *p, i64 chunk_id, i64 chunk_offset,
94639464
char *pzError;
94649465
size_t dimensions;
94659466
enum VectorElementType elementType;
9466-
void *vector;
9467+
void *vector = NULL;
94679468
vector_cleanup cleanup = vector_cleanup_noop;
94689469
// https://github.com/asg017/sqlite-vec/issues/53
94699470
rc = vector_from_value(valueVector, &vector, &dimensions, &elementType,
@@ -9539,7 +9540,7 @@ int vec0Update_Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv) {
95399540
i64 chunk_id;
95409541
i64 chunk_offset;
95419542

9542-
i64 rowid;
9543+
i64 rowid = 0;
95439544
if (p->pkIsText) {
95449545
const char *a = (const char *)sqlite3_value_text(argv[0]);
95459546
const char *b = (const char *)sqlite3_value_text(argv[1]);
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# serializer version: 1
2+
# name: TestAuxiliaryColumnParser.test_invalid_auxiliary_type[invalid auxiliary type]
3+
dict({
4+
'error': 'OperationalError',
5+
'message': "vec0 constructor error: Could not parse '+aux varchar'",
6+
})
7+
# ---
8+
# name: TestAuxiliaryColumnParser.test_plus_without_name[plus without column name]
9+
dict({
10+
'error': 'OperationalError',
11+
'message': "vec0 constructor error: Could not parse '+ text'",
12+
})
13+
# ---
14+
# name: TestAuxiliaryColumnParser.test_plus_without_type[auxiliary without type]
15+
dict({
16+
'error': 'OperationalError',
17+
'message': "vec0 constructor error: Could not parse '+aux'",
18+
})
19+
# ---
20+
# name: TestMalformedDefinitions.test_double_comma[double comma]
21+
OrderedDict({
22+
'sql': 'create virtual table v using vec0(a float[4],, b float[4])',
23+
'rows': list([
24+
]),
25+
})
26+
# ---
27+
# name: TestMalformedDefinitions.test_empty_definition[empty definition]
28+
dict({
29+
'error': 'OperationalError',
30+
'message': 'vec0 constructor error: At least one vector column is required',
31+
})
32+
# ---
33+
# name: TestMalformedDefinitions.test_just_comma[just comma]
34+
dict({
35+
'error': 'OperationalError',
36+
'message': 'vec0 constructor error: At least one vector column is required',
37+
})
38+
# ---
39+
# name: TestMalformedDefinitions.test_leading_comma[leading comma]
40+
OrderedDict({
41+
'sql': 'create virtual table v using vec0(, a float[4])',
42+
'rows': list([
43+
]),
44+
})
45+
# ---
46+
# name: TestMalformedDefinitions.test_number_only[number only]
47+
dict({
48+
'error': 'OperationalError',
49+
'message': "vec0 constructor error: Could not parse '123'",
50+
})
51+
# ---
52+
# name: TestMalformedDefinitions.test_only_whitespace[only whitespace]
53+
dict({
54+
'error': 'OperationalError',
55+
'message': 'vec0 constructor error: At least one vector column is required',
56+
})
57+
# ---
58+
# name: TestMalformedDefinitions.test_special_characters[special characters]
59+
dict({
60+
'error': 'OperationalError',
61+
'message': 'unrecognized token: "@"',
62+
})
63+
# ---
64+
# name: TestMalformedDefinitions.test_trailing_comma[trailing comma]
65+
OrderedDict({
66+
'sql': 'create virtual table v using vec0(a float[4],)',
67+
'rows': list([
68+
]),
69+
})
70+
# ---
71+
# name: TestPartitionKeyParser.test_invalid_type[invalid partition key type]
72+
dict({
73+
'error': 'OperationalError',
74+
'message': "vec0 constructor error: Could not parse 'p blob partition key'",
75+
})
76+
# ---
77+
# name: TestPartitionKeyParser.test_missing_key_keyword[missing key keyword]
78+
OrderedDict({
79+
'sql': 'create virtual table v using vec0(p int partition, a float[4])',
80+
'rows': list([
81+
]),
82+
})
83+
# ---
84+
# name: TestPartitionKeyParser.test_missing_partition_keyword[missing partition keyword]
85+
OrderedDict({
86+
'sql': 'create virtual table v using vec0(p int key, a float[4])',
87+
'rows': list([
88+
]),
89+
})
90+
# ---
91+
# name: TestPartitionKeyParser.test_missing_type[partition key missing type]
92+
dict({
93+
'error': 'OperationalError',
94+
'message': "vec0 constructor error: Could not parse 'p partition key'",
95+
})
96+
# ---
97+
# name: TestPrimaryKeyParser.test_invalid_type[invalid primary key type]
98+
dict({
99+
'error': 'OperationalError',
100+
'message': "vec0 constructor error: Could not parse 'id blob primary key'",
101+
})
102+
# ---
103+
# name: TestPrimaryKeyParser.test_missing_key_keyword[missing key keyword after primary]
104+
OrderedDict({
105+
'sql': 'create virtual table v using vec0(id int primary, a float[4])',
106+
'rows': list([
107+
]),
108+
})
109+
# ---
110+
# name: TestPrimaryKeyParser.test_missing_primary_keyword[missing primary keyword]
111+
OrderedDict({
112+
'sql': 'create virtual table v using vec0(id int key, a float[4])',
113+
'rows': list([
114+
]),
115+
})
116+
# ---
117+
# name: TestPrimaryKeyParser.test_missing_type[primary key missing type]
118+
dict({
119+
'error': 'OperationalError',
120+
'message': "vec0 constructor error: Could not parse 'id primary key'",
121+
})
122+
# ---
123+
# name: TestTableOptionParser.test_extra_tokens_after_value[extra tokens after value]
124+
dict({
125+
'error': 'OperationalError',
126+
'message': "vec0 constructor error: could not parse table option 'chunk_size=8 extra'",
127+
})
128+
# ---
129+
# name: TestTableOptionParser.test_missing_equals_sign[missing equals sign]
130+
dict({
131+
'error': 'OperationalError',
132+
'message': "vec0 constructor error: Could not parse 'chunk_size 8'",
133+
})
134+
# ---
135+
# name: TestTableOptionParser.test_missing_key[missing key before equals]
136+
dict({
137+
'error': 'OperationalError',
138+
'message': "vec0 constructor error: Could not parse '=8'",
139+
})
140+
# ---
141+
# name: TestTableOptionParser.test_missing_value[missing value after equals]
142+
dict({
143+
'error': 'OperationalError',
144+
'message': "vec0 constructor error: could not parse table option 'chunk_size='",
145+
})
146+
# ---
147+
# name: TestVectorColumnParser.test_distance_metric_invalid_value[distance_metric invalid value]
148+
dict({
149+
'error': 'OperationalError',
150+
'message': "vec0 constructor error: could not parse vector column 'a float[4] distance_metric=invalid'",
151+
})
152+
# ---
153+
# name: TestVectorColumnParser.test_distance_metric_missing_equals[distance_metric missing equals]
154+
dict({
155+
'error': 'OperationalError',
156+
'message': "vec0 constructor error: could not parse vector column 'a float[4] distance_metric l2'",
157+
})
158+
# ---
159+
# name: TestVectorColumnParser.test_distance_metric_missing_value[distance_metric missing value]
160+
dict({
161+
'error': 'OperationalError',
162+
'message': "vec0 constructor error: could not parse vector column 'a float[4] distance_metric='",
163+
})
164+
# ---
165+
# name: TestVectorColumnParser.test_missing_dimensions[vector missing dimensions]
166+
dict({
167+
'error': 'OperationalError',
168+
'message': 'vec0 constructor error: At least one vector column is required',
169+
})
170+
# ---
171+
# name: TestVectorColumnParser.test_missing_type[vector missing type]
172+
dict({
173+
'error': 'OperationalError',
174+
'message': "vec0 constructor error: Could not parse 'a [4]'",
175+
})
176+
# ---
177+
# name: TestVectorColumnParser.test_negative_dimensions[negative dimensions]
178+
dict({
179+
'error': 'OperationalError',
180+
'message': "vec0 constructor error: could not parse vector column 'a float[-1]'",
181+
})
182+
# ---
183+
# name: TestVectorColumnParser.test_zero_dimensions[zero dimensions]
184+
dict({
185+
'error': 'OperationalError',
186+
'message': "vec0 constructor error: could not parse vector column 'a float[0]'",
187+
})
188+
# ---

0 commit comments

Comments
 (0)