Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,9 @@ type SQLiteDriver struct {

// Conn struct.
type SQLiteConn struct {
db *C.sqlite3
loc *time.Location
db *C.sqlite3
loc *time.Location
txlock string
}

// Tx struct.
Expand Down Expand Up @@ -252,7 +253,7 @@ func (c *SQLiteConn) exec(cmd string) (driver.Result, error) {

// Begin transaction.
func (c *SQLiteConn) Begin() (driver.Tx, error) {
if _, err := c.exec("BEGIN"); err != nil {
if _, err := c.exec(c.txlock); err != nil {
return nil, err
}
return &SQLiteTx{c}, nil
Expand All @@ -273,12 +274,16 @@ func errorString(err Error) string {
// Specify location of time format. It's possible to specify "auto".
// _busy_timeout=XXX
// Specify value for sqlite3_busy_timeout.
// _txlock=XXX
// Specify locking behavior for transactions. XXX can be "immediate",
// "deferred", "exclusive".
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
}

var loc *time.Location
txlock := "BEGIN"
busy_timeout := 5000
pos := strings.IndexRune(dsn, '?')
if pos >= 1 {
Expand Down Expand Up @@ -308,6 +313,20 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
busy_timeout = int(iv)
}

// _txlock
if val := params.Get("_txlock"); val != "" {
switch val {
case "immediate":
txlock = "BEGIN IMMEDIATE"
case "exclusive":
txlock = "BEGIN EXCLUSIVE"
case "deferred":
txlock = "BEGIN"
default:
return nil, fmt.Errorf("Invalid _txlock: %v", val)
}
}

if !strings.HasPrefix(dsn, "file:") {
dsn = dsn[:pos]
}
Expand All @@ -333,7 +352,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, Error{Code: ErrNo(rv)}
}

conn := &SQLiteConn{db: db, loc: loc}
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}

if len(d.Extensions) > 0 {
rv = C.sqlite3_enable_load_extension(db, 1)
Expand Down
44 changes: 39 additions & 5 deletions sqlite3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/rand"
"database/sql"
"encoding/hex"
"fmt"
"net/url"
"os"
"path/filepath"
Expand All @@ -25,23 +26,56 @@ func TempFilename() string {
return filepath.Join(os.TempDir(), "foo"+hex.EncodeToString(randBytes)+".db")
}

func TestOpen(t *testing.T) {
func doTestOpen(t *testing.T, option string) (string, error) {
var url string
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename)
if option != "" {
url = tempFilename + option
} else {
url = tempFilename
}
db, err := sql.Open("sqlite3", url)
if err != nil {
t.Fatal("Failed to open database:", err)
return "Failed to open database:", err
}
defer os.Remove(tempFilename)
defer db.Close()

_, err = db.Exec("drop table foo")
_, err = db.Exec("create table foo (id integer)")
if err != nil {
t.Fatal("Failed to create table:", err)
return "Failed to create table:", err
}

if stat, err := os.Stat(tempFilename); err != nil || stat.IsDir() {
t.Error("Failed to create ./foo.db")
return "Failed to create ./foo.db", nil
}

return "", nil
}

func TestOpen(t *testing.T) {
cases := map[string]bool{
"": true,
"?_txlock=immediate": true,
"?_txlock=deferred": true,
"?_txlock=exclusive": true,
"?_txlock=bogus": false,
}
for option, expectedPass := range cases {
result, err := doTestOpen(t, option)
if result == "" {
if ! expectedPass {
errmsg := fmt.Sprintf("_txlock error not caught at dbOpen with option: %s", option)
t.Fatal(errmsg)
}
} else if expectedPass {
if err == nil {
t.Fatal(result)
} else {
t.Fatal(result, err)
}
}
}
}

Expand Down