Skip to content

Commit

Permalink
move codecs into template
Browse files Browse the repository at this point in the history
  • Loading branch information
jschaf committed Jan 26, 2024
1 parent 1b29c21 commit 2b41dd1
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 30 deletions.
112 changes: 87 additions & 25 deletions internal/codegen/golang/query.gotemplate
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ func NewQuerier(conn genericConn) *DBQuerier {
return &DBQuerier{conn: conn}
}

{{- range .Declarers}}{{- "\n\n" -}}{{ .Declare $.PkgPath }}{{ end -}}
{{- end -}}

{{- range $i, $q := .Queries -}}
Expand All @@ -55,34 +54,41 @@ const {{ $q.SQLVarName }} = {{ $q.EmitPreparedSQL }}
func (q *DBQuerier) {{ $q.Name }}(ctx context.Context {{- $q.EmitParams }}) ({{ $q.EmitResultType }}, error) {
ctx = context.WithValue(ctx, "pggen_query_name", "{{ $q.Name }}")
{{- if eq $q.ResultKind ":one" }}
row := q.conn.QueryRow(ctx, {{ $q.SQLVarName }} {{- $q.EmitParamNames }})
{{ $q.EmitResultTypeInit "item" }}
{{- $q.EmitResultDecoders }}
if err := row.Scan({{ $q.EmitRowScanArgs }}); err != nil {
return {{ $q.EmitResultExpr "item" }}, fmt.Errorf("query {{ $q.Name }}: %w", err)
}
{{- $q.EmitResultAssigns "item" }}
return {{ $q.EmitResultExpr "item" }}, nil
{{- else if eq $q.ResultKind ":many" }}
rows, err := q.conn.Query(ctx, {{ $q.SQLVarName }} {{- $q.EmitParamNames }})
row, err := q.conn.Query(ctx, {{ $q.SQLVarName }} {{- $q.EmitParamNames }})
if err != nil {
return nil, fmt.Errorf("query {{ $q.Name }}: %w", err)
}
defer rows.Close()
{{ $q.EmitResultTypeInit "items" }}
{{- $q.EmitResultDecoders }}
for rows.Next() {
var item {{ $q.EmitResultElem }}
if err := rows.Scan({{- $q.EmitRowScanArgs -}}); err != nil {
return nil, fmt.Errorf("scan {{ $q.Name }} row: %w", err)
}
{{- $q.EmitResultAssigns "nil" }}
items = append(items, {{ $q.EmitResultExpr "item" }})
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("close {{ $q.Name }} rows: %w", err)

fds := rows.FieldDescriptions()
{{- range $i, $col := $q.Outputs -}}
{{ $q.EmitPlanScan $i $col}}
{{- end -}}

return pgx.CollectExactlyOneRow(rows, func(row pgx.CollectableRow) ({{ $q.EmitSingularResultType }}, error) {
vals := row.RawValues()
{{ $q.EmitResultTypeInit "item" }}
{{- range $i, $col := $q.Outputs -}}
{{- $q.EmitScanColumn $i $col -}}
{{- end -}}
})
{{- else if eq $q.ResultKind ":many" }}
row, err := q.conn.Query(ctx, {{ $q.SQLVarName }} {{- $q.EmitParamNames }})
if err != nil {
return nil, fmt.Errorf("query {{ $q.Name }}: %w", err)
}
return items, err

fds := rows.FieldDescriptions()
{{- range $i, $col := $q.Outputs -}}
{{ $q.EmitPlanScan $i $col}}
{{- end -}}

return pgx.CollectRows(rows, func(row pgx.CollectableRow) ({{ $q.EmitSingularResultType }}, error) {
vals := row.RawValues()
{{ $q.EmitResultTypeInit "item" }}
{{- range $i, $col := $q.Outputs -}}
{{- $q.EmitScanColumn $i $col -}}
{{- end -}}
})
{{- else if eq $q.ResultKind ":exec" }}
cmdTag, err := q.conn.Exec(ctx, {{ $q.SQLVarName }} {{- $q.EmitParamNames }})
if err != nil {
Expand All @@ -93,4 +99,60 @@ func (q *DBQuerier) {{ $q.Name }}(ctx context.Context {{- $q.EmitParams }}) ({{
}
{{- end -}}
{{- "\n" -}}
{{ if .IsLeader }}
type scanCacheKey struct {
oid uint32
format int16
typeName string
}

var (
plans = make(map[scanCacheKey]pgtype.ScanPlan, 16)
plansMu sync.RWMutex
)

func planScan(codec pgtype.Codec, fd pgconn.FieldDescription, target any) pgtype.ScanPlan {
key := scanCacheKey{fd.DataTypeOID, fd.Format, fmt.Sprintf("%T", target)}
plansMu.RLock()
plan := plans[key]
plansMu.RUnlock()
if plan != nil {
return plan
}
plan = codec.PlanScan(nil, fd.DataTypeOID, fd.Format, target)
plansMu.Lock()
plans[key] = plan
plansMu.Unlock()
return plan
}

type ptrScanner[T any] struct {
basePlan pgtype.ScanPlan
}

func (s ptrScanner[T]) Scan(src []byte, dst any) error {
if src == nil {
return nil
}
d := dst.(**T)
*d = new(T)
return s.basePlan.Scan(src, *d)
}

func planPtrScan[T any](codec pgtype.Codec, fd pgconn.FieldDescription, target *T) pgtype.ScanPlan {
key := scanCacheKey{fd.DataTypeOID, fd.Format, fmt.Sprintf("*%T", target)}
plansMu.RLock()
plan := plans[key]
plansMu.RUnlock()
if plan != nil {
return plan
}
basePlan := planScan(codec, fd, target)
ptrPlan := ptrScanner[T]{basePlan}
plansMu.Lock()
plans[key] = plan
plansMu.Unlock()
return ptrPlan
}
{{- end -}}
{{- end -}}
59 changes: 54 additions & 5 deletions internal/codegen/golang/templated_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,31 @@ func (tq TemplatedQuery) isInlineParams() bool {
return len(tq.Inputs) <= tq.InlineParamCount
}

// EmitPlanScan emits the variable that hold the pgtype.ScanPlan for a query
// output.
func (tq TemplatedQuery) EmitPlanScan(idx int, out TemplatedColumn) (string, error) {
switch tq.ResultKind {
case ast.ResultKindExec:
return "", fmt.Errorf("cannot EmitPlanScanArgs for :exec query %s", tq.Name)
case ast.ResultKindMany, ast.ResultKindOne:
break // okay
default:
return "", fmt.Errorf("unhandled EmitPlanScanArgs type: %s", tq.ResultKind)
}
return fmt.Sprintf("planScan%d = pgtype.TODOCodec{}, fs[%d], (*%s)(nil))", idx, idx, out.Type.BaseName()), nil
}

// EmitScanColumn emits scan call for a single TemplatedColumn.
func (tq TemplatedQuery) EmitScanColumn(idx int, out TemplatedColumn) (string, error) {
sb := &strings.Builder{}
_, _ = fmt.Fprintf(sb, "if err := plan%d.Scan(vals[%d], &item); err != nil\n", idx, idx)
sb.WriteString("\t\t\t")
_, _ = fmt.Fprintf(sb, `return item, fmt.Errorf("scan %s.%s: %%w", err)`, tq.Name, out.PgName)
sb.WriteString("\n")
sb.WriteString("\t\t}\n")
return sb.String(), nil
}

// EmitRowScanArgs emits the args to scan a single row from a pgx.Row or
// pgx.Rows.
func (tq TemplatedQuery) EmitRowScanArgs() (string, error) {
Expand Down Expand Up @@ -256,9 +281,10 @@ func (tq TemplatedQuery) EmitRowScanArgs() (string, error) {
return sb.String(), nil
}

// EmitResultType returns the string representing the overall query result type,
// meaning the return result.
func (tq TemplatedQuery) EmitResultType() (string, error) {
// EmitSingularResultType returns the string representing a single element
// of the overall query result type, like FindAuthorsRow when the overall return
// type is []FindAuthorsRow.
func (tq TemplatedQuery) EmitSingularResultType() (string, error) {
outs := removeVoidColumns(tq.Outputs)
switch tq.ResultKind {
case ast.ResultKindExec:
Expand All @@ -268,9 +294,9 @@ func (tq TemplatedQuery) EmitResultType() (string, error) {
case 0:
return "pgconn.CommandTag", nil
case 1:
return "[]" + outs[0].QualType, nil
return outs[0].QualType, nil
default:
return "[]" + tq.Name + "Row", nil
return tq.Name + "Row", nil
}
case ast.ResultKindOne:
switch len(outs) {
Expand All @@ -281,6 +307,29 @@ func (tq TemplatedQuery) EmitResultType() (string, error) {
default:
return tq.Name + "Row", nil
}
default:
return "", fmt.Errorf("unhandled EmitSingularResultType kind: %s", tq.ResultKind)
}
}

// EmitResultType returns the string representing the overall query result type,
// meaning the return result.
func (tq TemplatedQuery) EmitResultType() (string, error) {
rt, err := tq.EmitSingularResultType()
if err != nil {
return "", fmt.Errorf("unhandled EmitResultType: %w", err)
}
switch tq.ResultKind {
case ast.ResultKindExec:
return rt, nil
case ast.ResultKindMany:
outs := removeVoidColumns(tq.Outputs)
if len(outs) == 0 {
return rt, nil
}
return "[]" + rt, nil
case ast.ResultKindOne:
return rt, nil
default:
return "", fmt.Errorf("unhandled EmitResultType kind: %s", tq.ResultKind)
}
Expand Down

0 comments on commit 2b41dd1

Please sign in to comment.