diff --git a/pg/catalog/analyze.go b/pg/catalog/analyze.go index 9c94e595..2dbc658c 100644 --- a/pg/catalog/analyze.go +++ b/pg/catalog/analyze.go @@ -783,7 +783,7 @@ func (ac *analyzeCtx) addRangeTableEntryForFunction( switch cls { case typeFuncScalar: // Single column with name from function. - cname := chooseScalarFunctionAlias(funcNames[i], alias, len(funcExprs)) + cname := ac.chooseScalarFunctionAlias(fexpr, funcNames[i], alias, len(funcExprs)) colNames = append(colNames, cname) colTypes = append(colTypes, retType) colTypMods = append(colTypMods, fexpr.exprTypMod()) @@ -979,13 +979,59 @@ func (ac *analyzeCtx) appendOutputParamColumns( // chooseScalarFunctionAlias picks the column name for a scalar function in FROM. // // pg: src/backend/parser/parse_relation.c — chooseScalarFunctionAlias -func chooseScalarFunctionAlias(funcName, rteAlias string, nfuncs int) string { +func (ac *analyzeCtx) chooseScalarFunctionAlias(fexpr AnalyzedExpr, funcName, rteAlias string, nfuncs int) string { + if pname := ac.funcResultName(fexpr); pname != "" { + return pname + } if nfuncs == 1 && rteAlias != "" { return rteAlias } return funcName } +// funcResultName returns the single named OUT/INOUT/TABLE parameter name for +// a function, if one exists. +// +// pg: src/backend/utils/fmgr/funcapi.c — get_func_result_name +func (ac *analyzeCtx) funcResultName(fexpr AnalyzedExpr) string { + call, ok := fexpr.(*FuncCallExpr) + if !ok || call.FuncOID == 0 { + return "" + } + + var modes []byte + var names []string + if up := ac.catalog.userProcs[call.FuncOID]; up != nil { + modes = up.ArgModes + names = up.ArgNames + } else if proc := ac.catalog.procByOID[call.FuncOID]; proc != nil { + modes = proc.ArgModes + names = proc.ArgNames + } + if len(modes) == 0 || len(names) == 0 { + return "" + } + + outCount := 0 + result := "" + for i, mode := range modes { + switch mode { + case 'i', 'v': + continue + case 'o', 'b', 't': + outCount++ + if outCount > 1 || i >= len(names) || names[i] == "" { + return "" + } + result = names[i] + } + } + if outCount != 1 { + return "" + } + return result +} + // figureColname extracts the function name from a raw expression for FROM alias. // // pg: src/backend/parser/parse_target.c — FigureColname diff --git a/pg/catalog/query_span_test.go b/pg/catalog/query_span_test.go index 31c25c4d..c074b1e1 100644 --- a/pg/catalog/query_span_test.go +++ b/pg/catalog/query_span_test.go @@ -696,6 +696,61 @@ func TestQuerySpan_ViewLineage(t *testing.T) { } } +func TestAnalyzeReturnsTableSingleColumnFunctionNames(t *testing.T) { + c := New() + _, err := c.Exec(` + CREATE FUNCTION plpgsql_t(x integer) + RETURNS TABLE(a integer) + LANGUAGE plpgsql + AS $$ + BEGIN + RETURN QUERY SELECT x; + END + $$; + `, nil) + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + sql string + want string + }{ + { + name: "table source uses table column name", + sql: `SELECT * FROM plpgsql_t(1)`, + want: "a", + }, + { + name: "rte alias does not override named output parameter", + sql: `SELECT * FROM plpgsql_t(1) AS x`, + want: "a", + }, + { + name: "column alias overrides named output parameter", + sql: `SELECT * FROM plpgsql_t(1) AS x(y)`, + want: "y", + }, + { + name: "scalar call uses function name", + sql: `SELECT plpgsql_t(1)`, + want: "plpgsql_t", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := parseAndAnalyze(t, c, tt.sql) + if got := len(q.TargetList); got != 1 { + t.Fatalf("target count = %d, want 1", got) + } + if got := q.TargetList[0].ResName; got != tt.want { + t.Fatalf("target name = %q, want %q", got, tt.want) + } + }) + } +} + func TestQuerySpan_DifferentSchemas(t *testing.T) { // Test cross-schema references c := New()