Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 89 additions & 37 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ import (
"strings"
"time"
"unsafe"

"golang.org/x/net/context"
)

// Timestamp formats understood by both this module and SQLite.
Expand Down Expand Up @@ -170,8 +172,6 @@ type SQLiteTx struct {
type SQLiteStmt struct {
c *SQLiteConn
s *C.sqlite3_stmt
nv int
nn []string
t string
closed bool
cls bool
Expand Down Expand Up @@ -295,19 +295,19 @@ func (ai *aggInfo) Done(ctx *C.sqlite3_context) {

// Commit transaction.
func (tx *SQLiteTx) Commit() error {
_, err := tx.c.exec("COMMIT")
_, err := tx.c.execQuery("COMMIT")
if err != nil && err.(Error).Code == C.SQLITE_BUSY {
// sqlite3 will leave the transaction open in this scenario.
// However, database/sql considers the transaction complete once we
// return from Commit() - we must clean up to honour its semantics.
tx.c.exec("ROLLBACK")
tx.c.execQuery("ROLLBACK")
}
return err
}

// Rollback transaction.
func (tx *SQLiteTx) Rollback() error {
_, err := tx.c.exec("ROLLBACK")
_, err := tx.c.execQuery("ROLLBACK")
return err
}

Expand Down Expand Up @@ -404,9 +404,21 @@ func (c *SQLiteConn) lastError() Error {
// Implements Execer
func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if len(args) == 0 {
return c.exec(query)
return c.execQuery(query)
}

list := make([]namedValue, len(args))
for i, v := range args {
list[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
return c.exec(context.Background(), query, list)
}

func (c *SQLiteConn) exec(ctx context.Context, query string, args []namedValue) (driver.Result, error) {
start := 0
for {
s, err := c.Prepare(query)
if err != nil {
Expand All @@ -418,12 +430,16 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
if len(args) < na {
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
}
res, err = s.Exec(args[:na])
for i := 0; i < na; i++ {
args[i].Ordinal -= start
}
res, err = s.(*SQLiteStmt).exec(ctx, args[:na])
if err != nil && err != driver.ErrSkip {
s.Close()
return nil, err
}
args = args[na:]
start += na
}
tail := s.(*SQLiteStmt).t
s.Close()
Expand All @@ -434,8 +450,26 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
}
}

type namedValue struct {
Name string
Ordinal int
Value driver.Value
}

// Implements Queryer
func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) {
list := make([]namedValue, len(args))
for i, v := range args {
list[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
return c.query(context.Background(), query, list)
}

func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue) (driver.Rows, error) {
start := 0
for {
s, err := c.Prepare(query)
if err != nil {
Expand All @@ -446,12 +480,16 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
if len(args) < na {
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
}
rows, err := s.Query(args[:na])
for i := 0; i < na; i++ {
args[i].Ordinal -= start
}
rows, err := s.(*SQLiteStmt).query(ctx, args[:na])
if err != nil && err != driver.ErrSkip {
s.Close()
return nil, err
}
args = args[na:]
start += na
tail := s.(*SQLiteStmt).t
if tail == "" {
return rows, nil
Expand All @@ -462,7 +500,7 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
}
}

func (c *SQLiteConn) exec(cmd string) (driver.Result, error) {
func (c *SQLiteConn) execQuery(cmd string) (driver.Result, error) {
pcmd := C.CString(cmd)
defer C.free(unsafe.Pointer(pcmd))

Expand All @@ -476,7 +514,11 @@ func (c *SQLiteConn) exec(cmd string) (driver.Result, error) {

// Begin transaction.
func (c *SQLiteConn) Begin() (driver.Tx, error) {
if _, err := c.exec(c.txlock); err != nil {
return c.begin(context.Background())
}

func (c *SQLiteConn) begin(ctx context.Context) (driver.Tx, error) {
if _, err := c.execQuery(c.txlock); err != nil {
return nil, err
}
return &SQLiteTx{c}, nil
Expand Down Expand Up @@ -606,6 +648,10 @@ func (c *SQLiteConn) Close() error {

// Prepare the query string. Return a new statement.
func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
return c.prepare(context.Background(), query)
}

func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, error) {
pquery := C.CString(query)
defer C.free(unsafe.Pointer(pquery))
var s *C.sqlite3_stmt
Expand All @@ -618,15 +664,7 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
if tail != nil && *tail != '\000' {
t = strings.TrimSpace(C.GoString(tail))
}
nv := int(C.sqlite3_bind_parameter_count(s))
var nn []string
for i := 0; i < nv; i++ {
pn := C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1)))
if len(pn) > 1 && pn[0] == '$' && 48 <= pn[1] && pn[1] <= 57 {
nn = append(nn, C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1))))
}
}
ss := &SQLiteStmt{c: c, s: s, nv: nv, nn: nn, t: t}
ss := &SQLiteStmt{c: c, s: s, t: t}
runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
return ss, nil
}
Expand All @@ -650,39 +688,31 @@ func (s *SQLiteStmt) Close() error {

// Return a number of parameters.
func (s *SQLiteStmt) NumInput() int {
return s.nv
return int(C.sqlite3_bind_parameter_count(s.s))
}

type bindArg struct {
n int
v driver.Value
}

func (s *SQLiteStmt) bind(args []driver.Value) error {
func (s *SQLiteStmt) bind(args []namedValue) error {
rv := C.sqlite3_reset(s.s)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return s.c.lastError()
}

var vargs []bindArg
narg := len(args)
vargs = make([]bindArg, narg)
if len(s.nn) > 0 {
for i, v := range s.nn {
if pi, err := strconv.Atoi(v[1:]); err == nil {
vargs[i] = bindArg{pi, args[i]}
}
}
} else {
for i, v := range args {
vargs[i] = bindArg{i + 1, v}
for i, v := range args {
if v.Name != "" {
cname := C.CString(v.Name)
args[i].Ordinal = int(C.sqlite3_bind_parameter_index(s.s, cname))
C.free(unsafe.Pointer(cname))
}
}

for _, varg := range vargs {
n := C.int(varg.n)
v := varg.v
switch v := v.(type) {
for _, arg := range args {
n := C.int(arg.Ordinal)
switch v := arg.Value.(type) {
case nil:
rv = C.sqlite3_bind_null(s.s, n)
case string:
Expand Down Expand Up @@ -722,6 +752,17 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {

// Query the statement with arguments. Return records.
func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
list := make([]namedValue, len(args))
for i, v := range args {
list[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
return s.query(context.Background(), list)
}

func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows, error) {
if err := s.bind(args); err != nil {
return nil, err
}
Expand All @@ -740,6 +781,17 @@ func (r *SQLiteResult) RowsAffected() (int64, error) {

// Execute the statement with arguments. Return result object.
func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
list := make([]namedValue, len(args))
for i, v := range args {
list[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
return s.exec(context.Background(), list)
}

func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
if err := s.bind(args); err != nil {
C.sqlite3_reset(s.s)
C.sqlite3_clear_bindings(s.s)
Expand Down
58 changes: 58 additions & 0 deletions sqlite3_go18.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// +build go1.8

package sqlite3

import (
"database/sql/driver"
"errors"

"golang.org/x/net/context"
)

// Ping implement Pinger.
func (c *SQLiteConn) Ping(ctx context.Context) error {
if c.db == nil {
return errors.New("Connection was closed")
}
return nil
}

func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
}
return c.query(ctx, query, list)
}

func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
}
return c.exec(ctx, query, list)
}

func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
return c.prepare(ctx, query)
}

func (c *SQLiteConn) BeginContext(ctx context.Context) (driver.Tx, error) {
return c.begin(ctx)
}

func (s *SQLiteStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
}
return s.query(ctx, list)
}

func (s *SQLiteStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
}
return s.exec(ctx, list)
}
49 changes: 49 additions & 0 deletions sqlite3_go18_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build go1.8

package sqlite3

import (
"database/sql"
"os"
"testing"
)

func TestNamedParams(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()

_, err = db.Exec(`
create table foo (id integer, name text, extra text);
`)
if err != nil {
t.Error("Failed to call db.Query:", err)
}

_, err = db.Exec(`insert into foo(id, name, extra) values(:id, :name, :name)`, sql.Param(":name", "foo"), sql.Param(":id", 1))
if err != nil {
t.Error("Failed to call db.Exec:", err)
}

row := db.QueryRow(`select id, extra from foo where id = :id and extra = :extra`, sql.Param(":id", 1), sql.Param(":extra", "foo"))
if row == nil {
t.Error("Failed to call db.QueryRow")
}
var id int
var extra string
err = row.Scan(&id, &extra)
if err != nil {
t.Error("Failed to db.Scan:", err)
}
if id != 1 || extra != "foo" {
t.Error("Failed to db.QueryRow: not matched results")
}
}
Loading