Skip to content

Commit

Permalink
TEST: use testify/require
Browse files Browse the repository at this point in the history
  • Loading branch information
taylorchu committed May 10, 2018
1 parent 6ac14d2 commit e51ed66
Show file tree
Hide file tree
Showing 15 changed files with 149 additions and 179 deletions.
8 changes: 4 additions & 4 deletions condition_test.go
Expand Up @@ -4,7 +4,7 @@ import (
"testing"

"github.com/gocraft/dbr/dialect"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestCondition(t *testing.T) {
Expand Down Expand Up @@ -66,8 +66,8 @@ func TestCondition(t *testing.T) {
} {
buf := NewBuffer()
err := test.cond.Build(dialect.MySQL, buf)
assert.NoError(t, err)
assert.Equal(t, test.query, buf.String())
assert.Equal(t, test.value, buf.Value())
require.NoError(t, err)
require.Equal(t, test.query, buf.String())
require.Equal(t, test.value, buf.Value())
}
}
2 changes: 1 addition & 1 deletion dbr.go
Expand Up @@ -41,7 +41,7 @@ const (
// to send events, errors, and timings to
type Connection struct {
*sql.DB
Dialect Dialect
Dialect
EventReceiver
}

Expand Down
88 changes: 40 additions & 48 deletions dbr_test.go
Expand Up @@ -2,7 +2,6 @@ package dbr

import (
"context"
"database/sql"
"fmt"
"os"
"testing"
Expand All @@ -12,7 +11,7 @@ import (
"github.com/gocraft/dbr/dialect"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

//
Expand Down Expand Up @@ -87,7 +86,7 @@ func reset(t *testing.T, sess *Session) {
)`, autoIncrementType),
} {
_, err := sess.Exec(v)
assert.NoError(t, err)
require.NoError(t, err)
}
}

Expand All @@ -106,109 +105,102 @@ func TestBasicCRUD(t *testing.T) {
}
// insert
result, err := sess.InsertInto("dbr_people").Columns(insertColumns...).Record(&jonathan).Exec()
assert.NoError(t, err)
require.NoError(t, err)

rowsAffected, err := result.RowsAffected()
assert.NoError(t, err)
assert.EqualValues(t, 1, rowsAffected)
require.NoError(t, err)
require.Equal(t, int64(1), rowsAffected)

assert.True(t, jonathan.Id > 0)
require.True(t, jonathan.Id > 0)
// select
var people []dbrPerson
count, err := sess.Select("*").From("dbr_people").Where(Eq("id", jonathan.Id)).Load(&people)
assert.NoError(t, err)
if assert.Equal(t, 1, count) {
assert.Equal(t, jonathan.Id, people[0].Id)
assert.Equal(t, jonathan.Name, people[0].Name)
assert.Equal(t, jonathan.Email, people[0].Email)
}
require.NoError(t, err)
require.Equal(t, 1, count)
require.Equal(t, jonathan.Id, people[0].Id)
require.Equal(t, jonathan.Name, people[0].Name)
require.Equal(t, jonathan.Email, people[0].Email)

// select id
ids, err := sess.Select("id").From("dbr_people").ReturnInt64s()
assert.NoError(t, err)
assert.Equal(t, 1, len(ids))
require.NoError(t, err)
require.Equal(t, 1, len(ids))

// select id limit
ids, err = sess.Select("id").From("dbr_people").Limit(1).ReturnInt64s()
assert.NoError(t, err)
assert.Equal(t, 1, len(ids))
require.NoError(t, err)
require.Equal(t, 1, len(ids))

// update
result, err = sess.Update("dbr_people").Where(Eq("id", jonathan.Id)).Set("name", "jonathan1").Exec()
assert.NoError(t, err)
require.NoError(t, err)

rowsAffected, err = result.RowsAffected()
assert.NoError(t, err)
assert.EqualValues(t, 1, rowsAffected)
require.NoError(t, err)
require.Equal(t, int64(1), rowsAffected)

var n NullInt64
sess.Select("count(*)").From("dbr_people").Where("name = ?", "jonathan1").LoadOne(&n)
assert.EqualValues(t, 1, n.Int64)
require.Equal(t, int64(1), n.Int64)

// delete
result, err = sess.DeleteFrom("dbr_people").Where(Eq("id", jonathan.Id)).Exec()
assert.NoError(t, err)
require.NoError(t, err)

rowsAffected, err = result.RowsAffected()
assert.NoError(t, err)
assert.EqualValues(t, 1, rowsAffected)
require.NoError(t, err)
require.Equal(t, int64(1), rowsAffected)

// select id
ids, err = sess.Select("id").From("dbr_people").ReturnInt64s()
assert.NoError(t, err)
assert.Equal(t, 0, len(ids))
require.NoError(t, err)
require.Equal(t, 0, len(ids))
}
}

func TestTimeout(t *testing.T) {
mysqlSession := createSession("mysql", mysqlDSN)
postgresSession := createSession("postgres", postgresDSN)
sqlite3Session := createSession("sqlite3", sqlite3DSN)

// all test sessions should be here
testSession := []*Session{mysqlSession, postgresSession, sqlite3Session}

for _, sess := range testSession {
reset(t, sess)

// session op timeout
sess.Timeout = time.Nanosecond
var people []dbrPerson
_, err := sess.Select("*").From("dbr_people").Load(&people)
assert.EqualValues(t, context.DeadlineExceeded, err)
require.Equal(t, context.DeadlineExceeded, err)

_, err = sess.InsertInto("dbr_people").Columns("name", "email").Values("test", "test@test.com").Exec()
assert.EqualValues(t, context.DeadlineExceeded, err)
require.Equal(t, context.DeadlineExceeded, err)

_, err = sess.Update("dbr_people").Set("name", "test1").Exec()
assert.EqualValues(t, context.DeadlineExceeded, err)
require.Equal(t, context.DeadlineExceeded, err)

_, err = sess.DeleteFrom("dbr_people").Exec()
assert.EqualValues(t, context.DeadlineExceeded, err)

// tx timeout
_, err = sess.Begin()
assert.EqualValues(t, context.DeadlineExceeded, err)
require.Equal(t, context.DeadlineExceeded, err)

// tx op timeout
sess.Timeout = 0
tx, err := sess.Begin()
assert.NoError(t, err)
require.NoError(t, err)
defer tx.RollbackUnlessCommitted()
tx.Timeout = time.Nanosecond

_, err = tx.Select("*").From("dbr_people").Load(&people)
assert.EqualValues(t, context.DeadlineExceeded, err)
require.Equal(t, context.DeadlineExceeded, err)

_, err = tx.InsertInto("dbr_people").Columns("name", "email").Values("test", "test@test.com").Exec()
assert.EqualValues(t, context.DeadlineExceeded, err)
require.Equal(t, context.DeadlineExceeded, err)

_, err = tx.Update("dbr_people").Set("name", "test1").Exec()
assert.EqualValues(t, context.DeadlineExceeded, err)
require.Equal(t, context.DeadlineExceeded, err)

_, err = tx.DeleteFrom("dbr_people").Exec()
assert.EqualValues(t, context.DeadlineExceeded, err)

// tx commit timeout
sess.Timeout = time.Second
tx, err = sess.Begin()
assert.NoError(t, err)
defer tx.RollbackUnlessCommitted()
time.Sleep(2 * time.Second)
err = tx.Commit()
assert.EqualValues(t, sql.ErrTxDone, err)
require.Equal(t, context.DeadlineExceeded, err)
}
}
8 changes: 4 additions & 4 deletions delete_test.go
Expand Up @@ -4,16 +4,16 @@ import (
"testing"

"github.com/gocraft/dbr/dialect"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestDeleteStmt(t *testing.T) {
buf := NewBuffer()
builder := DeleteFrom("table").Where(Eq("a", 1))
err := builder.Build(dialect.MySQL, buf)
assert.NoError(t, err)
assert.Equal(t, "DELETE FROM `table` WHERE (`a` = ?)", buf.String())
assert.Equal(t, []interface{}{1}, buf.Value())
require.NoError(t, err)
require.Equal(t, "DELETE FROM `table` WHERE (`a` = ?)", buf.String())
require.Equal(t, []interface{}{1}, buf.Value())
}

func BenchmarkDeleteSQL(b *testing.B) {
Expand Down
8 changes: 4 additions & 4 deletions dialect/dialect_test.go
Expand Up @@ -3,7 +3,7 @@ package dialect
import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestMySQL(t *testing.T) {
Expand All @@ -20,7 +20,7 @@ func TestMySQL(t *testing.T) {
want: "`col`",
},
} {
assert.Equal(t, test.want, MySQL.QuoteIdent(test.in))
require.Equal(t, test.want, MySQL.QuoteIdent(test.in))
}
}

Expand All @@ -38,7 +38,7 @@ func TestPostgreSQL(t *testing.T) {
want: `"col"`,
},
} {
assert.Equal(t, test.want, PostgreSQL.QuoteIdent(test.in))
require.Equal(t, test.want, PostgreSQL.QuoteIdent(test.in))
}
}

Expand All @@ -56,6 +56,6 @@ func TestSQLite3(t *testing.T) {
want: `"col"`,
},
} {
assert.Equal(t, test.want, SQLite3.QuoteIdent(test.in))
require.Equal(t, test.want, SQLite3.QuoteIdent(test.in))
}
}
14 changes: 8 additions & 6 deletions insert_test.go
Expand Up @@ -4,7 +4,7 @@ import (
"testing"

"github.com/gocraft/dbr/dialect"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type insertTest struct {
Expand All @@ -19,18 +19,20 @@ func TestInsertStmt(t *testing.T) {
C: "two",
})
err := builder.Build(dialect.MySQL, buf)
assert.NoError(t, err)
assert.Equal(t, "INSERT INTO `table` (`a`,`b`) VALUES (?,?), (?,?)", buf.String())
assert.Equal(t, []interface{}{1, "one", 2, "two"}, buf.Value())
require.NoError(t, err)
require.Equal(t, "INSERT INTO `table` (`a`,`b`) VALUES (?,?), (?,?)", buf.String())
require.Equal(t, []interface{}{1, "one", 2, "two"}, buf.Value())
}

func TestPostgresReturning(t *testing.T) {
sess := postgresSession
reset(t, sess)

var person dbrPerson
err := sess.InsertInto("dbr_people").Columns("name").Record(&person).
Returning("id").Load(&person.Id)
assert.NoError(t, err)
assert.True(t, person.Id > 0)
require.NoError(t, err)
require.True(t, person.Id > 0)
}

func BenchmarkInsertValuesSQL(b *testing.B) {
Expand Down
18 changes: 10 additions & 8 deletions interpolate_test.go
Expand Up @@ -6,7 +6,7 @@ import (
"time"

"github.com/gocraft/dbr/dialect"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestInterpolateIgnoreBinary(t *testing.T) {
Expand Down Expand Up @@ -48,10 +48,10 @@ func TestInterpolateIgnoreBinary(t *testing.T) {
}

err := i.interpolate(test.query, test.value, true)
assert.NoError(t, err)
require.NoError(t, err)

assert.Equal(t, test.wantQuery, i.String())
assert.Equal(t, test.wantValue, i.Value())
require.Equal(t, test.wantQuery, i.String())
require.Equal(t, test.wantValue, i.Value())
}
}

Expand Down Expand Up @@ -138,26 +138,28 @@ func TestInterpolateForDialect(t *testing.T) {
},
} {
s, err := InterpolateForDialect(test.query, test.value, dialect.MySQL)
assert.NoError(t, err)
assert.Equal(t, test.want, s)
require.NoError(t, err)
require.Equal(t, test.want, s)
}
}

// Attempts to test common SQL injection strings. See `InjectionAttempts` for
// more information on the source and the strings themselves.
func TestCommonSQLInjections(t *testing.T) {
for _, sess := range testSession {
reset(t, sess)

for _, injectionAttempt := range strings.Split(injectionAttempts, "\n") {
// Create a user with the attempted injection as the email address
_, err := sess.InsertInto("dbr_people").
Pair("name", injectionAttempt).
Exec()
assert.NoError(t, err)
require.NoError(t, err)

// SELECT the name back and ensure it's equal to the injection attempt
var name string
err = sess.Select("name").From("dbr_people").OrderDesc("id").Limit(1).LoadOne(&name)
assert.Equal(t, injectionAttempt, name)
require.Equal(t, injectionAttempt, name)
}
}
}
Expand Down

0 comments on commit e51ed66

Please sign in to comment.