Skip to content

Commit

Permalink
Infer types with prepare instead of execution
Browse files Browse the repository at this point in the history
Previously, we used the higher level pgx Exec method to run a Postgres PREPARE
command and then inspected the Postgres catalog.

Instead, issue a prepare command directly with the lower-level pgconn which
returns the fields and parameter OIDs directly.
  • Loading branch information
jschaf committed Jun 17, 2023
1 parent 85a8613 commit 55d8fc9
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 108 deletions.
20 changes: 9 additions & 11 deletions ARCHITECTURE.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,21 @@ in the following steps.
`pginfer.TypedQuery` in [internal/pginfer/pginfer.go].

To determine the Postgres types, pggen uses itself to compile the queries
in [internal/pg/query.sql]. The queries leverage the Postgres catalog
tables to get the input parameter types. Specifically, pggen determines
input parameters types by using a `PREPARE` statement and querying the
[`pg_prepared_statement`] table to get type information for each parameter.
in [internal/pg/query.sql]. The queries leverage the Postgres prepare
command to find the input parameter types.

pggen determines output columns types and names by executing the query and
pggen determines output columns types and names by preparing the query and
reading the field descriptions returned with the query result rows. The
field descriptions contain the type ID for each output column. The type ID
is a Postgres object ID (OID), the primary key to identify a row in the
[`pg_type`] catalog table.

pggen determines if an output column can be null using heuristics. If a column
cannot be null, pggen uses more ergonomic types to represent the output like
`string` instead of `pgtype.Text`. The heuristics are quite simple; see
[internal/pginfer/nullability.go]. A proper approach requires a control
flow analysis to determine nullability. I've started down that road in
[pgplan.go](./internal/pgplan/pgplan.go).
pggen determines if an output column can be null using heuristics. If a
column cannot be null, pggen uses more ergonomic types to represent the
output like `string` instead of `pgtype.Text`. The heuristics are quite
simple; see [internal/pginfer/nullability.go]. A proper approach requires a
control flow analysis to determine nullability. I've started down that road
in [pgplan.go](./internal/pgplan/pgplan.go).

5. Transform each `*ast.File` into `codegen.QueryFile` in [generate.go]
`parseQueries`.
Expand Down
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -561,10 +561,9 @@ We'll walk through the generated file `author/query.sql.go`:
example, the generated query for `SELECT author_id from author` returns
`int32`, not a `<query_name>Row` struct.

pggen infers struct field types by running the query. When Postgres returns
query results, Postgres also sends the column types as a header for the
results. pggen looks up the types in the header using the `pg_type` catalog
table and chooses an appropriate Go type in
pggen infers struct field types by preparing the query. When Postgres
prepares a query, Postgres returns the parameter and column types as OIDs.
pggen finds the type name from the returned OIDs in
[internal/codegen/golang/gotype/types.go].

Choosing an appropriate type is more difficult than might seem at first
Expand Down
138 changes: 45 additions & 93 deletions internal/pginfer/pginfer.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v4"
"github.com/jschaf/pggen/internal/ast"
"github.com/jschaf/pggen/internal/errs"
"github.com/jschaf/pggen/internal/pg"
)

Expand Down Expand Up @@ -76,11 +75,7 @@ func NewInferrer(conn *pgx.Conn) *Inferrer {
}

func (inf *Inferrer) InferTypes(query *ast.SourceQuery) (TypedQuery, error) {
inputs, err := inf.inferInputTypes(query)
if err != nil {
return TypedQuery{}, fmt.Errorf("infer input types for query: %w", err)
}
outputs, err := inf.inferOutputTypes(query)
inputs, outputs, err := inf.prepareTypes(query)
if err != nil {
return TypedQuery{}, fmt.Errorf("infer output types for query: %w", err)
}
Expand Down Expand Up @@ -108,81 +103,15 @@ func (inf *Inferrer) InferTypes(query *ast.SourceQuery) (TypedQuery, error) {
}, nil
}

func (inf *Inferrer) inferInputTypes(query *ast.SourceQuery) (ps []InputParam, mErr error) {
if len(query.ParamNames) == 0 {
return nil, nil
}

// Prepare the query so we can get the parameter types from Postgres.
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
prepareName := "pggen_" + query.Name
prepareQuery := fmt.Sprintf(`PREPARE %s AS %s`, prepareName, query.PreparedSQL)
_, err := inf.conn.Exec(ctx, prepareQuery)
if err != nil {
return nil, fmt.Errorf("exec prepare statement to infer input query types: %w", err)
}
defer errs.Capture(&mErr, func() error { return inf.deallocateQuery(prepareName) }, "")

// Get the parameter types from the pg_prepared_statements table.
ctx, cancel = context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
catalogQuery := `SELECT parameter_types::int[] FROM pg_prepared_statements WHERE lower(name) = lower($1)`
row := inf.conn.QueryRow(ctx, catalogQuery, prepareName)
oids := make([]uint32, 0, len(query.ParamNames))
if err := row.Scan(&oids); err != nil {
return nil, fmt.Errorf("scan prepared parameter types: %w", err)
}
if len(oids) != len(query.ParamNames) {
return nil, fmt.Errorf("expected %d parameter types for query; got %d",
len(query.ParamNames), len(oids))
}
types, err := inf.typeFetcher.FindTypesByOIDs(oids...)
if err != nil {
return nil, fmt.Errorf("fetch oid types: %w", err)
}

// Build up the input params.
params := make([]InputParam, len(query.ParamNames))
for i := 0; i < len(params); i++ {
pgType := types[pgtype.OID(oids[i])]
params[i] = InputParam{
PgName: query.ParamNames[i],
PgType: pgType,
}
}
return params, nil
}

