Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Allow custom window functions to be registered #1216

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,36 @@ func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value
//export stepTrampoline
func stepTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) {
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)]
ai := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo)
ai.Step(ctx, args)
if ai, ok := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo); ok {
ai.Step(ctx, args)
}
if window, ok := lookupHandle(C.sqlite3_user_data(ctx)).(*windowAggInfo); ok {
window.Step(ctx, args)
}
}

//export inverseTrampoline
func inverseTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) {
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)]
window := lookupHandle(C.sqlite3_user_data(ctx)).(*windowAggInfo)
window.Inverse(ctx, args)
}

//export valueTrampoline
func valueTrampoline(ctx *C.sqlite3_context) {
window := lookupHandle(C.sqlite3_user_data(ctx)).(*windowAggInfo)
window.Value(ctx)
}

//export doneTrampoline
func doneTrampoline(ctx *C.sqlite3_context) {
ai := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo)
ai.Done(ctx)
if ai, ok := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo); ok {
ai.Done(ctx)
}

if window, ok := lookupHandle(C.sqlite3_user_data(ctx)).(*windowAggInfo); ok {
window.Done(ctx)
}
}

//export compareTrampoline
Expand Down
264 changes: 264 additions & 0 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,25 @@ int _sqlite3_create_function(
return sqlite3_create_function(db, zFunctionName, nArg, eTextRep, (void*) pApp, xFunc, xStep, xFinal);
}

int _sqlite3_create_window_function(
sqlite3 *db,
const char *zFunctionName,
int nArg,
int eTextRep,
uintptr_t pApp,
void (*xStep)(sqlite3_context*,int,sqlite3_value**),
void (*xFinal)(sqlite3_context*),
void (*xValue)(sqlite3_context*),
void (*xInverse)(sqlite3_context*,int,sqlite3_value**)
) {
return sqlite3_create_window_function(db, zFunctionName, nArg, eTextRep, (void*) pApp, xStep, xFinal, xValue, xInverse, 0);
}


void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
void valueTrampoline(sqlite3_context*);
void inverseTrampoline(sqlite3_context*);
void doneTrampoline(sqlite3_context*);

int compareTrampoline(void*, int, char*, int, char*);
Expand Down Expand Up @@ -367,6 +384,7 @@ type SQLiteConn struct {
txlock string
funcs []*functionInfo
aggregators []*aggInfo
windows []*windowAggInfo
}

// SQLiteTx implements driver.Tx.
Expand Down Expand Up @@ -444,6 +462,15 @@ type aggInfo struct {
doneRetConverter callbackRetConverter
}

type windowAggInfo struct {
aggInfo

inverseArgConverters []callbackArgConverter
inverseVariadicConverter callbackArgConverter

valueRetConverter callbackRetConverter
}

func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) {
aggIdx := (*int64)(C.sqlite3_aggregate_context(ctx, C.int(8)))
if *aggIdx == 0 {
Expand Down Expand Up @@ -502,6 +529,45 @@ func (ai *aggInfo) Done(ctx *C.sqlite3_context) {
}
}

func (window *windowAggInfo) Inverse(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
_, agg, err := window.agg(ctx)
if err != nil {
callbackError(ctx, err)
return
}

args, err := callbackConvertArgs(argv, window.inverseArgConverters, window.inverseVariadicConverter)
if err != nil {
callbackError(ctx, err)
return
}

ret := agg.MethodByName("Inverse").Call(args)
if len(ret) == 1 && ret[0].Interface() != nil {
callbackError(ctx, ret[0].Interface().(error))
return
}
}

func (window *windowAggInfo) Value(ctx *C.sqlite3_context) {
_, agg, err := window.agg(ctx)
if err != nil {
callbackError(ctx, err)
return
}
ret := agg.MethodByName("Value").Call(nil)
if len(ret) == 2 && ret[1].Interface() != nil {
callbackError(ctx, ret[1].Interface().(error))
return
}

err = window.valueRetConverter(ctx, ret[0])
if err != nil {
callbackError(ctx, err)
return
}
}

// Commit transaction.
func (tx *SQLiteTx) Commit() error {
_, err := tx.c.exec(context.Background(), "COMMIT", nil)
Expand Down Expand Up @@ -845,6 +911,204 @@ func lastError(db *C.sqlite3) error {
}
}

func sqlite3CreateWindowFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTextRep C.int, pApp unsafe.Pointer, xStep unsafe.Pointer, xFinal unsafe.Pointer, xValue unsafe.Pointer, xInverse unsafe.Pointer) C.int {
return C._sqlite3_create_window_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(uintptr(pApp)), (*[0]byte)(xStep), (*[0]byte)(xFinal), (*[0]byte)(xValue), (*[0]byte)(xInverse))
}

