From e51ed665caac5bdce32d8e6bd7fe4dd2f4a62ada Mon Sep 17 00:00:00 2001 From: taylorchu Date: Tue, 8 May 2018 22:06:17 -0700 Subject: [PATCH] TEST: use testify/require --- condition_test.go | 8 +-- dbr.go | 2 +- dbr_test.go | 88 +++++++++++++++----------------- delete_test.go | 8 +-- dialect/dialect_test.go | 8 +-- insert_test.go | 14 ++--- interpolate_test.go | 18 ++++--- load_test.go | 52 +++++++++---------- postgres_bytea_benchmark_test.go | 6 +-- select_test.go | 8 +-- transaction.go | 36 ++----------- transaction_test.go | 26 +++++----- types_test.go | 38 +++++++------- update_test.go | 8 +-- util_test.go | 8 +-- 15 files changed, 149 insertions(+), 179 deletions(-) diff --git a/condition_test.go b/condition_test.go index a25f3fcc..62e634c6 100644 --- a/condition_test.go +++ b/condition_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/gocraft/dbr/dialect" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCondition(t *testing.T) { @@ -66,8 +66,8 @@ func TestCondition(t *testing.T) { } { buf := NewBuffer() err := test.cond.Build(dialect.MySQL, buf) - assert.NoError(t, err) - assert.Equal(t, test.query, buf.String()) - assert.Equal(t, test.value, buf.Value()) + require.NoError(t, err) + require.Equal(t, test.query, buf.String()) + require.Equal(t, test.value, buf.Value()) } } diff --git a/dbr.go b/dbr.go index a88ab47d..6d5ac208 100644 --- a/dbr.go +++ b/dbr.go @@ -41,7 +41,7 @@ const ( // to send events, errors, and timings to type Connection struct { *sql.DB - Dialect Dialect + Dialect EventReceiver } diff --git a/dbr_test.go b/dbr_test.go index 19fc4995..78402443 100644 --- a/dbr_test.go +++ b/dbr_test.go @@ -2,7 +2,6 @@ package dbr import ( "context" - "database/sql" "fmt" "os" "testing" @@ -12,7 +11,7 @@ import ( "github.com/gocraft/dbr/dialect" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // @@ -87,7 +86,7 @@ func reset(t *testing.T, sess *Session) { )`, autoIncrementType), } { _, err := sess.Exec(v) - assert.NoError(t, err) + require.NoError(t, err) } } @@ -106,61 +105,67 @@ func TestBasicCRUD(t *testing.T) { } // insert result, err := sess.InsertInto("dbr_people").Columns(insertColumns...).Record(&jonathan).Exec() - assert.NoError(t, err) + require.NoError(t, err) rowsAffected, err := result.RowsAffected() - assert.NoError(t, err) - assert.EqualValues(t, 1, rowsAffected) + require.NoError(t, err) + require.Equal(t, int64(1), rowsAffected) - assert.True(t, jonathan.Id > 0) + require.True(t, jonathan.Id > 0) // select var people []dbrPerson count, err := sess.Select("*").From("dbr_people").Where(Eq("id", jonathan.Id)).Load(&people) - assert.NoError(t, err) - if assert.Equal(t, 1, count) { - assert.Equal(t, jonathan.Id, people[0].Id) - assert.Equal(t, jonathan.Name, people[0].Name) - assert.Equal(t, jonathan.Email, people[0].Email) - } + require.NoError(t, err) + require.Equal(t, 1, count) + require.Equal(t, jonathan.Id, people[0].Id) + require.Equal(t, jonathan.Name, people[0].Name) + require.Equal(t, jonathan.Email, people[0].Email) // select id ids, err := sess.Select("id").From("dbr_people").ReturnInt64s() - assert.NoError(t, err) - assert.Equal(t, 1, len(ids)) + require.NoError(t, err) + require.Equal(t, 1, len(ids)) // select id limit ids, err = sess.Select("id").From("dbr_people").Limit(1).ReturnInt64s() - assert.NoError(t, err) - assert.Equal(t, 1, len(ids)) + require.NoError(t, err) + require.Equal(t, 1, len(ids)) // update result, err = sess.Update("dbr_people").Where(Eq("id", jonathan.Id)).Set("name", "jonathan1").Exec() - assert.NoError(t, err) + require.NoError(t, err) rowsAffected, err = result.RowsAffected() - assert.NoError(t, err) - assert.EqualValues(t, 1, rowsAffected) + require.NoError(t, err) + require.Equal(t, int64(1), rowsAffected) var n NullInt64 sess.Select("count(*)").From("dbr_people").Where("name = ?", "jonathan1").LoadOne(&n) - assert.EqualValues(t, 1, n.Int64) + require.Equal(t, int64(1), n.Int64) // delete result, err = sess.DeleteFrom("dbr_people").Where(Eq("id", jonathan.Id)).Exec() - assert.NoError(t, err) + require.NoError(t, err) rowsAffected, err = result.RowsAffected() - assert.NoError(t, err) - assert.EqualValues(t, 1, rowsAffected) + require.NoError(t, err) + require.Equal(t, int64(1), rowsAffected) // select id ids, err = sess.Select("id").From("dbr_people").ReturnInt64s() - assert.NoError(t, err) - assert.Equal(t, 0, len(ids)) + require.NoError(t, err) + require.Equal(t, 0, len(ids)) } } func TestTimeout(t *testing.T) { + mysqlSession := createSession("mysql", mysqlDSN) + postgresSession := createSession("postgres", postgresDSN) + sqlite3Session := createSession("sqlite3", sqlite3DSN) + + // all test sessions should be here + testSession := []*Session{mysqlSession, postgresSession, sqlite3Session} + for _, sess := range testSession { reset(t, sess) @@ -168,47 +173,34 @@ func TestTimeout(t *testing.T) { sess.Timeout = time.Nanosecond var people []dbrPerson _, err := sess.Select("*").From("dbr_people").Load(&people) - assert.EqualValues(t, context.DeadlineExceeded, err) + require.Equal(t, context.DeadlineExceeded, err) _, err = sess.InsertInto("dbr_people").Columns("name", "email").Values("test", "test@test.com").Exec() - assert.EqualValues(t, context.DeadlineExceeded, err) + require.Equal(t, context.DeadlineExceeded, err) _, err = sess.Update("dbr_people").Set("name", "test1").Exec() - assert.EqualValues(t, context.DeadlineExceeded, err) + require.Equal(t, context.DeadlineExceeded, err) _, err = sess.DeleteFrom("dbr_people").Exec() - assert.EqualValues(t, context.DeadlineExceeded, err) - - // tx timeout - _, err = sess.Begin() - assert.EqualValues(t, context.DeadlineExceeded, err) + require.Equal(t, context.DeadlineExceeded, err) // tx op timeout sess.Timeout = 0 tx, err := sess.Begin() - assert.NoError(t, err) + require.NoError(t, err) defer tx.RollbackUnlessCommitted() tx.Timeout = time.Nanosecond _, err = tx.Select("*").From("dbr_people").Load(&people) - assert.EqualValues(t, context.DeadlineExceeded, err) + require.Equal(t, context.DeadlineExceeded, err) _, err = tx.InsertInto("dbr_people").Columns("name", "email").Values("test", "test@test.com").Exec() - assert.EqualValues(t, context.DeadlineExceeded, err) + require.Equal(t, context.DeadlineExceeded, err) _, err = tx.Update("dbr_people").Set("name", "test1").Exec() - assert.EqualValues(t, context.DeadlineExceeded, err) + require.Equal(t, context.DeadlineExceeded, err) _, err = tx.DeleteFrom("dbr_people").Exec() - assert.EqualValues(t, context.DeadlineExceeded, err) - - // tx commit timeout - sess.Timeout = time.Second - tx, err = sess.Begin() - assert.NoError(t, err) - defer tx.RollbackUnlessCommitted() - time.Sleep(2 * time.Second) - err = tx.Commit() - assert.EqualValues(t, sql.ErrTxDone, err) + require.Equal(t, context.DeadlineExceeded, err) } } diff --git a/delete_test.go b/delete_test.go index e5cff17a..1046dbba 100644 --- a/delete_test.go +++ b/delete_test.go @@ -4,16 +4,16 @@ import ( "testing" "github.com/gocraft/dbr/dialect" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDeleteStmt(t *testing.T) { buf := NewBuffer() builder := DeleteFrom("table").Where(Eq("a", 1)) err := builder.Build(dialect.MySQL, buf) - assert.NoError(t, err) - assert.Equal(t, "DELETE FROM `table` WHERE (`a` = ?)", buf.String()) - assert.Equal(t, []interface{}{1}, buf.Value()) + require.NoError(t, err) + require.Equal(t, "DELETE FROM `table` WHERE (`a` = ?)", buf.String()) + require.Equal(t, []interface{}{1}, buf.Value()) } func BenchmarkDeleteSQL(b *testing.B) { diff --git a/dialect/dialect_test.go b/dialect/dialect_test.go index 85ea1086..418d3d08 100644 --- a/dialect/dialect_test.go +++ b/dialect/dialect_test.go @@ -3,7 +3,7 @@ package dialect import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMySQL(t *testing.T) { @@ -20,7 +20,7 @@ func TestMySQL(t *testing.T) { want: "`col`", }, } { - assert.Equal(t, test.want, MySQL.QuoteIdent(test.in)) + require.Equal(t, test.want, MySQL.QuoteIdent(test.in)) } } @@ -38,7 +38,7 @@ func TestPostgreSQL(t *testing.T) { want: `"col"`, }, } { - assert.Equal(t, test.want, PostgreSQL.QuoteIdent(test.in)) + require.Equal(t, test.want, PostgreSQL.QuoteIdent(test.in)) } } @@ -56,6 +56,6 @@ func TestSQLite3(t *testing.T) { want: `"col"`, }, } { - assert.Equal(t, test.want, SQLite3.QuoteIdent(test.in)) + require.Equal(t, test.want, SQLite3.QuoteIdent(test.in)) } } diff --git a/insert_test.go b/insert_test.go index 716664d6..ac363e24 100644 --- a/insert_test.go +++ b/insert_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/gocraft/dbr/dialect" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type insertTest struct { @@ -19,18 +19,20 @@ func TestInsertStmt(t *testing.T) { C: "two", }) err := builder.Build(dialect.MySQL, buf) - assert.NoError(t, err) - assert.Equal(t, "INSERT INTO `table` (`a`,`b`) VALUES (?,?), (?,?)", buf.String()) - assert.Equal(t, []interface{}{1, "one", 2, "two"}, buf.Value()) + require.NoError(t, err) + require.Equal(t, "INSERT INTO `table` (`a`,`b`) VALUES (?,?), (?,?)", buf.String()) + require.Equal(t, []interface{}{1, "one", 2, "two"}, buf.Value()) } func TestPostgresReturning(t *testing.T) { sess := postgresSession + reset(t, sess) + var person dbrPerson err := sess.InsertInto("dbr_people").Columns("name").Record(&person). Returning("id").Load(&person.Id) - assert.NoError(t, err) - assert.True(t, person.Id > 0) + require.NoError(t, err) + require.True(t, person.Id > 0) } func BenchmarkInsertValuesSQL(b *testing.B) { diff --git a/interpolate_test.go b/interpolate_test.go index a198e26e..c5d89df6 100644 --- a/interpolate_test.go +++ b/interpolate_test.go @@ -6,7 +6,7 @@ import ( "time" "github.com/gocraft/dbr/dialect" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestInterpolateIgnoreBinary(t *testing.T) { @@ -48,10 +48,10 @@ func TestInterpolateIgnoreBinary(t *testing.T) { } err := i.interpolate(test.query, test.value, true) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, test.wantQuery, i.String()) - assert.Equal(t, test.wantValue, i.Value()) + require.Equal(t, test.wantQuery, i.String()) + require.Equal(t, test.wantValue, i.Value()) } } @@ -138,8 +138,8 @@ func TestInterpolateForDialect(t *testing.T) { }, } { s, err := InterpolateForDialect(test.query, test.value, dialect.MySQL) - assert.NoError(t, err) - assert.Equal(t, test.want, s) + require.NoError(t, err) + require.Equal(t, test.want, s) } } @@ -147,17 +147,19 @@ func TestInterpolateForDialect(t *testing.T) { // more information on the source and the strings themselves. func TestCommonSQLInjections(t *testing.T) { for _, sess := range testSession { + reset(t, sess) + for _, injectionAttempt := range strings.Split(injectionAttempts, "\n") { // Create a user with the attempted injection as the email address _, err := sess.InsertInto("dbr_people"). Pair("name", injectionAttempt). Exec() - assert.NoError(t, err) + require.NoError(t, err) // SELECT the name back and ensure it's equal to the injection attempt var name string err = sess.Select("name").From("dbr_people").OrderDesc("id").Limit(1).LoadOne(&name) - assert.Equal(t, injectionAttempt, name) + require.Equal(t, injectionAttempt, name) } } } diff --git a/load_test.go b/load_test.go index 67cafebb..cb46dac7 100644 --- a/load_test.go +++ b/load_test.go @@ -3,7 +3,7 @@ package dbr import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type stringSliceWithSQLScanner []string @@ -28,17 +28,17 @@ func TestSliceWithSQLScannerSelect(t *testing.T) { var stringSlice []string cnt, err := sess.Select("name").From("dbr_people").Load(&stringSlice) - assert.NoError(t, err) - assert.Equal(t, cnt, 3) - assert.Len(t, stringSlice, 3) + require.NoError(t, err) + require.Equal(t, 3, cnt) + require.Len(t, stringSlice, 3) //string slice with sql.Scanner implemented, should act as a single record var sliceScanner stringSliceWithSQLScanner cnt, err = sess.Select("name").From("dbr_people").Load(&sliceScanner) - assert.NoError(t, err) - assert.Equal(t, cnt, 1) - assert.Len(t, sliceScanner, 1) + require.NoError(t, err) + require.Equal(t, 1, cnt) + require.Len(t, sliceScanner, 1) } } @@ -55,35 +55,35 @@ func TestMaps(t *testing.T) { var m map[string]string cnt, err := sess.Select("email, name").From("dbr_people").Load(&m) - assert.NoError(t, err) - assert.Equal(t, cnt, 3) - assert.Len(t, m, 3) - assert.Equal(t, m["test1@test.com"], "test1") + require.NoError(t, err) + require.Equal(t, 3, cnt) + require.Len(t, m, 3) + require.Equal(t, "test1", m["test1@test.com"]) var m2 map[int64]*dbrPerson cnt, err = sess.Select("id, name, email").From("dbr_people").Load(&m2) - assert.NoError(t, err) - assert.Equal(t, cnt, 3) - assert.Len(t, m2, 3) - assert.Equal(t, m2[1].Email, "test1@test.com") - assert.Equal(t, m2[1].Name, "test1") + require.NoError(t, err) + require.Equal(t, 3, cnt) + require.Len(t, m2, 3) + require.Equal(t, "test1@test.com", m2[1].Email) + require.Equal(t, "test1", m2[1].Name) // the id value is used as the map key, so it is not hydrated in the struct - assert.EqualValues(t, m2[1].Id, 0) + require.Equal(t, int64(0), m2[1].Id) var m3 map[string][]string cnt, err = sess.Select("name, email").From("dbr_people").OrderAsc("id").Load(&m3) - assert.NoError(t, err) - assert.Equal(t, cnt, 3) - assert.Len(t, m3, 2) - assert.Equal(t, m3["test1"], []string{"test1@test.com"}) - assert.Equal(t, m3["test2"], []string{"test2@test.com", "test3@test.com"}) + require.NoError(t, err) + require.Equal(t, 3, cnt) + require.Len(t, m3, 2) + require.Equal(t, []string{"test1@test.com"}, m3["test1"]) + require.Equal(t, []string{"test2@test.com", "test3@test.com"}, m3["test2"]) var set map[string]struct{} cnt, err = sess.Select("name").From("dbr_people").Load(&set) - assert.NoError(t, err) - assert.Equal(t, cnt, 3) - assert.Len(t, set, 2) + require.NoError(t, err) + require.Equal(t, 3, cnt) + require.Len(t, set, 2) _, ok := set["test1"] - assert.True(t, ok) + require.True(t, ok) } } diff --git a/postgres_bytea_benchmark_test.go b/postgres_bytea_benchmark_test.go index 383d314d..17ee2e2d 100644 --- a/postgres_bytea_benchmark_test.go +++ b/postgres_bytea_benchmark_test.go @@ -4,7 +4,7 @@ import ( "bytes" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func BenchmarkByteaNoBinaryEncode(b *testing.B) { @@ -24,12 +24,12 @@ func benchmarkBytea(b *testing.B, sess *Session) { )`, } { _, err := sess.Exec(v) - assert.NoError(b, err) + require.NoError(b, err) } b.ResetTimer() for i := 0; i < b.N; i++ { _, err := sess.InsertInto("bytea_table").Pair("val", data).Exec() - assert.NoError(b, err) + require.NoError(b, err) } } diff --git a/select_test.go b/select_test.go index 94928017..a9cb9da0 100644 --- a/select_test.go +++ b/select_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/gocraft/dbr/dialect" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSelectStmt(t *testing.T) { @@ -20,10 +20,10 @@ func TestSelectStmt(t *testing.T) { Limit(3). Offset(4) err := builder.Build(dialect.MySQL, buf) - assert.NoError(t, err) - assert.Equal(t, "SELECT DISTINCT a, b FROM ? LEFT JOIN `table2` ON table.a1 = table.a2 WHERE (`c` = ?) GROUP BY d HAVING (`e` = ?) ORDER BY f ASC LIMIT 3 OFFSET 4", buf.String()) + require.NoError(t, err) + require.Equal(t, "SELECT DISTINCT a, b FROM ? LEFT JOIN `table2` ON table.a1 = table.a2 WHERE (`c` = ?) GROUP BY d HAVING (`e` = ?) ORDER BY f ASC LIMIT 3 OFFSET 4", buf.String()) // two functions cannot be compared - assert.Equal(t, 3, len(buf.Value())) + require.Equal(t, 3, len(buf.Value())) } func BenchmarkSelectSQL(b *testing.B) { diff --git a/transaction.go b/transaction.go index 9217d73d..197efeeb 100644 --- a/transaction.go +++ b/transaction.go @@ -9,14 +9,9 @@ import ( // Tx is a transaction for the given Session type Tx struct { EventReceiver - Dialect Dialect + Dialect *sql.Tx Timeout time.Duration - - // normally we don't call the context cancelFunc. - // however, if we start a tx without explictly tx, - // we will need to call this after the transaction. - Cancel func() } // GetTimeout returns timeout enforced in Tx @@ -32,40 +27,21 @@ func (sess *Session) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, err } sess.Event("dbr.begin") - stx := &Tx{ + return &Tx{ EventReceiver: sess, Dialect: sess.Dialect, Tx: tx, - Cancel: func() {}, - } - deadline, ok := ctx.Deadline() - if ok { - stx.Timeout = deadline.Sub(time.Now()) - } - return stx, nil + Timeout: sess.GetTimeout(), + }, nil } // Begin creates a transaction for the given session func (sess *Session) Begin() (*Tx, error) { - ctx := context.Background() - var cancel func() - timeout := sess.GetTimeout() - if timeout > 0 { - ctx, cancel = context.WithTimeout(ctx, timeout) - } - stx, err := sess.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - if cancel != nil { - stx.Cancel = cancel - } - return stx, nil + return sess.BeginTx(context.Background(), nil) } // Commit finishes the transaction func (tx *Tx) Commit() error { - defer tx.Cancel() err := tx.Tx.Commit() if err != nil { return tx.EventErr("dbr.commit.error", err) @@ -76,7 +52,6 @@ func (tx *Tx) Commit() error { // Rollback cancels the transaction func (tx *Tx) Rollback() error { - defer tx.Cancel() err := tx.Tx.Rollback() if err != nil { return tx.EventErr("dbr.rollback", err) @@ -89,7 +64,6 @@ func (tx *Tx) Rollback() error { // Useful to defer tx.RollbackUnlessCommitted() -- so you don't have to handle N failure cases // Keep in mind the only way to detect an error on the rollback is via the event log. func (tx *Tx) RollbackUnlessCommitted() { - defer tx.Cancel() err := tx.Tx.Rollback() if err == sql.ErrTxDone { // ok diff --git a/transaction_test.go b/transaction_test.go index 356e545e..b2819681 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -3,7 +3,7 @@ package dbr import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestTransactionCommit(t *testing.T) { @@ -11,24 +11,24 @@ func TestTransactionCommit(t *testing.T) { reset(t, sess) tx, err := sess.Begin() - assert.NoError(t, err) + require.NoError(t, err) defer tx.RollbackUnlessCommitted() id := 1 result, err := tx.InsertInto("dbr_people").Columns("id", "name", "email").Values(id, "Barack", "obama@whitehouse.gov").Exec() - assert.NoError(t, err) + require.NoError(t, err) rowsAffected, err := result.RowsAffected() - assert.NoError(t, err) - assert.EqualValues(t, 1, rowsAffected) + require.NoError(t, err) + require.Equal(t, int64(1), rowsAffected) err = tx.Commit() - assert.NoError(t, err) + require.NoError(t, err) var person dbrPerson err = tx.Select("*").From("dbr_people").Where(Eq("id", id)).LoadOne(&person) - assert.Error(t, err) + require.Error(t, err) } } @@ -37,23 +37,23 @@ func TestTransactionRollback(t *testing.T) { reset(t, sess) tx, err := sess.Begin() - assert.NoError(t, err) + require.NoError(t, err) defer tx.RollbackUnlessCommitted() id := 1 result, err := tx.InsertInto("dbr_people").Columns("id", "name", "email").Values(id, "Barack", "obama@whitehouse.gov").Exec() - assert.NoError(t, err) + require.NoError(t, err) rowsAffected, err := result.RowsAffected() - assert.NoError(t, err) - assert.EqualValues(t, 1, rowsAffected) + require.NoError(t, err) + require.Equal(t, int64(1), rowsAffected) err = tx.Rollback() - assert.NoError(t, err) + require.NoError(t, err) var person dbrPerson err = tx.Select("*").From("dbr_people").Where(Eq("id", id)).LoadOne(&person) - assert.Error(t, err) + require.Error(t, err) } } diff --git a/types_test.go b/types_test.go index 6e21b893..89d4dc8b 100644 --- a/types_test.go +++ b/types_test.go @@ -6,7 +6,7 @@ import ( "time" "github.com/gocraft/dbr/dialect" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -33,18 +33,18 @@ func TestNullTypesScanning(t *testing.T) { test.in.Id = 1 _, err := sess.InsertInto("null_types").Columns("id", "string_val", "int64_val", "float64_val", "time_val", "bool_val").Record(test.in).Exec() - assert.NoError(t, err) + require.NoError(t, err) var record nullTypedRecord err = sess.Select("*").From("null_types").Where(Eq("id", test.in.Id)).LoadOne(&record) - assert.NoError(t, err) + require.NoError(t, err) if sess.Dialect == dialect.PostgreSQL { // TODO: https://github.com/lib/pq/issues/329 if !record.TimeVal.Time.IsZero() { record.TimeVal.Time = record.TimeVal.Time.UTC() } } - assert.Equal(t, test.in, record) + require.Equal(t, test.in, record) } } } @@ -54,9 +54,9 @@ func TestNullInt64Unmarshal(t *testing.T) { Num NullInt64 } err := json.Unmarshal([]byte(`{"num":null}`), &test) - assert.NoError(t, err) - assert.Equal(t, int64(0), test.Num.Int64) - assert.False(t, test.Num.Valid) + require.NoError(t, err) + require.Equal(t, int64(0), test.Num.Int64) + require.False(t, test.Num.Valid) } func TestNullTypesActuallyNullJSON(t *testing.T) { @@ -69,12 +69,12 @@ func TestNullTypesActuallyNullJSON(t *testing.T) { } jsonBs := []byte(`{"b":null,"f":null,"s":null,"t":null,"i":null}`) err := json.Unmarshal(jsonBs, &out) - assert.NoError(t, err) - assert.False(t, out.Bool.Valid) - assert.False(t, out.Float.Valid) - assert.False(t, out.String.Valid) - assert.False(t, out.Time.Valid) - assert.False(t, out.Int.Valid) + require.NoError(t, err) + require.False(t, out.Bool.Valid) + require.False(t, out.Float.Valid) + require.False(t, out.String.Valid) + require.False(t, out.Time.Valid) + require.False(t, out.Int.Valid) } func TestNullTypesJSON(t *testing.T) { @@ -117,17 +117,17 @@ func TestNullTypesJSON(t *testing.T) { } { // marshal ptr b, err := json.Marshal(test.in) - assert.NoError(t, err) - assert.Equal(t, test.want, string(b)) + require.NoError(t, err) + require.Equal(t, test.want, string(b)) // marshal value b, err = json.Marshal(test.in2) - assert.NoError(t, err) - assert.Equal(t, test.want, string(b)) + require.NoError(t, err) + require.Equal(t, test.want, string(b)) // unmarshal err = json.Unmarshal(b, test.out) - assert.NoError(t, err) - assert.Equal(t, test.in, test.out) + require.NoError(t, err) + require.Equal(t, test.in, test.out) } } diff --git a/update_test.go b/update_test.go index bc0007cd..851a67dd 100644 --- a/update_test.go +++ b/update_test.go @@ -4,17 +4,17 @@ import ( "testing" "github.com/gocraft/dbr/dialect" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUpdateStmt(t *testing.T) { buf := NewBuffer() builder := Update("table").Set("a", 1).Where(Eq("b", 2)) err := builder.Build(dialect.MySQL, buf) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, "UPDATE `table` SET `a` = ? WHERE (`b` = ?)", buf.String()) - assert.Equal(t, []interface{}{1, 2}, buf.Value()) + require.Equal(t, "UPDATE `table` SET `a` = ? WHERE (`b` = ?)", buf.String()) + require.Equal(t, []interface{}{1, 2}, buf.Value()) } func BenchmarkUpdateValuesSQL(b *testing.B) { diff --git a/util_test.go b/util_test.go index 49b8021f..98fcf96a 100644 --- a/util_test.go +++ b/util_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSnakeCase(t *testing.T) { @@ -46,7 +46,7 @@ func TestSnakeCase(t *testing.T) { want: "xml_name", }, } { - assert.Equal(t, test.want, camelCaseToSnakeCase(test.in)) + require.Equal(t, test.want, camelCaseToSnakeCase(test.in)) } } @@ -99,11 +99,11 @@ func TestStructMap(t *testing.T) { m := structMap(reflect.ValueOf(test.in)) for _, c := range test.ok { _, ok := m[c] - assert.True(t, ok) + require.True(t, ok) } for _, c := range test.bad { _, ok := m[c] - assert.False(t, ok) + require.False(t, ok) } } }