// Deallocates a prepared query. Implemented mostly for tests so we can reuse
// the same query name. Postgres doesn't allow PREPARE with duplicates.
func (inf *Inferrer) deallocateQuery(name string) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
_, err := inf.conn.Exec(ctx, "DEALLOCATE "+name)
if err != nil {
return fmt.Errorf("deallocate query %s: %w", name, err)
}
return nil
}

func (inf *Inferrer) inferOutputTypes(query *ast.SourceQuery) ([]OutputColumn, error) {
func (inf *Inferrer) prepareTypes(query *ast.SourceQuery) (_a []InputParam, _ []OutputColumn, mErr error) {
// Execute the query to get field descriptions of the output columns.
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
rows, err := inf.conn.Query(ctx, query.PreparedSQL, createParamArgs(query)...)

// If paramOIDs is null, Postgres infers the type for each parameter.
var paramOIDs []uint32
stmtDesc, err := inf.conn.PgConn().Prepare(ctx, "", query.PreparedSQL, paramOIDs)
if err != nil {
return nil, fmt.Errorf("execute output query: %w", err)
}
descriptions := make([]pgproto3.FieldDescription, len(rows.FieldDescriptions()))
copy(descriptions, rows.FieldDescriptions()) // pgx reuses row objects
rows.Close()
// We can ignore the error if we got the field descriptions. Most queries
// will error with a not-null constraint violation since we use null for all
// parameters in createParamArgs. For :exec queries, ignore the error since we
// don't need the field descriptions.
haveDescriptions := len(descriptions) > 0 || query.ResultKind == ast.ResultKindExec
if err := rows.Err(); err != nil && !haveDescriptions {
if pgErr, ok := err.(*pgconn.PgError); ok {
msg := "fetch field descriptions: " + pgErr.Message
if pgErr.Where != "" {
Expand Down Expand Up @@ -210,42 +139,65 @@ func (inf *Inferrer) inferOutputTypes(query *ast.SourceQuery) ([]OutputColumn, e
msg += "\n a RETURNING clause (this query is marked " + string(query.ResultKind) + ")."
msg += "\n Use :exec if you don't need the query output."
}
return nil, fmt.Errorf(msg+"\n %w", pgErr)
return nil, nil, fmt.Errorf(msg+"\n %w", pgErr)
}
return nil, nil, fmt.Errorf("prepare query to infer types: %w", err)
}

// Validate.
if len(stmtDesc.ParamOIDs) != len(query.ParamNames) {
return nil, nil, fmt.Errorf("expected %d parameter types for query; got %d", len(query.ParamNames), len(stmtDesc.ParamOIDs))
}

// Build input params.
var inputParams []InputParam
if len(stmtDesc.ParamOIDs) > 0 {
types, err := inf.typeFetcher.FindTypesByOIDs(stmtDesc.ParamOIDs...)
if err != nil {
return nil, nil, fmt.Errorf("fetch oid types: %w", err)
}
for i, oid := range stmtDesc.ParamOIDs {
inputType, ok := types[pgtype.OID(oid)]
if !ok {
return nil, nil, fmt.Errorf("no postgres type name found for parameter %s with oid %d", query.ParamNames[i], oid)
}
inputParams = append(inputParams, InputParam{
PgName: query.ParamNames[i],
PgType: inputType,
})
}
return nil, fmt.Errorf("fetch field descriptions: %w", err)
}

// Resolve type names of output column data type OIDs.
oids := make([]uint32, len(descriptions))
for i, desc := range descriptions {
oids[i] = desc.DataTypeOID
outputOIDs := make([]uint32, len(stmtDesc.Fields))
for i, desc := range stmtDesc.Fields {
outputOIDs[i] = desc.DataTypeOID
}
types, err := inf.typeFetcher.FindTypesByOIDs(oids...)
outputTypes, err := inf.typeFetcher.FindTypesByOIDs(outputOIDs...)
if err != nil {
return nil, fmt.Errorf("fetch oid types: %w", err)
return nil, nil, fmt.Errorf("fetch oid types: %w", err)
}

// Output nullability.
nullables, err := inf.inferOutputNullability(query, descriptions)
nullables, err := inf.inferOutputNullability(query, stmtDesc.Fields)
if err != nil {
return nil, fmt.Errorf("infer output type nullability: %w", err)
return nil, nil, fmt.Errorf("infer output type nullability: %w", err)
}

// Create output columns
var outs []OutputColumn
for i, desc := range descriptions {
pgType, ok := types[pgtype.OID(desc.DataTypeOID)]
var outputColumns []OutputColumn
for i, desc := range stmtDesc.Fields {
pgType, ok := outputTypes[pgtype.OID(desc.DataTypeOID)]
if !ok {
return nil, fmt.Errorf("no type name found for oid %d", desc.DataTypeOID)
return nil, nil, fmt.Errorf("no postgrestype name found for column %s with oid %d", string(desc.Name), desc.DataTypeOID)
}

outs = append(outs, OutputColumn{
outputColumns = append(outputColumns, OutputColumn{
PgName: string(desc.Name),
PgType: pgType,
Nullable: nullables[i],
})
}
return outs, nil
return inputParams, outputColumns, nil
}

// inferOutputNullability infers which of the output columns produced by the
Expand Down

0 comments on commit 55d8fc9

Please sign in to comment.