// RegisterAggregator makes a Go type available as a SQLite aggregation function.
//
// Because aggregation is incremental, it's implemented in Go with a
// type that has 2 methods: func Step(values) accumulates one row of
// data into the accumulator, and func Done() ret finalizes and
// returns the aggregate value. "values" and "ret" may be any type
// supported by RegisterFunc.
//
// RegisterAggregator takes as implementation a constructor function
// that constructs an instance of the aggregator type each time an
// aggregation begins. The constructor must return a pointer to a
// type, or an interface that implements Step() and Done().
//
// The constructor function and the Step/Done methods may optionally
// return an error in addition to their other return values.
//
// See _example/go_custom_funcs for a detailed example.
func (c *SQLiteConn) RegisterWindowAggregator(name string, impl any, pure bool) error {
var window windowAggInfo
window.constructor = reflect.ValueOf(impl)
t := window.constructor.Type()
if t.Kind() != reflect.Func {
return errors.New("non-function passed to RegisterWindowAggregator")
}
if t.NumOut() != 1 && t.NumOut() != 2 {
return errors.New("SQLite aggregator constructors must return 4 or 5 values")
}
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("second return value of SQLite function must be error")
}
if t.NumIn() != 0 {
return errors.New("SQLite window aggregator constructors must not have arguments")
}

agg := t.Out(0)
switch agg.Kind() {
case reflect.Ptr, reflect.Interface:
default:
return errors.New("SQLite aggregator constructor must return a pointer object")
}
stepFn, found := agg.MethodByName("Step")
if !found {
return errors.New("SQLite aggregator doesn't have a Step() function")
}
step := stepFn.Type
if step.NumOut() != 0 && step.NumOut() != 1 {
return errors.New("SQLite aggregator Step() function must return 0 or 1 values")
}
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("type of SQLite aggregator Step() return value must be error")
}

stepNArgs := step.NumIn()
start := 0
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
stepNArgs--
start++
}
if step.IsVariadic() {
stepNArgs--
}
for i := start; i < start+stepNArgs; i++ {
conv, err := callbackArg(step.In(i))
if err != nil {
return err
}
window.stepArgConverters = append(window.stepArgConverters, conv)
}
if step.IsVariadic() {
conv, err := callbackArg(step.In(start + stepNArgs).Elem())
if err != nil {
return err
}
window.stepVariadicConverter = conv
// Pass -1 to sqlite so that it allows any number of
// arguments. The call helper verifies that the minimum number
// of arguments is present for variadic functions.
stepNArgs = -1
}

inverseFn, found := agg.MethodByName("Inverse")
if !found {
return errors.New("SQLite aggregator doesn't have a Inverse() function")
}
inverse := inverseFn.Type
if inverse.NumOut() != 0 && inverse.NumOut() != 1 {
return errors.New("SQLite aggregator Inverse() function must return 0 or 1 values")
}
if inverse.NumOut() == 1 && !inverse.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("type of SQLite aggregator Inverse() return value must be error")
}

inverseNArgs := inverse.NumIn()
start = 0
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
inverseNArgs--
start++
}
if inverse.IsVariadic() {
inverseNArgs--
}
for i := start; i < start+inverseNArgs; i++ {
conv, err := callbackArg(inverse.In(i))
if err != nil {
return err
}
window.inverseArgConverters = append(window.inverseArgConverters, conv)
}
if inverse.IsVariadic() {
conv, err := callbackArg(inverse.In(start + inverseNArgs).Elem())
if err != nil {
return err
}
window.inverseVariadicConverter = conv
// Pass -1 to sqlite so that it allows any number of
// arguments. The call helper verifies that the minimum number
// of arguments is present for variadic functions.
inverseNArgs = -1
}

doneFn, found := agg.MethodByName("Done")
if !found {
return errors.New("SQLite aggregator doesn't have a Done() function")
}
done := doneFn.Type
doneNArgs := done.NumIn()
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
doneNArgs--
}
if doneNArgs != 0 {
return errors.New("SQLite aggregator Done() function must have no arguments")
}
if done.NumOut() != 1 && done.NumOut() != 2 {
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
}
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("second return value of SQLite aggregator Done() function must be error")
}

conv, err := callbackRet(done.Out(0))
if err != nil {
return err
}
window.doneRetConverter = conv
window.active = make(map[int64]reflect.Value)
window.next = 1

valueFn, found := agg.MethodByName("Done")
if !found {
return errors.New("SQLite window aggregator doesn't have a Value() function")
}
value := valueFn.Type
valueNArgs := value.NumIn()
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
valueNArgs--
}
if valueNArgs != 0 {
return errors.New("SQLite window aggregator Value() function must have no arguments")
}
if value.NumOut() != 1 && value.NumOut() != 2 {
return errors.New("SQLite aggregator value() function must return 1 or 2 values")
}
if value.NumOut() == 2 && !value.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("second return value of SQLite aggregator value() function must be error")
}

valueConv, err := callbackRet(value.Out(0))
if err != nil {
return err
}
window.valueRetConverter = valueConv
window.active = make(map[int64]reflect.Value)
window.next = 1

// window must outlast the database connection, or we'll have dangling pointers.
c.windows = append(c.windows, &window)

cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
opts := C.SQLITE_UTF8
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
rv := sqlite3CreateWindowFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &window), C.stepTrampoline, C.doneTrampoline, C.valueTrampoline, C.inverseTrampoline)
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}

// Exec implements Execer.
func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) {
list := make([]driver.NamedValue, len(args))
Expand Down
Loading