Skip to content
This repository has been archived by the owner on Mar 16, 2019. It is now read-only.

Commit

Permalink
Rename Rollback -> RollbackErr, add example
Browse files Browse the repository at this point in the history
  • Loading branch information
bahlo committed Sep 1, 2015
1 parent 8af2599 commit a2ac889
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
4 changes: 2 additions & 2 deletions database.go
Expand Up @@ -4,11 +4,11 @@ import (
"database/sql"
)

// Rollback does a rollback on the transaction and returns either the error
// RollbackErr does a rollback on the transaction and returns either the error
// from the rollback if there was one or the alternative.
// This is useful if you have multiple statments in a row but don't want to
// call rollback and check for errors every time.
func Rollback(tx *sql.Tx, alt error) error {
func RollbackErr(tx *sql.Tx, alt error) error {
if err := tx.Rollback(); err != nil {
return err
}
Expand Down
44 changes: 38 additions & 6 deletions database_test.go
Expand Up @@ -3,6 +3,7 @@ package abutil
import (
"database/sql"
"errors"
"fmt"
"testing"

"github.com/DATA-DOG/go-sqlmock"
Expand All @@ -18,7 +19,7 @@ func mockDBContext(t *testing.T, fn func(*sql.DB)) {
fn(db)
}

func TestRollback(t *testing.T) {
func TestRollbackErr(t *testing.T) {
mockDBContext(t, func(db *sql.DB) {
sqlmock.ExpectBegin()
sqlmock.ExpectRollback()
Expand All @@ -29,15 +30,15 @@ func TestRollback(t *testing.T) {
}

alt := errors.New("Some alternative error")
err = Rollback(tx, alt)
err = RollbackErr(tx, alt)

if err != alt {
t.Errorf("Expected Rollback to return %v, but got %v", alt, err)
t.Errorf("Expected RollbackErr to return %v, but got %v", alt, err)
}
})
}

func TestRollbackFailing(t *testing.T) {
func TestRollbackErrFailing(t *testing.T) {
mockDBContext(t, func(db *sql.DB) {
rberr := errors.New("Some rollback error")

Expand All @@ -50,9 +51,40 @@ func TestRollbackFailing(t *testing.T) {
t.Error(err)
}

err = Rollback(tx, errors.New("This should not be used"))
err = RollbackErr(tx, errors.New("This should not be used"))
if err != rberr {
t.Errorf("Expected Rollback to return %v, but got %v", rberr, err)
t.Errorf("Expected RollbackErr to return %v, but got %v", rberr, err)
}
})
}

func rollbackDBContext(fn func(*sql.DB)) {
db, _ := sqlmock.New()
fn(db)
db.Close()
}

func RollbackErrExample() {
insertSomething := func(db *sql.DB) error {
tx, _ := db.Begin()

_, err := tx.Exec("INSERT INTO some_table (some_column) VALUES (?)",
"foobar")
if err != nil {
// We now have a one-liner instead of a check every time an error
// occurs
return RollbackErr(tx, err)
}

_, err = tx.Exec("DROP DATABASE foobar")
if err != nil {
return RollbackErr(tx, err)
}

return nil
}

rollbackDBContext(func(db *sql.DB) {
fmt.Println(insertSomething(db))
})
}

0 comments on commit a2ac889

Please sign in to comment.