Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Allow custom window functions to be registered with the driver #1220

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,31 @@ func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value
//export stepTrampoline
func stepTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) {
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)]
ai := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo)
ai.Step(ctx, args)
if ai, ok := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo); ok {
ai.Step(ctx, args)
}
}

//export inverseTrampoline
func inverseTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) {
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what this is doing?

Copy link
Author

@ohaibbq ohaibbq Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try my best. For the step and inverse interfaces, SQLite passes us argc, the number of arguments to consume, and argv a double pointer to underlying sqlite3_values. A simple explanation of this code is that it initializes a slice of C.sqlite3_values with the length of argc.

In the initial implementation, we used the value
1 << 30 to represent the maximum possible length of this array.

#238 identified that this caused an overflow error. It was changed to use the maximum int32 size, divided by the size of nil *C.sqlite3_value, to limit its length without overflowing.

The crux seems to be that we cannot dynamically initialize an array with a length equal to argc. We must specify a constant size at compile time.

Once this array is initialized, we slice it down to the correct length as determined by argc.

I'm not familiar enough with Go / C to be sure of the exact performance implications of this, but
if you examine the pointer list prior to slicing, you can see that it contains 268,435,455 ((math.MaxInt31 - 1) / 8) elements. it seems like it would be much better if there were a way to only initialize it to argc.

if ai, ok := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo); ok {
ai.Inverse(ctx, args)
}
}

//export valueTrampoline
func valueTrampoline(ctx *C.sqlite3_context) {
if ai, ok := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo); ok {
ai.Value(ctx)
}
}

//export doneTrampoline
func doneTrampoline(ctx *C.sqlite3_context) {
ai := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo)
ai.Done(ctx)
if ai, ok := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo); ok {
ai.Done(ctx)
}
}

//export compareTrampoline
Expand Down
239 changes: 184 additions & 55 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,25 @@ int _sqlite3_create_function(
return sqlite3_create_function(db, zFunctionName, nArg, eTextRep, (void*) pApp, xFunc, xStep, xFinal);
}

int _sqlite3_create_window_function(
sqlite3 *db,
const char *zFunctionName,
int nArg,
int eTextRep,
uintptr_t pApp,
void (*xStep)(sqlite3_context*,int,sqlite3_value**),
void (*xFinal)(sqlite3_context*),
void (*xValue)(sqlite3_context*),
void (*xInverse)(sqlite3_context*,int,sqlite3_value**)
) {
return sqlite3_create_window_function(db, zFunctionName, nArg, eTextRep, (void*) pApp, xStep, xFinal, xValue, xInverse, 0);
}


void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
void valueTrampoline(sqlite3_context*);
void inverseTrampoline(sqlite3_context*);
void doneTrampoline(sqlite3_context*);

int compareTrampoline(void*, int, char*, int, char*);
Expand Down Expand Up @@ -438,10 +455,18 @@ type aggInfo struct {
active map[int64]reflect.Value
next int64

nArgs int

stepArgConverters []callbackArgConverter
stepVariadicConverter callbackArgConverter

doneRetConverter callbackRetConverter

// Inverse and Value arg converters are used for window aggregations.
inverseArgConverters []callbackArgConverter
inverseVariadicConverter callbackArgConverter

valueRetConverter callbackRetConverter
}

func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) {
Expand All @@ -461,6 +486,8 @@ func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) {
return *aggIdx, ai.active[*aggIdx], nil
}

// Step Implements the xStep function for both aggregate and window functions
// https://www.sqlite.org/windowfunctions.html#udfwinfunc
func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
_, agg, err := ai.agg(ctx)
if err != nil {
Expand All @@ -481,6 +508,8 @@ func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
}
}

// Done Implements the xFinal function for both aggregate and window functions
// https://www.sqlite.org/windowfunctions.html#udfwinfunc
func (ai *aggInfo) Done(ctx *C.sqlite3_context) {
idx, agg, err := ai.agg(ctx)
if err != nil {
Expand All @@ -502,6 +531,49 @@ func (ai *aggInfo) Done(ctx *C.sqlite3_context) {
}
}

// Inverse Implements the xInverse function for window functions
// https://www.sqlite.org/windowfunctions.html#udfwinfunc
func (ai *aggInfo) Inverse(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
_, agg, err := ai.agg(ctx)
if err != nil {
callbackError(ctx, err)
return
}

args, err := callbackConvertArgs(argv, ai.inverseArgConverters, ai.inverseVariadicConverter)
if err != nil {
callbackError(ctx, err)
return
}

ret := agg.MethodByName("Inverse").Call(args)
if len(ret) == 1 && ret[0].Interface() != nil {
callbackError(ctx, ret[0].Interface().(error))
return
}
}

// Value Implements the xValue function for window functions
// https://www.sqlite.org/windowfunctions.html#udfwinfunc
func (ai *aggInfo) Value(ctx *C.sqlite3_context) {
_, agg, err := ai.agg(ctx)
if err != nil {
callbackError(ctx, err)
return
}
ret := agg.MethodByName("Value").Call(nil)
if len(ret) == 2 && ret[1].Interface() != nil {
callbackError(ctx, ret[1].Interface().(error))
return
}

err = ai.valueRetConverter(ctx, ret[0])
if err != nil {
callbackError(ctx, err)
return
}
}

// Commit transaction.
func (tx *SQLiteTx) Commit() error {
_, err := tx.c.exec(context.Background(), "COMMIT", nil)
Expand Down Expand Up @@ -684,20 +756,28 @@ func sqlite3CreateFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTe
return C._sqlite3_create_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(uintptr(pApp)), (*[0]byte)(xFunc), (*[0]byte)(xStep), (*[0]byte)(xFinal))
}

