diff --git a/example/extension/extension.go b/example/extension/extension.go index d4b8fdb7..f58ea3ad 100644 --- a/example/extension/extension.go +++ b/example/extension/extension.go @@ -10,8 +10,9 @@ import ( func main() { sql.Register("sqlite3_with_extensions", &sqlite3.SQLiteDriver{ - EnableLoadExtension: true, - ConnectHook: nil, + Extensions: []string{ + "sqlite3_mod_regexp.dll", + }, }) db, err := sql.Open("sqlite3_with_extensions", ":memory:") @@ -20,11 +21,15 @@ func main() { } defer db.Close() - _, err = db.Exec("select load_extension('sqlite3_mod_regexp.dll')") + // Force db to make a new connection in pool + // by putting the original in a transaction + tx, err := db.Begin() if err != nil { log.Fatal(err) } + defer tx.Commit() + // New connection works (hopefully!) rows, err := db.Query("select 'hello world' where 'hello world' regexp '^hello.*d$'") if err != nil { log.Fatal(err) diff --git a/sqlite3.go b/sqlite3.go index 692306d0..e7417ec4 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -72,13 +72,13 @@ var SQLiteTimestampFormats = []string{ } func init() { - sql.Register("sqlite3", &SQLiteDriver{false, nil}) + sql.Register("sqlite3", &SQLiteDriver{}) } // Driver struct. type SQLiteDriver struct { - EnableLoadExtension bool - ConnectHook func(*SQLiteConn) + Extensions []string + ConnectHook func(*SQLiteConn) error } // Conn struct. @@ -182,19 +182,39 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) } - enableLoadExtension := 0 - if d.EnableLoadExtension { - enableLoadExtension = 1 - } - rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtension)) - if rv != C.SQLITE_OK { - return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) - } - conn := &SQLiteConn{db} + if len(d.Extensions) > 0 { + rv = C.sqlite3_enable_load_extension(db, 1) + if rv != C.SQLITE_OK { + return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) + } + + stmt, err := conn.Prepare("SELECT load_extension(?);") + if err != nil { + return nil, err + } + + for _, extension := range d.Extensions { + if _, err = stmt.Exec([]driver.Value{extension}); err != nil { + return nil, err + } + } + + if err = stmt.Close(); err != nil { + return nil, err + } + + rv = C.sqlite3_enable_load_extension(db, 0) + if rv != C.SQLITE_OK { + return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) + } + } + if d.ConnectHook != nil { - d.ConnectHook(conn) + if err := d.ConnectHook(conn); err != nil { + return nil, err + } } return conn, nil