diff --git a/example/extension/extension.go b/example/extension/extension.go index d4b8fdb7..49eacf17 100644 --- a/example/extension/extension.go +++ b/example/extension/extension.go @@ -8,10 +8,30 @@ import ( ) func main() { + const ( + use_hook = true + load_query = "SELECT load_extension('sqlite3_mod_regexp.dll')" + ) + sql.Register("sqlite3_with_extensions", &sqlite3.SQLiteDriver{ EnableLoadExtension: true, - ConnectHook: nil, + ConnectHook: func(c *sqlite3.SQLiteConn) error { + if use_hook { + stmt, err := c.Prepare(load_query) + if err != nil { + return err + } + + _, err = stmt.Exec(nil) + if err != nil { + return err + } + + return stmt.Close() + } + return nil + }, }) db, err := sql.Open("sqlite3_with_extensions", ":memory:") @@ -20,11 +40,21 @@ func main() { } defer db.Close() - _, err = db.Exec("select load_extension('sqlite3_mod_regexp.dll')") + if !use_hook { + if _, err = db.Exec(load_query); err != nil { + log.Fatal(err) + } + } + + // Force db to make a new connection in pool + // by putting the original in a transaction + tx, err := db.Begin() if err != nil { log.Fatal(err) } + defer tx.Commit() + // New connection works (hopefully!) rows, err := db.Query("select 'hello world' where 'hello world' regexp '^hello.*d$'") if err != nil { log.Fatal(err) diff --git a/sqlite3.go b/sqlite3.go index 692306d0..cc42c13d 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -78,7 +78,7 @@ func init() { // Driver struct. type SQLiteDriver struct { EnableLoadExtension bool - ConnectHook func(*SQLiteConn) + ConnectHook func(*SQLiteConn) error } // Conn struct. @@ -194,7 +194,9 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { conn := &SQLiteConn{db} if d.ConnectHook != nil { - d.ConnectHook(conn) + if err := d.ConnectHook(conn); err != nil { + return nil, err + } } return conn, nil