// RegisterAggregator makes a Go type available as a SQLite aggregation function.
func sqlite3CreateWindowFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTextRep C.int, pApp unsafe.Pointer, xStep unsafe.Pointer, xFinal unsafe.Pointer, xValue unsafe.Pointer, xInverse unsafe.Pointer) C.int {
return C._sqlite3_create_window_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(uintptr(pApp)), (*[0]byte)(xStep), (*[0]byte)(xFinal), (*[0]byte)(xValue), (*[0]byte)(xInverse))
}

// RegisterAggregator makes a Go type available as a SQLite aggregation function or window function.
//
// Because aggregation is incremental, it's implemented in Go with a
// type that has 2 methods: func Step(values) accumulates one row of
// data into the accumulator, and func Done() ret finalizes and
// returns the aggregate value. "values" and "ret" may be any type
// supported by RegisterFunc.
//
// To register a window function, the type must also contain implement
// a Value and Inverse function.
//
// RegisterAggregator takes as implementation a constructor function
// that constructs an instance of the aggregator type each time an
// aggregation begins. The constructor must return a pointer to a
// type, or an interface that implements Step() and Done().
// type, or an interface that implements Step() and Done(), and optionally
// Value() and Inverse() if the aggregator is a window function.
//
// The constructor function and the Step/Done methods may optionally
// The constructor function and the Step/Done/Value/Inverse methods may optionally
// return an error in addition to their other return values.
//
// See _example/go_custom_funcs for a detailed example.
Expand All @@ -719,93 +799,142 @@ func (c *SQLiteConn) RegisterAggregator(name string, impl any, pure bool) error
}

agg := t.Out(0)
var implReturnsPointer bool
switch agg.Kind() {
case reflect.Ptr, reflect.Interface:
case reflect.Ptr:
implReturnsPointer = true
case reflect.Interface:
implReturnsPointer = false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. How can you conclude whether it returns a pointer or not in this case?
  2. Saying it doesn't return a pointer and allowing that seems to contradict the error message in the default case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right that I can't conclude if it returns a pointer or not in this case.

It does contradict the error message in the default case, but it was previously allowed in the initial implementation of user-defined functions #229.

Perhaps there is a better name for this variable.

default:
return errors.New("SQlite aggregator constructor must return a pointer object")
return errors.New("SQLite aggregator constructor must return a pointer object")
}

stepFn, found := agg.MethodByName("Step")
if !found {
return errors.New("SQlite aggregator doesn't have a Step() function")
return errors.New("SQLite aggregator doesn't have a Step() function")
}
err := ai.setupStepInterface(stepFn, &ai.stepArgConverters, &ai.stepVariadicConverter, implReturnsPointer, "Step()")
Copy link
Collaborator

@rittneje rittneje Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't entirely understand this call. The penultimate parameter to setupStepInterface is named isImplPointer, which actually means there is a method receiver? And that doesn't seem to have anything to do with implReturnsPointer.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could rename the parameter to hasMethodReceiver.

I maintained the behavior from the original implementation, where we'd only skip the method receiver if the agg.Kind() == reflect.Pointer, but if we must always return a pointer, then there should always be a method receiver.

According to the reflect.Type docs:

// For a non-interface type T or *T, the returned Method's Type and Func
// fields describe a function whose first argument is the receiver,
// and only exported methods are accessible.

When running sqlite3_test.go, I did not encounter any cases where implReturnsPointer would be false. I wonder if we should remove this branch of the logic.

