Skip to content

Commit

Permalink
use a prepared statement explicitly to make mysql happy when there ar…
Browse files Browse the repository at this point in the history
…e no query parameters
  • Loading branch information
jonbodner committed Sep 3, 2020
1 parent 136fd2a commit d8ed813
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 2 deletions.
59 changes: 59 additions & 0 deletions proteus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,65 @@ func TestNilScanner(t *testing.T) {
})
}

func TestNoParams(t *testing.T) {
type ScannerProduct struct {
Id int `prof:"id"`
Name string `prof:"name"`
}

type ScannerProductDao struct {
Insert func(e Executor) (int64, error) `proq:"insert into product(name) values('hi')"`
FindAll func(e Querier) (ScannerProduct, error) `proq:"select * from product"`
FindAllContext func(ctx context.Context, e ContextQuerier) (ScannerProduct, error) `proq:"select * from product"`
}

doTest := func(t *testing.T, setup setup, create string) {
productDao := ScannerProductDao{}
c := logger.WithLevel(context.Background(), logger.DEBUG)
db, err := setup(c, &productDao)
if err != nil {
t.Fatal(err)
}
defer db.Close()

tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Commit()

_, err = tx.Exec(create)
if err != nil {
t.Fatal(err)
}

_, err = productDao.Insert(tx)
if err != nil {
t.Fatal(err)
}
roundTrip, err := productDao.FindAll(tx)
if err != nil {
t.Fatal(err)
}
if roundTrip.Id != 1 || roundTrip.Name != "hi" {
t.Errorf("Expected {1 hi}, got %v", roundTrip)
}
roundTrip2, err := productDao.FindAllContext(context.Background(), tx)
if err != nil {
t.Fatal(err)
}
if roundTrip2.Id != 1 || roundTrip2.Name != "hi" {
t.Errorf("Expected {1 hi}, got %v", roundTrip2)
}
}
t.Run("postgres", func(t *testing.T) {
doTest(t, setupPostgres, " drop table if exists product; CREATE TABLE product(id SERIAL PRIMARY KEY, name VARCHAR(100), null_field VARCHAR(100))")
})
t.Run("mysql", func(t *testing.T) {
doTest(t, setupMySQL, " drop table if exists product; CREATE TABLE product(id int AUTO_INCREMENT, name VARCHAR(100), null_field VARCHAR(100), PRIMARY KEY(id))")
})
}

type setup func(c context.Context, dao interface{}) (*sql.DB, error)

func setupPostgres(c context.Context, dao interface{}) (*sql.DB, error) {
Expand Down
40 changes: 38 additions & 2 deletions runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,25 @@ func makeContextQuerierImplementation(c context.Context, funcType reflect.Type,
}

logger.Log(ctx, logger.DEBUG, fmt.Sprintln("calling", finalQuery, "with params", queryArgs))
rows, err = querier.QueryContext(ctx, finalQuery, queryArgs...)
// going to work around the defective Go MySQL driver, which refuses to convert the text protocol properly.
// It is used when doing a query without parameters.
type ContextPreparer interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}
if cp, ok := querier.(ContextPreparer); ok {
var stmt *sql.Stmt
stmt, err = cp.PrepareContext(ctx, finalQuery)
if err != nil {
return buildRetVals(rows, err)
}
defer stmt.Close()
rows, err = stmt.QueryContext(ctx, queryArgs...)
if err != nil {
return buildRetVals(rows, err)
}
} else {
rows, err = querier.QueryContext(ctx, finalQuery, queryArgs...)
}

return buildRetVals(rows, err)
}, nil
Expand Down Expand Up @@ -220,7 +238,25 @@ func makeQuerierImplementation(c context.Context, funcType reflect.Type, query q
}

logger.Log(c, logger.DEBUG, fmt.Sprintln("calling", finalQuery, "with params", queryArgs))
rows, err = querier.Query(finalQuery, queryArgs...)
// going to work around the defective Go MySQL driver, which refuses to convert the text protocol properly.
// It is used when doing a query without parameters.
type Preparer interface {
Prepare(query string) (*sql.Stmt, error)
}
if cp, ok := querier.(Preparer); ok {
var stmt *sql.Stmt
stmt, err = cp.Prepare(finalQuery)
if err != nil {
return buildRetVals(rows, err)
}
defer stmt.Close()
rows, err = stmt.Query(queryArgs...)
if err != nil {
return buildRetVals(rows, err)
}
} else {
rows, err = querier.Query(finalQuery, queryArgs...)
}

return buildRetVals(rows, err)
}, nil
Expand Down

0 comments on commit d8ed813

Please sign in to comment.