Skip to content

Commit

Permalink
Merge pull request #73 from cookieo9/extlist
Browse files Browse the repository at this point in the history
Change extension loading mechanism to use a string list of extensions
  • Loading branch information
mattn committed Aug 25, 2013
2 parents 248e51c + 0dd7156 commit a3e3a8e
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 16 deletions.
11 changes: 8 additions & 3 deletions example/extension/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ import (
func main() {
sql.Register("sqlite3_with_extensions",
&sqlite3.SQLiteDriver{
EnableLoadExtension: true,
ConnectHook: nil,
Extensions: []string{
"sqlite3_mod_regexp.dll",
},
})

db, err := sql.Open("sqlite3_with_extensions", ":memory:")
Expand All @@ -20,11 +21,15 @@ func main() {
}
defer db.Close()

_, err = db.Exec("select load_extension('sqlite3_mod_regexp.dll')")
// Force db to make a new connection in pool
// by putting the original in a transaction
tx, err := db.Begin()
if err != nil {
log.Fatal(err)
}
defer tx.Commit()

// New connection works (hopefully!)
rows, err := db.Query("select 'hello world' where 'hello world' regexp '^hello.*d$'")
if err != nil {
log.Fatal(err)
Expand Down
46 changes: 33 additions & 13 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ var SQLiteTimestampFormats = []string{
}

func init() {
sql.Register("sqlite3", &SQLiteDriver{false, nil})
sql.Register("sqlite3", &SQLiteDriver{})
}

// Driver struct.
type SQLiteDriver struct {
EnableLoadExtension bool
ConnectHook func(*SQLiteConn)
Extensions []string
ConnectHook func(*SQLiteConn) error
}

// Conn struct.
Expand Down Expand Up @@ -182,19 +182,39 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}

enableLoadExtension := 0
if d.EnableLoadExtension {
enableLoadExtension = 1
}
rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtension))
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}

conn := &SQLiteConn{db}

if len(d.Extensions) > 0 {
rv = C.sqlite3_enable_load_extension(db, 1)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}

stmt, err := conn.Prepare("SELECT load_extension(?);")
if err != nil {
return nil, err
}

for _, extension := range d.Extensions {
if _, err = stmt.Exec([]driver.Value{extension}); err != nil {
return nil, err
}
}

if err = stmt.Close(); err != nil {
return nil, err
}

rv = C.sqlite3_enable_load_extension(db, 0)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
}

if d.ConnectHook != nil {
d.ConnectHook(conn)
if err := d.ConnectHook(conn); err != nil {
return nil, err
}
}

return conn, nil
Expand Down

0 comments on commit a3e3a8e

Please sign in to comment.