Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions _example/hook/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ func main() {
&sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
sqlite3conn = append(sqlite3conn, conn)
conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
switch op {
case sqlite3.SQLITE_INSERT:
log.Println("Notified of insert on db", db, "table", table, "rowid", rowid)
}
})
return nil
},
})
Expand Down
18 changes: 18 additions & 0 deletions callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,24 @@ func doneTrampoline(ctx *C.sqlite3_context) {
ai.Done(ctx)
}

//export commitHookTrampoline
func commitHookTrampoline(handle uintptr) int {
callback := lookupHandle(handle).(func() int)
return callback()
}

//export rollbackHookTrampoline
func rollbackHookTrampoline(handle uintptr) {
callback := lookupHandle(handle).(func())
callback()
}

//export updateHookTrampoline
func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, rowid int64) {
callback := lookupHandle(handle).(func(int, string, string, int64))
callback(op, C.GoString(db), C.GoString(table), rowid)
}

// Use handles to avoid passing Go pointers to C.

type handleVal struct {
Expand Down
54 changes: 54 additions & 0 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ int _sqlite3_create_function(
}

void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
int commitHookTrampoline(void*);
void rollbackHookTrampoline(void*);
void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64);
*/
import "C"
import (
Expand Down Expand Up @@ -145,6 +148,12 @@ func Version() (libVersion string, libVersionNumber int, sourceId string) {
return libVersion, libVersionNumber, sourceId
}

const (
SQLITE_DELETE = C.SQLITE_DELETE
SQLITE_INSERT = C.SQLITE_INSERT
SQLITE_UPDATE = C.SQLITE_UPDATE
)

// Driver struct.
type SQLiteDriver struct {
Extensions []string
Expand Down Expand Up @@ -310,6 +319,51 @@ func (tx *SQLiteTx) Rollback() error {
return err
}

// RegisterCommitHook sets the commit hook for a connection.
//
// If the callback returns non-zero the transaction will become a rollback.
//
// If there is an existing commit hook for this connection, it will be
// removed. If callback is nil the existing hook (if any) will be removed
// without creating a new one.
func (c *SQLiteConn) RegisterCommitHook(callback func() int) {
if callback == nil {
C.sqlite3_commit_hook(c.db, nil, nil)
} else {
C.sqlite3_commit_hook(c.db, (*[0]byte)(unsafe.Pointer(C.commitHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
}
}

// RegisterRollbackHook sets the rollback hook for a connection.
//
// If there is an existing rollback hook for this connection, it will be
// removed. If callback is nil the existing hook (if any) will be removed
// without creating a new one.
func (c *SQLiteConn) RegisterRollbackHook(callback func()) {
if callback == nil {
C.sqlite3_rollback_hook(c.db, nil, nil)
} else {
C.sqlite3_rollback_hook(c.db, (*[0]byte)(unsafe.Pointer(C.rollbackHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
}
}

// RegisterUpdateHook sets the update hook for a connection.
//
// The parameters to the callback are the operation (one of the constants
// SQLITE_INSERT, SQLITE_DELETE, or SQLITE_UPDATE), the database name, the
// table name, and the rowid.
//
// If there is an existing update hook for this connection, it will be
// removed. If callback is nil the existing hook (if any) will be removed
// without creating a new one.
func (c *SQLiteConn) RegisterUpdateHook(callback func(int, string, string, int64)) {
if callback == nil {
C.sqlite3_update_hook(c.db, nil, nil)
} else {
C.sqlite3_update_hook(c.db, (*[0]byte)(unsafe.Pointer(C.updateHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
}
}

// RegisterFunc makes a Go function available as a SQLite function.
//
// The Go function can have arguments of the following types: any
Expand Down
61 changes: 61 additions & 0 deletions sqlite3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,67 @@ func TestDeclTypes(t *testing.T) {
}
}

func TestUpdateAndTransactionHooks(t *testing.T) {
var events []string
var commitHookReturn = 0

sql.Register("sqlite3_UpdateHook", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
conn.RegisterCommitHook(func() int {
events = append(events, "commit")
return commitHookReturn
})
conn.RegisterRollbackHook(func() {
events = append(events, "rollback")
})
conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
events = append(events, fmt.Sprintf("update(op=%v db=%v table=%v rowid=%v)", op, db, table, rowid))
})
return nil
},
})
db, err := sql.Open("sqlite3_UpdateHook", ":memory:")
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()

statements := []string{
"create table foo (id integer primary key)",
"insert into foo values (9)",
"update foo set id = 99 where id = 9",
"delete from foo where id = 99",
}
for _, statement := range statements {
_, err = db.Exec(statement)
if err != nil {
t.Fatalf("Unable to prepare test data [%v]: %v", statement, err)
}
}

commitHookReturn = 1
_, err = db.Exec("insert into foo values (5)")
if err == nil {
t.Error("Commit hook failed to rollback transaction")
}

var expected = []string{
"commit",
fmt.Sprintf("update(op=%v db=main table=foo rowid=9)", SQLITE_INSERT),
"commit",
fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_UPDATE),
"commit",
fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_DELETE),
"commit",
fmt.Sprintf("update(op=%v db=main table=foo rowid=5)", SQLITE_INSERT),
"commit",
"rollback",
}
if !reflect.DeepEqual(events, expected) {
t.Errorf("Expected notifications %v but got %v", expected, events)
}
}

var customFunctionOnce sync.Once

func BenchmarkCustomFunctions(b *testing.B) {
Expand Down