Skip to content

Commit

Permalink
Remove error from mapper function signature.
Browse files Browse the repository at this point in the history
Errors (if any) in mapper functions should be surfaced using panic() instead.
People shouldn't be burdened with writing `func(*Row) (T, error)` over and over
if 99% of the time they are never going to return any error.
  • Loading branch information
bokwoon95 committed Nov 30, 2022
1 parent 8b9daa1 commit d31a2ba
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 95 deletions.
67 changes: 25 additions & 42 deletions fetch_exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type Cursor[T any] struct {
ctx context.Context
row *Row
sqlRows *sql.Rows
rowmapper func(*Row) (T, error)
rowmapper func(*Row) T
stats QueryStats
logSettings LogSettings
logger SqLogger
Expand All @@ -28,16 +28,16 @@ type Cursor[T any] struct {
}

// FetchCursor returns a new cursor.
func FetchCursor[T any](db DB, q Query, rowmapper func(*Row) (T, error)) (*Cursor[T], error) {
func FetchCursor[T any](db DB, q Query, rowmapper func(*Row) T) (*Cursor[T], error) {
return fetchCursor(context.Background(), db, q, rowmapper, 1)
}

// FetchCursorContext is like FetchCursor but additionally requires a context.Context.
func FetchCursorContext[T any](ctx context.Context, db DB, q Query, rowmapper func(*Row) (T, error)) (*Cursor[T], error) {
func FetchCursorContext[T any](ctx context.Context, db DB, q Query, rowmapper func(*Row) T) (*Cursor[T], error) {
return fetchCursor(ctx, db, q, rowmapper, 1)
}

func fetchCursor[T any](ctx context.Context, db DB, q Query, rowmapper func(*Row) (T, error), skip int) (c *Cursor[T], err error) {
func fetchCursor[T any](ctx context.Context, db DB, q Query, rowmapper func(*Row) T, skip int) (c *Cursor[T], err error) {
if db == nil {
return nil, fmt.Errorf("db is nil")
}
Expand All @@ -52,11 +52,8 @@ func fetchCursor[T any](ctx context.Context, db DB, q Query, rowmapper func(*Row
// Get fields and dest from rowmapper
dialect := q.GetDialect()
c.row = newRow(dialect)
defer recoverRowmapperPanic(&err)
_, err = c.rowmapper(c.row)
if err != nil {
return nil, err
}
defer recoverPanic(&err)
_ = c.rowmapper(c.row)
c.row.active = true
if len(c.row.fields) == 0 || len(c.row.dest) == 0 {
return nil, fmt.Errorf("rowmapper did not yield any fields")
Expand Down Expand Up @@ -196,8 +193,8 @@ func (c *Cursor[T]) Result() (result T, err error) {
}
}
c.row.index = 0
defer recoverRowmapperPanic(&err)
result, c.stats.Err = c.rowmapper(c.row)
defer recoverPanic(&err)
result = c.rowmapper(c.row)
if c.stats.Err != nil {
c.log()
return result, c.stats.Err
Expand Down Expand Up @@ -237,7 +234,7 @@ func (c *Cursor[T]) Close() error {

// FetchOne returns the first result from running the given Query on the given
// DB.
func FetchOne[T any](db DB, q Query, rowmapper func(*Row) (T, error)) (T, error) {
func FetchOne[T any](db DB, q Query, rowmapper func(*Row) T) (T, error) {
cursor, err := fetchCursor(context.Background(), db, q, rowmapper, 1)
if err != nil {
return *new(T), err
Expand All @@ -247,7 +244,7 @@ func FetchOne[T any](db DB, q Query, rowmapper func(*Row) (T, error)) (T, error)
}

// FetchOneContext is like FetchOne but additionally requires a context.Context.
func FetchOneContext[T any](ctx context.Context, db DB, q Query, rowmapper func(*Row) (T, error)) (T, error) {
func FetchOneContext[T any](ctx context.Context, db DB, q Query, rowmapper func(*Row) T) (T, error) {
cursor, err := fetchCursor(ctx, db, q, rowmapper, 1)
if err != nil {
return *new(T), err
Expand All @@ -257,7 +254,7 @@ func FetchOneContext[T any](ctx context.Context, db DB, q Query, rowmapper func(
}

// FetchAll returns all results from running the given Query on the given DB.
func FetchAll[T any](db DB, q Query, rowmapper func(*Row) (T, error)) ([]T, error) {
func FetchAll[T any](db DB, q Query, rowmapper func(*Row) T) ([]T, error) {
cursor, err := fetchCursor(context.Background(), db, q, rowmapper, 1)
if err != nil {
return nil, err
Expand All @@ -267,7 +264,7 @@ func FetchAll[T any](db DB, q Query, rowmapper func(*Row) (T, error)) ([]T, erro
}

// FetchAllContext is like FetchAll but additionally requires a context.Context.
func FetchAllContext[T any](ctx context.Context, db DB, q Query, rowmapper func(*Row) (T, error)) ([]T, error) {
func FetchAllContext[T any](ctx context.Context, db DB, q Query, rowmapper func(*Row) T) ([]T, error) {
cursor, err := fetchCursor(ctx, db, q, rowmapper, 1)
if err != nil {
return nil, err
Expand All @@ -283,11 +280,11 @@ type CompiledFetch[T any] struct {
query string
args []any
params map[string][]int
rowmapper func(*Row) (T, error)
rowmapper func(*Row) T
}

// NewCompiledFetch returns a new CompiledFetch.
func NewCompiledFetch[T any](dialect string, query string, args []any, params map[string][]int, rowmapper func(*Row) (T, error)) *CompiledFetch[T] {
func NewCompiledFetch[T any](dialect string, query string, args []any, params map[string][]int, rowmapper func(*Row) T) *CompiledFetch[T] {
return &CompiledFetch[T]{
dialect: dialect,
query: query,
Expand All @@ -298,12 +295,12 @@ func NewCompiledFetch[T any](dialect string, query string, args []any, params ma
}

// CompileFetch returns a new CompileFetch.
func CompileFetch[T any](q Query, rowmapper func(*Row) (T, error)) (*CompiledFetch[T], error) {
func CompileFetch[T any](q Query, rowmapper func(*Row) T) (*CompiledFetch[T], error) {
return CompileFetchContext(context.Background(), q, rowmapper)
}

// CompileFetchContext is like CompileFetch but accpets a context.Context.
func CompileFetchContext[T any](ctx context.Context, q Query, rowmapper func(*Row) (T, error)) (f *CompiledFetch[T], err error) {
func CompileFetchContext[T any](ctx context.Context, q Query, rowmapper func(*Row) T) (f *CompiledFetch[T], err error) {
if q == nil {
return nil, fmt.Errorf("query is nil")
}
Expand All @@ -314,11 +311,8 @@ func CompileFetchContext[T any](ctx context.Context, q Query, rowmapper func(*Ro
// Get fields from rowmapper
dialect := q.GetDialect()
row := newRow(dialect)
defer recoverRowmapperPanic(&err)
_, err = rowmapper(row)
if err != nil {
return nil, err
}
defer recoverPanic(&err)
_ = rowmapper(row)
if len(row.fields) == 0 {
return nil, fmt.Errorf("rowmapper did not yield any fields")
}
Expand Down Expand Up @@ -366,8 +360,8 @@ func (f *CompiledFetch[T]) fetchCursor(ctx context.Context, db DB, params Params

// Get fields and dest from rowmapper
c.row = newRow(f.dialect)
defer recoverRowmapperPanic(&err)
_, err = c.rowmapper(c.row)
defer recoverPanic(&err)
_ = c.rowmapper(c.row)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -456,7 +450,7 @@ func (f *CompiledFetch[T]) FetchAllContext(ctx context.Context, db DB, params Pa

// GetSQL returns a copy of the dialect, query, args, params and rowmapper that
// make up the CompiledFetch.
func (f *CompiledFetch[T]) GetSQL() (dialect string, query string, args []any, params map[string][]int, rowmapper func(*Row) (T, error)) {
func (f *CompiledFetch[T]) GetSQL() (dialect string, query string, args []any, params map[string][]int, rowmapper func(*Row) T) {
dialect = f.dialect
query = f.query
args = make([]any, len(f.args))
Expand Down Expand Up @@ -501,12 +495,12 @@ type PreparedFetch[T any] struct {
}

// PrepareFetch returns a new PreparedFetch.
func PrepareFetch[T any](db DB, q Query, rowmapper func(*Row) (T, error)) (*PreparedFetch[T], error) {
func PrepareFetch[T any](db DB, q Query, rowmapper func(*Row) T) (*PreparedFetch[T], error) {
return PrepareFetchContext(context.Background(), db, q, rowmapper)
}

// PrepareFetchContext is like PrepareFetch but additionally requires a context.Context.
func PrepareFetchContext[T any](ctx context.Context, db DB, q Query, rowmapper func(*Row) (T, error)) (*PreparedFetch[T], error) {
func PrepareFetchContext[T any](ctx context.Context, db DB, q Query, rowmapper func(*Row) T) (*PreparedFetch[T], error) {
compiledFetch, err := CompileFetchContext(ctx, q, rowmapper)
if err != nil {
return nil, err
Expand Down Expand Up @@ -539,8 +533,8 @@ func (f *PreparedFetch[T]) fetchCursor(ctx context.Context, params Params, skip

// Get fields and dest from rowmapper
c.row = newRow(f.compiled.dialect)
defer recoverRowmapperPanic(&err)
_, err = c.rowmapper(c.row)
defer recoverPanic(&err)
_ = c.rowmapper(c.row)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -913,17 +907,6 @@ func (e *PreparedExec) exec(ctx context.Context, params Params, skip int) (resul
return execResult(res, &stats)
}

func recoverRowmapperPanic(err *error) {
if r := recover(); r != nil {
switch r := r.(type) {
case error:
*err = r
default:
*err = fmt.Errorf("rowmapper panic: %v", r)
}
}
}

func getFieldNames(ctx context.Context, dialect string, fields []Field) []string {
buf := bufpool.Get().(*bytes.Buffer)
buf.Reset()
Expand Down
15 changes: 6 additions & 9 deletions fetch_exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ type Actor struct {
LastUpdate time.Time
}

func (actor Actor) RowMapper(a ACTOR) func(*Row) (Actor, error) {
return func(row *Row) (Actor, error) {
func (actor Actor) RowMapper(a ACTOR) func(*Row) Actor {
return func(row *Row) Actor {
actor.ActorID = row.IntField(a.ACTOR_ID)
actor.FirstName = row.StringField(a.FIRST_NAME)
actor.LastName = row.StringField(a.LAST_NAME)
actor.LastUpdate = row.TimeField(a.LAST_UPDATE)
return actor, nil
return actor
}
}

Expand Down Expand Up @@ -69,14 +69,13 @@ func TestFetchExec(t *testing.T) {
// Exec
res, err := Exec(Log(db), SQLite.
InsertInto(a).
ColumnValues(func(col *Column) error {
ColumnValues(func(col *Column) {
for _, actor := range __actors__ {
col.SetInt(a.ACTOR_ID, actor.ActorID)
col.SetString(a.FIRST_NAME, actor.FirstName)
col.SetString(a.LAST_NAME, actor.LastName)
col.SetTime(a.LAST_UPDATE, actor.LastUpdate)
}
return nil
}),
)
if err != nil {
Expand Down Expand Up @@ -121,12 +120,11 @@ func TestCompiledFetchExec(t *testing.T) {
// CompiledExec
insertActor, err := CompileExec(SQLite.
InsertInto(a).
ColumnValues(func(col *Column) error {
ColumnValues(func(col *Column) {
col.Set(a.ACTOR_ID, IntParam("actor_id", 0))
col.Set(a.FIRST_NAME, StringParam("first_name", ""))
col.Set(a.LAST_NAME, StringParam("last_name", ""))
col.Set(a.LAST_UPDATE, TimeParam("last_update", time.Time{}))
return nil
}),
)
if err != nil {
Expand Down Expand Up @@ -187,12 +185,11 @@ func TestPreparedFetchExec(t *testing.T) {
// PreparedExec
insertActor, err := PrepareExec(Log(db), SQLite.
InsertInto(a).
ColumnValues(func(col *Column) error {
ColumnValues(func(col *Column) {
col.Set(a.ACTOR_ID, IntParam("actor_id", 0))
col.Set(a.FIRST_NAME, StringParam("first_name", ""))
col.Set(a.LAST_NAME, StringParam("last_name", ""))
col.Set(a.LAST_UPDATE, TimeParam("last_update", time.Time{}))
return nil
}),
)
if err != nil {
Expand Down
18 changes: 9 additions & 9 deletions insert_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
// InsertQuery represents an SQL INSERT query.
type InsertQuery struct {
Dialect string
ColumnMapper func(*Column) error
ColumnMapper func(*Column)
// WITH
CTEs []CTE
// INSERT INTO
Expand All @@ -30,14 +30,14 @@ type InsertQuery struct {
var _ Query = (*InsertQuery)(nil)

// WriteSQL implements the SQLWriter interface.
func (q InsertQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
var err error
func (q InsertQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) (err error) {
if q.ColumnMapper != nil {
col := &Column{
dialect: q.Dialect,
isUpdate: false,
}
err = q.ColumnMapper(col)
defer recoverPanic(&err)
q.ColumnMapper(col)
if err != nil {
return err
}
Expand Down Expand Up @@ -159,7 +159,7 @@ func (q InsertQuery) Values(values ...any) InsertQuery {
}

// ColumnValues sets the ColumnMapper field of the InsertQuery.
func (q InsertQuery) ColumnValues(colmapper func(*Column) error) InsertQuery {
func (q InsertQuery) ColumnValues(colmapper func(*Column)) InsertQuery {
q.ColumnMapper = colmapper
return q
}
Expand Down Expand Up @@ -317,7 +317,7 @@ func (q SQLiteInsertQuery) Values(values ...any) SQLiteInsertQuery {
}

// ColumnValues sets the ColumnMapper field of the SQLiteInsertQuery.
func (q SQLiteInsertQuery) ColumnValues(colmapper func(*Column) error) SQLiteInsertQuery {
func (q SQLiteInsertQuery) ColumnValues(colmapper func(*Column)) SQLiteInsertQuery {
q.ColumnMapper = colmapper
return q
}
Expand Down Expand Up @@ -418,7 +418,7 @@ func (q PostgresInsertQuery) Values(values ...any) PostgresInsertQuery {
}

// ColumnValues sets the ColumnMapper field of the PostgresInsertQuery.
func (q PostgresInsertQuery) ColumnValues(colmapper func(*Column) error) PostgresInsertQuery {
func (q PostgresInsertQuery) ColumnValues(colmapper func(*Column)) PostgresInsertQuery {
q.ColumnMapper = colmapper
return q
}
Expand Down Expand Up @@ -541,7 +541,7 @@ func (q MySQLInsertQuery) As(rowAlias string) MySQLInsertQuery {
}

// ColumnValues sets the ColumnMapper field of the MySQLInsertQuery.
func (q MySQLInsertQuery) ColumnValues(colmapper func(*Column) error) MySQLInsertQuery {
func (q MySQLInsertQuery) ColumnValues(colmapper func(*Column)) MySQLInsertQuery {
q.ColumnMapper = colmapper
return q
}
Expand Down Expand Up @@ -610,7 +610,7 @@ func (q SQLServerInsertQuery) Values(values ...any) SQLServerInsertQuery {
}

// ColumnValues sets the ColumnMapper field of the SQLServerInsertQuery.
func (q SQLServerInsertQuery) ColumnValues(colmapper func(*Column) error) SQLServerInsertQuery {
func (q SQLServerInsertQuery) ColumnValues(colmapper func(*Column)) SQLServerInsertQuery {
q.ColumnMapper = colmapper
return q
}
Expand Down
Loading

0 comments on commit d31a2ba

Please sign in to comment.