Skip to content

Commit

Permalink
add support for static queries (raw SQL)
Browse files Browse the repository at this point in the history
  • Loading branch information
bokwoon95 committed May 12, 2024
1 parent e1322f1 commit f076302
Show file tree
Hide file tree
Showing 4 changed files with 894 additions and 276 deletions.
181 changes: 120 additions & 61 deletions fetch_exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ import (
"time"
)

var (
errMixedCalls = fmt.Errorf("rowmapper cannot mix calls to row.Values()/row.Columns()/row.ColumnTypes() with the other row methods")
errNoFieldsAccessed = fmt.Errorf("rowmapper did not access any fields, unable to determine fields to insert into query")
errForbiddenCalls = fmt.Errorf("rowmapper can only contain calls to row.Values()/row.Columns()/row.ColumnTypes() because query's SELECT clause is not dynamic")
)

// Default dialect used by all queries (if no dialect is explicitly provided).
var DefaultDialect atomic.Pointer[string]

Expand Down Expand Up @@ -62,11 +56,14 @@ func fetchCursor[T any](ctx context.Context, db DB, query Query, rowmapper func(
dialect = *defaultDialect
}
}
// If we can't set the fetchable fields, the query is static.
_, ok := query.SetFetchableFields(nil)
cursor = &Cursor[T]{
ctx: ctx,
rowmapper: rowmapper,
row: &Row{
dialect: dialect,
dialect: dialect,
queryIsStatic: !ok,
},
queryStats: QueryStats{
Dialect: dialect,
Expand All @@ -75,21 +72,12 @@ func fetchCursor[T any](ctx context.Context, db DB, query Query, rowmapper func(
},
}

// Call the rowmapper to populate row.fields and row.scanDest.
defer mapperFunctionPanicked(&err)
_ = cursor.rowmapper(cursor.row)
var ok bool
if cursor.row.rawSQLMode && len(cursor.row.fields) > 0 {
return nil, errMixedCalls
}

// Insert the fields into the query.
query, ok = query.SetFetchableFields(cursor.row.fields)
if ok && len(cursor.row.fields) == 0 {
return nil, errNoFieldsAccessed
}
if !ok && len(cursor.row.fields) > 0 {
return nil, errForbiddenCalls
// If the query is dynamic, call the rowmapper to populate row.fields and
// row.scanDest. Then, insert those fields back into the query.
if !cursor.row.queryIsStatic {
defer mapperFunctionPanicked(&err)
_ = cursor.rowmapper(cursor.row)
query, _ = query.SetFetchableFields(cursor.row.fields)
}

// Build query.
Expand Down Expand Up @@ -134,6 +122,29 @@ func fetchCursor[T any](ctx context.Context, db DB, query Query, rowmapper func(
return nil, cursor.queryStats.Err
}

// If the query is static, we now know the number of columns returned by
// the query and can allocate the values slice and scanDest slice for
// scanning later.
if cursor.row.queryIsStatic {
cursor.row.columns, err = cursor.row.sqlRows.Columns()
if err != nil {
return nil, err
}
cursor.row.columnTypes, err = cursor.row.sqlRows.ColumnTypes()
if err != nil {
return nil, err
}
cursor.row.columnIndex = make(map[string]int)
for index, column := range cursor.row.columns {
cursor.row.columnIndex[column] = index
}
cursor.row.values = make([]any, len(cursor.row.columns))
cursor.row.scanDest = make([]any, len(cursor.row.columns))
for index := range cursor.row.values {
cursor.row.scanDest[index] = &cursor.row.values[index]
}
}

// Allocate the resultsBuffer.
if cursor.logSettings.IncludeResults > 0 {
cursor.resultsBuffer = bufpool.Get().(*bytes.Buffer)
Expand All @@ -158,13 +169,11 @@ func (cursor *Cursor[T]) RowCount() int64 { return cursor.queryStats.RowCount.In

// Result returns the cursor result.
func (cursor *Cursor[T]) Result() (result T, err error) {
if !cursor.row.rawSQLMode {
err = cursor.row.sqlRows.Scan(cursor.row.scanDest...)
if err != nil {
cursor.log()
fieldMappings := getFieldMappings(cursor.queryStats.Dialect, cursor.row.fields, cursor.row.scanDest)
return result, fmt.Errorf("please check if your mapper function is correct:%s\n%w", fieldMappings, err)
}
err = cursor.row.sqlRows.Scan(cursor.row.scanDest...)
if err != nil {
cursor.log()
fieldMappings := getFieldMappings(cursor.queryStats.Dialect, cursor.row.fields, cursor.row.scanDest)
return result, fmt.Errorf("please check if your mapper function is correct:%s\n%w", fieldMappings, err)
}
// If results should be logged, write the row into the resultsBuffer.
if cursor.resultsBuffer != nil && cursor.queryStats.RowCount.Int64 <= int64(cursor.logSettings.IncludeResults) {
Expand All @@ -187,7 +196,7 @@ func (cursor *Cursor[T]) Result() (result T, err error) {
cursor.resultsBuffer.WriteString(rhs)
}
}
cursor.row.index = 0
cursor.row.runningIndex = 0
defer mapperFunctionPanicked(&err)
result = cursor.rowmapper(cursor.row)
return result, nil
Expand Down Expand Up @@ -272,6 +281,10 @@ type CompiledFetch[T any] struct {
args []any
params map[string][]int
rowmapper func(*Row) T
// if queryIsStatic is true, the rowmapper doesn't actually know what
// columns are in the query and it must be determined at runtime after
// running the query.
queryIsStatic bool
}

// NewCompiledFetch returns a new CompiledFetch.
Expand Down Expand Up @@ -305,30 +318,25 @@ func CompileFetchContext[T any](ctx context.Context, query Query, rowmapper func
dialect = *defaultDialect
}
}
// If we can't set the fetchable fields, the query is static.
_, ok := query.SetFetchableFields(nil)
compiledFetch = &CompiledFetch[T]{
dialect: dialect,
params: make(map[string][]int),
rowmapper: rowmapper,
dialect: dialect,
params: make(map[string][]int),
rowmapper: rowmapper,
queryIsStatic: !ok,
}
row := &Row{
dialect: dialect,
}

// Call the rowmapper to populate row.fields.
defer mapperFunctionPanicked(&err)
_ = rowmapper(row)
var ok bool
if row.rawSQLMode && len(row.fields) > 0 {
return nil, errMixedCalls
dialect: dialect,
queryIsStatic: !ok,
}

// Insert the fields into the query.
query, ok = query.SetFetchableFields(row.fields)
if ok && len(row.fields) == 0 {
return nil, errNoFieldsAccessed
}
if !ok && len(row.fields) > 0 {
return nil, errForbiddenCalls
// If the query is dynamic, call the rowmapper to populate row.fields.
// Then, insert those fields back into the query.
if !row.queryIsStatic {
defer mapperFunctionPanicked(&err)
_ = rowmapper(row)
query, _ = query.SetFetchableFields(row.fields)
}

// Build query.
Expand Down Expand Up @@ -361,7 +369,8 @@ func (compiledFetch *CompiledFetch[T]) fetchCursor(ctx context.Context, db DB, p
ctx: ctx,
rowmapper: compiledFetch.rowmapper,
row: &Row{
dialect: compiledFetch.dialect,
dialect: compiledFetch.dialect,
queryIsStatic: compiledFetch.queryIsStatic,
},
queryStats: QueryStats{
Dialect: compiledFetch.dialect,
Expand All @@ -372,10 +381,9 @@ func (compiledFetch *CompiledFetch[T]) fetchCursor(ctx context.Context, db DB, p
}

// Call the rowmapper to populate row.scanDest.
defer mapperFunctionPanicked(&err)
_ = cursor.rowmapper(cursor.row)
if err != nil {
return nil, err
if !cursor.row.queryIsStatic {
defer mapperFunctionPanicked(&err)
_ = cursor.rowmapper(cursor.row)
}

// Substitute params.
Expand Down Expand Up @@ -416,6 +424,29 @@ func (compiledFetch *CompiledFetch[T]) fetchCursor(ctx context.Context, db DB, p
return nil, cursor.queryStats.Err
}

// If the query is static, we now know the number of columns returned by
// the query and can allocate the values slice and scanDest slice for
// scanning later.
if cursor.row.queryIsStatic {
cursor.row.columns, err = cursor.row.sqlRows.Columns()
if err != nil {
return nil, err
}
cursor.row.columnTypes, err = cursor.row.sqlRows.ColumnTypes()
if err != nil {
return nil, err
}
cursor.row.columnIndex = make(map[string]int)
for index, column := range cursor.row.columns {
cursor.row.columnIndex[column] = index
}
cursor.row.values = make([]any, len(cursor.row.columns))
cursor.row.scanDest = make([]any, len(cursor.row.columns))
for index := range cursor.row.values {
cursor.row.scanDest[index] = &cursor.row.values[index]
}
}

// Allocate the resultsBuffer.
if cursor.logSettings.IncludeResults > 0 {
cursor.resultsBuffer = bufpool.Get().(*bytes.Buffer)
Expand Down Expand Up @@ -494,6 +525,7 @@ func (compiledFetch *CompiledFetch[T]) PrepareContext(ctx context.Context, db DB
preparedFetch := &PreparedFetch[T]{
compiledFetch: NewCompiledFetch(compiledFetch.GetSQL()),
}
preparedFetch.compiledFetch.queryIsStatic = compiledFetch.queryIsStatic
if db == nil {
return nil, fmt.Errorf("db is nil")
}
Expand Down Expand Up @@ -551,7 +583,8 @@ func (preparedFetch *PreparedFetch[T]) fetchCursor(ctx context.Context, params P
ctx: ctx,
rowmapper: preparedFetch.compiledFetch.rowmapper,
row: &Row{
dialect: preparedFetch.compiledFetch.dialect,
dialect: preparedFetch.compiledFetch.dialect,
queryIsStatic: preparedFetch.compiledFetch.queryIsStatic,
},
queryStats: QueryStats{
Dialect: preparedFetch.compiledFetch.dialect,
Expand All @@ -563,11 +596,10 @@ func (preparedFetch *PreparedFetch[T]) fetchCursor(ctx context.Context, params P
logger: preparedFetch.logger,
}

// Call the rowmapper to populate row.scanDest.
defer mapperFunctionPanicked(&err)
_ = cursor.rowmapper(cursor.row)
if err != nil {
return nil, err
// If the query is dynamic, call the rowmapper to populate row.scanDest.
if !cursor.row.queryIsStatic {
defer mapperFunctionPanicked(&err)
_ = cursor.rowmapper(cursor.row)
}

// Substitute params.
Expand Down Expand Up @@ -596,6 +628,29 @@ func (preparedFetch *PreparedFetch[T]) fetchCursor(ctx context.Context, params P
return nil, cursor.queryStats.Err
}

// If the query is static, we now know the number of columns returned by
// the query and can allocate the values slice and scanDest slice for
// scanning later.
if cursor.row.queryIsStatic {
cursor.row.columns, err = cursor.row.sqlRows.Columns()
if err != nil {
return nil, err
}
cursor.row.columnTypes, err = cursor.row.sqlRows.ColumnTypes()
if err != nil {
return nil, err
}
cursor.row.columnIndex = make(map[string]int)
for index, column := range cursor.row.columns {
cursor.row.columnIndex[column] = index
}
cursor.row.values = make([]any, len(cursor.row.columns))
cursor.row.scanDest = make([]any, len(cursor.row.columns))
for index := range cursor.row.values {
cursor.row.scanDest[index] = &cursor.row.values[index]
}
}

// Allocate the resultsBuffer.
if cursor.logSettings.IncludeResults > 0 {
cursor.resultsBuffer = bufpool.Get().(*bytes.Buffer)
Expand Down Expand Up @@ -648,7 +703,9 @@ func (preparedFetch *PreparedFetch[T]) FetchAllContext(ctx context.Context, para

// GetCompiled returns a copy of the underlying CompiledFetch.
func (preparedFetch *PreparedFetch[T]) GetCompiled() *CompiledFetch[T] {
return NewCompiledFetch(preparedFetch.compiledFetch.GetSQL())
compiledFetch := NewCompiledFetch(preparedFetch.compiledFetch.GetSQL())
compiledFetch.queryIsStatic = preparedFetch.compiledFetch.queryIsStatic
return compiledFetch
}

// Close closes the PreparedFetch.
Expand Down Expand Up @@ -1048,6 +1105,8 @@ func getFieldMappings(dialect string, fields []Field, scanDest []any) string {
return b.String()
}

// TODO: inline cursorResult, cursorResults and execResult.

func cursorResult[T any](cursor *Cursor[T]) (result T, err error) {
for cursor.Next() {
result, err = cursor.Result()
Expand Down
Loading

0 comments on commit f076302

Please sign in to comment.