diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 00000000..a610df71 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,27 @@ +version: 2 +jobs: + build: + docker: + - image: circleci/golang:1.10 + environment: + - DBR_TEST_MYSQL_DSN=root:root@tcp(127.0.0.1:3306)/circle_test + - DBR_TEST_POSTGRES_DSN=postgres://postgres:mysecretpassword@127.0.0.1:5432/postgres?sslmode=disable + - image: percona:5.7.20 + environment: + - MYSQL_DATABASE=circle_test + - MYSQL_ROOT_PASSWORD=root + - image: postgres:9.6-alpine + + working_directory: /go/src/github.com/gocraft/dbr + steps: + - checkout + - run: + name: Install dockerize + command: wget https://github.com/jwilder/dockerize/releases/download/$DOCKERIZE_VERSION/dockerize-linux-amd64-$DOCKERIZE_VERSION.tar.gz && sudo tar -C /usr/local/bin -xzvf dockerize-linux-amd64-$DOCKERIZE_VERSION.tar.gz && rm dockerize-linux-amd64-$DOCKERIZE_VERSION.tar.gz + environment: + DOCKERIZE_VERSION: v0.3.0 + - run: + name: Wait for db + command: dockerize -wait tcp://127.0.0.1:3306 -wait tcp://127.0.0.1:5432 -timeout 1m + - run: go get -t ./... + - run: go test -v -cover -bench . ./... diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index b196ccee..00000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,20 +0,0 @@ -# Change log - -Not all minor changes may be noted here, but all large and/or breaking changes -should be. - -## v2.0 - 2015-10-09 - -### Added -- PostgreSQL support! -- `Open(driver, dsn string, log EventReceiver)` creates an underlying connection for you based on a supplied driver and dsn string -- All builders are now available without a `Session` facilitating much more complex queries -- More common SQL support: Subqueries, Unions, Joins, Aliases -- More complex condition building support: And/Or/Eq/Neq/Gt/Gte/Lt/Lte - -### Deprecated -- `NewConnection` is deprecated. It assumes MySQL driver. Please use `Open` instead - -### Changed -- `NullTime` no longer relies on the mysql package. E.g. instead of `NullTime{mysql.NullTime{Time: t, Valid: true}}` it's now simply `NullTime{Time: t, Valid: true}` -- All `*Builder` structs now embed a corresponding `*Stmt` struct (E.g. `SelectBuilder` embeds `SelectStmt`). All non-`Session` specific properies have been moved the `*Stmt` structs diff --git a/README.md b/README.md index 5b4e8025..bfa441a6 100644 --- a/README.md +++ b/README.md @@ -1,293 +1,160 @@ -# gocraft/dbr (database records) [![GoDoc](https://godoc.org/github.com/gocraft/web?status.png)](https://godoc.org/github.com/gocraft/dbr) +# gocraft/dbr (database records) -gocraft/dbr provides additions to Go's database/sql for super fast performance and convenience. - -## Getting Started - -```go -// create a connection (e.g. "postgres", "mysql", or "sqlite3") -conn, _ := dbr.Open("postgres", "...") +[![GoDoc](https://godoc.org/github.com/gocraft/dbr?status.png)](https://godoc.org/github.com/gocraft/dbr) +[![FOSSA Status](https://app.fossa.io/api/projects/git%2Bgithub.com%2Fgocraft%2Fdbr.svg?type=shield)](https://app.fossa.io/projects/git%2Bgithub.com%2Fgocraft%2Fdbr?ref=badge_shield) +[![Go Report Card](https://goreportcard.com/badge/github.com/gocraft/dbr)](https://goreportcard.com/report/github.com/gocraft/dbr) +[![CircleCI](https://circleci.com/gh/gocraft/dbr.svg?style=svg)](https://circleci.com/gh/gocraft/dbr) -// create a session for each business unit of execution (e.g. a web request or goworkers job) -sess := conn.NewSession(nil) - -// get a record -var suggestion Suggestion -sess.Select("id", "title").From("suggestions").Where("id = ?", 1).Load(&suggestion) +gocraft/dbr provides additions to Go's database/sql for super fast performance and convenience. -// JSON-ready, with dbr.Null* types serialized like you want -json.Marshal(&suggestion) ``` - -## Feature highlights - -### Use a Sweet Query Builder or use Plain SQL - -gocraft/dbr supports both. - -Sweet Query Builder: -```go -stmt := dbr.Select("title", "body"). - From("suggestions"). - OrderBy("id"). - Limit(10) +$ go get -u github.com/gocraft/dbr ``` -Plain SQL: - ```go -builder := dbr.SelectBySql("SELECT `title`, `body` FROM `suggestions` ORDER BY `id` ASC LIMIT 10") +import "github.com/gocraft/dbr" ``` -### Amazing instrumentation with session - -All queries in gocraft/dbr are made in the context of a session. This is because when instrumenting your app, it's important to understand which business action the query took place in. See gocraft/health for more detail. - -Writing instrumented code is a first-class concern for gocraft/dbr. We instrument each query to emit to a gocraft/health-compatible EventReceiver interface. - -### Faster performance than using database/sql directly -Every time you call database/sql's db.Query("SELECT ...") method, under the hood, the mysql driver will create a prepared statement, execute it, and then throw it away. This has a big performance cost. - -gocraft/dbr doesn't use prepared statements. We ported mysql's query escape functionality directly into our package, which means we interpolate all of those question marks with their arguments before they get to MySQL. The result of this is that it's way faster, and just as secure. - -Check out these [benchmarks](https://github.com/tyler-smith/golang-sql-benchmark). - -### IN queries that aren't horrible -Traditionally, database/sql uses prepared statements, which means each argument in an IN clause needs its own question mark. gocraft/dbr, on the other hand, handles interpolation itself so that you can easily use a single question mark paired with a dynamically sized slice. - -```go -ids := []int64{1, 2, 3, 4, 5} -builder.Where("id IN ?", ids) // `id` IN ? -``` +## Driver support -### JSON Friendly -Every try to JSON-encode a sql.NullString? You get: -```json -{ - "str1": { - "Valid": true, - "String": "Hi!" - }, - "str2": { - "Valid": false, - "String": "" - } -} -``` +* MySQL +* PostgreSQL +* SQLite3 -Not quite what you want. gocraft/dbr has dbr.NullString (and the rest of the Null* types) that encode correctly, giving you: +## Examples -```json -{ - "str1": "Hi!", - "str2": null -} -``` +See [godoc](https://godoc.org/github.com/gocraft/dbr) for more examples. -### Inserting multiple records +### Open connections ```go -sess.InsertInto("suggestions").Columns("title", "body"). - Record(suggestion1). - Record(suggestion2) -``` +// create a connection (e.g. "postgres", "mysql", or "sqlite3") +conn, _ := Open("postgres", "...", nil) +conn.SetMaxOpenConns(10) -### Updating records +// create a session for each business unit of execution (e.g. a web request or goworkers job) +sess := conn.NewSession(nil) -```go -sess.Update("suggestions"). - Set("title", "Gopher"). - Set("body", "I love go."). - Where("id = ?", 1) +// create a tx from sessions +sess.Begin() ``` -### Transactions +### Create and use Tx ```go +sess := mysqlSession tx, err := sess.Begin() if err != nil { - return err + return } defer tx.RollbackUnlessCommitted() // do stuff... -return tx.Commit() +tx.Commit() ``` -### Load database values to variables - -Querying is the heart of gocraft/dbr. - -* Load(&any): load everything! -* LoadStruct(&oneStruct): load struct -* LoadStructs(&manyStructs): load a slice of structs -* LoadValue(&oneValue): load basic type -* LoadValues(&manyValues): load a slice of basic types +### SelectStmt loads data into structs ```go // columns are mapped by tag then by field type Suggestion struct { - ID int64 // id, will be autoloaded by last insert id - Title string // title - Url string `db:"-"` // ignored - secret string // ignored - Body dbr.NullString `db:"content"` // content - User User -} - -// By default dbr converts CamelCase property names to snake_case column_names -// You can override this with struct tags, just like with JSON tags -// This is especially helpful while migrating from legacy systems -type Suggestion struct { - Id int64 - Title dbr.NullString `db:"subject"` // subjects are called titles now - CreatedAt dbr.NullTime + ID int64 // id, will be autoloaded by last insert id + Title NullString `db:"subject"` // subjects are called titles now + Url string `db:"-"` // ignored + secret string // ignored } +// By default gocraft/dbr converts CamelCase property names to snake_case column_names. +// You can override this with struct tags, just like with JSON tags. +// This is especially helpful while migrating from legacy systems. var suggestions []Suggestion +sess := mysqlSession sess.Select("*").From("suggestions").Load(&suggestions) ``` -### Join multiple tables - -dbr supports many join types: +### SelectStmt with where-value interpolation ```go -sess.Select("*").From("suggestions"). - Join("subdomains", "suggestions.subdomain_id = subdomains.id") - -sess.Select("*").From("suggestions"). - LeftJoin("subdomains", "suggestions.subdomain_id = subdomains.id") +// database/sql uses prepared statements, which means each argument +// in an IN clause needs its own question mark. +// gocraft/dbr, on the other hand, handles interpolation itself +// so that you can easily use a single question mark paired with a +// dynamically sized slice. -sess.Select("*").From("suggestions"). - RightJoin("subdomains", "suggestions.subdomain_id = subdomains.id") - -sess.Select("*").From("suggestions"). - FullJoin("subdomains", "suggestions.subdomain_id = subdomains.id") +sess := mysqlSession +ids := []int64{1, 2, 3, 4, 5} +sess.Select("*").From("suggestions").Where("id IN ?", ids) ``` -You can join on multiple tables: +### SelectStmt with joins ```go +sess := mysqlSession sess.Select("*").From("suggestions"). - Join("subdomains", "suggestions.subdomain_id = subdomains.id"). - Join("accounts", "subdomains.accounts_id = accounts.id") -``` + Join("subdomains", "suggestions.subdomain_id = subdomains.id") -### Quoting/escaping identifiers (e.g. table and column names) +sess.Select("*").From("suggestions"). + LeftJoin("subdomains", "suggestions.subdomain_id = subdomains.id") -```go -dbr.I("suggestions.id") // `suggestions`.`id` +// join multiple tables +sess.Select("*").From("suggestions"). + Join("subdomains", "suggestions.subdomain_id = subdomains.id"). + Join("accounts", "subdomains.accounts_id = accounts.id") ``` -### Subquery +### SelectStmt with raw SQL ```go -sess.Select("count(id)").From( - dbr.Select("*").From("suggestions").As("count"), -) +SelectBySql("SELECT `title`, `body` FROM `suggestions` ORDER BY `id` ASC LIMIT 10") ``` -### Union +### InsertStmt adds data from struct ```go -dbr.Union( - dbr.Select("*"), - dbr.Select("*"), -) - -dbr.UnionAll( - dbr.Select("*"), - dbr.Select("*"), -) -``` - -Union can be used in subquery. - -### Alias/AS - -* SelectStmt +type Suggestion struct { + ID int64 + Title NullString + CreatedAt time.Time +} +sugg := &Suggestion{ + Title: NewNullString("Gopher"), + CreatedAt: time.Now(), +} +sess := mysqlSession +sess.InsertInto("suggestions"). + Columns("id", "title"). + Record(&sugg). + Exec() -```go -dbr.Select("*").From("suggestions").As("count") +// id is set automatically +fmt.Println(sugg.ID) ``` -* Identity +### InsertStmt adds data from value ```go -dbr.I("suggestions").As("s") +sess := mysqlSession +sess.InsertInto("suggestions"). + Pair("title", "Gopher"). + Pair("body", "I love go.") ``` -* Union +## Benchmark (2018-05-11) -```go -dbr.Union( - dbr.Select("*"), - dbr.Select("*"), -).As("u1") - -dbr.UnionAll( - dbr.Select("*"), - dbr.Select("*"), -).As("u2") ``` - -### Building arbitrary condition - -One common reason to use this is to prevent string concatenation in a loop. - -* And -* Or -* Eq -* Neq -* Gt -* Gte -* Lt -* Lte -* Like -* NotLike - -```go -dbr.And( - dbr.Or( - dbr.Gt("created_at", "2015-09-10"), - dbr.Lte("created_at", "2015-09-11"), - ), - dbr.Eq("title", "hello world"), -) +BenchmarkLoadValues/sqlx_10-8 5000 407318 ns/op 3913 B/op 164 allocs/op +BenchmarkLoadValues/dbr_10-8 5000 372940 ns/op 3874 B/op 123 allocs/op +BenchmarkLoadValues/sqlx_100-8 2000 584197 ns/op 30195 B/op 1428 allocs/op +BenchmarkLoadValues/dbr_100-8 3000 558852 ns/op 22965 B/op 937 allocs/op +BenchmarkLoadValues/sqlx_1000-8 1000 2319101 ns/op 289339 B/op 14031 allocs/op +BenchmarkLoadValues/dbr_1000-8 1000 2310441 ns/op 210092 B/op 9040 allocs/op +BenchmarkLoadValues/sqlx_10000-8 100 17004716 ns/op 3193997 B/op 140043 allocs/op +BenchmarkLoadValues/dbr_10000-8 100 16150062 ns/op 2394698 B/op 90051 allocs/op +BenchmarkLoadValues/sqlx_100000-8 10 170068209 ns/op 31679944 B/op 1400053 allocs/op +BenchmarkLoadValues/dbr_100000-8 10 147202536 ns/op 23680625 B/op 900061 allocs/op ``` -### Built with extensibility - -The core of dbr is interpolation, which can expand `?` with arbitrary SQL. If you need a feature that is not currently supported, -you can build it on your own (or use `dbr.Expr`). - -To do that, the value that you wish to be expaned with `?` needs to implement `dbr.Builder`. - -```go -type Builder interface { - Build(Dialect, Buffer) error -} -``` - -## Driver support - -* MySQL -* PostgreSQL -* SQLite3 - -## gocraft - -gocraft offers a toolkit for building web apps. Currently these packages are available: - -* [gocraft/web](https://github.com/gocraft/web) - Go Router + Middleware. Your Contexts. -* [gocraft/dbr](https://github.com/gocraft/dbr) - Additions to Go's database/sql for super fast performance and convenience. -* [gocraft/health](https://github.com/gocraft/health) - Instrument your web apps with logging and metrics. -* [gocraft/work](https://github.com/gocraft/work) - Process background jobs in Go. - -These packages were developed by the [engineering team](https://eng.uservoice.com) at [UserVoice](https://www.uservoice.com) and currently power much of its infrastructure and tech stack. - ## Thanks & Authors Inspiration from these excellent libraries: * [sqlx](https://github.com/jmoiron/sqlx) - various useful tools and utils for interacting with database/sql. @@ -300,3 +167,6 @@ Authors: Contributors: * Paul Bergeron -- [https://github.com/dinedal](https://github.com/dinedal) - SQLite dialect + +## License +[![FOSSA Status](https://app.fossa.io/api/projects/git%2Bgithub.com%2Fgocraft%2Fdbr.svg?type=large)](https://app.fossa.io/projects/git%2Bgithub.com%2Fgocraft%2Fdbr?ref=badge_large) diff --git a/README.md.tpl b/README.md.tpl new file mode 100644 index 00000000..1eabc170 --- /dev/null +++ b/README.md.tpl @@ -0,0 +1,90 @@ +# gocraft/dbr (database records) + +[![GoDoc](https://godoc.org/github.com/gocraft/dbr?status.png)](https://godoc.org/github.com/gocraft/dbr) +[![FOSSA Status](https://app.fossa.io/api/projects/git%2Bgithub.com%2Fgocraft%2Fdbr.svg?type=shield)](https://app.fossa.io/projects/git%2Bgithub.com%2Fgocraft%2Fdbr?ref=badge_shield) +[![Go Report Card](https://goreportcard.com/badge/github.com/gocraft/dbr)](https://goreportcard.com/report/github.com/gocraft/dbr) +[![CircleCI](https://circleci.com/gh/gocraft/dbr.svg?style=svg)](https://circleci.com/gh/gocraft/dbr) + +gocraft/dbr provides additions to Go's database/sql for super fast performance and convenience. + +``` +$ go get -u github.com/gocraft/dbr +``` + +```go +import "github.com/gocraft/dbr" +``` + +## Driver support + +* MySQL +* PostgreSQL +* SQLite3 + +## Examples + +See [godoc](https://godoc.org/github.com/gocraft/dbr) for more examples. + +### Open connections + +{{ "ExampleOpen" | example }} + +### Create and use Tx + +{{ "ExampleTx" | example }} + +### SelectStmt loads data into structs + +{{ "ExampleSelectStmt_Load" | example }} + +### SelectStmt with where-value interpolation + +{{ "ExampleSelectStmt_Where" | example }} + +### SelectStmt with joins + +{{ "ExampleSelectStmt_Join" | example }} + +### SelectStmt with raw SQL + +{{ "ExampleSelectBySql" | example }} + +### InsertStmt adds data from struct + +{{ "ExampleInsertStmt_Record" | example }} + +### InsertStmt adds data from value + +{{ "ExampleInsertStmt_Pair" | example }} + + +## Benchmark (2018-05-11) + +``` +BenchmarkLoadValues/sqlx_10-8 5000 407318 ns/op 3913 B/op 164 allocs/op +BenchmarkLoadValues/dbr_10-8 5000 372940 ns/op 3874 B/op 123 allocs/op +BenchmarkLoadValues/sqlx_100-8 2000 584197 ns/op 30195 B/op 1428 allocs/op +BenchmarkLoadValues/dbr_100-8 3000 558852 ns/op 22965 B/op 937 allocs/op +BenchmarkLoadValues/sqlx_1000-8 1000 2319101 ns/op 289339 B/op 14031 allocs/op +BenchmarkLoadValues/dbr_1000-8 1000 2310441 ns/op 210092 B/op 9040 allocs/op +BenchmarkLoadValues/sqlx_10000-8 100 17004716 ns/op 3193997 B/op 140043 allocs/op +BenchmarkLoadValues/dbr_10000-8 100 16150062 ns/op 2394698 B/op 90051 allocs/op +BenchmarkLoadValues/sqlx_100000-8 10 170068209 ns/op 31679944 B/op 1400053 allocs/op +BenchmarkLoadValues/dbr_100000-8 10 147202536 ns/op 23680625 B/op 900061 allocs/op +``` + +## Thanks & Authors +Inspiration from these excellent libraries: +* [sqlx](https://github.com/jmoiron/sqlx) - various useful tools and utils for interacting with database/sql. +* [Squirrel](https://github.com/lann/squirrel) - simple fluent query builder. + +Authors: +* Jonathan Novak -- [https://github.com/cypriss](https://github.com/cypriss) +* Tai-Lin Chu -- [https://github.com/taylorchu](https://github.com/taylorchu) +* Sponsored by [UserVoice](https://eng.uservoice.com) + +Contributors: +* Paul Bergeron -- [https://github.com/dinedal](https://github.com/dinedal) - SQLite dialect + +## License +[![FOSSA Status](https://app.fossa.io/api/projects/git%2Bgithub.com%2Fgocraft%2Fdbr.svg?type=large)](https://app.fossa.io/projects/git%2Bgithub.com%2Fgocraft%2Fdbr?ref=badge_large) diff --git a/buffer.go b/buffer.go index 72b885de..d3194221 100644 --- a/buffer.go +++ b/buffer.go @@ -1,9 +1,11 @@ package dbr -import "bytes" +import "strings" +// Buffer collects strings, and values that are ready to be interpolated. +// This is used internally to efficiently build SQL statement. type Buffer interface { - WriteString(s string) (n int, err error) + WriteString(string) (int, error) String() string WriteValue(v ...interface{}) (err error) @@ -11,10 +13,11 @@ type Buffer interface { } type buffer struct { - bytes.Buffer + strings.Builder v []interface{} } +// NewBuffer creates a new Buffer. func NewBuffer() Buffer { return &buffer{} } diff --git a/builder.go b/builder.go index 625ceb06..70be4103 100644 --- a/builder.go +++ b/builder.go @@ -1,13 +1,22 @@ package dbr -// Builder builds sql in one dialect like MySQL/PostgreSQL -// e.g. XxxBuilder +// Builder builds SQL in Dialect like MySQL, and PostgreSQL. +// The raw SQL and values are stored in Buffer. +// +// The core of gocraft/dbr is interpolation, which can expand ? with arbitrary SQL. +// If you need a feature that is not currently supported, you can build it +// on your own (or use Expr). +// +// To do that, the value that you wish to be expanded with ? needs to +// implement Builder. type Builder interface { Build(Dialect, Buffer) error } +// BuildFunc implements Builder. type BuildFunc func(Dialect, Buffer) error +// Build calls itself to build SQL. func (b BuildFunc) Build(d Dialect, buf Buffer) error { return b(d, buf) } diff --git a/circle.yml b/circle.yml deleted file mode 100644 index 7fb688a8..00000000 --- a/circle.yml +++ /dev/null @@ -1,5 +0,0 @@ -## Customize the test machine -machine: - environment: - DBR_TEST_MYSQL_DSN: "ubuntu:@unix(/var/run/mysqld/mysqld.sock)/circle_test?charset=utf8" - DBR_TEST_POSTGRES_DSN: "postgres://ubuntu:@127.0.0.1:5432/circle_test" diff --git a/condition.go b/condition.go index 61537af4..7d37f12e 100644 --- a/condition.go +++ b/condition.go @@ -21,14 +21,14 @@ func buildCond(d Dialect, buf Buffer, pred string, cond ...Builder) error { return nil } -// And creates AND from a list of conditions +// And creates AND from a list of conditions. func And(cond ...Builder) Builder { return BuildFunc(func(d Dialect, buf Buffer) error { return buildCond(d, buf, "AND", cond...) }) } -// Or creates OR from a list of conditions +// Or creates OR from a list of conditions. func Or(cond ...Builder) Builder { return BuildFunc(func(d Dialect, buf Buffer) error { return buildCond(d, buf, "OR", cond...) @@ -122,7 +122,7 @@ func Lte(column string, value interface{}) Builder { func buildLikeCmp(d Dialect, buf Buffer, pred string, column string, value interface{}) error { if value == nil { - return ErrColumnNotSpecified + return ErrInvalidValue } v := reflect.ValueOf(value) @@ -142,10 +142,10 @@ func buildLikeCmp(d Dialect, buf Buffer, pred string, column string, value inter // need to convert into string return buildCmp(d, buf, pred, column, string(value.([]rune))) } - fallthrough - default: - return ErrColumnNotSpecified + } + + return ErrInvalidValue } // Like is `LIKE`. diff --git a/condition_test.go b/condition_test.go index 5f9a4682..5a52b252 100644 --- a/condition_test.go +++ b/condition_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/gocraft/dbr/dialect" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCondition(t *testing.T) { @@ -150,9 +150,9 @@ func TestCondition(t *testing.T) { buf := NewBuffer() err := test.cond.Build(dialect.MySQL, buf) if !test.isErr { - assert.NoError(t, err) + require.NoError(t, err) } - assert.Equal(t, test.query, buf.String()) - assert.Equal(t, test.value, buf.Value()) + require.Equal(t, test.query, buf.String()) + require.Equal(t, test.value, buf.Value()) } } diff --git a/dbr.go b/dbr.go index a6c38aa0..3e16216c 100644 --- a/dbr.go +++ b/dbr.go @@ -1,6 +1,8 @@ +// Package dbr provides additions to Go's database/sql for super fast performance and convenience. package dbr import ( + "context" "database/sql" "fmt" "time" @@ -8,8 +10,8 @@ import ( "github.com/gocraft/dbr/dialect" ) -// Open instantiates a Connection for a given database/sql connection -// and event receiver +// Open creates a Connection. +// log can be nil to ignore logging. func Open(driver, dsn string, log EventReceiver) (*Connection, error) { if log == nil { log = nullReceiver @@ -36,21 +38,36 @@ const ( placeholder = "?" ) -// Connection is a connection to the database with an EventReceiver -// to send events, errors, and timings to +// Connection wraps sql.DB with an EventReceiver +// to send events, errors, and timings. type Connection struct { *sql.DB - Dialect Dialect + Dialect EventReceiver } -// Session represents a business unit of execution for some connection +// Session represents a business unit of execution. +// +// All queries in gocraft/dbr are made in the context of a session. +// This is because when instrumenting your app, it's important +// to understand which business action the query took place in. +// +// A custom EventReceiver can be set. +// +// Timeout specifies max duration for an operation like Select. type Session struct { *Connection EventReceiver + Timeout time.Duration } -// NewSession instantiates a Session for the Connection +// GetTimeout returns current timeout enforced in session. +func (sess *Session) GetTimeout() time.Duration { + return sess.Timeout +} + +// NewSession instantiates a Session from Connection. +// If log is nil, Connection EventReceiver is used. func (conn *Connection) NewSession(log EventReceiver) *Session { if log == nil { log = conn.EventReceiver // Use parent instrumentation @@ -65,6 +82,7 @@ var ( ) // SessionRunner can do anything that a Session can except start a transaction. +// Both Session and Tx implements this interface. type SessionRunner interface { Select(column ...string) *SelectBuilder SelectBySql(query string, value ...interface{}) *SelectBuilder @@ -80,17 +98,25 @@ type SessionRunner interface { } type runner interface { - Exec(query string, args ...interface{}) (sql.Result, error) - Query(query string, args ...interface{}) (*sql.Rows, error) + GetTimeout() time.Duration + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } -func exec(runner runner, log EventReceiver, builder Builder, d Dialect) (sql.Result, error) { +func exec(ctx context.Context, runner runner, log EventReceiver, builder Builder, d Dialect) (sql.Result, error) { + timeout := runner.GetTimeout() + if timeout > 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + i := interpolator{ Buffer: NewBuffer(), Dialect: d, IgnoreBinary: true, } - err := i.interpolate(placeholder, []interface{}{builder}) + err := i.encodePlaceholder(builder, true) query, value := i.String(), i.Value() if err != nil { return nil, log.EventErrKv("dbr.exec.interpolate", err, kvs{ @@ -106,7 +132,7 @@ func exec(runner runner, log EventReceiver, builder Builder, d Dialect) (sql.Res }) }() - result, err := runner.Exec(query, value...) + result, err := runner.ExecContext(ctx, query, value...) if err != nil { return result, log.EventErrKv("dbr.exec.exec", err, kvs{ "sql": query, @@ -115,13 +141,20 @@ func exec(runner runner, log EventReceiver, builder Builder, d Dialect) (sql.Res return result, nil } -func query(runner runner, log EventReceiver, builder Builder, d Dialect, dest interface{}) (int, error) { +func query(ctx context.Context, runner runner, log EventReceiver, builder Builder, d Dialect, dest interface{}) (int, error) { + timeout := runner.GetTimeout() + if timeout > 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + i := interpolator{ Buffer: NewBuffer(), Dialect: d, IgnoreBinary: true, } - err := i.interpolate(placeholder, []interface{}{builder}) + err := i.encodePlaceholder(builder, true) query, value := i.String(), i.Value() if err != nil { return 0, log.EventErrKv("dbr.select.interpolate", err, kvs{ @@ -137,7 +170,7 @@ func query(runner runner, log EventReceiver, builder Builder, d Dialect, dest in }) }() - rows, err := runner.Query(query, value...) + rows, err := runner.QueryContext(ctx, query, value...) if err != nil { return 0, log.EventErrKv("dbr.select.load.query", err, kvs{ "sql": query, diff --git a/dbr_test.go b/dbr_test.go index c7f2f364..78402443 100644 --- a/dbr_test.go +++ b/dbr_test.go @@ -1,17 +1,17 @@ package dbr import ( - "bytes" + "context" "fmt" - "log" "os" "testing" + "time" _ "github.com/go-sql-driver/mysql" "github.com/gocraft/dbr/dialect" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // @@ -19,41 +19,17 @@ import ( // var ( - currID int64 = 256 -) - -// create id -func nextID() int64 { - currID++ - return currID -} - -const ( - mysqlDSN = "root@unix(/tmp/mysql.sock)/uservoice_test?charset=utf8" - postgresDSN = "postgres://postgres@localhost:5432/uservoice_test?sslmode=disable" + mysqlDSN = os.Getenv("DBR_TEST_MYSQL_DSN") + postgresDSN = os.Getenv("DBR_TEST_POSTGRES_DSN") sqlite3DSN = ":memory:" ) func createSession(driver, dsn string) *Session { - var testDSN string - switch driver { - case "mysql": - testDSN = os.Getenv("DBR_TEST_MYSQL_DSN") - case "postgres": - testDSN = os.Getenv("DBR_TEST_POSTGRES_DSN") - case "sqlite3": - testDSN = os.Getenv("DBR_TEST_SQLITE3_DSN") - } - if testDSN != "" { - dsn = testDSN - } conn, err := Open(driver, dsn, nil) if err != nil { - log.Fatal(err) + panic(err) } - sess := conn.NewSession(nil) - reset(sess) - return sess + return conn.NewSession(nil) } var ( @@ -81,7 +57,7 @@ type nullTypedRecord struct { BoolVal NullBool } -func reset(sess *Session) { +func reset(t *testing.T, sess *Session) { var autoIncrementType string switch sess.Dialect { case dialect.MySQL: @@ -110,102 +86,121 @@ func reset(sess *Session) { )`, autoIncrementType), } { _, err := sess.Exec(v) - if err != nil { - log.Fatalf("Failed to execute statement: %s, Got error: %s", v, err) - } - } -} - -func BenchmarkByteaNoBinaryEncode(b *testing.B) { - benchmarkBytea(b, postgresSession) -} - -func BenchmarkByteaBinaryEncode(b *testing.B) { - benchmarkBytea(b, postgresBinarySession) -} - -func benchmarkBytea(b *testing.B, sess *Session) { - data := bytes.Repeat([]byte("0123456789"), 1000) - for _, v := range []string{ - `DROP TABLE IF EXISTS bytea_table`, - `CREATE TABLE bytea_table ( - val bytea - )`, - } { - _, err := sess.Exec(v) - assert.NoError(b, err) - } - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _, err := sess.InsertInto("bytea_table").Pair("val", data).Exec() - assert.NoError(b, err) + require.NoError(t, err) } } func TestBasicCRUD(t *testing.T) { for _, sess := range testSession { + reset(t, sess) + jonathan := dbrPerson{ Name: "jonathan", Email: "jonathan@uservoice.com", } insertColumns := []string{"name", "email"} if sess.Dialect == dialect.PostgreSQL { - jonathan.Id = nextID() + jonathan.Id = 1 insertColumns = []string{"id", "name", "email"} } // insert result, err := sess.InsertInto("dbr_people").Columns(insertColumns...).Record(&jonathan).Exec() - assert.NoError(t, err) + require.NoError(t, err) rowsAffected, err := result.RowsAffected() - assert.NoError(t, err) - assert.EqualValues(t, 1, rowsAffected) + require.NoError(t, err) + require.Equal(t, int64(1), rowsAffected) - assert.True(t, jonathan.Id > 0) + require.True(t, jonathan.Id > 0) // select var people []dbrPerson - count, err := sess.Select("*").From("dbr_people").Where(Eq("id", jonathan.Id)).LoadStructs(&people) - assert.NoError(t, err) - if assert.Equal(t, 1, count) { - assert.Equal(t, jonathan.Id, people[0].Id) - assert.Equal(t, jonathan.Name, people[0].Name) - assert.Equal(t, jonathan.Email, people[0].Email) - } + count, err := sess.Select("*").From("dbr_people").Where(Eq("id", jonathan.Id)).Load(&people) + require.NoError(t, err) + require.Equal(t, 1, count) + require.Equal(t, jonathan.Id, people[0].Id) + require.Equal(t, jonathan.Name, people[0].Name) + require.Equal(t, jonathan.Email, people[0].Email) // select id ids, err := sess.Select("id").From("dbr_people").ReturnInt64s() - assert.NoError(t, err) - assert.Equal(t, 1, len(ids)) + require.NoError(t, err) + require.Equal(t, 1, len(ids)) // select id limit ids, err = sess.Select("id").From("dbr_people").Limit(1).ReturnInt64s() - assert.NoError(t, err) - assert.Equal(t, 1, len(ids)) + require.NoError(t, err) + require.Equal(t, 1, len(ids)) // update result, err = sess.Update("dbr_people").Where(Eq("id", jonathan.Id)).Set("name", "jonathan1").Exec() - assert.NoError(t, err) + require.NoError(t, err) rowsAffected, err = result.RowsAffected() - assert.NoError(t, err) - assert.EqualValues(t, 1, rowsAffected) + require.NoError(t, err) + require.Equal(t, int64(1), rowsAffected) var n NullInt64 - sess.Select("count(*)").From("dbr_people").Where("name = ?", "jonathan1").LoadValue(&n) - assert.EqualValues(t, 1, n.Int64) + sess.Select("count(*)").From("dbr_people").Where("name = ?", "jonathan1").LoadOne(&n) + require.Equal(t, int64(1), n.Int64) // delete result, err = sess.DeleteFrom("dbr_people").Where(Eq("id", jonathan.Id)).Exec() - assert.NoError(t, err) + require.NoError(t, err) rowsAffected, err = result.RowsAffected() - assert.NoError(t, err) - assert.EqualValues(t, 1, rowsAffected) + require.NoError(t, err) + require.Equal(t, int64(1), rowsAffected) // select id ids, err = sess.Select("id").From("dbr_people").ReturnInt64s() - assert.NoError(t, err) - assert.Equal(t, 0, len(ids)) + require.NoError(t, err) + require.Equal(t, 0, len(ids)) + } +} + +func TestTimeout(t *testing.T) { + mysqlSession := createSession("mysql", mysqlDSN) + postgresSession := createSession("postgres", postgresDSN) + sqlite3Session := createSession("sqlite3", sqlite3DSN) + + // all test sessions should be here + testSession := []*Session{mysqlSession, postgresSession, sqlite3Session} + + for _, sess := range testSession { + reset(t, sess) + + // session op timeout + sess.Timeout = time.Nanosecond + var people []dbrPerson + _, err := sess.Select("*").From("dbr_people").Load(&people) + require.Equal(t, context.DeadlineExceeded, err) + + _, err = sess.InsertInto("dbr_people").Columns("name", "email").Values("test", "test@test.com").Exec() + require.Equal(t, context.DeadlineExceeded, err) + + _, err = sess.Update("dbr_people").Set("name", "test1").Exec() + require.Equal(t, context.DeadlineExceeded, err) + + _, err = sess.DeleteFrom("dbr_people").Exec() + require.Equal(t, context.DeadlineExceeded, err) + + // tx op timeout + sess.Timeout = 0 + tx, err := sess.Begin() + require.NoError(t, err) + defer tx.RollbackUnlessCommitted() + tx.Timeout = time.Nanosecond + + _, err = tx.Select("*").From("dbr_people").Load(&people) + require.Equal(t, context.DeadlineExceeded, err) + + _, err = tx.InsertInto("dbr_people").Columns("name", "email").Values("test", "test@test.com").Exec() + require.Equal(t, context.DeadlineExceeded, err) + + _, err = tx.Update("dbr_people").Set("name", "test1").Exec() + require.Equal(t, context.DeadlineExceeded, err) + + _, err = tx.DeleteFrom("dbr_people").Exec() + require.Equal(t, context.DeadlineExceeded, err) } } diff --git a/delete.go b/delete.go index 5b62b30f..8e0757ed 100644 --- a/delete.go +++ b/delete.go @@ -1,15 +1,26 @@ package dbr -// DeleteStmt builds `DELETE ...` +import ( + "context" + "database/sql" + "strconv" +) + +// DeleteStmt builds `DELETE ...`. type DeleteStmt struct { - raw + runner + EventReceiver + Dialect - Table string + raw - WhereCond []Builder + Table string + WhereCond []Builder + LimitCount int64 } -// Build builds `DELETE ...` in dialect +type DeleteBuilder = DeleteStmt + func (b *DeleteStmt) Build(d Dialect, buf Buffer) error { if b.raw.Query != "" { return b.raw.Build(d, buf) @@ -29,27 +40,70 @@ func (b *DeleteStmt) Build(d Dialect, buf Buffer) error { return err } } + if b.LimitCount >= 0 { + buf.WriteString(" LIMIT ") + buf.WriteString(strconv.FormatInt(b.LimitCount, 10)) + } return nil } -// DeleteFrom creates a DeleteStmt +// DeleteFrom creates a DeleteStmt. func DeleteFrom(table string) *DeleteStmt { return &DeleteStmt{ - Table: table, + Table: table, + LimitCount: -1, } } -// DeleteBySql creates a DeleteStmt from raw query +// DeleteFrom creates a DeleteStmt. +func (sess *Session) DeleteFrom(table string) *DeleteStmt { + b := DeleteFrom(table) + b.runner = sess + b.EventReceiver = sess + b.Dialect = sess.Dialect + return b +} + +// DeleteFrom creates a DeleteStmt. +func (tx *Tx) DeleteFrom(table string) *DeleteStmt { + b := DeleteFrom(table) + b.runner = tx + b.EventReceiver = tx + b.Dialect = tx.Dialect + return b +} + +// DeleteBySql creates a DeleteStmt from raw query. func DeleteBySql(query string, value ...interface{}) *DeleteStmt { return &DeleteStmt{ raw: raw{ Query: query, Value: value, }, + LimitCount: -1, } } -// Where adds a where condition +// DeleteBySql creates a DeleteStmt from raw query. +func (sess *Session) DeleteBySql(query string, value ...interface{}) *DeleteStmt { + b := DeleteBySql(query, value...) + b.runner = sess + b.EventReceiver = sess + b.Dialect = sess.Dialect + return b +} + +// DeleteBySql creates a DeleteStmt from raw query. +func (tx *Tx) DeleteBySql(query string, value ...interface{}) *DeleteStmt { + b := DeleteBySql(query, value...) + b.runner = tx + b.EventReceiver = tx + b.Dialect = tx.Dialect + return b +} + +// Where adds a where condition. +// query can be Builder or string. value is used only if query type is string. func (b *DeleteStmt) Where(query interface{}, value ...interface{}) *DeleteStmt { switch query := query.(type) { case string: @@ -59,3 +113,16 @@ func (b *DeleteStmt) Where(query interface{}, value ...interface{}) *DeleteStmt } return b } + +func (b *DeleteStmt) Limit(n uint64) *DeleteStmt { + b.LimitCount = int64(n) + return b +} + +func (b *DeleteStmt) Exec() (sql.Result, error) { + return b.ExecContext(context.Background()) +} + +func (b *DeleteStmt) ExecContext(ctx context.Context) (sql.Result, error) { + return exec(ctx, b.runner, b.EventReceiver, b, b.Dialect) +} diff --git a/delete_builder.go b/delete_builder.go deleted file mode 100644 index 73145451..00000000 --- a/delete_builder.go +++ /dev/null @@ -1,82 +0,0 @@ -package dbr - -import ( - "database/sql" - "fmt" -) - -type DeleteBuilder struct { - runner - EventReceiver - Dialect Dialect - - *DeleteStmt - - LimitCount int64 -} - -func (sess *Session) DeleteFrom(table string) *DeleteBuilder { - return &DeleteBuilder{ - runner: sess, - EventReceiver: sess, - Dialect: sess.Dialect, - DeleteStmt: DeleteFrom(table), - LimitCount: -1, - } -} - -func (tx *Tx) DeleteFrom(table string) *DeleteBuilder { - return &DeleteBuilder{ - runner: tx, - EventReceiver: tx, - Dialect: tx.Dialect, - DeleteStmt: DeleteFrom(table), - LimitCount: -1, - } -} - -func (sess *Session) DeleteBySql(query string, value ...interface{}) *DeleteBuilder { - return &DeleteBuilder{ - runner: sess, - EventReceiver: sess, - Dialect: sess.Dialect, - DeleteStmt: DeleteBySql(query, value...), - LimitCount: -1, - } -} - -func (tx *Tx) DeleteBySql(query string, value ...interface{}) *DeleteBuilder { - return &DeleteBuilder{ - runner: tx, - EventReceiver: tx, - Dialect: tx.Dialect, - DeleteStmt: DeleteBySql(query, value...), - LimitCount: -1, - } -} - -func (b *DeleteBuilder) Exec() (sql.Result, error) { - return exec(b.runner, b.EventReceiver, b, b.Dialect) -} - -func (b *DeleteBuilder) Where(query interface{}, value ...interface{}) *DeleteBuilder { - b.DeleteStmt.Where(query, value...) - return b -} - -func (b *DeleteBuilder) Limit(n uint64) *DeleteBuilder { - b.LimitCount = int64(n) - return b -} - -func (b *DeleteBuilder) Build(d Dialect, buf Buffer) error { - err := b.DeleteStmt.Build(b.Dialect, buf) - if err != nil { - return err - } - if b.LimitCount >= 0 { - buf.WriteString(" LIMIT ") - buf.WriteString(fmt.Sprint(b.LimitCount)) - } - return nil -} diff --git a/delete_test.go b/delete_test.go index e5cff17a..1046dbba 100644 --- a/delete_test.go +++ b/delete_test.go @@ -4,16 +4,16 @@ import ( "testing" "github.com/gocraft/dbr/dialect" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDeleteStmt(t *testing.T) { buf := NewBuffer() builder := DeleteFrom("table").Where(Eq("a", 1)) err := builder.Build(dialect.MySQL, buf) - assert.NoError(t, err) - assert.Equal(t, "DELETE FROM `table` WHERE (`a` = ?)", buf.String()) - assert.Equal(t, []interface{}{1}, buf.Value()) + require.NoError(t, err) + require.Equal(t, "DELETE FROM `table` WHERE (`a` = ?)", buf.String()) + require.Equal(t, []interface{}{1}, buf.Value()) } func BenchmarkDeleteSQL(b *testing.B) { diff --git a/dialect.go b/dialect.go index 136b7284..b89567c8 100644 --- a/dialect.go +++ b/dialect.go @@ -2,7 +2,8 @@ package dbr import "time" -// Dialect abstracts database differences +// Dialect abstracts database driver differences in encoding +// types, and placeholders. type Dialect interface { QuoteIdent(id string) string diff --git a/dialect/dialect_test.go b/dialect/dialect_test.go index 85ea1086..418d3d08 100644 --- a/dialect/dialect_test.go +++ b/dialect/dialect_test.go @@ -3,7 +3,7 @@ package dialect import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMySQL(t *testing.T) { @@ -20,7 +20,7 @@ func TestMySQL(t *testing.T) { want: "`col`", }, } { - assert.Equal(t, test.want, MySQL.QuoteIdent(test.in)) + require.Equal(t, test.want, MySQL.QuoteIdent(test.in)) } } @@ -38,7 +38,7 @@ func TestPostgreSQL(t *testing.T) { want: `"col"`, }, } { - assert.Equal(t, test.want, PostgreSQL.QuoteIdent(test.in)) + require.Equal(t, test.want, PostgreSQL.QuoteIdent(test.in)) } } @@ -56,6 +56,6 @@ func TestSQLite3(t *testing.T) { want: `"col"`, }, } { - assert.Equal(t, test.want, SQLite3.QuoteIdent(test.in)) + require.Equal(t, test.want, SQLite3.QuoteIdent(test.in)) } } diff --git a/dialect/mysql.go b/dialect/mysql.go index a29f65a8..24bdb8ef 100644 --- a/dialect/mysql.go +++ b/dialect/mysql.go @@ -1,8 +1,8 @@ package dialect import ( - "bytes" "fmt" + "strings" "time" ) @@ -13,7 +13,7 @@ func (d mysql) QuoteIdent(s string) string { } func (d mysql) EncodeString(s string) string { - buf := new(bytes.Buffer) + var buf strings.Builder buf.WriteRune('\'') // https://dev.mysql.com/doc/refman/5.7/en/string-literals.html diff --git a/errors.go b/errors.go index b53661c3..37d488b0 100644 --- a/errors.go +++ b/errors.go @@ -13,4 +13,5 @@ var ( ErrInvalidSliceLength = errors.New("dbr: length of slice is 0. length must be >= 1") ErrCantConvertToTime = errors.New("dbr: can't convert to time.Time") ErrInvalidTimestring = errors.New("dbr: invalid time string") + ErrInvalidValue = errors.New("dbr: invalid value") ) diff --git a/event.go b/event.go index 983e123f..110534c3 100644 --- a/event.go +++ b/event.go @@ -1,6 +1,6 @@ package dbr -// EventReceiver gets events from dbr methods for logging purposes +// EventReceiver gets events from dbr methods for logging purposes. type EventReceiver interface { Event(eventName string) EventKv(eventName string, kvs map[string]string) @@ -14,27 +14,28 @@ type kvs map[string]string var nullReceiver = &NullEventReceiver{} -// NullEventReceiver is a sentinel EventReceiver; use it if the caller doesn't supply one +// NullEventReceiver is a sentinel EventReceiver. +// Use it if the caller doesn't supply one. type NullEventReceiver struct{} -// Event receives a simple notification when various events occur +// Event receives a simple notification when various events occur. func (n *NullEventReceiver) Event(eventName string) {} // EventKv receives a notification when various events occur along with -// optional key/value data +// optional key/value data. func (n *NullEventReceiver) EventKv(eventName string, kvs map[string]string) {} -// EventErr receives a notification of an error if one occurs +// EventErr receives a notification of an error if one occurs. func (n *NullEventReceiver) EventErr(eventName string, err error) error { return err } // EventErrKv receives a notification of an error if one occurs along with -// optional key/value data +// optional key/value data. func (n *NullEventReceiver) EventErrKv(eventName string, err error, kvs map[string]string) error { return err } -// Timing receives the time an event took to happen +// Timing receives the time an event took to happen. func (n *NullEventReceiver) Timing(eventName string, nanoseconds int64) {} -// TimingKv receives the time an event took to happen along with optional key/value data +// TimingKv receives the time an event took to happen along with optional key/value data. func (n *NullEventReceiver) TimingKv(eventName string, nanoseconds int64, kvs map[string]string) {} diff --git a/example_test.go b/example_test.go new file mode 100644 index 00000000..94f5fc04 --- /dev/null +++ b/example_test.go @@ -0,0 +1,162 @@ +package dbr + +import ( + "fmt" + "time" +) + +func ExampleOpen() { + // create a connection (e.g. "postgres", "mysql", or "sqlite3") + conn, _ := Open("postgres", "...", nil) + conn.SetMaxOpenConns(10) + + // create a session for each business unit of execution (e.g. a web request or goworkers job) + sess := conn.NewSession(nil) + + // create a tx from sessions + sess.Begin() +} + +func ExampleSelect() { + Select("title", "body"). + From("suggestions"). + OrderBy("id"). + Limit(10) +} + +func ExampleSelectBySql() { + SelectBySql("SELECT `title`, `body` FROM `suggestions` ORDER BY `id` ASC LIMIT 10") +} + +func ExampleSelectStmt_Load() { + // columns are mapped by tag then by field + type Suggestion struct { + ID int64 // id, will be autoloaded by last insert id + Title NullString `db:"subject"` // subjects are called titles now + Url string `db:"-"` // ignored + secret string // ignored + } + + // By default gocraft/dbr converts CamelCase property names to snake_case column_names. + // You can override this with struct tags, just like with JSON tags. + // This is especially helpful while migrating from legacy systems. + var suggestions []Suggestion + sess := mysqlSession + sess.Select("*").From("suggestions").Load(&suggestions) +} + +func ExampleSelectStmt_Where() { + // database/sql uses prepared statements, which means each argument + // in an IN clause needs its own question mark. + // gocraft/dbr, on the other hand, handles interpolation itself + // so that you can easily use a single question mark paired with a + // dynamically sized slice. + + sess := mysqlSession + ids := []int64{1, 2, 3, 4, 5} + sess.Select("*").From("suggestions").Where("id IN ?", ids) +} + +func ExampleSelectStmt_Join() { + sess := mysqlSession + sess.Select("*").From("suggestions"). + Join("subdomains", "suggestions.subdomain_id = subdomains.id") + + sess.Select("*").From("suggestions"). + LeftJoin("subdomains", "suggestions.subdomain_id = subdomains.id") + + // join multiple tables + sess.Select("*").From("suggestions"). + Join("subdomains", "suggestions.subdomain_id = subdomains.id"). + Join("accounts", "subdomains.accounts_id = accounts.id") +} + +func ExampleSelectStmt_As() { + sess := mysqlSession + sess.Select("count(id)").From( + Select("*").From("suggestions").As("count"), + ) +} + +func ExampleInsertStmt_Pair() { + sess := mysqlSession + sess.InsertInto("suggestions"). + Pair("title", "Gopher"). + Pair("body", "I love go.") +} + +func ExampleInsertStmt_Record() { + type Suggestion struct { + ID int64 + Title NullString + CreatedAt time.Time + } + sugg := &Suggestion{ + Title: NewNullString("Gopher"), + CreatedAt: time.Now(), + } + sess := mysqlSession + sess.InsertInto("suggestions"). + Columns("id", "title"). + Record(&sugg). + Exec() + + // id is set automatically + fmt.Println(sugg.ID) +} + +func ExampleUpdateStmt() { + sess := mysqlSession + sess.Update("suggestions"). + Set("title", "Gopher"). + Set("body", "I love go."). + Where("id = ?", 1) +} + +func ExampleDeleteStmt() { + sess := mysqlSession + sess.DeleteFrom("suggestions"). + Where("id = ?", 1) +} + +func ExampleTx() { + sess := mysqlSession + tx, err := sess.Begin() + if err != nil { + return + } + defer tx.RollbackUnlessCommitted() + + // do stuff... + + tx.Commit() +} + +func ExampleAnd() { + And( + Or( + Gt("created_at", "2015-09-10"), + Lte("created_at", "2015-09-11"), + ), + Eq("title", "hello world"), + ) +} + +func ExampleI() { + // I, identifier, can be used to quote. + I("suggestions.id").As("id") // `suggestions`.`id` +} + +func ExampleUnion() { + Union( + Select("*"), + Select("*"), + ).As("subquery") +} + +func ExampleUnionAll() { + UnionAll( + Select("*"), + Select("*"), + ).As("subquery") +} diff --git a/expr.go b/expr.go index ecc0e1f3..048ec27e 100644 --- a/expr.go +++ b/expr.go @@ -1,12 +1,12 @@ package dbr -// XxxBuilders all support raw query type raw struct { Query string Value []interface{} } -// Expr should be used when sql syntax is not supported +// Expr allows raw expression to be used when current SQL syntax is +// not supported by gocraft/dbr. func Expr(query string, value ...interface{}) Builder { return &raw{Query: query, Value: value} } diff --git a/ident.go b/ident.go index 42d90a05..b2f8aa89 100644 --- a/ident.go +++ b/ident.go @@ -1,14 +1,15 @@ package dbr -// identifier is a type of string +// I is quoted identifier type I string +// Build quotes string with dialect. func (i I) Build(d Dialect, buf Buffer) error { buf.WriteString(d.QuoteIdent(string(i))) return nil } -// As creates an alias for expr. e.g. SELECT `a1` AS `a2` +// As creates an alias for expr. func (i I) As(alias string) Builder { return as(i, alias) } diff --git a/insert.go b/insert.go index c1f26c86..b5a0da82 100644 --- a/insert.go +++ b/insert.go @@ -1,20 +1,29 @@ package dbr import ( - "bytes" + "context" + "database/sql" "reflect" + "strings" ) -// InsertStmt builds `INSERT INTO ...` +// InsertStmt builds `INSERT INTO ...`. type InsertStmt struct { + runner + EventReceiver + Dialect + raw - Table string - Column []string - Value [][]interface{} + Table string + Column []string + Value [][]interface{} + ReturnColumn []string + RecordID *int64 } -// Build builds `INSERT INTO ...` in dialect +type InsertBuilder = InsertStmt + func (b *InsertStmt) Build(d Dialect, buf Buffer) error { if b.raw.Query != "" { return b.raw.Build(d, buf) @@ -31,7 +40,7 @@ func (b *InsertStmt) Build(d Dialect, buf Buffer) error { buf.WriteString("INSERT INTO ") buf.WriteString(d.QuoteIdent(b.Table)) - placeholderBuf := new(bytes.Buffer) + var placeholderBuf strings.Builder placeholderBuf.WriteString("(") buf.WriteString(" (") for i, col := range b.Column { @@ -55,17 +64,45 @@ func (b *InsertStmt) Build(d Dialect, buf Buffer) error { buf.WriteValue(tuple...) } + if len(b.ReturnColumn) > 0 { + buf.WriteString(" RETURNING ") + for i, col := range b.ReturnColumn { + if i > 0 { + buf.WriteString(",") + } + buf.WriteString(d.QuoteIdent(col)) + } + } + return nil } -// InsertInto creates an InsertStmt +// InsertInto creates an InsertStmt. func InsertInto(table string) *InsertStmt { return &InsertStmt{ Table: table, } } -// InsertBySql creates an InsertStmt from raw query +// InsertInto creates an InsertStmt. +func (sess *Session) InsertInto(table string) *InsertStmt { + b := InsertInto(table) + b.runner = sess + b.EventReceiver = sess + b.Dialect = sess.Dialect + return b +} + +// InsertInto creates an InsertStmt. +func (tx *Tx) InsertInto(table string) *InsertStmt { + b := InsertInto(table) + b.runner = tx + b.EventReceiver = tx + b.Dialect = tx.Dialect + return b +} + +// InsertBySql creates an InsertStmt from raw query. func InsertBySql(query string, value ...interface{}) *InsertStmt { return &InsertStmt{ raw: raw{ @@ -75,33 +112,115 @@ func InsertBySql(query string, value ...interface{}) *InsertStmt { } } -// Columns adds columns +// InsertBySql creates an InsertStmt from raw query. +func (sess *Session) InsertBySql(query string, value ...interface{}) *InsertStmt { + b := InsertBySql(query, value...) + b.runner = sess + b.EventReceiver = sess + b.Dialect = sess.Dialect + return b +} + +// InsertBySql creates an InsertStmt from raw query. +func (tx *Tx) InsertBySql(query string, value ...interface{}) *InsertStmt { + b := InsertBySql(query, value...) + b.runner = tx + b.EventReceiver = tx + b.Dialect = tx.Dialect + return b +} + func (b *InsertStmt) Columns(column ...string) *InsertStmt { b.Column = column return b } -// Values adds a tuple for columns +// Values adds a tuple to be inserted. +// The order of the tuple should match Columns. func (b *InsertStmt) Values(value ...interface{}) *InsertStmt { b.Value = append(b.Value, value) return b } -// Record adds a tuple for columns from a struct +// Record adds a tuple for columns from a struct. +// +// If there is a field called "Id" or "ID" in the struct, +// it will be set to LastInsertId. func (b *InsertStmt) Record(structValue interface{}) *InsertStmt { v := reflect.Indirect(reflect.ValueOf(structValue)) if v.Kind() == reflect.Struct { - var value []interface{} - m := structMap(v) - for _, key := range b.Column { - if val, ok := m[key]; ok { - value = append(value, val.Interface()) - } else { - value = append(value, nil) + found := make([]interface{}, len(b.Column)+1) + // ID is recommended by golint here + s := newTagStore() + s.findValueByName(v, append(b.Column, "id"), found, false) + + value := found[:len(found)-1] + for i, v := range value { + if v != nil { + value[i] = v.(reflect.Value).Interface() + } + } + + if v.CanSet() { + switch idField := found[len(found)-1].(type) { + case reflect.Value: + if idField.Kind() == reflect.Int64 { + b.RecordID = idField.Addr().Interface().(*int64) + } } } b.Values(value...) } return b } + +// Returning specifies the returning columns for postgres. +func (b *InsertStmt) Returning(column ...string) *InsertStmt { + b.ReturnColumn = column + return b +} + +// Pair adds (column, value) to be inserted. +// It is an error to mix Pair with Values and Record. +func (b *InsertStmt) Pair(column string, value interface{}) *InsertStmt { + b.Column = append(b.Column, column) + switch len(b.Value) { + case 0: + b.Values(value) + case 1: + b.Value[0] = append(b.Value[0], value) + default: + panic("pair only allows one record to insert") + } + return b +} + +func (b *InsertStmt) Exec() (sql.Result, error) { + return b.ExecContext(context.Background()) +} + +func (b *InsertStmt) ExecContext(ctx context.Context) (sql.Result, error) { + result, err := exec(ctx, b.runner, b.EventReceiver, b, b.Dialect) + if err != nil { + return nil, err + } + + if b.RecordID != nil { + if id, err := result.LastInsertId(); err == nil { + *b.RecordID = id + } + b.RecordID = nil + } + + return result, nil +} + +func (b *InsertStmt) LoadContext(ctx context.Context, value interface{}) error { + _, err := query(ctx, b.runner, b.EventReceiver, b, b.Dialect, value) + return err +} + +func (b *InsertStmt) Load(value interface{}) error { + return b.LoadContext(context.Background(), value) +} diff --git a/insert_builder.go b/insert_builder.go deleted file mode 100644 index 2d7a1c38..00000000 --- a/insert_builder.go +++ /dev/null @@ -1,107 +0,0 @@ -package dbr - -import ( - "database/sql" - "reflect" -) - -type InsertBuilder struct { - runner - EventReceiver - Dialect Dialect - - RecordID reflect.Value - - *InsertStmt -} - -func (sess *Session) InsertInto(table string) *InsertBuilder { - return &InsertBuilder{ - runner: sess, - EventReceiver: sess, - Dialect: sess.Dialect, - InsertStmt: InsertInto(table), - } -} - -func (tx *Tx) InsertInto(table string) *InsertBuilder { - return &InsertBuilder{ - runner: tx, - EventReceiver: tx, - Dialect: tx.Dialect, - InsertStmt: InsertInto(table), - } -} - -func (sess *Session) InsertBySql(query string, value ...interface{}) *InsertBuilder { - return &InsertBuilder{ - runner: sess, - EventReceiver: sess, - Dialect: sess.Dialect, - InsertStmt: InsertBySql(query, value...), - } -} - -func (tx *Tx) InsertBySql(query string, value ...interface{}) *InsertBuilder { - return &InsertBuilder{ - runner: tx, - EventReceiver: tx, - Dialect: tx.Dialect, - InsertStmt: InsertBySql(query, value...), - } -} - -func (b *InsertBuilder) Pair(column string, value interface{}) *InsertBuilder { - b.Column = append(b.Column, column) - switch len(b.Value) { - case 0: - b.InsertStmt.Values(value) - case 1: - b.Value[0] = append(b.Value[0], value) - default: - panic("pair only allows one record to insert") - } - return b -} - -func (b *InsertBuilder) Exec() (sql.Result, error) { - result, err := exec(b.runner, b.EventReceiver, b, b.Dialect) - if err != nil { - return nil, err - } - - if b.RecordID.IsValid() { - if id, err := result.LastInsertId(); err == nil { - b.RecordID.SetInt(id) - } - } - - return result, nil -} - -func (b *InsertBuilder) Columns(column ...string) *InsertBuilder { - b.InsertStmt.Columns(column...) - return b -} - -func (b *InsertBuilder) Record(structValue interface{}) *InsertBuilder { - v := reflect.Indirect(reflect.ValueOf(structValue)) - if v.Kind() == reflect.Struct && v.CanSet() { - // ID is recommended by golint here - for _, name := range []string{"Id", "ID"} { - field := v.FieldByName(name) - if field.IsValid() && field.Kind() == reflect.Int64 { - b.RecordID = field - break - } - } - } - - b.InsertStmt.Record(structValue) - return b -} - -func (b *InsertBuilder) Values(value ...interface{}) *InsertBuilder { - b.InsertStmt.Values(value...) - return b -} diff --git a/insert_test.go b/insert_test.go index 055bf64f..ac363e24 100644 --- a/insert_test.go +++ b/insert_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/gocraft/dbr/dialect" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type insertTest struct { @@ -19,9 +19,20 @@ func TestInsertStmt(t *testing.T) { C: "two", }) err := builder.Build(dialect.MySQL, buf) - assert.NoError(t, err) - assert.Equal(t, "INSERT INTO `table` (`a`,`b`) VALUES (?,?), (?,?)", buf.String()) - assert.Equal(t, []interface{}{1, "one", 2, "two"}, buf.Value()) + require.NoError(t, err) + require.Equal(t, "INSERT INTO `table` (`a`,`b`) VALUES (?,?), (?,?)", buf.String()) + require.Equal(t, []interface{}{1, "one", 2, "two"}, buf.Value()) +} + +func TestPostgresReturning(t *testing.T) { + sess := postgresSession + reset(t, sess) + + var person dbrPerson + err := sess.InsertInto("dbr_people").Columns("name").Record(&person). + Returning("id").Load(&person.Id) + require.NoError(t, err) + require.True(t, person.Id > 0) } func BenchmarkInsertValuesSQL(b *testing.B) { diff --git a/interpolate.go b/interpolate.go index 4471672f..372164f5 100644 --- a/interpolate.go +++ b/interpolate.go @@ -15,24 +15,37 @@ type interpolator struct { N int } -// InterpolateForDialect replaces placeholder in query with corresponding value in dialect +// InterpolateForDialect replaces placeholder +// in query with corresponding value in dialect. +// +// It can be also used for debugging custom Builder. +// +// Every time you call database/sql's db.Query("SELECT ...") method, +// under the hood, the mysql driver will create a prepared statement, +// execute it, and then throw it away. This has a big performance cost. +// +// gocraft/dbr doesn't use prepared statements. +// We ported mysql's query escape functionality directly into our package, +// which means we interpolate all of those question marks with +// their arguments before they get to MySQL. +// The result of this is that it's way faster, and just as secure. +// +// Check out these benchmarks from https://github.com/tyler-smith/golang-sql-benchmark. func InterpolateForDialect(query string, value []interface{}, d Dialect) (string, error) { i := interpolator{ Buffer: NewBuffer(), Dialect: d, } - err := i.interpolate(query, value) + err := i.interpolate(query, value, true) if err != nil { return "", err } return i.String(), nil } -func (i *interpolator) interpolate(query string, value []interface{}) error { - if strings.Count(query, placeholder) != len(value) { - return ErrPlaceholderCount - } +var escapedPlaceholder = strings.Repeat(placeholder, 2) +func (i *interpolator) interpolate(query string, value []interface{}, topLevel bool) error { valueIndex := 0 for { @@ -41,13 +54,24 @@ func (i *interpolator) interpolate(query string, value []interface{}) error { break } + // escape placeholder by repeating it twice + if strings.HasPrefix(query[index:], escapedPlaceholder) { + i.WriteString(query[:index+1]) // Write placeholder once, not twice + query = query[index+len(escapedPlaceholder):] + continue + } + + if valueIndex >= len(value) { + return ErrPlaceholderCount + } + i.WriteString(query[:index]) if _, ok := value[valueIndex].([]byte); ok && i.IgnoreBinary { i.WriteString(i.Placeholder(i.N)) i.N++ i.WriteValue(value[valueIndex]) } else { - err := i.encodePlaceholder(value[valueIndex]) + err := i.encodePlaceholder(value[valueIndex], topLevel) if err != nil { return err } @@ -56,30 +80,36 @@ func (i *interpolator) interpolate(query string, value []interface{}) error { valueIndex++ } + if valueIndex != len(value) { + return ErrPlaceholderCount + } + // placeholder not found; write remaining query i.WriteString(query) return nil } -func (i *interpolator) encodePlaceholder(value interface{}) error { +var ( + typeTime = reflect.TypeOf(time.Time{}) +) + +func (i *interpolator) encodePlaceholder(value interface{}, topLevel bool) error { if builder, ok := value.(Builder); ok { pbuf := NewBuffer() err := builder.Build(i.Dialect, pbuf) if err != nil { return err } - paren := true + paren := false switch value.(type) { - case *SelectStmt: - case *union: - default: - paren = false + case *SelectStmt, *union: + paren = !topLevel } if paren { i.WriteString("(") } - err = i.interpolate(pbuf.String(), pbuf.Value()) + err = i.interpolate(pbuf.String(), pbuf.Value(), false) if err != nil { return err } @@ -120,7 +150,7 @@ func (i *interpolator) encodePlaceholder(value interface{}) error { i.WriteString(strconv.FormatFloat(v.Float(), 'f', -1, 64)) return nil case reflect.Struct: - if v.Type() == reflect.TypeOf(time.Time{}) { + if v.Type() == typeTime { i.WriteString(i.EncodeTime(v.Interface().(time.Time))) return nil } @@ -139,7 +169,7 @@ func (i *interpolator) encodePlaceholder(value interface{}) error { if n > 0 { i.WriteString(",") } - err := i.encodePlaceholder(v.Index(n).Interface()) + err := i.encodePlaceholder(v.Index(n).Interface(), topLevel) if err != nil { return err } @@ -151,7 +181,7 @@ func (i *interpolator) encodePlaceholder(value interface{}) error { i.WriteString("NULL") return nil } - return i.encodePlaceholder(v.Elem().Interface()) + return i.encodePlaceholder(v.Elem().Interface(), topLevel) } return ErrNotSupported } diff --git a/interpolate_test.go b/interpolate_test.go index eca55e02..30e8c3c2 100644 --- a/interpolate_test.go +++ b/interpolate_test.go @@ -6,7 +6,7 @@ import ( "time" "github.com/gocraft/dbr/dialect" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestInterpolateIgnoreBinary(t *testing.T) { @@ -47,11 +47,11 @@ func TestInterpolateIgnoreBinary(t *testing.T) { IgnoreBinary: true, } - err := i.interpolate(test.query, test.value) - assert.NoError(t, err) + err := i.interpolate(test.query, test.value, true) + require.NoError(t, err) - assert.Equal(t, test.wantQuery, i.String()) - assert.Equal(t, test.wantValue, i.Value()) + require.Equal(t, test.wantQuery, i.String()) + require.Equal(t, test.wantValue, i.Value()) } } @@ -104,7 +104,7 @@ func TestInterpolateForDialect(t *testing.T) { { query: "?", value: []interface{}{Select("a").From("table")}, - want: "(SELECT a FROM table)", + want: "SELECT a FROM table", }, { query: "?", @@ -136,10 +136,20 @@ func TestInterpolateForDialect(t *testing.T) { value: []interface{}{(*int64)(nil)}, want: "NULL", }, + { + query: "???? ? ?? ? ??", + value: []interface{}{1, 2}, + want: "?? 1 ? 2 ?", + }, + { + query: "???", + value: []interface{}{1}, + want: "?1", + }, } { s, err := InterpolateForDialect(test.query, test.value, dialect.MySQL) - assert.NoError(t, err) - assert.Equal(t, test.want, s) + require.NoError(t, err) + require.Equal(t, test.want, s) } } @@ -147,17 +157,20 @@ func TestInterpolateForDialect(t *testing.T) { // more information on the source and the strings themselves. func TestCommonSQLInjections(t *testing.T) { for _, sess := range testSession { + reset(t, sess) + for _, injectionAttempt := range strings.Split(injectionAttempts, "\n") { // Create a user with the attempted injection as the email address _, err := sess.InsertInto("dbr_people"). Pair("name", injectionAttempt). Exec() - assert.NoError(t, err) + require.NoError(t, err) // SELECT the name back and ensure it's equal to the injection attempt var name string - err = sess.Select("name").From("dbr_people").OrderDir("id", false).Limit(1).LoadValue(&name) - assert.Equal(t, injectionAttempt, name) + err = sess.Select("name").From("dbr_people").OrderDesc("id").Limit(1).LoadOne(&name) + require.NoError(t, err) + require.Equal(t, injectionAttempt, name) } } } diff --git a/load.go b/load.go index 14d69188..2bc82a41 100644 --- a/load.go +++ b/load.go @@ -5,7 +5,29 @@ import ( "reflect" ) -// Load loads any value from sql.Rows +type interfaceLoader struct { + v interface{} + typ reflect.Type +} + +func InterfaceLoader(value interface{}, concreteType interface{}) interface{} { + return interfaceLoader{value, reflect.TypeOf(concreteType)} +} + +// Load loads any value from sql.Rows. +// +// value can be: +// +// 1. simple type like int64, string, etc. +// +// 2. sql.Scanner, which allows loading with custom types. +// +// 3. map; the first column from SQL result loaded to the key, +// and the rest of columns will be loaded into the value. +// This is useful to dedup SQL result with first column. +// +// 4. map of slice; like map, values with the same key are +// collected with a slice. func Load(rows *sql.Rows, value interface{}) (int, error) { defer rows.Close() @@ -13,32 +35,89 @@ func Load(rows *sql.Rows, value interface{}) (int, error) { if err != nil { return 0, err } + ptr := make([]interface{}, len(column)) + + var v reflect.Value + var elemType reflect.Type + + if il, ok := value.(interfaceLoader); ok { + v = reflect.ValueOf(il.v) + elemType = il.typ + } else { + v = reflect.ValueOf(value) + } - v := reflect.ValueOf(value) if v.Kind() != reflect.Ptr || v.IsNil() { return 0, ErrInvalidPointer } v = v.Elem() - isSlice := v.Kind() == reflect.Slice && v.Type().Elem().Kind() != reflect.Uint8 + isScanner := v.Addr().Type().Implements(typeScanner) + isSlice := v.Kind() == reflect.Slice && v.Type().Elem().Kind() != reflect.Uint8 && !isScanner + isMap := v.Kind() == reflect.Map && !isScanner + isMapOfSlices := isMap && v.Type().Elem().Kind() == reflect.Slice && v.Type().Elem().Elem().Kind() != reflect.Uint8 + if isMap { + v.Set(reflect.MakeMap(v.Type())) + } + + s := newTagStore() count := 0 for rows.Next() { - var elem reflect.Value - if isSlice { - elem = reflect.New(v.Type().Elem()).Elem() + var elem, keyElem reflect.Value + + if elemType != nil { + elem = reflectAlloc(elemType) + } else if isMapOfSlices { + elem = reflectAlloc(v.Type().Elem().Elem()) + } else if isSlice || isMap { + elem = reflectAlloc(v.Type().Elem()) } else { elem = v } - ptr, err := findPtr(column, elem) - if err != nil { - return 0, err + + if isMap { + err := s.findPtr(elem, column[1:], ptr[1:]) + if err != nil { + return 0, err + } + keyElem = reflectAlloc(v.Type().Key()) + err = s.findPtr(keyElem, column[:1], ptr[:1]) + if err != nil { + return 0, err + } + } else { + err := s.findPtr(elem, column, ptr) + if err != nil { + return 0, err + } + } + + // Before scanning, set nil pointer to dummy dest. + // After that, reset pointers to nil for the next batch. + for i := range ptr { + if ptr[i] == nil { + ptr[i] = dummyDest + } } err = rows.Scan(ptr...) if err != nil { return 0, err } + for i := range ptr { + ptr[i] = nil + } + count++ + if isSlice { v.Set(reflect.Append(v, elem)) + } else if isMapOfSlices { + s := v.MapIndex(keyElem) + if !s.IsValid() { + s = reflect.Zero(v.Type().Elem()) + } + v.SetMapIndex(keyElem, reflect.Append(s, elem)) + } else if isMap { + v.SetMapIndex(keyElem, elem) } else { break } @@ -46,6 +125,13 @@ func Load(rows *sql.Rows, value interface{}) (int, error) { return count, nil } +func reflectAlloc(typ reflect.Type) reflect.Value { + if typ.Kind() == reflect.Ptr { + return reflect.New(typ.Elem()) + } + return reflect.New(typ).Elem() +} + type dummyScanner struct{} func (dummyScanner) Scan(interface{}) error { @@ -56,28 +142,3 @@ var ( dummyDest sql.Scanner = dummyScanner{} typeScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() ) - -func findPtr(column []string, value reflect.Value) ([]interface{}, error) { - if value.Addr().Type().Implements(typeScanner) { - return []interface{}{value.Addr().Interface()}, nil - } - switch value.Kind() { - case reflect.Struct: - var ptr []interface{} - m := structMap(value) - for _, key := range column { - if val, ok := m[key]; ok { - ptr = append(ptr, val.Addr().Interface()) - } else { - ptr = append(ptr, dummyDest) - } - } - return ptr, nil - case reflect.Ptr: - if value.IsNil() { - value.Set(reflect.New(value.Type().Elem())) - } - return findPtr(column, value.Elem()) - } - return []interface{}{value.Addr().Interface()}, nil -} diff --git a/load_benchmark_test.go b/load_benchmark_test.go new file mode 100644 index 00000000..ca9985c3 --- /dev/null +++ b/load_benchmark_test.go @@ -0,0 +1,76 @@ +package dbr + +import ( + "context" + "fmt" + "testing" + + "github.com/jmoiron/sqlx" + "github.com/stretchr/testify/require" +) + +func BenchmarkLoadValues(b *testing.B) { + sess := mysqlSession + for _, v := range []string{ + `DROP TABLE IF EXISTS suggestions`, + `CREATE TABLE suggestions ( + id serial PRIMARY KEY, + title varchar(255), + body text + )`, + } { + _, err := sess.Exec(v) + require.NoError(b, err) + } + tx, err := sess.Begin() + require.NoError(b, err) + + const maxRows = 100000 + + for i := 0; i < maxRows; i++ { + _, err := tx.InsertInto("suggestions"). + Columns("title", "body"). + Values("title", "body"). + Exec() + require.NoError(b, err) + } + err = tx.Commit() + require.NoError(b, err) + + type Suggestion struct { + Title *string + Body *string + } + for n := 10; n <= maxRows; n *= 10 { + query := fmt.Sprintf("SELECT * FROM suggestions ORDER BY id ASC LIMIT %d", n) + + b.Run(fmt.Sprintf("sqlx_%d", n), func(b *testing.B) { + b.StopTimer() + db, err := sqlx.Connect("mysql", mysqlDSN) + require.NoError(b, err) + db = db.Unsafe() + defer db.Close() + + for i := 0; i < b.N; i++ { + var suggs []*Suggestion + b.StartTimer() + err := db.SelectContext(context.Background(), &suggs, query) + b.StopTimer() + require.NoError(b, err) + require.Len(b, suggs, n) + } + }) + b.Run(fmt.Sprintf("dbr_%d", n), func(b *testing.B) { + b.StopTimer() + + for i := 0; i < b.N; i++ { + var suggs []*Suggestion + b.StartTimer() + _, err := sess.SelectBySql(query).LoadContext(context.Background(), &suggs) + b.StopTimer() + require.NoError(b, err) + require.Len(b, suggs, n) + } + }) + } +} diff --git a/now.go b/now.go index c1b19a17..67705155 100644 --- a/now.go +++ b/now.go @@ -5,7 +5,7 @@ import ( "time" ) -// Now is a value that serializes to the current time +// Now is a value that serializes to the current time in UTC. var Now = nowSentinel{} const timeFormat = "2006-01-02 15:04:05.000000" diff --git a/postgres_bytea_benchmark_test.go b/postgres_bytea_benchmark_test.go new file mode 100644 index 00000000..17ee2e2d --- /dev/null +++ b/postgres_bytea_benchmark_test.go @@ -0,0 +1,35 @@ +package dbr + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func BenchmarkByteaNoBinaryEncode(b *testing.B) { + benchmarkBytea(b, postgresSession) +} + +func BenchmarkByteaBinaryEncode(b *testing.B) { + benchmarkBytea(b, postgresBinarySession) +} + +func benchmarkBytea(b *testing.B, sess *Session) { + data := bytes.Repeat([]byte("0123456789"), 1000) + for _, v := range []string{ + `DROP TABLE IF EXISTS bytea_table`, + `CREATE TABLE bytea_table ( + val bytea + )`, + } { + _, err := sess.Exec(v) + require.NoError(b, err) + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := sess.InsertInto("bytea_table").Pair("val", data).Exec() + require.NoError(b, err) + } +} diff --git a/select.go b/select.go index dfe7e22f..a5836247 100644 --- a/select.go +++ b/select.go @@ -1,9 +1,16 @@ package dbr -import "fmt" +import ( + "context" + "strconv" +) -// SelectStmt builds `SELECT ...` +// SelectStmt builds `SELECT ...`. type SelectStmt struct { + runner + EventReceiver + Dialect + raw IsDistinct bool @@ -21,7 +28,8 @@ type SelectStmt struct { OffsetCount int64 } -// Build builds `SELECT ...` in dialect +type SelectBuilder = SelectStmt + func (b *SelectStmt) Build(d Dialect, buf Buffer) error { if b.raw.Query != "" { return b.raw.Build(d, buf) @@ -115,17 +123,17 @@ func (b *SelectStmt) Build(d Dialect, buf Buffer) error { if b.LimitCount >= 0 { buf.WriteString(" LIMIT ") - buf.WriteString(fmt.Sprint(b.LimitCount)) + buf.WriteString(strconv.FormatInt(b.LimitCount, 10)) } if b.OffsetCount >= 0 { buf.WriteString(" OFFSET ") - buf.WriteString(fmt.Sprint(b.OffsetCount)) + buf.WriteString(strconv.FormatInt(b.OffsetCount, 10)) } return nil } -// Select creates a SelectStmt +// Select creates a SelectStmt. func Select(column ...interface{}) *SelectStmt { return &SelectStmt{ Column: column, @@ -134,13 +142,33 @@ func Select(column ...interface{}) *SelectStmt { } } -// From specifies table -func (b *SelectStmt) From(table interface{}) *SelectStmt { - b.Table = table +func prepareSelect(a []string) []interface{} { + b := make([]interface{}, len(a)) + for i := range a { + b[i] = a[i] + } + return b +} + +// Select creates a SelectStmt. +func (sess *Session) Select(column ...string) *SelectStmt { + b := Select(prepareSelect(column)...) + b.runner = sess + b.EventReceiver = sess + b.Dialect = sess.Dialect + return b +} + +// Select creates a SelectStmt. +func (tx *Tx) Select(column ...string) *SelectStmt { + b := Select(prepareSelect(column)...) + b.runner = tx + b.EventReceiver = tx + b.Dialect = tx.Dialect return b } -// SelectBySql creates a SelectStmt from raw query +// SelectBySql creates a SelectStmt from raw query. func SelectBySql(query string, value ...interface{}) *SelectStmt { return &SelectStmt{ raw: raw{ @@ -152,13 +180,38 @@ func SelectBySql(query string, value ...interface{}) *SelectStmt { } } -// Distinct adds `DISTINCT` +// SelectBySql creates a SelectStmt from raw query. +func (sess *Session) SelectBySql(query string, value ...interface{}) *SelectStmt { + b := SelectBySql(query, value...) + b.runner = sess + b.EventReceiver = sess + b.Dialect = sess.Dialect + return b +} + +// SelectBySql creates a SelectStmt from raw query. +func (tx *Tx) SelectBySql(query string, value ...interface{}) *SelectStmt { + b := SelectBySql(query, value...) + b.runner = tx + b.EventReceiver = tx + b.Dialect = tx.Dialect + return b +} + +// From specifies table to select from. +// table can be Builder like SelectStmt, or string. +func (b *SelectStmt) From(table interface{}) *SelectStmt { + b.Table = table + return b +} + func (b *SelectStmt) Distinct() *SelectStmt { b.IsDistinct = true return b } -// Where adds a where condition +// Where adds a where condition. +// query can be Builder or string. value is used only if query type is string. func (b *SelectStmt) Where(query interface{}, value ...interface{}) *SelectStmt { switch query := query.(type) { case string: @@ -169,7 +222,8 @@ func (b *SelectStmt) Where(query interface{}, value ...interface{}) *SelectStmt return b } -// Having adds a having condition +// Having adds a having condition. +// query can be Builder or string. value is used only if query type is string. func (b *SelectStmt) Having(query interface{}, value ...interface{}) *SelectStmt { switch query := query.(type) { case string: @@ -180,7 +234,7 @@ func (b *SelectStmt) Having(query interface{}, value ...interface{}) *SelectStmt return b } -// GroupBy specifies columns for grouping +// GroupBy specifies columns for grouping. func (b *SelectStmt) GroupBy(col ...string) *SelectStmt { for _, group := range col { b.Group = append(b.Group, Expr(group)) @@ -188,7 +242,6 @@ func (b *SelectStmt) GroupBy(col ...string) *SelectStmt { return b } -// OrderBy specifies columns for ordering func (b *SelectStmt) OrderAsc(col string) *SelectStmt { b.Order = append(b.Order, order(col, asc)) return b @@ -199,40 +252,98 @@ func (b *SelectStmt) OrderDesc(col string) *SelectStmt { return b } -// Limit adds limit +// OrderBy specifies columns for ordering. +func (b *SelectStmt) OrderBy(col string) *SelectStmt { + b.Order = append(b.Order, Expr(col)) + return b +} + func (b *SelectStmt) Limit(n uint64) *SelectStmt { b.LimitCount = int64(n) return b } -// Offset adds offset func (b *SelectStmt) Offset(n uint64) *SelectStmt { b.OffsetCount = int64(n) return b } -// Join joins table on condition +// Paginate fetches a page in a naive way for a small set of data. +func (b *SelectStmt) Paginate(page, perPage uint64) *SelectStmt { + b.Limit(perPage) + b.Offset((page - 1) * perPage) + return b +} + +// OrderDir is a helper for OrderAsc and OrderDesc. +func (b *SelectStmt) OrderDir(col string, isAsc bool) *SelectStmt { + if isAsc { + b.OrderAsc(col) + } else { + b.OrderDesc(col) + } + return b +} + +// Join add inner-join. +// on can be Builder or string. func (b *SelectStmt) Join(table, on interface{}) *SelectStmt { b.JoinTable = append(b.JoinTable, join(inner, table, on)) return b } +// LeftJoin add left-join. +// on can be Builder or string. func (b *SelectStmt) LeftJoin(table, on interface{}) *SelectStmt { b.JoinTable = append(b.JoinTable, join(left, table, on)) return b } +// RightJoin add right-join. +// on can be Builder or string. func (b *SelectStmt) RightJoin(table, on interface{}) *SelectStmt { b.JoinTable = append(b.JoinTable, join(right, table, on)) return b } +// FullJoin add full-join. +// on can be Builder or string. func (b *SelectStmt) FullJoin(table, on interface{}) *SelectStmt { b.JoinTable = append(b.JoinTable, join(full, table, on)) return b } -// As creates alias for select statement +// As creates alias for select statement. func (b *SelectStmt) As(alias string) Builder { return as(b, alias) } + +func (b *SelectStmt) LoadOneContext(ctx context.Context, value interface{}) error { + count, err := query(ctx, b.runner, b.EventReceiver, b, b.Dialect, value) + if err != nil { + return err + } + if count == 0 { + return ErrNotFound + } + return nil +} + +// LoadOne loads SQL result into go variable that is not a slice. +// Unlike Load, it returns ErrNotFound if the SQL result row count is 0. +// +// See https://godoc.org/github.com/gocraft/dbr#Load. +func (b *SelectStmt) LoadOne(value interface{}) error { + return b.LoadOneContext(context.Background(), value) +} + +func (b *SelectStmt) LoadContext(ctx context.Context, value interface{}) (int, error) { + return query(ctx, b.runner, b.EventReceiver, b, b.Dialect, value) +} + +// Load loads multi-row SQL result into a slice of go variables. +// +// See https://godoc.org/github.com/gocraft/dbr#Load. +func (b *SelectStmt) Load(value interface{}) (int, error) { + return b.LoadContext(context.Background(), value) +} diff --git a/select_builder.go b/select_builder.go deleted file mode 100644 index 95ad69c8..00000000 --- a/select_builder.go +++ /dev/null @@ -1,162 +0,0 @@ -package dbr - -type SelectBuilder struct { - runner - EventReceiver - Dialect Dialect - - *SelectStmt -} - -func prepareSelect(a []string) []interface{} { - b := make([]interface{}, len(a)) - for i := range a { - b[i] = a[i] - } - return b -} - -func (sess *Session) Select(column ...string) *SelectBuilder { - return &SelectBuilder{ - runner: sess, - EventReceiver: sess, - Dialect: sess.Dialect, - SelectStmt: Select(prepareSelect(column)...), - } -} - -func (tx *Tx) Select(column ...string) *SelectBuilder { - return &SelectBuilder{ - runner: tx, - EventReceiver: tx, - Dialect: tx.Dialect, - SelectStmt: Select(prepareSelect(column)...), - } -} - -func (sess *Session) SelectBySql(query string, value ...interface{}) *SelectBuilder { - return &SelectBuilder{ - runner: sess, - EventReceiver: sess, - Dialect: sess.Dialect, - SelectStmt: SelectBySql(query, value...), - } -} - -func (tx *Tx) SelectBySql(query string, value ...interface{}) *SelectBuilder { - return &SelectBuilder{ - runner: tx, - EventReceiver: tx, - Dialect: tx.Dialect, - SelectStmt: SelectBySql(query, value...), - } -} - -func (b *SelectBuilder) Load(value interface{}) (int, error) { - return query(b.runner, b.EventReceiver, b, b.Dialect, value) -} - -func (b *SelectBuilder) LoadStruct(value interface{}) error { - count, err := query(b.runner, b.EventReceiver, b, b.Dialect, value) - if err != nil { - return err - } - if count == 0 { - return ErrNotFound - } - return nil -} - -func (b *SelectBuilder) LoadStructs(value interface{}) (int, error) { - return query(b.runner, b.EventReceiver, b, b.Dialect, value) -} - -func (b *SelectBuilder) LoadValue(value interface{}) error { - count, err := query(b.runner, b.EventReceiver, b, b.Dialect, value) - if err != nil { - return err - } - if count == 0 { - return ErrNotFound - } - return nil -} - -func (b *SelectBuilder) LoadValues(value interface{}) (int, error) { - return query(b.runner, b.EventReceiver, b, b.Dialect, value) -} - -func (b *SelectBuilder) Join(table, on interface{}) *SelectBuilder { - b.SelectStmt.Join(table, on) - return b -} - -func (b *SelectBuilder) LeftJoin(table, on interface{}) *SelectBuilder { - b.SelectStmt.LeftJoin(table, on) - return b -} - -func (b *SelectBuilder) RightJoin(table, on interface{}) *SelectBuilder { - b.SelectStmt.RightJoin(table, on) - return b -} - -func (b *SelectBuilder) FullJoin(table, on interface{}) *SelectBuilder { - b.SelectStmt.FullJoin(table, on) - return b -} - -func (b *SelectBuilder) Distinct() *SelectBuilder { - b.SelectStmt.Distinct() - return b -} - -func (b *SelectBuilder) From(table interface{}) *SelectBuilder { - b.SelectStmt.From(table) - return b -} - -func (b *SelectBuilder) GroupBy(col ...string) *SelectBuilder { - b.SelectStmt.GroupBy(col...) - return b -} - -func (b *SelectBuilder) Having(query interface{}, value ...interface{}) *SelectBuilder { - b.SelectStmt.Having(query, value...) - return b -} - -func (b *SelectBuilder) Limit(n uint64) *SelectBuilder { - b.SelectStmt.Limit(n) - return b -} - -func (b *SelectBuilder) Offset(n uint64) *SelectBuilder { - b.SelectStmt.Offset(n) - return b -} - -func (b *SelectBuilder) OrderDir(col string, isAsc bool) *SelectBuilder { - if isAsc { - b.SelectStmt.OrderAsc(col) - } else { - b.SelectStmt.OrderDesc(col) - } - return b -} - -func (b *SelectBuilder) Paginate(page, perPage uint64) *SelectBuilder { - b.Limit(perPage) - b.Offset((page - 1) * perPage) - return b -} - -func (b *SelectBuilder) OrderBy(col string) *SelectBuilder { - b.SelectStmt.Order = append(b.SelectStmt.Order, Expr(col)) - return b -} - -func (b *SelectBuilder) Where(query interface{}, value ...interface{}) *SelectBuilder { - b.SelectStmt.Where(query, value...) - return b -} diff --git a/select_return.go b/select_return.go index d55c208d..c4a3c698 100644 --- a/select_return.go +++ b/select_return.go @@ -1,66 +1,43 @@ package dbr -// -// These are a set of helpers that just call LoadValue and return the value. -// They return (_, ErrNotFound) if nothing was found. -// - -// The inclusion of these helpers in the package is not an obvious choice: -// Benefits: -// - slight increase in code clarity/conciseness b/c you can use ":=" to define the variable -// -// count, err := d.Select("COUNT(*)").From("users").Where("x = ?", x).ReturnInt64() -// -// vs -// -// var count int64 -// err := d.Select("COUNT(*)").From("users").Where("x = ?", x).LoadValue(&count) -// -// Downsides: -// - very small increase in code cost, although it's not complex code -// - increase in conceptual model / API footprint when presenting the package to new users -// - no functionality that you can't achieve calling .LoadValue directly. -// - There's a lot of possible types. Do we want to include ALL of them? u?int{8,16,32,64}?, strings, null varieties, etc. -// - Let's just do the common, non-null varieties. - -// ReturnInt64 executes the SelectStmt and returns the value as an int64 -func (b *SelectBuilder) ReturnInt64() (int64, error) { +// ReturnInt64 executes the SelectStmt and returns the value as an int64. +func (b *SelectStmt) ReturnInt64() (int64, error) { var v int64 - err := b.LoadValue(&v) + err := b.LoadOne(&v) return v, err } -// ReturnInt64s executes the SelectStmt and returns the value as a slice of int64s -func (b *SelectBuilder) ReturnInt64s() ([]int64, error) { +// ReturnInt64s executes the SelectStmt and returns the value as a slice of int64s. +func (b *SelectStmt) ReturnInt64s() ([]int64, error) { var v []int64 - _, err := b.LoadValues(&v) + _, err := b.Load(&v) return v, err } -// ReturnUint64 executes the SelectStmt and returns the value as an uint64 -func (b *SelectBuilder) ReturnUint64() (uint64, error) { +// ReturnUint64 executes the SelectStmt and returns the value as an uint64. +func (b *SelectStmt) ReturnUint64() (uint64, error) { var v uint64 - err := b.LoadValue(&v) + err := b.LoadOne(&v) return v, err } -// ReturnUint64s executes the SelectStmt and returns the value as a slice of uint64s -func (b *SelectBuilder) ReturnUint64s() ([]uint64, error) { +// ReturnUint64s executes the SelectStmt and returns the value as a slice of uint64s. +func (b *SelectStmt) ReturnUint64s() ([]uint64, error) { var v []uint64 - _, err := b.LoadValues(&v) + _, err := b.Load(&v) return v, err } -// ReturnString executes the SelectStmt and returns the value as a string -func (b *SelectBuilder) ReturnString() (string, error) { +// ReturnString executes the SelectStmt and returns the value as a string. +func (b *SelectStmt) ReturnString() (string, error) { var v string - err := b.LoadValue(&v) + err := b.LoadOne(&v) return v, err } -// ReturnStrings executes the SelectStmt and returns the value as a slice of strings -func (b *SelectBuilder) ReturnStrings() ([]string, error) { +// ReturnStrings executes the SelectStmt and returns the value as a slice of strings. +func (b *SelectStmt) ReturnStrings() ([]string, error) { var v []string - _, err := b.LoadValues(&v) + _, err := b.Load(&v) return v, err } diff --git a/select_test.go b/select_test.go index 94928017..16566d30 100644 --- a/select_test.go +++ b/select_test.go @@ -3,8 +3,10 @@ package dbr import ( "testing" + "github.com/lib/pq" + "github.com/gocraft/dbr/dialect" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSelectStmt(t *testing.T) { @@ -20,10 +22,10 @@ func TestSelectStmt(t *testing.T) { Limit(3). Offset(4) err := builder.Build(dialect.MySQL, buf) - assert.NoError(t, err) - assert.Equal(t, "SELECT DISTINCT a, b FROM ? LEFT JOIN `table2` ON table.a1 = table.a2 WHERE (`c` = ?) GROUP BY d HAVING (`e` = ?) ORDER BY f ASC LIMIT 3 OFFSET 4", buf.String()) + require.NoError(t, err) + require.Equal(t, "SELECT DISTINCT a, b FROM ? LEFT JOIN `table2` ON table.a1 = table.a2 WHERE (`c` = ?) GROUP BY d HAVING (`e` = ?) ORDER BY f ASC LIMIT 3 OFFSET 4", buf.String()) // two functions cannot be compared - assert.Equal(t, 3, len(buf.Value())) + require.Equal(t, 3, len(buf.Value())) } func BenchmarkSelectSQL(b *testing.B) { @@ -32,3 +34,132 @@ func BenchmarkSelectSQL(b *testing.B) { Select("a", "b").From("table").Where(Eq("c", 1)).OrderAsc("d").Build(dialect.MySQL, buf) } } + +type stringSliceWithSQLScanner []string + +func (ss *stringSliceWithSQLScanner) Scan(src interface{}) error { + *ss = append(*ss, "called") + return nil +} + +func TestSliceWithSQLScannerSelect(t *testing.T) { + for _, sess := range testSession { + reset(t, sess) + + _, err := sess.InsertInto("dbr_people"). + Columns("name", "email"). + Values("test1", "test1@test.com"). + Values("test2", "test2@test.com"). + Values("test3", "test3@test.com"). + Exec() + + //plain string slice (original behavior) + var stringSlice []string + cnt, err := sess.Select("name").From("dbr_people").Load(&stringSlice) + + require.NoError(t, err) + require.Equal(t, 3, cnt) + require.Len(t, stringSlice, 3) + + //string slice with sql.Scanner implemented, should act as a single record + var sliceScanner stringSliceWithSQLScanner + cnt, err = sess.Select("name").From("dbr_people").Load(&sliceScanner) + + require.NoError(t, err) + require.Equal(t, 1, cnt) + require.Len(t, sliceScanner, 1) + } +} + +func TestMaps(t *testing.T) { + for _, sess := range testSession { + reset(t, sess) + + _, err := sess.InsertInto("dbr_people"). + Columns("name", "email"). + Values("test1", "test1@test.com"). + Values("test2", "test2@test.com"). + Values("test2", "test3@test.com"). + Exec() + + var m map[string]string + cnt, err := sess.Select("email, name").From("dbr_people").Load(&m) + require.NoError(t, err) + require.Equal(t, 3, cnt) + require.Len(t, m, 3) + require.Equal(t, "test1", m["test1@test.com"]) + + var m2 map[int64]*dbrPerson + cnt, err = sess.Select("id, name, email").From("dbr_people").Load(&m2) + require.NoError(t, err) + require.Equal(t, 3, cnt) + require.Len(t, m2, 3) + require.Equal(t, "test1@test.com", m2[1].Email) + require.Equal(t, "test1", m2[1].Name) + // the id value is used as the map key, so it is not hydrated in the struct + require.Equal(t, int64(0), m2[1].Id) + + var m3 map[string][]string + cnt, err = sess.Select("name, email").From("dbr_people").OrderAsc("id").Load(&m3) + require.NoError(t, err) + require.Equal(t, 3, cnt) + require.Len(t, m3, 2) + require.Equal(t, []string{"test1@test.com"}, m3["test1"]) + require.Equal(t, []string{"test2@test.com", "test3@test.com"}, m3["test2"]) + + var set map[string]struct{} + cnt, err = sess.Select("name").From("dbr_people").Load(&set) + require.NoError(t, err) + require.Equal(t, 3, cnt) + require.Len(t, set, 2) + _, ok := set["test1"] + require.True(t, ok) + } +} + +func TestInterfaceLoader(t *testing.T) { + for _, sess := range testSession { + reset(t, sess) + + _, err := sess.InsertInto("dbr_people"). + Columns("name", "email"). + Values("test1", "test1@test.com"). + Values("test2", "test2@test.com"). + Values("test2", "test3@test.com"). + Exec() + + var m []interface{} + cnt, err := sess.Select("*").From("dbr_people").Load(InterfaceLoader(&m, dbrPerson{})) + require.NoError(t, err) + require.Equal(t, 3, cnt) + require.Len(t, m, 3) + person, ok := m[0].(dbrPerson) + require.True(t, ok) + require.Equal(t, "test1", person.Name) + } +} + +func TestPostgresArray(t *testing.T) { + sess := postgresSession + for _, v := range []string{ + `DROP TABLE IF EXISTS array_table`, + `CREATE TABLE array_table ( + val integer[] + )`, + } { + _, err := sess.Exec(v) + require.NoError(t, err) + } + + // INSERT INTO "array_table" ("val") VALUES ('{1,2,3}') + _, err := sess.InsertInto("array_table"). + Pair("val", pq.Array([]int64{1, 2, 3})). + Exec() + require.NoError(t, err) + + var ns []int64 + err = sess.Select("val").From("array_table").LoadOne(pq.Array(&ns)) + require.NoError(t, err) + + require.Equal(t, []int64{1, 2, 3}, ns) +} diff --git a/sqlmock_test.go b/sqlmock_test.go new file mode 100644 index 00000000..ff38ec7f --- /dev/null +++ b/sqlmock_test.go @@ -0,0 +1,32 @@ +package dbr + +import ( + "testing" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/gocraft/dbr/dialect" + "github.com/stretchr/testify/require" +) + +func TestSQLMock(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + + conn := &Connection{ + DB: db, + EventReceiver: &NullEventReceiver{}, + Dialect: dialect.MySQL, + } + sess := conn.NewSession(nil) + + mock.ExpectQuery("SELECT id FROM suggestions"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1).AddRow(2)) + id, err := sess.Select("id").From("suggestions").ReturnInt64s() + require.NoError(t, err) + require.Equal(t, []int64{1, 2}, id) + + mock.ExpectClose() + conn.Close() + + require.NoError(t, mock.ExpectationsWereMet()) +} diff --git a/transaction.go b/transaction.go index f6bafd19..e4f8f2ea 100644 --- a/transaction.go +++ b/transaction.go @@ -1,17 +1,27 @@ package dbr -import "database/sql" +import ( + "context" + "database/sql" + "time" +) -// Tx is a transaction for the given Session +// Tx is a transaction created by Session. type Tx struct { EventReceiver - Dialect Dialect + Dialect *sql.Tx + Timeout time.Duration } -// Begin creates a transaction for the given session -func (sess *Session) Begin() (*Tx, error) { - tx, err := sess.Connection.Begin() +// GetTimeout returns timeout enforced in Tx. +func (tx *Tx) GetTimeout() time.Duration { + return tx.Timeout +} + +// BeginTx creates a transaction with TxOptions. +func (sess *Session) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { + tx, err := sess.Connection.BeginTx(ctx, opts) if err != nil { return nil, sess.EventErr("dbr.begin.error", err) } @@ -21,10 +31,16 @@ func (sess *Session) Begin() (*Tx, error) { EventReceiver: sess, Dialect: sess.Dialect, Tx: tx, + Timeout: sess.GetTimeout(), }, nil } -// Commit finishes the transaction +// Begin creates a transaction for the given session. +func (sess *Session) Begin() (*Tx, error) { + return sess.BeginTx(context.Background(), nil) +} + +// Commit finishes the transaction. func (tx *Tx) Commit() error { err := tx.Tx.Commit() if err != nil { @@ -34,7 +50,7 @@ func (tx *Tx) Commit() error { return nil } -// Rollback cancels the transaction +// Rollback cancels the transaction. func (tx *Tx) Rollback() error { err := tx.Tx.Rollback() if err != nil { @@ -44,9 +60,13 @@ func (tx *Tx) Rollback() error { return nil } -// RollbackUnlessCommitted rollsback the transaction unless it has already been committed or rolled back. -// Useful to defer tx.RollbackUnlessCommitted() -- so you don't have to handle N failure cases -// Keep in mind the only way to detect an error on the rollback is via the event log. +// RollbackUnlessCommitted rollsback the transaction unless +// it has already been committed or rolled back. +// +// Useful to defer tx.RollbackUnlessCommitted(), so you don't +// have to handle N failure cases. +// Keep in mind the only way to detect an error on the rollback +// is via the event log. func (tx *Tx) RollbackUnlessCommitted() { err := tx.Tx.Rollback() if err == sql.ErrTxDone { diff --git a/transaction_test.go b/transaction_test.go index 3187ef61..b2819681 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -3,53 +3,57 @@ package dbr import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestTransactionCommit(t *testing.T) { for _, sess := range testSession { + reset(t, sess) + tx, err := sess.Begin() - assert.NoError(t, err) + require.NoError(t, err) defer tx.RollbackUnlessCommitted() - id := nextID() + id := 1 result, err := tx.InsertInto("dbr_people").Columns("id", "name", "email").Values(id, "Barack", "obama@whitehouse.gov").Exec() - assert.NoError(t, err) + require.NoError(t, err) rowsAffected, err := result.RowsAffected() - assert.NoError(t, err) - assert.EqualValues(t, 1, rowsAffected) + require.NoError(t, err) + require.Equal(t, int64(1), rowsAffected) err = tx.Commit() - assert.NoError(t, err) + require.NoError(t, err) var person dbrPerson - err = tx.Select("*").From("dbr_people").Where(Eq("id", id)).LoadStruct(&person) - assert.Error(t, err) + err = tx.Select("*").From("dbr_people").Where(Eq("id", id)).LoadOne(&person) + require.Error(t, err) } } func TestTransactionRollback(t *testing.T) { for _, sess := range testSession { + reset(t, sess) + tx, err := sess.Begin() - assert.NoError(t, err) + require.NoError(t, err) defer tx.RollbackUnlessCommitted() - id := nextID() + id := 1 result, err := tx.InsertInto("dbr_people").Columns("id", "name", "email").Values(id, "Barack", "obama@whitehouse.gov").Exec() - assert.NoError(t, err) + require.NoError(t, err) rowsAffected, err := result.RowsAffected() - assert.NoError(t, err) - assert.EqualValues(t, 1, rowsAffected) + require.NoError(t, err) + require.Equal(t, int64(1), rowsAffected) err = tx.Rollback() - assert.NoError(t, err) + require.NoError(t, err) var person dbrPerson - err = tx.Select("*").From("dbr_people").Where(Eq("id", id)).LoadStruct(&person) - assert.Error(t, err) + err = tx.Select("*").From("dbr_people").Where(Eq("id", id)).LoadOne(&person) + require.Error(t, err) } } diff --git a/types.go b/types.go index 4b58fd91..9c8ea369 100644 --- a/types.go +++ b/types.go @@ -12,22 +12,22 @@ import ( // Your app can use these Null types instead of the defaults. The sole benefit you get is a MarshalJSON method that is not retarded. // -// NullString is a type that can be null or a string +// NullString is a type that can be null or a string. type NullString struct { sql.NullString } -// NullFloat64 is a type that can be null or a float64 +// NullFloat64 is a type that can be null or a float64. type NullFloat64 struct { sql.NullFloat64 } -// NullInt64 is a type that can be null or an int +// NullInt64 is a type that can be null or an int. type NullInt64 struct { sql.NullInt64 } -// NullTime is a type that can be null or a time +// NullTime is a type that can be null or a time. type NullTime struct { Time time.Time Valid bool // Valid is true if Time is not NULL @@ -41,14 +41,14 @@ func (n NullTime) Value() (driver.Value, error) { return n.Time, nil } -// NullBool is a type that can be null or a bool +// NullBool is a type that can be null or a bool. type NullBool struct { sql.NullBool } var nullString = []byte("null") -// MarshalJSON correctly serializes a NullString to JSON +// MarshalJSON correctly serializes a NullString to JSON. func (n NullString) MarshalJSON() ([]byte, error) { if n.Valid { return json.Marshal(n.String) @@ -56,7 +56,7 @@ func (n NullString) MarshalJSON() ([]byte, error) { return nullString, nil } -// MarshalJSON correctly serializes a NullInt64 to JSON +// MarshalJSON correctly serializes a NullInt64 to JSON. func (n NullInt64) MarshalJSON() ([]byte, error) { if n.Valid { return json.Marshal(n.Int64) @@ -64,7 +64,7 @@ func (n NullInt64) MarshalJSON() ([]byte, error) { return nullString, nil } -// MarshalJSON correctly serializes a NullFloat64 to JSON +// MarshalJSON correctly serializes a NullFloat64 to JSON. func (n NullFloat64) MarshalJSON() ([]byte, error) { if n.Valid { return json.Marshal(n.Float64) @@ -72,7 +72,7 @@ func (n NullFloat64) MarshalJSON() ([]byte, error) { return nullString, nil } -// MarshalJSON correctly serializes a NullTime to JSON +// MarshalJSON correctly serializes a NullTime to JSON. func (n NullTime) MarshalJSON() ([]byte, error) { if n.Valid { return json.Marshal(n.Time) @@ -80,7 +80,7 @@ func (n NullTime) MarshalJSON() ([]byte, error) { return nullString, nil } -// MarshalJSON correctly serializes a NullBool to JSON +// MarshalJSON correctly serializes a NullBool to JSON. func (n NullBool) MarshalJSON() ([]byte, error) { if n.Valid { return json.Marshal(n.Bool) @@ -88,7 +88,7 @@ func (n NullBool) MarshalJSON() ([]byte, error) { return nullString, nil } -// UnmarshalJSON correctly deserializes a NullString from JSON +// UnmarshalJSON correctly deserializes a NullString from JSON. func (n *NullString) UnmarshalJSON(b []byte) error { var s interface{} if err := json.Unmarshal(b, &s); err != nil { @@ -97,16 +97,19 @@ func (n *NullString) UnmarshalJSON(b []byte) error { return n.Scan(s) } -// UnmarshalJSON correctly deserializes a NullInt64 from JSON +// UnmarshalJSON correctly deserializes a NullInt64 from JSON. func (n *NullInt64) UnmarshalJSON(b []byte) error { var s json.Number if err := json.Unmarshal(b, &s); err != nil { return err } + if s == "" { + return n.Scan(nil) + } return n.Scan(s) } -// UnmarshalJSON correctly deserializes a NullFloat64 from JSON +// UnmarshalJSON correctly deserializes a NullFloat64 from JSON. func (n *NullFloat64) UnmarshalJSON(b []byte) error { var s interface{} if err := json.Unmarshal(b, &s); err != nil { @@ -115,7 +118,7 @@ func (n *NullFloat64) UnmarshalJSON(b []byte) error { return n.Scan(s) } -// UnmarshalJSON correctly deserializes a NullTime from JSON +// UnmarshalJSON correctly deserializes a NullTime from JSON. func (n *NullTime) UnmarshalJSON(b []byte) error { // scan for null if bytes.Equal(b, nullString) { @@ -129,7 +132,7 @@ func (n *NullTime) UnmarshalJSON(b []byte) error { return n.Scan(t) } -// UnmarshalJSON correctly deserializes a NullBool from JSON +// UnmarshalJSON correctly deserializes a NullBool from JSON. func (n *NullBool) UnmarshalJSON(b []byte) error { var s interface{} if err := json.Unmarshal(b, &s); err != nil { @@ -138,26 +141,31 @@ func (n *NullBool) UnmarshalJSON(b []byte) error { return n.Scan(s) } +// NewNullInt64 creates a NullInt64 with Scan(). func NewNullInt64(v interface{}) (n NullInt64) { n.Scan(v) return } +// NewNullFloat64 creates a NullFloat64 with Scan(). func NewNullFloat64(v interface{}) (n NullFloat64) { n.Scan(v) return } +// NewNullString creates a NullString with Scan(). func NewNullString(v interface{}) (n NullString) { n.Scan(v) return } +// NewNullTime creates a NullTime with Scan(). func NewNullTime(v interface{}) (n NullTime) { n.Scan(v) return } +// NewNullBool creates a NullBool with Scan(). func NewNullBool(v interface{}) (n NullBool) { n.Scan(v) return diff --git a/types_test.go b/types_test.go index bbc92cd7..89d4dc8b 100644 --- a/types_test.go +++ b/types_test.go @@ -6,7 +6,7 @@ import ( "time" "github.com/gocraft/dbr/dialect" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -29,24 +29,54 @@ func TestNullTypesScanning(t *testing.T) { }, } { for _, sess := range testSession { - test.in.Id = nextID() + reset(t, sess) + + test.in.Id = 1 _, err := sess.InsertInto("null_types").Columns("id", "string_val", "int64_val", "float64_val", "time_val", "bool_val").Record(test.in).Exec() - assert.NoError(t, err) + require.NoError(t, err) var record nullTypedRecord - err = sess.Select("*").From("null_types").Where(Eq("id", test.in.Id)).LoadStruct(&record) - assert.NoError(t, err) + err = sess.Select("*").From("null_types").Where(Eq("id", test.in.Id)).LoadOne(&record) + require.NoError(t, err) if sess.Dialect == dialect.PostgreSQL { // TODO: https://github.com/lib/pq/issues/329 if !record.TimeVal.Time.IsZero() { record.TimeVal.Time = record.TimeVal.Time.UTC() } } - assert.Equal(t, test.in, record) + require.Equal(t, test.in, record) } } } +func TestNullInt64Unmarshal(t *testing.T) { + var test struct { + Num NullInt64 + } + err := json.Unmarshal([]byte(`{"num":null}`), &test) + require.NoError(t, err) + require.Equal(t, int64(0), test.Num.Int64) + require.False(t, test.Num.Valid) +} + +func TestNullTypesActuallyNullJSON(t *testing.T) { + var out struct { + Bool NullBool `json:"b"` + Float NullFloat64 `json:"f"` + String NullString `json:"s"` + Time NullTime `json:"t"` + Int NullInt64 `json:"i"` + } + jsonBs := []byte(`{"b":null,"f":null,"s":null,"t":null,"i":null}`) + err := json.Unmarshal(jsonBs, &out) + require.NoError(t, err) + require.False(t, out.Bool.Valid) + require.False(t, out.Float.Valid) + require.False(t, out.String.Valid) + require.False(t, out.Time.Valid) + require.False(t, out.Int.Valid) +} + func TestNullTypesJSON(t *testing.T) { for _, test := range []struct { in interface{} @@ -87,17 +117,17 @@ func TestNullTypesJSON(t *testing.T) { } { // marshal ptr b, err := json.Marshal(test.in) - assert.NoError(t, err) - assert.Equal(t, test.want, string(b)) + require.NoError(t, err) + require.Equal(t, test.want, string(b)) // marshal value b, err = json.Marshal(test.in2) - assert.NoError(t, err) - assert.Equal(t, test.want, string(b)) + require.NoError(t, err) + require.Equal(t, test.want, string(b)) // unmarshal err = json.Unmarshal(b, test.out) - assert.NoError(t, err) - assert.Equal(t, test.in, test.out) + require.NoError(t, err) + require.Equal(t, test.in, test.out) } } diff --git a/union.go b/union.go index 73db6a81..d0b2f8bb 100644 --- a/union.go +++ b/union.go @@ -5,6 +5,7 @@ type union struct { all bool } +// Union builds `... UNION ...`. func Union(builder ...Builder) interface { Builder As(string) Builder @@ -14,6 +15,7 @@ func Union(builder ...Builder) interface { } } +// UnionAll builds `... UNION ALL ...`. func UnionAll(builder ...Builder) interface { Builder As(string) Builder diff --git a/update.go b/update.go index 523a1184..81dd21d5 100644 --- a/update.go +++ b/update.go @@ -1,16 +1,27 @@ package dbr -// UpdateStmt builds `UPDATE ...` +import ( + "context" + "database/sql" + "strconv" +) + +// UpdateStmt builds `UPDATE ...`. type UpdateStmt struct { - raw + runner + EventReceiver + Dialect - Table string - Value map[string]interface{} + raw - WhereCond []Builder + Table string + Value map[string]interface{} + WhereCond []Builder + LimitCount int64 } -// Build builds `UPDATE ...` in dialect +type UpdateBuilder = UpdateStmt + func (b *UpdateStmt) Build(d Dialect, buf Buffer) error { if b.raw.Query != "" { return b.raw.Build(d, buf) @@ -48,29 +59,74 @@ func (b *UpdateStmt) Build(d Dialect, buf Buffer) error { return err } } + + if b.LimitCount >= 0 { + buf.WriteString(" LIMIT ") + buf.WriteString(strconv.FormatInt(b.LimitCount, 10)) + } + return nil } -// Update creates an UpdateStmt +// Update creates an UpdateStmt. func Update(table string) *UpdateStmt { return &UpdateStmt{ - Table: table, - Value: make(map[string]interface{}), + Table: table, + Value: make(map[string]interface{}), + LimitCount: -1, } } -// UpdateBySql creates an UpdateStmt with raw query +// Update creates an UpdateStmt. +func (sess *Session) Update(table string) *UpdateStmt { + b := Update(table) + b.runner = sess + b.EventReceiver = sess + b.Dialect = sess.Dialect + return b +} + +// Update creates an UpdateStmt. +func (tx *Tx) Update(table string) *UpdateStmt { + b := Update(table) + b.runner = tx + b.EventReceiver = tx + b.Dialect = tx.Dialect + return b +} + +// UpdateBySql creates an UpdateStmt with raw query. func UpdateBySql(query string, value ...interface{}) *UpdateStmt { return &UpdateStmt{ raw: raw{ Query: query, Value: value, }, - Value: make(map[string]interface{}), + Value: make(map[string]interface{}), + LimitCount: -1, } } -// Where adds a where condition +// UpdateBySql creates an UpdateStmt with raw query. +func (sess *Session) UpdateBySql(query string, value ...interface{}) *UpdateStmt { + b := UpdateBySql(query, value...) + b.runner = sess + b.EventReceiver = sess + b.Dialect = sess.Dialect + return b +} + +// UpdateBySql creates an UpdateStmt with raw query. +func (tx *Tx) UpdateBySql(query string, value ...interface{}) *UpdateStmt { + b := UpdateBySql(query, value...) + b.runner = tx + b.EventReceiver = tx + b.Dialect = tx.Dialect + return b +} + +// Where adds a where condition. +// query can be Builder or string. value is used only if query type is string. func (b *UpdateStmt) Where(query interface{}, value ...interface{}) *UpdateStmt { switch query := query.(type) { case string: @@ -81,16 +137,29 @@ func (b *UpdateStmt) Where(query interface{}, value ...interface{}) *UpdateStmt return b } -// Set specifies a key-value pair +// Set updates column with value. func (b *UpdateStmt) Set(column string, value interface{}) *UpdateStmt { b.Value[column] = value return b } -// SetMap specifies a list of key-value pair +// SetMap specifies a map of (column, value) to update in bulk. func (b *UpdateStmt) SetMap(m map[string]interface{}) *UpdateStmt { for col, val := range m { b.Set(col, val) } return b } + +func (b *UpdateStmt) Limit(n uint64) *UpdateStmt { + b.LimitCount = int64(n) + return b +} + +func (b *UpdateStmt) Exec() (sql.Result, error) { + return b.ExecContext(context.Background()) +} + +func (b *UpdateStmt) ExecContext(ctx context.Context) (sql.Result, error) { + return exec(ctx, b.runner, b.EventReceiver, b, b.Dialect) +} diff --git a/update_builder.go b/update_builder.go deleted file mode 100644 index f41a0344..00000000 --- a/update_builder.go +++ /dev/null @@ -1,92 +0,0 @@ -package dbr - -import ( - "database/sql" - "fmt" -) - -type UpdateBuilder struct { - runner - EventReceiver - Dialect Dialect - - *UpdateStmt - - LimitCount int64 -} - -func (sess *Session) Update(table string) *UpdateBuilder { - return &UpdateBuilder{ - runner: sess, - EventReceiver: sess, - Dialect: sess.Dialect, - UpdateStmt: Update(table), - LimitCount: -1, - } -} - -func (tx *Tx) Update(table string) *UpdateBuilder { - return &UpdateBuilder{ - runner: tx, - EventReceiver: tx, - Dialect: tx.Dialect, - UpdateStmt: Update(table), - LimitCount: -1, - } -} - -func (sess *Session) UpdateBySql(query string, value ...interface{}) *UpdateBuilder { - return &UpdateBuilder{ - runner: sess, - EventReceiver: sess, - Dialect: sess.Dialect, - UpdateStmt: UpdateBySql(query, value...), - LimitCount: -1, - } -} - -func (tx *Tx) UpdateBySql(query string, value ...interface{}) *UpdateBuilder { - return &UpdateBuilder{ - runner: tx, - EventReceiver: tx, - Dialect: tx.Dialect, - UpdateStmt: UpdateBySql(query, value...), - LimitCount: -1, - } -} - -func (b *UpdateBuilder) Exec() (sql.Result, error) { - return exec(b.runner, b.EventReceiver, b, b.Dialect) -} - -func (b *UpdateBuilder) Set(column string, value interface{}) *UpdateBuilder { - b.UpdateStmt.Set(column, value) - return b -} - -func (b *UpdateBuilder) SetMap(m map[string]interface{}) *UpdateBuilder { - b.UpdateStmt.SetMap(m) - return b -} - -func (b *UpdateBuilder) Where(query interface{}, value ...interface{}) *UpdateBuilder { - b.UpdateStmt.Where(query, value...) - return b -} - -func (b *UpdateBuilder) Limit(n uint64) *UpdateBuilder { - b.LimitCount = int64(n) - return b -} - -func (b *UpdateBuilder) Build(d Dialect, buf Buffer) error { - err := b.UpdateStmt.Build(b.Dialect, buf) - if err != nil { - return err - } - if b.LimitCount >= 0 { - buf.WriteString(" LIMIT ") - buf.WriteString(fmt.Sprint(b.LimitCount)) - } - return nil -} diff --git a/update_test.go b/update_test.go index bc0007cd..851a67dd 100644 --- a/update_test.go +++ b/update_test.go @@ -4,17 +4,17 @@ import ( "testing" "github.com/gocraft/dbr/dialect" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUpdateStmt(t *testing.T) { buf := NewBuffer() builder := Update("table").Set("a", 1).Where(Eq("b", 2)) err := builder.Build(dialect.MySQL, buf) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, "UPDATE `table` SET `a` = ? WHERE (`b` = ?)", buf.String()) - assert.Equal(t, []interface{}{1, 2}, buf.Value()) + require.Equal(t, "UPDATE `table` SET `a` = ? WHERE (`b` = ?)", buf.String()) + require.Equal(t, []interface{}{1, 2}, buf.Value()) } func BenchmarkUpdateValuesSQL(b *testing.B) { diff --git a/util.go b/util.go index 69f4b639..81af3e17 100644 --- a/util.go +++ b/util.go @@ -1,51 +1,68 @@ package dbr import ( - "bytes" "database/sql/driver" "reflect" - "unicode" + "strings" ) -func camelCaseToSnakeCase(name string) string { - buf := new(bytes.Buffer) +var NameMapping = camelCaseToSnakeCase + +func isUpper(b byte) bool { + return 'A' <= b && b <= 'Z' +} - runes := []rune(name) +func isLower(b byte) bool { + return 'a' <= b && b <= 'z' +} - for i := 0; i < len(runes); i++ { - buf.WriteRune(unicode.ToLower(runes[i])) - if i != len(runes)-1 && unicode.IsUpper(runes[i+1]) && - (unicode.IsLower(runes[i]) || unicode.IsDigit(runes[i]) || - (i != len(runes)-2 && unicode.IsLower(runes[i+2]))) { - buf.WriteRune('_') +func isDigit(b byte) bool { + return '0' <= b && b <= '9' +} + +func toLower(b byte) byte { + if isUpper(b) { + return b - 'A' + 'a' + } + return b +} + +func camelCaseToSnakeCase(name string) string { + var buf strings.Builder + buf.Grow(len(name) * 2) + + for i := 0; i < len(name); i++ { + buf.WriteByte(toLower(name[i])) + if i != len(name)-1 && isUpper(name[i+1]) && + (isLower(name[i]) || isDigit(name[i]) || + (i != len(name)-2 && isLower(name[i+2]))) { + buf.WriteByte('_') } } return buf.String() } -func structMap(value reflect.Value) map[string]reflect.Value { - m := make(map[string]reflect.Value) - structValue(m, value) - return m -} - var ( typeValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() ) -func structValue(m map[string]reflect.Value, value reflect.Value) { - if value.Type().Implements(typeValuer) { - return +type tagStore struct { + m map[reflect.Type][]string +} + +func newTagStore() *tagStore { + return &tagStore{ + m: make(map[reflect.Type][]string), } - switch value.Kind() { - case reflect.Ptr: - if value.IsNil() { - return - } - structValue(m, value.Elem()) - case reflect.Struct: - t := value.Type() +} + +func (s *tagStore) get(t reflect.Type) []string { + if t.Kind() != reflect.Struct { + return nil + } + if _, ok := s.m[t]; !ok { + l := make([]string, t.NumField()) for i := 0; i < t.NumField(); i++ { field := t.Field(i) if field.PkgPath != "" && !field.Anonymous { @@ -59,13 +76,66 @@ func structValue(m map[string]reflect.Value, value reflect.Value) { } if tag == "" { // no tag, but we can record the field name - tag = camelCaseToSnakeCase(field.Name) + tag = NameMapping(field.Name) + } + l[i] = tag + } + s.m[t] = l + } + return s.m[t] +} + +func (s *tagStore) findPtr(value reflect.Value, name []string, ptr []interface{}) error { + if value.CanAddr() && value.Addr().Type().Implements(typeScanner) { + ptr[0] = value.Addr().Interface() + return nil + } + switch value.Kind() { + case reflect.Struct: + s.findValueByName(value, name, ptr, true) + return nil + case reflect.Ptr: + if value.IsNil() { + value.Set(reflect.New(value.Type().Elem())) + } + return s.findPtr(value.Elem(), name, ptr) + default: + ptr[0] = value.Addr().Interface() + return nil + } +} + +func (s *tagStore) findValueByName(value reflect.Value, name []string, ret []interface{}, retPtr bool) { + if value.Type().Implements(typeValuer) { + return + } + switch value.Kind() { + case reflect.Ptr: + if value.IsNil() { + return + } + s.findValueByName(value.Elem(), name, ret, retPtr) + case reflect.Struct: + l := s.get(value.Type()) + for i := 0; i < value.NumField(); i++ { + tag := l[i] + if tag == "" { + continue } fieldValue := value.Field(i) - if _, ok := m[tag]; !ok { - m[tag] = fieldValue + for i, want := range name { + if want != tag { + continue + } + if ret[i] == nil { + if retPtr { + ret[i] = fieldValue.Addr().Interface() + } else { + ret[i] = fieldValue + } + } } - structValue(m, fieldValue) + s.findValueByName(fieldValue, name, ret, retPtr) } } } diff --git a/util_test.go b/util_test.go index 43999661..0e9d0514 100644 --- a/util_test.go +++ b/util_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSnakeCase(t *testing.T) { @@ -46,40 +46,47 @@ func TestSnakeCase(t *testing.T) { want: "xml_name", }, } { - assert.Equal(t, test.want, camelCaseToSnakeCase(test.in)) + require.Equal(t, test.want, camelCaseToSnakeCase(test.in)) } } -func TestStructMap(t *testing.T) { +func BenchmarkCamelCaseToSnakeCase(b *testing.B) { + for i := 0; i < b.N; i++ { + camelCaseToSnakeCase("getHTTPResponseCode") + } +} + +func TestFindValueByName(t *testing.T) { for _, test := range []struct { - in interface{} - ok []string - bad []string + in interface{} + name []string + want []string }{ { in: struct { CreatedAt time.Time }{}, - ok: []string{"created_at"}, + name: []string{"created_at"}, + want: []string{"created_at"}, }, { in: struct { intVal int }{}, - bad: []string{"int_val"}, + name: []string{"int_val"}, }, { in: struct { IntVal int `db:"test"` }{}, - ok: []string{"test"}, - bad: []string{"int_val"}, + name: []string{"test"}, + want: []string{"test"}, }, { in: struct { IntVal int `db:"-"` }{}, - bad: []string{"int_val"}, + name: []string{"int_val"}, }, { in: struct { @@ -87,17 +94,21 @@ func TestStructMap(t *testing.T) { Test2 int } }{}, - ok: []string{"test2"}, + name: []string{"test2"}, + want: []string{"test2"}, }, } { - m := structMap(reflect.ValueOf(test.in)) - for _, c := range test.ok { - _, ok := m[c] - assert.True(t, ok) - } - for _, c := range test.bad { - _, ok := m[c] - assert.False(t, ok) + found := make([]interface{}, len(test.name)) + s := newTagStore() + s.findValueByName(reflect.ValueOf(test.in), test.name, found, false) + + var got []string + for i, v := range found { + if v != nil { + got = append(got, test.name[i]) + } } + + require.Equal(t, test.want, got) } }