From 2bdb5d3689fad4fadd76bca65ca42bc9fd8b6f1e Mon Sep 17 00:00:00 2001 From: Daylon Wilkins Date: Thu, 20 Nov 2025 07:32:17 -0800 Subject: [PATCH] Table as column draft --- enginetest/memory_engine_test.go | 29 ++++++++----- sql/expression/alias.go | 49 ++++++++++++++++++++++ sql/expression/function/registry.go | 64 +++++++++++++++++++++++++++++ sql/functions.go | 9 ++++ sql/planbuilder/from.go | 1 + sql/planbuilder/scalar.go | 39 ++++++++++++++---- sql/planbuilder/scope.go | 61 +++++++++++++++++++++++++++ 7 files changed, 234 insertions(+), 18 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index f1ec7b45d0..a482705349 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -200,23 +200,30 @@ func TestSingleQueryPrepared(t *testing.T) { // Convenience test for debugging a single query. Unskip and set to the desired query. func TestSingleScript(t *testing.T) { - t.Skip() + //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "AS OF propagates to nested CALLs", - SetUpScript: []string{}, + Name: "AS OF propagates to nested CALLs", + SetUpScript: []string{ + `CREATE TABLE test (pk INT PRIMARY KEY, v1 VARCHAR(255));`, + `INSERT INTO test VALUES (1, 'a'), (2, 'b');`, + }, Assertions: []queries.ScriptTestAssertion{ { - Query: "create procedure create_proc() create table t (i int primary key, j int);", - Expected: []sql.Row{ - {types.NewOkResult(0)}, - }, + Query: "SELECT temporarytesting(t) FROM test AS t;", + Expected: []sql.Row{}, }, { - Query: "call create_proc()", - Expected: []sql.Row{ - {types.NewOkResult(0)}, - }, + Query: "SELECT temporarytesting(test) FROM test;", + Expected: []sql.Row{}, + }, + { + Query: "SELECT temporarytesting(pk, test) FROM test;", + Expected: []sql.Row{}, + }, + { + Query: "SELECT temporarytesting(v1, test, pk) FROM test;", + Expected: []sql.Row{}, }, }, }, diff --git a/sql/expression/alias.go b/sql/expression/alias.go index ea587555c9..aaf61d7607 100644 --- a/sql/expression/alias.go +++ b/sql/expression/alias.go @@ -165,3 +165,52 @@ func (e *Alias) WithChildren(children ...sql.Expression) (sql.Expression, error) // Name implements the Nameable interface. func (e *Alias) Name() string { return e.name } + +// TODO: DELETE EVERYTHING UNDER HERE ----------------------------------------------------------------------- +// --------------------------------------------------------------------------------------------------------- +// --------------------------------------------------------------------------------------------------------- +// --------------------------------------------------------------------------------------------------------- +// --------------------------------------------------------------------------------------------------------- + +// This is an expression that will be returned from a Doltgres hook (GMS' hook will return a nil expression to indicate +// incompatibility). This function is just a stand-in for testing purposes. +type DoltgresHookExpression struct { + args []sql.Expression +} + +var _ sql.Expression = (*DoltgresHookExpression)(nil) + +func NewDoltgresHookExpression(args ...sql.Expression) sql.Expression { + return &DoltgresHookExpression{args: args} +} + +func (tt *DoltgresHookExpression) String() string { return "temporarytesting2" } + +// Type implements the Expression interface. +func (tt *DoltgresHookExpression) Type() sql.Type { return types.Int32 } + +// Eval implements the Expression interface. +func (tt *DoltgresHookExpression) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + rowLen := len(row) + return int32(-rowLen), nil +} + +// Resolved implements the Expression interface. +func (tt *DoltgresHookExpression) Resolved() bool { + return true +} + +// Children implements the Expression interface. +func (tt *DoltgresHookExpression) Children() []sql.Expression { + return tt.args +} + +// IsNullable implements the Expression interface. +func (tt *DoltgresHookExpression) IsNullable() bool { + return false +} + +// WithChildren implements the Expression interface. +func (*DoltgresHookExpression) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewDoltgresHookExpression(children...), nil +} diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index 8a6fccbc66..0e00c85264 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -25,6 +25,7 @@ import ( "github.com/dolthub/go-mysql-server/sql/expression/function/json" "github.com/dolthub/go-mysql-server/sql/expression/function/spatial" "github.com/dolthub/go-mysql-server/sql/expression/function/vector" + "github.com/dolthub/go-mysql-server/sql/types" ) // ErrFunctionAlreadyRegistered is thrown when a function is already registered @@ -342,6 +343,7 @@ var BuiltIns = []sql.Function{ sql.Function1{Name: "weekofyear", Fn: NewWeekOfYear}, sql.Function1{Name: "year", Fn: NewYear}, sql.FunctionN{Name: "yearweek", Fn: NewYearWeek}, + sql.FunctionN{Name: "temporarytesting", Fn: NewTemporaryTesting}, // TODO: DELETE ME } func GetLockingFuncs(ls *sql.LockSubsystem) []sql.Function { @@ -390,3 +392,65 @@ func (r Registry) mustRegister(fn ...sql.Function) { panic(err) } } + +// TODO: DELETE EVERYTHING UNDER HERE ----------------------------------------------------------------------- +// --------------------------------------------------------------------------------------------------------- +// --------------------------------------------------------------------------------------------------------- +// --------------------------------------------------------------------------------------------------------- +// --------------------------------------------------------------------------------------------------------- + +// This is an example Doltgres function. This exists solely for testing purposes, and will be deleted as noted by the +// massive comment above this. +type DoltgresFunction struct { + args []sql.Expression +} + +var _ sql.FunctionExpression = (*DoltgresFunction)(nil) + +func NewTemporaryTesting(args ...sql.Expression) (sql.Expression, error) { + return &DoltgresFunction{args: args}, nil +} + +// FunctionName implements sql.FunctionExpression +func (tt *DoltgresFunction) FunctionName() string { + return "temporarytesting" +} + +// Description implements sql.FunctionExpression +func (tt *DoltgresFunction) Description() string { + return "" +} + +func (tt *DoltgresFunction) String() string { return "temporarytesting()" } + +// Type implements the Expression interface. +func (tt *DoltgresFunction) Type() sql.Type { return types.Int32 } + +// CollationCoercibility implements the interface sql.CollationCoercible. +func (*DoltgresFunction) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + return sql.Collation_binary, 5 +} + +// Eval implements the Expression interface. +func (tt *DoltgresFunction) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + rowLen := len(row) + return int32(rowLen), nil +} + +// Resolved implements the Expression interface. +func (tt *DoltgresFunction) Resolved() bool { + return true +} + +// Children implements the Expression interface. +func (tt *DoltgresFunction) Children() []sql.Expression { return tt.args } + +// IsNullable implements the Expression interface. +func (tt *DoltgresFunction) IsNullable() bool { + return false +} + +// WithChildren implements the Expression interface. +func (*DoltgresFunction) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewTemporaryTesting(children...) +} diff --git a/sql/functions.go b/sql/functions.go index b0a804f3c4..f9bbec7d41 100644 --- a/sql/functions.go +++ b/sql/functions.go @@ -108,6 +108,7 @@ func NewFunction0(name string, fn func() Expression) Function0 { } } +// NewInstance implements the interface Function. func (fn Function0) NewInstance(args []Expression) (Expression, error) { if len(args) != 0 { return nil, ErrInvalidArgumentNumber.New(fn.Name, 0, len(args)) @@ -116,6 +117,7 @@ func (fn Function0) NewInstance(args []Expression) (Expression, error) { return fn.Fn(), nil } +// NewInstance implements the interface Function. func (fn Function1) NewInstance(args []Expression) (Expression, error) { if len(args) != 1 { return nil, ErrInvalidArgumentNumber.New(fn.Name, 1, len(args)) @@ -124,6 +126,7 @@ func (fn Function1) NewInstance(args []Expression) (Expression, error) { return fn.Fn(args[0]), nil } +// NewInstance implements the interface Function. func (fn Function2) NewInstance(args []Expression) (Expression, error) { if len(args) != 2 { return nil, ErrInvalidArgumentNumber.New(fn.Name, 2, len(args)) @@ -132,6 +135,7 @@ func (fn Function2) NewInstance(args []Expression) (Expression, error) { return fn.Fn(args[0], args[1]), nil } +// NewInstance implements the interface Function. func (fn Function3) NewInstance(args []Expression) (Expression, error) { if len(args) != 3 { return nil, ErrInvalidArgumentNumber.New(fn.Name, 3, len(args)) @@ -140,6 +144,7 @@ func (fn Function3) NewInstance(args []Expression) (Expression, error) { return fn.Fn(args[0], args[1], args[2]), nil } +// NewInstance implements the interface Function. func (fn Function4) NewInstance(args []Expression) (Expression, error) { if len(args) != 4 { return nil, ErrInvalidArgumentNumber.New(fn.Name, 4, len(args)) @@ -148,6 +153,7 @@ func (fn Function4) NewInstance(args []Expression) (Expression, error) { return fn.Fn(args[0], args[1], args[2], args[3]), nil } +// NewInstance implements the interface Function. func (fn Function5) NewInstance(args []Expression) (Expression, error) { if len(args) != 5 { return nil, ErrInvalidArgumentNumber.New(fn.Name, 5, len(args)) @@ -156,6 +162,7 @@ func (fn Function5) NewInstance(args []Expression) (Expression, error) { return fn.Fn(args[0], args[1], args[2], args[3], args[4]), nil } +// NewInstance implements the interface Function. func (fn Function6) NewInstance(args []Expression) (Expression, error) { if len(args) != 6 { return nil, ErrInvalidArgumentNumber.New(fn.Name, 6, len(args)) @@ -164,6 +171,7 @@ func (fn Function6) NewInstance(args []Expression) (Expression, error) { return fn.Fn(args[0], args[1], args[2], args[3], args[4], args[5]), nil } +// NewInstance implements the interface Function. func (fn Function7) NewInstance(args []Expression) (Expression, error) { if len(args) != 7 { return nil, ErrInvalidArgumentNumber.New(fn.Name, 7, len(args)) @@ -172,6 +180,7 @@ func (fn Function7) NewInstance(args []Expression) (Expression, error) { return fn.Fn(args[0], args[1], args[2], args[3], args[4], args[5], args[6]), nil } +// NewInstance implements the interface Function. func (fn FunctionN) NewInstance(args []Expression) (Expression, error) { return fn.Fn(args...) } diff --git a/sql/planbuilder/from.go b/sql/planbuilder/from.go index 8df135a69e..3ff1f4f33b 100644 --- a/sql/planbuilder/from.go +++ b/sql/planbuilder/from.go @@ -751,6 +751,7 @@ func (b *Builder) buildResolvedTable(inScope *scope, db, schema, name string, as }) cols.Add(sql.ColumnId(id)) } + outScope.recordTableAsColumn(db, strings.ToLower(tab.Name()), tabId, rt) rt = rt.WithId(tabId).WithColumns(cols).(*plan.ResolvedTable) outScope.node = rt diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 5d83706230..b2edc91da5 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -136,19 +136,44 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { return sysVar } } - var err error if scope == ast.SetScope_User || scope == ast.SetScope_Persist || scope == ast.SetScope_PersistOnly { - err = sql.ErrUnknownUserVariable.New(colName) + err := sql.ErrUnknownUserVariable.New(colName) + b.handleErr(err) } else if scope == ast.SetScope_Global || scope == ast.SetScope_Session { - err = sql.ErrUnknownSystemVariable.New(colName) + err := sql.ErrUnknownSystemVariable.New(colName) + b.handleErr(err) } else if tblName != "" && !inScope.hasTable(tblName) { - err = sql.ErrTableNotFound.New(tblName) + err := sql.ErrTableNotFound.New(tblName) + b.handleErr(err) } else if tblName != "" { - err = sql.ErrTableColumnNotFound.New(tblName, colName) + err := sql.ErrTableColumnNotFound.New(tblName, colName) + b.handleErr(err) + } else if inScope.hasTable(colName) { + // TODO: only relevant for Doltgres, this will use a hook + scopeTable, ok := inScope.resolveColumnAsTable(dbName, colName) + if !ok { + err := sql.ErrColumnNotFound.New(v) + b.handleErr(err) + } + tableSch := scopeTable.rt.Schema() + astQualifier := ast.TableName{ + Name: ast.NewTableIdent(colName), // This must be the colName due to aliases + DbQualifier: ast.NewTableIdent(scopeTable.rt.Database().Name()), + } + fieldArgs := make([]sql.Expression, len(tableSch)) + for i := range tableSch { + astArg := ast.ColName{ + StoredProcVal: nil, + Qualifier: astQualifier, + Name: ast.NewColIdent(tableSch[i].Name), + } + fieldArgs[i] = b.buildScalar(inScope, &astArg) + } + return expression.NewDoltgresHookExpression(fieldArgs...) } else { - err = sql.ErrColumnNotFound.New(v) + err := sql.ErrColumnNotFound.New(v) + b.handleErr(err) } - b.handleErr(err) } origTbl := b.getOrigTblName(inScope.node, c.table) diff --git a/sql/planbuilder/scope.go b/sql/planbuilder/scope.go index 80ee268303..3d05b7f0e9 100644 --- a/sql/planbuilder/scope.go +++ b/sql/planbuilder/scope.go @@ -65,10 +65,22 @@ type scope struct { extraCols []scopeColumn // windowFuncs is a list of window functions in the current scope windowFuncs []scopeColumn + // tablesAsColumns allow for using tables in the same scope as columns, which is only valid in Doltgres + tablesAsColumns []resolvedScopeTable refsSubquery bool } +// resolvedScopeTable contains a table that may be resolved as though it were a column by the step that searches the +// column namespace. Integrators may implement support for treating table names as just another column, although +// standard MySQL does not allow this. +type resolvedScopeTable struct { + db string + table string + tableId sql.TableId + rt *plan.ResolvedTable +} + // resolveColumn matches a variable use to a column definition with a unique // expression id. |chooseFirst| is indicated for accepting ambiguous having and // group by columns. @@ -153,6 +165,32 @@ func (s *scope) resolveColumn(db, table, col string, checkParent, chooseFirst bo return c, true } +// resolveColumnAsTable resolves a column as though it were a table, by searching the table space. This is not standard +// in MySQL, and is only used by integrators. +func (s *scope) resolveColumnAsTable(db, table string) (resolvedScopeTable, bool) { + var scopeTable resolvedScopeTable + var found bool + tabId := s.getTable(table) + for _, tab := range s.tablesAsColumns { + if tab.tableId == tabId && (strings.EqualFold(tab.db, db) || db == "") { + if found { + if scopeTable.tableId == tab.tableId { + continue + } + // TODO: fix error being invalid for a table + err := sql.ErrAmbiguousColumnName.New(table, []string{tab.table, scopeTable.table}) + s.handleErr(err) + } + scopeTable = tab + found = true + } + } + if !found && s.parent != nil { + return s.parent.resolveColumnAsTable(db, table) + } + return scopeTable, found +} + // getCol gets a scopeColumn based on a columnId func (s *scope) getCol(colId sql.ColumnId) (scopeColumn, bool) { if s.colset.Contains(colId) { @@ -176,6 +214,19 @@ func (s *scope) hasTable(table string) bool { return false } +// getTable returns the table ID matching the given name. +func (s *scope) getTable(table string) sql.TableId { + // TODO: this doesn't take a database, but maybe it doesn't need to? (only care if table name exists at all) + id, ok := s.tables[strings.ToLower(table)] + if ok { + return id + } + if s.parent != nil { + return s.parent.getTable(table) + } + return 0 +} + // triggerCol is used to hallucinate a new column during trigger DDL // when we fail a resolveColumn. func (s *scope) triggerCol(table, col string) (scopeColumn, bool) { @@ -557,6 +608,16 @@ func (s *scope) addTable(name string) sql.TableId { return s.tables[name] } +// recordTableAsColumn records the properties of a table defined in this scope as a referenceable column +func (s *scope) recordTableAsColumn(db, table string, tableId sql.TableId, rt *plan.ResolvedTable) { + s.tablesAsColumns = append(s.tablesAsColumns, resolvedScopeTable{ + db: db, + table: table, + tableId: tableId, + rt: rt, + }) +} + // addExtraColumn marks an auxiliary column used in an // aggregation, sorting, or having clause. func (s *scope) addExtraColumn(col scopeColumn) {