if err != nil {
return err
}
step := stepFn.Type
if step.NumOut() != 0 && step.NumOut() != 1 {
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")

doneFn, found := agg.MethodByName("Done")
if !found {
return errors.New("SQLite aggregator doesn't have a Done() function")
}
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("type of SQlite aggregator Step() return value must be error")
err = ai.setupDoneInterface(doneFn, &ai.doneRetConverter, implReturnsPointer, "Done()")
if err != nil {
return err
}

stepNArgs := step.NumIn()
valueFn, valueFnFound := agg.MethodByName("Value")
inverseFn, inverseFnFound := agg.MethodByName("Inverse")
if (inverseFnFound && !valueFnFound) || (valueFnFound && !inverseFnFound) {
return errors.New("SQLite window aggregator must implement both Value() and Inverse() functions")
}
isWindowFunction := valueFnFound && inverseFnFound
// Validate window function interface
if isWindowFunction {
if inverseFn.Type.NumIn() != stepFn.Type.NumIn() {
return errors.New("SQLite window aggregator Inverse() function must accept the same number of arguments as Step()")
}
err := ai.setupStepInterface(inverseFn, &ai.inverseArgConverters, &ai.inverseVariadicConverter, implReturnsPointer, "Inverse()")
if err != nil {
return err
}
err = ai.setupDoneInterface(valueFn, &ai.valueRetConverter, implReturnsPointer, "Value()")
if err != nil {
return err
}
}

ai.active = make(map[int64]reflect.Value)
ai.next = 1

// ai must outlast the database connection, or we'll have dangling pointers.
c.aggregators = append(c.aggregators, &ai)

cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
opts := C.SQLITE_UTF8
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
var rv C.int
if isWindowFunction {
rv = sqlite3CreateWindowFunction(c.db, cname, C.int(ai.nArgs), C.int(opts), newHandle(c, &ai), C.stepTrampoline, C.doneTrampoline, C.valueTrampoline, C.inverseTrampoline)
} else {
rv = sqlite3CreateFunction(c.db, cname, C.int(ai.nArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
}
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}

func (ai *aggInfo) setupStepInterface(fn reflect.Method, argConverters *[]callbackArgConverter, variadicConverter *callbackArgConverter, isImplPointer bool, name string) error {
t := fn.Type
if t.NumOut() != 0 && t.NumOut() != 1 {
return fmt.Errorf("SQLite aggregator %s function must return 0 or 1 values", name)
}
if t.NumOut() == 1 && !t.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return fmt.Errorf("type of SQLite aggregator %s return value must be error", name)
}
nArgs := t.NumIn()
start := 0
if agg.Kind() == reflect.Ptr {
if isImplPointer {
// Skip over the method receiver
stepNArgs--
nArgs--
start++
}
if step.IsVariadic() {
stepNArgs--
if t.IsVariadic() {
nArgs--
}
for i := start; i < start+stepNArgs; i++ {
conv, err := callbackArg(step.In(i))
for i := start; i < start+nArgs; i++ {
conv, err := callbackArg(t.In(i))
if err != nil {
return err
}
ai.stepArgConverters = append(ai.stepArgConverters, conv)

*argConverters = append(*argConverters, conv)
}
if step.IsVariadic() {
conv, err := callbackArg(step.In(start + stepNArgs).Elem())
if t.IsVariadic() {
conv, err := callbackArg(t.In(start + nArgs).Elem())
if err != nil {
return err
}
ai.stepVariadicConverter = conv
*variadicConverter = conv
// Pass -1 to sqlite so that it allows any number of
// arguments. The call helper verifies that the minimum number
// of arguments is present for variadic functions.
stepNArgs = -1
nArgs = -1
}
ai.nArgs = nArgs
return nil
}

doneFn, found := agg.MethodByName("Done")
if !found {
return errors.New("SQlite aggregator doesn't have a Done() function")
}
done := doneFn.Type
doneNArgs := done.NumIn()
if agg.Kind() == reflect.Ptr {
func (ai *aggInfo) setupDoneInterface(fn reflect.Method, retConverter *callbackRetConverter, implReturnsPointer bool, name string) error {
t := fn.Type
nArgs := t.NumIn()
if implReturnsPointer {
// Skip over the method receiver
doneNArgs--
nArgs--
}
if doneNArgs != 0 {
return errors.New("SQlite aggregator Done() function must have no arguments")
if nArgs != 0 {
return fmt.Errorf("SQlite aggregator %s function must have no arguments", name)
}
if done.NumOut() != 1 && done.NumOut() != 2 {
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
if t.NumOut() != 1 && t.NumOut() != 2 {
return fmt.Errorf("SQLite aggregator %s function must return 1 or 2 values", name)
}
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("second return value of SQLite aggregator Done() function must be error")
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return fmt.Errorf("second return value of SQLite aggregator %s function must be error", name)
}

conv, err := callbackRet(done.Out(0))
conv, err := callbackRet(t.Out(0))
if err != nil {
return err
}
ai.doneRetConverter = conv
ai.active = make(map[int64]reflect.Value)
ai.next = 1

// ai must outlast the database connection, or we'll have dangling pointers.
c.aggregators = append(c.aggregators, &ai)

cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
opts := C.SQLITE_UTF8
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
if rv != C.SQLITE_OK {
return c.lastError()
}
*retConverter = conv
return nil
}

Expand Down
Loading