From c91eb4b32d66f939a749036f16645207f99be3a9 Mon Sep 17 00:00:00 2001 From: Vincent Huang Date: Mon, 20 Apr 2026 12:02:06 -0700 Subject: [PATCH] feat(tidb): fork catalog, completion, deparse, scope packages from mysql (PR3a) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mechanical fork with import path renames only — no logic changes. Collation default left at utf8mb4_0900_ai_ci (PR3b will flip to utf8mb4_bin). Packages forked: - tidb/scope/ (1 file) - tidb/deparse/ (5 files) - tidb/catalog/ (120 files) - tidb/completion/ (7 files incl. SCENARIOS) Verification: - go build ./tidb/... — clean - go vet ./tidb/... — clean - go test ./tidb/... -short -count=1 — all passing Prepares the 4 packages for PR3b, which adds TiDB-specific catalog fields, option wiring, collation flip, completion candidates, and container tests. Co-Authored-By: Claude Opus 4.7 (1M context) --- tidb/catalog/altercmds.go | 1119 +++ tidb/catalog/altercmds_test.go | 393 + tidb/catalog/analyze.go | 1004 +++ tidb/catalog/analyze_expr.go | 368 + tidb/catalog/analyze_targetlist.go | 195 + tidb/catalog/analyze_test.go | 2757 +++++++ tidb/catalog/bugfix_test.go | 72 + tidb/catalog/catalog.go | 54 + tidb/catalog/catalog_spotcheck_test.go | 1167 +++ tidb/catalog/constraint.go | 27 + tidb/catalog/container_comprehensive_test.go | 100 + tidb/catalog/container_reserved_kw_test.go | 159 + tidb/catalog/container_scenarios_test.go | 6747 ++++++++++++++++++ tidb/catalog/container_test.go | 411 ++ tidb/catalog/database.go | 37 + tidb/catalog/dbcmds.go | 92 + tidb/catalog/dbcmds_test.go | 96 + tidb/catalog/define.go | 169 + tidb/catalog/define_test.go | 592 ++ tidb/catalog/deparse_container_test.go | 2525 +++++++ tidb/catalog/deparse_expr.go | 367 + tidb/catalog/deparse_query.go | 292 + tidb/catalog/deparse_query_test.go | 246 + tidb/catalog/deparse_rules_test.go | 340 + tidb/catalog/dropcmds.go | 76 + tidb/catalog/dropcmds_test.go | 71 + tidb/catalog/errors.go | 210 + tidb/catalog/eventcmds.go | 196 + tidb/catalog/exec.go | 274 + tidb/catalog/exec_line_test.go | 49 + tidb/catalog/exec_test.go | 27 + tidb/catalog/function_types.go | 88 + tidb/catalog/index.go | 22 + tidb/catalog/indexcmds.go | 131 + tidb/catalog/indexcmds_test.go | 141 + tidb/catalog/query.go | 1310 ++++ tidb/catalog/query_expand.go | 38 + tidb/catalog/query_span_test.go | 349 + tidb/catalog/renamecmds.go | 32 + tidb/catalog/renamecmds_test.go | 49 + tidb/catalog/routinecmds.go | 328 + tidb/catalog/scenarios_ax_test.go | 657 ++ tidb/catalog/scenarios_c10_test.go | 359 + tidb/catalog/scenarios_c11_test.go | 335 + tidb/catalog/scenarios_c14_test.go | 279 + tidb/catalog/scenarios_c15_test.go | 150 + tidb/catalog/scenarios_c16_test.go | 591 ++ tidb/catalog/scenarios_c17_test.go | 390 + tidb/catalog/scenarios_c18_test.go | 524 ++ tidb/catalog/scenarios_c19_test.go | 424 ++ tidb/catalog/scenarios_c1_test.go | 483 ++ tidb/catalog/scenarios_c20_test.go | 424 ++ tidb/catalog/scenarios_c21_test.go | 545 ++ tidb/catalog/scenarios_c22_test.go | 417 ++ tidb/catalog/scenarios_c23_test.go | 282 + tidb/catalog/scenarios_c24_test.go | 330 + tidb/catalog/scenarios_c25_test.go | 314 + tidb/catalog/scenarios_c2_test.go | 599 ++ tidb/catalog/scenarios_c3_test.go | 375 + tidb/catalog/scenarios_c4_test.go | 624 ++ tidb/catalog/scenarios_c5_test.go | 472 ++ tidb/catalog/scenarios_c6_test.go | 566 ++ tidb/catalog/scenarios_c7_test.go | 465 ++ tidb/catalog/scenarios_c8_test.go | 357 + tidb/catalog/scenarios_c9_test.go | 367 + tidb/catalog/scenarios_helpers_test.go | 332 + tidb/catalog/scenarios_ps_test.go | 369 + tidb/catalog/scope.go | 155 + tidb/catalog/show.go | 674 ++ tidb/catalog/show_test.go | 195 + tidb/catalog/table.go | 224 + tidb/catalog/tablecmds.go | 1529 ++++ tidb/catalog/tablecmds_test.go | 306 + tidb/catalog/triggercmds.go | 138 + tidb/catalog/viewcmds.go | 323 + tidb/catalog/viewcmds_test.go | 153 + tidb/catalog/wt_10_1_test.go | 363 + tidb/catalog/wt_10_2_test.go | 328 + tidb/catalog/wt_11_1_test.go | 252 + tidb/catalog/wt_11_2_test.go | 334 + tidb/catalog/wt_12_1_test.go | 176 + tidb/catalog/wt_12_2_test.go | 162 + tidb/catalog/wt_13_1_test.go | 110 + tidb/catalog/wt_13_2_test.go | 173 + tidb/catalog/wt_1_1_test.go | 156 + tidb/catalog/wt_1_2_test.go | 179 + tidb/catalog/wt_1_3_test.go | 127 + tidb/catalog/wt_2_1_test.go | 88 + tidb/catalog/wt_2_2_test.go | 185 + tidb/catalog/wt_2_3_test.go | 71 + tidb/catalog/wt_2_4_test.go | 155 + tidb/catalog/wt_2_5_test.go | 50 + tidb/catalog/wt_2_6_test.go | 113 + tidb/catalog/wt_3_1_test.go | 372 + tidb/catalog/wt_3_2_test.go | 153 + tidb/catalog/wt_3_3_test.go | 608 ++ tidb/catalog/wt_3_4_test.go | 257 + tidb/catalog/wt_3_5_test.go | 415 ++ tidb/catalog/wt_3_6_test.go | 223 + tidb/catalog/wt_3_7_test.go | 274 + tidb/catalog/wt_4_1_test.go | 452 ++ tidb/catalog/wt_4_2_test.go | 315 + tidb/catalog/wt_4_3_test.go | 327 + tidb/catalog/wt_5_1_test.go | 211 + tidb/catalog/wt_5_2_test.go | 262 + tidb/catalog/wt_5_3_test.go | 249 + tidb/catalog/wt_5_4_test.go | 116 + tidb/catalog/wt_6_1_test.go | 377 + tidb/catalog/wt_6_2_test.go | 317 + tidb/catalog/wt_6_3_test.go | 100 + tidb/catalog/wt_7_1_test.go | 336 + tidb/catalog/wt_7_2_test.go | 272 + tidb/catalog/wt_7_3_test.go | 327 + tidb/catalog/wt_7_4_test.go | 392 + tidb/catalog/wt_8_1_test.go | 148 + tidb/catalog/wt_8_2_test.go | 263 + tidb/catalog/wt_8_3_test.go | 126 + tidb/catalog/wt_9_1_test.go | 268 + tidb/catalog/wt_9_2_test.go | 239 + tidb/catalog/wt_helpers_test.go | 56 + tidb/completion/SCENARIOS-completion.md | 442 ++ tidb/completion/completion.go | 163 + tidb/completion/completion_test.go | 2835 ++++++++ tidb/completion/integration_test.go | 353 + tidb/completion/refs.go | 264 + tidb/completion/refs_test.go | 174 + tidb/completion/resolve.go | 453 ++ tidb/deparse/deparse.go | 1757 +++++ tidb/deparse/deparse_test.go | 1324 ++++ tidb/deparse/resolver.go | 895 +++ tidb/deparse/resolver_test.go | 484 ++ tidb/deparse/rewrite.go | 287 + tidb/scope/scope.go | 129 + 133 files changed, 57484 insertions(+) create mode 100644 tidb/catalog/altercmds.go create mode 100644 tidb/catalog/altercmds_test.go create mode 100644 tidb/catalog/analyze.go create mode 100644 tidb/catalog/analyze_expr.go create mode 100644 tidb/catalog/analyze_targetlist.go create mode 100644 tidb/catalog/analyze_test.go create mode 100644 tidb/catalog/bugfix_test.go create mode 100644 tidb/catalog/catalog.go create mode 100644 tidb/catalog/catalog_spotcheck_test.go create mode 100644 tidb/catalog/constraint.go create mode 100644 tidb/catalog/container_comprehensive_test.go create mode 100644 tidb/catalog/container_reserved_kw_test.go create mode 100644 tidb/catalog/container_scenarios_test.go create mode 100644 tidb/catalog/container_test.go create mode 100644 tidb/catalog/database.go create mode 100644 tidb/catalog/dbcmds.go create mode 100644 tidb/catalog/dbcmds_test.go create mode 100644 tidb/catalog/define.go create mode 100644 tidb/catalog/define_test.go create mode 100644 tidb/catalog/deparse_container_test.go create mode 100644 tidb/catalog/deparse_expr.go create mode 100644 tidb/catalog/deparse_query.go create mode 100644 tidb/catalog/deparse_query_test.go create mode 100644 tidb/catalog/deparse_rules_test.go create mode 100644 tidb/catalog/dropcmds.go create mode 100644 tidb/catalog/dropcmds_test.go create mode 100644 tidb/catalog/errors.go create mode 100644 tidb/catalog/eventcmds.go create mode 100644 tidb/catalog/exec.go create mode 100644 tidb/catalog/exec_line_test.go create mode 100644 tidb/catalog/exec_test.go create mode 100644 tidb/catalog/function_types.go create mode 100644 tidb/catalog/index.go create mode 100644 tidb/catalog/indexcmds.go create mode 100644 tidb/catalog/indexcmds_test.go create mode 100644 tidb/catalog/query.go create mode 100644 tidb/catalog/query_expand.go create mode 100644 tidb/catalog/query_span_test.go create mode 100644 tidb/catalog/renamecmds.go create mode 100644 tidb/catalog/renamecmds_test.go create mode 100644 tidb/catalog/routinecmds.go create mode 100644 tidb/catalog/scenarios_ax_test.go create mode 100644 tidb/catalog/scenarios_c10_test.go create mode 100644 tidb/catalog/scenarios_c11_test.go create mode 100644 tidb/catalog/scenarios_c14_test.go create mode 100644 tidb/catalog/scenarios_c15_test.go create mode 100644 tidb/catalog/scenarios_c16_test.go create mode 100644 tidb/catalog/scenarios_c17_test.go create mode 100644 tidb/catalog/scenarios_c18_test.go create mode 100644 tidb/catalog/scenarios_c19_test.go create mode 100644 tidb/catalog/scenarios_c1_test.go create mode 100644 tidb/catalog/scenarios_c20_test.go create mode 100644 tidb/catalog/scenarios_c21_test.go create mode 100644 tidb/catalog/scenarios_c22_test.go create mode 100644 tidb/catalog/scenarios_c23_test.go create mode 100644 tidb/catalog/scenarios_c24_test.go create mode 100644 tidb/catalog/scenarios_c25_test.go create mode 100644 tidb/catalog/scenarios_c2_test.go create mode 100644 tidb/catalog/scenarios_c3_test.go create mode 100644 tidb/catalog/scenarios_c4_test.go create mode 100644 tidb/catalog/scenarios_c5_test.go create mode 100644 tidb/catalog/scenarios_c6_test.go create mode 100644 tidb/catalog/scenarios_c7_test.go create mode 100644 tidb/catalog/scenarios_c8_test.go create mode 100644 tidb/catalog/scenarios_c9_test.go create mode 100644 tidb/catalog/scenarios_helpers_test.go create mode 100644 tidb/catalog/scenarios_ps_test.go create mode 100644 tidb/catalog/scope.go create mode 100644 tidb/catalog/show.go create mode 100644 tidb/catalog/show_test.go create mode 100644 tidb/catalog/table.go create mode 100644 tidb/catalog/tablecmds.go create mode 100644 tidb/catalog/tablecmds_test.go create mode 100644 tidb/catalog/triggercmds.go create mode 100644 tidb/catalog/viewcmds.go create mode 100644 tidb/catalog/viewcmds_test.go create mode 100644 tidb/catalog/wt_10_1_test.go create mode 100644 tidb/catalog/wt_10_2_test.go create mode 100644 tidb/catalog/wt_11_1_test.go create mode 100644 tidb/catalog/wt_11_2_test.go create mode 100644 tidb/catalog/wt_12_1_test.go create mode 100644 tidb/catalog/wt_12_2_test.go create mode 100644 tidb/catalog/wt_13_1_test.go create mode 100644 tidb/catalog/wt_13_2_test.go create mode 100644 tidb/catalog/wt_1_1_test.go create mode 100644 tidb/catalog/wt_1_2_test.go create mode 100644 tidb/catalog/wt_1_3_test.go create mode 100644 tidb/catalog/wt_2_1_test.go create mode 100644 tidb/catalog/wt_2_2_test.go create mode 100644 tidb/catalog/wt_2_3_test.go create mode 100644 tidb/catalog/wt_2_4_test.go create mode 100644 tidb/catalog/wt_2_5_test.go create mode 100644 tidb/catalog/wt_2_6_test.go create mode 100644 tidb/catalog/wt_3_1_test.go create mode 100644 tidb/catalog/wt_3_2_test.go create mode 100644 tidb/catalog/wt_3_3_test.go create mode 100644 tidb/catalog/wt_3_4_test.go create mode 100644 tidb/catalog/wt_3_5_test.go create mode 100644 tidb/catalog/wt_3_6_test.go create mode 100644 tidb/catalog/wt_3_7_test.go create mode 100644 tidb/catalog/wt_4_1_test.go create mode 100644 tidb/catalog/wt_4_2_test.go create mode 100644 tidb/catalog/wt_4_3_test.go create mode 100644 tidb/catalog/wt_5_1_test.go create mode 100644 tidb/catalog/wt_5_2_test.go create mode 100644 tidb/catalog/wt_5_3_test.go create mode 100644 tidb/catalog/wt_5_4_test.go create mode 100644 tidb/catalog/wt_6_1_test.go create mode 100644 tidb/catalog/wt_6_2_test.go create mode 100644 tidb/catalog/wt_6_3_test.go create mode 100644 tidb/catalog/wt_7_1_test.go create mode 100644 tidb/catalog/wt_7_2_test.go create mode 100644 tidb/catalog/wt_7_3_test.go create mode 100644 tidb/catalog/wt_7_4_test.go create mode 100644 tidb/catalog/wt_8_1_test.go create mode 100644 tidb/catalog/wt_8_2_test.go create mode 100644 tidb/catalog/wt_8_3_test.go create mode 100644 tidb/catalog/wt_9_1_test.go create mode 100644 tidb/catalog/wt_9_2_test.go create mode 100644 tidb/catalog/wt_helpers_test.go create mode 100644 tidb/completion/SCENARIOS-completion.md create mode 100644 tidb/completion/completion.go create mode 100644 tidb/completion/completion_test.go create mode 100644 tidb/completion/integration_test.go create mode 100644 tidb/completion/refs.go create mode 100644 tidb/completion/refs_test.go create mode 100644 tidb/completion/resolve.go create mode 100644 tidb/deparse/deparse.go create mode 100644 tidb/deparse/deparse_test.go create mode 100644 tidb/deparse/resolver.go create mode 100644 tidb/deparse/resolver_test.go create mode 100644 tidb/deparse/rewrite.go create mode 100644 tidb/scope/scope.go diff --git a/tidb/catalog/altercmds.go b/tidb/catalog/altercmds.go new file mode 100644 index 00000000..31437c03 --- /dev/null +++ b/tidb/catalog/altercmds.go @@ -0,0 +1,1119 @@ +package catalog + +import ( + "fmt" + "strings" + + nodes "github.com/bytebase/omni/tidb/ast" +) + +func (c *Catalog) alterTable(stmt *nodes.AlterTableStmt) error { + // Resolve database. + dbName := "" + if stmt.Table != nil { + dbName = stmt.Table.Schema + } + db, err := c.resolveDatabase(dbName) + if err != nil { + return err + } + + tableName := stmt.Table.Name + key := toLower(tableName) + tbl := db.Tables[key] + if tbl == nil { + return errNoSuchTable(db.Name, tableName) + } + + if len(stmt.Commands) <= 1 { + // Single command: no rollback needed. + if len(stmt.Commands) == 1 { + return c.execAlterCmd(db, tbl, stmt.Commands[0]) + } + return nil + } + + // Multi-command ALTER: MySQL treats this as atomic. + // Snapshot the table so we can rollback on any sub-command failure. + snapshot := cloneTable(tbl) + origKey := key + + for _, cmd := range stmt.Commands { + if err := c.execAlterCmd(db, tbl, cmd); err != nil { + // Rollback: if a RENAME changed the map key, undo it. + newKey := toLower(tbl.Name) + if newKey != origKey { + delete(db.Tables, newKey) + db.Tables[origKey] = tbl + } + // Restore all table fields from snapshot. + *tbl = snapshot + return err + } + } + // Clear transient cleanup tracking after successful multi-command ALTER. + tbl.droppedByCleanup = nil + return nil +} + +func (c *Catalog) execAlterCmd(db *Database, tbl *Table, cmd *nodes.AlterTableCmd) error { + switch cmd.Type { + case nodes.ATAddColumn: + return c.alterAddColumn(tbl, cmd) + case nodes.ATDropColumn: + return c.alterDropColumn(tbl, cmd) + case nodes.ATModifyColumn: + return c.alterModifyColumn(tbl, cmd) + case nodes.ATChangeColumn: + return c.alterChangeColumn(tbl, cmd) + case nodes.ATAddIndex, nodes.ATAddConstraint: + return c.alterAddConstraint(tbl, cmd) + case nodes.ATDropIndex: + return c.alterDropIndex(tbl, cmd) + case nodes.ATDropConstraint: + return c.alterDropConstraint(tbl, cmd) + case nodes.ATRenameColumn: + return c.alterRenameColumn(tbl, cmd) + case nodes.ATRenameIndex: + return c.alterRenameIndex(tbl, cmd) + case nodes.ATRenameTable: + return c.alterRenameTable(db, tbl, cmd) + case nodes.ATTableOption: + return c.alterTableOption(tbl, cmd) + case nodes.ATAlterColumnDefault: + return c.alterColumnDefault(tbl, cmd) + case nodes.ATAlterColumnVisible: + return c.alterColumnVisibility(tbl, cmd, false) + case nodes.ATAlterColumnInvisible: + return c.alterColumnVisibility(tbl, cmd, true) + case nodes.ATAlterIndexVisible: + return c.alterIndexVisibility(tbl, cmd, true) + case nodes.ATAlterIndexInvisible: + return c.alterIndexVisibility(tbl, cmd, false) + case nodes.ATAlterCheckEnforced: + return c.alterCheckEnforced(tbl, cmd) + case nodes.ATConvertCharset: + return c.alterConvertCharset(tbl, cmd) + case nodes.ATAddPartition: + return c.alterAddPartition(tbl, cmd) + case nodes.ATDropPartition: + return c.alterDropPartition(tbl, cmd) + case nodes.ATTruncatePartition: + return c.alterTruncatePartition(tbl, cmd) + case nodes.ATCoalescePartition: + return c.alterCoalescePartition(tbl, cmd) + case nodes.ATReorganizePartition: + return c.alterReorganizePartition(tbl, cmd) + case nodes.ATExchangePartition: + return c.alterExchangePartition(db, tbl, cmd) + case nodes.ATRemovePartitioning: + tbl.Partitioning = nil + return nil + default: + // Unsupported alter command; silently ignore. + return nil + } +} + +// alterAddColumn adds a new column to the table. +func (c *Catalog) alterAddColumn(tbl *Table, cmd *nodes.AlterTableCmd) error { + // Handle multi-column parenthesized form: ADD (col1 INT, col2 INT, ...) + if len(cmd.Columns) > 0 { + for _, colDef := range cmd.Columns { + if err := c.addSingleColumn(tbl, colDef, false, ""); err != nil { + return err + } + } + return nil + } + + colDef := cmd.Column + if colDef == nil { + return nil + } + + return c.addSingleColumn(tbl, colDef, cmd.First, cmd.After) +} + +// addSingleColumn adds one column definition to the table. +func (c *Catalog) addSingleColumn(tbl *Table, colDef *nodes.ColumnDef, first bool, after string) error { + colKey := toLower(colDef.Name) + if _, exists := tbl.colByName[colKey]; exists { + return errDupColumn(colDef.Name) + } + + col := buildColumnFromDef(tbl, colDef) + if err := insertColumn(tbl, col, first, after); err != nil { + return err + } + + // Process column-level constraints that produce indexes/constraints. + for _, cc := range colDef.Constraints { + switch cc.Type { + case nodes.ColConstrPrimaryKey: + // Check for duplicate PK. + for _, idx := range tbl.Indexes { + if idx.Primary { + return errMultiplePriKey() + } + } + col.Nullable = false + tbl.Indexes = append(tbl.Indexes, &Index{ + Name: "PRIMARY", + Table: tbl, + Columns: []*IndexColumn{{Name: colDef.Name}}, + Unique: true, + Primary: true, + IndexType: "", + Visible: true, + }) + tbl.Constraints = append(tbl.Constraints, &Constraint{ + Name: "PRIMARY", + Type: ConPrimaryKey, + Table: tbl, + Columns: []string{colDef.Name}, + IndexName: "PRIMARY", + }) + case nodes.ColConstrUnique: + idxName := allocIndexName(tbl, colDef.Name) + tbl.Indexes = append(tbl.Indexes, &Index{ + Name: idxName, + Table: tbl, + Columns: []*IndexColumn{{Name: colDef.Name}}, + Unique: true, + IndexType: "", + Visible: true, + }) + tbl.Constraints = append(tbl.Constraints, &Constraint{ + Name: idxName, + Type: ConUniqueKey, + Table: tbl, + Columns: []string{colDef.Name}, + IndexName: idxName, + }) + } + } + + return nil +} + +// alterDropColumn removes a column from the table. +func (c *Catalog) alterDropColumn(tbl *Table, cmd *nodes.AlterTableCmd) error { + colKey := toLower(cmd.Name) + if _, exists := tbl.colByName[colKey]; !exists { + if cmd.IfExists { + return nil + } + // MySQL 8.0 returns error 1091 for DROP COLUMN on nonexistent column, + // same as DROP INDEX: "Can't DROP 'x'; check that column/key exists". + return errCantDropKey(cmd.Name) + } + + // Check if column is referenced by a generated column expression. + for _, col := range tbl.Columns { + if col.Generated != nil && generatedExprReferencesColumn(col.Generated.Expr, cmd.Name) { + return errDependentByGeneratedColumn(cmd.Name, col.Name, tbl.Name) + } + } + + // Check if column is referenced by a foreign key constraint. + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey { + for _, col := range con.Columns { + if toLower(col) == colKey { + return &Error{ + Code: 1828, + SQLState: "HY000", + Message: fmt.Sprintf("Cannot drop column '%s': needed in a foreign key constraint '%s'", cmd.Name, con.Name), + } + } + } + } + } + + // Remove column from indexes; if index becomes empty, remove it entirely. + cleanupIndexesForDroppedColumn(tbl, cmd.Name) + + idx := tbl.colByName[colKey] + tbl.Columns = append(tbl.Columns[:idx], tbl.Columns[idx+1:]...) + rebuildColIndex(tbl) + return nil +} + +// cleanupIndexesForDroppedColumn removes references to a dropped column from +// all indexes. If an index loses all columns, it is removed entirely. +// Associated constraints are also cleaned up. +func cleanupIndexesForDroppedColumn(tbl *Table, colName string) { + colKey := toLower(colName) + + // Clean up indexes. + newIndexes := make([]*Index, 0, len(tbl.Indexes)) + removedIndexNames := make(map[string]bool) + for _, idx := range tbl.Indexes { + // Remove the column from this index. + newCols := make([]*IndexColumn, 0, len(idx.Columns)) + for _, ic := range idx.Columns { + if toLower(ic.Name) != colKey { + newCols = append(newCols, ic) + } + } + if len(newCols) == 0 { + // Index has no columns left — remove it. + nameKey := toLower(idx.Name) + removedIndexNames[nameKey] = true + // Track for multi-command ALTER so explicit DROP INDEX succeeds. + if tbl.droppedByCleanup == nil { + tbl.droppedByCleanup = make(map[string]bool) + } + tbl.droppedByCleanup[nameKey] = true + continue + } + idx.Columns = newCols + newIndexes = append(newIndexes, idx) + } + tbl.Indexes = newIndexes + + // Clean up constraints that reference removed indexes. + if len(removedIndexNames) > 0 { + newConstraints := make([]*Constraint, 0, len(tbl.Constraints)) + for _, con := range tbl.Constraints { + if removedIndexNames[toLower(con.IndexName)] || removedIndexNames[toLower(con.Name)] { + continue + } + newConstraints = append(newConstraints, con) + } + tbl.Constraints = newConstraints + } + + // Also update constraint column lists for remaining constraints. + for _, con := range tbl.Constraints { + newCols := make([]string, 0, len(con.Columns)) + for _, col := range con.Columns { + if toLower(col) != colKey { + newCols = append(newCols, col) + } + } + con.Columns = newCols + } +} + +// alterModifyColumn replaces a column definition in-place (same name). +func (c *Catalog) alterModifyColumn(tbl *Table, cmd *nodes.AlterTableCmd) error { + if cmd.Column == nil { + return nil + } + return c.alterReplaceColumn(tbl, cmd.Column.Name, cmd) +} + +// alterChangeColumn replaces a column (old name -> new name + new definition). +func (c *Catalog) alterChangeColumn(tbl *Table, cmd *nodes.AlterTableCmd) error { + if cmd.Column == nil { + return nil + } + return c.alterReplaceColumn(tbl, cmd.Name, cmd) +} + +// alterReplaceColumn is the shared implementation for MODIFY and CHANGE COLUMN. +// oldName is the existing column to replace; cmd.Column defines the new column. +func (c *Catalog) alterReplaceColumn(tbl *Table, oldName string, cmd *nodes.AlterTableCmd) error { + colDef := cmd.Column + oldKey := toLower(oldName) + idx, exists := tbl.colByName[oldKey] + if !exists { + return errNoSuchColumn(oldName, tbl.Name) + } + + // Check if new name conflicts with existing column (unless same). + newKey := toLower(colDef.Name) + if newKey != oldKey { + if _, dup := tbl.colByName[newKey]; dup { + return errDupColumn(colDef.Name) + } + } + + // Check for VIRTUAL<->STORED storage type change (MySQL 8.0 error 3106). + oldCol := tbl.Columns[idx] + if oldCol.Generated != nil && colDef.Generated != nil { + if oldCol.Generated.Stored != colDef.Generated.Stored { + return errUnsupportedGeneratedStorageChange(colDef.Name, tbl.Name) + } + } + + col := buildColumnFromDef(tbl, colDef) + col.Position = idx + 1 + tbl.Columns[idx] = col + + // Update index/constraint column references if name changed. + if newKey != oldKey { + updateColumnRefsInIndexes(tbl, oldName, colDef.Name) + } + + // Handle repositioning. + if cmd.First || cmd.After != "" { + tbl.Columns = append(tbl.Columns[:idx], tbl.Columns[idx+1:]...) + rebuildColIndex(tbl) + if err := insertColumn(tbl, col, cmd.First, cmd.After); err != nil { + return err + } + } else { + rebuildColIndex(tbl) + } + + return nil +} + +// alterAddConstraint adds a constraint or index to the table. +func (c *Catalog) alterAddConstraint(tbl *Table, cmd *nodes.AlterTableCmd) error { + con := cmd.Constraint + if con == nil { + return nil + } + + cols := extractColumnNames(con) + + switch con.Type { + case nodes.ConstrPrimaryKey: + // Check for duplicate PK. + for _, idx := range tbl.Indexes { + if idx.Primary { + return errMultiplePriKey() + } + } + // Mark PK columns as NOT NULL. + for _, colName := range cols { + col := tbl.GetColumn(colName) + if col != nil { + col.Nullable = false + } + } + idxCols := buildIndexColumns(con) + tbl.Indexes = append(tbl.Indexes, &Index{ + Name: "PRIMARY", + Table: tbl, + Columns: idxCols, + Unique: true, + Primary: true, + IndexType: "", + Visible: true, + }) + tbl.Constraints = append(tbl.Constraints, &Constraint{ + Name: "PRIMARY", + Type: ConPrimaryKey, + Table: tbl, + Columns: cols, + IndexName: "PRIMARY", + }) + + case nodes.ConstrUnique: + idxName := con.Name + if idxName == "" && len(cols) > 0 { + idxName = allocIndexName(tbl, cols[0]) + } else if idxName != "" && indexNameExists(tbl, idxName) { + return errDupKeyName(idxName) + } + idxCols := buildIndexColumns(con) + tbl.Indexes = append(tbl.Indexes, &Index{ + Name: idxName, + Table: tbl, + Columns: idxCols, + Unique: true, + IndexType: resolveConstraintIndexType(con), + Visible: true, + }) + tbl.Constraints = append(tbl.Constraints, &Constraint{ + Name: idxName, + Type: ConUniqueKey, + Table: tbl, + Columns: cols, + IndexName: idxName, + }) + + case nodes.ConstrForeignKey: + conName := con.Name + if conName == "" { + conName = fmt.Sprintf("%s_ibfk_%d", tbl.Name, nextFKGeneratedNumber(tbl, tbl.Name)) + } + refDBName := "" + refTable := "" + if con.RefTable != nil { + refDBName = con.RefTable.Schema + refTable = con.RefTable.Name + } + fkCon := &Constraint{ + Name: conName, + Type: ConForeignKey, + Table: tbl, + Columns: cols, + RefDatabase: refDBName, + RefTable: refTable, + RefColumns: con.RefColumns, + OnDelete: refActionToString(con.OnDelete), + OnUpdate: refActionToString(con.OnUpdate), + } + // Validate FK before adding (unless foreign_key_checks=0). + db := tbl.Database + if c.foreignKeyChecks { + if err := c.validateSingleFK(db, tbl, fkCon); err != nil { + return err + } + } + tbl.Constraints = append(tbl.Constraints, fkCon) + // Add implicit backing index for FK if needed. + ensureFKBackingIndex(tbl, con.Name, cols, buildIndexColumns(con)) + + case nodes.ConstrCheck: + conName := con.Name + if conName == "" { + conName = fmt.Sprintf("%s_chk_%d", tbl.Name, nextCheckNumber(tbl)) + } + tbl.Constraints = append(tbl.Constraints, &Constraint{ + Name: conName, + Type: ConCheck, + Table: tbl, + CheckExpr: nodeToSQL(con.Expr), + NotEnforced: con.NotEnforced, + }) + + case nodes.ConstrIndex: + idxName := con.Name + if idxName == "" && len(cols) > 0 { + idxName = allocIndexName(tbl, cols[0]) + } else if idxName != "" && indexNameExists(tbl, idxName) { + return errDupKeyName(idxName) + } + idxCols := buildIndexColumns(con) + tbl.Indexes = append(tbl.Indexes, &Index{ + Name: idxName, + Table: tbl, + Columns: idxCols, + IndexType: resolveConstraintIndexType(con), + Visible: true, + }) + + case nodes.ConstrFulltextIndex: + idxName := con.Name + if idxName == "" && len(cols) > 0 { + idxName = allocIndexName(tbl, cols[0]) + } else if idxName != "" && indexNameExists(tbl, idxName) { + return errDupKeyName(idxName) + } + idxCols := buildIndexColumns(con) + tbl.Indexes = append(tbl.Indexes, &Index{ + Name: idxName, + Table: tbl, + Columns: idxCols, + Fulltext: true, + IndexType: "FULLTEXT", + Visible: true, + }) + + case nodes.ConstrSpatialIndex: + idxName := con.Name + if idxName == "" && len(cols) > 0 { + idxName = allocIndexName(tbl, cols[0]) + } else if idxName != "" && indexNameExists(tbl, idxName) { + return errDupKeyName(idxName) + } + idxCols := buildIndexColumns(con) + tbl.Indexes = append(tbl.Indexes, &Index{ + Name: idxName, + Table: tbl, + Columns: idxCols, + Spatial: true, + IndexType: "SPATIAL", + Visible: true, + }) + } + + return nil +} + +// alterDropIndex removes an index (and any associated constraint) by name. +func (c *Catalog) alterDropIndex(tbl *Table, cmd *nodes.AlterTableCmd) error { + name := cmd.Name + key := toLower(name) + + found := false + for i, idx := range tbl.Indexes { + if toLower(idx.Name) == key { + tbl.Indexes = append(tbl.Indexes[:i], tbl.Indexes[i+1:]...) + found = true + break + } + } + if !found { + if cmd.IfExists { + return nil + } + // If the index was auto-removed by DROP COLUMN cleanup in this + // multi-command ALTER, treat as success (matches MySQL 8.0 behavior). + if tbl.droppedByCleanup[key] { + return nil + } + return errCantDropKey(name) + } + + // Also remove any constraint that references this index. + for i, con := range tbl.Constraints { + if toLower(con.IndexName) == key || toLower(con.Name) == key { + tbl.Constraints = append(tbl.Constraints[:i], tbl.Constraints[i+1:]...) + break + } + } + + return nil +} + +// alterDropConstraint removes a constraint by name. +func (c *Catalog) alterDropConstraint(tbl *Table, cmd *nodes.AlterTableCmd) error { + name := cmd.Name + key := toLower(name) + + found := false + isForeignKey := false + for i, con := range tbl.Constraints { + if toLower(con.Name) == key { + isForeignKey = (con.Type == ConForeignKey) + tbl.Constraints = append(tbl.Constraints[:i], tbl.Constraints[i+1:]...) + found = true + break + } + } + + if !found { + if cmd.IfExists { + return nil + } + return errCantDropKey(name) + } + + // For FK constraints, MySQL keeps the backing index when dropping the FK. + // For other constraints (e.g., PRIMARY KEY), also remove the corresponding index. + if !isForeignKey { + for i, idx := range tbl.Indexes { + if toLower(idx.Name) == key { + tbl.Indexes = append(tbl.Indexes[:i], tbl.Indexes[i+1:]...) + break + } + } + } + + return nil +} + +// alterRenameColumn changes a column name in-place. +func (c *Catalog) alterRenameColumn(tbl *Table, cmd *nodes.AlterTableCmd) error { + oldKey := toLower(cmd.Name) + idx, exists := tbl.colByName[oldKey] + if !exists { + return errNoSuchColumn(cmd.Name, tbl.Name) + } + + newKey := toLower(cmd.NewName) + if newKey != oldKey { + if _, dup := tbl.colByName[newKey]; dup { + return errDupColumn(cmd.NewName) + } + } + + tbl.Columns[idx].Name = cmd.NewName + updateColumnRefsInIndexes(tbl, cmd.Name, cmd.NewName) + rebuildColIndex(tbl) + return nil +} + +// alterRenameIndex changes an index name in-place. +func (c *Catalog) alterRenameIndex(tbl *Table, cmd *nodes.AlterTableCmd) error { + oldKey := toLower(cmd.Name) + newKey := toLower(cmd.NewName) + + if newKey != oldKey && indexNameExists(tbl, cmd.NewName) { + return errDupKeyName(cmd.NewName) + } + + for _, idx := range tbl.Indexes { + if toLower(idx.Name) == oldKey { + idx.Name = cmd.NewName + // Also update any constraint that references this index. + for _, con := range tbl.Constraints { + if toLower(con.IndexName) == oldKey { + con.IndexName = cmd.NewName + con.Name = cmd.NewName + } + } + return nil + } + } + + return &Error{ + Code: ErrDupKeyName, + SQLState: sqlState(ErrDupKeyName), + Message: fmt.Sprintf("Key '%s' doesn't exist in table '%s'", cmd.Name, tbl.Name), + } +} + +// alterRenameTable moves a table to a new name. +func (c *Catalog) alterRenameTable(db *Database, tbl *Table, cmd *nodes.AlterTableCmd) error { + newName := cmd.NewName + newKey := toLower(newName) + oldKey := toLower(tbl.Name) + + if newKey != oldKey { + if db.Tables[newKey] != nil { + return errDupTable(newName) + } + } + + delete(db.Tables, oldKey) + tbl.Name = newName + db.Tables[newKey] = tbl + return nil +} + +// alterTableOption applies a table option (ENGINE, CHARSET, etc.). +func (c *Catalog) alterTableOption(tbl *Table, cmd *nodes.AlterTableCmd) error { + opt := cmd.Option + if opt == nil { + return nil + } + + switch toLower(opt.Name) { + case "engine": + tbl.Engine = opt.Value + case "charset", "character set", "default charset", "default character set": + tbl.Charset = opt.Value + // Update collation to the default for this charset. + if defColl, ok := defaultCollationForCharset[toLower(opt.Value)]; ok { + tbl.Collation = defColl + } + case "collate", "default collate": + tbl.Collation = opt.Value + case "comment": + tbl.Comment = opt.Value + case "auto_increment": + fmt.Sscanf(opt.Value, "%d", &tbl.AutoIncrement) + case "row_format": + tbl.RowFormat = opt.Value + } + return nil +} + +// alterColumnDefault sets or drops the default on an existing column. +func (c *Catalog) alterColumnDefault(tbl *Table, cmd *nodes.AlterTableCmd) error { + colKey := toLower(cmd.Name) + idx, exists := tbl.colByName[colKey] + if !exists { + return errNoSuchColumn(cmd.Name, tbl.Name) + } + + col := tbl.Columns[idx] + if cmd.DefaultExpr != nil { + s := nodeToSQL(cmd.DefaultExpr) + col.Default = &s + col.DefaultDropped = false + } else { + // DROP DEFAULT — MySQL shows no default at all (not even DEFAULT NULL). + col.Default = nil + col.DefaultDropped = true + } + return nil +} + +// alterColumnVisibility toggles the INVISIBLE flag on a column. +func (c *Catalog) alterColumnVisibility(tbl *Table, cmd *nodes.AlterTableCmd, invisible bool) error { + colKey := toLower(cmd.Name) + idx, exists := tbl.colByName[colKey] + if !exists { + return errNoSuchColumn(cmd.Name, tbl.Name) + } + tbl.Columns[idx].Invisible = invisible + return nil +} + +// alterIndexVisibility toggles the Visible flag on an index. +func (c *Catalog) alterIndexVisibility(tbl *Table, cmd *nodes.AlterTableCmd, visible bool) error { + key := toLower(cmd.Name) + for _, idx := range tbl.Indexes { + if toLower(idx.Name) == key { + idx.Visible = visible + return nil + } + } + return &Error{ + Code: ErrDupKeyName, + SQLState: sqlState(ErrDupKeyName), + Message: fmt.Sprintf("Key '%s' doesn't exist in table '%s'", cmd.Name, tbl.Name), + } +} + +// insertColumn inserts col into tbl at the position specified by first/after. +// If neither first nor after is set, appends at end. Always rebuilds the column index. +func insertColumn(tbl *Table, col *Column, first bool, after string) error { + if first { + tbl.Columns = append([]*Column{col}, tbl.Columns...) + } else if after != "" { + afterIdx, ok := tbl.colByName[toLower(after)] + if !ok { + return errNoSuchColumn(after, tbl.Name) + } + pos := afterIdx + 1 + tbl.Columns = append(tbl.Columns, nil) + copy(tbl.Columns[pos+1:], tbl.Columns[pos:]) + tbl.Columns[pos] = col + } else { + tbl.Columns = append(tbl.Columns, col) + } + rebuildColIndex(tbl) + return nil +} + +// rebuildColIndex rebuilds tbl.colByName and updates Position fields. +func rebuildColIndex(tbl *Table) { + tbl.colByName = make(map[string]int, len(tbl.Columns)) + for i, col := range tbl.Columns { + col.Position = i + 1 + tbl.colByName[toLower(col.Name)] = i + } +} + +// buildColumnFromDef builds a catalog Column from an AST ColumnDef. +func buildColumnFromDef(tbl *Table, colDef *nodes.ColumnDef) *Column { + col := &Column{ + Name: colDef.Name, + Nullable: true, + } + + // Type info. + if colDef.TypeName != nil { + col.DataType = toLower(colDef.TypeName.Name) + // MySQL 8.0 normalizes GEOMETRYCOLLECTION → geomcollection. + if col.DataType == "geometrycollection" { + col.DataType = "geomcollection" + } + col.ColumnType = formatColumnType(colDef.TypeName) + if colDef.TypeName.Charset != "" { + col.Charset = colDef.TypeName.Charset + } + if colDef.TypeName.Collate != "" { + col.Collation = colDef.TypeName.Collate + } + } + + // Default charset/collation for string types. + if isStringType(col.DataType) { + if col.Charset == "" { + col.Charset = tbl.Charset + } + if col.Collation == "" { + // If column charset differs from table charset, use the default + // collation for the column's charset, not the table's collation. + if !strings.EqualFold(col.Charset, tbl.Charset) { + if dc, ok := defaultCollationForCharset[toLower(col.Charset)]; ok { + col.Collation = dc + } + } else { + col.Collation = tbl.Collation + } + } + } + + // Top-level column properties. + if colDef.AutoIncrement { + col.AutoIncrement = true + col.Nullable = false + } + if colDef.Comment != "" { + col.Comment = colDef.Comment + } + if colDef.DefaultValue != nil { + s := nodeToSQL(colDef.DefaultValue) + col.Default = &s + } + if colDef.OnUpdate != nil { + col.OnUpdate = nodeToSQL(colDef.OnUpdate) + } + if colDef.Generated != nil { + col.Generated = &GeneratedColumnInfo{ + Expr: nodeToSQLGenerated(colDef.Generated.Expr, tbl.Charset), + Stored: colDef.Generated.Stored, + } + } + + // Process column-level constraints (non-index-producing ones). + for _, cc := range colDef.Constraints { + switch cc.Type { + case nodes.ColConstrNotNull: + col.Nullable = false + case nodes.ColConstrNull: + col.Nullable = true + case nodes.ColConstrDefault: + if cc.Expr != nil { + s := nodeToSQL(cc.Expr) + col.Default = &s + } + case nodes.ColConstrAutoIncrement: + col.AutoIncrement = true + col.Nullable = false + case nodes.ColConstrVisible: + col.Invisible = false + case nodes.ColConstrInvisible: + col.Invisible = true + case nodes.ColConstrCollate: + if cc.Expr != nil { + if s, ok := cc.Expr.(*nodes.StringLit); ok { + col.Collation = s.Value + } + } + } + } + + return col +} + +// updateColumnRefsInIndexes updates index and constraint column references +// when a column is renamed. +func updateColumnRefsInIndexes(tbl *Table, oldName, newName string) { + oldKey := toLower(oldName) + for _, idx := range tbl.Indexes { + for _, ic := range idx.Columns { + if toLower(ic.Name) == oldKey { + ic.Name = newName + } + } + } + for _, con := range tbl.Constraints { + for i, col := range con.Columns { + if toLower(col) == oldKey { + con.Columns[i] = newName + } + } + } +} + +// alterCheckEnforced toggles the ENFORCED / NOT ENFORCED flag on a CHECK constraint. +func (c *Catalog) alterCheckEnforced(tbl *Table, cmd *nodes.AlterTableCmd) error { + key := toLower(cmd.Name) + for _, con := range tbl.Constraints { + if toLower(con.Name) == key && con.Type == ConCheck { + con.NotEnforced = (cmd.NewName == "NOT ENFORCED") + return nil + } + } + return &Error{ + Code: 3940, + SQLState: "HY000", + Message: fmt.Sprintf("Constraint '%s' does not exist.", cmd.Name), + } +} + +// alterConvertCharset handles CONVERT TO CHARACTER SET charset [COLLATE collation]. +// This changes the table's default charset/collation AND converts all existing +// string columns to the new charset/collation. +func (c *Catalog) alterConvertCharset(tbl *Table, cmd *nodes.AlterTableCmd) error { + charset := cmd.Name + collation := cmd.NewName + + // If no collation specified, use the default collation for the charset. + if collation == "" { + if defColl, ok := defaultCollationForCharset[toLower(charset)]; ok { + collation = defColl + } + } + + tbl.Charset = charset + tbl.Collation = collation + + // Convert all string-type columns to the new charset/collation. + for _, col := range tbl.Columns { + if isStringType(col.DataType) { + col.Charset = charset + col.Collation = collation + } + } + + return nil +} + +// Ensure strings import is used (for toLower references via strings package). +var _ = strings.ToLower + +// alterAddPartition adds partition definitions to a partitioned table. +func (c *Catalog) alterAddPartition(tbl *Table, cmd *nodes.AlterTableCmd) error { + if tbl.Partitioning == nil { + return fmt.Errorf("ALTER TABLE ADD PARTITION: table '%s' is not partitioned", tbl.Name) + } + for _, pd := range cmd.PartitionDefs { + pdi := &PartitionDefInfo{ + Name: pd.Name, + } + if pd.Values != nil { + pdi.ValueExpr = partitionValueToString(pd.Values, partitionTypeFromString(tbl.Partitioning.Type)) + } + for _, opt := range pd.Options { + switch toLower(opt.Name) { + case "engine": + pdi.Engine = opt.Value + case "comment": + pdi.Comment = opt.Value + } + } + tbl.Partitioning.Partitions = append(tbl.Partitioning.Partitions, pdi) + } + return nil +} + +// alterDropPartition drops named partitions from a partitioned table. +func (c *Catalog) alterDropPartition(tbl *Table, cmd *nodes.AlterTableCmd) error { + if tbl.Partitioning == nil { + return fmt.Errorf("ALTER TABLE DROP PARTITION: table '%s' is not partitioned", tbl.Name) + } + dropSet := make(map[string]bool) + for _, name := range cmd.PartitionNames { + dropSet[toLower(name)] = true + } + var remaining []*PartitionDefInfo + for _, pd := range tbl.Partitioning.Partitions { + if !dropSet[toLower(pd.Name)] { + remaining = append(remaining, pd) + } + } + tbl.Partitioning.Partitions = remaining + return nil +} + +// alterTruncatePartition truncates named partitions (no-op for metadata catalog). +func (c *Catalog) alterTruncatePartition(tbl *Table, cmd *nodes.AlterTableCmd) error { + if tbl.Partitioning == nil { + return fmt.Errorf("ALTER TABLE TRUNCATE PARTITION: table '%s' is not partitioned", tbl.Name) + } + // Truncate is a data operation; for DDL catalog purposes, it's a no-op. + return nil +} + +// alterCoalescePartition reduces the number of partitions for HASH/KEY partitioned tables. +func (c *Catalog) alterCoalescePartition(tbl *Table, cmd *nodes.AlterTableCmd) error { + if tbl.Partitioning == nil { + return fmt.Errorf("ALTER TABLE COALESCE PARTITION: table '%s' is not partitioned", tbl.Name) + } + // Determine current partition count. + currentCount := len(tbl.Partitioning.Partitions) + if currentCount == 0 { + currentCount = tbl.Partitioning.NumParts + } + newCount := currentCount - cmd.Number + if newCount < 1 { + newCount = 1 + } + if len(tbl.Partitioning.Partitions) > 0 { + tbl.Partitioning.Partitions = tbl.Partitioning.Partitions[:newCount] + } + tbl.Partitioning.NumParts = newCount + return nil +} + +// alterReorganizePartition reorganizes partitions into new definitions. +func (c *Catalog) alterReorganizePartition(tbl *Table, cmd *nodes.AlterTableCmd) error { + if tbl.Partitioning == nil { + return fmt.Errorf("ALTER TABLE REORGANIZE PARTITION: table '%s' is not partitioned", tbl.Name) + } + // Remove the old partitions. + dropSet := make(map[string]bool) + for _, name := range cmd.PartitionNames { + dropSet[toLower(name)] = true + } + var remaining []*PartitionDefInfo + insertPos := -1 + for i, pd := range tbl.Partitioning.Partitions { + if dropSet[toLower(pd.Name)] { + if insertPos < 0 { + insertPos = i + } + continue + } + remaining = append(remaining, pd) + } + if insertPos < 0 { + insertPos = len(remaining) + } + + // Build new partitions. + var newParts []*PartitionDefInfo + for _, pd := range cmd.PartitionDefs { + pdi := &PartitionDefInfo{ + Name: pd.Name, + } + if pd.Values != nil { + pdi.ValueExpr = partitionValueToString(pd.Values, partitionTypeFromString(tbl.Partitioning.Type)) + } + for _, opt := range pd.Options { + switch toLower(opt.Name) { + case "engine": + pdi.Engine = opt.Value + case "comment": + pdi.Comment = opt.Value + } + } + newParts = append(newParts, pdi) + } + + // Insert new partitions at the position of the first removed partition. + result := make([]*PartitionDefInfo, 0, len(remaining)+len(newParts)) + for i, pd := range remaining { + if i == insertPos { + result = append(result, newParts...) + } + result = append(result, pd) + } + if insertPos >= len(remaining) { + result = append(result, newParts...) + } + tbl.Partitioning.Partitions = result + return nil +} + +// alterExchangePartition exchanges a partition with a non-partitioned table. +func (c *Catalog) alterExchangePartition(db *Database, tbl *Table, cmd *nodes.AlterTableCmd) error { + if tbl.Partitioning == nil { + return fmt.Errorf("ALTER TABLE EXCHANGE PARTITION: table '%s' is not partitioned", tbl.Name) + } + // For DDL catalog purposes, exchange is primarily a data operation. + // We just validate both tables exist. + if cmd.ExchangeTable != nil { + exchDB := db + if cmd.ExchangeTable.Schema != "" { + exchDB = c.GetDatabase(cmd.ExchangeTable.Schema) + if exchDB == nil { + return errNoSuchTable(cmd.ExchangeTable.Schema, cmd.ExchangeTable.Name) + } + } + exchTbl := exchDB.GetTable(cmd.ExchangeTable.Name) + if exchTbl == nil { + return errNoSuchTable(exchDB.Name, cmd.ExchangeTable.Name) + } + } + return nil +} + +// partitionTypeFromString converts a string partition type to AST PartitionType. +func partitionTypeFromString(t string) nodes.PartitionType { + switch t { + case "RANGE", "RANGE COLUMNS": + return nodes.PartitionRange + case "LIST", "LIST COLUMNS": + return nodes.PartitionList + case "HASH": + return nodes.PartitionHash + case "KEY": + return nodes.PartitionKey + default: + return nodes.PartitionRange + } +} + +// generatedExprReferencesColumn checks if a generated column expression +// references a column by name. The expression uses backtick-quoted identifiers +// (e.g., `col_name`), so we search for the backtick-quoted form. +func generatedExprReferencesColumn(expr, colName string) bool { + target := "`" + strings.ToLower(colName) + "`" + return strings.Contains(strings.ToLower(expr), target) +} diff --git a/tidb/catalog/altercmds_test.go b/tidb/catalog/altercmds_test.go new file mode 100644 index 00000000..2a26478f --- /dev/null +++ b/tidb/catalog/altercmds_test.go @@ -0,0 +1,393 @@ +package catalog + +import "testing" + +func setupTestTable(t *testing.T) *Catalog { + t.Helper() + c := New() + mustExec(t, c, "CREATE DATABASE test") + c.SetCurrentDatabase("test") + mustExec(t, c, "CREATE TABLE t1 (id INT NOT NULL, name VARCHAR(100), age INT)") + return c +} + +func TestAlterTableAddColumn(t *testing.T) { + c := setupTestTable(t) + mustExec(t, c, "ALTER TABLE t1 ADD COLUMN email VARCHAR(255) NOT NULL") + + tbl := c.GetDatabase("test").GetTable("t1") + if len(tbl.Columns) != 4 { + t.Fatalf("expected 4 columns, got %d", len(tbl.Columns)) + } + + col := tbl.GetColumn("email") + if col == nil { + t.Fatal("column email not found") + } + if col.Nullable { + t.Error("email should not be nullable") + } + if col.ColumnType != "varchar(255)" { + t.Errorf("expected column type 'varchar(255)', got %q", col.ColumnType) + } + if col.Position != 4 { + t.Errorf("expected position 4, got %d", col.Position) + } + + // Adding duplicate column should fail. + results, _ := c.Exec("ALTER TABLE t1 ADD COLUMN email VARCHAR(100)", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("expected duplicate column error") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("expected *Error, got %T", results[0].Error) + } + if catErr.Code != ErrDupColumn { + t.Errorf("expected error code %d, got %d", ErrDupColumn, catErr.Code) + } +} + +func TestAlterTableAddColumnMulti(t *testing.T) { + c := setupTestTable(t) + mustExec(t, c, "ALTER TABLE t1 ADD COLUMN (email VARCHAR(255) NOT NULL, score INT)") + + tbl := c.GetDatabase("test").GetTable("t1") + if len(tbl.Columns) != 5 { + t.Fatalf("expected 5 columns, got %d", len(tbl.Columns)) + } + + col := tbl.GetColumn("email") + if col == nil { + t.Fatal("column email not found") + } + if col.Nullable { + t.Error("email should not be nullable") + } + if col.ColumnType != "varchar(255)" { + t.Errorf("expected column type 'varchar(255)', got %q", col.ColumnType) + } + + col2 := tbl.GetColumn("score") + if col2 == nil { + t.Fatal("column score not found") + } +} + +func TestAlterTableDropColumn(t *testing.T) { + c := setupTestTable(t) + mustExec(t, c, "ALTER TABLE t1 DROP COLUMN age") + + tbl := c.GetDatabase("test").GetTable("t1") + if len(tbl.Columns) != 2 { + t.Fatalf("expected 2 columns, got %d", len(tbl.Columns)) + } + + if tbl.GetColumn("age") != nil { + t.Error("column age should have been dropped") + } + + // Check remaining columns have correct positions. + id := tbl.GetColumn("id") + if id.Position != 1 { + t.Errorf("expected id position 1, got %d", id.Position) + } + name := tbl.GetColumn("name") + if name.Position != 2 { + t.Errorf("expected name position 2, got %d", name.Position) + } + + // Dropping non-existent column should fail. + results, _ := c.Exec("ALTER TABLE t1 DROP COLUMN nonexistent", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("expected no such column error") + } +} + +func TestAlterTableModifyColumn(t *testing.T) { + c := setupTestTable(t) + mustExec(t, c, "ALTER TABLE t1 MODIFY COLUMN name VARCHAR(200) NOT NULL") + + tbl := c.GetDatabase("test").GetTable("t1") + col := tbl.GetColumn("name") + if col == nil { + t.Fatal("column name not found") + } + if col.ColumnType != "varchar(200)" { + t.Errorf("expected column type 'varchar(200)', got %q", col.ColumnType) + } + if col.Nullable { + t.Error("name should not be nullable after MODIFY") + } + if col.Position != 2 { + t.Errorf("expected position 2, got %d", col.Position) + } +} + +func TestAlterTableModifyColumnAfterLater(t *testing.T) { + // Regression: MODIFY COLUMN ... AFTER panicked with slice bounds + // out of range when the AFTER column was positioned after the modified column. + c := setupTestTable(t) + // t1 has: id(1), name(2), age(3) + // Move id AFTER age → name(1), age(2), id(3) + mustExec(t, c, "ALTER TABLE t1 MODIFY COLUMN id INT NOT NULL AFTER age") + + tbl := c.GetDatabase("test").GetTable("t1") + if tbl.Columns[0].Name != "name" { + t.Errorf("expected column 0 to be 'name', got %q", tbl.Columns[0].Name) + } + if tbl.Columns[1].Name != "age" { + t.Errorf("expected column 1 to be 'age', got %q", tbl.Columns[1].Name) + } + if tbl.Columns[2].Name != "id" { + t.Errorf("expected column 2 to be 'id', got %q", tbl.Columns[2].Name) + } +} + +func TestAlterTableChangeColumn(t *testing.T) { + c := setupTestTable(t) + mustExec(t, c, "ALTER TABLE t1 CHANGE COLUMN name full_name VARCHAR(200)") + + tbl := c.GetDatabase("test").GetTable("t1") + + if tbl.GetColumn("name") != nil { + t.Error("old column 'name' should no longer exist") + } + + col := tbl.GetColumn("full_name") + if col == nil { + t.Fatal("column full_name not found") + } + if col.ColumnType != "varchar(200)" { + t.Errorf("expected column type 'varchar(200)', got %q", col.ColumnType) + } + if col.Position != 2 { + t.Errorf("expected position 2, got %d", col.Position) + } + + // Changing to a name that already exists should fail. + results, _ := c.Exec("ALTER TABLE t1 CHANGE COLUMN age full_name INT", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("expected duplicate column error") + } +} + +func TestAlterTableAddIndex(t *testing.T) { + c := setupTestTable(t) + mustExec(t, c, "ALTER TABLE t1 ADD INDEX idx_name (name)") + + tbl := c.GetDatabase("test").GetTable("t1") + + var found *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_name" { + found = idx + break + } + } + if found == nil { + t.Fatal("index idx_name not found") + } + if found.Unique { + t.Error("idx_name should not be unique") + } + if len(found.Columns) != 1 || found.Columns[0].Name != "name" { + t.Error("idx_name should have column 'name'") + } +} + +func TestAlterTableDropIndex(t *testing.T) { + c := setupTestTable(t) + mustExec(t, c, "ALTER TABLE t1 ADD INDEX idx_name (name)") + mustExec(t, c, "ALTER TABLE t1 DROP INDEX idx_name") + + tbl := c.GetDatabase("test").GetTable("t1") + for _, idx := range tbl.Indexes { + if idx.Name == "idx_name" { + t.Fatal("index idx_name should have been dropped") + } + } +} + +func TestAlterTableAddPrimaryKey(t *testing.T) { + c := setupTestTable(t) + mustExec(t, c, "ALTER TABLE t1 ADD PRIMARY KEY (id)") + + tbl := c.GetDatabase("test").GetTable("t1") + + var pkIdx *Index + for _, idx := range tbl.Indexes { + if idx.Primary { + pkIdx = idx + break + } + } + if pkIdx == nil { + t.Fatal("primary key index not found") + } + if pkIdx.Name != "PRIMARY" { + t.Errorf("expected PK index name 'PRIMARY', got %q", pkIdx.Name) + } + if !pkIdx.Unique { + t.Error("PK should be unique") + } + + // id column should be NOT NULL. + id := tbl.GetColumn("id") + if id.Nullable { + t.Error("id should not be nullable after adding PK") + } + + // Check PK constraint. + var pkCon *Constraint + for _, con := range tbl.Constraints { + if con.Type == ConPrimaryKey { + pkCon = con + break + } + } + if pkCon == nil { + t.Fatal("PK constraint not found") + } + + // Adding another PK should fail. + results, _ := c.Exec("ALTER TABLE t1 ADD PRIMARY KEY (name)", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("expected multiple primary key error") + } +} + +func TestAlterTableRenameColumn(t *testing.T) { + c := setupTestTable(t) + mustExec(t, c, "ALTER TABLE t1 RENAME COLUMN name TO full_name") + + tbl := c.GetDatabase("test").GetTable("t1") + + if tbl.GetColumn("name") != nil { + t.Error("old column 'name' should no longer exist") + } + col := tbl.GetColumn("full_name") + if col == nil { + t.Fatal("column full_name not found") + } + if col.Position != 2 { + t.Errorf("expected position 2, got %d", col.Position) + } + + // Renaming to existing name should fail. + results, _ := c.Exec("ALTER TABLE t1 RENAME COLUMN full_name TO id", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("expected duplicate column error") + } +} + +func TestAlterTableAddColumnFirst(t *testing.T) { + c := setupTestTable(t) + mustExec(t, c, "ALTER TABLE t1 ADD COLUMN email VARCHAR(255) FIRST") + + tbl := c.GetDatabase("test").GetTable("t1") + if len(tbl.Columns) != 4 { + t.Fatalf("expected 4 columns, got %d", len(tbl.Columns)) + } + + email := tbl.GetColumn("email") + if email == nil { + t.Fatal("column email not found") + } + if email.Position != 1 { + t.Errorf("expected email at position 1, got %d", email.Position) + } + + id := tbl.GetColumn("id") + if id.Position != 2 { + t.Errorf("expected id at position 2, got %d", id.Position) + } + + name := tbl.GetColumn("name") + if name.Position != 3 { + t.Errorf("expected name at position 3, got %d", name.Position) + } + + age := tbl.GetColumn("age") + if age.Position != 4 { + t.Errorf("expected age at position 4, got %d", age.Position) + } +} + +func TestAlterTableAddColumnAfter(t *testing.T) { + c := setupTestTable(t) + mustExec(t, c, "ALTER TABLE t1 ADD COLUMN email VARCHAR(255) AFTER id") + + tbl := c.GetDatabase("test").GetTable("t1") + if len(tbl.Columns) != 4 { + t.Fatalf("expected 4 columns, got %d", len(tbl.Columns)) + } + + id := tbl.GetColumn("id") + if id.Position != 1 { + t.Errorf("expected id at position 1, got %d", id.Position) + } + + email := tbl.GetColumn("email") + if email == nil { + t.Fatal("column email not found") + } + if email.Position != 2 { + t.Errorf("expected email at position 2, got %d", email.Position) + } + + name := tbl.GetColumn("name") + if name.Position != 3 { + t.Errorf("expected name at position 3, got %d", name.Position) + } + + age := tbl.GetColumn("age") + if age.Position != 4 { + t.Errorf("expected age at position 4, got %d", age.Position) + } +} + +func TestAlterTableMultiCommandRollback(t *testing.T) { + c := setupTestTable(t) + // t1 has: id INT NOT NULL, name VARCHAR(100), age INT + + // Multi-command ALTER where second command fails (duplicate column). + // MySQL rolls back the entire ALTER — first ADD should also be undone. + results, _ := c.Exec( + "ALTER TABLE t1 ADD COLUMN email VARCHAR(255), ADD COLUMN name VARCHAR(50)", + &ExecOptions{ContinueOnError: true}, + ) + if results[0].Error == nil { + t.Fatal("expected duplicate column error") + } + + tbl := c.GetDatabase("test").GetTable("t1") + + // Verify rollback: email should NOT have been added. + if tbl.GetColumn("email") != nil { + t.Error("column 'email' should not exist after rollback") + } + + // Original columns should be intact. + if len(tbl.Columns) != 3 { + t.Errorf("expected 3 columns after rollback, got %d", len(tbl.Columns)) + } +} + +func TestAlterTableMultiCommandSuccess(t *testing.T) { + c := setupTestTable(t) + + // Multi-command ALTER that succeeds — all changes should apply. + mustExec(t, c, "ALTER TABLE t1 ADD COLUMN email VARCHAR(255), ADD COLUMN score INT") + + tbl := c.GetDatabase("test").GetTable("t1") + if len(tbl.Columns) != 5 { + t.Fatalf("expected 5 columns, got %d", len(tbl.Columns)) + } + if tbl.GetColumn("email") == nil { + t.Error("column 'email' not found") + } + if tbl.GetColumn("score") == nil { + t.Error("column 'score' not found") + } +} diff --git a/tidb/catalog/analyze.go b/tidb/catalog/analyze.go new file mode 100644 index 00000000..894af659 --- /dev/null +++ b/tidb/catalog/analyze.go @@ -0,0 +1,1004 @@ +package catalog + +import ( + "fmt" + "strings" + + nodes "github.com/bytebase/omni/tidb/ast" +) + +// AnalyzeSelectStmt performs semantic analysis on a parsed SELECT statement, +// returning a resolved Query IR. +func (c *Catalog) AnalyzeSelectStmt(stmt *nodes.SelectStmt) (*Query, error) { + return c.analyzeSelectStmtInternal(stmt, nil) +} + +// analyzeSelectStmtInternal is the core analysis routine. parentScope is non-nil +// when analyzing a subquery (for correlated column resolution). +func (c *Catalog) analyzeSelectStmtInternal(stmt *nodes.SelectStmt, parentScope *analyzerScope) (*Query, error) { + return c.analyzeSelectStmtWithCTEs(stmt, parentScope, nil) +} + +// analyzeSelectStmtWithCTEs is the full internal analysis routine. +// inheritedCTEMap provides CTE definitions inherited from an enclosing context +// (e.g., a recursive CTE body referencing itself). +func (c *Catalog) analyzeSelectStmtWithCTEs(stmt *nodes.SelectStmt, parentScope *analyzerScope, inheritedCTEMap map[string]*CommonTableExprQ) (*Query, error) { + // Handle set operations (UNION/INTERSECT/EXCEPT). + if stmt.SetOp != nodes.SetOpNone { + return c.analyzeSetOpWithCTEs(stmt, parentScope, inheritedCTEMap) + } + + q := &Query{ + CommandType: CmdSelect, + JoinTree: &JoinTreeQ{}, + } + + scope := newScopeWithParent(parentScope) + + // Step 0: Process CTEs (WITH clause). + cteMap, err := c.analyzeCTEs(stmt.CTEs, q, parentScope) + if err != nil { + return nil, err + } + + // Merge inherited CTE map (for recursive CTE self-references). + if inheritedCTEMap != nil { + if cteMap == nil { + cteMap = make(map[string]*CommonTableExprQ) + } + for k, v := range inheritedCTEMap { + if _, exists := cteMap[k]; !exists { + cteMap[k] = v + } + } + } + + // Step 1: Analyze FROM clause → populate RangeTable and scope. + if err := analyzeFromClause(c, stmt.From, q, scope, cteMap); err != nil { + return nil, err + } + + // Step 2: Analyze target list (SELECT expressions). + if err := analyzeTargetList(c, stmt.TargetList, q, scope); err != nil { + return nil, err + } + + // Step 3: Analyze WHERE clause. + if stmt.Where != nil { + analyzed, err := analyzeExpr(c, stmt.Where, scope) + if err != nil { + return nil, err + } + q.JoinTree.Quals = analyzed + } + + // Step 4: GROUP BY + if len(stmt.GroupBy) > 0 { + if err := c.analyzeGroupBy(stmt.GroupBy, q, scope); err != nil { + return nil, err + } + } + + // Step 5: HAVING + if stmt.Having != nil { + analyzed, err := analyzeExpr(c, stmt.Having, scope) + if err != nil { + return nil, err + } + q.HavingQual = analyzed + } + + // Step 6: Detect aggregates in target list and having + q.HasAggs = detectAggregates(q) + + // Step 7: ORDER BY + if len(stmt.OrderBy) > 0 { + if err := c.analyzeOrderBy(stmt.OrderBy, q, scope); err != nil { + return nil, err + } + } + + // Step 8: LIMIT / OFFSET + if stmt.Limit != nil { + if err := c.analyzeLimitOffset(stmt.Limit, q, scope); err != nil { + return nil, err + } + } + + // Step 9: DISTINCT + q.Distinct = stmt.DistinctKind != nodes.DistinctNone && stmt.DistinctKind != nodes.DistinctAll + + return q, nil +} + +// analyzeFromClause processes the FROM clause, populating the query's +// RangeTable, JoinTree.FromList, and the scope for column resolution. +func analyzeFromClause(c *Catalog, from []nodes.TableExpr, q *Query, scope *analyzerScope, cteMap map[string]*CommonTableExprQ) error { + for _, te := range from { + joinNode, err := analyzeTableExpr(c, te, q, scope, cteMap) + if err != nil { + return err + } + q.JoinTree.FromList = append(q.JoinTree.FromList, joinNode) + } + return nil +} + +// analyzeTableExpr recursively processes a table expression (TableRef, +// JoinClause, or SubqueryExpr used as a derived table), creating RTEs and +// scope entries as appropriate, and returning a JoinNode for the join tree. +func analyzeTableExpr(c *Catalog, te nodes.TableExpr, q *Query, scope *analyzerScope, cteMap map[string]*CommonTableExprQ) (JoinNode, error) { + switch ref := te.(type) { + case *nodes.TableRef: + // Check if this references a CTE before looking up catalog tables. + if cteMap != nil && ref.Schema == "" { + lower := strings.ToLower(ref.Name) + if cteQ, ok := cteMap[lower]; ok { + return analyzeCTERef(ref, cteQ, q, scope) + } + } + rte, cols, err := analyzeTableRef(c, ref) + if err != nil { + return nil, err + } + idx := len(q.RangeTable) + q.RangeTable = append(q.RangeTable, rte) + scope.add(rte.ERef, idx, cols) + return &RangeTableRefQ{RTIndex: idx}, nil + + case *nodes.JoinClause: + return analyzeJoinClause(c, ref, q, scope, cteMap) + + case *nodes.SubqueryExpr: + return analyzeFromSubquery(c, ref, q, scope) + + default: + return nil, fmt.Errorf("unsupported FROM clause element: %T", te) + } +} + +// analyzeJoinClause processes a JOIN clause, creating RTEs for the join and +// its children, and returning a JoinExprNodeQ. +func analyzeJoinClause(c *Catalog, jc *nodes.JoinClause, q *Query, scope *analyzerScope, cteMap map[string]*CommonTableExprQ) (JoinNode, error) { + // Recursively process left and right sides. + left, err := analyzeTableExpr(c, jc.Left, q, scope, cteMap) + if err != nil { + return nil, err + } + right, err := analyzeTableExpr(c, jc.Right, q, scope, cteMap) + if err != nil { + return nil, err + } + + // Map AST JoinType to IR JoinTypeQ and detect NATURAL. + var joinType JoinTypeQ + natural := false + switch jc.Type { + case nodes.JoinInner: + joinType = JoinInner + case nodes.JoinLeft: + joinType = JoinLeft + case nodes.JoinRight: + joinType = JoinRight + case nodes.JoinCross: + joinType = JoinCross + case nodes.JoinStraight: + joinType = JoinStraight + case nodes.JoinNatural: + joinType = JoinInner + natural = true + case nodes.JoinNaturalLeft: + joinType = JoinLeft + natural = true + case nodes.JoinNaturalRight: + joinType = JoinRight + natural = true + default: + return nil, fmt.Errorf("unsupported join type: %d", jc.Type) + } + + // Collect left and right column names for USING/NATURAL coalescing. + leftCols := collectJoinNodeColNames(left, q) + rightCols := collectJoinNodeColNames(right, q) + + var usingCols []string + var quals AnalyzedExpr + + switch cond := jc.Condition.(type) { + case *nodes.OnCondition: + quals, err = analyzeExpr(c, cond.Expr, scope) + if err != nil { + return nil, err + } + case *nodes.UsingCondition: + usingCols = cond.Columns + case nil: + // CROSS JOIN or condition resolved via NATURAL below. + default: + return nil, fmt.Errorf("unsupported join condition type: %T", cond) + } + + // For NATURAL JOIN, compute shared columns. + if natural { + usingCols = computeNaturalJoinColumns(leftCols, rightCols) + } + + // Build coalesced column names for the RTEJoin. + coalescedCols := buildCoalescedColNames(leftCols, rightCols, usingCols) + + // Mark right-side USING columns as coalesced for star expansion. + if len(usingCols) > 0 { + markCoalescedColumns(right, q, scope, usingCols) + } + + // Create the RTEJoin entry. + rteJoin := &RangeTableEntryQ{ + Kind: RTEJoin, + JoinType: joinType, + JoinUsing: usingCols, + ColNames: coalescedCols, + } + rtIdx := len(q.RangeTable) + q.RangeTable = append(q.RangeTable, rteJoin) + + joinExpr := &JoinExprNodeQ{ + JoinType: joinType, + Left: left, + Right: right, + Quals: quals, + UsingClause: usingCols, + Natural: natural, + RTIndex: rtIdx, + } + + return joinExpr, nil +} + +// collectJoinNodeColNames returns the column names contributed by a JoinNode. +func collectJoinNodeColNames(node JoinNode, q *Query) []string { + switch n := node.(type) { + case *RangeTableRefQ: + return q.RangeTable[n.RTIndex].ColNames + case *JoinExprNodeQ: + return q.RangeTable[n.RTIndex].ColNames + } + return nil +} + +// computeNaturalJoinColumns finds column names shared between left and right +// (preserving left-side order). +func computeNaturalJoinColumns(leftCols, rightCols []string) []string { + rightSet := make(map[string]bool, len(rightCols)) + for _, c := range rightCols { + rightSet[strings.ToLower(c)] = true + } + var shared []string + for _, c := range leftCols { + if rightSet[strings.ToLower(c)] { + shared = append(shared, c) + } + } + return shared +} + +// buildCoalescedColNames builds the coalesced column list for a JOIN: +// USING columns first (from left), then remaining left columns, then remaining right columns. +func buildCoalescedColNames(leftCols, rightCols, usingCols []string) []string { + if len(usingCols) == 0 { + // No coalescing — just concatenate all columns. + result := make([]string, 0, len(leftCols)+len(rightCols)) + result = append(result, leftCols...) + result = append(result, rightCols...) + return result + } + + usingSet := make(map[string]bool, len(usingCols)) + for _, c := range usingCols { + usingSet[strings.ToLower(c)] = true + } + + result := make([]string, 0, len(leftCols)+len(rightCols)-len(usingCols)) + // USING columns first (in USING order, from left side). + result = append(result, usingCols...) + // Remaining left columns. + for _, c := range leftCols { + if !usingSet[strings.ToLower(c)] { + result = append(result, c) + } + } + // Remaining right columns (skip USING columns). + for _, c := range rightCols { + if !usingSet[strings.ToLower(c)] { + result = append(result, c) + } + } + return result +} + +// markCoalescedColumns marks right-side USING columns as coalesced in scope +// so that star expansion skips them (avoiding duplicate columns). +func markCoalescedColumns(rightNode JoinNode, q *Query, scope *analyzerScope, usingCols []string) { + usingSet := make(map[string]bool, len(usingCols)) + for _, c := range usingCols { + usingSet[strings.ToLower(c)] = true + } + + // Walk the right node's base tables and mark their USING columns. + markCoalescedInNode(rightNode, q, scope, usingSet) +} + +// markCoalescedInNode recursively marks coalesced columns on base tables +// within a join node. +func markCoalescedInNode(node JoinNode, q *Query, scope *analyzerScope, usingSet map[string]bool) { + switch n := node.(type) { + case *RangeTableRefQ: + rte := q.RangeTable[n.RTIndex] + scopeName := strings.ToLower(rte.ERef) + for _, colName := range rte.ColNames { + if usingSet[strings.ToLower(colName)] { + scope.markCoalesced(scopeName, colName) + } + } + case *JoinExprNodeQ: + markCoalescedInNode(n.Left, q, scope, usingSet) + markCoalescedInNode(n.Right, q, scope, usingSet) + } +} + +// analyzeFromSubquery processes a subquery used as a derived table in FROM. +func analyzeFromSubquery(c *Catalog, subq *nodes.SubqueryExpr, q *Query, scope *analyzerScope) (JoinNode, error) { + // Recursively analyze the inner SELECT. FROM subqueries are not correlated, + // so we pass nil as parent scope (unless LATERAL). + var parentScope *analyzerScope + if subq.Lateral { + parentScope = scope + } + innerQ, err := c.analyzeSelectStmtInternal(subq.Select, parentScope) + if err != nil { + return nil, err + } + + // Derive column names from the inner query's non-junk target list. + var colNames []string + for _, te := range innerQ.TargetList { + if !te.ResJunk { + colNames = append(colNames, te.ResName) + } + } + + // If explicit column aliases are provided, use those instead. + if len(subq.Columns) > 0 { + colNames = subq.Columns + } + + alias := subq.Alias + if alias == "" { + alias = fmt.Sprintf("__subquery_%d", len(q.RangeTable)) + } + + rte := &RangeTableEntryQ{ + Kind: RTESubquery, + Alias: alias, + ERef: alias, + ColNames: colNames, + Subquery: innerQ, + Lateral: subq.Lateral, + } + + idx := len(q.RangeTable) + q.RangeTable = append(q.RangeTable, rte) + + // Build stub columns for scope resolution. + cols := make([]*Column, len(colNames)) + for i, name := range colNames { + cols[i] = &Column{Position: i + 1, Name: name} + } + scope.add(alias, idx, cols) + + return &RangeTableRefQ{RTIndex: idx}, nil +} + +// analyzeTableRef resolves a table reference from the FROM clause against +// the catalog, returning the RTE and the column list. +func analyzeTableRef(c *Catalog, ref *nodes.TableRef) (*RangeTableEntryQ, []*Column, error) { + dbName := ref.Schema + if dbName == "" { + dbName = c.CurrentDatabase() + } + if dbName == "" { + return nil, nil, errNoDatabaseSelected() + } + + db := c.GetDatabase(dbName) + if db == nil { + return nil, nil, errUnknownDatabase(dbName) + } + + // Check for a table first, then a view. + tbl := db.GetTable(ref.Name) + if tbl != nil { + eref := ref.Name + if ref.Alias != "" { + eref = ref.Alias + } + colNames := make([]string, len(tbl.Columns)) + for i, col := range tbl.Columns { + colNames[i] = col.Name + } + rte := &RangeTableEntryQ{ + Kind: RTERelation, + DBName: db.Name, + TableName: tbl.Name, + Alias: ref.Alias, + ERef: eref, + ColNames: colNames, + } + return rte, tbl.Columns, nil + } + + // Check views. + view := db.Views[toLower(ref.Name)] + if view != nil { + eref := ref.Name + if ref.Alias != "" { + eref = ref.Alias + } + // Build stub columns from view column names. + cols := make([]*Column, len(view.Columns)) + colNames := make([]string, len(view.Columns)) + for i, name := range view.Columns { + cols[i] = &Column{Position: i + 1, Name: name} + colNames[i] = name + } + rte := &RangeTableEntryQ{ + Kind: RTERelation, + DBName: db.Name, + TableName: view.Name, + Alias: ref.Alias, + ERef: eref, + ColNames: colNames, + IsView: true, + ViewAlgorithm: viewAlgorithmFromString(view.Algorithm), + } + return rte, cols, nil + } + + return nil, nil, errNoSuchTable(dbName, ref.Name) +} + +// viewAlgorithmFromString converts a string algorithm value to ViewAlgorithm. +func viewAlgorithmFromString(s string) ViewAlgorithm { + switch strings.ToUpper(s) { + case "MERGE": + return ViewAlgMerge + case "TEMPTABLE": + return ViewAlgTemptable + case "UNDEFINED", "": + return ViewAlgUndefined + default: + return ViewAlgUndefined + } +} + +// analyzeGroupBy processes the GROUP BY clause, populating q.GroupClause. +func (c *Catalog) analyzeGroupBy(groupBy []nodes.ExprNode, q *Query, scope *analyzerScope) error { + for _, expr := range groupBy { + switch n := expr.(type) { + case *nodes.IntLit: + // Ordinal reference: GROUP BY 1 means first SELECT column. + idx := int(n.Value) + if idx < 1 || idx > len(q.TargetList) { + return fmt.Errorf("GROUP BY position %d is not in select list", idx) + } + q.GroupClause = append(q.GroupClause, &SortGroupClauseQ{ + TargetIdx: idx, + }) + case *nodes.ColumnRef: + // Resolve the column ref, then find matching target entry. + analyzed, err := analyzeExpr(c, n, scope) + if err != nil { + return err + } + varExpr, ok := analyzed.(*VarExprQ) + if !ok { + return fmt.Errorf("GROUP BY column reference resolved to unexpected type %T", analyzed) + } + targetIdx := findMatchingTarget(q.TargetList, varExpr) + if targetIdx == 0 { + // Not found in target list — add as junk entry. + te := &TargetEntryQ{ + Expr: varExpr, + ResNo: len(q.TargetList) + 1, + ResName: n.Column, + ResJunk: true, + } + q.TargetList = append(q.TargetList, te) + targetIdx = te.ResNo + } + q.GroupClause = append(q.GroupClause, &SortGroupClauseQ{ + TargetIdx: targetIdx, + }) + default: + // General expression — analyze and try to match to target list. + analyzed, err := analyzeExpr(c, expr, scope) + if err != nil { + return err + } + targetIdx := 0 + for _, te := range q.TargetList { + if exprEqual(te.Expr, analyzed) { + targetIdx = te.ResNo + break + } + } + if targetIdx == 0 { + te := &TargetEntryQ{ + Expr: analyzed, + ResNo: len(q.TargetList) + 1, + ResJunk: true, + } + q.TargetList = append(q.TargetList, te) + targetIdx = te.ResNo + } + q.GroupClause = append(q.GroupClause, &SortGroupClauseQ{ + TargetIdx: targetIdx, + }) + } + } + return nil +} + +// findMatchingTarget finds a TargetEntryQ whose Expr is a VarExprQ matching +// the given VarExprQ (same RangeIdx and AttNum). Returns ResNo (1-based) or 0. +func findMatchingTarget(tl []*TargetEntryQ, v *VarExprQ) int { + for _, te := range tl { + if tv, ok := te.Expr.(*VarExprQ); ok { + if tv.RangeIdx == v.RangeIdx && tv.AttNum == v.AttNum { + return te.ResNo + } + } + } + return 0 +} + +// exprEqual compares two AnalyzedExpr values for structural equality. +// Phase 1a: only VarExprQ is compared; other types return false. +func exprEqual(a, b AnalyzedExpr) bool { + va, okA := a.(*VarExprQ) + vb, okB := b.(*VarExprQ) + if okA && okB { + return va.RangeIdx == vb.RangeIdx && va.AttNum == vb.AttNum + } + return false +} + +// detectAggregates returns true if any aggregate function call exists in the +// query's TargetList or HavingQual. +func detectAggregates(q *Query) bool { + for _, te := range q.TargetList { + if exprContainsAggregate(te.Expr) { + return true + } + } + if q.HavingQual != nil { + return exprContainsAggregate(q.HavingQual) + } + return false +} + +// exprContainsAggregate recursively walks an AnalyzedExpr looking for +// FuncCallExprQ with IsAggregate=true. +func exprContainsAggregate(expr AnalyzedExpr) bool { + if expr == nil { + return false + } + switch e := expr.(type) { + case *FuncCallExprQ: + if e.IsAggregate { + return true + } + for _, arg := range e.Args { + if exprContainsAggregate(arg) { + return true + } + } + case *OpExprQ: + return exprContainsAggregate(e.Left) || exprContainsAggregate(e.Right) + case *BoolExprQ: + for _, arg := range e.Args { + if exprContainsAggregate(arg) { + return true + } + } + case *InListExprQ: + if exprContainsAggregate(e.Arg) { + return true + } + for _, item := range e.List { + if exprContainsAggregate(item) { + return true + } + } + case *BetweenExprQ: + return exprContainsAggregate(e.Arg) || exprContainsAggregate(e.Lower) || exprContainsAggregate(e.Upper) + case *NullTestExprQ: + return exprContainsAggregate(e.Arg) + case *VarExprQ, *ConstExprQ: + return false + } + return false +} + +// analyzeOrderBy processes the ORDER BY clause, populating q.SortClause. +// When an ORDER BY expression is not in the SELECT list, a junk TargetEntryQ +// is added (ResJunk=true). +func (c *Catalog) analyzeOrderBy(orderBy []*nodes.OrderByItem, q *Query, scope *analyzerScope) error { + for _, item := range orderBy { + desc := item.Desc + // MySQL default: ASC → NullsFirst=true, DESC → NullsFirst=false. + nullsFirst := !desc + if item.NullsFirst != nil { + nullsFirst = *item.NullsFirst + } + + switch n := item.Expr.(type) { + case *nodes.IntLit: + // Ordinal reference: ORDER BY 1 means first SELECT column. + idx := int(n.Value) + if idx < 1 || idx > len(q.TargetList) { + return fmt.Errorf("ORDER BY position %d is not in select list", idx) + } + q.SortClause = append(q.SortClause, &SortGroupClauseQ{ + TargetIdx: idx, + Descending: desc, + NullsFirst: nullsFirst, + }) + case *nodes.ColumnRef: + // Resolve the column ref, then find matching target entry. + analyzed, err := analyzeExpr(c, n, scope) + if err != nil { + return err + } + varExpr, ok := analyzed.(*VarExprQ) + if !ok { + return fmt.Errorf("ORDER BY column reference resolved to unexpected type %T", analyzed) + } + targetIdx := findMatchingTarget(q.TargetList, varExpr) + if targetIdx == 0 { + // Not found in target list — add as junk entry. + te := &TargetEntryQ{ + Expr: varExpr, + ResNo: len(q.TargetList) + 1, + ResName: n.Column, + ResJunk: true, + } + q.TargetList = append(q.TargetList, te) + targetIdx = te.ResNo + } + q.SortClause = append(q.SortClause, &SortGroupClauseQ{ + TargetIdx: targetIdx, + Descending: desc, + NullsFirst: nullsFirst, + }) + default: + // General expression — analyze and try to match to target list. + analyzed, err := analyzeExpr(c, item.Expr, scope) + if err != nil { + return err + } + targetIdx := 0 + for _, te := range q.TargetList { + if exprEqual(te.Expr, analyzed) { + targetIdx = te.ResNo + break + } + } + if targetIdx == 0 { + te := &TargetEntryQ{ + Expr: analyzed, + ResNo: len(q.TargetList) + 1, + ResJunk: true, + } + q.TargetList = append(q.TargetList, te) + targetIdx = te.ResNo + } + q.SortClause = append(q.SortClause, &SortGroupClauseQ{ + TargetIdx: targetIdx, + Descending: desc, + NullsFirst: nullsFirst, + }) + } + } + return nil +} + +// analyzeLimitOffset processes the LIMIT/OFFSET clause, populating +// q.LimitCount and q.LimitOffset. +func (c *Catalog) analyzeLimitOffset(limit *nodes.Limit, q *Query, scope *analyzerScope) error { + if limit.Count != nil { + analyzed, err := analyzeExpr(c, limit.Count, scope) + if err != nil { + return err + } + q.LimitCount = analyzed + } + if limit.Offset != nil { + analyzed, err := analyzeExpr(c, limit.Offset, scope) + if err != nil { + return err + } + q.LimitOffset = analyzed + } + return nil +} + +// analyzeCTEs processes WITH clause CTEs, returning a map for CTE name lookup. +func (c *Catalog) analyzeCTEs(ctes []*nodes.CommonTableExpr, q *Query, parentScope *analyzerScope) (map[string]*CommonTableExprQ, error) { + if len(ctes) == 0 { + return nil, nil + } + + cteMap := make(map[string]*CommonTableExprQ) + for i, cte := range ctes { + if cte.Recursive { + q.IsRecursive = true + } + + var innerQ *Query + var err error + + if cte.Recursive && cte.Select.SetOp != nodes.SetOpNone { + // Recursive CTE: analyze the left arm first to establish columns, + // then register the CTE, then analyze the right arm. + innerQ, err = c.analyzeRecursiveCTE(cte, parentScope, cteMap) + } else { + innerQ, err = c.analyzeSelectStmtInternal(cte.Select, parentScope) + } + if err != nil { + return nil, err + } + + // Derive column names. + colNames := cte.Columns + if len(colNames) == 0 { + for _, te := range innerQ.TargetList { + if !te.ResJunk { + colNames = append(colNames, te.ResName) + } + } + } + + cteQ := &CommonTableExprQ{ + Name: cte.Name, + ColumnNames: colNames, + Query: innerQ, + Recursive: cte.Recursive, + } + q.CTEList = append(q.CTEList, cteQ) + cteMap[strings.ToLower(cte.Name)] = cteQ + _ = i + } + return cteMap, nil +} + +// analyzeRecursiveCTE handles WITH RECURSIVE where the CTE body is a set +// operation. The left arm establishes column signatures; the right arm may +// reference the CTE itself. +func (c *Catalog) analyzeRecursiveCTE(cte *nodes.CommonTableExpr, parentScope *analyzerScope, cteMap map[string]*CommonTableExprQ) (*Query, error) { + stmt := cte.Select + + // Analyze left arm to establish base columns. + larg, err := c.analyzeSelectStmtInternal(stmt.Left, parentScope) + if err != nil { + return nil, err + } + + // Derive column names from left arm (or explicit column list). + colNames := cte.Columns + if len(colNames) == 0 { + for _, te := range larg.TargetList { + if !te.ResJunk { + colNames = append(colNames, te.ResName) + } + } + } + + // Register a temporary CTE entry so the right arm can reference it. + tempCTE := &CommonTableExprQ{ + Name: cte.Name, + ColumnNames: colNames, + Query: larg, // temporary — will be replaced + Recursive: true, + } + cteMap[strings.ToLower(cte.Name)] = tempCTE + + // Analyze right arm with the CTE visible via inherited CTE map. + rarg, err := c.analyzeSelectStmtWithCTEs(stmt.Right, parentScope, cteMap) + if err != nil { + return nil, err + } + + // Map SetOperation to SetOpType. + var setOp SetOpType + switch stmt.SetOp { + case nodes.SetOpUnion: + setOp = SetOpUnion + case nodes.SetOpIntersect: + setOp = SetOpIntersect + case nodes.SetOpExcept: + setOp = SetOpExcept + } + + // Build result columns from left arm. + targetList := make([]*TargetEntryQ, 0, len(larg.TargetList)) + for _, te := range larg.TargetList { + if !te.ResJunk { + targetList = append(targetList, &TargetEntryQ{ + Expr: te.Expr, + ResNo: te.ResNo, + ResName: te.ResName, + }) + } + } + + q := &Query{ + CommandType: CmdSelect, + SetOp: setOp, + AllSetOp: stmt.SetAll, + LArg: larg, + RArg: rarg, + TargetList: targetList, + JoinTree: &JoinTreeQ{}, + } + return q, nil +} + +// analyzeCTERef creates an RTE and scope entry for a CTE reference in FROM. +func analyzeCTERef(ref *nodes.TableRef, cteQ *CommonTableExprQ, q *Query, scope *analyzerScope) (JoinNode, error) { + eref := ref.Name + if ref.Alias != "" { + eref = ref.Alias + } + + // Find the CTE's index in the current query's CTEList. + cteIndex := -1 + for i, c := range q.CTEList { + if strings.EqualFold(c.Name, cteQ.Name) { + cteIndex = i + break + } + } + if cteIndex < 0 { + cteIndex = 0 // fallback for recursive self-ref during analysis + } + + colNames := cteQ.ColumnNames + + rte := &RangeTableEntryQ{ + Kind: RTECTE, + Alias: ref.Alias, + ERef: eref, + ColNames: colNames, + CTEIndex: cteIndex, + CTEName: cteQ.Name, + Subquery: cteQ.Query, + } + + idx := len(q.RangeTable) + q.RangeTable = append(q.RangeTable, rte) + + // Build stub columns for scope resolution. + cols := make([]*Column, len(colNames)) + for i, name := range colNames { + cols[i] = &Column{Position: i + 1, Name: name} + } + scope.add(eref, idx, cols) + + return &RangeTableRefQ{RTIndex: idx}, nil +} + +// analyzeSetOp processes a set operation (UNION/INTERSECT/EXCEPT). +func (c *Catalog) analyzeSetOp(stmt *nodes.SelectStmt, parentScope *analyzerScope) (*Query, error) { + return c.analyzeSetOpWithCTEs(stmt, parentScope, nil) +} + +// analyzeSetOpWithCTEs processes a set operation with inherited CTE definitions. +func (c *Catalog) analyzeSetOpWithCTEs(stmt *nodes.SelectStmt, parentScope *analyzerScope, inheritedCTEMap map[string]*CommonTableExprQ) (*Query, error) { + // Process CTEs if present on the outer set-op node. + q := &Query{ + CommandType: CmdSelect, + JoinTree: &JoinTreeQ{}, + } + + cteMap, err := c.analyzeCTEs(stmt.CTEs, q, parentScope) + if err != nil { + return nil, err + } + + // Merge inherited CTEs. + if inheritedCTEMap != nil { + if cteMap == nil { + cteMap = make(map[string]*CommonTableExprQ) + } + for k, v := range inheritedCTEMap { + if _, exists := cteMap[k]; !exists { + cteMap[k] = v + } + } + } + + larg, err := c.analyzeSelectStmtWithCTEs(stmt.Left, parentScope, cteMap) + if err != nil { + return nil, err + } + rarg, err := c.analyzeSelectStmtWithCTEs(stmt.Right, parentScope, cteMap) + if err != nil { + return nil, err + } + + // Map SetOperation to SetOpType. + var setOp SetOpType + switch stmt.SetOp { + case nodes.SetOpUnion: + setOp = SetOpUnion + case nodes.SetOpIntersect: + setOp = SetOpIntersect + case nodes.SetOpExcept: + setOp = SetOpExcept + } + + // Result columns come from the left arm (MySQL convention). + targetList := make([]*TargetEntryQ, 0, len(larg.TargetList)) + for _, te := range larg.TargetList { + if !te.ResJunk { + targetList = append(targetList, &TargetEntryQ{ + Expr: te.Expr, + ResNo: te.ResNo, + ResName: te.ResName, + }) + } + } + + q.SetOp = setOp + q.AllSetOp = stmt.SetAll + q.LArg = larg + q.RArg = rarg + q.TargetList = targetList + + // Handle ORDER BY / LIMIT on the outer set-op query using a + // scope built from result columns. + if len(stmt.OrderBy) > 0 || stmt.Limit != nil { + setScope := newScope() + // Build stub columns from target list for ORDER BY resolution. + cols := make([]*Column, len(targetList)) + colNames := make([]string, len(targetList)) + for i, te := range targetList { + cols[i] = &Column{Position: i + 1, Name: te.ResName} + colNames[i] = te.ResName + } + // Add a virtual table entry for unqualified column resolution. + setScope.add("__setop__", 0, cols) + + if len(stmt.OrderBy) > 0 { + if err := c.analyzeOrderBy(stmt.OrderBy, q, setScope); err != nil { + return nil, err + } + } + if stmt.Limit != nil { + if err := c.analyzeLimitOffset(stmt.Limit, q, setScope); err != nil { + return nil, err + } + } + } + + return q, nil +} + +// AnalyzeStandaloneExpr analyzes an expression in the context of a single table. +// Used for CHECK constraints, DEFAULT expressions, and GENERATED column expressions. +func (c *Catalog) AnalyzeStandaloneExpr(expr nodes.ExprNode, table *Table) (AnalyzedExpr, error) { + scope := newScope() + // Register the table's columns into the scope at RTE index 0. + scope.add(table.Name, 0, table.Columns) + return analyzeExpr(c, expr, scope) +} diff --git a/tidb/catalog/analyze_expr.go b/tidb/catalog/analyze_expr.go new file mode 100644 index 00000000..a2d947a6 --- /dev/null +++ b/tidb/catalog/analyze_expr.go @@ -0,0 +1,368 @@ +package catalog + +import ( + "strconv" + "strings" + + nodes "github.com/bytebase/omni/tidb/ast" +) + +// analyzeExpr is the main expression analysis dispatcher. +func analyzeExpr(c *Catalog, expr nodes.ExprNode, scope *analyzerScope) (AnalyzedExpr, error) { + switch n := expr.(type) { + case *nodes.ColumnRef: + return analyzeColumnRef(n, scope) + case *nodes.IntLit: + return &ConstExprQ{Value: strconv.FormatInt(n.Value, 10)}, nil + case *nodes.StringLit: + return &ConstExprQ{Value: n.Value}, nil + case *nodes.FloatLit: + return &ConstExprQ{Value: n.Value}, nil + case *nodes.NullLit: + return &ConstExprQ{IsNull: true, Value: "NULL"}, nil + case *nodes.BoolLit: + if n.Value { + return &ConstExprQ{Value: "TRUE"}, nil + } + return &ConstExprQ{Value: "FALSE"}, nil + case *nodes.FuncCallExpr: + return analyzeFuncCall(c, n, scope) + case *nodes.ParenExpr: + return analyzeExpr(c, n.Expr, scope) + case *nodes.BinaryExpr: + return analyzeBinaryExpr(c, n, scope) + case *nodes.UnaryExpr: + return analyzeUnaryExpr(c, n, scope) + case *nodes.InExpr: + return analyzeInExpr(c, n, scope) + case *nodes.BetweenExpr: + return analyzeBetweenExpr(c, n, scope) + case *nodes.IsExpr: + return analyzeIsExpr(c, n, scope) + case *nodes.CaseExpr: + return analyzeCaseExpr(c, n, scope) + case *nodes.CastExpr: + return analyzeCastExpr(c, n, scope) + case *nodes.SubqueryExpr: + return analyzeScalarSubquery(c, n, scope) + case *nodes.ExistsExpr: + return analyzeExistsSubquery(c, n, scope) + default: + return nil, &Error{ + Code: 0, + SQLState: "HY000", + Message: "unsupported expression type in analyzer", + } + } +} + +// analyzeColumnRef resolves a column reference against the scope. +// Uses the Full resolution methods to support correlated subquery references +// (setting LevelsUp > 0 when the column comes from a parent scope). +func analyzeColumnRef(ref *nodes.ColumnRef, scope *analyzerScope) (AnalyzedExpr, error) { + if ref.Table != "" { + rteIdx, attNum, levelsUp, err := scope.resolveQualifiedColumnFull(ref.Table, ref.Column) + if err != nil { + return nil, err + } + return &VarExprQ{RangeIdx: rteIdx, AttNum: attNum, LevelsUp: levelsUp}, nil + } + rteIdx, attNum, levelsUp, err := scope.resolveColumnFull(ref.Column) + if err != nil { + return nil, err + } + return &VarExprQ{RangeIdx: rteIdx, AttNum: attNum, LevelsUp: levelsUp}, nil +} + +// analyzeFuncCall resolves a function call expression. +func analyzeFuncCall(c *Catalog, fc *nodes.FuncCallExpr, scope *analyzerScope) (AnalyzedExpr, error) { + args := make([]AnalyzedExpr, 0, len(fc.Args)) + for _, arg := range fc.Args { + a, err := analyzeExpr(c, arg, scope) + if err != nil { + return nil, err + } + args = append(args, a) + } + lower := strings.ToLower(fc.Name) + if lower == "coalesce" || lower == "ifnull" { + return &CoalesceExprQ{Args: args}, nil + } + result := &FuncCallExprQ{ + Name: lower, + Args: args, + IsAggregate: isAggregateFunc(fc.Name), + Distinct: fc.Distinct, + } + // Phase 3: populate return type from function type table. + result.ResultType = functionReturnType(result.Name, result.Args) + return result, nil +} + +// analyzeBinaryExpr resolves a binary expression. +func analyzeBinaryExpr(c *Catalog, expr *nodes.BinaryExpr, scope *analyzerScope) (AnalyzedExpr, error) { + left, err := analyzeExpr(c, expr.Left, scope) + if err != nil { + return nil, err + } + right, err := analyzeExpr(c, expr.Right, scope) + if err != nil { + return nil, err + } + + switch expr.Op { + case nodes.BinOpAnd: + return &BoolExprQ{Op: BoolAnd, Args: []AnalyzedExpr{left, right}}, nil + case nodes.BinOpOr: + return &BoolExprQ{Op: BoolOr, Args: []AnalyzedExpr{left, right}}, nil + default: + return &OpExprQ{Op: binaryOpToString(expr.Op), Left: left, Right: right}, nil + } +} + +// analyzeUnaryExpr resolves a unary expression. +func analyzeUnaryExpr(c *Catalog, expr *nodes.UnaryExpr, scope *analyzerScope) (AnalyzedExpr, error) { + operand, err := analyzeExpr(c, expr.Operand, scope) + if err != nil { + return nil, err + } + + switch expr.Op { + case nodes.UnaryPlus: + return operand, nil + case nodes.UnaryMinus: + return &OpExprQ{Op: "-", Right: operand}, nil + case nodes.UnaryNot: + return &BoolExprQ{Op: BoolNot, Args: []AnalyzedExpr{operand}}, nil + case nodes.UnaryBitNot: + return &OpExprQ{Op: "~", Right: operand}, nil + case nodes.UnaryBinary: + return &OpExprQ{Op: "BINARY", Right: operand}, nil + default: + return nil, &Error{ + Code: 0, + SQLState: "HY000", + Message: "unsupported unary operator in analyzer", + } + } +} + +// analyzeInExpr resolves an IN expression. +func analyzeInExpr(c *Catalog, expr *nodes.InExpr, scope *analyzerScope) (AnalyzedExpr, error) { + arg, err := analyzeExpr(c, expr.Expr, scope) + if err != nil { + return nil, err + } + + if expr.Select != nil { + // IN (SELECT ...) / NOT IN (SELECT ...) + innerQ, err := c.analyzeSelectStmtInternal(expr.Select, scope) + if err != nil { + return nil, err + } + subLink := &SubLinkExprQ{ + Kind: SubLinkIn, + TestExpr: arg, + Op: "=", + Subquery: innerQ, + } + if expr.Not { + return &BoolExprQ{Op: BoolNot, Args: []AnalyzedExpr{subLink}}, nil + } + return subLink, nil + } + + list := make([]AnalyzedExpr, 0, len(expr.List)) + for _, item := range expr.List { + a, err := analyzeExpr(c, item, scope) + if err != nil { + return nil, err + } + list = append(list, a) + } + + return &InListExprQ{Arg: arg, List: list, Negated: expr.Not}, nil +} + +// analyzeBetweenExpr resolves a BETWEEN expression. +func analyzeBetweenExpr(c *Catalog, expr *nodes.BetweenExpr, scope *analyzerScope) (AnalyzedExpr, error) { + arg, err := analyzeExpr(c, expr.Expr, scope) + if err != nil { + return nil, err + } + lower, err := analyzeExpr(c, expr.Low, scope) + if err != nil { + return nil, err + } + upper, err := analyzeExpr(c, expr.High, scope) + if err != nil { + return nil, err + } + + return &BetweenExprQ{Arg: arg, Lower: lower, Upper: upper, Negated: expr.Not}, nil +} + +// analyzeIsExpr resolves an IS expression. +func analyzeIsExpr(c *Catalog, expr *nodes.IsExpr, scope *analyzerScope) (AnalyzedExpr, error) { + arg, err := analyzeExpr(c, expr.Expr, scope) + if err != nil { + return nil, err + } + + switch expr.Test { + case nodes.IsNull: + return &NullTestExprQ{Arg: arg, IsNull: !expr.Not}, nil + case nodes.IsTrue: + op := "IS TRUE" + if expr.Not { + op = "IS NOT TRUE" + } + return &OpExprQ{Op: op, Left: arg}, nil + case nodes.IsFalse: + op := "IS FALSE" + if expr.Not { + op = "IS NOT FALSE" + } + return &OpExprQ{Op: op, Left: arg}, nil + case nodes.IsUnknown: + op := "IS UNKNOWN" + if expr.Not { + op = "IS NOT UNKNOWN" + } + return &OpExprQ{Op: op, Left: arg}, nil + default: + return nil, &Error{ + Code: 0, + SQLState: "HY000", + Message: "unsupported IS test type in analyzer", + } + } +} + +// analyzeCaseExpr resolves a CASE expression. +func analyzeCaseExpr(c *Catalog, expr *nodes.CaseExpr, scope *analyzerScope) (AnalyzedExpr, error) { + var testExpr AnalyzedExpr + if expr.Operand != nil { + var err error + testExpr, err = analyzeExpr(c, expr.Operand, scope) + if err != nil { + return nil, err + } + } + + whens := make([]*CaseWhenQ, 0, len(expr.Whens)) + for _, w := range expr.Whens { + cond, err := analyzeExpr(c, w.Cond, scope) + if err != nil { + return nil, err + } + then, err := analyzeExpr(c, w.Result, scope) + if err != nil { + return nil, err + } + whens = append(whens, &CaseWhenQ{Cond: cond, Then: then}) + } + + var def AnalyzedExpr + if expr.Default != nil { + var err error + def, err = analyzeExpr(c, expr.Default, scope) + if err != nil { + return nil, err + } + } + + return &CaseExprQ{TestExpr: testExpr, Args: whens, Default: def}, nil +} + +// analyzeCastExpr resolves a CAST expression. +func analyzeCastExpr(c *Catalog, expr *nodes.CastExpr, scope *analyzerScope) (AnalyzedExpr, error) { + arg, err := analyzeExpr(c, expr.Expr, scope) + if err != nil { + return nil, err + } + + var targetType *ResolvedType + if expr.TypeName != nil { + targetType = dataTypeToResolvedType(expr.TypeName) + } + + return &CastExprQ{Arg: arg, TargetType: targetType}, nil +} + +// dataTypeToResolvedType maps a parser DataType to a ResolvedType for CAST targets. +func dataTypeToResolvedType(dt *nodes.DataType) *ResolvedType { + name := strings.ToLower(dt.Name) + switch name { + case "signed", "signed integer": + return &ResolvedType{BaseType: BaseTypeBigInt} + case "unsigned", "unsigned integer": + return &ResolvedType{BaseType: BaseTypeBigInt, Unsigned: true} + case "char": + rt := &ResolvedType{BaseType: BaseTypeChar} + if dt.Length > 0 { + rt.Length = dt.Length + } + return rt + case "binary": + rt := &ResolvedType{BaseType: BaseTypeBinary} + if dt.Length > 0 { + rt.Length = dt.Length + } + return rt + case "decimal": + return &ResolvedType{BaseType: BaseTypeDecimal, Precision: dt.Length, Scale: dt.Scale} + case "date": + return &ResolvedType{BaseType: BaseTypeDate} + case "datetime": + return &ResolvedType{BaseType: BaseTypeDateTime} + case "time": + return &ResolvedType{BaseType: BaseTypeTime} + case "json": + return &ResolvedType{BaseType: BaseTypeJSON} + case "float": + return &ResolvedType{BaseType: BaseTypeFloat} + case "double": + return &ResolvedType{BaseType: BaseTypeDouble} + default: + return &ResolvedType{BaseType: BaseTypeUnknown} + } +} + +// analyzeScalarSubquery resolves a scalar subquery expression: (SELECT ...). +func analyzeScalarSubquery(c *Catalog, subq *nodes.SubqueryExpr, scope *analyzerScope) (AnalyzedExpr, error) { + innerQ, err := c.analyzeSelectStmtInternal(subq.Select, scope) + if err != nil { + return nil, err + } + return &SubLinkExprQ{ + Kind: SubLinkScalar, + Subquery: innerQ, + }, nil +} + +// analyzeExistsSubquery resolves an EXISTS (SELECT ...) expression. +func analyzeExistsSubquery(c *Catalog, expr *nodes.ExistsExpr, scope *analyzerScope) (AnalyzedExpr, error) { + innerQ, err := c.analyzeSelectStmtInternal(expr.Select, scope) + if err != nil { + return nil, err + } + return &SubLinkExprQ{ + Kind: SubLinkExists, + Subquery: innerQ, + }, nil +} + +// isAggregateFunc returns true if the function name is a known aggregate. +func isAggregateFunc(name string) bool { + switch strings.ToLower(name) { + case "count", "sum", "avg", "min", "max", + "group_concat", "json_arrayagg", "json_objectagg", + "bit_and", "bit_or", "bit_xor", + "std", "stddev", "stddev_pop", "stddev_samp", + "var_pop", "var_samp", "variance", + "any_value": + return true + } + return false +} diff --git a/tidb/catalog/analyze_targetlist.go b/tidb/catalog/analyze_targetlist.go new file mode 100644 index 00000000..78ca8f6b --- /dev/null +++ b/tidb/catalog/analyze_targetlist.go @@ -0,0 +1,195 @@ +package catalog + +import ( + "fmt" + "strconv" + + nodes "github.com/bytebase/omni/tidb/ast" +) + +// analyzeTargetList processes the SELECT list, populating q.TargetList. +func analyzeTargetList(c *Catalog, targetList []nodes.ExprNode, q *Query, scope *analyzerScope) error { + resNo := 1 + for _, item := range targetList { + entries, err := analyzeTargetEntry(c, item, q, scope, &resNo) + if err != nil { + return err + } + q.TargetList = append(q.TargetList, entries...) + } + return nil +} + +// analyzeTargetEntry processes one item from the SELECT list. It may return +// multiple TargetEntryQ values when expanding star expressions. +func analyzeTargetEntry(c *Catalog, item nodes.ExprNode, q *Query, scope *analyzerScope, resNo *int) ([]*TargetEntryQ, error) { + switch n := item.(type) { + case *nodes.ResTarget: + // Aliased expression: SELECT expr AS alias + analyzed, err := analyzeExpr(c, n.Val, scope) + if err != nil { + return nil, err + } + te := &TargetEntryQ{ + Expr: analyzed, + ResNo: *resNo, + ResName: n.Name, + } + fillProvenance(te, q) + *resNo++ + return []*TargetEntryQ{te}, nil + + case *nodes.StarExpr: + // SELECT * + return expandStar("", q, scope, resNo) + + case *nodes.ColumnRef: + if n.Star { + // SELECT t.* + return expandStar(n.Table, q, scope, resNo) + } + // Bare column reference: SELECT col + analyzed, err := analyzeExpr(c, item, scope) + if err != nil { + return nil, err + } + te := &TargetEntryQ{ + Expr: analyzed, + ResNo: *resNo, + ResName: deriveColumnName(item, analyzed), + } + fillProvenance(te, q) + *resNo++ + return []*TargetEntryQ{te}, nil + + default: + // Bare expression (literal, function call, etc.) + analyzed, err := analyzeExpr(c, item, scope) + if err != nil { + return nil, err + } + te := &TargetEntryQ{ + Expr: analyzed, + ResNo: *resNo, + ResName: deriveColumnName(item, analyzed), + } + fillProvenance(te, q) + *resNo++ + return []*TargetEntryQ{te}, nil + } +} + +// expandStar expands a star expression (SELECT * or SELECT t.*) into +// individual TargetEntryQ values. +func expandStar(tableName string, q *Query, scope *analyzerScope, resNo *int) ([]*TargetEntryQ, error) { + var result []*TargetEntryQ + + if tableName == "" { + // SELECT * — expand all tables in scope order. + for _, e := range scope.allEntries() { + entries, err := expandScopeEntry(e, q, scope, resNo) + if err != nil { + return nil, err + } + result = append(result, entries...) + } + } else { + // SELECT t.* — expand only the named table. + lower := toLower(tableName) + found := false + for _, e := range scope.allEntries() { + if e.name == lower { + entries, err := expandScopeEntry(e, q, scope, resNo) + if err != nil { + return nil, err + } + result = append(result, entries...) + found = true + break + } + } + if !found { + return nil, &Error{ + Code: ErrUnknownTable, + SQLState: sqlState(ErrUnknownTable), + Message: fmt.Sprintf("Unknown table '%s'", tableName), + } + } + } + + if len(result) == 0 { + return nil, fmt.Errorf("no columns found for star expansion") + } + return result, nil +} + +// expandScopeEntry expands all columns from a single scope entry, +// skipping columns marked as coalesced by USING/NATURAL. +func expandScopeEntry(e scopeEntry, q *Query, scope *analyzerScope, resNo *int) ([]*TargetEntryQ, error) { + var result []*TargetEntryQ + rte := q.RangeTable[e.rteIdx] + for i, col := range e.columns { + // Skip columns that were coalesced away by USING/NATURAL. + if scope.isCoalesced(e.name, col.Name) { + continue + } + te := &TargetEntryQ{ + Expr: &VarExprQ{ + RangeIdx: e.rteIdx, + AttNum: i + 1, + }, + ResNo: *resNo, + ResName: col.Name, + ResOrigDB: rte.DBName, + ResOrigTable: rte.TableName, + ResOrigCol: col.Name, + } + *resNo++ + result = append(result, te) + } + return result, nil +} + +// deriveColumnName generates the output column name for an unaliased expression. +func deriveColumnName(astNode nodes.ExprNode, _ AnalyzedExpr) string { + switch n := astNode.(type) { + case *nodes.ColumnRef: + return n.Column + case *nodes.IntLit: + return strconv.FormatInt(n.Value, 10) + case *nodes.StringLit: + return n.Value + case *nodes.FloatLit: + return n.Value + case *nodes.BoolLit: + if n.Value { + return "TRUE" + } + return "FALSE" + case *nodes.NullLit: + return "NULL" + case *nodes.FuncCallExpr: + return n.Name + case *nodes.ParenExpr: + return deriveColumnName(n.Expr, nil) + default: + return "?" + } +} + +// fillProvenance sets ResOrigDB/Table/Col when the expression is a VarExprQ. +func fillProvenance(te *TargetEntryQ, q *Query) { + v, ok := te.Expr.(*VarExprQ) + if !ok { + return + } + if v.RangeIdx < 0 || v.RangeIdx >= len(q.RangeTable) { + return + } + rte := q.RangeTable[v.RangeIdx] + te.ResOrigDB = rte.DBName + te.ResOrigTable = rte.TableName + if v.AttNum >= 1 && v.AttNum <= len(rte.ColNames) { + te.ResOrigCol = rte.ColNames[v.AttNum-1] + } +} diff --git a/tidb/catalog/analyze_test.go b/tidb/catalog/analyze_test.go new file mode 100644 index 00000000..a9884862 --- /dev/null +++ b/tidb/catalog/analyze_test.go @@ -0,0 +1,2757 @@ +package catalog + +import ( + "testing" + + nodes "github.com/bytebase/omni/tidb/ast" + "github.com/bytebase/omni/tidb/parser" +) + +// parseSelect parses a single SELECT statement and returns the SelectStmt node. +func parseSelect(t *testing.T, sql string) *nodes.SelectStmt { + t.Helper() + list, err := parser.Parse(sql) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if list.Len() != 1 { + t.Fatalf("expected 1 statement, got %d", list.Len()) + } + sel, ok := list.Items[0].(*nodes.SelectStmt) + if !ok { + t.Fatalf("expected *ast.SelectStmt, got %T", list.Items[0]) + } + return sel +} + +// TestAnalyze_1_1_BareLiteral tests SELECT 1: no tables, just a literal. +func TestAnalyze_1_1_BareLiteral(t *testing.T) { + c := wtSetup(t) + sel := parseSelect(t, "SELECT 1") + + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // CommandType + if q.CommandType != CmdSelect { + t.Errorf("CommandType: want CmdSelect, got %v", q.CommandType) + } + + // RangeTable should be empty (no FROM clause). + if len(q.RangeTable) != 0 { + t.Errorf("RangeTable: want 0 entries, got %d", len(q.RangeTable)) + } + + // JoinTree should be non-nil with empty FromList and nil Quals. + if q.JoinTree == nil { + t.Fatal("JoinTree: want non-nil, got nil") + } + if len(q.JoinTree.FromList) != 0 { + t.Errorf("JoinTree.FromList: want 0 entries, got %d", len(q.JoinTree.FromList)) + } + if q.JoinTree.Quals != nil { + t.Errorf("JoinTree.Quals: want nil, got %v", q.JoinTree.Quals) + } + + // TargetList should have exactly 1 entry. + if len(q.TargetList) != 1 { + t.Fatalf("TargetList: want 1 entry, got %d", len(q.TargetList)) + } + + te := q.TargetList[0] + + // ResNo = 1 + if te.ResNo != 1 { + t.Errorf("ResNo: want 1, got %d", te.ResNo) + } + + // ResName = "1" + if te.ResName != "1" { + t.Errorf("ResName: want %q, got %q", "1", te.ResName) + } + + // ResJunk = false + if te.ResJunk { + t.Errorf("ResJunk: want false, got true") + } + + // Expr should be ConstExprQ with Value "1". + constExpr, ok := te.Expr.(*ConstExprQ) + if !ok { + t.Fatalf("Expr: want *ConstExprQ, got %T", te.Expr) + } + if constExpr.Value != "1" { + t.Errorf("ConstExprQ.Value: want %q, got %q", "1", constExpr.Value) + } + if constExpr.IsNull { + t.Errorf("ConstExprQ.IsNull: want false, got true") + } +} + +// TestAnalyze_1_2_ColumnFromTable tests SELECT name FROM employees. +func TestAnalyze_1_2_ColumnFromTable(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE employees ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(200), + department_id INT NOT NULL, + salary DECIMAL(10,2) NOT NULL, + hire_date DATE NOT NULL, + is_active TINYINT(1) NOT NULL DEFAULT 1, + notes TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + )`) + + sel := parseSelect(t, "SELECT name FROM employees") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // RangeTable: 1 entry + if len(q.RangeTable) != 1 { + t.Fatalf("RangeTable: want 1 entry, got %d", len(q.RangeTable)) + } + rte := q.RangeTable[0] + if rte.Kind != RTERelation { + t.Errorf("RTE.Kind: want RTERelation, got %v", rte.Kind) + } + if rte.DBName != "testdb" { + t.Errorf("RTE.DBName: want %q, got %q", "testdb", rte.DBName) + } + if rte.TableName != "employees" { + t.Errorf("RTE.TableName: want %q, got %q", "employees", rte.TableName) + } + if rte.ERef != "employees" { + t.Errorf("RTE.ERef: want %q, got %q", "employees", rte.ERef) + } + + // RTE ColNames: 9 entries + if len(rte.ColNames) != 9 { + t.Errorf("RTE.ColNames: want 9 entries, got %d", len(rte.ColNames)) + } + + // JoinTree.FromList: 1 RangeTableRefQ + if q.JoinTree == nil { + t.Fatal("JoinTree: want non-nil, got nil") + } + if len(q.JoinTree.FromList) != 1 { + t.Fatalf("JoinTree.FromList: want 1 entry, got %d", len(q.JoinTree.FromList)) + } + rtRef, ok := q.JoinTree.FromList[0].(*RangeTableRefQ) + if !ok { + t.Fatalf("FromList[0]: want *RangeTableRefQ, got %T", q.JoinTree.FromList[0]) + } + if rtRef.RTIndex != 0 { + t.Errorf("RangeTableRefQ.RTIndex: want 0, got %d", rtRef.RTIndex) + } + + // TargetList: 1 entry + if len(q.TargetList) != 1 { + t.Fatalf("TargetList: want 1 entry, got %d", len(q.TargetList)) + } + te := q.TargetList[0] + + varExpr, ok := te.Expr.(*VarExprQ) + if !ok { + t.Fatalf("Expr: want *VarExprQ, got %T", te.Expr) + } + if varExpr.RangeIdx != 0 { + t.Errorf("VarExprQ.RangeIdx: want 0, got %d", varExpr.RangeIdx) + } + if varExpr.AttNum != 2 { + t.Errorf("VarExprQ.AttNum: want 2, got %d", varExpr.AttNum) + } + if te.ResName != "name" { + t.Errorf("ResName: want %q, got %q", "name", te.ResName) + } + + // Provenance + if te.ResOrigDB != "testdb" { + t.Errorf("ResOrigDB: want %q, got %q", "testdb", te.ResOrigDB) + } + if te.ResOrigTable != "employees" { + t.Errorf("ResOrigTable: want %q, got %q", "employees", te.ResOrigTable) + } + if te.ResOrigCol != "name" { + t.Errorf("ResOrigCol: want %q, got %q", "name", te.ResOrigCol) + } +} + +// TestAnalyze_1_3_MultipleColumnsWithAlias tests SELECT id, name AS employee_name, salary FROM employees. +func TestAnalyze_1_3_MultipleColumnsWithAlias(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE employees ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(200), + department_id INT NOT NULL, + salary DECIMAL(10,2) NOT NULL, + hire_date DATE NOT NULL, + is_active TINYINT(1) NOT NULL DEFAULT 1, + notes TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + )`) + + sel := parseSelect(t, "SELECT id, name AS employee_name, salary FROM employees") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // TargetList: 3 entries + if len(q.TargetList) != 3 { + t.Fatalf("TargetList: want 3 entries, got %d", len(q.TargetList)) + } + + // Entry 0: id + te0 := q.TargetList[0] + v0, ok := te0.Expr.(*VarExprQ) + if !ok { + t.Fatalf("TargetList[0].Expr: want *VarExprQ, got %T", te0.Expr) + } + if v0.AttNum != 1 { + t.Errorf("TargetList[0] AttNum: want 1, got %d", v0.AttNum) + } + if te0.ResName != "id" { + t.Errorf("TargetList[0] ResName: want %q, got %q", "id", te0.ResName) + } + if te0.ResNo != 1 { + t.Errorf("TargetList[0] ResNo: want 1, got %d", te0.ResNo) + } + + // Entry 1: name AS employee_name + te1 := q.TargetList[1] + v1, ok := te1.Expr.(*VarExprQ) + if !ok { + t.Fatalf("TargetList[1].Expr: want *VarExprQ, got %T", te1.Expr) + } + if v1.AttNum != 2 { + t.Errorf("TargetList[1] AttNum: want 2, got %d", v1.AttNum) + } + if te1.ResName != "employee_name" { + t.Errorf("TargetList[1] ResName: want %q, got %q", "employee_name", te1.ResName) + } + if te1.ResNo != 2 { + t.Errorf("TargetList[1] ResNo: want 2, got %d", te1.ResNo) + } + + // Entry 2: salary + te2 := q.TargetList[2] + v2, ok := te2.Expr.(*VarExprQ) + if !ok { + t.Fatalf("TargetList[2].Expr: want *VarExprQ, got %T", te2.Expr) + } + if v2.AttNum != 5 { + t.Errorf("TargetList[2] AttNum: want 5, got %d", v2.AttNum) + } + if te2.ResName != "salary" { + t.Errorf("TargetList[2] ResName: want %q, got %q", "salary", te2.ResName) + } + if te2.ResNo != 3 { + t.Errorf("TargetList[2] ResNo: want 3, got %d", te2.ResNo) + } + + // All should have provenance + for i, te := range q.TargetList { + if te.ResOrigDB == "" { + t.Errorf("TargetList[%d] ResOrigDB: want non-empty, got empty", i) + } + if te.ResOrigTable == "" { + t.Errorf("TargetList[%d] ResOrigTable: want non-empty, got empty", i) + } + if te.ResOrigCol == "" { + t.Errorf("TargetList[%d] ResOrigCol: want non-empty, got empty", i) + } + } +} + +// TestAnalyze_1_4_StarExpansion tests SELECT * FROM employees. +func TestAnalyze_1_4_StarExpansion(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE employees ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(200), + department_id INT NOT NULL, + salary DECIMAL(10,2) NOT NULL, + hire_date DATE NOT NULL, + is_active TINYINT(1) NOT NULL DEFAULT 1, + notes TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + )`) + + sel := parseSelect(t, "SELECT * FROM employees") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // TargetList: 9 entries (one per column) + if len(q.TargetList) != 9 { + t.Fatalf("TargetList: want 9 entries, got %d", len(q.TargetList)) + } + + wantNames := []string{"id", "name", "email", "department_id", "salary", "hire_date", "is_active", "notes", "created_at"} + + for i, te := range q.TargetList { + v, ok := te.Expr.(*VarExprQ) + if !ok { + t.Fatalf("TargetList[%d].Expr: want *VarExprQ, got %T", i, te.Expr) + } + wantAttNum := i + 1 + if v.AttNum != wantAttNum { + t.Errorf("TargetList[%d] AttNum: want %d, got %d", i, wantAttNum, v.AttNum) + } + if te.ResName != wantNames[i] { + t.Errorf("TargetList[%d] ResName: want %q, got %q", i, wantNames[i], te.ResName) + } + if te.ResOrigCol == "" { + t.Errorf("TargetList[%d] ResOrigCol: want non-empty, got empty", i) + } + } +} + +// TestAnalyze_1_5_QualifiedStar tests SELECT e.* FROM employees AS e. +func TestAnalyze_1_5_QualifiedStar(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE employees ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(200), + department_id INT NOT NULL, + salary DECIMAL(10,2) NOT NULL, + hire_date DATE NOT NULL, + is_active TINYINT(1) NOT NULL DEFAULT 1, + notes TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + )`) + + sel := parseSelect(t, "SELECT e.* FROM employees AS e") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // RangeTable[0]: Alias="e", ERef="e" + if len(q.RangeTable) != 1 { + t.Fatalf("RangeTable: want 1 entry, got %d", len(q.RangeTable)) + } + rte := q.RangeTable[0] + if rte.Alias != "e" { + t.Errorf("RTE.Alias: want %q, got %q", "e", rte.Alias) + } + if rte.ERef != "e" { + t.Errorf("RTE.ERef: want %q, got %q", "e", rte.ERef) + } + + // TargetList: 9 entries same as 1.4 + if len(q.TargetList) != 9 { + t.Fatalf("TargetList: want 9 entries, got %d", len(q.TargetList)) + } + + for i, te := range q.TargetList { + v, ok := te.Expr.(*VarExprQ) + if !ok { + t.Fatalf("TargetList[%d].Expr: want *VarExprQ, got %T", i, te.Expr) + } + if v.RangeIdx != 0 { + t.Errorf("TargetList[%d] RangeIdx: want 0, got %d", i, v.RangeIdx) + } + wantAttNum := i + 1 + if v.AttNum != wantAttNum { + t.Errorf("TargetList[%d] AttNum: want %d, got %d", i, wantAttNum, v.AttNum) + } + } +} + +// employeesTableDDL is the shared DDL for the employees table used across Batch 2 tests. +const employeesTableDDL = `CREATE TABLE employees ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(200), + department_id INT NOT NULL, + salary DECIMAL(10,2) NOT NULL, + hire_date DATE NOT NULL, + is_active TINYINT(1) NOT NULL DEFAULT 1, + notes TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +)` + +// TestAnalyze_2_1_WhereSimpleEquality tests WHERE id = 1. +func TestAnalyze_2_1_WhereSimpleEquality(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT name FROM employees WHERE id = 1") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if q.JoinTree == nil { + t.Fatal("JoinTree: want non-nil, got nil") + } + if q.JoinTree.Quals == nil { + t.Fatal("JoinTree.Quals: want non-nil, got nil") + } + + opExpr, ok := q.JoinTree.Quals.(*OpExprQ) + if !ok { + t.Fatalf("Quals: want *OpExprQ, got %T", q.JoinTree.Quals) + } + if opExpr.Op != "=" { + t.Errorf("OpExprQ.Op: want %q, got %q", "=", opExpr.Op) + } + + leftVar, ok := opExpr.Left.(*VarExprQ) + if !ok { + t.Fatalf("OpExprQ.Left: want *VarExprQ, got %T", opExpr.Left) + } + if leftVar.AttNum != 1 { + t.Errorf("Left.AttNum: want 1, got %d", leftVar.AttNum) + } + + rightConst, ok := opExpr.Right.(*ConstExprQ) + if !ok { + t.Fatalf("OpExprQ.Right: want *ConstExprQ, got %T", opExpr.Right) + } + if rightConst.Value != "1" { + t.Errorf("Right.Value: want %q, got %q", "1", rightConst.Value) + } +} + +// TestAnalyze_2_2_WhereAndOr tests WHERE is_active = 1 AND (department_id = 1 OR department_id = 2). +func TestAnalyze_2_2_WhereAndOr(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT name FROM employees WHERE is_active = 1 AND (department_id = 1 OR department_id = 2)") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if q.JoinTree == nil || q.JoinTree.Quals == nil { + t.Fatal("JoinTree.Quals: want non-nil") + } + + // Top level: BoolExprQ with BoolAnd + andExpr, ok := q.JoinTree.Quals.(*BoolExprQ) + if !ok { + t.Fatalf("Quals: want *BoolExprQ, got %T", q.JoinTree.Quals) + } + if andExpr.Op != BoolAnd { + t.Errorf("BoolExprQ.Op: want BoolAnd, got %v", andExpr.Op) + } + if len(andExpr.Args) != 2 { + t.Fatalf("BoolExprQ.Args: want 2 entries, got %d", len(andExpr.Args)) + } + + // Left: is_active = 1 + leftOp, ok := andExpr.Args[0].(*OpExprQ) + if !ok { + t.Fatalf("Args[0]: want *OpExprQ, got %T", andExpr.Args[0]) + } + if leftOp.Op != "=" { + t.Errorf("Args[0].Op: want %q, got %q", "=", leftOp.Op) + } + + // Right: BoolExprQ with BoolOr + orExpr, ok := andExpr.Args[1].(*BoolExprQ) + if !ok { + t.Fatalf("Args[1]: want *BoolExprQ, got %T", andExpr.Args[1]) + } + if orExpr.Op != BoolOr { + t.Errorf("Args[1].Op: want BoolOr, got %v", orExpr.Op) + } + if len(orExpr.Args) != 2 { + t.Fatalf("BoolOr.Args: want 2 entries, got %d", len(orExpr.Args)) + } + + // Each OR arg should be OpExprQ with "=" + for i, arg := range orExpr.Args { + op, ok := arg.(*OpExprQ) + if !ok { + t.Fatalf("BoolOr.Args[%d]: want *OpExprQ, got %T", i, arg) + } + if op.Op != "=" { + t.Errorf("BoolOr.Args[%d].Op: want %q, got %q", i, "=", op.Op) + } + } +} + +// TestAnalyze_2_3_WhereIn tests WHERE department_id IN (1, 2, 3). +func TestAnalyze_2_3_WhereIn(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT name FROM employees WHERE department_id IN (1, 2, 3)") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if q.JoinTree == nil || q.JoinTree.Quals == nil { + t.Fatal("JoinTree.Quals: want non-nil") + } + + inExpr, ok := q.JoinTree.Quals.(*InListExprQ) + if !ok { + t.Fatalf("Quals: want *InListExprQ, got %T", q.JoinTree.Quals) + } + + argVar, ok := inExpr.Arg.(*VarExprQ) + if !ok { + t.Fatalf("InListExprQ.Arg: want *VarExprQ, got %T", inExpr.Arg) + } + if argVar.AttNum != 4 { + t.Errorf("Arg.AttNum: want 4, got %d", argVar.AttNum) + } + + if len(inExpr.List) != 3 { + t.Fatalf("InListExprQ.List: want 3 entries, got %d", len(inExpr.List)) + } + for i, item := range inExpr.List { + if _, ok := item.(*ConstExprQ); !ok { + t.Errorf("List[%d]: want *ConstExprQ, got %T", i, item) + } + } + + if inExpr.Negated { + t.Errorf("InListExprQ.Negated: want false, got true") + } +} + +// TestAnalyze_2_4_WhereBetween tests WHERE salary BETWEEN 50000 AND 100000. +func TestAnalyze_2_4_WhereBetween(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT name FROM employees WHERE salary BETWEEN 50000 AND 100000") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if q.JoinTree == nil || q.JoinTree.Quals == nil { + t.Fatal("JoinTree.Quals: want non-nil") + } + + betExpr, ok := q.JoinTree.Quals.(*BetweenExprQ) + if !ok { + t.Fatalf("Quals: want *BetweenExprQ, got %T", q.JoinTree.Quals) + } + + argVar, ok := betExpr.Arg.(*VarExprQ) + if !ok { + t.Fatalf("BetweenExprQ.Arg: want *VarExprQ, got %T", betExpr.Arg) + } + if argVar.AttNum != 5 { + t.Errorf("Arg.AttNum: want 5, got %d", argVar.AttNum) + } + + lowerConst, ok := betExpr.Lower.(*ConstExprQ) + if !ok { + t.Fatalf("BetweenExprQ.Lower: want *ConstExprQ, got %T", betExpr.Lower) + } + if lowerConst.Value != "50000" { + t.Errorf("Lower.Value: want %q, got %q", "50000", lowerConst.Value) + } + + upperConst, ok := betExpr.Upper.(*ConstExprQ) + if !ok { + t.Fatalf("BetweenExprQ.Upper: want *ConstExprQ, got %T", betExpr.Upper) + } + if upperConst.Value != "100000" { + t.Errorf("Upper.Value: want %q, got %q", "100000", upperConst.Value) + } + + if betExpr.Negated { + t.Errorf("BetweenExprQ.Negated: want false, got true") + } +} + +// TestAnalyze_2_5_WhereIsNull tests WHERE email IS NOT NULL AND notes IS NULL. +func TestAnalyze_2_5_WhereIsNull(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT name FROM employees WHERE email IS NOT NULL AND notes IS NULL") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if q.JoinTree == nil || q.JoinTree.Quals == nil { + t.Fatal("JoinTree.Quals: want non-nil") + } + + andExpr, ok := q.JoinTree.Quals.(*BoolExprQ) + if !ok { + t.Fatalf("Quals: want *BoolExprQ, got %T", q.JoinTree.Quals) + } + if andExpr.Op != BoolAnd { + t.Errorf("BoolExprQ.Op: want BoolAnd, got %v", andExpr.Op) + } + if len(andExpr.Args) != 2 { + t.Fatalf("BoolExprQ.Args: want 2 entries, got %d", len(andExpr.Args)) + } + + // Args[0]: email IS NOT NULL → NullTestExprQ{IsNull: false} + nt0, ok := andExpr.Args[0].(*NullTestExprQ) + if !ok { + t.Fatalf("Args[0]: want *NullTestExprQ, got %T", andExpr.Args[0]) + } + if nt0.IsNull { + t.Errorf("Args[0].IsNull: want false, got true") + } + + // Args[1]: notes IS NULL → NullTestExprQ{IsNull: true} + nt1, ok := andExpr.Args[1].(*NullTestExprQ) + if !ok { + t.Fatalf("Args[1]: want *NullTestExprQ, got %T", andExpr.Args[1]) + } + if !nt1.IsNull { + t.Errorf("Args[1].IsNull: want true, got false") + } +} + +// TestAnalyze_3_1_GroupByColumn tests GROUP BY with a column reference. +func TestAnalyze_3_1_GroupByColumn(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT department_id, COUNT(*) FROM employees GROUP BY department_id") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // TargetList: 2 entries + if len(q.TargetList) != 2 { + t.Fatalf("TargetList: want 2 entries, got %d", len(q.TargetList)) + } + + // TargetList[0]: department_id → VarExprQ{AttNum:4} + te0 := q.TargetList[0] + v0, ok := te0.Expr.(*VarExprQ) + if !ok { + t.Fatalf("TargetList[0].Expr: want *VarExprQ, got %T", te0.Expr) + } + if v0.AttNum != 4 { + t.Errorf("TargetList[0] AttNum: want 4, got %d", v0.AttNum) + } + if te0.ResName != "department_id" { + t.Errorf("TargetList[0] ResName: want %q, got %q", "department_id", te0.ResName) + } + + // TargetList[1]: COUNT(*) → FuncCallExprQ with IsAggregate=true + te1 := q.TargetList[1] + fc1, ok := te1.Expr.(*FuncCallExprQ) + if !ok { + t.Fatalf("TargetList[1].Expr: want *FuncCallExprQ, got %T", te1.Expr) + } + if fc1.Name != "count" { + t.Errorf("FuncCallExprQ.Name: want %q, got %q", "count", fc1.Name) + } + if !fc1.IsAggregate { + t.Errorf("FuncCallExprQ.IsAggregate: want true, got false") + } + + // GroupClause: 1 entry referencing TargetIdx=1 + if len(q.GroupClause) != 1 { + t.Fatalf("GroupClause: want 1 entry, got %d", len(q.GroupClause)) + } + if q.GroupClause[0].TargetIdx != 1 { + t.Errorf("GroupClause[0].TargetIdx: want 1, got %d", q.GroupClause[0].TargetIdx) + } + + // HasAggs = true + if !q.HasAggs { + t.Errorf("HasAggs: want true, got false") + } +} + +// TestAnalyze_3_2_GroupByMultipleAggregates tests GROUP BY with multiple aggregates. +func TestAnalyze_3_2_GroupByMultipleAggregates(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT department_id, COUNT(*) AS cnt, SUM(salary) AS total_salary, AVG(salary) AS avg_salary FROM employees GROUP BY department_id") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // TargetList: 4 entries + if len(q.TargetList) != 4 { + t.Fatalf("TargetList: want 4 entries, got %d", len(q.TargetList)) + } + + // TargetList[0]: department_id + if _, ok := q.TargetList[0].Expr.(*VarExprQ); !ok { + t.Fatalf("TargetList[0].Expr: want *VarExprQ, got %T", q.TargetList[0].Expr) + } + + // TargetList[1..3]: all FuncCallExprQ with IsAggregate=true + for i := 1; i <= 3; i++ { + fc, ok := q.TargetList[i].Expr.(*FuncCallExprQ) + if !ok { + t.Fatalf("TargetList[%d].Expr: want *FuncCallExprQ, got %T", i, q.TargetList[i].Expr) + } + if !fc.IsAggregate { + t.Errorf("TargetList[%d] IsAggregate: want true, got false", i) + } + } + + // GroupClause references TargetIdx=1 + if len(q.GroupClause) != 1 { + t.Fatalf("GroupClause: want 1 entry, got %d", len(q.GroupClause)) + } + if q.GroupClause[0].TargetIdx != 1 { + t.Errorf("GroupClause[0].TargetIdx: want 1, got %d", q.GroupClause[0].TargetIdx) + } +} + +// TestAnalyze_3_3_Having tests HAVING clause with aggregate condition. +func TestAnalyze_3_3_Having(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT department_id, COUNT(*) AS cnt FROM employees GROUP BY department_id HAVING COUNT(*) > 5") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // HavingQual should be OpExprQ{Op:">", Left:FuncCallExprQ, Right:ConstExprQ} + if q.HavingQual == nil { + t.Fatal("HavingQual: want non-nil, got nil") + } + opExpr, ok := q.HavingQual.(*OpExprQ) + if !ok { + t.Fatalf("HavingQual: want *OpExprQ, got %T", q.HavingQual) + } + if opExpr.Op != ">" { + t.Errorf("OpExprQ.Op: want %q, got %q", ">", opExpr.Op) + } + + leftFC, ok := opExpr.Left.(*FuncCallExprQ) + if !ok { + t.Fatalf("OpExprQ.Left: want *FuncCallExprQ, got %T", opExpr.Left) + } + if !leftFC.IsAggregate { + t.Errorf("Left.IsAggregate: want true, got false") + } + + rightConst, ok := opExpr.Right.(*ConstExprQ) + if !ok { + t.Fatalf("OpExprQ.Right: want *ConstExprQ, got %T", opExpr.Right) + } + if rightConst.Value != "5" { + t.Errorf("Right.Value: want %q, got %q", "5", rightConst.Value) + } + + // HasAggs should be true + if !q.HasAggs { + t.Errorf("HasAggs: want true, got false") + } +} + +// TestAnalyze_3_4_CountDistinct tests COUNT(DISTINCT column). +func TestAnalyze_3_4_CountDistinct(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT COUNT(DISTINCT department_id) FROM employees") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.TargetList) != 1 { + t.Fatalf("TargetList: want 1 entry, got %d", len(q.TargetList)) + } + + fc, ok := q.TargetList[0].Expr.(*FuncCallExprQ) + if !ok { + t.Fatalf("TargetList[0].Expr: want *FuncCallExprQ, got %T", q.TargetList[0].Expr) + } + if fc.Name != "count" { + t.Errorf("FuncCallExprQ.Name: want %q, got %q", "count", fc.Name) + } + if !fc.IsAggregate { + t.Errorf("FuncCallExprQ.IsAggregate: want true, got false") + } + if !fc.Distinct { + t.Errorf("FuncCallExprQ.Distinct: want true, got false") + } + + // Args should contain one VarExprQ with AttNum=4 (department_id) + if len(fc.Args) != 1 { + t.Fatalf("FuncCallExprQ.Args: want 1 entry, got %d", len(fc.Args)) + } + argVar, ok := fc.Args[0].(*VarExprQ) + if !ok { + t.Fatalf("Args[0]: want *VarExprQ, got %T", fc.Args[0]) + } + if argVar.AttNum != 4 { + t.Errorf("Args[0].AttNum: want 4, got %d", argVar.AttNum) + } + + // HasAggs should be true + if !q.HasAggs { + t.Errorf("HasAggs: want true, got false") + } +} + +// TestAnalyze_3_5_GroupByOrdinal tests GROUP BY with ordinal reference. +func TestAnalyze_3_5_GroupByOrdinal(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT department_id, COUNT(*) FROM employees GROUP BY 1") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // GroupClause: 1 entry referencing TargetIdx=1 (ordinal 1 → first target entry) + if len(q.GroupClause) != 1 { + t.Fatalf("GroupClause: want 1 entry, got %d", len(q.GroupClause)) + } + if q.GroupClause[0].TargetIdx != 1 { + t.Errorf("GroupClause[0].TargetIdx: want 1, got %d", q.GroupClause[0].TargetIdx) + } + + // HasAggs should be true + if !q.HasAggs { + t.Errorf("HasAggs: want true, got false") + } +} + +// TestAnalyze_5_1_ArithmeticAlias tests SELECT name, salary * 12 AS annual_salary FROM employees. +func TestAnalyze_5_1_ArithmeticAlias(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT name, salary * 12 AS annual_salary FROM employees") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.TargetList) != 2 { + t.Fatalf("TargetList: want 2 entries, got %d", len(q.TargetList)) + } + + te := q.TargetList[1] + if te.ResName != "annual_salary" { + t.Errorf("ResName: want %q, got %q", "annual_salary", te.ResName) + } + // Computed column: no provenance + if te.ResOrigDB != "" { + t.Errorf("ResOrigDB: want empty, got %q", te.ResOrigDB) + } + if te.ResOrigTable != "" { + t.Errorf("ResOrigTable: want empty, got %q", te.ResOrigTable) + } + if te.ResOrigCol != "" { + t.Errorf("ResOrigCol: want empty, got %q", te.ResOrigCol) + } + + opExpr, ok := te.Expr.(*OpExprQ) + if !ok { + t.Fatalf("Expr: want *OpExprQ, got %T", te.Expr) + } + if opExpr.Op != "*" { + t.Errorf("OpExprQ.Op: want %q, got %q", "*", opExpr.Op) + } + left, ok := opExpr.Left.(*VarExprQ) + if !ok { + t.Fatalf("Left: want *VarExprQ, got %T", opExpr.Left) + } + if left.AttNum != 5 { + t.Errorf("Left.AttNum: want 5, got %d", left.AttNum) + } + right, ok := opExpr.Right.(*ConstExprQ) + if !ok { + t.Fatalf("Right: want *ConstExprQ, got %T", opExpr.Right) + } + if right.Value != "12" { + t.Errorf("Right.Value: want %q, got %q", "12", right.Value) + } +} + +// TestAnalyze_5_2_ConcatFunc tests SELECT CONCAT(name, ' <', email, '>') AS display FROM employees. +func TestAnalyze_5_2_ConcatFunc(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT CONCAT(name, ' <', email, '>') AS display FROM employees") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.TargetList) != 1 { + t.Fatalf("TargetList: want 1 entry, got %d", len(q.TargetList)) + } + + te := q.TargetList[0] + if te.ResName != "display" { + t.Errorf("ResName: want %q, got %q", "display", te.ResName) + } + + fc, ok := te.Expr.(*FuncCallExprQ) + if !ok { + t.Fatalf("Expr: want *FuncCallExprQ, got %T", te.Expr) + } + if fc.Name != "concat" { + t.Errorf("FuncCallExprQ.Name: want %q, got %q", "concat", fc.Name) + } + if fc.IsAggregate { + t.Errorf("FuncCallExprQ.IsAggregate: want false, got true") + } + if len(fc.Args) != 4 { + t.Fatalf("FuncCallExprQ.Args: want 4 entries, got %d", len(fc.Args)) + } + + // Args[0]: VarExprQ (name) + if _, ok := fc.Args[0].(*VarExprQ); !ok { + t.Errorf("Args[0]: want *VarExprQ, got %T", fc.Args[0]) + } + // Args[1]: ConstExprQ (' <') + if c1, ok := fc.Args[1].(*ConstExprQ); !ok { + t.Errorf("Args[1]: want *ConstExprQ, got %T", fc.Args[1]) + } else if c1.Value != " <" { + t.Errorf("Args[1].Value: want %q, got %q", " <", c1.Value) + } + // Args[2]: VarExprQ (email) + if _, ok := fc.Args[2].(*VarExprQ); !ok { + t.Errorf("Args[2]: want *VarExprQ, got %T", fc.Args[2]) + } + // Args[3]: ConstExprQ ('>') + if c3, ok := fc.Args[3].(*ConstExprQ); !ok { + t.Errorf("Args[3]: want *ConstExprQ, got %T", fc.Args[3]) + } else if c3.Value != ">" { + t.Errorf("Args[3].Value: want %q, got %q", ">", c3.Value) + } +} + +// TestAnalyze_5_3_SearchedCase tests searched CASE WHEN with ELSE. +func TestAnalyze_5_3_SearchedCase(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT name, CASE WHEN salary > 100000 THEN 'high' WHEN salary > 50000 THEN 'mid' ELSE 'low' END AS salary_band FROM employees") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.TargetList) != 2 { + t.Fatalf("TargetList: want 2 entries, got %d", len(q.TargetList)) + } + + te := q.TargetList[1] + if te.ResName != "salary_band" { + t.Errorf("ResName: want %q, got %q", "salary_band", te.ResName) + } + + caseExpr, ok := te.Expr.(*CaseExprQ) + if !ok { + t.Fatalf("Expr: want *CaseExprQ, got %T", te.Expr) + } + if caseExpr.TestExpr != nil { + t.Errorf("TestExpr: want nil (searched CASE), got %T", caseExpr.TestExpr) + } + if len(caseExpr.Args) != 2 { + t.Fatalf("Args: want 2 CaseWhenQs, got %d", len(caseExpr.Args)) + } + + // WHEN salary > 100000 THEN 'high' + w0 := caseExpr.Args[0] + if _, ok := w0.Cond.(*OpExprQ); !ok { + t.Errorf("Args[0].Cond: want *OpExprQ, got %T", w0.Cond) + } + if then0, ok := w0.Then.(*ConstExprQ); !ok { + t.Errorf("Args[0].Then: want *ConstExprQ, got %T", w0.Then) + } else if then0.Value != "high" { + t.Errorf("Args[0].Then.Value: want %q, got %q", "high", then0.Value) + } + + // WHEN salary > 50000 THEN 'mid' + w1 := caseExpr.Args[1] + if _, ok := w1.Cond.(*OpExprQ); !ok { + t.Errorf("Args[1].Cond: want *OpExprQ, got %T", w1.Cond) + } + if then1, ok := w1.Then.(*ConstExprQ); !ok { + t.Errorf("Args[1].Then: want *ConstExprQ, got %T", w1.Then) + } else if then1.Value != "mid" { + t.Errorf("Args[1].Then.Value: want %q, got %q", "mid", then1.Value) + } + + // ELSE 'low' + if caseExpr.Default == nil { + t.Fatal("Default: want non-nil, got nil") + } + defConst, ok := caseExpr.Default.(*ConstExprQ) + if !ok { + t.Fatalf("Default: want *ConstExprQ, got %T", caseExpr.Default) + } + if defConst.Value != "low" { + t.Errorf("Default.Value: want %q, got %q", "low", defConst.Value) + } +} + +// TestAnalyze_5_4_SimpleCase tests simple CASE with operand and no ELSE. +func TestAnalyze_5_4_SimpleCase(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT CASE department_id WHEN 1 THEN 'eng' WHEN 2 THEN 'sales' END FROM employees") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.TargetList) != 1 { + t.Fatalf("TargetList: want 1 entry, got %d", len(q.TargetList)) + } + + te := q.TargetList[0] + caseExpr, ok := te.Expr.(*CaseExprQ) + if !ok { + t.Fatalf("Expr: want *CaseExprQ, got %T", te.Expr) + } + + // TestExpr: VarExprQ for department_id (attnum=4) + testVar, ok := caseExpr.TestExpr.(*VarExprQ) + if !ok { + t.Fatalf("TestExpr: want *VarExprQ, got %T", caseExpr.TestExpr) + } + if testVar.AttNum != 4 { + t.Errorf("TestExpr.AttNum: want 4, got %d", testVar.AttNum) + } + + if len(caseExpr.Args) != 2 { + t.Fatalf("Args: want 2 CaseWhenQs, got %d", len(caseExpr.Args)) + } + + // WHEN 1 THEN 'eng' + if c0, ok := caseExpr.Args[0].Cond.(*ConstExprQ); !ok { + t.Errorf("Args[0].Cond: want *ConstExprQ, got %T", caseExpr.Args[0].Cond) + } else if c0.Value != "1" { + t.Errorf("Args[0].Cond.Value: want %q, got %q", "1", c0.Value) + } + if t0, ok := caseExpr.Args[0].Then.(*ConstExprQ); !ok { + t.Errorf("Args[0].Then: want *ConstExprQ, got %T", caseExpr.Args[0].Then) + } else if t0.Value != "eng" { + t.Errorf("Args[0].Then.Value: want %q, got %q", "eng", t0.Value) + } + + // No ELSE + if caseExpr.Default != nil { + t.Errorf("Default: want nil, got %T", caseExpr.Default) + } +} + +// TestAnalyze_5_5_CoalesceIfnull tests COALESCE and IFNULL producing CoalesceExprQ. +func TestAnalyze_5_5_CoalesceIfnull(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT COALESCE(email, 'no-email') AS email, IFNULL(notes, '') AS notes FROM employees") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.TargetList) != 2 { + t.Fatalf("TargetList: want 2 entries, got %d", len(q.TargetList)) + } + + // TargetList[0]: COALESCE(email, 'no-email') + te0 := q.TargetList[0] + if te0.ResName != "email" { + t.Errorf("TargetList[0].ResName: want %q, got %q", "email", te0.ResName) + } + coal0, ok := te0.Expr.(*CoalesceExprQ) + if !ok { + t.Fatalf("TargetList[0].Expr: want *CoalesceExprQ, got %T", te0.Expr) + } + if len(coal0.Args) != 2 { + t.Fatalf("coal0.Args: want 2 entries, got %d", len(coal0.Args)) + } + if v, ok := coal0.Args[0].(*VarExprQ); !ok { + t.Errorf("coal0.Args[0]: want *VarExprQ, got %T", coal0.Args[0]) + } else if v.AttNum != 3 { + t.Errorf("coal0.Args[0].AttNum: want 3, got %d", v.AttNum) + } + if c0, ok := coal0.Args[1].(*ConstExprQ); !ok { + t.Errorf("coal0.Args[1]: want *ConstExprQ, got %T", coal0.Args[1]) + } else if c0.Value != "no-email" { + t.Errorf("coal0.Args[1].Value: want %q, got %q", "no-email", c0.Value) + } + + // TargetList[1]: IFNULL(notes, '') + te1 := q.TargetList[1] + if te1.ResName != "notes" { + t.Errorf("TargetList[1].ResName: want %q, got %q", "notes", te1.ResName) + } + coal1, ok := te1.Expr.(*CoalesceExprQ) + if !ok { + t.Fatalf("TargetList[1].Expr: want *CoalesceExprQ, got %T", te1.Expr) + } + if len(coal1.Args) != 2 { + t.Fatalf("coal1.Args: want 2 entries, got %d", len(coal1.Args)) + } + if v, ok := coal1.Args[0].(*VarExprQ); !ok { + t.Errorf("coal1.Args[0]: want *VarExprQ, got %T", coal1.Args[0]) + } else if v.AttNum != 8 { + t.Errorf("coal1.Args[0].AttNum: want 8, got %d", v.AttNum) + } + if c1, ok := coal1.Args[1].(*ConstExprQ); !ok { + t.Errorf("coal1.Args[1]: want *ConstExprQ, got %T", coal1.Args[1]) + } else if c1.Value != "" { + t.Errorf("coal1.Args[1].Value: want %q, got %q", "", c1.Value) + } +} + +// TestAnalyze_5_6_CastSigned tests CAST(salary AS SIGNED) producing CastExprQ. +func TestAnalyze_5_6_CastSigned(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT CAST(salary AS SIGNED) AS salary_int FROM employees") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.TargetList) != 1 { + t.Fatalf("TargetList: want 1 entry, got %d", len(q.TargetList)) + } + + te := q.TargetList[0] + if te.ResName != "salary_int" { + t.Errorf("ResName: want %q, got %q", "salary_int", te.ResName) + } + + castExpr, ok := te.Expr.(*CastExprQ) + if !ok { + t.Fatalf("Expr: want *CastExprQ, got %T", te.Expr) + } + + // Arg: VarExprQ for salary (attnum=5) + argVar, ok := castExpr.Arg.(*VarExprQ) + if !ok { + t.Fatalf("Arg: want *VarExprQ, got %T", castExpr.Arg) + } + if argVar.AttNum != 5 { + t.Errorf("Arg.AttNum: want 5, got %d", argVar.AttNum) + } + + // TargetType: SIGNED -> BaseTypeBigInt + if castExpr.TargetType == nil { + t.Fatal("TargetType: want non-nil, got nil") + } + if castExpr.TargetType.BaseType != BaseTypeBigInt { + t.Errorf("TargetType.BaseType: want BaseTypeBigInt, got %v", castExpr.TargetType.BaseType) + } +} + +// TestAnalyze_4_1_OrderByDesc tests ORDER BY with DESC on a column in the SELECT list. +func TestAnalyze_4_1_OrderByDesc(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE employees ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(200), + department_id INT NOT NULL, + salary DECIMAL(10,2) NOT NULL, + hire_date DATE NOT NULL, + is_active TINYINT(1) NOT NULL DEFAULT 1, + notes TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + )`) + + sel := parseSelect(t, "SELECT name, salary FROM employees ORDER BY salary DESC") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // TargetList: 2 entries (no junk needed since salary is in SELECT list). + if len(q.TargetList) != 2 { + t.Fatalf("TargetList: want 2 entries, got %d", len(q.TargetList)) + } + + // SortClause: 1 entry. + if len(q.SortClause) != 1 { + t.Fatalf("SortClause: want 1 entry, got %d", len(q.SortClause)) + } + sc := q.SortClause[0] + if sc.TargetIdx != 2 { + t.Errorf("SortClause[0].TargetIdx: want 2, got %d", sc.TargetIdx) + } + if !sc.Descending { + t.Errorf("SortClause[0].Descending: want true, got false") + } + // DESC → NullsFirst=false (MySQL default). + if sc.NullsFirst { + t.Errorf("SortClause[0].NullsFirst: want false, got true") + } +} + +// TestAnalyze_4_2_OrderByJunk tests ORDER BY a column NOT in the SELECT list. +func TestAnalyze_4_2_OrderByJunk(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE employees ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(200), + department_id INT NOT NULL, + salary DECIMAL(10,2) NOT NULL, + hire_date DATE NOT NULL, + is_active TINYINT(1) NOT NULL DEFAULT 1, + notes TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + )`) + + sel := parseSelect(t, "SELECT name FROM employees ORDER BY salary") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // TargetList: 2 entries — name (non-junk) + salary (junk). + if len(q.TargetList) != 2 { + t.Fatalf("TargetList: want 2 entries, got %d", len(q.TargetList)) + } + + te0 := q.TargetList[0] + if te0.ResName != "name" { + t.Errorf("TargetList[0].ResName: want %q, got %q", "name", te0.ResName) + } + if te0.ResJunk { + t.Errorf("TargetList[0].ResJunk: want false, got true") + } + + te1 := q.TargetList[1] + if te1.ResName != "salary" { + t.Errorf("TargetList[1].ResName: want %q, got %q", "salary", te1.ResName) + } + if !te1.ResJunk { + t.Errorf("TargetList[1].ResJunk: want true, got false") + } + + // SortClause: 1 entry pointing to junk target. + if len(q.SortClause) != 1 { + t.Fatalf("SortClause: want 1 entry, got %d", len(q.SortClause)) + } + sc := q.SortClause[0] + if sc.TargetIdx != 2 { + t.Errorf("SortClause[0].TargetIdx: want 2, got %d", sc.TargetIdx) + } + // ASC default → NullsFirst=true. + if !sc.NullsFirst { + t.Errorf("SortClause[0].NullsFirst: want true, got false") + } +} + +// TestAnalyze_4_3_OrderByLimitOffset tests ORDER BY + LIMIT + OFFSET. +func TestAnalyze_4_3_OrderByLimitOffset(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE employees ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(200), + department_id INT NOT NULL, + salary DECIMAL(10,2) NOT NULL, + hire_date DATE NOT NULL, + is_active TINYINT(1) NOT NULL DEFAULT 1, + notes TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + )`) + + sel := parseSelect(t, "SELECT name FROM employees ORDER BY id LIMIT 10 OFFSET 20") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // LimitCount = ConstExprQ{Value:"10"} + if q.LimitCount == nil { + t.Fatal("LimitCount: want non-nil, got nil") + } + lc, ok := q.LimitCount.(*ConstExprQ) + if !ok { + t.Fatalf("LimitCount: want *ConstExprQ, got %T", q.LimitCount) + } + if lc.Value != "10" { + t.Errorf("LimitCount.Value: want %q, got %q", "10", lc.Value) + } + + // LimitOffset = ConstExprQ{Value:"20"} + if q.LimitOffset == nil { + t.Fatal("LimitOffset: want non-nil, got nil") + } + lo, ok := q.LimitOffset.(*ConstExprQ) + if !ok { + t.Fatalf("LimitOffset: want *ConstExprQ, got %T", q.LimitOffset) + } + if lo.Value != "20" { + t.Errorf("LimitOffset.Value: want %q, got %q", "20", lo.Value) + } +} + +// TestAnalyze_4_4_LimitComma tests LIMIT offset,count syntax. +func TestAnalyze_4_4_LimitComma(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE employees ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(200), + department_id INT NOT NULL, + salary DECIMAL(10,2) NOT NULL, + hire_date DATE NOT NULL, + is_active TINYINT(1) NOT NULL DEFAULT 1, + notes TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + )`) + + sel := parseSelect(t, "SELECT name FROM employees LIMIT 20, 10") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // Parser normalizes LIMIT 20,10 to Count=10, Offset=20. + if q.LimitCount == nil { + t.Fatal("LimitCount: want non-nil, got nil") + } + lc, ok := q.LimitCount.(*ConstExprQ) + if !ok { + t.Fatalf("LimitCount: want *ConstExprQ, got %T", q.LimitCount) + } + if lc.Value != "10" { + t.Errorf("LimitCount.Value: want %q, got %q", "10", lc.Value) + } + + if q.LimitOffset == nil { + t.Fatal("LimitOffset: want non-nil, got nil") + } + lo, ok := q.LimitOffset.(*ConstExprQ) + if !ok { + t.Fatalf("LimitOffset: want *ConstExprQ, got %T", q.LimitOffset) + } + if lo.Value != "20" { + t.Errorf("LimitOffset.Value: want %q, got %q", "20", lo.Value) + } +} + +// TestAnalyze_4_5_Distinct tests SELECT DISTINCT. +func TestAnalyze_4_5_Distinct(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE employees ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(200), + department_id INT NOT NULL, + salary DECIMAL(10,2) NOT NULL, + hire_date DATE NOT NULL, + is_active TINYINT(1) NOT NULL DEFAULT 1, + notes TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + )`) + + sel := parseSelect(t, "SELECT DISTINCT department_id FROM employees") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // Distinct should be true. + if !q.Distinct { + t.Errorf("Distinct: want true, got false") + } + + // TargetList: 1 non-junk entry. + if len(q.TargetList) != 1 { + t.Fatalf("TargetList: want 1 entry, got %d", len(q.TargetList)) + } + te := q.TargetList[0] + if te.ResJunk { + t.Errorf("TargetList[0].ResJunk: want false, got true") + } + if te.ResName != "department_id" { + t.Errorf("TargetList[0].ResName: want %q, got %q", "department_id", te.ResName) + } +} + +// TestAnalyze_6_1_ColumnAliasAmbiguity tests that WHERE resolves against base columns, not SELECT aliases. +func TestAnalyze_6_1_ColumnAliasAmbiguity(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT name AS id FROM employees WHERE id = 1") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // TargetList[0]: name column aliased as "id" + if len(q.TargetList) != 1 { + t.Fatalf("TargetList: want 1 entry, got %d", len(q.TargetList)) + } + te := q.TargetList[0] + if te.ResName != "id" { + t.Errorf("ResName: want %q, got %q", "id", te.ResName) + } + varExpr, ok := te.Expr.(*VarExprQ) + if !ok { + t.Fatalf("Expr: want *VarExprQ, got %T", te.Expr) + } + if varExpr.AttNum != 2 { + t.Errorf("TargetList[0] AttNum: want 2 (name column), got %d", varExpr.AttNum) + } + + // WHERE id = 1: id must resolve to the base column id (AttNum 1), not the alias. + if q.JoinTree == nil || q.JoinTree.Quals == nil { + t.Fatal("JoinTree.Quals: want non-nil") + } + opExpr, ok := q.JoinTree.Quals.(*OpExprQ) + if !ok { + t.Fatalf("Quals: want *OpExprQ, got %T", q.JoinTree.Quals) + } + leftVar, ok := opExpr.Left.(*VarExprQ) + if !ok { + t.Fatalf("Left: want *VarExprQ, got %T", opExpr.Left) + } + if leftVar.AttNum != 1 { + t.Errorf("WHERE id AttNum: want 1 (base column id), got %d", leftVar.AttNum) + } +} + +// TestAnalyze_6_2_SameColumnTwice tests SELECT referencing the same column twice. +func TestAnalyze_6_2_SameColumnTwice(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT salary, salary + 1000 AS raised FROM employees") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.TargetList) != 2 { + t.Fatalf("TargetList: want 2 entries, got %d", len(q.TargetList)) + } + + // TargetList[0]: plain salary column + v0, ok := q.TargetList[0].Expr.(*VarExprQ) + if !ok { + t.Fatalf("TargetList[0].Expr: want *VarExprQ, got %T", q.TargetList[0].Expr) + } + if v0.RangeIdx != 0 { + t.Errorf("TargetList[0] RangeIdx: want 0, got %d", v0.RangeIdx) + } + if v0.AttNum != 5 { + t.Errorf("TargetList[0] AttNum: want 5, got %d", v0.AttNum) + } + + // TargetList[1]: salary + 1000 + opExpr, ok := q.TargetList[1].Expr.(*OpExprQ) + if !ok { + t.Fatalf("TargetList[1].Expr: want *OpExprQ, got %T", q.TargetList[1].Expr) + } + if opExpr.Op != "+" { + t.Errorf("OpExprQ.Op: want %q, got %q", "+", opExpr.Op) + } + + leftVar, ok := opExpr.Left.(*VarExprQ) + if !ok { + t.Fatalf("Left: want *VarExprQ, got %T", opExpr.Left) + } + if leftVar.RangeIdx != 0 { + t.Errorf("Left.RangeIdx: want 0, got %d", leftVar.RangeIdx) + } + if leftVar.AttNum != 5 { + t.Errorf("Left.AttNum: want 5, got %d", leftVar.AttNum) + } + + rightConst, ok := opExpr.Right.(*ConstExprQ) + if !ok { + t.Fatalf("Right: want *ConstExprQ, got %T", opExpr.Right) + } + if rightConst.Value != "1000" { + t.Errorf("Right.Value: want %q, got %q", "1000", rightConst.Value) + } + + // The two VarExprQ references to salary should be distinct pointers but same RangeIdx/AttNum. + if v0 == leftVar { + t.Errorf("VarExprQ pointers: want distinct objects, got same pointer") + } + if v0.RangeIdx != leftVar.RangeIdx || v0.AttNum != leftVar.AttNum { + t.Errorf("VarExprQ values: want same RangeIdx/AttNum, got (%d,%d) vs (%d,%d)", + v0.RangeIdx, v0.AttNum, leftVar.RangeIdx, leftVar.AttNum) + } +} + +// TestAnalyze_6_3_ThreePartQualified tests three-part column qualification: schema.table.column. +func TestAnalyze_6_3_ThreePartQualified(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + + sel := parseSelect(t, "SELECT testdb.employees.name FROM testdb.employees") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // RangeTable: 1 entry with DBName="testdb", TableName="employees" + if len(q.RangeTable) != 1 { + t.Fatalf("RangeTable: want 1 entry, got %d", len(q.RangeTable)) + } + rte := q.RangeTable[0] + if rte.DBName != "testdb" { + t.Errorf("RTE.DBName: want %q, got %q", "testdb", rte.DBName) + } + if rte.TableName != "employees" { + t.Errorf("RTE.TableName: want %q, got %q", "employees", rte.TableName) + } + + // TargetList: 1 entry — name resolves to AttNum 2 + if len(q.TargetList) != 1 { + t.Fatalf("TargetList: want 1 entry, got %d", len(q.TargetList)) + } + varExpr, ok := q.TargetList[0].Expr.(*VarExprQ) + if !ok { + t.Fatalf("Expr: want *VarExprQ, got %T", q.TargetList[0].Expr) + } + if varExpr.RangeIdx != 0 { + t.Errorf("RangeIdx: want 0, got %d", varExpr.RangeIdx) + } + if varExpr.AttNum != 2 { + t.Errorf("AttNum: want 2, got %d", varExpr.AttNum) + } +} + +// TestAnalyze_6_4_NoFromClause tests SELECT with no FROM clause — pure expressions. +func TestAnalyze_6_4_NoFromClause(t *testing.T) { + c := wtSetup(t) + + sel := parseSelect(t, "SELECT 1 + 2, 'hello', NOW()") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // RangeTable empty + if len(q.RangeTable) != 0 { + t.Errorf("RangeTable: want 0 entries, got %d", len(q.RangeTable)) + } + + // JoinTree.FromList empty + if q.JoinTree == nil { + t.Fatal("JoinTree: want non-nil, got nil") + } + if len(q.JoinTree.FromList) != 0 { + t.Errorf("JoinTree.FromList: want 0 entries, got %d", len(q.JoinTree.FromList)) + } + + // TargetList: 3 entries + if len(q.TargetList) != 3 { + t.Fatalf("TargetList: want 3 entries, got %d", len(q.TargetList)) + } + + // [0]: 1 + 2 → OpExprQ + opExpr, ok := q.TargetList[0].Expr.(*OpExprQ) + if !ok { + t.Fatalf("TargetList[0].Expr: want *OpExprQ, got %T", q.TargetList[0].Expr) + } + if opExpr.Op != "+" { + t.Errorf("OpExprQ.Op: want %q, got %q", "+", opExpr.Op) + } + leftConst, ok := opExpr.Left.(*ConstExprQ) + if !ok { + t.Fatalf("Left: want *ConstExprQ, got %T", opExpr.Left) + } + if leftConst.Value != "1" { + t.Errorf("Left.Value: want %q, got %q", "1", leftConst.Value) + } + rightConst, ok := opExpr.Right.(*ConstExprQ) + if !ok { + t.Fatalf("Right: want *ConstExprQ, got %T", opExpr.Right) + } + if rightConst.Value != "2" { + t.Errorf("Right.Value: want %q, got %q", "2", rightConst.Value) + } + + // [1]: 'hello' → ConstExprQ + strConst, ok := q.TargetList[1].Expr.(*ConstExprQ) + if !ok { + t.Fatalf("TargetList[1].Expr: want *ConstExprQ, got %T", q.TargetList[1].Expr) + } + if strConst.Value != "hello" { + t.Errorf("ConstExprQ.Value: want %q, got %q", "hello", strConst.Value) + } + + // [2]: NOW() → FuncCallExprQ + fc, ok := q.TargetList[2].Expr.(*FuncCallExprQ) + if !ok { + t.Fatalf("TargetList[2].Expr: want *FuncCallExprQ, got %T", q.TargetList[2].Expr) + } + if fc.Name != "now" { + t.Errorf("FuncCallExprQ.Name: want %q, got %q", "now", fc.Name) + } + if len(fc.Args) != 0 { + t.Errorf("FuncCallExprQ.Args: want 0 entries, got %d", len(fc.Args)) + } + if fc.IsAggregate { + t.Errorf("FuncCallExprQ.IsAggregate: want false, got true") + } +} + +// --------------------------------------------------------------------------- +// Phase 1b — Batches 7-10: JOINs, USING/NATURAL, FROM subqueries +// --------------------------------------------------------------------------- + +const departmentsTableDDL = `CREATE TABLE departments ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + budget DECIMAL(12,2) +)` + +const projectsTableDDL = `CREATE TABLE projects ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + department_id INT NOT NULL, + lead_id INT, + start_date DATE +)` + +// setupJoinTables creates employees, departments, and projects tables. +func setupJoinTables(t *testing.T) *Catalog { + t.Helper() + c := wtSetup(t) + wtExec(t, c, employeesTableDDL) + wtExec(t, c, departmentsTableDDL) + wtExec(t, c, projectsTableDDL) + return c +} + +// --- Batch 7: Basic JOINs --- + +// TestAnalyze_7_1_InnerJoin tests INNER JOIN with ON condition. +func TestAnalyze_7_1_InnerJoin(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT e.name, d.name AS dept_name FROM employees e INNER JOIN departments d ON e.department_id = d.id`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // 3 RTEs: employees, departments, RTEJoin + if len(q.RangeTable) != 3 { + t.Fatalf("RangeTable: want 3 entries, got %d", len(q.RangeTable)) + } + if q.RangeTable[0].Kind != RTERelation || q.RangeTable[0].ERef != "e" { + t.Errorf("RTE[0]: want RTERelation 'e', got kind=%d eref=%q", q.RangeTable[0].Kind, q.RangeTable[0].ERef) + } + if q.RangeTable[1].Kind != RTERelation || q.RangeTable[1].ERef != "d" { + t.Errorf("RTE[1]: want RTERelation 'd', got kind=%d eref=%q", q.RangeTable[1].Kind, q.RangeTable[1].ERef) + } + if q.RangeTable[2].Kind != RTEJoin { + t.Errorf("RTE[2]: want RTEJoin, got kind=%d", q.RangeTable[2].Kind) + } + if q.RangeTable[2].JoinType != JoinInner { + t.Errorf("RTE[2].JoinType: want JoinInner, got %d", q.RangeTable[2].JoinType) + } + + // JoinTree.FromList should have 1 JoinExprNodeQ. + if len(q.JoinTree.FromList) != 1 { + t.Fatalf("FromList: want 1 entry, got %d", len(q.JoinTree.FromList)) + } + je, ok := q.JoinTree.FromList[0].(*JoinExprNodeQ) + if !ok { + t.Fatalf("FromList[0]: want *JoinExprNodeQ, got %T", q.JoinTree.FromList[0]) + } + if je.JoinType != JoinInner { + t.Errorf("JoinExprNodeQ.JoinType: want JoinInner, got %d", je.JoinType) + } + if je.Natural { + t.Errorf("JoinExprNodeQ.Natural: want false, got true") + } + if je.Quals == nil { + t.Error("JoinExprNodeQ.Quals: want non-nil ON condition, got nil") + } + // ON condition should be OpExprQ with "=" + onOp, ok := je.Quals.(*OpExprQ) + if !ok { + t.Fatalf("Quals: want *OpExprQ, got %T", je.Quals) + } + if onOp.Op != "=" { + t.Errorf("ON Op: want %q, got %q", "=", onOp.Op) + } + + // TargetList: 2 entries. + if len(q.TargetList) != 2 { + t.Fatalf("TargetList: want 2, got %d", len(q.TargetList)) + } + if q.TargetList[0].ResName != "name" { + t.Errorf("TargetList[0].ResName: want %q, got %q", "name", q.TargetList[0].ResName) + } + if q.TargetList[1].ResName != "dept_name" { + t.Errorf("TargetList[1].ResName: want %q, got %q", "dept_name", q.TargetList[1].ResName) + } +} + +// TestAnalyze_7_2_LeftJoin tests LEFT JOIN. +func TestAnalyze_7_2_LeftJoin(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT e.name, d.name AS dept_name FROM employees e LEFT JOIN departments d ON e.department_id = d.id`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.RangeTable) != 3 { + t.Fatalf("RangeTable: want 3, got %d", len(q.RangeTable)) + } + if q.RangeTable[2].JoinType != JoinLeft { + t.Errorf("RTE[2].JoinType: want JoinLeft, got %d", q.RangeTable[2].JoinType) + } + je := q.JoinTree.FromList[0].(*JoinExprNodeQ) + if je.JoinType != JoinLeft { + t.Errorf("JoinExprNodeQ.JoinType: want JoinLeft, got %d", je.JoinType) + } +} + +// TestAnalyze_7_3_RightJoin tests RIGHT JOIN. +func TestAnalyze_7_3_RightJoin(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT e.name, d.name AS dept_name FROM employees e RIGHT JOIN departments d ON e.department_id = d.id`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.RangeTable) != 3 { + t.Fatalf("RangeTable: want 3, got %d", len(q.RangeTable)) + } + if q.RangeTable[2].JoinType != JoinRight { + t.Errorf("RTE[2].JoinType: want JoinRight, got %d", q.RangeTable[2].JoinType) + } + je := q.JoinTree.FromList[0].(*JoinExprNodeQ) + if je.JoinType != JoinRight { + t.Errorf("JoinExprNodeQ.JoinType: want JoinRight, got %d", je.JoinType) + } +} + +// TestAnalyze_7_4_CrossJoin tests CROSS JOIN with no condition. +func TestAnalyze_7_4_CrossJoin(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT e.name, d.name FROM employees e CROSS JOIN departments d`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.RangeTable) != 3 { + t.Fatalf("RangeTable: want 3, got %d", len(q.RangeTable)) + } + if q.RangeTable[2].JoinType != JoinCross { + t.Errorf("RTE[2].JoinType: want JoinCross, got %d", q.RangeTable[2].JoinType) + } + je := q.JoinTree.FromList[0].(*JoinExprNodeQ) + if je.JoinType != JoinCross { + t.Errorf("JoinExprNodeQ.JoinType: want JoinCross, got %d", je.JoinType) + } + if je.Quals != nil { + t.Errorf("JoinExprNodeQ.Quals: want nil for CROSS JOIN, got %v", je.Quals) + } +} + +// TestAnalyze_7_5_CommaJoin tests comma-separated tables (implicit cross join). +func TestAnalyze_7_5_CommaJoin(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT e.name, d.name FROM employees e, departments d WHERE e.department_id = d.id`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // Comma join: 2 RTEs only (no RTEJoin). + if len(q.RangeTable) != 2 { + t.Fatalf("RangeTable: want 2, got %d", len(q.RangeTable)) + } + if q.RangeTable[0].Kind != RTERelation { + t.Errorf("RTE[0]: want RTERelation, got %d", q.RangeTable[0].Kind) + } + if q.RangeTable[1].Kind != RTERelation { + t.Errorf("RTE[1]: want RTERelation, got %d", q.RangeTable[1].Kind) + } + + // JoinTree.FromList has 2 RangeTableRefQ entries. + if len(q.JoinTree.FromList) != 2 { + t.Fatalf("FromList: want 2, got %d", len(q.JoinTree.FromList)) + } + for i, item := range q.JoinTree.FromList { + if _, ok := item.(*RangeTableRefQ); !ok { + t.Errorf("FromList[%d]: want *RangeTableRefQ, got %T", i, item) + } + } + + // WHERE in JoinTree.Quals. + if q.JoinTree.Quals == nil { + t.Error("JoinTree.Quals: want non-nil WHERE, got nil") + } +} + +// --- Batch 8: USING/NATURAL --- + +// TestAnalyze_8_1_JoinUsing tests JOIN ... USING (col). +func TestAnalyze_8_1_JoinUsing(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT e.name, p.name AS project_name FROM employees e JOIN projects p USING (department_id)`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.RangeTable) != 3 { + t.Fatalf("RangeTable: want 3, got %d", len(q.RangeTable)) + } + + je := q.JoinTree.FromList[0].(*JoinExprNodeQ) + if len(je.UsingClause) != 1 || je.UsingClause[0] != "department_id" { + t.Errorf("UsingClause: want [department_id], got %v", je.UsingClause) + } + if je.Natural { + t.Errorf("Natural: want false, got true") + } + + // RTEJoin should also have JoinUsing. + rteJoin := q.RangeTable[2] + if len(rteJoin.JoinUsing) != 1 || rteJoin.JoinUsing[0] != "department_id" { + t.Errorf("RTE JoinUsing: want [department_id], got %v", rteJoin.JoinUsing) + } + + // TargetList: 2 entries. + if len(q.TargetList) != 2 { + t.Fatalf("TargetList: want 2, got %d", len(q.TargetList)) + } +} + +// TestAnalyze_8_2_NaturalJoin tests NATURAL JOIN with star expansion. +func TestAnalyze_8_2_NaturalJoin(t *testing.T) { + c := setupJoinTables(t) + // employees has: id, name, email, department_id, salary, hire_date, is_active, notes, created_at (9 cols) + // departments has: id, name, budget (3 cols) + // Shared columns: id, name → coalesced + // Result: id, name (from left), email, department_id, salary, hire_date, is_active, notes, created_at, budget + // = 10 columns + sel := parseSelect(t, `SELECT * FROM employees e NATURAL JOIN departments d`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + je := q.JoinTree.FromList[0].(*JoinExprNodeQ) + if !je.Natural { + t.Errorf("Natural: want true, got false") + } + if je.JoinType != JoinInner { + t.Errorf("JoinType: want JoinInner, got %d", je.JoinType) + } + + // Check USING columns are the shared ones: id, name. + if len(je.UsingClause) != 2 { + t.Fatalf("UsingClause: want 2, got %d", len(je.UsingClause)) + } + + // Star expansion: should produce 10 columns (9 from employees + 1 remaining from departments). + // The right-side id and name are coalesced away. + if len(q.TargetList) != 10 { + t.Errorf("TargetList: want 10 columns (NATURAL JOIN coalesced), got %d", len(q.TargetList)) + } +} + +// TestAnalyze_8_3_NaturalLeftJoin tests NATURAL LEFT JOIN. +func TestAnalyze_8_3_NaturalLeftJoin(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT * FROM employees e NATURAL LEFT JOIN departments d`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + je := q.JoinTree.FromList[0].(*JoinExprNodeQ) + if !je.Natural { + t.Errorf("Natural: want true, got false") + } + if je.JoinType != JoinLeft { + t.Errorf("JoinType: want JoinLeft, got %d", je.JoinType) + } + + // Same column count as 8.2: 10 columns. + if len(q.TargetList) != 10 { + t.Errorf("TargetList: want 10 columns, got %d", len(q.TargetList)) + } +} + +// TestAnalyze_8_4_JoinUsingMultiple tests USING with multiple columns. +func TestAnalyze_8_4_JoinUsingMultiple(t *testing.T) { + c := setupJoinTables(t) + // employees and departments share: id, name. USING (id, name). + sel := parseSelect(t, `SELECT * FROM employees e JOIN departments d USING (id, name)`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + je := q.JoinTree.FromList[0].(*JoinExprNodeQ) + if len(je.UsingClause) != 2 { + t.Fatalf("UsingClause: want 2, got %d", len(je.UsingClause)) + } + if je.UsingClause[0] != "id" || je.UsingClause[1] != "name" { + t.Errorf("UsingClause: want [id, name], got %v", je.UsingClause) + } + + // Star expansion: 10 columns (same as NATURAL — same shared columns). + if len(q.TargetList) != 10 { + t.Errorf("TargetList: want 10 columns, got %d", len(q.TargetList)) + } +} + +// --- Batch 9: FROM subqueries --- + +// TestAnalyze_9_1_SimpleSubquery tests FROM (SELECT ...) AS sub. +func TestAnalyze_9_1_SimpleSubquery(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT sub.total FROM (SELECT COUNT(*) AS total FROM employees) AS sub`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // 1 RTE: RTESubquery. + if len(q.RangeTable) != 1 { + t.Fatalf("RangeTable: want 1, got %d", len(q.RangeTable)) + } + rte := q.RangeTable[0] + if rte.Kind != RTESubquery { + t.Errorf("RTE Kind: want RTESubquery, got %d", rte.Kind) + } + if rte.ERef != "sub" { + t.Errorf("RTE ERef: want %q, got %q", "sub", rte.ERef) + } + if len(rte.ColNames) != 1 || rte.ColNames[0] != "total" { + t.Errorf("ColNames: want [total], got %v", rte.ColNames) + } + + // Inner query should have HasAggs = true. + if rte.Subquery == nil { + t.Fatal("Subquery: want non-nil, got nil") + } + if !rte.Subquery.HasAggs { + t.Errorf("Inner Query HasAggs: want true, got false") + } + + // Outer TargetList: 1 column. + if len(q.TargetList) != 1 { + t.Fatalf("TargetList: want 1, got %d", len(q.TargetList)) + } + if q.TargetList[0].ResName != "total" { + t.Errorf("ResName: want %q, got %q", "total", q.TargetList[0].ResName) + } +} + +// TestAnalyze_9_2_SubqueryWithGroupBy tests FROM subquery with GROUP BY. +func TestAnalyze_9_2_SubqueryWithGroupBy(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT x.dept, x.cnt FROM (SELECT department_id AS dept, COUNT(*) AS cnt FROM employees GROUP BY department_id) AS x`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.RangeTable) != 1 { + t.Fatalf("RangeTable: want 1, got %d", len(q.RangeTable)) + } + rte := q.RangeTable[0] + if rte.Kind != RTESubquery { + t.Errorf("RTE Kind: want RTESubquery, got %d", rte.Kind) + } + if len(rte.ColNames) != 2 { + t.Fatalf("ColNames: want 2, got %d", len(rte.ColNames)) + } + if rte.ColNames[0] != "dept" || rte.ColNames[1] != "cnt" { + t.Errorf("ColNames: want [dept, cnt], got %v", rte.ColNames) + } + + // Outer TargetList: 2 columns. + if len(q.TargetList) != 2 { + t.Fatalf("TargetList: want 2, got %d", len(q.TargetList)) + } +} + +// TestAnalyze_9_3_SubqueryJoinedWithTable tests JOIN between a table and a FROM subquery. +func TestAnalyze_9_3_SubqueryJoinedWithTable(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT e.name, sub.avg_sal FROM employees e JOIN (SELECT department_id, AVG(salary) AS avg_sal FROM employees GROUP BY department_id) AS sub ON e.department_id = sub.department_id`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // RTEs: employees (0), RTESubquery (1), RTEJoin (2). + if len(q.RangeTable) != 3 { + t.Fatalf("RangeTable: want 3, got %d", len(q.RangeTable)) + } + if q.RangeTable[0].Kind != RTERelation { + t.Errorf("RTE[0] Kind: want RTERelation, got %d", q.RangeTable[0].Kind) + } + if q.RangeTable[1].Kind != RTESubquery { + t.Errorf("RTE[1] Kind: want RTESubquery, got %d", q.RangeTable[1].Kind) + } + if q.RangeTable[2].Kind != RTEJoin { + t.Errorf("RTE[2] Kind: want RTEJoin, got %d", q.RangeTable[2].Kind) + } + + // Inner subquery has 2 columns. + subRTE := q.RangeTable[1] + if len(subRTE.ColNames) != 2 { + t.Fatalf("Sub ColNames: want 2, got %d", len(subRTE.ColNames)) + } + + // JoinExprNodeQ with ON condition. + je := q.JoinTree.FromList[0].(*JoinExprNodeQ) + if je.Quals == nil { + t.Error("ON condition: want non-nil, got nil") + } + + // Outer TargetList: 2 columns. + if len(q.TargetList) != 2 { + t.Fatalf("TargetList: want 2, got %d", len(q.TargetList)) + } +} + +// --- Batch 10: Multi-table edges --- + +// TestAnalyze_10_1_ThreeWayJoin tests a three-way JOIN. +func TestAnalyze_10_1_ThreeWayJoin(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT e.name, d.name, p.name FROM employees e INNER JOIN departments d ON e.department_id = d.id INNER JOIN projects p ON d.id = p.department_id`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // 5 RTEs: employees(0), departments(1), RTEJoin for first join(2), projects(3), RTEJoin for second join(4). + if len(q.RangeTable) != 5 { + t.Fatalf("RangeTable: want 5, got %d", len(q.RangeTable)) + } + if q.RangeTable[0].Kind != RTERelation { + t.Errorf("RTE[0]: want RTERelation, got %d", q.RangeTable[0].Kind) + } + if q.RangeTable[1].Kind != RTERelation { + t.Errorf("RTE[1]: want RTERelation, got %d", q.RangeTable[1].Kind) + } + if q.RangeTable[2].Kind != RTEJoin { + t.Errorf("RTE[2]: want RTEJoin, got %d", q.RangeTable[2].Kind) + } + if q.RangeTable[3].Kind != RTERelation { + t.Errorf("RTE[3]: want RTERelation, got %d", q.RangeTable[3].Kind) + } + if q.RangeTable[4].Kind != RTEJoin { + t.Errorf("RTE[4]: want RTEJoin, got %d", q.RangeTable[4].Kind) + } + + // JoinTree.FromList should have 1 outer JoinExprNodeQ. + if len(q.JoinTree.FromList) != 1 { + t.Fatalf("FromList: want 1, got %d", len(q.JoinTree.FromList)) + } + outerJoin, ok := q.JoinTree.FromList[0].(*JoinExprNodeQ) + if !ok { + t.Fatalf("FromList[0]: want *JoinExprNodeQ, got %T", q.JoinTree.FromList[0]) + } + // The left side of the outer join should be another JoinExprNodeQ (the first join). + innerJoin, ok := outerJoin.Left.(*JoinExprNodeQ) + if !ok { + t.Fatalf("OuterJoin.Left: want *JoinExprNodeQ, got %T", outerJoin.Left) + } + _ = innerJoin + + // TargetList: 3 columns. + if len(q.TargetList) != 3 { + t.Fatalf("TargetList: want 3, got %d", len(q.TargetList)) + } +} + +// TestAnalyze_10_2_StarJoin tests SELECT * with a two-table JOIN. +func TestAnalyze_10_2_StarJoin(t *testing.T) { + c := setupJoinTables(t) + // employees: 9 cols, departments: 3 cols → 12 total. + sel := parseSelect(t, `SELECT * FROM employees e JOIN departments d ON e.department_id = d.id`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // Star expansion for JOIN without USING: all columns from both tables. + // employees(9) + departments(3) = 12. + if len(q.TargetList) != 12 { + t.Errorf("TargetList: want 12 columns, got %d", len(q.TargetList)) + } +} + +// TestAnalyze_10_3_AmbiguousColumn tests that unqualified 'name' is ambiguous across two tables. +func TestAnalyze_10_3_AmbiguousColumn(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT name FROM employees e JOIN departments d ON e.department_id = d.id`) + _, err := c.AnalyzeSelectStmt(sel) + assertError(t, err, 1052) // ambiguous column +} + +// --------------------------------------------------------------------------- +// Phase 1c — Batches 11-13: WHERE subqueries, CTEs, set operations +// --------------------------------------------------------------------------- + +// --- Batch 11: WHERE subqueries --- + +// TestAnalyze_11_1_ScalarSubqueryInWhere tests WHERE salary > (SELECT AVG(salary) FROM employees). +func TestAnalyze_11_1_ScalarSubqueryInWhere(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT name FROM employees WHERE salary > (SELECT AVG(salary) FROM employees)`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // WHERE should be an OpExprQ with ">" operator. + op, ok := q.JoinTree.Quals.(*OpExprQ) + if !ok { + t.Fatalf("Quals: want *OpExprQ, got %T", q.JoinTree.Quals) + } + if op.Op != ">" { + t.Errorf("Op: want >, got %s", op.Op) + } + + // Left side: VarExprQ for salary. + if _, ok := op.Left.(*VarExprQ); !ok { + t.Errorf("Left: want *VarExprQ, got %T", op.Left) + } + + // Right side: SubLinkExprQ (scalar subquery). + subLink, ok := op.Right.(*SubLinkExprQ) + if !ok { + t.Fatalf("Right: want *SubLinkExprQ, got %T", op.Right) + } + if subLink.Kind != SubLinkScalar { + t.Errorf("Kind: want SubLinkScalar, got %d", subLink.Kind) + } + if subLink.Subquery == nil { + t.Fatal("Subquery: want non-nil, got nil") + } + if !subLink.Subquery.HasAggs { + t.Errorf("Subquery.HasAggs: want true, got false") + } +} + +// TestAnalyze_11_2_InSubquery tests WHERE department_id IN (SELECT id FROM departments WHERE budget > 100000). +func TestAnalyze_11_2_InSubquery(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT name FROM employees WHERE department_id IN (SELECT id FROM departments WHERE budget > 100000)`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // WHERE should be a SubLinkExprQ with Kind=SubLinkIn. + subLink, ok := q.JoinTree.Quals.(*SubLinkExprQ) + if !ok { + t.Fatalf("Quals: want *SubLinkExprQ, got %T", q.JoinTree.Quals) + } + if subLink.Kind != SubLinkIn { + t.Errorf("Kind: want SubLinkIn, got %d", subLink.Kind) + } + if subLink.Op != "=" { + t.Errorf("Op: want =, got %s", subLink.Op) + } + + // TestExpr: VarExprQ for department_id (AttNum=4). + testVar, ok := subLink.TestExpr.(*VarExprQ) + if !ok { + t.Fatalf("TestExpr: want *VarExprQ, got %T", subLink.TestExpr) + } + if testVar.AttNum != 4 { + t.Errorf("TestExpr.AttNum: want 4, got %d", testVar.AttNum) + } + + // Subquery should have a WHERE qual. + if subLink.Subquery == nil { + t.Fatal("Subquery: want non-nil, got nil") + } + if subLink.Subquery.JoinTree.Quals == nil { + t.Error("Subquery WHERE: want non-nil, got nil") + } +} + +// TestAnalyze_11_3_ExistsCorrelated tests EXISTS with a correlated subquery. +func TestAnalyze_11_3_ExistsCorrelated(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT name FROM employees e WHERE EXISTS (SELECT 1 FROM projects p WHERE p.lead_id = e.id)`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // WHERE should be SubLinkExprQ with Kind=SubLinkExists. + subLink, ok := q.JoinTree.Quals.(*SubLinkExprQ) + if !ok { + t.Fatalf("Quals: want *SubLinkExprQ, got %T", q.JoinTree.Quals) + } + if subLink.Kind != SubLinkExists { + t.Errorf("Kind: want SubLinkExists, got %d", subLink.Kind) + } + if subLink.TestExpr != nil { + t.Errorf("TestExpr: want nil for EXISTS, got %v", subLink.TestExpr) + } + + // Inner query WHERE should reference outer scope. + innerQ := subLink.Subquery + if innerQ == nil { + t.Fatal("Subquery: want non-nil, got nil") + } + innerQuals := innerQ.JoinTree.Quals + if innerQuals == nil { + t.Fatal("inner Quals: want non-nil, got nil") + } + + // Should be OpExprQ: p.lead_id = e.id + innerOp, ok := innerQuals.(*OpExprQ) + if !ok { + t.Fatalf("inner Quals: want *OpExprQ, got %T", innerQuals) + } + + // One side should have LevelsUp=1 (correlated reference to outer e.id). + leftVar, leftOk := innerOp.Left.(*VarExprQ) + rightVar, rightOk := innerOp.Right.(*VarExprQ) + if !leftOk || !rightOk { + t.Fatalf("inner Op sides: want both *VarExprQ, got %T and %T", innerOp.Left, innerOp.Right) + } + + // e.id is from the outer scope (LevelsUp=1); p.lead_id is from inner scope (LevelsUp=0). + if leftVar.LevelsUp == 0 && rightVar.LevelsUp == 1 { + // right is correlated + if rightVar.AttNum != 1 { // e.id is column 1 + t.Errorf("correlated ref AttNum: want 1 (id), got %d", rightVar.AttNum) + } + } else if leftVar.LevelsUp == 1 && rightVar.LevelsUp == 0 { + // left is correlated + if leftVar.AttNum != 1 { + t.Errorf("correlated ref AttNum: want 1 (id), got %d", leftVar.AttNum) + } + } else { + t.Errorf("expected one LevelsUp=1 (correlated), got left=%d right=%d", leftVar.LevelsUp, rightVar.LevelsUp) + } +} + +// TestAnalyze_11_4_NotInSubquery tests NOT IN (SELECT ...). +func TestAnalyze_11_4_NotInSubquery(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT name FROM employees WHERE department_id NOT IN (SELECT department_id FROM projects)`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // NOT IN subquery is represented as BoolExprQ{BoolNot, [SubLinkExprQ{SubLinkIn}]}. + boolExpr, ok := q.JoinTree.Quals.(*BoolExprQ) + if !ok { + t.Fatalf("Quals: want *BoolExprQ, got %T", q.JoinTree.Quals) + } + if boolExpr.Op != BoolNot { + t.Errorf("BoolOp: want BoolNot, got %d", boolExpr.Op) + } + if len(boolExpr.Args) != 1 { + t.Fatalf("BoolExpr.Args: want 1, got %d", len(boolExpr.Args)) + } + + subLink, ok := boolExpr.Args[0].(*SubLinkExprQ) + if !ok { + t.Fatalf("BoolExpr.Args[0]: want *SubLinkExprQ, got %T", boolExpr.Args[0]) + } + if subLink.Kind != SubLinkIn { + t.Errorf("Kind: want SubLinkIn, got %d", subLink.Kind) + } +} + +// TestAnalyze_11_5_ScalarSubqueryInSelect tests a scalar subquery in the SELECT list. +func TestAnalyze_11_5_ScalarSubqueryInSelect(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT name, (SELECT COUNT(*) FROM projects p WHERE p.lead_id = e.id) AS project_count FROM employees e`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.TargetList) != 2 { + t.Fatalf("TargetList: want 2, got %d", len(q.TargetList)) + } + + // Second column should be a SubLinkExprQ. + subLink, ok := q.TargetList[1].Expr.(*SubLinkExprQ) + if !ok { + t.Fatalf("TargetList[1].Expr: want *SubLinkExprQ, got %T", q.TargetList[1].Expr) + } + if subLink.Kind != SubLinkScalar { + t.Errorf("Kind: want SubLinkScalar, got %d", subLink.Kind) + } + if q.TargetList[1].ResName != "project_count" { + t.Errorf("ResName: want project_count, got %s", q.TargetList[1].ResName) + } + + // Inner query should have HasAggs=true (COUNT). + if !subLink.Subquery.HasAggs { + t.Errorf("Subquery.HasAggs: want true, got false") + } +} + +// --- Batch 12: CTEs --- + +// TestAnalyze_12_1_SimpleCTE tests a simple CTE with WITH clause. +func TestAnalyze_12_1_SimpleCTE(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `WITH dept_stats AS (SELECT department_id, COUNT(*) AS cnt FROM employees GROUP BY department_id) SELECT * FROM dept_stats`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // CTEList should have 1 entry. + if len(q.CTEList) != 1 { + t.Fatalf("CTEList: want 1, got %d", len(q.CTEList)) + } + cte := q.CTEList[0] + if cte.Name != "dept_stats" { + t.Errorf("CTE Name: want dept_stats, got %s", cte.Name) + } + if cte.Recursive { + t.Errorf("CTE Recursive: want false, got true") + } + if cte.Query == nil { + t.Fatal("CTE Query: want non-nil, got nil") + } + + // RangeTable should have an RTECTE entry. + if len(q.RangeTable) != 1 { + t.Fatalf("RangeTable: want 1, got %d", len(q.RangeTable)) + } + rte := q.RangeTable[0] + if rte.Kind != RTECTE { + t.Errorf("RTE Kind: want RTECTE, got %d", rte.Kind) + } + if rte.CTEName != "dept_stats" { + t.Errorf("RTE CTEName: want dept_stats, got %s", rte.CTEName) + } + if rte.CTEIndex != 0 { + t.Errorf("RTE CTEIndex: want 0, got %d", rte.CTEIndex) + } + + // Star expansion from CTE should produce columns from the CTE body. + if len(q.TargetList) != 2 { + t.Errorf("TargetList: want 2, got %d", len(q.TargetList)) + } +} + +// TestAnalyze_12_2_MultipleCTEs tests multiple CTEs. +func TestAnalyze_12_2_MultipleCTEs(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, ` + WITH + active_emps AS (SELECT id, name FROM employees WHERE is_active = 1), + big_depts AS (SELECT id, name FROM departments WHERE budget > 100000) + SELECT a.name, b.name AS dept_name FROM active_emps a JOIN big_depts b ON a.id = b.id`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.CTEList) != 2 { + t.Fatalf("CTEList: want 2, got %d", len(q.CTEList)) + } + if q.CTEList[0].Name != "active_emps" { + t.Errorf("CTE[0].Name: want active_emps, got %s", q.CTEList[0].Name) + } + if q.CTEList[1].Name != "big_depts" { + t.Errorf("CTE[1].Name: want big_depts, got %s", q.CTEList[1].Name) + } + + // RangeTable: 2 RTECTE + 1 RTEJoin = 3. + if len(q.RangeTable) != 3 { + t.Fatalf("RangeTable: want 3, got %d", len(q.RangeTable)) + } + if q.RangeTable[0].Kind != RTECTE { + t.Errorf("RTE[0]: want RTECTE, got %d", q.RangeTable[0].Kind) + } + if q.RangeTable[1].Kind != RTECTE { + t.Errorf("RTE[1]: want RTECTE, got %d", q.RangeTable[1].Kind) + } + + if len(q.TargetList) != 2 { + t.Errorf("TargetList: want 2, got %d", len(q.TargetList)) + } +} + +// TestAnalyze_12_3_CTEWithExplicitColumns tests CTE with explicit column aliases. +func TestAnalyze_12_3_CTEWithExplicitColumns(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `WITH emp_summary(dept, cnt) AS (SELECT department_id, COUNT(*) FROM employees GROUP BY department_id) SELECT dept, cnt FROM emp_summary`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.CTEList) != 1 { + t.Fatalf("CTEList: want 1, got %d", len(q.CTEList)) + } + cte := q.CTEList[0] + if len(cte.ColumnNames) != 2 || cte.ColumnNames[0] != "dept" || cte.ColumnNames[1] != "cnt" { + t.Errorf("CTE ColumnNames: want [dept, cnt], got %v", cte.ColumnNames) + } + + // RTECTE should use the explicit column names. + rte := q.RangeTable[0] + if rte.Kind != RTECTE { + t.Fatalf("RTE Kind: want RTECTE, got %d", rte.Kind) + } + if len(rte.ColNames) != 2 || rte.ColNames[0] != "dept" || rte.ColNames[1] != "cnt" { + t.Errorf("RTE ColNames: want [dept, cnt], got %v", rte.ColNames) + } + + // TargetList should resolve dept and cnt. + if len(q.TargetList) != 2 { + t.Fatalf("TargetList: want 2, got %d", len(q.TargetList)) + } + if q.TargetList[0].ResName != "dept" { + t.Errorf("TargetList[0].ResName: want dept, got %s", q.TargetList[0].ResName) + } + if q.TargetList[1].ResName != "cnt" { + t.Errorf("TargetList[1].ResName: want cnt, got %s", q.TargetList[1].ResName) + } +} + +// TestAnalyze_12_4_CTEReferencedTwice tests a CTE referenced twice in the FROM clause. +func TestAnalyze_12_4_CTEReferencedTwice(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `WITH emp_ids AS (SELECT id, name FROM employees) SELECT a.name, b.name FROM emp_ids a JOIN emp_ids b ON a.id = b.id`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.CTEList) != 1 { + t.Fatalf("CTEList: want 1, got %d", len(q.CTEList)) + } + + // Two RTECTE entries (one per reference), plus one RTEJoin. + if len(q.RangeTable) != 3 { + t.Fatalf("RangeTable: want 3, got %d", len(q.RangeTable)) + } + if q.RangeTable[0].Kind != RTECTE { + t.Errorf("RTE[0]: want RTECTE, got %d", q.RangeTable[0].Kind) + } + if q.RangeTable[1].Kind != RTECTE { + t.Errorf("RTE[1]: want RTECTE, got %d", q.RangeTable[1].Kind) + } + // Both should reference CTEIndex=0. + if q.RangeTable[0].CTEIndex != 0 { + t.Errorf("RTE[0].CTEIndex: want 0, got %d", q.RangeTable[0].CTEIndex) + } + if q.RangeTable[1].CTEIndex != 0 { + t.Errorf("RTE[1].CTEIndex: want 0, got %d", q.RangeTable[1].CTEIndex) + } +} + +// TestAnalyze_12_5_RecursiveCTE tests WITH RECURSIVE. +func TestAnalyze_12_5_RecursiveCTE(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE categories (id INT PRIMARY KEY, name VARCHAR(100), parent_id INT)`) + + sel := parseSelect(t, ` + WITH RECURSIVE cat_tree(id, name, parent_id) AS ( + SELECT id, name, parent_id FROM categories WHERE parent_id IS NULL + UNION ALL + SELECT c.id, c.name, c.parent_id FROM categories c INNER JOIN cat_tree ct ON c.parent_id = ct.id + ) + SELECT * FROM cat_tree`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if !q.IsRecursive { + t.Errorf("IsRecursive: want true, got false") + } + if len(q.CTEList) != 1 { + t.Fatalf("CTEList: want 1, got %d", len(q.CTEList)) + } + cte := q.CTEList[0] + if !cte.Recursive { + t.Errorf("CTE Recursive: want true, got false") + } + + // CTE body should be a set-op query. + if cte.Query.SetOp != SetOpUnion { + t.Errorf("CTE SetOp: want SetOpUnion, got %d", cte.Query.SetOp) + } + if !cte.Query.AllSetOp { + t.Errorf("CTE AllSetOp: want true, got false") + } + + // RTECTE in main query. + if len(q.RangeTable) != 1 { + t.Fatalf("RangeTable: want 1, got %d", len(q.RangeTable)) + } + if q.RangeTable[0].Kind != RTECTE { + t.Errorf("RTE Kind: want RTECTE, got %d", q.RangeTable[0].Kind) + } + + // Star expansion: 3 columns from the CTE. + if len(q.TargetList) != 3 { + t.Errorf("TargetList: want 3, got %d", len(q.TargetList)) + } +} + +// --- Batch 13: Set operations --- + +// TestAnalyze_13_1_Union tests UNION (without ALL). +func TestAnalyze_13_1_Union(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT name FROM employees UNION SELECT name FROM departments`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if q.SetOp != SetOpUnion { + t.Errorf("SetOp: want SetOpUnion, got %d", q.SetOp) + } + if q.AllSetOp { + t.Errorf("AllSetOp: want false, got true") + } + if q.LArg == nil { + t.Fatal("LArg: want non-nil, got nil") + } + if q.RArg == nil { + t.Fatal("RArg: want non-nil, got nil") + } + + // Result columns from left arm. + if len(q.TargetList) != 1 { + t.Fatalf("TargetList: want 1, got %d", len(q.TargetList)) + } + if q.TargetList[0].ResName != "name" { + t.Errorf("TargetList[0].ResName: want name, got %s", q.TargetList[0].ResName) + } +} + +// TestAnalyze_13_2_UnionAll tests UNION ALL. +func TestAnalyze_13_2_UnionAll(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT name FROM employees UNION ALL SELECT name FROM departments`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if q.SetOp != SetOpUnion { + t.Errorf("SetOp: want SetOpUnion, got %d", q.SetOp) + } + if !q.AllSetOp { + t.Errorf("AllSetOp: want true, got false") + } +} + +// TestAnalyze_13_3_Intersect tests INTERSECT. +func TestAnalyze_13_3_Intersect(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT name FROM employees INTERSECT SELECT name FROM departments`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if q.SetOp != SetOpIntersect { + t.Errorf("SetOp: want SetOpIntersect, got %d", q.SetOp) + } + if q.AllSetOp { + t.Errorf("AllSetOp: want false, got true") + } + if q.LArg == nil || q.RArg == nil { + t.Fatal("LArg/RArg: want non-nil") + } +} + +// TestAnalyze_13_4_Except tests EXCEPT. +func TestAnalyze_13_4_Except(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT name FROM employees EXCEPT SELECT name FROM departments`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if q.SetOp != SetOpExcept { + t.Errorf("SetOp: want SetOpExcept, got %d", q.SetOp) + } + if q.AllSetOp { + t.Errorf("AllSetOp: want false, got true") + } +} + +// TestAnalyze_13_5_UnionAllOrderByLimit tests UNION ALL with ORDER BY and LIMIT. +func TestAnalyze_13_5_UnionAllOrderByLimit(t *testing.T) { + c := setupJoinTables(t) + sel := parseSelect(t, `SELECT name FROM employees UNION ALL SELECT name FROM departments ORDER BY name LIMIT 10`) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if q.SetOp != SetOpUnion { + t.Errorf("SetOp: want SetOpUnion, got %d", q.SetOp) + } + if !q.AllSetOp { + t.Errorf("AllSetOp: want true, got false") + } + + // ORDER BY should be populated. + if len(q.SortClause) != 1 { + t.Fatalf("SortClause: want 1, got %d", len(q.SortClause)) + } + + // LIMIT should be populated. + if q.LimitCount == nil { + t.Error("LimitCount: want non-nil, got nil") + } + constExpr, ok := q.LimitCount.(*ConstExprQ) + if !ok { + t.Fatalf("LimitCount: want *ConstExprQ, got %T", q.LimitCount) + } + if constExpr.Value != "10" { + t.Errorf("LimitCount value: want 10, got %s", constExpr.Value) + } +} + +// ---------- Phase 3: Standalone expression analysis, function types, DDL hookups ---------- + +// TestAnalyze_15_1_CheckConstraintAnalyzed tests that CHECK constraint expressions +// are analyzed and stored as CheckAnalyzed on the Constraint. +func TestAnalyze_15_1_CheckConstraintAnalyzed(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t (a INT, b INT, CONSTRAINT chk CHECK (a > 0 AND b > 0))`) + + db := c.GetDatabase("testdb") + tbl := db.GetTable("t") + if tbl == nil { + t.Fatal("table t not found") + } + + // Find the CHECK constraint named "chk". + var con *Constraint + for _, cc := range tbl.Constraints { + if cc.Type == ConCheck && cc.Name == "chk" { + con = cc + break + } + } + if con == nil { + t.Fatal("CHECK constraint 'chk' not found") + } + if con.CheckAnalyzed == nil { + t.Fatal("CheckAnalyzed: want non-nil, got nil") + } + + // Top-level should be BoolExprQ with BoolAnd. + boolExpr, ok := con.CheckAnalyzed.(*BoolExprQ) + if !ok { + t.Fatalf("CheckAnalyzed: want *BoolExprQ, got %T", con.CheckAnalyzed) + } + if boolExpr.Op != BoolAnd { + t.Errorf("BoolExprQ.Op: want BoolAnd, got %v", boolExpr.Op) + } + if len(boolExpr.Args) != 2 { + t.Fatalf("BoolExprQ.Args: want 2, got %d", len(boolExpr.Args)) + } + + // Each arg should be OpExprQ with Op ">". + for i, arg := range boolExpr.Args { + opExpr, ok := arg.(*OpExprQ) + if !ok { + t.Fatalf("Args[%d]: want *OpExprQ, got %T", i, arg) + } + if opExpr.Op != ">" { + t.Errorf("Args[%d].Op: want >, got %s", i, opExpr.Op) + } + } +} + +// TestAnalyze_15_2_DefaultAnalyzed tests that DEFAULT expressions are analyzed. +func TestAnalyze_15_2_DefaultAnalyzed(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t (a INT DEFAULT 42, b VARCHAR(100) DEFAULT 'hello')`) + + db := c.GetDatabase("testdb") + tbl := db.GetTable("t") + if tbl == nil { + t.Fatal("table t not found") + } + + // Column "a" default: ConstExprQ{Value:"42"} + colA := tbl.GetColumn("a") + if colA == nil { + t.Fatal("column a not found") + } + if colA.DefaultAnalyzed == nil { + t.Fatal("a.DefaultAnalyzed: want non-nil, got nil") + } + constA, ok := colA.DefaultAnalyzed.(*ConstExprQ) + if !ok { + t.Fatalf("a.DefaultAnalyzed: want *ConstExprQ, got %T", colA.DefaultAnalyzed) + } + if constA.Value != "42" { + t.Errorf("a.DefaultAnalyzed.Value: want 42, got %s", constA.Value) + } + + // Column "b" default: ConstExprQ{Value:"hello"} + colB := tbl.GetColumn("b") + if colB == nil { + t.Fatal("column b not found") + } + if colB.DefaultAnalyzed == nil { + t.Fatal("b.DefaultAnalyzed: want non-nil, got nil") + } + constB, ok := colB.DefaultAnalyzed.(*ConstExprQ) + if !ok { + t.Fatalf("b.DefaultAnalyzed: want *ConstExprQ, got %T", colB.DefaultAnalyzed) + } + if constB.Value != "hello" { + t.Errorf("b.DefaultAnalyzed.Value: want hello, got %s", constB.Value) + } +} + +// TestAnalyze_15_3_GeneratedAnalyzed tests that GENERATED ALWAYS AS expressions are analyzed. +func TestAnalyze_15_3_GeneratedAnalyzed(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t (a INT, b INT, c INT GENERATED ALWAYS AS (a + b) STORED)`) + + db := c.GetDatabase("testdb") + tbl := db.GetTable("t") + if tbl == nil { + t.Fatal("table t not found") + } + + colC := tbl.GetColumn("c") + if colC == nil { + t.Fatal("column c not found") + } + if colC.GeneratedAnalyzed == nil { + t.Fatal("c.GeneratedAnalyzed: want non-nil, got nil") + } + + opExpr, ok := colC.GeneratedAnalyzed.(*OpExprQ) + if !ok { + t.Fatalf("c.GeneratedAnalyzed: want *OpExprQ, got %T", colC.GeneratedAnalyzed) + } + if opExpr.Op != "+" { + t.Errorf("OpExprQ.Op: want +, got %s", opExpr.Op) + } + + // Left should be VarExprQ for column a (AttNum=1). + leftVar, ok := opExpr.Left.(*VarExprQ) + if !ok { + t.Fatalf("Left: want *VarExprQ, got %T", opExpr.Left) + } + if leftVar.AttNum != 1 { + t.Errorf("Left.AttNum: want 1, got %d", leftVar.AttNum) + } + + // Right should be VarExprQ for column b (AttNum=2). + rightVar, ok := opExpr.Right.(*VarExprQ) + if !ok { + t.Fatalf("Right: want *VarExprQ, got %T", opExpr.Right) + } + if rightVar.AttNum != 2 { + t.Errorf("Right.AttNum: want 2, got %d", rightVar.AttNum) + } +} + +// TestAnalyze_15_4_FunctionReturnTypes tests that function return types are populated. +func TestAnalyze_15_4_FunctionReturnTypes(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE employees ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) NOT NULL + )`) + + sel := parseSelect(t, "SELECT COUNT(*), CONCAT(name, '!'), NOW() FROM employees") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + if len(q.TargetList) != 3 { + t.Fatalf("TargetList: want 3, got %d", len(q.TargetList)) + } + + // COUNT(*) -> BaseTypeBigInt + fc0, ok := q.TargetList[0].Expr.(*FuncCallExprQ) + if !ok { + t.Fatalf("TargetList[0].Expr: want *FuncCallExprQ, got %T", q.TargetList[0].Expr) + } + if fc0.ResultType == nil { + t.Fatal("COUNT(*) ResultType: want non-nil, got nil") + } + if fc0.ResultType.BaseType != BaseTypeBigInt { + t.Errorf("COUNT(*) ResultType.BaseType: want BaseTypeBigInt, got %d", fc0.ResultType.BaseType) + } + + // CONCAT(name, '!') -> BaseTypeVarchar + fc1, ok := q.TargetList[1].Expr.(*FuncCallExprQ) + if !ok { + t.Fatalf("TargetList[1].Expr: want *FuncCallExprQ, got %T", q.TargetList[1].Expr) + } + if fc1.ResultType == nil { + t.Fatal("CONCAT() ResultType: want non-nil, got nil") + } + if fc1.ResultType.BaseType != BaseTypeVarchar { + t.Errorf("CONCAT() ResultType.BaseType: want BaseTypeVarchar, got %d", fc1.ResultType.BaseType) + } + + // NOW() -> BaseTypeDateTime + fc2, ok := q.TargetList[2].Expr.(*FuncCallExprQ) + if !ok { + t.Fatalf("TargetList[2].Expr: want *FuncCallExprQ, got %T", q.TargetList[2].Expr) + } + if fc2.ResultType == nil { + t.Fatal("NOW() ResultType: want non-nil, got nil") + } + if fc2.ResultType.BaseType != BaseTypeDateTime { + t.Errorf("NOW() ResultType.BaseType: want BaseTypeDateTime, got %d", fc2.ResultType.BaseType) + } +} + +// TestAnalyze_15_5_ViewFunctionTypeFlow tests that function return types flow through views. +func TestAnalyze_15_5_ViewFunctionTypeFlow(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t (id INT, name VARCHAR(100))`) + wtExec(t, c, `CREATE VIEW v AS SELECT id, name, COUNT(*) OVER () AS cnt FROM t`) + + sel := parseSelect(t, "SELECT * FROM v") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + // Should have 3 columns: id, name, cnt + if len(q.TargetList) != 3 { + t.Fatalf("TargetList: want 3, got %d", len(q.TargetList)) + } + if q.TargetList[0].ResName != "id" { + t.Errorf("TargetList[0].ResName: want id, got %s", q.TargetList[0].ResName) + } + if q.TargetList[1].ResName != "name" { + t.Errorf("TargetList[1].ResName: want name, got %s", q.TargetList[1].ResName) + } + if q.TargetList[2].ResName != "cnt" { + t.Errorf("TargetList[2].ResName: want cnt, got %s", q.TargetList[2].ResName) + } +} diff --git a/tidb/catalog/bugfix_test.go b/tidb/catalog/bugfix_test.go new file mode 100644 index 00000000..a902047e --- /dev/null +++ b/tidb/catalog/bugfix_test.go @@ -0,0 +1,72 @@ +package catalog + +import ( + "testing" +) + +func TestBugFix_InExprNodeToSQL(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (category VARCHAR(50), CONSTRAINT chk CHECK (category IN ('a','b','c')))") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table t not found") + } + var found bool + for _, con := range tbl.Constraints { + if con.Name == "chk" { + found = true + // The expression should contain IN, not "(?)" + if con.CheckExpr == "(?)" || con.CheckExpr == "" { + t.Errorf("CHECK expression was not properly deparsed: got %q", con.CheckExpr) + } + t.Logf("CHECK expression: %s", con.CheckExpr) + } + } + if !found { + t.Error("constraint chk not found") + } +} + +func TestBugFix_FKIndexOrder(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, `CREATE TABLE child ( + id INT AUTO_INCREMENT PRIMARY KEY, + parent_id INT NOT NULL, + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent(id), + INDEX idx_parent (parent_id) + )`) + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + // Should have 2 indexes: PRIMARY + idx_parent (FK should reuse idx_parent, not create a 3rd) + if len(tbl.Indexes) != 2 { + names := make([]string, len(tbl.Indexes)) + for i, idx := range tbl.Indexes { + names[i] = idx.Name + } + t.Errorf("expected 2 indexes (PRIMARY + idx_parent), got %d: %v", len(tbl.Indexes), names) + } +} + +func TestBugFix_PartitionAutoGen(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT) PARTITION BY HASH(id) PARTITIONS 4") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table t not found") + } + if tbl.Partitioning == nil { + t.Fatal("partitioning info is nil") + } + if len(tbl.Partitioning.Partitions) != 4 { + t.Errorf("expected 4 partitions, got %d", len(tbl.Partitioning.Partitions)) + } + for i, p := range tbl.Partitioning.Partitions { + expected := "p" + string(rune('0'+i)) + if p.Name != expected { + t.Errorf("partition %d: expected name %q, got %q", i, expected, p.Name) + } + } +} diff --git a/tidb/catalog/catalog.go b/tidb/catalog/catalog.go new file mode 100644 index 00000000..45afa95f --- /dev/null +++ b/tidb/catalog/catalog.go @@ -0,0 +1,54 @@ +package catalog + +// Catalog is the in-memory MySQL catalog. +type Catalog struct { + databases map[string]*Database // lowered name -> Database + currentDB string + defaultCharset string + defaultCollation string + foreignKeyChecks bool // SET foreign_key_checks (default true) +} + +func New() *Catalog { + return &Catalog{ + databases: make(map[string]*Database), + defaultCharset: "utf8mb4", + defaultCollation: "utf8mb4_0900_ai_ci", + foreignKeyChecks: true, + } +} + +// ForeignKeyChecks returns whether FK validation is enabled. +func (c *Catalog) ForeignKeyChecks() bool { return c.foreignKeyChecks } + +// SetForeignKeyChecks enables or disables FK validation. +func (c *Catalog) SetForeignKeyChecks(v bool) { c.foreignKeyChecks = v } + +func (c *Catalog) SetCurrentDatabase(name string) { c.currentDB = name } +func (c *Catalog) CurrentDatabase() string { return c.currentDB } + +func (c *Catalog) GetDatabase(name string) *Database { + return c.databases[toLower(name)] +} + +func (c *Catalog) Databases() []*Database { + result := make([]*Database, 0, len(c.databases)) + for _, db := range c.databases { + result = append(result, db) + } + return result +} + +func (c *Catalog) resolveDatabase(name string) (*Database, error) { + if name == "" { + name = c.currentDB + } + if name == "" { + return nil, errNoDatabaseSelected() + } + db := c.GetDatabase(name) + if db == nil { + return nil, errUnknownDatabase(name) + } + return db, nil +} diff --git a/tidb/catalog/catalog_spotcheck_test.go b/tidb/catalog/catalog_spotcheck_test.go new file mode 100644 index 00000000..bed662e7 --- /dev/null +++ b/tidb/catalog/catalog_spotcheck_test.go @@ -0,0 +1,1167 @@ +package catalog + +import ( + "fmt" + "strings" + "testing" +) + +// These tests verify that the behaviors described in +// docs/plans/2026-04-13-mysql-implicit-behaviors-catalog.md +// match real MySQL 8.0 observable behavior. They run against a +// testcontainers MySQL container and do NOT involve omni's catalog. + +// spotCheckQuery runs DDL (possibly multi-statement) then returns the +// rows of the given query. Fatals on any error. +func spotCheckQuery(t *testing.T, mc *mysqlContainer, ddl, query string) [][]any { + t.Helper() + for _, stmt := range splitStatements(ddl) { + stmt = strings.TrimSpace(stmt) + if stmt == "" { + continue + } + if _, err := mc.db.ExecContext(mc.ctx, stmt); err != nil { + t.Fatalf("DDL failed: %q\n %v", stmt, err) + } + } + rows, err := mc.db.QueryContext(mc.ctx, query) + if err != nil { + t.Fatalf("query failed: %q\n %v", query, err) + } + defer rows.Close() + + cols, _ := rows.Columns() + var results [][]any + for rows.Next() { + vals := make([]any, len(cols)) + ptrs := make([]any, len(cols)) + for i := range vals { + ptrs[i] = &vals[i] + } + if err := rows.Scan(ptrs...); err != nil { + t.Fatalf("scan failed: %v", err) + } + // Convert []byte to string for readability. + for i, v := range vals { + if b, ok := v.([]byte); ok { + vals[i] = string(b) + } + } + results = append(results, vals) + } + return results +} + +func asString(v any) string { + switch x := v.(type) { + case string: + return x + case []byte: + return string(x) + case nil: + return "" + case int64: + return fmt.Sprintf("%d", x) + case int: + return fmt.Sprintf("%d", x) + case float64: + return fmt.Sprintf("%v", x) + } + return fmt.Sprintf("%v", v) +} + +func TestSpotCheck_CatalogVerification(t *testing.T) { + if testing.Short() { + t.Skip("spot-check requires container") + } + mc, cleanup := startContainer(t) + defer cleanup() + + // Reset state at the start of every sub-test. + reset := func(t *testing.T) { + t.Helper() + if _, err := mc.db.ExecContext(mc.ctx, "DROP DATABASE IF EXISTS sc"); err != nil { + t.Fatalf("drop sc: %v", err) + } + if _, err := mc.db.ExecContext(mc.ctx, "CREATE DATABASE sc"); err != nil { + t.Fatalf("create sc: %v", err) + } + if _, err := mc.db.ExecContext(mc.ctx, "USE sc"); err != nil { + t.Fatalf("use sc: %v", err) + } + } + + // --------------------------------------------------------------- + // C1.1: FK name counter uses max(existing) + 1, not count + 1. + // --------------------------------------------------------------- + t.Run("C1_1_FK_counter_max_plus_one", func(t *testing.T) { + reset(t) + // Test both CREATE TABLE and ALTER TABLE flows. + // Phase A: CREATE TABLE mixing explicit high-numbered CONSTRAINT name and unnamed FK. + rowsA := spotCheckQuery(t, mc, ` + CREATE TABLE parent (id INT PRIMARY KEY); + CREATE TABLE child ( + a INT, + b INT, + CONSTRAINT child_ibfk_5 FOREIGN KEY (a) REFERENCES parent(id), + FOREIGN KEY (b) REFERENCES parent(id) + ); + `, ` + SELECT CONSTRAINT_NAME FROM information_schema.TABLE_CONSTRAINTS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='child' AND CONSTRAINT_TYPE='FOREIGN KEY' + ORDER BY CONSTRAINT_NAME + `) + var namesA []string + for _, row := range rowsA { + namesA = append(namesA, asString(row[0])) + } + t.Logf("C1.1 phase A (CREATE TABLE) observed names: %v", namesA) + + // Phase B: ALTER TABLE adds an unnamed FK AFTER child_ibfk_5 already exists. + // Catalog's max+1 rule should yield child_ibfk_6. + rowsB := spotCheckQuery(t, mc, ` + CREATE TABLE parent2 (id INT PRIMARY KEY); + CREATE TABLE child2 (a INT, b INT); + ALTER TABLE child2 ADD CONSTRAINT child2_ibfk_5 FOREIGN KEY (a) REFERENCES parent2(id); + ALTER TABLE child2 ADD FOREIGN KEY (b) REFERENCES parent2(id); + `, ` + SELECT CONSTRAINT_NAME FROM information_schema.TABLE_CONSTRAINTS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='child2' AND CONSTRAINT_TYPE='FOREIGN KEY' + ORDER BY CONSTRAINT_NAME + `) + var namesB []string + for _, row := range rowsB { + namesB = append(namesB, asString(row[0])) + } + has5 := false + has6 := false + for _, n := range namesB { + if n == "child2_ibfk_5" { + has5 = true + } + if n == "child2_ibfk_6" { + has6 = true + } + } + if !has5 || !has6 || len(namesB) != 2 { + t.Errorf("CATALOG MISMATCH C1.1 (phase B, ALTER TABLE max+1): expected [child2_ibfk_5, child2_ibfk_6], got %v", namesB) + } else { + t.Logf("OK C1.1 phase B verified (ALTER TABLE honors max+1): %v", namesB) + } + + // Report phase A as observation (may differ from phase B if CREATE uses count-based). + if len(namesA) == 2 { + hasA5, hasA6 := false, false + for _, n := range namesA { + if n == "child_ibfk_5" { + hasA5 = true + } + if n == "child_ibfk_6" { + hasA6 = true + } + } + if hasA5 && hasA6 { + t.Logf("OK C1.1 phase A: CREATE TABLE also follows max+1: %v", namesA) + } else { + t.Logf("NOTE C1.1 phase A: CREATE TABLE does NOT follow max+1 (catalog may need clarification): %v", namesA) + } + } + }) + + // --------------------------------------------------------------- + // C1.2: Partition default names are p0, p1, p2, ... + // --------------------------------------------------------------- + t.Run("C1_2_partition_naming", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE pt (id INT) PARTITION BY HASH(id) PARTITIONS 4; + `, ` + SELECT PARTITION_NAME FROM information_schema.PARTITIONS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='pt' + ORDER BY PARTITION_ORDINAL_POSITION + `) + var names []string + for _, row := range rows { + names = append(names, asString(row[0])) + } + want := []string{"p0", "p1", "p2", "p3"} + if len(names) != 4 { + t.Errorf("CATALOG MISMATCH C1.2: expected 4 partitions, got %d: %v", len(names), names) + return + } + for i, w := range want { + if names[i] != w { + t.Errorf("CATALOG MISMATCH C1.2: expected %v, got %v", want, names) + return + } + } + t.Logf("OK C1.2 verified: %v", names) + }) + + // --------------------------------------------------------------- + // C3.1: TIMESTAMP NOT NULL promotion applies only to FIRST timestamp. + // Catalog claim: ts1 gets DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + // ts2 does NOT. + // --------------------------------------------------------------- + t.Run("C3_1_timestamp_first_only_promotion", func(t *testing.T) { + reset(t) + // We need explicit_defaults_for_timestamp=OFF to see the promotion, + // and we must relax STRICT mode because a second TIMESTAMP NOT NULL + // with no default receives the zero-date default which STRICT rejects. + if _, err := mc.db.ExecContext(mc.ctx, "SET SESSION explicit_defaults_for_timestamp=OFF"); err != nil { + t.Fatalf("set sql var: %v", err) + } + defer mc.db.ExecContext(mc.ctx, "SET SESSION explicit_defaults_for_timestamp=ON") + if _, err := mc.db.ExecContext(mc.ctx, "SET SESSION sql_mode=''"); err != nil { + t.Fatalf("set sql_mode: %v", err) + } + defer mc.db.ExecContext(mc.ctx, "SET SESSION sql_mode=DEFAULT") + + rows := spotCheckQuery(t, mc, ` + CREATE TABLE ts ( + ts1 TIMESTAMP NOT NULL, + ts2 TIMESTAMP NOT NULL + ); + `, ` + SELECT COLUMN_NAME, COLUMN_DEFAULT, EXTRA FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='ts' + ORDER BY ORDINAL_POSITION + `) + if len(rows) != 2 { + t.Fatalf("expected 2 cols, got %d", len(rows)) + } + ts1Default := asString(rows[0][1]) + ts1Extra := asString(rows[0][2]) + ts2Default := asString(rows[1][1]) + ts2Extra := asString(rows[1][2]) + + ts1Promoted := strings.Contains(strings.ToUpper(ts1Default), "CURRENT_TIMESTAMP") && + strings.Contains(strings.ToUpper(ts1Extra), "ON UPDATE") + ts2Promoted := strings.Contains(strings.ToUpper(ts2Default), "CURRENT_TIMESTAMP") + + if !ts1Promoted { + t.Errorf("CATALOG MISMATCH C3.1: expected ts1 promoted; got default=%q extra=%q", + ts1Default, ts1Extra) + } + if ts2Promoted { + t.Errorf("CATALOG MISMATCH C3.1: ts2 should NOT be promoted; got default=%q extra=%q", + ts2Default, ts2Extra) + } + if ts1Promoted && !ts2Promoted { + t.Logf("OK C3.1 verified: ts1 default=%q extra=%q; ts2 default=%q extra=%q", + ts1Default, ts1Extra, ts2Default, ts2Extra) + } + }) + + // --------------------------------------------------------------- + // C3.2: PRIMARY KEY column is implicitly NOT NULL. + // --------------------------------------------------------------- + t.Run("C3_2_primary_key_implies_not_null", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE pk (id INT, PRIMARY KEY(id)); + `, ` + SELECT IS_NULLABLE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='pk' AND COLUMN_NAME='id' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + nullable := asString(rows[0][0]) + if nullable != "NO" { + t.Errorf("CATALOG MISMATCH C3.2: expected IS_NULLABLE=NO, got %q", nullable) + } else { + t.Logf("OK C3.2 verified: IS_NULLABLE=%q", nullable) + } + }) + + // --------------------------------------------------------------- + // C4.1: Table inherits charset from database. + // --------------------------------------------------------------- + t.Run("C4_1_table_charset_from_database", func(t *testing.T) { + // Special: create DB with custom charset, not 'sc'. + if _, err := mc.db.ExecContext(mc.ctx, "DROP DATABASE IF EXISTS sc_cs"); err != nil { + t.Fatalf("drop: %v", err) + } + if _, err := mc.db.ExecContext(mc.ctx, "CREATE DATABASE sc_cs CHARACTER SET latin1 COLLATE latin1_swedish_ci"); err != nil { + t.Fatalf("create db: %v", err) + } + defer mc.db.ExecContext(mc.ctx, "DROP DATABASE IF EXISTS sc_cs") + if _, err := mc.db.ExecContext(mc.ctx, "CREATE TABLE sc_cs.t (c VARCHAR(10))"); err != nil { + t.Fatalf("create table: %v", err) + } + rows := spotCheckQuery(t, mc, ``, ` + SELECT TABLE_COLLATION FROM information_schema.TABLES + WHERE TABLE_SCHEMA='sc_cs' AND TABLE_NAME='t' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + coll := asString(rows[0][0]) + if coll != "latin1_swedish_ci" { + t.Errorf("CATALOG MISMATCH C4.1: expected latin1_swedish_ci, got %q", coll) + } else { + t.Logf("OK C4.1 verified: TABLE_COLLATION=%q", coll) + } + }) + + // --------------------------------------------------------------- + // C5.1: FK ON DELETE default. Catalog claims parser default is + // FK_OPTION_RESTRICT. However, information_schema.REFERENTIAL_CONSTRAINTS + // famously reports 'NO ACTION' for both RESTRICT and unspecified, and + // SHOW CREATE TABLE elides the clause entirely. Verify both views. + // --------------------------------------------------------------- + t.Run("C5_1_fk_on_delete_default", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE p5 (id INT PRIMARY KEY); + CREATE TABLE c5 ( + a INT, + FOREIGN KEY (a) REFERENCES p5(id) + ); + `, ` + SELECT DELETE_RULE, UPDATE_RULE FROM information_schema.REFERENTIAL_CONSTRAINTS + WHERE CONSTRAINT_SCHEMA='sc' AND TABLE_NAME='c5' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + del := asString(rows[0][0]) + upd := asString(rows[0][1]) + // Real MySQL 8.0 reports "NO ACTION" in info_schema for unspecified. + if del != "NO ACTION" || upd != "NO ACTION" { + t.Errorf("CATALOG OBS C5.1: REFERENTIAL_CONSTRAINTS reports DELETE=%q UPDATE=%q (catalog says parser default is FK_OPTION_RESTRICT; confirm whether the catalog means semantic behavior vs reporting)", del, upd) + } else { + t.Logf("OK C5.1 observed: info_schema reports DELETE=%q UPDATE=%q (semantically equivalent to RESTRICT)", del, upd) + } + // Also verify SHOW CREATE TABLE elides ON DELETE clause (standard behavior). + stmt, err := mc.showCreateTable("c5") + if err != nil { + t.Fatalf("show create: %v", err) + } + if strings.Contains(strings.ToUpper(stmt), "ON DELETE") { + t.Errorf("unexpected ON DELETE clause in SHOW CREATE TABLE: %s", stmt) + } else { + t.Logf("OK C5.1 SHOW CREATE elides ON DELETE clause: %s", stmt) + } + }) + + // --------------------------------------------------------------- + // C10.2: View SQL SECURITY defaults to DEFINER. + // --------------------------------------------------------------- + t.Run("C10_2_view_security_definer", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE base (id INT); + CREATE VIEW v AS SELECT * FROM base; + `, ` + SELECT SECURITY_TYPE FROM information_schema.VIEWS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='v' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + sec := asString(rows[0][0]) + if sec != "DEFINER" { + t.Errorf("CATALOG MISMATCH C10.2: expected SECURITY_TYPE=DEFINER, got %q", sec) + } else { + t.Logf("OK C10.2 verified: SECURITY_TYPE=%q", sec) + } + }) + + // --------------------------------------------------------------- + // C16.1: NOW()/CURRENT_TIMESTAMP precision defaults to 0. + // Test via a column with DEFAULT NOW() -- the column's DATETIME_PRECISION + // is driven by the column type, so instead check the generated column case: + // if we use `DATETIME` without precision, DATETIME_PRECISION = 0. + // More directly observable: use LENGTH(NOW()) in a scalar SELECT. + // --------------------------------------------------------------- + t.Run("C16_1_now_precision_default_zero", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ``, `SELECT LENGTH(NOW()), LENGTH(NOW(6)), LENGTH(CURRENT_TIMESTAMP)`) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + // "YYYY-MM-DD HH:MM:SS" = 19, with (6) fractional = 19+7 = 26. + l0 := asString(rows[0][0]) + l6 := asString(rows[0][1]) + lCT := asString(rows[0][2]) + if l0 != "19" { + t.Errorf("CATALOG MISMATCH C16.1: LENGTH(NOW())=%q, expected 19 (no fractional seconds)", l0) + } + if l6 != "26" { + t.Errorf("CATALOG MISMATCH C16.1: LENGTH(NOW(6))=%q, expected 26", l6) + } + if lCT != "19" { + t.Errorf("CATALOG MISMATCH C16.1: LENGTH(CURRENT_TIMESTAMP)=%q, expected 19", lCT) + } + if l0 == "19" && l6 == "26" && lCT == "19" { + t.Logf("OK C16.1 verified: NOW()=%s CURRENT_TIMESTAMP=%s NOW(6)=%s", l0, lCT, l6) + } + }) + + // --------------------------------------------------------------- + // C18.4: AUTO_INCREMENT clause elided in SHOW CREATE TABLE if counter <= 1. + // --------------------------------------------------------------- + t.Run("C18_4_auto_increment_elision", func(t *testing.T) { + reset(t) + if _, err := mc.db.ExecContext(mc.ctx, `CREATE TABLE ai (id INT AUTO_INCREMENT PRIMARY KEY, v INT)`); err != nil { + t.Fatalf("create: %v", err) + } + stmt1, err := mc.showCreateTable("ai") + if err != nil { + t.Fatalf("show create: %v", err) + } + // Fresh table: no AUTO_INCREMENT=N clause. + if strings.Contains(strings.ToUpper(stmt1), "AUTO_INCREMENT=") { + t.Errorf("CATALOG MISMATCH C18.4 (before insert): expected no AUTO_INCREMENT= clause, got: %s", stmt1) + } else { + t.Logf("OK C18.4 verified (before insert): %s", stmt1) + } + // After inserting, counter advances, clause should appear. + if _, err := mc.db.ExecContext(mc.ctx, `INSERT INTO ai (v) VALUES (10),(20),(30)`); err != nil { + t.Fatalf("insert: %v", err) + } + stmt2, err := mc.showCreateTable("ai") + if err != nil { + t.Fatalf("show create: %v", err) + } + if !strings.Contains(strings.ToUpper(stmt2), "AUTO_INCREMENT=") { + t.Errorf("CATALOG MISMATCH C18.4 (after insert): expected AUTO_INCREMENT= clause, got: %s", stmt2) + } else { + t.Logf("OK C18.4 verified (after insert, counter > 1): %s", stmt2) + } + }) + + // =================================================================== + // Round 2 extended spot-check (PS1-PS7 path-splits + Round 1/2 gaps) + // =================================================================== + + // --------------------------------------------------------------- + // PS1 (CREATE): CHECK constraint counter — CREATE uses FRESH counter. + // Catalog claim: unnamed CHECK constraints in CREATE are numbered + // 1, 2, 3, ... starting from 0 regardless of user-named CCs. + // --------------------------------------------------------------- + t.Run("PS1_CheckCounter_CREATE_fresh", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE tchk ( + a INT, + CONSTRAINT tchk_chk_5 CHECK (a > 0), + b INT, + CHECK (b < 100) + ); + `, ` + SELECT CONSTRAINT_NAME FROM information_schema.TABLE_CONSTRAINTS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='tchk' AND CONSTRAINT_TYPE='CHECK' + ORDER BY CONSTRAINT_NAME + `) + var names []string + for _, row := range rows { + names = append(names, asString(row[0])) + } + t.Logf("PS1 CREATE observed: %v", names) + has1, has5 := false, false + for _, n := range names { + if n == "tchk_chk_1" { + has1 = true + } + if n == "tchk_chk_5" { + has5 = true + } + } + if !has1 || !has5 || len(names) != 2 { + t.Errorf("CATALOG MISMATCH PS1 CREATE: expected [tchk_chk_1, tchk_chk_5], got %v", names) + } else { + t.Logf("OK PS1 CREATE verified (fresh counter from 0): %v", names) + } + }) + + // --------------------------------------------------------------- + // PS1 (ALTER): Does ALTER use fresh counter (like CREATE) or + // max+1 (like FK ALTER)? This is the open question. + // --------------------------------------------------------------- + t.Run("PS1_CheckCounter_ALTER_open", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE tchk2 ( + a INT, + b INT, + CONSTRAINT tchk2_chk_20 CHECK (a > 0) + ); + ALTER TABLE tchk2 ADD CHECK (b > 0); + `, ` + SELECT CONSTRAINT_NAME FROM information_schema.TABLE_CONSTRAINTS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='tchk2' AND CONSTRAINT_TYPE='CHECK' + ORDER BY CONSTRAINT_NAME + `) + var names []string + for _, row := range rows { + names = append(names, asString(row[0])) + } + t.Logf("PS1 ALTER observed: %v", names) + hasFresh := false + hasMaxPlus1 := false + for _, n := range names { + if n == "tchk2_chk_1" { + hasFresh = true + } + if n == "tchk2_chk_21" { + hasMaxPlus1 = true + } + } + switch { + case hasFresh: + t.Logf("PS1 ALTER FINDING: fresh counter (tchk2_chk_1). Catalog should document ALTER=fresh.") + case hasMaxPlus1: + t.Logf("PS1 ALTER FINDING: max+1 counter (tchk2_chk_21). Catalog should document ALTER=max+1 (like FK).") + default: + t.Logf("PS1 ALTER FINDING: unexpected name(s) %v", names) + } + }) + + // --------------------------------------------------------------- + // PS5: DATETIME(6) DEFAULT NOW() (fsp=0) — catalog says ER_INVALID_DEFAULT. + // --------------------------------------------------------------- + t.Run("PS5_DatetimeFspMismatch", func(t *testing.T) { + reset(t) + _, err := mc.db.ExecContext(mc.ctx, `CREATE TABLE tps5 (ts DATETIME(6) DEFAULT NOW())`) + if err == nil { + // Accepted — catalog MISMATCH. Read back what was stored. + rows := spotCheckQuery(t, mc, ``, ` + SELECT COLUMN_NAME, COLUMN_DEFAULT, DATETIME_PRECISION + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='tps5' + `) + t.Errorf("CATALOG MISMATCH PS5: MySQL accepted DATETIME(6) DEFAULT NOW() (catalog says ER_INVALID_DEFAULT). COLUMNS=%v", rows) + return + } + msg := err.Error() + t.Logf("PS5 error observed: %v", msg) + if !strings.Contains(strings.ToLower(msg), "invalid default") && !strings.Contains(msg, "1067") { + t.Errorf("PS5 UNEXPECTED ERROR TEXT: expected ER_INVALID_DEFAULT (1067), got: %v", msg) + } else { + t.Logf("OK PS5 verified: MySQL rejects DATETIME(6) DEFAULT NOW() with ER_INVALID_DEFAULT") + } + }) + + // --------------------------------------------------------------- + // PS7: FK name collision — first unnamed FK wants t_ibfk_1, collides + // with user-named t_ibfk_1 → ER_FK_DUP_NAME. + // --------------------------------------------------------------- + t.Run("PS7_FKNameCollision", func(t *testing.T) { + reset(t) + if _, err := mc.db.ExecContext(mc.ctx, `CREATE TABLE p7 (id INT PRIMARY KEY)`); err != nil { + t.Fatalf("setup: %v", err) + } + _, err := mc.db.ExecContext(mc.ctx, ` + CREATE TABLE tps7 ( + a INT, + CONSTRAINT tps7_ibfk_1 FOREIGN KEY (a) REFERENCES p7(id), + b INT, + FOREIGN KEY (b) REFERENCES p7(id) + ) + `) + if err == nil { + rows := spotCheckQuery(t, mc, ``, ` + SELECT CONSTRAINT_NAME FROM information_schema.TABLE_CONSTRAINTS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='tps7' AND CONSTRAINT_TYPE='FOREIGN KEY' + ORDER BY CONSTRAINT_NAME + `) + var names []string + for _, row := range rows { + names = append(names, asString(row[0])) + } + t.Errorf("CATALOG MISMATCH PS7: expected ER_FK_DUP_NAME, got success with FKs %v", names) + return + } + msg := err.Error() + t.Logf("PS7 error observed: %v", msg) + // ER_FK_DUP_NAME = 1826 + if !strings.Contains(msg, "1826") && !strings.Contains(strings.ToLower(msg), "duplicate") { + t.Errorf("PS7 UNEXPECTED ERROR: expected ER_FK_DUP_NAME (1826), got: %v", msg) + } else { + t.Logf("OK PS7 verified: collision rejected") + } + }) + + // --------------------------------------------------------------- + // C1.3: Check constraint name format = {table}_chk_N + // (Also verified as part of PS1 CREATE above; explicit here.) + // --------------------------------------------------------------- + t.Run("C1_3_CheckConstraintName", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE cc (a INT, CHECK (a > 0), b INT, CHECK (b < 100)); + `, ` + SELECT CONSTRAINT_NAME FROM information_schema.TABLE_CONSTRAINTS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='cc' AND CONSTRAINT_TYPE='CHECK' + ORDER BY CONSTRAINT_NAME + `) + var names []string + for _, row := range rows { + names = append(names, asString(row[0])) + } + want := []string{"cc_chk_1", "cc_chk_2"} + if len(names) != 2 || names[0] != want[0] || names[1] != want[1] { + t.Errorf("CATALOG MISMATCH C1.3: expected %v, got %v", want, names) + } else { + t.Logf("OK C1.3 verified: %v", names) + } + }) + + // --------------------------------------------------------------- + // C2.1: REAL → DOUBLE. + // --------------------------------------------------------------- + t.Run("C2_1_REAL_to_DOUBLE", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE t2 (x REAL); + `, ` + SELECT DATA_TYPE, COLUMN_TYPE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='t2' AND COLUMN_NAME='x' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + dt := strings.ToLower(asString(rows[0][0])) + ct := strings.ToLower(asString(rows[0][1])) + if dt != "double" { + t.Errorf("CATALOG MISMATCH C2.1: expected DATA_TYPE=double, got %q (COLUMN_TYPE=%q)", dt, ct) + } else { + t.Logf("OK C2.1 verified: DATA_TYPE=%q COLUMN_TYPE=%q", dt, ct) + } + }) + + // --------------------------------------------------------------- + // C2.2: BOOL → TINYINT(1). + // --------------------------------------------------------------- + t.Run("C2_2_BOOL_to_TINYINT1", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE tbool (flag BOOL); + `, ` + SELECT DATA_TYPE, COLUMN_TYPE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='tbool' AND COLUMN_NAME='flag' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + dt := strings.ToLower(asString(rows[0][0])) + ct := strings.ToLower(asString(rows[0][1])) + if dt != "tinyint" || ct != "tinyint(1)" { + t.Errorf("CATALOG MISMATCH C2.2: expected tinyint / tinyint(1), got %q / %q", dt, ct) + } else { + t.Logf("OK C2.2 verified: %q / %q", dt, ct) + } + }) + + // --------------------------------------------------------------- + // C3.3: AUTO_INCREMENT implies NOT NULL. + // --------------------------------------------------------------- + t.Run("C3_3_AutoIncrement_implies_NOT_NULL", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE tai (id INT AUTO_INCREMENT PRIMARY KEY); + `, ` + SELECT IS_NULLABLE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='tai' AND COLUMN_NAME='id' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + if asString(rows[0][0]) != "NO" { + t.Errorf("CATALOG MISMATCH C3.3: expected IS_NULLABLE=NO, got %q", asString(rows[0][0])) + } else { + t.Logf("OK C3.3 verified") + } + }) + + // --------------------------------------------------------------- + // C4.2 + C18.1 + C18.5: DB utf8mb4 + table latin1 override; + // per-column charset inheritance from table charset; + // SHOW CREATE elides the per-column charset when matching table. + // --------------------------------------------------------------- + t.Run("C4_2_and_C18_1_and_C18_5_charset_inheritance_and_elision", func(t *testing.T) { + if _, err := mc.db.ExecContext(mc.ctx, "DROP DATABASE IF EXISTS sc_cs2"); err != nil { + t.Fatalf("drop: %v", err) + } + if _, err := mc.db.ExecContext(mc.ctx, "CREATE DATABASE sc_cs2 CHARACTER SET utf8mb4"); err != nil { + t.Fatalf("create db: %v", err) + } + defer mc.db.ExecContext(mc.ctx, "DROP DATABASE IF EXISTS sc_cs2") + if _, err := mc.db.ExecContext(mc.ctx, "CREATE TABLE sc_cs2.t (c VARCHAR(10)) CHARSET latin1"); err != nil { + t.Fatalf("create table: %v", err) + } + rows := spotCheckQuery(t, mc, ``, ` + SELECT CHARACTER_SET_NAME, COLLATION_NAME FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='sc_cs2' AND TABLE_NAME='t' AND COLUMN_NAME='c' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + cs := asString(rows[0][0]) + if cs != "latin1" { + t.Errorf("CATALOG MISMATCH C4.2: expected column charset=latin1 (inherited from table), got %q", cs) + } else { + t.Logf("OK C4.2 verified: column charset=%s", cs) + } + // SHOW CREATE TABLE should NOT contain per-column CHARACTER SET, + // but SHOULD contain DEFAULT CHARSET=latin1 (explicitly specified). + var scStmt string + row := mc.db.QueryRowContext(mc.ctx, "SHOW CREATE TABLE sc_cs2.t") + var tbl string + if err := row.Scan(&tbl, &scStmt); err != nil { + t.Fatalf("show create: %v", err) + } + up := strings.ToUpper(scStmt) + // C18.1: column-level CHARACTER SET should be elided + // We look specifically for "CHARACTER SET" after the column name "c". + colLineIdx := strings.Index(scStmt, "`c` ") + if colLineIdx < 0 { + t.Logf("C18.1 NOTE: could not find `c` column line in SHOW CREATE") + } else { + rest := scStmt[colLineIdx:] + if nl := strings.Index(rest, "\n"); nl >= 0 { + rest = rest[:nl] + } + if strings.Contains(strings.ToUpper(rest), "CHARACTER SET") { + t.Errorf("CATALOG MISMATCH C18.1: expected per-column CHARACTER SET elided; column line: %q", rest) + } else { + t.Logf("OK C18.1 verified: per-column CHARACTER SET elided (column line: %q)", rest) + } + } + // C18.5: DEFAULT CHARSET=latin1 SHOULD be present (user explicitly specified). + if !strings.Contains(up, "DEFAULT CHARSET=LATIN1") && !strings.Contains(up, "CHARSET=LATIN1") { + t.Errorf("CATALOG MISMATCH C18.5 (explicit): expected DEFAULT CHARSET=latin1 to be shown, got: %s", scStmt) + } else { + t.Logf("OK C18.5 (explicit) verified: DEFAULT CHARSET=latin1 shown") + } + }) + + // --------------------------------------------------------------- + // C18.5 (implicit): CREATE TABLE without charset → SHOW CREATE + // may still include DEFAULT CHARSET clause inherited from DB. + // Real MySQL: it DOES show DEFAULT CHARSET even when inherited. + // --------------------------------------------------------------- + t.Run("C18_5_DefaultCharset_implicit", func(t *testing.T) { + reset(t) + if _, err := mc.db.ExecContext(mc.ctx, `CREATE TABLE tnocs (x INT)`); err != nil { + t.Fatalf("create: %v", err) + } + stmt, err := mc.showCreateTable("tnocs") + if err != nil { + t.Fatalf("show create: %v", err) + } + up := strings.ToUpper(stmt) + has := strings.Contains(up, "DEFAULT CHARSET=") || strings.Contains(up, "CHARSET=") + t.Logf("C18.5 (implicit) observation: DEFAULT CHARSET present=%v; stmt=%s", has, stmt) + }) + + // --------------------------------------------------------------- + // C5.3: FK MATCH default. + // --------------------------------------------------------------- + t.Run("C5_3_FK_MATCH_default", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE pm (id INT PRIMARY KEY); + CREATE TABLE cm (a INT, FOREIGN KEY (a) REFERENCES pm(id)); + `, ` + SELECT MATCH_OPTION FROM information_schema.REFERENTIAL_CONSTRAINTS + WHERE CONSTRAINT_SCHEMA='sc' AND TABLE_NAME='cm' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + t.Logf("OK C5.3 observed: MATCH_OPTION=%q", asString(rows[0][0])) + // Catalog says FK_MATCH_SIMPLE; info_schema typically reports "NONE" for InnoDB. + }) + + // --------------------------------------------------------------- + // C6.1: PARTITION BY HASH without PARTITIONS defaults to 1. + // --------------------------------------------------------------- + t.Run("C6_1_Partition_default_count", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE phd (id INT) PARTITION BY HASH(id); + `, ` + SELECT COUNT(*) FROM information_schema.PARTITIONS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='phd' AND PARTITION_NAME IS NOT NULL + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + n := asString(rows[0][0]) + if n != "1" { + t.Errorf("CATALOG MISMATCH C6.1: expected 1 partition, got %s", n) + } else { + t.Logf("OK C6.1 verified: partitions=1") + } + }) + + // --------------------------------------------------------------- + // C6.2: Subpartitions auto-gen. PARTITIONS 2 SUBPARTITIONS 3 → 6 rows. + // --------------------------------------------------------------- + t.Run("C6_2_Subpartition_count", func(t *testing.T) { + reset(t) + _, err := mc.db.ExecContext(mc.ctx, ` + CREATE TABLE psp (id INT, d INT) + PARTITION BY RANGE(id) + SUBPARTITION BY HASH(d) SUBPARTITIONS 3 ( + PARTITION p0 VALUES LESS THAN (10), + PARTITION p1 VALUES LESS THAN (20) + ) + `) + if err != nil { + t.Logf("C6.2 NOTE: subpartition DDL errored: %v (skipping verification)", err) + return + } + rows := spotCheckQuery(t, mc, ``, ` + SELECT PARTITION_NAME, SUBPARTITION_NAME FROM information_schema.PARTITIONS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='psp' + ORDER BY PARTITION_ORDINAL_POSITION, SUBPARTITION_ORDINAL_POSITION + `) + if len(rows) != 6 { + t.Errorf("CATALOG MISMATCH C6.2: expected 6 sub-part rows, got %d: %v", len(rows), rows) + } else { + var subs []string + for _, r := range rows { + subs = append(subs, asString(r[1])) + } + t.Logf("OK C6.2 verified: 6 subparts, names=%v", subs) + } + }) + + // --------------------------------------------------------------- + // C7.1: Default index algorithm = BTREE. + // --------------------------------------------------------------- + t.Run("C7_1_Default_index_algorithm_BTREE", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE tidx (a INT, KEY(a)); + `, ` + SELECT INDEX_TYPE FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='tidx' AND INDEX_NAME='a' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + if asString(rows[0][0]) != "BTREE" { + t.Errorf("CATALOG MISMATCH C7.1: expected BTREE, got %q", asString(rows[0][0])) + } else { + t.Logf("OK C7.1 verified: INDEX_TYPE=BTREE") + } + }) + + // --------------------------------------------------------------- + // C7.2: FK creates implicit backing index on child FK columns. + // --------------------------------------------------------------- + t.Run("C7_2_FK_backing_index", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE pfk (id INT PRIMARY KEY); + CREATE TABLE cfk (a INT, FOREIGN KEY (a) REFERENCES pfk(id)); + `, ` + SELECT INDEX_NAME, COLUMN_NAME FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='cfk' + ORDER BY INDEX_NAME, SEQ_IN_INDEX + `) + if len(rows) == 0 { + t.Errorf("CATALOG MISMATCH C7.2: expected at least 1 backing index, got 0") + } else { + t.Logf("OK C7.2 verified: %d index row(s) on cfk: %v", len(rows), rows) + } + }) + + // --------------------------------------------------------------- + // C8.1: Default engine = InnoDB. + // --------------------------------------------------------------- + t.Run("C8_1_Default_engine_InnoDB", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, `CREATE TABLE teng (x INT);`, ` + SELECT ENGINE FROM information_schema.TABLES + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='teng' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + if asString(rows[0][0]) != "InnoDB" { + t.Errorf("CATALOG MISMATCH C8.1: expected InnoDB, got %q", asString(rows[0][0])) + } else { + t.Logf("OK C8.1 verified: ENGINE=InnoDB") + } + }) + + // --------------------------------------------------------------- + // C8.2: Default ROW_FORMAT. + // --------------------------------------------------------------- + t.Run("C8_2_Default_row_format", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, `CREATE TABLE trf (x INT);`, ` + SELECT ROW_FORMAT FROM information_schema.TABLES + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='trf' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + rf := asString(rows[0][0]) + if rf != "Dynamic" && rf != "Compact" { + t.Errorf("CATALOG MISMATCH C8.2: expected Dynamic or Compact, got %q", rf) + } else { + t.Logf("OK C8.2 verified: ROW_FORMAT=%s", rf) + } + }) + + // --------------------------------------------------------------- + // C9.1: Generated column defaults to VIRTUAL. + // --------------------------------------------------------------- + t.Run("C9_1_GeneratedColumn_default_VIRTUAL", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE tgen (a INT, b INT AS (a + 1)); + `, ` + SELECT EXTRA FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='tgen' AND COLUMN_NAME='b' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + extra := strings.ToUpper(asString(rows[0][0])) + if !strings.Contains(extra, "VIRTUAL") { + t.Errorf("CATALOG MISMATCH C9.1: expected VIRTUAL GENERATED, got %q", extra) + } else { + t.Logf("OK C9.1 verified: EXTRA=%s", extra) + } + }) + + // --------------------------------------------------------------- + // C10.1 + C10.3 + C10.4: view algorithm, definer, check option. + // --------------------------------------------------------------- + t.Run("C10_1_3_4_View_defaults", func(t *testing.T) { + reset(t) + if _, err := mc.db.ExecContext(mc.ctx, `CREATE TABLE base10 (id INT)`); err != nil { + t.Fatalf("setup: %v", err) + } + if _, err := mc.db.ExecContext(mc.ctx, `CREATE VIEW v10 AS SELECT * FROM base10`); err != nil { + t.Fatalf("create view: %v", err) + } + rows := spotCheckQuery(t, mc, ``, ` + SELECT VIEW_DEFINITION, CHECK_OPTION, DEFINER FROM information_schema.VIEWS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='v10' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + checkOpt := asString(rows[0][1]) + definer := asString(rows[0][2]) + if checkOpt != "NONE" { + t.Errorf("CATALOG MISMATCH C10.4: expected CHECK_OPTION=NONE, got %q", checkOpt) + } else { + t.Logf("OK C10.4 verified: CHECK_OPTION=%s", checkOpt) + } + if definer == "" || definer == "" { + t.Errorf("CATALOG MISMATCH C10.3: expected DEFINER to be populated, got %q", definer) + } else { + t.Logf("OK C10.3 verified: DEFINER=%s", definer) + } + // C10.1: SHOW CREATE VIEW for ALGORITHM + stmt, err := mc.showCreateView("v10") + if err != nil { + t.Fatalf("show create view: %v", err) + } + up := strings.ToUpper(stmt) + if !strings.Contains(up, "ALGORITHM=UNDEFINED") { + t.Errorf("CATALOG MISMATCH C10.1: expected ALGORITHM=UNDEFINED in SHOW CREATE VIEW, got: %s", stmt) + } else { + t.Logf("OK C10.1 verified: ALGORITHM=UNDEFINED") + } + }) + + // --------------------------------------------------------------- + // C11.1: Trigger DEFINER defaults to current user. + // --------------------------------------------------------------- + t.Run("C11_1_Trigger_definer_default", func(t *testing.T) { + reset(t) + if _, err := mc.db.ExecContext(mc.ctx, `CREATE TABLE tt11 (a INT)`); err != nil { + t.Fatalf("setup: %v", err) + } + if _, err := mc.db.ExecContext(mc.ctx, + `CREATE TRIGGER trg11 BEFORE INSERT ON tt11 FOR EACH ROW SET NEW.a = NEW.a`); err != nil { + t.Fatalf("create trigger: %v", err) + } + rows := spotCheckQuery(t, mc, ``, ` + SELECT DEFINER FROM information_schema.TRIGGERS + WHERE TRIGGER_SCHEMA='sc' AND TRIGGER_NAME='trg11' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + def := asString(rows[0][0]) + if def == "" || def == "" { + t.Errorf("CATALOG MISMATCH C11.1: expected DEFINER populated, got %q", def) + } else { + t.Logf("OK C11.1 verified: DEFINER=%s", def) + } + }) + + // --------------------------------------------------------------- + // C14.1: CHECK CONSTRAINT ENFORCED by default. + // --------------------------------------------------------------- + t.Run("C14_1_Check_enforced_default", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE tchk14 (a INT, CONSTRAINT chk14 CHECK (a > 0)); + `, ` + SELECT ENFORCED FROM information_schema.TABLE_CONSTRAINTS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='tchk14' AND CONSTRAINT_NAME='chk14' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + if asString(rows[0][0]) != "YES" { + t.Errorf("CATALOG MISMATCH C14.1: expected ENFORCED=YES, got %q", asString(rows[0][0])) + } else { + t.Logf("OK C14.1 verified: ENFORCED=YES") + } + }) + + // --------------------------------------------------------------- + // C15.1: New column added via ALTER lands at end. + // --------------------------------------------------------------- + t.Run("C15_1_Column_positioning_end", func(t *testing.T) { + reset(t) + if _, err := mc.db.ExecContext(mc.ctx, `CREATE TABLE tpos (a INT, b INT)`); err != nil { + t.Fatalf("create: %v", err) + } + if _, err := mc.db.ExecContext(mc.ctx, `ALTER TABLE tpos ADD COLUMN c INT`); err != nil { + t.Fatalf("alter: %v", err) + } + rows := spotCheckQuery(t, mc, ``, ` + SELECT COLUMN_NAME FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='tpos' + ORDER BY ORDINAL_POSITION + `) + var names []string + for _, r := range rows { + names = append(names, asString(r[0])) + } + want := []string{"a", "b", "c"} + if len(names) != 3 || names[0] != want[0] || names[1] != want[1] || names[2] != want[2] { + t.Errorf("CATALOG MISMATCH C15.1: expected %v, got %v", want, names) + } else { + t.Logf("OK C15.1 verified: %v", names) + } + }) + + // --------------------------------------------------------------- + // C18.2: NOT NULL rendering in SHOW CREATE TABLE. + // --------------------------------------------------------------- + t.Run("C18_2_NotNull_rendering", func(t *testing.T) { + reset(t) + if _, err := mc.db.ExecContext(mc.ctx, `CREATE TABLE tnn (x INT, y INT NOT NULL)`); err != nil { + t.Fatalf("create: %v", err) + } + stmt, err := mc.showCreateTable("tnn") + if err != nil { + t.Fatalf("show create: %v", err) + } + up := strings.ToUpper(stmt) + // `x` line should NOT contain "NOT NULL"; `y` line SHOULD. + xIdx := strings.Index(stmt, "`x`") + yIdx := strings.Index(stmt, "`y`") + xLine := "" + yLine := "" + if xIdx >= 0 { + e := strings.Index(stmt[xIdx:], "\n") + if e < 0 { + e = len(stmt) - xIdx + } + xLine = stmt[xIdx : xIdx+e] + } + if yIdx >= 0 { + e := strings.Index(stmt[yIdx:], "\n") + if e < 0 { + e = len(stmt) - yIdx + } + yLine = stmt[yIdx : yIdx+e] + } + _ = up + if strings.Contains(strings.ToUpper(xLine), "NOT NULL") { + t.Errorf("CATALOG MISMATCH C18.2: expected `x` line to elide NOT NULL, got %q", xLine) + } else { + t.Logf("OK C18.2 (nullable elides): %q", xLine) + } + if !strings.Contains(strings.ToUpper(yLine), "NOT NULL") { + t.Errorf("CATALOG MISMATCH C18.2: expected `y` line to contain NOT NULL, got %q", yLine) + } else { + t.Logf("OK C18.2 (NOT NULL shown): %q", yLine) + } + }) + + // --------------------------------------------------------------- + // C21.1: DEFAULT NULL on nullable column. + // --------------------------------------------------------------- + t.Run("C21_1_Default_NULL", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE td (x INT DEFAULT NULL); + `, ` + SELECT COLUMN_DEFAULT, IS_NULLABLE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='td' AND COLUMN_NAME='x' + `) + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + def := asString(rows[0][0]) + nullable := asString(rows[0][1]) + if def != "" && def != "NULL" { + t.Errorf("CATALOG MISMATCH C21.1: expected COLUMN_DEFAULT=NULL, got %q", def) + } + if nullable != "YES" { + t.Errorf("CATALOG MISMATCH C21.1: expected IS_NULLABLE=YES, got %q", nullable) + } + t.Logf("OK C21.1 verified: default=%q nullable=%q", def, nullable) + }) + + // --------------------------------------------------------------- + // C25.1 — original DECIMAL test. Keep below. + // --------------------------------------------------------------- + t.Run("C25_1_decimal_default_10_0", func(t *testing.T) { + reset(t) + rows := spotCheckQuery(t, mc, ` + CREATE TABLE d (x DECIMAL, y DECIMAL(8), z NUMERIC); + `, ` + SELECT COLUMN_NAME, COLUMN_TYPE, NUMERIC_PRECISION, NUMERIC_SCALE + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='sc' AND TABLE_NAME='d' + ORDER BY ORDINAL_POSITION + `) + if len(rows) != 3 { + t.Fatalf("expected 3 cols, got %d", len(rows)) + } + // x DECIMAL -> decimal(10,0) + xType := asString(rows[0][1]) + xPrec := asString(rows[0][2]) + xScale := asString(rows[0][3]) + if xType != "decimal(10,0)" || xPrec != "10" || xScale != "0" { + t.Errorf("CATALOG MISMATCH C25.1 DECIMAL: type=%q prec=%q scale=%q (expected decimal(10,0), 10, 0)", + xType, xPrec, xScale) + } else { + t.Logf("OK C25.1 DECIMAL verified: %s prec=%s scale=%s", xType, xPrec, xScale) + } + // y DECIMAL(8) -> decimal(8,0) + yType := asString(rows[1][1]) + if yType != "decimal(8,0)" { + t.Errorf("CATALOG MISMATCH C25.1 DECIMAL(8): type=%q (expected decimal(8,0))", yType) + } else { + t.Logf("OK C25.1 DECIMAL(8) verified: %s", yType) + } + // z NUMERIC -> decimal(10,0) (NUMERIC is alias for DECIMAL) + zType := asString(rows[2][1]) + if zType != "decimal(10,0)" { + t.Errorf("CATALOG MISMATCH C25.1 NUMERIC: type=%q (expected decimal(10,0))", zType) + } else { + t.Logf("OK C25.1 NUMERIC verified: %s", zType) + } + }) +} diff --git a/tidb/catalog/constraint.go b/tidb/catalog/constraint.go new file mode 100644 index 00000000..aac871da --- /dev/null +++ b/tidb/catalog/constraint.go @@ -0,0 +1,27 @@ +package catalog + +type ConstraintType int + +const ( + ConPrimaryKey ConstraintType = iota + ConUniqueKey + ConForeignKey + ConCheck +) + +type Constraint struct { + Name string + Type ConstraintType + Table *Table + Columns []string + IndexName string + RefDatabase string + RefTable string + RefColumns []string + OnDelete string + OnUpdate string + MatchType string + CheckExpr string + NotEnforced bool + CheckAnalyzed AnalyzedExpr // Phase 3: analyzed CHECK expression body +} diff --git a/tidb/catalog/container_comprehensive_test.go b/tidb/catalog/container_comprehensive_test.go new file mode 100644 index 00000000..b30ccf8d --- /dev/null +++ b/tidb/catalog/container_comprehensive_test.go @@ -0,0 +1,100 @@ +package catalog + +import "testing" + +func TestDDLWorkflow_Container(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + steps := []struct { + name string + sql string + check string // table to SHOW CREATE TABLE after this step + }{ + {"create_basic", "CREATE TABLE users (id INT NOT NULL AUTO_INCREMENT, name VARCHAR(100) NOT NULL, email VARCHAR(255), PRIMARY KEY (id), UNIQUE KEY idx_email (email)) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4", "users"}, + {"add_column", "ALTER TABLE users ADD COLUMN age INT DEFAULT 0", "users"}, + {"add_index", "CREATE INDEX idx_name ON users (name)", "users"}, + {"modify_column", "ALTER TABLE users MODIFY COLUMN name VARCHAR(200) NOT NULL", "users"}, + {"drop_index", "DROP INDEX idx_name ON users", "users"}, + {"create_orders", "CREATE TABLE orders (id INT NOT NULL AUTO_INCREMENT, user_id INT NOT NULL, amount DECIMAL(10,2), created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (id), KEY idx_user (user_id)) ENGINE=InnoDB", "orders"}, + {"rename_column", "ALTER TABLE users CHANGE COLUMN email email_address VARCHAR(255)", "users"}, + {"drop_column", "ALTER TABLE users DROP COLUMN age", "users"}, + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + + for _, step := range steps { + t.Run(step.name, func(t *testing.T) { + if err := ctr.execSQL(step.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + results, err := c.Exec(step.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + + if step.check == "" { + return + } + + ctrDDL, err := ctr.showCreateTable(step.check) + if err != nil { + t.Fatalf("container show create: %v", err) + } + omniDDL := c.ShowCreateTable("test", step.check) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("SHOW CREATE TABLE mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestShowCreateTable_ContainerComparison(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"basic_types", "CREATE TABLE t_types (a INT, b VARCHAR(100), c TEXT, d DECIMAL(10,2), e DATETIME)", "t_types"}, + {"not_null_default", "CREATE TABLE t_defaults (id INT NOT NULL, name VARCHAR(50) DEFAULT 'test', active TINYINT(1) DEFAULT 1)", "t_defaults"}, + {"auto_increment_pk", "CREATE TABLE t_auto (id INT NOT NULL AUTO_INCREMENT, PRIMARY KEY (id))", "t_auto"}, + {"multi_col_pk", "CREATE TABLE t_multi_pk (a INT NOT NULL, b INT NOT NULL, c VARCHAR(10), PRIMARY KEY (a, b))", "t_multi_pk"}, + {"unique_index", "CREATE TABLE t_unique (id INT, email VARCHAR(255), UNIQUE KEY idx_email (email))", "t_unique"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec(tc.sql, nil) + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", ctrDDL, omniDDL) + } + }) + } +} diff --git a/tidb/catalog/container_reserved_kw_test.go b/tidb/catalog/container_reserved_kw_test.go new file mode 100644 index 00000000..67b2fa19 --- /dev/null +++ b/tidb/catalog/container_reserved_kw_test.go @@ -0,0 +1,159 @@ +package catalog + +import ( + "fmt" + "strings" + "testing" + + mysqlparser "github.com/bytebase/omni/tidb/parser" +) + +// TestContainer_ReservedKeywordAcceptance systematically tests whether omni +// and MySQL 8.0 agree on which reserved keywords are accepted in various +// syntactic "name" positions. A mismatch (MySQL accepts, omni rejects) +// reveals a parser bug where isIdentToken/parseIdent is too restrictive. +// +// This is a diagnostic test — it reports all mismatches rather than failing +// on the first one, so we get a complete gap picture. +func TestContainer_ReservedKeywordAcceptance(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // All reserved keywords from mysql/parser/name.go. + // We extract them programmatically from the parser's keyword table. + reservedKeywords := getReservedKeywords() + if len(reservedKeywords) == 0 { + t.Fatal("no reserved keywords found") + } + t.Logf("testing %d reserved keywords", len(reservedKeywords)) + + // Each template defines a syntactic position where an identifier might + // be a reserved keyword. The %s placeholder is where the keyword goes. + // We use backtick-quoting in setup SQL to avoid interfering. + type namePosition struct { + name string + setup string // SQL to run before the test (on both container and omni) + template string // SQL with %s for the keyword being tested + cleanup string // SQL to run after each keyword attempt + } + + positions := []namePosition{ + { + name: "CHARACTER SET value", + setup: "CREATE DATABASE IF NOT EXISTS kw_test; USE kw_test", + template: "CREATE TABLE kw_cs_test (a VARCHAR(50) CHARACTER SET %s)", + cleanup: "DROP TABLE IF EXISTS kw_cs_test", + }, + { + name: "COLLATE value", + setup: "CREATE DATABASE IF NOT EXISTS kw_test; USE kw_test", + template: "CREATE TABLE kw_co_test (a VARCHAR(50) COLLATE %s)", + cleanup: "DROP TABLE IF EXISTS kw_co_test", + }, + { + name: "ENGINE value", + setup: "CREATE DATABASE IF NOT EXISTS kw_test; USE kw_test", + template: "CREATE TABLE kw_eng_test (id INT) ENGINE=%s", + cleanup: "DROP TABLE IF EXISTS kw_eng_test", + }, + { + name: "INDEX name", + setup: "CREATE DATABASE IF NOT EXISTS kw_test; USE kw_test; CREATE TABLE kw_idx_base (id INT, val INT)", + template: "CREATE INDEX %s ON kw_idx_base (val)", + cleanup: "DROP INDEX %s ON kw_idx_base", + }, + { + name: "CONSTRAINT name in CREATE TABLE", + setup: "CREATE DATABASE IF NOT EXISTS kw_test; USE kw_test", + template: "CREATE TABLE kw_con_test (id INT, CONSTRAINT %s UNIQUE (id))", + cleanup: "DROP TABLE IF EXISTS kw_con_test", + }, + { + name: "Column alias in SELECT", + setup: "CREATE DATABASE IF NOT EXISTS kw_test; USE kw_test", + template: "SELECT 1 AS %s", + cleanup: "", + }, + { + name: "Table alias in SELECT", + setup: "CREATE DATABASE IF NOT EXISTS kw_test; USE kw_test; CREATE TABLE kw_alias_base (id INT)", + template: "SELECT * FROM kw_alias_base %s", + cleanup: "", + }, + } + + for _, pos := range positions { + t.Run(pos.name, func(t *testing.T) { + // Setup + if pos.setup != "" { + if err := ctr.execSQL(pos.setup); err != nil { + t.Fatalf("container setup: %v", err) + } + } + + var mismatches []string + + for _, kw := range reservedKeywords { + sql := fmt.Sprintf(pos.template, kw) + + // Try on MySQL 8.0 + ctrErr := ctr.execSQL(sql) + + // Try on omni (parse-only — we just check if it parses) + _, omniErr := mysqlparser.Parse(sql) + + // Cleanup on container + if pos.cleanup != "" { + cleanSQL := fmt.Sprintf(pos.cleanup, kw) + ctr.execSQL(cleanSQL) //nolint:errcheck + } + + ctrOK := ctrErr == nil + omniOK := omniErr == nil + + if ctrOK && !omniOK { + mismatches = append(mismatches, fmt.Sprintf( + " %s: MySQL accepts, omni rejects — %v", kw, omniErr)) + } + // We don't care about the reverse (omni accepts, MySQL rejects) + // — that's leniency, not a bug. + } + + if len(mismatches) > 0 { + t.Errorf("%d keywords accepted by MySQL but rejected by omni:\n%s", + len(mismatches), strings.Join(mismatches, "\n")) + } else { + t.Logf("all %d keywords match between MySQL and omni", len(reservedKeywords)) + } + }) + } +} + +// getReservedKeywords returns all reserved keyword strings from the parser. +func getReservedKeywords() []string { + // Use the parser's exported TokenName to reverse-map token types to names. + // We check all keyword token types (>= 700) and filter for reserved ones. + var keywords []string + seen := make(map[string]bool) + + // Scan the full range of possible keyword token types. + // MySQL keywords are in the range 700-1500 approximately. + for tok := 700; tok < 1500; tok++ { + name := mysqlparser.TokenName(tok) + if name == "" { + continue + } + // Check if this is a reserved keyword by trying to use it as an identifier. + // If Parse rejects "SELECT " as a column alias (without AS), it's reserved. + // But a simpler approach: just collect all keyword names and let the test filter. + lower := strings.ToLower(name) + if !seen[lower] { + seen[lower] = true + keywords = append(keywords, name) + } + } + return keywords +} diff --git a/tidb/catalog/container_scenarios_test.go b/tidb/catalog/container_scenarios_test.go new file mode 100644 index 00000000..8af9b556 --- /dev/null +++ b/tidb/catalog/container_scenarios_test.go @@ -0,0 +1,6747 @@ +package catalog + +import ( + "errors" + "strings" + "testing" + + mysqldriver "github.com/go-sql-driver/mysql" +) + +func TestContainer_Section_1_2_StringTypes(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"char_10", "CREATE TABLE t_char10 (a CHAR(10))", "t_char10"}, + {"char_no_length", "CREATE TABLE t_char1 (a CHAR)", "t_char1"}, + {"varchar_255", "CREATE TABLE t_varchar255 (a VARCHAR(255))", "t_varchar255"}, + {"varchar_16383", "CREATE TABLE t_varchar16383 (a VARCHAR(16383))", "t_varchar16383"}, + {"tinytext", "CREATE TABLE t_tinytext (a TINYTEXT)", "t_tinytext"}, + {"text", "CREATE TABLE t_text (a TEXT)", "t_text"}, + {"mediumtext", "CREATE TABLE t_mediumtext (a MEDIUMTEXT)", "t_mediumtext"}, + {"longtext", "CREATE TABLE t_longtext (a LONGTEXT)", "t_longtext"}, + {"text_1000", "CREATE TABLE t_text1000 (a TEXT(1000))", "t_text1000"}, + {"enum_basic", "CREATE TABLE t_enum (a ENUM('a','b','c'))", "t_enum"}, + {"enum_special_chars", "CREATE TABLE t_enum_sc (a ENUM('it''s','hello,world','a\"b'))", "t_enum_sc"}, + {"set_basic", "CREATE TABLE t_set (a SET('x','y','z'))", "t_set"}, + {"char_charset_latin1", "CREATE TABLE t_char_cs (a CHAR(10) CHARACTER SET latin1)", "t_char_cs"}, + {"varchar_charset_collate", "CREATE TABLE t_varchar_cc (a VARCHAR(100) CHARACTER SET utf8mb3 COLLATE utf8mb3_bin)", "t_varchar_cc"}, + {"national_char", "CREATE TABLE t_nchar (a NATIONAL CHAR(10))", "t_nchar"}, + {"nchar", "CREATE TABLE t_nchar2 (a NCHAR(10))", "t_nchar2"}, + {"nvarchar", "CREATE TABLE t_nvarchar (a NVARCHAR(100))", "t_nvarchar"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_4_DateTimeTypes(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"date", "CREATE TABLE t_date (a DATE)", "t_date"}, + {"time", "CREATE TABLE t_time (a TIME)", "t_time"}, + {"time_fsp", "CREATE TABLE t_time3 (a TIME(3))", "t_time3"}, + {"datetime", "CREATE TABLE t_datetime (a DATETIME)", "t_datetime"}, + {"datetime_fsp", "CREATE TABLE t_datetime6 (a DATETIME(6))", "t_datetime6"}, + {"timestamp", "CREATE TABLE t_timestamp (a TIMESTAMP)", "t_timestamp"}, + {"timestamp_fsp", "CREATE TABLE t_timestamp3 (a TIMESTAMP(3))", "t_timestamp3"}, + {"year", "CREATE TABLE t_year (a YEAR)", "t_year"}, + {"year_4", "CREATE TABLE t_year4 (a YEAR(4))", "t_year4"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_1_NumericTypes(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"int_basic", "CREATE TABLE t_int (a INT)", "t_int"}, + {"int_display_width", "CREATE TABLE t_int_dw (a INT(11))", "t_int_dw"}, + {"int_unsigned", "CREATE TABLE t_int_u (a INT UNSIGNED)", "t_int_u"}, + {"int_unsigned_zerofill", "CREATE TABLE t_int_uz (a INT UNSIGNED ZEROFILL)", "t_int_uz"}, + {"tinyint", "CREATE TABLE t_tinyint (a TINYINT)", "t_tinyint"}, + {"smallint", "CREATE TABLE t_smallint (a SMALLINT)", "t_smallint"}, + {"mediumint", "CREATE TABLE t_mediumint (a MEDIUMINT)", "t_mediumint"}, + {"bigint", "CREATE TABLE t_bigint (a BIGINT)", "t_bigint"}, + {"bigint_unsigned", "CREATE TABLE t_bigint_u (a BIGINT UNSIGNED)", "t_bigint_u"}, + {"float_basic", "CREATE TABLE t_float (a FLOAT)", "t_float"}, + {"float_precision", "CREATE TABLE t_float_p (a FLOAT(7,3))", "t_float_p"}, + {"float_unsigned", "CREATE TABLE t_float_u (a FLOAT UNSIGNED)", "t_float_u"}, + {"double_basic", "CREATE TABLE t_double (a DOUBLE)", "t_double"}, + {"double_precision_alias", "CREATE TABLE t_double_p (a DOUBLE PRECISION)", "t_double_p"}, + {"double_with_precision", "CREATE TABLE t_double_wp (a DOUBLE(15,5))", "t_double_wp"}, + {"decimal_precision", "CREATE TABLE t_decimal (a DECIMAL(10,2))", "t_decimal"}, + {"numeric_precision", "CREATE TABLE t_numeric (a NUMERIC(10,2))", "t_numeric"}, + {"decimal_no_precision", "CREATE TABLE t_decimal_np (a DECIMAL)", "t_decimal_np"}, + {"boolean", "CREATE TABLE t_bool (a BOOLEAN)", "t_bool"}, + {"bool_alias", "CREATE TABLE t_bool2 (a BOOL)", "t_bool2"}, + {"bit_1", "CREATE TABLE t_bit1 (a BIT(1))", "t_bit1"}, + {"bit_8", "CREATE TABLE t_bit8 (a BIT(8))", "t_bit8"}, + {"bit_64", "CREATE TABLE t_bit64 (a BIT(64))", "t_bit64"}, + {"serial", "CREATE TABLE t_serial (a SERIAL)", "t_serial"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec(tc.sql, nil) + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_10_ColumnAttributesCombination(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"int_not_null_auto_increment", "CREATE TABLE t_ai1 (a INT NOT NULL AUTO_INCREMENT, PRIMARY KEY (a))", "t_ai1"}, + {"bigint_unsigned_not_null_auto_increment", "CREATE TABLE t_ai2 (a BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY (a))", "t_ai2"}, + {"varchar_not_null_default_empty", "CREATE TABLE t_vnde (a VARCHAR(100) NOT NULL DEFAULT '')", "t_vnde"}, + {"varchar_charset_collate_not_null", "CREATE TABLE t_vccnn (a VARCHAR(100) CHARACTER SET utf8mb3 COLLATE utf8mb3_bin NOT NULL)", "t_vccnn"}, + {"int_not_null_comment", "CREATE TABLE t_innc (a INT NOT NULL COMMENT 'user id')", "t_innc"}, + {"varchar_invisible", "CREATE TABLE t_vinv (a INT, b VARCHAR(255) INVISIBLE)", "t_vinv"}, + {"int_visible_not_shown", "CREATE TABLE t_ivis (a INT VISIBLE)", "t_ivis"}, + {"all_attributes", "CREATE TABLE t_all (a INT UNSIGNED NOT NULL DEFAULT '0' COMMENT 'count')", "t_all"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_7_DefaultValues(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"int_default_0", "CREATE TABLE t_def_int0 (a INT DEFAULT 0)", "t_def_int0"}, + {"int_default_null", "CREATE TABLE t_def_intn (a INT DEFAULT NULL)", "t_def_intn"}, + {"int_not_null", "CREATE TABLE t_def_intnn (a INT NOT NULL)", "t_def_intnn"}, + {"varchar_default_hello", "CREATE TABLE t_def_vch (a VARCHAR(50) DEFAULT 'hello')", "t_def_vch"}, + {"varchar_default_empty", "CREATE TABLE t_def_vce (a VARCHAR(50) DEFAULT '')", "t_def_vce"}, + {"float_default", "CREATE TABLE t_def_flt (a FLOAT DEFAULT 3.14)", "t_def_flt"}, + {"decimal_default", "CREATE TABLE t_def_dec (a DECIMAL(10,2) DEFAULT 0.00)", "t_def_dec"}, + {"bool_default_true", "CREATE TABLE t_def_bt (a BOOLEAN DEFAULT TRUE)", "t_def_bt"}, + {"bool_default_false", "CREATE TABLE t_def_bf (a BOOLEAN DEFAULT FALSE)", "t_def_bf"}, + {"enum_default", "CREATE TABLE t_def_enum (a ENUM('a','b','c') DEFAULT 'a')", "t_def_enum"}, + {"set_default", "CREATE TABLE t_def_set (a SET('x','y','z') DEFAULT 'x,y')", "t_def_set"}, + {"bit_default", "CREATE TABLE t_def_bit (a BIT(8) DEFAULT b'00001111')", "t_def_bit"}, + {"blob_no_default_null", "CREATE TABLE t_def_blob (a BLOB)", "t_def_blob"}, + {"text_no_default_null", "CREATE TABLE t_def_text (a TEXT)", "t_def_text"}, + {"json_no_default_null", "CREATE TABLE t_def_json (a JSON)", "t_def_json"}, + {"timestamp_default_ct", "CREATE TABLE t_def_tsct (a TIMESTAMP DEFAULT CURRENT_TIMESTAMP)", "t_def_tsct"}, + {"datetime_default_ct", "CREATE TABLE t_def_dtct (a DATETIME DEFAULT CURRENT_TIMESTAMP)", "t_def_dtct"}, + {"timestamp3_default_ct3", "CREATE TABLE t_def_ts3 (a TIMESTAMP(3) DEFAULT CURRENT_TIMESTAMP(3))", "t_def_ts3"}, + {"expr_default_int", "CREATE TABLE t_def_expr (a INT DEFAULT (FLOOR(RAND()*100)))", "t_def_expr"}, + {"expr_default_json", "CREATE TABLE t_def_exjson (a JSON DEFAULT (JSON_ARRAY()))", "t_def_exjson"}, + {"expr_default_varchar", "CREATE TABLE t_def_exvc (a VARCHAR(36) DEFAULT (UUID()))", "t_def_exvc"}, + {"datetime_default_literal", "CREATE TABLE t_def_dtlit (a DATETIME DEFAULT '2024-01-01 00:00:00')", "t_def_dtlit"}, + {"date_default_literal", "CREATE TABLE t_def_dtlit2 (a DATE DEFAULT '2024-01-01')", "t_def_dtlit2"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_11_PrimaryKey(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"pk_single_column", "CREATE TABLE t_pk1 (id INT NOT NULL, PRIMARY KEY (id))", "t_pk1"}, + {"pk_multi_column", "CREATE TABLE t_pk2 (a INT NOT NULL, b INT NOT NULL, PRIMARY KEY (a, b))", "t_pk2"}, + {"pk_bigint_unsigned_auto_inc", "CREATE TABLE t_pk3 (id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, name VARCHAR(100), PRIMARY KEY (id))", "t_pk3"}, + {"pk_column_ordering", "CREATE TABLE t_pk4 (c INT NOT NULL, b INT NOT NULL, a INT NOT NULL, PRIMARY KEY (b, a))", "t_pk4"}, + {"pk_implicit_not_null", "CREATE TABLE t_pk5 (id INT, PRIMARY KEY (id))", "t_pk5"}, + {"pk_name_not_shown", "CREATE TABLE t_pk6 (id INT NOT NULL, PRIMARY KEY (id))", "t_pk6"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_13_RegularIndexes(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"key_named", "CREATE TABLE t_idx1 (a INT, KEY `idx_a` (a))", "t_idx1"}, + {"key_auto_named", "CREATE TABLE t_idx2 (a INT, KEY (a))", "t_idx2"}, + {"key_multi_column", "CREATE TABLE t_idx3 (a INT, b INT, c INT, KEY `idx_abc` (a, b, c))", "t_idx3"}, + {"key_prefix_length", "CREATE TABLE t_idx4 (a VARCHAR(255), KEY `idx_a` (a(10)))", "t_idx4"}, + {"key_desc", "CREATE TABLE t_idx5 (a INT, KEY `idx_a` (a DESC))", "t_idx5"}, + {"key_mixed_asc_desc", "CREATE TABLE t_idx6 (a INT, b INT, KEY `idx_ab` (a ASC, b DESC))", "t_idx6"}, + {"key_using_hash", "CREATE TABLE t_idx7 (a INT, KEY `idx_a` (a) USING HASH) ENGINE=MEMORY", "t_idx7"}, + {"key_using_btree", "CREATE TABLE t_idx8 (a INT, KEY `idx_a` (a) USING BTREE)", "t_idx8"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_17_ForeignKeys(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + setup string // SQL to run before the main CREATE TABLE (e.g., parent tables) + sql string // The CREATE TABLE with FK to compare + table string // The table to SHOW CREATE TABLE on + cleanup string // SQL to clean up after (drop child then parent) + }{ + { + "fk_basic", + "CREATE TABLE t_parent1 (id INT NOT NULL, PRIMARY KEY (id)) ENGINE=InnoDB", + "CREATE TABLE t_fk1 (id INT, pid INT, FOREIGN KEY (pid) REFERENCES t_parent1(id))", + "t_fk1", + "DROP TABLE IF EXISTS t_fk1; DROP TABLE IF EXISTS t_parent1", + }, + { + "fk_named", + "CREATE TABLE t_parent2 (id INT NOT NULL, PRIMARY KEY (id)) ENGINE=InnoDB", + "CREATE TABLE t_fk2 (id INT, pid INT, CONSTRAINT `fk_name` FOREIGN KEY (pid) REFERENCES t_parent2(id))", + "t_fk2", + "DROP TABLE IF EXISTS t_fk2; DROP TABLE IF EXISTS t_parent2", + }, + { + "fk_on_delete_cascade", + "CREATE TABLE t_parent3 (id INT NOT NULL, PRIMARY KEY (id)) ENGINE=InnoDB", + "CREATE TABLE t_fk3 (id INT, pid INT, FOREIGN KEY (pid) REFERENCES t_parent3(id) ON DELETE CASCADE)", + "t_fk3", + "DROP TABLE IF EXISTS t_fk3; DROP TABLE IF EXISTS t_parent3", + }, + { + "fk_on_delete_set_null", + "CREATE TABLE t_parent4 (id INT NOT NULL, PRIMARY KEY (id)) ENGINE=InnoDB", + "CREATE TABLE t_fk4 (id INT, pid INT, FOREIGN KEY (pid) REFERENCES t_parent4(id) ON DELETE SET NULL)", + "t_fk4", + "DROP TABLE IF EXISTS t_fk4; DROP TABLE IF EXISTS t_parent4", + }, + { + "fk_on_delete_set_default", + "CREATE TABLE t_parent5 (id INT NOT NULL, PRIMARY KEY (id)) ENGINE=InnoDB", + "CREATE TABLE t_fk5 (id INT, pid INT, FOREIGN KEY (pid) REFERENCES t_parent5(id) ON DELETE SET DEFAULT)", + "t_fk5", + "DROP TABLE IF EXISTS t_fk5; DROP TABLE IF EXISTS t_parent5", + }, + { + "fk_on_delete_restrict", + "CREATE TABLE t_parent6 (id INT NOT NULL, PRIMARY KEY (id)) ENGINE=InnoDB", + "CREATE TABLE t_fk6 (id INT, pid INT, FOREIGN KEY (pid) REFERENCES t_parent6(id) ON DELETE RESTRICT)", + "t_fk6", + "DROP TABLE IF EXISTS t_fk6; DROP TABLE IF EXISTS t_parent6", + }, + { + "fk_on_delete_no_action", + "CREATE TABLE t_parent7 (id INT NOT NULL, PRIMARY KEY (id)) ENGINE=InnoDB", + "CREATE TABLE t_fk7 (id INT, pid INT, FOREIGN KEY (pid) REFERENCES t_parent7(id) ON DELETE NO ACTION)", + "t_fk7", + "DROP TABLE IF EXISTS t_fk7; DROP TABLE IF EXISTS t_parent7", + }, + { + "fk_on_update_cascade", + "CREATE TABLE t_parent8 (id INT NOT NULL, PRIMARY KEY (id)) ENGINE=InnoDB", + "CREATE TABLE t_fk8 (id INT, pid INT, FOREIGN KEY (pid) REFERENCES t_parent8(id) ON UPDATE CASCADE)", + "t_fk8", + "DROP TABLE IF EXISTS t_fk8; DROP TABLE IF EXISTS t_parent8", + }, + { + "fk_on_update_set_null", + "CREATE TABLE t_parent9 (id INT NOT NULL, PRIMARY KEY (id)) ENGINE=InnoDB", + "CREATE TABLE t_fk9 (id INT, pid INT, FOREIGN KEY (pid) REFERENCES t_parent9(id) ON UPDATE SET NULL)", + "t_fk9", + "DROP TABLE IF EXISTS t_fk9; DROP TABLE IF EXISTS t_parent9", + }, + { + "fk_combined_actions", + "CREATE TABLE t_parent10 (id INT NOT NULL, PRIMARY KEY (id)) ENGINE=InnoDB", + "CREATE TABLE t_fk10 (id INT, pid INT, FOREIGN KEY (pid) REFERENCES t_parent10(id) ON DELETE CASCADE ON UPDATE SET NULL)", + "t_fk10", + "DROP TABLE IF EXISTS t_fk10; DROP TABLE IF EXISTS t_parent10", + }, + { + "fk_auto_naming", + "CREATE TABLE t_parent11 (id INT NOT NULL, id2 INT NOT NULL, PRIMARY KEY (id), KEY (id2)) ENGINE=InnoDB", + "CREATE TABLE t_fk11 (id INT, pid INT, pid2 INT, FOREIGN KEY (pid) REFERENCES t_parent11(id), FOREIGN KEY (pid2) REFERENCES t_parent11(id2))", + "t_fk11", + "DROP TABLE IF EXISTS t_fk11; DROP TABLE IF EXISTS t_parent11", + }, + { + "fk_auto_generates_index", + "CREATE TABLE t_parent12 (id INT NOT NULL, PRIMARY KEY (id)) ENGINE=InnoDB", + "CREATE TABLE t_fk12 (id INT, pid INT, FOREIGN KEY (pid) REFERENCES t_parent12(id))", + "t_fk12", + "DROP TABLE IF EXISTS t_fk12; DROP TABLE IF EXISTS t_parent12", + }, + { + "fk_self_referencing", + "", + "CREATE TABLE t_fk13 (id INT NOT NULL, parent_id INT, PRIMARY KEY (id), FOREIGN KEY (parent_id) REFERENCES t_fk13(id))", + "t_fk13", + "DROP TABLE IF EXISTS t_fk13", + }, + { + "fk_multi_column", + "CREATE TABLE t_parent14 (x INT NOT NULL, y INT NOT NULL, PRIMARY KEY (x, y)) ENGINE=InnoDB", + "CREATE TABLE t_fk14 (id INT, a INT, b INT, FOREIGN KEY (a, b) REFERENCES t_parent14(x, y))", + "t_fk14", + "DROP TABLE IF EXISTS t_fk14; DROP TABLE IF EXISTS t_parent14", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // Cleanup from prior runs. + ctr.execSQL(tc.cleanup) + + // Setup parent tables. + if tc.setup != "" { + if err := ctr.execSQL(tc.setup); err != nil { + t.Fatalf("container setup: %v", err) + } + } + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + // Run setup on omni too. + if tc.setup != "" { + results, err := c.Exec(tc.setup, nil) + if err != nil { + t.Fatalf("omni setup parse error: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Fatalf("omni setup exec error: %v", r.Error) + } + } + } + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_12_UniqueKeys(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"unique_key_named", "CREATE TABLE t_uk1 (a INT, UNIQUE KEY `uk_a` (a))", "t_uk1"}, + {"unique_key_auto_named", "CREATE TABLE t_uk2 (a INT, UNIQUE KEY (a))", "t_uk2"}, + {"unique_key_multi_column", "CREATE TABLE t_uk3 (a INT, b INT, UNIQUE KEY `uk_ab` (a, b))", "t_uk3"}, + {"multiple_unique_keys", "CREATE TABLE t_uk4 (a INT, b INT, c INT, UNIQUE KEY `uk_a` (a), UNIQUE KEY `uk_b` (b))", "t_uk4"}, + {"unique_key_auto_name_collision", "CREATE TABLE t_uk5 (a INT, b INT, c INT, UNIQUE KEY (a), UNIQUE KEY (a, b), UNIQUE KEY (a, c))", "t_uk5"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_18_CheckConstraints(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"check_basic", "CREATE TABLE t_chk1 (a INT, CHECK (a > 0))", "t_chk1"}, + {"check_named", "CREATE TABLE t_chk2 (a INT, CONSTRAINT `chk_name` CHECK (a > 0))", "t_chk2"}, + {"check_not_enforced", "CREATE TABLE t_chk3 (a INT, CHECK (a > 0) NOT ENFORCED)", "t_chk3"}, + {"check_auto_naming", "CREATE TABLE t_chk4 (a INT, b INT, CHECK (a > 0), CHECK (b > 0))", "t_chk4"}, + {"check_expr_parens", "CREATE TABLE t_chk5 (a INT, CHECK (a > 0 AND a < 100))", "t_chk5"}, + {"check_multi_col", "CREATE TABLE t_chk6 (a INT, b INT, CHECK (a > b))", "t_chk6"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_19_TableOptions(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"engine_innodb", "CREATE TABLE t_eng_innodb (a INT) ENGINE=InnoDB", "t_eng_innodb"}, + {"engine_myisam", "CREATE TABLE t_eng_myisam (a INT) ENGINE=MyISAM", "t_eng_myisam"}, + {"engine_memory", "CREATE TABLE t_eng_memory (a INT) ENGINE=MEMORY", "t_eng_memory"}, + {"charset_utf8mb4_default_collation", "CREATE TABLE t_cs_utf8mb4 (a INT) DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci", "t_cs_utf8mb4"}, + {"charset_latin1", "CREATE TABLE t_cs_latin1 (a INT) DEFAULT CHARSET=latin1 COLLATE=latin1_swedish_ci", "t_cs_latin1"}, + {"charset_utf8mb4_unicode_ci", "CREATE TABLE t_cs_unicode (a INT) DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci", "t_cs_unicode"}, + {"comment_basic", "CREATE TABLE t_comment (a INT) COMMENT='table description'", "t_comment"}, + {"comment_special_chars", "CREATE TABLE t_comment_sc (a INT) COMMENT='it\\'s a \\\\test'", "t_comment_sc"}, + {"row_format_dynamic", "CREATE TABLE t_rf_dyn (a INT) ROW_FORMAT=DYNAMIC", "t_rf_dyn"}, + {"row_format_compressed", "CREATE TABLE t_rf_comp (a INT) ROW_FORMAT=COMPRESSED", "t_rf_comp"}, + {"auto_increment_1000", "CREATE TABLE t_ai1000 (id INT NOT NULL AUTO_INCREMENT, PRIMARY KEY (id)) AUTO_INCREMENT=1000", "t_ai1000"}, + {"key_block_size_8", "CREATE TABLE t_kbs8 (a INT) KEY_BLOCK_SIZE=8", "t_kbs8"}, + {"multiple_options", "CREATE TABLE t_multi_opts (id INT NOT NULL AUTO_INCREMENT, name VARCHAR(100), PRIMARY KEY (id)) ENGINE=InnoDB AUTO_INCREMENT=500 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci ROW_FORMAT=DYNAMIC COMMENT='multi opts'", "t_multi_opts"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_8_OnUpdate(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"timestamp_on_update", "CREATE TABLE t_ou1 (a TIMESTAMP ON UPDATE CURRENT_TIMESTAMP)", "t_ou1"}, + {"datetime3_on_update", "CREATE TABLE t_ou2 (a DATETIME(3) ON UPDATE CURRENT_TIMESTAMP(3))", "t_ou2"}, + {"timestamp_default_and_on_update", "CREATE TABLE t_ou3 (a TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP)", "t_ou3"}, + {"datetime6_default_and_on_update", "CREATE TABLE t_ou4 (a DATETIME(6) DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6))", "t_ou4"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_20_CharsetCollationInheritance(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + setupSQL string + sql string + table string + database string + }{ + { + "table_charset_inherited_from_database", + "DROP DATABASE IF EXISTS db_latin1; CREATE DATABASE db_latin1 CHARACTER SET latin1; USE db_latin1", + "CREATE TABLE t_inherit_db (a VARCHAR(100))", + "t_inherit_db", + "db_latin1", + }, + { + "column_charset_inherited_from_table", + "DROP DATABASE IF EXISTS db_utf8mb4_tbl; CREATE DATABASE db_utf8mb4_tbl; USE db_utf8mb4_tbl", + "CREATE TABLE t_col_inherit (a VARCHAR(50)) DEFAULT CHARSET=latin1", + "t_col_inherit", + "db_utf8mb4_tbl", + }, + { + "column_charset_overrides_table", + "DROP DATABASE IF EXISTS db_override_cs; CREATE DATABASE db_override_cs; USE db_override_cs", + "CREATE TABLE t_col_override_cs (a VARCHAR(50) CHARACTER SET latin1) DEFAULT CHARSET=utf8mb4", + "t_col_override_cs", + "db_override_cs", + }, + { + "column_collation_overrides_table", + "DROP DATABASE IF EXISTS db_override_coll; CREATE DATABASE db_override_coll; USE db_override_coll", + "CREATE TABLE t_col_override_coll (a VARCHAR(50) COLLATE utf8mb4_bin) DEFAULT CHARSET=utf8mb4", + "t_col_override_coll", + "db_override_coll", + }, + { + "column_charset_collation_display_rules", + "DROP DATABASE IF EXISTS db_display; CREATE DATABASE db_display CHARACTER SET utf8mb4; USE db_display", + "CREATE TABLE t_display_rules (a VARCHAR(50), b VARCHAR(50) CHARACTER SET latin1, c VARCHAR(50) COLLATE utf8mb4_bin, d VARCHAR(50) CHARACTER SET utf8mb3 COLLATE utf8mb3_bin)", + "t_display_rules", + "db_display", + }, + { + "binary_charset_on_column", + "DROP DATABASE IF EXISTS db_binary; CREATE DATABASE db_binary; USE db_binary", + "CREATE TABLE t_binary_cs (a VARCHAR(50) CHARACTER SET binary)", + "t_binary_cs", + "db_binary", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if tc.setupSQL != "" { + if err := ctr.execSQL(tc.setupSQL); err != nil { + t.Fatalf("container setup: %v", err) + } + } + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateTable(tc.table) + if err != nil { + t.Fatalf("container show create: %v", err) + } + + c := New() + if tc.setupSQL != "" { + results, parseErr := c.Exec(tc.setupSQL, nil) + if parseErr != nil { + t.Fatalf("omni setup parse error: %v", parseErr) + } + for _, r := range results { + if r.Error != nil { + t.Fatalf("omni setup exec error: %v", r.Error) + } + } + } + results, parseErr := c.Exec(tc.sql, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable(tc.database, tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_9_GeneratedColumns(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + { + "generated_virtual_add", + "CREATE TABLE t_gen1 (col1 INT, col2 INT, col3 INT GENERATED ALWAYS AS (col1 + col2) VIRTUAL)", + "t_gen1", + }, + { + "generated_stored_mul", + "CREATE TABLE t_gen2 (col1 INT, col2 INT, col3 INT GENERATED ALWAYS AS (col1 * col2) STORED)", + "t_gen2", + }, + { + "generated_varchar_concat", + "CREATE TABLE t_gen3 (first_name VARCHAR(50), last_name VARCHAR(50), full_name VARCHAR(255) AS (CONCAT(first_name, ' ', last_name)) VIRTUAL)", + "t_gen3", + }, + { + "generated_not_null", + "CREATE TABLE t_gen4 (col1 INT, col2 INT, col3 INT GENERATED ALWAYS AS (col1 + col2) STORED NOT NULL)", + "t_gen4", + }, + { + "generated_comment", + "CREATE TABLE t_gen5 (col1 INT, col2 INT, col3 INT GENERATED ALWAYS AS (col1 + col2) VIRTUAL COMMENT 'sum of cols')", + "t_gen5", + }, + { + "generated_invisible", + "CREATE TABLE t_gen6 (col1 INT, col2 INT, col3 INT GENERATED ALWAYS AS (col1 + col2) VIRTUAL INVISIBLE)", + "t_gen6", + }, + { + "generated_json_extract", + "CREATE TABLE t_gen7 (data JSON, name VARCHAR(255) GENERATED ALWAYS AS (JSON_EXTRACT(data, '$.name')) VIRTUAL)", + "t_gen7", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_3_BinaryTypes(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"binary_16", "CREATE TABLE t_binary16 (a BINARY(16))", "t_binary16"}, + {"binary_no_length", "CREATE TABLE t_binary1 (a BINARY)", "t_binary1"}, + {"varbinary_255", "CREATE TABLE t_varbinary255 (a VARBINARY(255))", "t_varbinary255"}, + {"tinyblob", "CREATE TABLE t_tinyblob (a TINYBLOB)", "t_tinyblob"}, + {"blob", "CREATE TABLE t_blob (a BLOB)", "t_blob"}, + {"mediumblob", "CREATE TABLE t_mediumblob (a MEDIUMBLOB)", "t_mediumblob"}, + {"longblob", "CREATE TABLE t_longblob (a LONGBLOB)", "t_longblob"}, + {"blob_1000", "CREATE TABLE t_blob1000 (a BLOB(1000))", "t_blob1000"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_6_JSONType(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"json_basic", "CREATE TABLE t_json (a JSON)", "t_json"}, + {"json_default_null", "CREATE TABLE t_json_dn (a JSON DEFAULT NULL)", "t_json_dn"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_5_SpatialTypes(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"geometry", "CREATE TABLE t_geometry (a GEOMETRY)", "t_geometry"}, + {"point", "CREATE TABLE t_point (a POINT)", "t_point"}, + {"linestring", "CREATE TABLE t_linestring (a LINESTRING)", "t_linestring"}, + {"polygon", "CREATE TABLE t_polygon (a POLYGON)", "t_polygon"}, + {"multipoint", "CREATE TABLE t_multipoint (a MULTIPOINT)", "t_multipoint"}, + {"multilinestring", "CREATE TABLE t_multilinestring (a MULTILINESTRING)", "t_multilinestring"}, + {"multipolygon", "CREATE TABLE t_multipolygon (a MULTIPOLYGON)", "t_multipolygon"}, + {"geometrycollection", "CREATE TABLE t_geomcoll (a GEOMETRYCOLLECTION)", "t_geomcoll"}, + {"point_srid", "CREATE TABLE t_point_srid (a POINT NOT NULL SRID 4326)", "t_point_srid"}, + {"linestring_srid", "CREATE TABLE t_ls_srid (a LINESTRING SRID 4326)", "t_ls_srid"}, + {"polygon_srid", "CREATE TABLE t_poly_srid (a POLYGON NOT NULL SRID 4326)", "t_poly_srid"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_14_FulltextSpatialIndexes(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"fulltext_named", "CREATE TABLE t_ft1 (id INT PRIMARY KEY, body TEXT, FULLTEXT KEY `ft_idx` (body))", "t_ft1"}, + {"fulltext_multi_col", "CREATE TABLE t_ft2 (id INT PRIMARY KEY, title VARCHAR(200), body TEXT, FULLTEXT KEY `ft_multi` (title, body))", "t_ft2"}, + {"fulltext_auto_name", "CREATE TABLE t_ft3 (id INT PRIMARY KEY, body TEXT, FULLTEXT KEY (body))", "t_ft3"}, + {"spatial_named", "CREATE TABLE t_sp1 (id INT PRIMARY KEY, geo_col GEOMETRY NOT NULL, SPATIAL KEY `sp_idx` (geo_col))", "t_sp1"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_15_ExpressionIndexes(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"expr_func_upper", "CREATE TABLE t_expr1 (name VARCHAR(100), KEY `idx` ((UPPER(name))))", "t_expr1"}, + {"expr_arithmetic", "CREATE TABLE t_expr2 (col1 INT, col2 INT, KEY `idx` ((col1 + col2)))", "t_expr2"}, + {"expr_unique", "CREATE TABLE t_expr3 (name VARCHAR(100), UNIQUE KEY `uidx` ((UPPER(name))))", "t_expr3"}, + {"expr_display_format", "CREATE TABLE t_expr4 (a INT, b INT, KEY `idx_col` (a), KEY `idx_expr` ((a * b)))", "t_expr4"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_1_16_IndexOptions(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + {"index_comment", "CREATE TABLE t_idx_comment (a INT, INDEX idx_a (a) COMMENT 'description')", "t_idx_comment"}, + {"index_invisible", "CREATE TABLE t_idx_invis (a INT, INDEX idx_a (a) INVISIBLE)", "t_idx_invis"}, + {"index_visible", "CREATE TABLE t_idx_vis (a INT, INDEX idx_a (a) VISIBLE)", "t_idx_vis"}, + {"index_key_block_size", "CREATE TABLE t_idx_kbs (a INT, INDEX idx_a (a) KEY_BLOCK_SIZE=4)", "t_idx_kbs"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_2_1_CreateTableVariants(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + t.Run("if_not_exists_no_error", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_ine") + ctr.execSQL("CREATE TABLE t_ine (id INT)") + + // Second CREATE with IF NOT EXISTS should not error on ctr. + ctrErr := ctr.execSQL("CREATE TABLE IF NOT EXISTS t_ine (id INT)") + if ctrErr != nil { + t.Fatalf("container error on IF NOT EXISTS: %v", ctrErr) + } + + // Omni should also not error. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_ine (id INT)", nil) + results, _ := c.Exec("CREATE TABLE IF NOT EXISTS t_ine (id INT)", nil) + if results[0].Error != nil { + t.Fatalf("omni error on IF NOT EXISTS: %v", results[0].Error) + } + + // Compare SHOW CREATE TABLE. + ctrDDL, _ := ctr.showCreateTable("t_ine") + omniDDL := c.ShowCreateTable("test", "t_ine") + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + t.Run("temporary_table", func(t *testing.T) { + // MySQL SHOW CREATE TABLE for temporary tables shows "CREATE TEMPORARY TABLE". + ctr.execSQL("DROP TEMPORARY TABLE IF EXISTS t_temp") + err := ctr.execSQL("CREATE TEMPORARY TABLE t_temp (id INT, name VARCHAR(50))") + if err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateTable("t_temp") + if err != nil { + t.Fatalf("container show create: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("CREATE TEMPORARY TABLE t_temp (id INT, name VARCHAR(50))", nil) + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_temp") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + t.Run("create_table_like", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_like_dst") + ctr.execSQL("DROP TABLE IF EXISTS t_like_src") + ctr.execSQL("CREATE TABLE t_like_src (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL DEFAULT '', score DECIMAL(10,2))") + err := ctr.execSQL("CREATE TABLE t_like_dst LIKE t_like_src") + if err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable("t_like_dst") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_like_src (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL DEFAULT '', score DECIMAL(10,2))", nil) + results, _ := c.Exec("CREATE TABLE t_like_dst LIKE t_like_src", nil) + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_like_dst") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + t.Run("create_table_with_view_name_conflict", func(t *testing.T) { + // Creating a table with same name as an existing view should error. + ctr.execSQL("DROP TABLE IF EXISTS t_view_conflict") + ctr.execSQL("DROP VIEW IF EXISTS t_view_conflict") + ctr.execSQL("CREATE VIEW t_view_conflict AS SELECT 1 AS a") + ctrErr := ctr.execSQL("CREATE TABLE t_view_conflict (id INT)") + if ctrErr == nil { + t.Fatal("expected container error when creating table with same name as view") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE VIEW t_view_conflict AS SELECT 1 AS a", nil) + results, _ := c.Exec("CREATE TABLE t_view_conflict (id INT)", nil) + if results[0].Error == nil { + t.Fatal("expected omni error when creating table with same name as view") + } + }) + + t.Run("reserved_word_as_name", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS `select`") + err := ctr.execSQL("CREATE TABLE `select` (`from` INT, `where` VARCHAR(50))") + if err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable("`select`") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("CREATE TABLE `select` (`from` INT, `where` VARCHAR(50))", nil) + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "select") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) +} + +func TestContainer_Section_2_2_AlterTableColumnOps(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // DDL-comparison tests: run setup + alter, compare SHOW CREATE TABLE. + ddlCases := []struct { + name string + setup string + alter string + table string + }{ + { + "add_column_at_end", + "CREATE TABLE t_add_end (id INT PRIMARY KEY)", + "ALTER TABLE t_add_end ADD COLUMN name VARCHAR(100)", + "t_add_end", + }, + { + "add_column_first", + "CREATE TABLE t_add_first (id INT PRIMARY KEY)", + "ALTER TABLE t_add_first ADD COLUMN name VARCHAR(100) FIRST", + "t_add_first", + }, + { + "add_column_after", + "CREATE TABLE t_add_after (id INT PRIMARY KEY, age INT)", + "ALTER TABLE t_add_after ADD COLUMN name VARCHAR(100) AFTER id", + "t_add_after", + }, + { + "add_multiple_columns", + "CREATE TABLE t_add_multi (id INT PRIMARY KEY)", + "ALTER TABLE t_add_multi ADD COLUMN name VARCHAR(100), ADD COLUMN age INT, ADD COLUMN email VARCHAR(255)", + "t_add_multi", + }, + { + "drop_column", + "CREATE TABLE t_drop_col (id INT PRIMARY KEY, name VARCHAR(100), age INT)", + "ALTER TABLE t_drop_col DROP COLUMN age", + "t_drop_col", + }, + { + "drop_column_part_of_index", + "CREATE TABLE t_drop_idx_col (id INT PRIMARY KEY, a INT, b INT, KEY idx_ab (a, b))", + "ALTER TABLE t_drop_idx_col DROP COLUMN b", + "t_drop_idx_col", + }, + { + "drop_column_only_in_index", + "CREATE TABLE t_drop_only_idx (id INT PRIMARY KEY, a INT, KEY idx_a (a))", + "ALTER TABLE t_drop_only_idx DROP COLUMN a", + "t_drop_only_idx", + }, + // Note: DROP COLUMN IF EXISTS is not supported in MySQL 8.0 (only 8.0.32+). + // Scenario marked as [~] partial in SCENARIOS.md. + { + "modify_column_change_type", + "CREATE TABLE t_mod_type (id INT PRIMARY KEY, val SMALLINT)", + "ALTER TABLE t_mod_type MODIFY COLUMN val INT", + "t_mod_type", + }, + { + "modify_column_widen_varchar", + "CREATE TABLE t_mod_widen (id INT PRIMARY KEY, name VARCHAR(50))", + "ALTER TABLE t_mod_widen MODIFY COLUMN name VARCHAR(200)", + "t_mod_widen", + }, + { + "modify_column_narrow_varchar", + "CREATE TABLE t_mod_narrow (id INT PRIMARY KEY, name VARCHAR(200))", + "ALTER TABLE t_mod_narrow MODIFY COLUMN name VARCHAR(50)", + "t_mod_narrow", + }, + { + "modify_column_int_to_bigint", + "CREATE TABLE t_mod_bigint (id INT PRIMARY KEY, val INT)", + "ALTER TABLE t_mod_bigint MODIFY COLUMN val BIGINT", + "t_mod_bigint", + }, + { + "modify_column_add_not_null", + "CREATE TABLE t_mod_nn (id INT PRIMARY KEY, name VARCHAR(100))", + "ALTER TABLE t_mod_nn MODIFY COLUMN name VARCHAR(100) NOT NULL", + "t_mod_nn", + }, + { + "modify_column_remove_not_null", + "CREATE TABLE t_mod_rnn (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL)", + "ALTER TABLE t_mod_rnn MODIFY COLUMN name VARCHAR(100)", + "t_mod_rnn", + }, + { + "modify_column_first_after", + "CREATE TABLE t_mod_pos (a INT, b INT, c INT)", + "ALTER TABLE t_mod_pos MODIFY COLUMN c INT FIRST", + "t_mod_pos", + }, + { + "change_column_rename_and_type", + "CREATE TABLE t_chg_rt (id INT PRIMARY KEY, old_name VARCHAR(50))", + "ALTER TABLE t_chg_rt CHANGE COLUMN old_name new_name VARCHAR(100)", + "t_chg_rt", + }, + { + "change_column_same_name_diff_type", + "CREATE TABLE t_chg_st (id INT PRIMARY KEY, val INT)", + "ALTER TABLE t_chg_st CHANGE COLUMN val val BIGINT", + "t_chg_st", + }, + { + "change_column_update_index_refs", + "CREATE TABLE t_chg_idx (id INT PRIMARY KEY, a INT, KEY idx_a (a))", + "ALTER TABLE t_chg_idx CHANGE COLUMN a b INT", + "t_chg_idx", + }, + { + "rename_column", + "CREATE TABLE t_ren_col (id INT PRIMARY KEY, old_col INT)", + "ALTER TABLE t_ren_col RENAME COLUMN old_col TO new_col", + "t_ren_col", + }, + { + "rename_column_update_index_refs", + "CREATE TABLE t_ren_idx (id INT PRIMARY KEY, a INT, KEY idx_a (a))", + "ALTER TABLE t_ren_idx RENAME COLUMN a TO b", + "t_ren_idx", + }, + { + "alter_column_set_default", + "CREATE TABLE t_set_def (id INT PRIMARY KEY, val INT)", + "ALTER TABLE t_set_def ALTER COLUMN val SET DEFAULT 42", + "t_set_def", + }, + { + "alter_column_drop_default", + "CREATE TABLE t_drop_def (id INT PRIMARY KEY, val INT DEFAULT 10)", + "ALTER TABLE t_drop_def ALTER COLUMN val DROP DEFAULT", + "t_drop_def", + }, + { + "alter_column_set_visible", + "CREATE TABLE t_vis (id INT PRIMARY KEY, val INT /*!80023 INVISIBLE */)", + "ALTER TABLE t_vis ALTER COLUMN val SET VISIBLE", + "t_vis", + }, + { + "alter_column_set_invisible", + "CREATE TABLE t_invis (id INT PRIMARY KEY, val INT)", + "ALTER TABLE t_invis ALTER COLUMN val SET INVISIBLE", + "t_invis", + }, + { + "drop_column_part_of_pk", + "CREATE TABLE t_drop_pk (a INT, b INT, PRIMARY KEY (a, b))", + "ALTER TABLE t_drop_pk DROP COLUMN a", + "t_drop_pk", + }, + } + + for _, tc := range ddlCases { + t.Run(tc.name, func(t *testing.T) { + // Oracle: setup + alter. + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.setup); err != nil { + t.Fatalf("container setup: %v", err) + } + if err := ctr.execSQL(tc.alter); err != nil { + t.Fatalf("container alter: %v", err) + } + ctrDDL, err := ctr.showCreateTable(tc.table) + if err != nil { + t.Fatalf("container show create: %v", err) + } + + // Omni: setup + alter. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec(tc.setup, nil) + if results[0].Error != nil { + t.Fatalf("omni setup error: %v", results[0].Error) + } + results, _ = c.Exec(tc.alter, nil) + for _, r := range results { + if r.Error != nil { + t.Fatalf("omni alter error: %v", r.Error) + } + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } + + // Error tests: operations that should produce errors. + errCases := []struct { + name string + setup string + alter string + table string + wantErr bool + }{ + { + "drop_column_referenced_by_fk", + "CREATE TABLE t_fk_parent (id INT PRIMARY KEY); CREATE TABLE t_fk_child (id INT PRIMARY KEY, pid INT, FOREIGN KEY (pid) REFERENCES t_fk_parent(id))", + "ALTER TABLE t_fk_child DROP COLUMN pid", + "t_fk_child", + true, + }, + } + + for _, tc := range errCases { + t.Run(tc.name, func(t *testing.T) { + // Clean up tables first. + ctr.execSQL("DROP TABLE IF EXISTS t_fk_child") + ctr.execSQL("DROP TABLE IF EXISTS t_fk_parent") + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + + if err := ctr.execSQL(tc.setup); err != nil { + t.Fatalf("container setup: %v", err) + } + ctrErr := ctr.execSQL(tc.alter) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec(tc.setup, nil) + results, _ := c.Exec(tc.alter, nil) + var omniErr error + for _, r := range results { + if r.Error != nil { + omniErr = r.Error + break + } + } + + if tc.wantErr { + if ctrErr == nil { + t.Fatal("expected container error but got nil") + } + if omniErr == nil { + t.Fatalf("expected omni error but got nil (container error: %v)", ctrErr) + } + t.Logf("both errored as expected — container: %v, omni: %v", ctrErr, omniErr) + } else { + if ctrErr != nil { + t.Fatalf("unexpected container error: %v", ctrErr) + } + if omniErr != nil { + t.Fatalf("unexpected omni error: %v", omniErr) + } + } + }) + } +} + +func TestContainer_Section_2_3_AlterTableIndexOps(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // DDL-comparison tests: run setup + alter, compare SHOW CREATE TABLE. + ddlCases := []struct { + name string + setup string + alter string + table string + }{ + { + "add_index", + "CREATE TABLE t_add_idx (id INT PRIMARY KEY, name VARCHAR(100))", + "ALTER TABLE t_add_idx ADD INDEX idx_name (name)", + "t_add_idx", + }, + { + "add_unique_index", + "CREATE TABLE t_add_uniq (id INT PRIMARY KEY, email VARCHAR(255))", + "ALTER TABLE t_add_uniq ADD UNIQUE INDEX idx_email (email)", + "t_add_uniq", + }, + { + "add_fulltext_index", + "CREATE TABLE t_add_ft (id INT PRIMARY KEY, body TEXT) ENGINE=InnoDB", + "ALTER TABLE t_add_ft ADD FULLTEXT INDEX idx_body (body)", + "t_add_ft", + }, + { + "add_primary_key", + "CREATE TABLE t_add_pk (id INT NOT NULL, name VARCHAR(100))", + "ALTER TABLE t_add_pk ADD PRIMARY KEY (id)", + "t_add_pk", + }, + { + "drop_index", + "CREATE TABLE t_drop_idx (id INT PRIMARY KEY, name VARCHAR(100), KEY idx_name (name))", + "ALTER TABLE t_drop_idx DROP INDEX idx_name", + "t_drop_idx", + }, + // Note: DROP INDEX IF EXISTS is not supported in MySQL 8.0 (syntax error). + // Scenario marked as [~] partial in SCENARIOS.md. + { + "drop_primary_key", + "CREATE TABLE t_drop_pk (id INT NOT NULL, name VARCHAR(100), PRIMARY KEY (id))", + "ALTER TABLE t_drop_pk DROP PRIMARY KEY", + "t_drop_pk", + }, + { + "rename_index", + "CREATE TABLE t_ren_idx (id INT PRIMARY KEY, name VARCHAR(100), KEY idx_old (name))", + "ALTER TABLE t_ren_idx RENAME INDEX idx_old TO idx_new", + "t_ren_idx", + }, + { + "alter_index_invisible", + "CREATE TABLE t_idx_invis (id INT PRIMARY KEY, name VARCHAR(100), KEY idx_name (name))", + "ALTER TABLE t_idx_invis ALTER INDEX idx_name INVISIBLE", + "t_idx_invis", + }, + { + "alter_index_visible", + "CREATE TABLE t_idx_vis (id INT PRIMARY KEY, name VARCHAR(100), KEY idx_name (name) INVISIBLE)", + "ALTER TABLE t_idx_vis ALTER INDEX idx_name VISIBLE", + "t_idx_vis", + }, + } + + for _, tc := range ddlCases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.setup); err != nil { + t.Fatalf("container setup: %v", err) + } + if err := ctr.execSQL(tc.alter); err != nil { + t.Fatalf("container alter: %v", err) + } + ctrDDL, err := ctr.showCreateTable(tc.table) + if err != nil { + t.Fatalf("container SHOW CREATE TABLE: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + if results, _ := c.Exec(tc.setup, nil); results != nil { + for _, r := range results { + if r.Error != nil { + t.Fatalf("omni setup error: %v", r.Error) + } + } + } + if results, _ := c.Exec(tc.alter, nil); results != nil { + for _, r := range results { + if r.Error != nil { + t.Fatalf("omni alter error: %v", r.Error) + } + } + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } + + // Error tests. + errCases := []struct { + name string + setup string + alter string + table string + wantErr bool + }{ + { + "add_primary_key_when_exists", + "CREATE TABLE t_dup_pk (id INT PRIMARY KEY, val INT NOT NULL)", + "ALTER TABLE t_dup_pk ADD PRIMARY KEY (val)", + "t_dup_pk", + true, + }, + } + + for _, tc := range errCases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.setup); err != nil { + t.Fatalf("container setup: %v", err) + } + ctrErr := ctr.execSQL(tc.alter) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec(tc.setup, nil) + results, _ := c.Exec(tc.alter, nil) + var omniErr error + for _, r := range results { + if r.Error != nil { + omniErr = r.Error + break + } + } + + if tc.wantErr { + if ctrErr == nil { + t.Fatal("expected container error but got nil") + } + if omniErr == nil { + t.Fatalf("expected omni error but got nil (container error: %v)", ctrErr) + } + t.Logf("both errored as expected — container: %v, omni: %v", ctrErr, omniErr) + } + }) + } +} + +func TestContainer_Section_2_4_AlterTableConstraintOps(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // DDL-comparison tests: run setup + alter, compare SHOW CREATE TABLE. + ddlCases := []struct { + name string + setup string + alter string + table string + }{ + { + "add_foreign_key", + "CREATE TABLE t_parent_fk (id INT PRIMARY KEY); CREATE TABLE t_child_fk (id INT PRIMARY KEY, parent_id INT)", + "ALTER TABLE t_child_fk ADD CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES t_parent_fk(id)", + "t_child_fk", + }, + { + "add_check", + "CREATE TABLE t_add_chk (id INT PRIMARY KEY, age INT)", + "ALTER TABLE t_add_chk ADD CHECK (age > 0)", + "t_add_chk", + }, + { + "drop_foreign_key", + "CREATE TABLE t_parent_dfk (id INT PRIMARY KEY); CREATE TABLE t_child_dfk (id INT PRIMARY KEY, parent_id INT, CONSTRAINT fk_drop FOREIGN KEY (parent_id) REFERENCES t_parent_dfk(id))", + "ALTER TABLE t_child_dfk DROP FOREIGN KEY fk_drop", + "t_child_dfk", + }, + { + "drop_check", + "CREATE TABLE t_drop_chk (id INT PRIMARY KEY, age INT, CONSTRAINT chk_age CHECK (age > 0))", + "ALTER TABLE t_drop_chk DROP CHECK chk_age", + "t_drop_chk", + }, + { + "drop_constraint_generic", + "CREATE TABLE t_drop_con (id INT PRIMARY KEY, val INT, CONSTRAINT chk_val CHECK (val >= 0))", + "ALTER TABLE t_drop_con DROP CONSTRAINT chk_val", + "t_drop_con", + }, + { + "alter_check_not_enforced", + "CREATE TABLE t_chk_ne (id INT PRIMARY KEY, score INT, CONSTRAINT chk_score CHECK (score >= 0))", + "ALTER TABLE t_chk_ne ALTER CHECK chk_score NOT ENFORCED", + "t_chk_ne", + }, + { + "alter_check_enforced", + "CREATE TABLE t_chk_enf (id INT PRIMARY KEY, score INT, CONSTRAINT chk_score2 CHECK (score >= 0) /*!80016 NOT ENFORCED */)", + "ALTER TABLE t_chk_enf ALTER CHECK chk_score2 ENFORCED", + "t_chk_enf", + }, + } + + for _, tc := range ddlCases { + t.Run(tc.name, func(t *testing.T) { + // Clean up all tables that might exist from setup. + for _, tName := range []string{ + "t_parent_fk", "t_child_fk", + "t_parent_dfk", "t_child_dfk", + "t_add_chk", + "t_drop_chk", + "t_drop_con", + "t_chk_ne", + "t_chk_enf", + } { + ctr.execSQL("DROP TABLE IF EXISTS " + tName) + } + if err := ctr.execSQL(tc.setup); err != nil { + t.Fatalf("container setup: %v", err) + } + if err := ctr.execSQL(tc.alter); err != nil { + t.Fatalf("container alter: %v", err) + } + ctrDDL, err := ctr.showCreateTable(tc.table) + if err != nil { + t.Fatalf("container SHOW CREATE TABLE: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + if results, _ := c.Exec(tc.setup, nil); results != nil { + for _, r := range results { + if r.Error != nil { + t.Fatalf("omni setup error: %v", r.Error) + } + } + } + if results, _ := c.Exec(tc.alter, nil); results != nil { + for _, r := range results { + if r.Error != nil { + t.Fatalf("omni alter error: %v", r.Error) + } + } + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_2_5_AlterTableTableLevel(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + setup string + alter string + table string // table name to SHOW CREATE TABLE after alter + }{ + { + "rename_to", + "CREATE TABLE t_rename_src (id INT PRIMARY KEY, name VARCHAR(100))", + "ALTER TABLE t_rename_src RENAME TO t_rename_dst", + "t_rename_dst", + }, + { + "engine_myisam", + "CREATE TABLE t_engine (id INT PRIMARY KEY, val INT)", + "ALTER TABLE t_engine ENGINE=MyISAM", + "t_engine", + }, + { + "convert_charset_utf8mb4", + "CREATE TABLE t_conv_cs (id INT PRIMARY KEY, name VARCHAR(100)) DEFAULT CHARSET=latin1", + "ALTER TABLE t_conv_cs CONVERT TO CHARACTER SET utf8mb4", + "t_conv_cs", + }, + { + "default_charset_latin1", + "CREATE TABLE t_def_cs (id INT PRIMARY KEY, name VARCHAR(100))", + "ALTER TABLE t_def_cs DEFAULT CHARACTER SET latin1", + "t_def_cs", + }, + { + "comment", + "CREATE TABLE t_comment (id INT PRIMARY KEY)", + "ALTER TABLE t_comment COMMENT='new comment'", + "t_comment", + }, + { + "auto_increment", + "CREATE TABLE t_autoinc (id INT PRIMARY KEY AUTO_INCREMENT, val INT)", + "ALTER TABLE t_autoinc AUTO_INCREMENT=1000", + "t_autoinc", + }, + { + "row_format_compressed", + "CREATE TABLE t_rowfmt (id INT PRIMARY KEY, val INT)", + "ALTER TABLE t_rowfmt ROW_FORMAT=COMPRESSED", + "t_rowfmt", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // Clean up tables that might exist. + for _, tName := range []string{ + "t_rename_src", "t_rename_dst", + "t_engine", + "t_conv_cs", + "t_def_cs", + "t_comment", + "t_autoinc", + "t_rowfmt", + } { + ctr.execSQL("DROP TABLE IF EXISTS " + tName) + } + if err := ctr.execSQL(tc.setup); err != nil { + t.Fatalf("container setup: %v", err) + } + if err := ctr.execSQL(tc.alter); err != nil { + t.Fatalf("container alter: %v", err) + } + ctrDDL, err := ctr.showCreateTable(tc.table) + if err != nil { + t.Fatalf("container SHOW CREATE TABLE: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + if results, _ := c.Exec(tc.setup, nil); results != nil { + for _, r := range results { + if r.Error != nil { + t.Fatalf("omni setup error: %v", r.Error) + } + } + } + if results, _ := c.Exec(tc.alter, nil); results != nil { + for _, r := range results { + if r.Error != nil { + t.Fatalf("omni alter error: %v", r.Error) + } + } + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +func TestContainer_Section_2_6_DropTable(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + t.Run("drop_table_basic", func(t *testing.T) { + // Setup: create a table, then drop it. Verify it's gone. + ctr.execSQL("DROP TABLE IF EXISTS t_drop1") + ctr.execSQL("CREATE TABLE t_drop1 (id INT PRIMARY KEY, name VARCHAR(100))") + + ctrErr := ctr.execSQL("DROP TABLE t_drop1") + if ctrErr != nil { + t.Fatalf("container DROP TABLE error: %v", ctrErr) + } + _, ctrShowErr := ctr.showCreateTable("t_drop1") + if ctrShowErr == nil { + t.Fatal("container: table still exists after DROP TABLE") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_drop1 (id INT PRIMARY KEY, name VARCHAR(100))", nil) + results, _ := c.Exec("DROP TABLE t_drop1", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP TABLE error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_drop1") + if omniDDL != "" { + t.Errorf("omni: table still exists after DROP TABLE, got: %s", omniDDL) + } + }) + + t.Run("drop_table_if_exists", func(t *testing.T) { + // DROP TABLE IF EXISTS on a nonexistent table should not error. + ctr.execSQL("DROP TABLE IF EXISTS t_drop_ine") + + ctrErr := ctr.execSQL("DROP TABLE IF EXISTS t_drop_ine") + if ctrErr != nil { + t.Fatalf("container DROP TABLE IF EXISTS error: %v", ctrErr) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("DROP TABLE IF EXISTS t_drop_ine", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP TABLE IF EXISTS error: %v", results[0].Error) + } + }) + + t.Run("drop_table_multi", func(t *testing.T) { + // DROP TABLE t1, t2, t3 — multi-table drop. + ctr.execSQL("DROP TABLE IF EXISTS t_dm1") + ctr.execSQL("DROP TABLE IF EXISTS t_dm2") + ctr.execSQL("DROP TABLE IF EXISTS t_dm3") + ctr.execSQL("CREATE TABLE t_dm1 (id INT)") + ctr.execSQL("CREATE TABLE t_dm2 (id INT)") + ctr.execSQL("CREATE TABLE t_dm3 (id INT)") + + ctrErr := ctr.execSQL("DROP TABLE t_dm1, t_dm2, t_dm3") + if ctrErr != nil { + t.Fatalf("container DROP TABLE multi error: %v", ctrErr) + } + for _, tbl := range []string{"t_dm1", "t_dm2", "t_dm3"} { + if _, err := ctr.showCreateTable(tbl); err == nil { + t.Errorf("container: table %s still exists after multi-drop", tbl) + } + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_dm1 (id INT)", nil) + c.Exec("CREATE TABLE t_dm2 (id INT)", nil) + c.Exec("CREATE TABLE t_dm3 (id INT)", nil) + results, _ := c.Exec("DROP TABLE t_dm1, t_dm2, t_dm3", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP TABLE multi error: %v", results[0].Error) + } + for _, tbl := range []string{"t_dm1", "t_dm2", "t_dm3"} { + ddl := c.ShowCreateTable("test", tbl) + if ddl != "" { + t.Errorf("omni: table %s still exists after multi-drop", tbl) + } + } + }) + + t.Run("drop_table_nonexistent_error", func(t *testing.T) { + // DROP TABLE on nonexistent table should produce error 1051. + ctr.execSQL("DROP TABLE IF EXISTS t_noexist") + ctrErr := ctr.execSQL("DROP TABLE t_noexist") + if ctrErr == nil { + t.Fatal("container: expected error for DROP nonexistent table") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("DROP TABLE t_noexist", nil) + omniErr := results[0].Error + if omniErr == nil { + t.Fatal("omni: expected error for DROP nonexistent table") + } + catErr, ok := omniErr.(*Error) + if !ok { + t.Fatalf("omni error is not *Error: %T", omniErr) + } + if catErr.Code != 1051 { + t.Errorf("omni error code: want 1051, got %d (message: %s)", catErr.Code, catErr.Message) + } + }) + + t.Run("drop_temporary_table", func(t *testing.T) { + // DROP TEMPORARY TABLE should work. + ctr.execSQL("DROP TEMPORARY TABLE IF EXISTS t_temp_drop") + ctr.execSQL("CREATE TEMPORARY TABLE t_temp_drop (id INT, val VARCHAR(50))") + ctrErr := ctr.execSQL("DROP TEMPORARY TABLE t_temp_drop") + if ctrErr != nil { + t.Fatalf("container DROP TEMPORARY TABLE error: %v", ctrErr) + } + _, ctrShowErr := ctr.showCreateTable("t_temp_drop") + if ctrShowErr == nil { + t.Fatal("container: temp table still exists after DROP TEMPORARY TABLE") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TEMPORARY TABLE t_temp_drop (id INT, val VARCHAR(50))", nil) + results, _ := c.Exec("DROP TEMPORARY TABLE t_temp_drop", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP TEMPORARY TABLE error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_temp_drop") + if omniDDL != "" { + t.Errorf("omni: temp table still exists after DROP TEMPORARY TABLE") + } + }) + + t.Run("drop_table_fk_referenced", func(t *testing.T) { + // DROP TABLE that has FK references should error (with foreign_key_checks=1, the default). + ctr.execSQL("DROP TABLE IF EXISTS t_fk_child") + ctr.execSQL("DROP TABLE IF EXISTS t_fk_parent") + ctr.execSQL("CREATE TABLE t_fk_parent (id INT PRIMARY KEY)") + ctr.execSQL("CREATE TABLE t_fk_child (id INT, parent_id INT, FOREIGN KEY (parent_id) REFERENCES t_fk_parent(id))") + + ctrErr := ctr.execSQL("DROP TABLE t_fk_parent") + if ctrErr == nil { + t.Fatal("container: expected error when dropping FK-referenced table") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_fk_parent (id INT PRIMARY KEY)", nil) + c.Exec("CREATE TABLE t_fk_child (id INT, parent_id INT, FOREIGN KEY (parent_id) REFERENCES t_fk_parent(id))", nil) + results, _ := c.Exec("DROP TABLE t_fk_parent", nil) + omniErr := results[0].Error + if omniErr == nil { + t.Fatal("omni: expected error when dropping FK-referenced table") + } + + // Verify table still exists after failed drop. + omniDDL := c.ShowCreateTable("test", "t_fk_parent") + if omniDDL == "" { + t.Error("omni: parent table was deleted despite FK reference") + } + }) +} + +func TestContainer_Section_2_7_TruncateTable(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Scenario 1: TRUNCATE TABLE t1 — table structure preserved + t.Run("truncate_basic", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_trunc1") + if err := ctr.execSQL("CREATE TABLE t_trunc1 (id INT PRIMARY KEY, name VARCHAR(100))"); err != nil { + t.Fatalf("container create: %v", err) + } + if err := ctr.execSQL("TRUNCATE TABLE t_trunc1"); err != nil { + t.Fatalf("container truncate: %v", err) + } + ctrDDL, _ := ctr.showCreateTable("t_trunc1") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_trunc1 (id INT PRIMARY KEY, name VARCHAR(100))", nil) + results, _ := c.Exec("TRUNCATE TABLE t_trunc1", nil) + if results[0].Error != nil { + t.Fatalf("omni truncate error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_trunc1") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + // Scenario 2: TRUNCATE resets AUTO_INCREMENT + t.Run("truncate_resets_auto_increment", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_trunc_ai") + if err := ctr.execSQL("CREATE TABLE t_trunc_ai (id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(100)) AUTO_INCREMENT=1000"); err != nil { + t.Fatalf("container create: %v", err) + } + // Verify AUTO_INCREMENT is shown before truncate + ctrBefore, _ := ctr.showCreateTable("t_trunc_ai") + if !strings.Contains(ctrBefore, "AUTO_INCREMENT=") { + t.Logf("container before truncate (no AUTO_INCREMENT shown): %s", ctrBefore) + } + if err := ctr.execSQL("TRUNCATE TABLE t_trunc_ai"); err != nil { + t.Fatalf("container truncate: %v", err) + } + ctrDDL, _ := ctr.showCreateTable("t_trunc_ai") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_trunc_ai (id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(100)) AUTO_INCREMENT=1000", nil) + c.Exec("TRUNCATE TABLE t_trunc_ai", nil) + omniDDL := c.ShowCreateTable("test", "t_trunc_ai") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + // Scenario 3: TRUNCATE nonexistent table → error + t.Run("truncate_nonexistent", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_trunc_noexist") + ctrErr := ctr.execSQL("TRUNCATE TABLE t_trunc_noexist") + if ctrErr == nil { + t.Fatal("container: expected error for TRUNCATE nonexistent table") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("TRUNCATE TABLE t_trunc_noexist", nil) + omniErr := results[0].Error + if omniErr == nil { + t.Fatal("omni: expected error for TRUNCATE nonexistent table") + } + }) +} + +func TestContainer_Section_2_8_CreateDropIndex(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Scenario 1: CREATE INDEX idx ON t (col) + t.Run("create_index_basic", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_ci1") + ctr.execSQL("CREATE TABLE t_ci1 (id INT PRIMARY KEY, name VARCHAR(100))") + if err := ctr.execSQL("CREATE INDEX idx_name ON t_ci1 (name)"); err != nil { + t.Fatalf("container CREATE INDEX error: %v", err) + } + ctrDDL, _ := ctr.showCreateTable("t_ci1") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_ci1 (id INT PRIMARY KEY, name VARCHAR(100))", nil) + results, _ := c.Exec("CREATE INDEX idx_name ON t_ci1 (name)", nil) + if results[0].Error != nil { + t.Fatalf("omni CREATE INDEX error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_ci1") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + // Scenario 2: CREATE UNIQUE INDEX + t.Run("create_unique_index", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_ci2") + ctr.execSQL("CREATE TABLE t_ci2 (id INT PRIMARY KEY, email VARCHAR(255))") + if err := ctr.execSQL("CREATE UNIQUE INDEX idx_email ON t_ci2 (email)"); err != nil { + t.Fatalf("container CREATE UNIQUE INDEX error: %v", err) + } + ctrDDL, _ := ctr.showCreateTable("t_ci2") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_ci2 (id INT PRIMARY KEY, email VARCHAR(255))", nil) + results, _ := c.Exec("CREATE UNIQUE INDEX idx_email ON t_ci2 (email)", nil) + if results[0].Error != nil { + t.Fatalf("omni CREATE UNIQUE INDEX error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_ci2") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + // Scenario 3: CREATE FULLTEXT INDEX + t.Run("create_fulltext_index", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_ci3") + ctr.execSQL("CREATE TABLE t_ci3 (id INT PRIMARY KEY, content TEXT)") + if err := ctr.execSQL("CREATE FULLTEXT INDEX idx_ft ON t_ci3 (content)"); err != nil { + t.Fatalf("container CREATE FULLTEXT INDEX error: %v", err) + } + ctrDDL, _ := ctr.showCreateTable("t_ci3") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_ci3 (id INT PRIMARY KEY, content TEXT)", nil) + results, _ := c.Exec("CREATE FULLTEXT INDEX idx_ft ON t_ci3 (content)", nil) + if results[0].Error != nil { + t.Fatalf("omni CREATE FULLTEXT INDEX error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_ci3") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + // Scenario 4: CREATE SPATIAL INDEX + t.Run("create_spatial_index", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_ci4") + ctr.execSQL("CREATE TABLE t_ci4 (id INT PRIMARY KEY, geo GEOMETRY NOT NULL)") + if err := ctr.execSQL("CREATE SPATIAL INDEX idx_sp ON t_ci4 (geo)"); err != nil { + t.Fatalf("container CREATE SPATIAL INDEX error: %v", err) + } + ctrDDL, _ := ctr.showCreateTable("t_ci4") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_ci4 (id INT PRIMARY KEY, geo GEOMETRY NOT NULL)", nil) + results, _ := c.Exec("CREATE SPATIAL INDEX idx_sp ON t_ci4 (geo)", nil) + if results[0].Error != nil { + t.Fatalf("omni CREATE SPATIAL INDEX error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_ci4") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + // Scenario 5: CREATE INDEX IF NOT EXISTS + // Note: MySQL 8.0 does NOT support IF NOT EXISTS on CREATE INDEX (syntax error). + // Our parser accepts it, and the catalog handles it gracefully (no-op on duplicate). + // This is marked [~] partial — omni is more permissive than MySQL 8.0 here. + t.Run("create_index_if_not_exists", func(t *testing.T) { + // Verify MySQL 8.0 rejects this syntax. + ctr.execSQL("DROP TABLE IF EXISTS t_ci5") + ctr.execSQL("CREATE TABLE t_ci5 (id INT PRIMARY KEY, val INT)") + ctr.execSQL("CREATE INDEX idx_val ON t_ci5 (val)") + ctrErr := ctr.execSQL("CREATE INDEX IF NOT EXISTS idx_val ON t_ci5 (val)") + if ctrErr == nil { + t.Fatal("container: expected syntax error for CREATE INDEX IF NOT EXISTS in MySQL 8.0") + } + t.Logf("container correctly rejects CREATE INDEX IF NOT EXISTS: %v", ctrErr) + + // Omni accepts IF NOT EXISTS as an extension — verify it doesn't error. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_ci5 (id INT PRIMARY KEY, val INT)", nil) + c.Exec("CREATE INDEX idx_val ON t_ci5 (val)", nil) + results, _ := c.Exec("CREATE INDEX IF NOT EXISTS idx_val ON t_ci5 (val)", nil) + if results[0].Error != nil { + t.Errorf("omni: unexpected error for IF NOT EXISTS: %v", results[0].Error) + } + // This is a known divergence from MySQL 8.0 behavior. + }) + + // Scenario 6: CREATE INDEX — duplicate name → error 1061 + t.Run("create_index_duplicate_error", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_ci6") + ctr.execSQL("CREATE TABLE t_ci6 (id INT PRIMARY KEY, a INT, b INT)") + ctr.execSQL("CREATE INDEX idx_a ON t_ci6 (a)") + ctrErr := ctr.execSQL("CREATE INDEX idx_a ON t_ci6 (b)") + if ctrErr == nil { + t.Fatal("container: expected error for duplicate index name") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_ci6 (id INT PRIMARY KEY, a INT, b INT)", nil) + c.Exec("CREATE INDEX idx_a ON t_ci6 (a)", nil) + results, _ := c.Exec("CREATE INDEX idx_a ON t_ci6 (b)", nil) + omniErr := results[0].Error + if omniErr == nil { + t.Fatal("omni: expected error for duplicate index name") + } + catErr, ok := omniErr.(*Error) + if !ok { + t.Fatalf("omni error is not *Error: %T", omniErr) + } + if catErr.Code != 1061 { + t.Errorf("omni error code: want 1061, got %d (message: %s)", catErr.Code, catErr.Message) + } + }) + + // Scenario 7: DROP INDEX idx ON t + t.Run("drop_index_basic", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_di1") + ctr.execSQL("CREATE TABLE t_di1 (id INT PRIMARY KEY, name VARCHAR(100), KEY idx_name (name))") + if err := ctr.execSQL("DROP INDEX idx_name ON t_di1"); err != nil { + t.Fatalf("container DROP INDEX error: %v", err) + } + ctrDDL, _ := ctr.showCreateTable("t_di1") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_di1 (id INT PRIMARY KEY, name VARCHAR(100), KEY idx_name (name))", nil) + results, _ := c.Exec("DROP INDEX idx_name ON t_di1", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP INDEX error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_di1") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + // Scenario 8: DROP INDEX nonexistent → error 1091 + t.Run("drop_index_nonexistent_error", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_di2") + ctr.execSQL("CREATE TABLE t_di2 (id INT PRIMARY KEY)") + ctrErr := ctr.execSQL("DROP INDEX idx_noexist ON t_di2") + if ctrErr == nil { + t.Fatal("container: expected error for DROP nonexistent index") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_di2 (id INT PRIMARY KEY)", nil) + results, _ := c.Exec("DROP INDEX idx_noexist ON t_di2", nil) + omniErr := results[0].Error + if omniErr == nil { + t.Fatal("omni: expected error for DROP nonexistent index") + } + catErr, ok := omniErr.(*Error) + if !ok { + t.Fatalf("omni error is not *Error: %T", omniErr) + } + if catErr.Code != 1091 { + t.Errorf("omni error code: want 1091, got %d (message: %s)", catErr.Code, catErr.Message) + } + }) +} + +func TestContainer_Section_2_9_RenameTable(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + t.Run("rename_basic", func(t *testing.T) { + // RENAME TABLE t1 TO t2 + ctr.execSQL("DROP TABLE IF EXISTS t_ren1") + ctr.execSQL("DROP TABLE IF EXISTS t_ren2") + ctr.execSQL("CREATE TABLE t_ren1 (id INT PRIMARY KEY, name VARCHAR(100))") + + if err := ctr.execSQL("RENAME TABLE t_ren1 TO t_ren2"); err != nil { + t.Fatalf("container RENAME TABLE error: %v", err) + } + ctrDDL, err := ctr.showCreateTable("t_ren2") + if err != nil { + t.Fatalf("container SHOW CREATE TABLE t_ren2 error: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_ren1 (id INT PRIMARY KEY, name VARCHAR(100))", nil) + results, _ := c.Exec("RENAME TABLE t_ren1 TO t_ren2", nil) + if results[0].Error != nil { + t.Fatalf("omni RENAME TABLE error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_ren2") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", ctrDDL, omniDDL) + } + // Old table should not exist + if c.ShowCreateTable("test", "t_ren1") != "" { + t.Error("omni: old table t_ren1 still exists after rename") + } + }) + + t.Run("rename_cross_database", func(t *testing.T) { + // RENAME TABLE t1 TO db2.t1 (cross-database) + ctr.execSQL("DROP TABLE IF EXISTS t_ren_cross") + ctr.execSQL("CREATE DATABASE IF NOT EXISTS test2") + ctr.execSQL("DROP TABLE IF EXISTS test2.t_ren_cross") + ctr.execSQL("CREATE TABLE t_ren_cross (id INT PRIMARY KEY, val VARCHAR(50) NOT NULL DEFAULT '')") + + if err := ctr.execSQL("RENAME TABLE t_ren_cross TO test2.t_ren_cross"); err != nil { + t.Fatalf("container RENAME TABLE cross-db error: %v", err) + } + ctrDDL, err := ctr.showCreateTable("test2.t_ren_cross") + if err != nil { + t.Fatalf("container SHOW CREATE TABLE test2.t_ren_cross error: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.Exec("CREATE DATABASE test2", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_ren_cross (id INT PRIMARY KEY, val VARCHAR(50) NOT NULL DEFAULT '')", nil) + results, _ := c.Exec("RENAME TABLE t_ren_cross TO test2.t_ren_cross", nil) + if results[0].Error != nil { + t.Fatalf("omni RENAME TABLE cross-db error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test2", "t_ren_cross") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", ctrDDL, omniDDL) + } + // Old table should not exist in test + if c.ShowCreateTable("test", "t_ren_cross") != "" { + t.Error("omni: old table still exists in test after cross-db rename") + } + }) + + t.Run("rename_multi_pair", func(t *testing.T) { + // RENAME TABLE t1 TO t2, t3 TO t4 (multi-pair) + ctr.execSQL("DROP TABLE IF EXISTS t_mp1") + ctr.execSQL("DROP TABLE IF EXISTS t_mp2") + ctr.execSQL("DROP TABLE IF EXISTS t_mp3") + ctr.execSQL("DROP TABLE IF EXISTS t_mp4") + ctr.execSQL("CREATE TABLE t_mp1 (id INT PRIMARY KEY)") + ctr.execSQL("CREATE TABLE t_mp3 (val VARCHAR(100))") + + if err := ctr.execSQL("RENAME TABLE t_mp1 TO t_mp2, t_mp3 TO t_mp4"); err != nil { + t.Fatalf("container RENAME TABLE multi-pair error: %v", err) + } + ctrDDL2, err := ctr.showCreateTable("t_mp2") + if err != nil { + t.Fatalf("container SHOW CREATE TABLE t_mp2 error: %v", err) + } + ctrDDL4, err := ctr.showCreateTable("t_mp4") + if err != nil { + t.Fatalf("container SHOW CREATE TABLE t_mp4 error: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_mp1 (id INT PRIMARY KEY)", nil) + c.Exec("CREATE TABLE t_mp3 (val VARCHAR(100))", nil) + results, _ := c.Exec("RENAME TABLE t_mp1 TO t_mp2, t_mp3 TO t_mp4", nil) + if results[0].Error != nil { + t.Fatalf("omni RENAME TABLE multi-pair error: %v", results[0].Error) + } + omniDDL2 := c.ShowCreateTable("test", "t_mp2") + omniDDL4 := c.ShowCreateTable("test", "t_mp4") + + if normalizeWhitespace(ctrDDL2) != normalizeWhitespace(omniDDL2) { + t.Errorf("t_mp2 mismatch:\n--- container ---\n%s\n--- omni ---\n%s", ctrDDL2, omniDDL2) + } + if normalizeWhitespace(ctrDDL4) != normalizeWhitespace(omniDDL4) { + t.Errorf("t_mp4 mismatch:\n--- container ---\n%s\n--- omni ---\n%s", ctrDDL4, omniDDL4) + } + // Old tables should be gone + if c.ShowCreateTable("test", "t_mp1") != "" { + t.Error("omni: t_mp1 still exists after rename") + } + if c.ShowCreateTable("test", "t_mp3") != "" { + t.Error("omni: t_mp3 still exists after rename") + } + }) + + t.Run("rename_nonexistent_error", func(t *testing.T) { + // RENAME TABLE nonexistent → error + ctr.execSQL("DROP TABLE IF EXISTS t_noexist_ren") + ctrErr := ctr.execSQL("RENAME TABLE t_noexist_ren TO t_noexist_ren2") + if ctrErr == nil { + t.Fatal("container: expected error for RENAME nonexistent table") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("RENAME TABLE t_noexist_ren TO t_noexist_ren2", &ExecOptions{ContinueOnError: true}) + omniErr := results[0].Error + if omniErr == nil { + t.Fatal("omni: expected error for RENAME nonexistent table") + } + }) + + t.Run("rename_to_existing_error", func(t *testing.T) { + // RENAME TABLE to existing name → error + ctr.execSQL("DROP TABLE IF EXISTS t_ren_exist1") + ctr.execSQL("DROP TABLE IF EXISTS t_ren_exist2") + ctr.execSQL("CREATE TABLE t_ren_exist1 (id INT)") + ctr.execSQL("CREATE TABLE t_ren_exist2 (id INT)") + + ctrErr := ctr.execSQL("RENAME TABLE t_ren_exist1 TO t_ren_exist2") + if ctrErr == nil { + t.Fatal("container: expected error for RENAME to existing table name") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_ren_exist1 (id INT)", nil) + c.Exec("CREATE TABLE t_ren_exist2 (id INT)", nil) + results, _ := c.Exec("RENAME TABLE t_ren_exist1 TO t_ren_exist2", &ExecOptions{ContinueOnError: true}) + omniErr := results[0].Error + if omniErr == nil { + t.Fatal("omni: expected error for RENAME to existing table name") + } + }) +} + +func TestContainer_Section_2_10_CreateDropView(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + t.Run("create_view_basic", func(t *testing.T) { + // CREATE VIEW v AS SELECT ... + ctr.execSQL("DROP VIEW IF EXISTS v_basic") + ctrErr := ctr.execSQL("CREATE VIEW v_basic AS SELECT 1 AS a") + if ctrErr != nil { + t.Fatalf("container CREATE VIEW error: %v", ctrErr) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("CREATE VIEW v_basic AS SELECT 1 AS a", nil) + if results[0].Error != nil { + t.Fatalf("omni CREATE VIEW error: %v", results[0].Error) + } + db := c.GetDatabase("test") + if db.Views[toLower("v_basic")] == nil { + t.Error("omni: view v_basic should exist after CREATE VIEW") + } + }) + + t.Run("create_or_replace_view", func(t *testing.T) { + // CREATE OR REPLACE VIEW + ctr.execSQL("DROP VIEW IF EXISTS v_replace") + ctr.execSQL("CREATE VIEW v_replace AS SELECT 1 AS a") + ctrErr := ctr.execSQL("CREATE OR REPLACE VIEW v_replace AS SELECT 2 AS b") + if ctrErr != nil { + t.Fatalf("container CREATE OR REPLACE VIEW error: %v", ctrErr) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE VIEW v_replace AS SELECT 1 AS a", nil) + results, _ := c.Exec("CREATE OR REPLACE VIEW v_replace AS SELECT 2 AS b", nil) + if results[0].Error != nil { + t.Fatalf("omni CREATE OR REPLACE VIEW error: %v", results[0].Error) + } + db := c.GetDatabase("test") + if db.Views[toLower("v_replace")] == nil { + t.Error("omni: view v_replace should exist after CREATE OR REPLACE VIEW") + } + }) + + t.Run("create_view_with_columns", func(t *testing.T) { + // CREATE VIEW with column list + ctr.execSQL("DROP VIEW IF EXISTS v_cols") + ctrErr := ctr.execSQL("CREATE VIEW v_cols (x, y) AS SELECT 1, 2") + if ctrErr != nil { + t.Fatalf("container CREATE VIEW with columns error: %v", ctrErr) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("CREATE VIEW v_cols (x, y) AS SELECT 1, 2", nil) + if results[0].Error != nil { + t.Fatalf("omni CREATE VIEW with columns error: %v", results[0].Error) + } + db := c.GetDatabase("test") + v := db.Views[toLower("v_cols")] + if v == nil { + t.Fatal("omni: view v_cols should exist") + } + if len(v.Columns) != 2 || v.Columns[0] != "x" || v.Columns[1] != "y" { + t.Errorf("omni: expected columns [x, y], got %v", v.Columns) + } + }) + + t.Run("create_view_with_options", func(t *testing.T) { + // CREATE VIEW with ALGORITHM, DEFINER, SQL_SECURITY + ctr.execSQL("DROP VIEW IF EXISTS v_opts") + ctrErr := ctr.execSQL("CREATE ALGORITHM=MERGE DEFINER=`root`@`localhost` SQL SECURITY INVOKER VIEW v_opts AS SELECT 1 AS a") + if ctrErr != nil { + t.Fatalf("container CREATE VIEW with options error: %v", ctrErr) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("CREATE ALGORITHM=MERGE DEFINER=`root`@`localhost` SQL SECURITY INVOKER VIEW v_opts AS SELECT 1 AS a", nil) + if results[0].Error != nil { + t.Fatalf("omni CREATE VIEW with options error: %v", results[0].Error) + } + db := c.GetDatabase("test") + v := db.Views[toLower("v_opts")] + if v == nil { + t.Fatal("omni: view v_opts should exist") + } + if v.Algorithm != "MERGE" { + t.Errorf("omni: expected Algorithm=MERGE, got %q", v.Algorithm) + } + if v.SqlSecurity != "INVOKER" { + t.Errorf("omni: expected SqlSecurity=INVOKER, got %q", v.SqlSecurity) + } + }) + + t.Run("create_view_with_check_option", func(t *testing.T) { + // CREATE VIEW with CHECK OPTION + ctr.execSQL("DROP VIEW IF EXISTS v_chk") + ctr.execSQL("DROP TABLE IF EXISTS t_chk_view") + ctr.execSQL("CREATE TABLE t_chk_view (id INT, val INT)") + ctrErr := ctr.execSQL("CREATE VIEW v_chk AS SELECT * FROM t_chk_view WHERE val > 0 WITH CASCADED CHECK OPTION") + if ctrErr != nil { + t.Fatalf("container CREATE VIEW WITH CHECK OPTION error: %v", ctrErr) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_chk_view (id INT, val INT)", nil) + results, _ := c.Exec("CREATE VIEW v_chk AS SELECT * FROM t_chk_view WHERE val > 0 WITH CASCADED CHECK OPTION", nil) + if results[0].Error != nil { + t.Fatalf("omni CREATE VIEW WITH CHECK OPTION error: %v", results[0].Error) + } + db := c.GetDatabase("test") + v := db.Views[toLower("v_chk")] + if v == nil { + t.Fatal("omni: view v_chk should exist") + } + if v.CheckOption != "CASCADED" { + t.Errorf("omni: expected CheckOption=CASCADED, got %q", v.CheckOption) + } + }) + + t.Run("drop_view_basic", func(t *testing.T) { + // DROP VIEW v + ctr.execSQL("DROP VIEW IF EXISTS v_drop1") + ctr.execSQL("CREATE VIEW v_drop1 AS SELECT 1 AS a") + ctrErr := ctr.execSQL("DROP VIEW v_drop1") + if ctrErr != nil { + t.Fatalf("container DROP VIEW error: %v", ctrErr) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE VIEW v_drop1 AS SELECT 1 AS a", nil) + results, _ := c.Exec("DROP VIEW v_drop1", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP VIEW error: %v", results[0].Error) + } + db := c.GetDatabase("test") + if db.Views[toLower("v_drop1")] != nil { + t.Error("omni: view v_drop1 should not exist after DROP VIEW") + } + }) + + t.Run("drop_view_if_exists", func(t *testing.T) { + // DROP VIEW IF EXISTS on nonexistent view — no error + ctr.execSQL("DROP VIEW IF EXISTS v_noexist") + ctrErr := ctr.execSQL("DROP VIEW IF EXISTS v_noexist") + if ctrErr != nil { + t.Fatalf("container DROP VIEW IF EXISTS error: %v", ctrErr) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("DROP VIEW IF EXISTS v_noexist", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP VIEW IF EXISTS error: %v", results[0].Error) + } + }) + + t.Run("drop_view_multi", func(t *testing.T) { + // DROP VIEW v1, v2 (multi-view) + ctr.execSQL("DROP VIEW IF EXISTS v_m1") + ctr.execSQL("DROP VIEW IF EXISTS v_m2") + ctr.execSQL("CREATE VIEW v_m1 AS SELECT 1 AS a") + ctr.execSQL("CREATE VIEW v_m2 AS SELECT 2 AS b") + ctrErr := ctr.execSQL("DROP VIEW v_m1, v_m2") + if ctrErr != nil { + t.Fatalf("container DROP VIEW multi error: %v", ctrErr) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE VIEW v_m1 AS SELECT 1 AS a", nil) + c.Exec("CREATE VIEW v_m2 AS SELECT 2 AS b", nil) + results, _ := c.Exec("DROP VIEW v_m1, v_m2", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP VIEW multi error: %v", results[0].Error) + } + db := c.GetDatabase("test") + if db.Views[toLower("v_m1")] != nil { + t.Error("omni: view v_m1 should not exist after DROP VIEW") + } + if db.Views[toLower("v_m2")] != nil { + t.Error("omni: view v_m2 should not exist after DROP VIEW") + } + }) + + // Extra: CREATE VIEW duplicate (no OR REPLACE) should error + t.Run("create_view_duplicate_error", func(t *testing.T) { + ctr.execSQL("DROP VIEW IF EXISTS v_dup") + ctr.execSQL("CREATE VIEW v_dup AS SELECT 1 AS a") + ctrErr := ctr.execSQL("CREATE VIEW v_dup AS SELECT 2 AS b") + if ctrErr == nil { + t.Fatal("container: expected error for duplicate view") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE VIEW v_dup AS SELECT 1 AS a", nil) + results, _ := c.Exec("CREATE VIEW v_dup AS SELECT 2 AS b", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for duplicate view") + } + }) + + // Extra: CREATE VIEW with same name as existing table should error + t.Run("create_view_table_conflict", func(t *testing.T) { + ctr.execSQL("DROP VIEW IF EXISTS v_tbl_conflict") + ctr.execSQL("DROP TABLE IF EXISTS v_tbl_conflict") + ctr.execSQL("CREATE TABLE v_tbl_conflict (id INT)") + ctrErr := ctr.execSQL("CREATE VIEW v_tbl_conflict AS SELECT 1 AS a") + if ctrErr == nil { + t.Fatal("container: expected error for view with same name as table") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE v_tbl_conflict (id INT)", nil) + results, _ := c.Exec("CREATE VIEW v_tbl_conflict AS SELECT 1 AS a", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for view with same name as table") + } + }) + + // Extra: DROP VIEW on nonexistent view (no IF EXISTS) should error + t.Run("drop_view_nonexistent_error", func(t *testing.T) { + ctr.execSQL("DROP VIEW IF EXISTS v_nonexist_err") + ctrErr := ctr.execSQL("DROP VIEW v_nonexist_err") + if ctrErr == nil { + t.Fatal("container: expected error for DROP VIEW on nonexistent view") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("DROP VIEW v_nonexist_err", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for DROP VIEW on nonexistent view") + } + }) +} + +func TestContainer_Section_2_11_CreateDropAlterDatabase(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + t.Run("create_database", func(t *testing.T) { + ctr.execSQL("DROP DATABASE IF EXISTS db_create1") + ctrErr := ctr.execSQL("CREATE DATABASE db_create1") + if ctrErr != nil { + t.Fatalf("container CREATE DATABASE error: %v", ctrErr) + } + ctrDDL, err := ctr.showCreateDatabase("db_create1") + if err != nil { + t.Fatalf("container SHOW CREATE DATABASE error: %v", err) + } + + c := New() + results, _ := c.Exec("CREATE DATABASE db_create1", nil) + if results[0].Error != nil { + t.Fatalf("omni CREATE DATABASE error: %v", results[0].Error) + } + db := c.GetDatabase("db_create1") + if db == nil { + t.Fatal("omni: database not found after CREATE DATABASE") + } + // Verify charset/collation defaults match ctr. + // Oracle returns something like: CREATE DATABASE `db_create1` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci */ /*!80016 DEFAULT ENCRYPTION='N' */ + if !strings.Contains(ctrDDL, "utf8mb4") { + t.Logf("container DDL: %s", ctrDDL) + } + if db.Charset != "utf8mb4" { + t.Errorf("omni charset mismatch: got %q, want utf8mb4", db.Charset) + } + if db.Collation != "utf8mb4_0900_ai_ci" { + t.Errorf("omni collation mismatch: got %q, want utf8mb4_0900_ai_ci", db.Collation) + } + ctr.execSQL("DROP DATABASE IF EXISTS db_create1") + }) + + t.Run("create_database_if_not_exists", func(t *testing.T) { + ctr.execSQL("DROP DATABASE IF EXISTS db_ine") + ctr.execSQL("CREATE DATABASE db_ine") + + // CREATE DATABASE IF NOT EXISTS on existing db should succeed (no error). + ctrErr := ctr.execSQL("CREATE DATABASE IF NOT EXISTS db_ine") + if ctrErr != nil { + t.Fatalf("container CREATE DATABASE IF NOT EXISTS error: %v", ctrErr) + } + + c := New() + c.Exec("CREATE DATABASE db_ine", nil) + results, _ := c.Exec("CREATE DATABASE IF NOT EXISTS db_ine", nil) + if results[0].Error != nil { + t.Fatalf("omni CREATE DATABASE IF NOT EXISTS error: %v", results[0].Error) + } + ctr.execSQL("DROP DATABASE IF EXISTS db_ine") + }) + + t.Run("create_database_charset", func(t *testing.T) { + ctr.execSQL("DROP DATABASE IF EXISTS db_cs") + ctrErr := ctr.execSQL("CREATE DATABASE db_cs CHARACTER SET latin1") + if ctrErr != nil { + t.Fatalf("container CREATE DATABASE with charset error: %v", ctrErr) + } + ctrDDL, _ := ctr.showCreateDatabase("db_cs") + + c := New() + results, _ := c.Exec("CREATE DATABASE db_cs CHARACTER SET latin1", nil) + if results[0].Error != nil { + t.Fatalf("omni CREATE DATABASE with charset error: %v", results[0].Error) + } + db := c.GetDatabase("db_cs") + if db == nil { + t.Fatal("omni: database not found") + } + // Oracle should show latin1 charset + if !strings.Contains(ctrDDL, "latin1") { + t.Errorf("container DDL missing latin1: %s", ctrDDL) + } + if db.Charset != "latin1" { + t.Errorf("omni charset: got %q, want latin1", db.Charset) + } + // Default collation for latin1 is latin1_swedish_ci + if db.Collation != "latin1_swedish_ci" { + t.Errorf("omni collation: got %q, want latin1_swedish_ci", db.Collation) + } + ctr.execSQL("DROP DATABASE IF EXISTS db_cs") + }) + + t.Run("create_database_collate", func(t *testing.T) { + ctr.execSQL("DROP DATABASE IF EXISTS db_coll") + ctrErr := ctr.execSQL("CREATE DATABASE db_coll COLLATE utf8mb4_unicode_ci") + if ctrErr != nil { + t.Fatalf("container CREATE DATABASE with collate error: %v", ctrErr) + } + ctrDDL, _ := ctr.showCreateDatabase("db_coll") + + c := New() + results, _ := c.Exec("CREATE DATABASE db_coll COLLATE utf8mb4_unicode_ci", nil) + if results[0].Error != nil { + t.Fatalf("omni CREATE DATABASE with collate error: %v", results[0].Error) + } + db := c.GetDatabase("db_coll") + if db == nil { + t.Fatal("omni: database not found") + } + if !strings.Contains(ctrDDL, "utf8mb4_unicode_ci") { + t.Errorf("container DDL missing utf8mb4_unicode_ci: %s", ctrDDL) + } + if db.Collation != "utf8mb4_unicode_ci" { + t.Errorf("omni collation: got %q, want utf8mb4_unicode_ci", db.Collation) + } + ctr.execSQL("DROP DATABASE IF EXISTS db_coll") + }) + + t.Run("drop_database", func(t *testing.T) { + ctr.execSQL("DROP DATABASE IF EXISTS db_drop1") + ctr.execSQL("CREATE DATABASE db_drop1") + + ctrErr := ctr.execSQL("DROP DATABASE db_drop1") + if ctrErr != nil { + t.Fatalf("container DROP DATABASE error: %v", ctrErr) + } + // Verify it's gone. + _, showErr := ctr.showCreateDatabase("db_drop1") + if showErr == nil { + t.Fatal("container: database still exists after DROP DATABASE") + } + + c := New() + c.Exec("CREATE DATABASE db_drop1", nil) + results, _ := c.Exec("DROP DATABASE db_drop1", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP DATABASE error: %v", results[0].Error) + } + if c.GetDatabase("db_drop1") != nil { + t.Fatal("omni: database still exists after DROP DATABASE") + } + }) + + t.Run("drop_database_if_exists", func(t *testing.T) { + ctr.execSQL("DROP DATABASE IF EXISTS db_drop_ine") + // DROP DATABASE IF EXISTS on nonexistent db should not error. + ctrErr := ctr.execSQL("DROP DATABASE IF EXISTS db_drop_ine") + if ctrErr != nil { + t.Fatalf("container DROP DATABASE IF EXISTS error: %v", ctrErr) + } + + c := New() + results, _ := c.Exec("DROP DATABASE IF EXISTS db_drop_ine", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP DATABASE IF EXISTS error: %v", results[0].Error) + } + }) + + t.Run("alter_database_charset", func(t *testing.T) { + ctr.execSQL("DROP DATABASE IF EXISTS db_alter_cs") + ctr.execSQL("CREATE DATABASE db_alter_cs") + + ctrErr := ctr.execSQL("ALTER DATABASE db_alter_cs CHARACTER SET utf8mb4") + if ctrErr != nil { + t.Fatalf("container ALTER DATABASE charset error: %v", ctrErr) + } + ctrDDL, _ := ctr.showCreateDatabase("db_alter_cs") + + c := New() + c.Exec("CREATE DATABASE db_alter_cs", nil) + results, _ := c.Exec("ALTER DATABASE db_alter_cs CHARACTER SET utf8mb4", nil) + if results[0].Error != nil { + t.Fatalf("omni ALTER DATABASE charset error: %v", results[0].Error) + } + db := c.GetDatabase("db_alter_cs") + if db == nil { + t.Fatal("omni: database not found") + } + if !strings.Contains(ctrDDL, "utf8mb4") { + t.Errorf("container DDL missing utf8mb4: %s", ctrDDL) + } + if db.Charset != "utf8mb4" { + t.Errorf("omni charset: got %q, want utf8mb4", db.Charset) + } + if db.Collation != "utf8mb4_0900_ai_ci" { + t.Errorf("omni collation: got %q, want utf8mb4_0900_ai_ci", db.Collation) + } + ctr.execSQL("DROP DATABASE IF EXISTS db_alter_cs") + }) + + t.Run("alter_database_collate", func(t *testing.T) { + ctr.execSQL("DROP DATABASE IF EXISTS db_alter_coll") + ctr.execSQL("CREATE DATABASE db_alter_coll") + + ctrErr := ctr.execSQL("ALTER DATABASE db_alter_coll COLLATE utf8mb4_unicode_ci") + if ctrErr != nil { + t.Fatalf("container ALTER DATABASE collate error: %v", ctrErr) + } + ctrDDL, _ := ctr.showCreateDatabase("db_alter_coll") + + c := New() + c.Exec("CREATE DATABASE db_alter_coll", nil) + results, _ := c.Exec("ALTER DATABASE db_alter_coll COLLATE utf8mb4_unicode_ci", nil) + if results[0].Error != nil { + t.Fatalf("omni ALTER DATABASE collate error: %v", results[0].Error) + } + db := c.GetDatabase("db_alter_coll") + if db == nil { + t.Fatal("omni: database not found") + } + if !strings.Contains(ctrDDL, "utf8mb4_unicode_ci") { + t.Errorf("container DDL missing utf8mb4_unicode_ci: %s", ctrDDL) + } + if db.Collation != "utf8mb4_unicode_ci" { + t.Errorf("omni collation: got %q, want utf8mb4_unicode_ci", db.Collation) + } + ctr.execSQL("DROP DATABASE IF EXISTS db_alter_coll") + }) + + t.Run("ops_on_nonexistent_database", func(t *testing.T) { + // DROP DATABASE on nonexistent db should error. + ctr.execSQL("DROP DATABASE IF EXISTS db_nonexist_xyz") + oracleDropErr := ctr.execSQL("DROP DATABASE db_nonexist_xyz") + if oracleDropErr == nil { + t.Fatal("container: expected error for DROP nonexistent database") + } + + c := New() + results, _ := c.Exec("DROP DATABASE db_nonexist_xyz", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for DROP nonexistent database") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + if catErr.Code != ErrUnknownDatabase { + t.Errorf("omni error code: got %d, want %d", catErr.Code, ErrUnknownDatabase) + } + + // ALTER DATABASE on nonexistent db should error. + oracleAlterErr := ctr.execSQL("ALTER DATABASE db_nonexist_xyz CHARACTER SET utf8mb4") + if oracleAlterErr == nil { + t.Fatal("container: expected error for ALTER nonexistent database") + } + + c2 := New() + results2, _ := c2.Exec("ALTER DATABASE db_nonexist_xyz CHARACTER SET utf8mb4", &ExecOptions{ContinueOnError: true}) + if results2[0].Error == nil { + t.Fatal("omni: expected error for ALTER nonexistent database") + } + catErr2, ok := results2[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results2[0].Error) + } + if catErr2.Code != ErrUnknownDatabase { + t.Errorf("omni error code: got %d, want %d", catErr2.Code, ErrUnknownDatabase) + } + + // CREATE DATABASE duplicate should error. + c3 := New() + c3.Exec("CREATE DATABASE db_dup_test", nil) + results3, _ := c3.Exec("CREATE DATABASE db_dup_test", &ExecOptions{ContinueOnError: true}) + if results3[0].Error == nil { + t.Fatal("omni: expected error for duplicate CREATE DATABASE") + } + catErr3, ok := results3[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results3[0].Error) + } + if catErr3.Code != ErrDupDatabase { + t.Errorf("omni error code: got %d, want %d", catErr3.Code, ErrDupDatabase) + } + }) +} + +func TestContainer_Section_3_1_DatabaseErrors(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Helper to extract MySQL error code and message from go-sql-driver error. + extractMySQLErr := func(err error) (uint16, string, string) { + var mysqlErr *mysqldriver.MySQLError + if errors.As(err, &mysqlErr) { + return mysqlErr.Number, string(mysqlErr.SQLState[:]), mysqlErr.Message + } + return 0, "", "" + } + + t.Run("1007_dup_database", func(t *testing.T) { + // Setup: create the database first, then try to create it again. + ctr.execSQL("DROP DATABASE IF EXISTS db_err_dup") + ctr.execSQL("CREATE DATABASE db_err_dup") + + ctrErr := ctr.execSQL("CREATE DATABASE db_err_dup") + if ctrErr == nil { + t.Fatal("container: expected error for duplicate CREATE DATABASE") + } + ctrCode, ctrState, ctrMsg := extractMySQLErr(ctrErr) + t.Logf("container error: %d (%s) %s", ctrCode, ctrState, ctrMsg) + + if ctrCode != 1007 { + t.Fatalf("container: expected error code 1007, got %d", ctrCode) + } + + // Run on omni. + c := New() + c.Exec("CREATE DATABASE db_err_dup", nil) + results, _ := c.Exec("CREATE DATABASE db_err_dup", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for duplicate CREATE DATABASE") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + + // Compare error code. + if catErr.Code != int(ctrCode) { + t.Errorf("error code mismatch: container=%d omni=%d", ctrCode, catErr.Code) + } + // Compare SQLSTATE. + if catErr.SQLState != ctrState { + t.Errorf("SQLSTATE mismatch: container=%q omni=%q", ctrState, catErr.SQLState) + } + // Compare error message format. + // Oracle: "Can't create database 'db_err_dup'; database exists" + if catErr.Message != ctrMsg { + t.Errorf("message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + + ctr.execSQL("DROP DATABASE IF EXISTS db_err_dup") + }) + + t.Run("1049_unknown_database", func(t *testing.T) { + // USE a nonexistent database on ctr. + ctr.execSQL("DROP DATABASE IF EXISTS db_err_unknown_xyz") + ctrErr := ctr.execSQL("USE db_err_unknown_xyz") + if ctrErr == nil { + t.Fatal("container: expected error for USE nonexistent database") + } + ctrCode, ctrState, ctrMsg := extractMySQLErr(ctrErr) + t.Logf("container error: %d (%s) %s", ctrCode, ctrState, ctrMsg) + + if ctrCode != 1049 { + t.Fatalf("container: expected error code 1049, got %d", ctrCode) + } + + // Run on omni: DROP DATABASE on nonexistent db triggers 1008, but + // we need to test the "Unknown database" error (1049). + // Use DROP DATABASE which should return 1008 in MySQL... + // Actually, let's check what MySQL returns for DROP DATABASE on nonexistent: + ctrErr2 := ctr.execSQL("DROP DATABASE db_err_unknown_xyz") + if ctrErr2 == nil { + t.Fatal("container: expected error for DROP nonexistent database") + } + ctrCode2, _, _ := extractMySQLErr(ctrErr2) + t.Logf("container DROP error code: %d", ctrCode2) + + // For omni, test USE nonexistent database. + c := New() + results, _ := c.Exec("USE db_err_unknown_xyz", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for USE nonexistent database") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + + // Compare error code. + if catErr.Code != int(ctrCode) { + t.Errorf("error code mismatch: container=%d omni=%d", ctrCode, catErr.Code) + } + // Compare SQLSTATE. + if catErr.SQLState != ctrState { + t.Errorf("SQLSTATE mismatch: container=%q omni=%q", ctrState, catErr.SQLState) + } + // Compare error message format. + if catErr.Message != ctrMsg { + t.Errorf("message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + }) + + t.Run("1046_no_database_selected", func(t *testing.T) { + // On ctr, try to CREATE TABLE without selecting a database. + // We need a fresh connection with no default database. + // The container connection defaults to "test" database, so we'll + // test omni behavior and verify the error code/SQLSTATE/message match MySQL's known format. + + // First verify MySQL's behavior: SELECT DATABASE() after no USE should work, + // but CREATE TABLE without database should fail. + // Since our container connection defaults to 'test' db, we verify omni matches MySQL's + // documented error format: ERROR 1046 (3D000): No database selected + + // Run on omni with no current database. + c := New() + // Don't set any database — just try to create a table. + results, _ := c.Exec("CREATE TABLE t_no_db (id INT)", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for CREATE TABLE without database selected") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + + // Verify against MySQL's documented error values. + wantCode := 1046 + wantState := "3D000" + wantMsg := "No database selected" + + if catErr.Code != wantCode { + t.Errorf("error code: got %d, want %d", catErr.Code, wantCode) + } + if catErr.SQLState != wantState { + t.Errorf("SQLSTATE: got %q, want %q", catErr.SQLState, wantState) + } + if catErr.Message != wantMsg { + t.Errorf("message: got %q, want %q", catErr.Message, wantMsg) + } + }) +} + +func TestContainer_Section_3_2_TableErrors(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Helper to extract MySQL error code and message from go-sql-driver error. + extractMySQLErr := func(err error) (uint16, string, string) { + var mysqlErr *mysqldriver.MySQLError + if errors.As(err, &mysqlErr) { + return mysqlErr.Number, string(mysqlErr.SQLState[:]), mysqlErr.Message + } + return 0, "", "" + } + + t.Run("1050_table_already_exists", func(t *testing.T) { + // Setup: create a table, then try to create it again. + ctr.execSQL("DROP TABLE IF EXISTS t_err_dup") + ctr.execSQL("CREATE TABLE t_err_dup (id INT)") + + ctrErr := ctr.execSQL("CREATE TABLE t_err_dup (id INT)") + if ctrErr == nil { + t.Fatal("container: expected error for duplicate CREATE TABLE") + } + ctrCode, ctrState, ctrMsg := extractMySQLErr(ctrErr) + t.Logf("container error: %d (%s) %s", ctrCode, ctrState, ctrMsg) + + if ctrCode != 1050 { + t.Fatalf("container: expected error code 1050, got %d", ctrCode) + } + + // Run on omni. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_err_dup (id INT)", nil) + results, _ := c.Exec("CREATE TABLE t_err_dup (id INT)", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for duplicate CREATE TABLE") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + + // Compare error code. + if catErr.Code != int(ctrCode) { + t.Errorf("error code mismatch: container=%d omni=%d", ctrCode, catErr.Code) + } + // Compare SQLSTATE. + if catErr.SQLState != ctrState { + t.Errorf("SQLSTATE mismatch: container=%q omni=%q", ctrState, catErr.SQLState) + } + // Compare error message format. + if catErr.Message != ctrMsg { + t.Errorf("message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + + ctr.execSQL("DROP TABLE IF EXISTS t_err_dup") + }) + + t.Run("1051_unknown_table_drop", func(t *testing.T) { + // DROP TABLE on a nonexistent table should return 1051. + ctr.execSQL("DROP TABLE IF EXISTS t_err_noexist") + + ctrErr := ctr.execSQL("DROP TABLE t_err_noexist") + if ctrErr == nil { + t.Fatal("container: expected error for DROP nonexistent table") + } + ctrCode, ctrState, ctrMsg := extractMySQLErr(ctrErr) + t.Logf("container error: %d (%s) %s", ctrCode, ctrState, ctrMsg) + + if ctrCode != 1051 { + t.Fatalf("container: expected error code 1051, got %d", ctrCode) + } + + // Run on omni. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("DROP TABLE t_err_noexist", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for DROP nonexistent table") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + + // Compare error code. + if catErr.Code != int(ctrCode) { + t.Errorf("error code mismatch: container=%d omni=%d", ctrCode, catErr.Code) + } + // Compare SQLSTATE. + if catErr.SQLState != ctrState { + t.Errorf("SQLSTATE mismatch: container=%q omni=%q", ctrState, catErr.SQLState) + } + // Compare error message format. + if catErr.Message != ctrMsg { + t.Errorf("message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + }) + + t.Run("1146_table_doesnt_exist", func(t *testing.T) { + // ALTER TABLE on a nonexistent table should return 1146. + ctr.execSQL("DROP TABLE IF EXISTS t_err_noexist2") + + ctrErr := ctr.execSQL("ALTER TABLE t_err_noexist2 ADD COLUMN x INT") + if ctrErr == nil { + t.Fatal("container: expected error for ALTER nonexistent table") + } + ctrCode, ctrState, ctrMsg := extractMySQLErr(ctrErr) + t.Logf("container error: %d (%s) %s", ctrCode, ctrState, ctrMsg) + + if ctrCode != 1146 { + t.Fatalf("container: expected error code 1146, got %d", ctrCode) + } + + // Run on omni. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("ALTER TABLE t_err_noexist2 ADD COLUMN x INT", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for ALTER nonexistent table") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + + // Compare error code. + if catErr.Code != int(ctrCode) { + t.Errorf("error code mismatch: container=%d omni=%d", ctrCode, catErr.Code) + } + // Compare SQLSTATE. + if catErr.SQLState != ctrState { + t.Errorf("SQLSTATE mismatch: container=%q omni=%q", ctrState, catErr.SQLState) + } + // Compare error message format. + if catErr.Message != ctrMsg { + t.Errorf("message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + }) +} + +func TestContainer_Section_3_3_ColumnErrors(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Helper to extract MySQL error code and message from go-sql-driver error. + extractMySQLErr := func(err error) (uint16, string, string) { + var mysqlErr *mysqldriver.MySQLError + if errors.As(err, &mysqlErr) { + return mysqlErr.Number, string(mysqlErr.SQLState[:]), mysqlErr.Message + } + return 0, "", "" + } + + t.Run("1054_unknown_column", func(t *testing.T) { + // ALTER TABLE ... MODIFY COLUMN AFTER nonexistent_col triggers 1054 + // "Unknown column 'col' in 'table definition'" + ctr.execSQL("DROP TABLE IF EXISTS t_err_nocol") + ctr.execSQL("CREATE TABLE t_err_nocol (id INT, name VARCHAR(50))") + + ctrErr := ctr.execSQL("ALTER TABLE t_err_nocol MODIFY COLUMN name VARCHAR(50) AFTER nonexistent") + if ctrErr == nil { + t.Fatal("container: expected error for unknown column in AFTER clause") + } + ctrCode, ctrState, ctrMsg := extractMySQLErr(ctrErr) + t.Logf("container error: %d (%s) %s", ctrCode, ctrState, ctrMsg) + + if ctrCode != 1054 { + t.Fatalf("container: expected error code 1054, got %d", ctrCode) + } + + // Run on omni. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_err_nocol (id INT, name VARCHAR(50))", nil) + results, _ := c.Exec("ALTER TABLE t_err_nocol MODIFY COLUMN name VARCHAR(50) AFTER nonexistent", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for unknown column in AFTER clause") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + + // Compare error code. + if catErr.Code != int(ctrCode) { + t.Errorf("error code mismatch: container=%d omni=%d", ctrCode, catErr.Code) + } + // Compare SQLSTATE. + if catErr.SQLState != ctrState { + t.Errorf("SQLSTATE mismatch: container=%q omni=%q", ctrState, catErr.SQLState) + } + // Compare error message format. + if catErr.Message != ctrMsg { + t.Errorf("message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + + ctr.execSQL("DROP TABLE IF EXISTS t_err_nocol") + }) + + t.Run("1060_duplicate_column", func(t *testing.T) { + // CREATE TABLE with two columns of the same name. + ctr.execSQL("DROP TABLE IF EXISTS t_err_dupcol") + + ctrErr := ctr.execSQL("CREATE TABLE t_err_dupcol (a INT, a VARCHAR(10))") + if ctrErr == nil { + t.Fatal("container: expected error for duplicate column name") + } + ctrCode, ctrState, ctrMsg := extractMySQLErr(ctrErr) + t.Logf("container error: %d (%s) %s", ctrCode, ctrState, ctrMsg) + + if ctrCode != 1060 { + t.Fatalf("container: expected error code 1060, got %d", ctrCode) + } + + // Run on omni. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("CREATE TABLE t_err_dupcol (a INT, a VARCHAR(10))", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for duplicate column name") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + + // Compare error code. + if catErr.Code != int(ctrCode) { + t.Errorf("error code mismatch: container=%d omni=%d", ctrCode, catErr.Code) + } + // Compare SQLSTATE. + if catErr.SQLState != ctrState { + t.Errorf("SQLSTATE mismatch: container=%q omni=%q", ctrState, catErr.SQLState) + } + // Compare error message format. + if catErr.Message != ctrMsg { + t.Errorf("message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + }) + + t.Run("1068_multiple_primary_key", func(t *testing.T) { + // CREATE TABLE with two PRIMARY KEY definitions. + ctr.execSQL("DROP TABLE IF EXISTS t_err_multipk") + + ctrErr := ctr.execSQL("CREATE TABLE t_err_multipk (a INT, b INT, PRIMARY KEY (a), PRIMARY KEY (b))") + if ctrErr == nil { + t.Fatal("container: expected error for multiple primary key") + } + ctrCode, ctrState, ctrMsg := extractMySQLErr(ctrErr) + t.Logf("container error: %d (%s) %s", ctrCode, ctrState, ctrMsg) + + if ctrCode != 1068 { + t.Fatalf("container: expected error code 1068, got %d", ctrCode) + } + + // Run on omni. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("CREATE TABLE t_err_multipk (a INT, b INT, PRIMARY KEY (a), PRIMARY KEY (b))", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for multiple primary key") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + + // Compare error code. + if catErr.Code != int(ctrCode) { + t.Errorf("error code mismatch: container=%d omni=%d", ctrCode, catErr.Code) + } + // Compare SQLSTATE. + if catErr.SQLState != ctrState { + t.Errorf("SQLSTATE mismatch: container=%q omni=%q", ctrState, catErr.SQLState) + } + // Compare error message format. + if catErr.Message != ctrMsg { + t.Errorf("message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + }) +} + +func TestContainer_Section_3_4_IndexKeyErrors(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Helper to extract MySQL error code and message from go-sql-driver error. + extractMySQLErr := func(err error) (uint16, string, string) { + var mysqlErr *mysqldriver.MySQLError + if errors.As(err, &mysqlErr) { + return mysqlErr.Number, string(mysqlErr.SQLState[:]), mysqlErr.Message + } + return 0, "", "" + } + + t.Run("1061_dup_key_name", func(t *testing.T) { + // Setup: create a table with an index, then try to add another index with the same name. + ctr.execSQL("DROP TABLE IF EXISTS t_dup_key") + ctr.execSQL("CREATE TABLE t_dup_key (a INT, b INT, KEY idx_a (a))") + + ctrErr := ctr.execSQL("ALTER TABLE t_dup_key ADD INDEX idx_a (b)") + if ctrErr == nil { + t.Fatal("container: expected error for duplicate key name") + } + ctrCode, ctrState, ctrMsg := extractMySQLErr(ctrErr) + t.Logf("container error: %d (%s) %s", ctrCode, ctrState, ctrMsg) + + if ctrCode != 1061 { + t.Fatalf("container: expected error code 1061, got %d", ctrCode) + } + + // Run on omni. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_dup_key (a INT, b INT, KEY idx_a (a))", nil) + results, _ := c.Exec("ALTER TABLE t_dup_key ADD INDEX idx_a (b)", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for duplicate key name") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + + // Compare error code. + if catErr.Code != int(ctrCode) { + t.Errorf("error code mismatch: container=%d omni=%d", ctrCode, catErr.Code) + } + // Compare SQLSTATE. + if catErr.SQLState != ctrState { + t.Errorf("SQLSTATE mismatch: container=%q omni=%q", ctrState, catErr.SQLState) + } + // Compare error message format. + if catErr.Message != ctrMsg { + t.Errorf("message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + + ctr.execSQL("DROP TABLE IF EXISTS t_dup_key") + }) + + t.Run("1091_cant_drop_key", func(t *testing.T) { + // Setup: create a table, then try to drop a nonexistent index. + ctr.execSQL("DROP TABLE IF EXISTS t_drop_key") + ctr.execSQL("CREATE TABLE t_drop_key (a INT)") + + ctrErr := ctr.execSQL("ALTER TABLE t_drop_key DROP INDEX idx_nonexistent") + if ctrErr == nil { + t.Fatal("container: expected error for dropping nonexistent key") + } + ctrCode, ctrState, ctrMsg := extractMySQLErr(ctrErr) + t.Logf("container error: %d (%s) %s", ctrCode, ctrState, ctrMsg) + + if ctrCode != 1091 { + t.Fatalf("container: expected error code 1091, got %d", ctrCode) + } + + // Run on omni. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_drop_key (a INT)", nil) + results, _ := c.Exec("ALTER TABLE t_drop_key DROP INDEX idx_nonexistent", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for dropping nonexistent key") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + + // Compare error code. + if catErr.Code != int(ctrCode) { + t.Errorf("error code mismatch: container=%d omni=%d", ctrCode, catErr.Code) + } + // Compare SQLSTATE. + if catErr.SQLState != ctrState { + t.Errorf("SQLSTATE mismatch: container=%q omni=%q", ctrState, catErr.SQLState) + } + // Compare error message format. + if catErr.Message != ctrMsg { + t.Errorf("message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + + ctr.execSQL("DROP TABLE IF EXISTS t_drop_key") + }) +} + +func TestContainer_Section_3_5_FKErrors(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Helper to extract MySQL error code and message from go-sql-driver error. + extractMySQLErr := func(err error) (uint16, string, string) { + var mysqlErr *mysqldriver.MySQLError + if errors.As(err, &mysqlErr) { + return mysqlErr.Number, string(mysqlErr.SQLState[:]), mysqlErr.Message + } + return 0, "", "" + } + + t.Run("1824_fk_ref_table_not_found", func(t *testing.T) { + // Try to create a table with FK referencing a nonexistent table. + ctr.execSQL("DROP TABLE IF EXISTS t_fk_noref") + ctr.execSQL("DROP TABLE IF EXISTS t_nonexistent_parent") + + ctrErr := ctr.execSQL("CREATE TABLE t_fk_noref (id INT, pid INT, FOREIGN KEY (pid) REFERENCES t_nonexistent_parent(id))") + if ctrErr == nil { + t.Fatal("container: expected error for FK referencing nonexistent table") + } + ctrCode, ctrState, ctrMsg := extractMySQLErr(ctrErr) + t.Logf("container error: %d (%s) %s", ctrCode, ctrState, ctrMsg) + + if ctrCode != 1824 { + t.Fatalf("container: expected error code 1824, got %d", ctrCode) + } + + // Run on omni. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("CREATE TABLE t_fk_noref (id INT, pid INT, FOREIGN KEY (pid) REFERENCES t_nonexistent_parent(id))", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for FK referencing nonexistent table") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T: %v", results[0].Error, results[0].Error) + } + + // Compare error code. + if catErr.Code != int(ctrCode) { + t.Errorf("error code mismatch: container=%d omni=%d", ctrCode, catErr.Code) + } + // Compare SQLSTATE. + if catErr.SQLState != ctrState { + t.Errorf("SQLSTATE mismatch: container=%q omni=%q", ctrState, catErr.SQLState) + } + // Compare error message format. + if catErr.Message != ctrMsg { + t.Errorf("message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + }) + + t.Run("1822_fk_missing_index_on_ref_table", func(t *testing.T) { + // Create a parent table without a key on the referenced column, + // then try to create a child with FK referencing that column. + ctr.execSQL("DROP TABLE IF EXISTS t_fk_child_nokey") + ctr.execSQL("DROP TABLE IF EXISTS t_fk_parent_nokey") + ctr.execSQL("CREATE TABLE t_fk_parent_nokey (id INT, val INT)") + + ctrErr := ctr.execSQL("CREATE TABLE t_fk_child_nokey (id INT, pid INT, FOREIGN KEY (pid) REFERENCES t_fk_parent_nokey(val))") + if ctrErr == nil { + t.Fatal("container: expected error for FK referencing column without index") + } + ctrCode, ctrState, ctrMsg := extractMySQLErr(ctrErr) + t.Logf("container error: %d (%s) %s", ctrCode, ctrState, ctrMsg) + + if ctrCode != 1822 { + t.Fatalf("container: expected error code 1822, got %d", ctrCode) + } + + // Run on omni. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_fk_parent_nokey (id INT, val INT)", nil) + results, _ := c.Exec("CREATE TABLE t_fk_child_nokey (id INT, pid INT, FOREIGN KEY (pid) REFERENCES t_fk_parent_nokey(val))", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for FK referencing column without index") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T: %v", results[0].Error, results[0].Error) + } + + // Compare error code. + if catErr.Code != int(ctrCode) { + t.Errorf("error code mismatch: container=%d omni=%d", ctrCode, catErr.Code) + } + // Compare SQLSTATE. + if catErr.SQLState != ctrState { + t.Errorf("SQLSTATE mismatch: container=%q omni=%q", ctrState, catErr.SQLState) + } + // Compare error message format. + if catErr.Message != ctrMsg { + t.Errorf("message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + + ctr.execSQL("DROP TABLE IF EXISTS t_fk_child_nokey") + ctr.execSQL("DROP TABLE IF EXISTS t_fk_parent_nokey") + }) + + t.Run("3780_fk_column_type_mismatch", func(t *testing.T) { + // Create parent with INT PK, then try to create child with VARCHAR FK column. + ctr.execSQL("DROP TABLE IF EXISTS t_fk_child_mismatch") + ctr.execSQL("DROP TABLE IF EXISTS t_fk_parent_mismatch") + ctr.execSQL("CREATE TABLE t_fk_parent_mismatch (id INT PRIMARY KEY)") + + ctrErr := ctr.execSQL("CREATE TABLE t_fk_child_mismatch (id INT, pid VARCHAR(50), FOREIGN KEY (pid) REFERENCES t_fk_parent_mismatch(id))") + if ctrErr == nil { + t.Fatal("container: expected error for FK column type mismatch") + } + ctrCode, ctrState, ctrMsg := extractMySQLErr(ctrErr) + t.Logf("container error: %d (%s) %s", ctrCode, ctrState, ctrMsg) + + // Run on omni. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_fk_parent_mismatch (id INT PRIMARY KEY)", nil) + results, _ := c.Exec("CREATE TABLE t_fk_child_mismatch (id INT, pid VARCHAR(50), FOREIGN KEY (pid) REFERENCES t_fk_parent_mismatch(id))", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for FK column type mismatch") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T: %v", results[0].Error, results[0].Error) + } + + // Compare error code. + if catErr.Code != int(ctrCode) { + t.Errorf("error code mismatch: container=%d omni=%d", ctrCode, catErr.Code) + } + // Compare SQLSTATE. + if catErr.SQLState != ctrState { + t.Errorf("SQLSTATE mismatch: container=%q omni=%q", ctrState, catErr.SQLState) + } + // Compare error message format. + if catErr.Message != ctrMsg { + t.Errorf("message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + + ctr.execSQL("DROP TABLE IF EXISTS t_fk_child_mismatch") + ctr.execSQL("DROP TABLE IF EXISTS t_fk_parent_mismatch") + }) +} + +func TestContainer_Section_3_6_ErrorContext(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Helper to extract MySQL error code and message from go-sql-driver error. + extractMySQLErr := func(err error) (uint16, string, string) { + var mysqlErr *mysqldriver.MySQLError + if errors.As(err, &mysqlErr) { + return mysqlErr.Number, string(mysqlErr.SQLState[:]), mysqlErr.Message + } + return 0, "", "" + } + + // Scenario 1: Error message identifier quoting matches MySQL + // MySQL uses single quotes around identifiers in error messages (not backticks). + // Test a variety of error types and compare message format exactly. + t.Run("identifier_quoting", func(t *testing.T) { + // 1a: Duplicate database — quotes around db name + ctr.execSQL("DROP DATABASE IF EXISTS db_quote_test") + ctr.execSQL("CREATE DATABASE db_quote_test") + ctrErr := ctr.execSQL("CREATE DATABASE db_quote_test") + if ctrErr == nil { + t.Fatal("container: expected error") + } + ctrCode, _, ctrMsg := extractMySQLErr(ctrErr) + t.Logf("container dup db: %d %s", ctrCode, ctrMsg) + + c := New() + c.Exec("CREATE DATABASE db_quote_test", nil) + results, _ := c.Exec("CREATE DATABASE db_quote_test", &ExecOptions{ContinueOnError: true}) + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + if catErr.Message != ctrMsg { + t.Errorf("dup db message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + + // 1b: Duplicate table — quotes around table name + ctr.execSQL("DROP TABLE IF EXISTS t_quote_test") + ctr.execSQL("CREATE TABLE t_quote_test (id INT)") + ctrErr = ctr.execSQL("CREATE TABLE t_quote_test (id INT)") + if ctrErr == nil { + t.Fatal("container: expected error") + } + ctrCode, _, ctrMsg = extractMySQLErr(ctrErr) + t.Logf("container dup table: %d %s", ctrCode, ctrMsg) + + c2 := New() + c2.Exec("CREATE DATABASE test", nil) + c2.SetCurrentDatabase("test") + c2.Exec("CREATE TABLE t_quote_test (id INT)", nil) + results, _ = c2.Exec("CREATE TABLE t_quote_test (id INT)", &ExecOptions{ContinueOnError: true}) + catErr, ok = results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + if catErr.Message != ctrMsg { + t.Errorf("dup table message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + + // 1c: Unknown column — quotes around column name and context + ctr.execSQL("DROP TABLE IF EXISTS t_col_quote") + ctr.execSQL("CREATE TABLE t_col_quote (id INT)") + ctrErr = ctr.execSQL("ALTER TABLE t_col_quote DROP COLUMN nonexistent") + if ctrErr == nil { + t.Fatal("container: expected error") + } + ctrCode, _, ctrMsg = extractMySQLErr(ctrErr) + t.Logf("container unknown col: %d %s", ctrCode, ctrMsg) + + c3 := New() + c3.Exec("CREATE DATABASE test", nil) + c3.SetCurrentDatabase("test") + c3.Exec("CREATE TABLE t_col_quote (id INT)", nil) + results, _ = c3.Exec("ALTER TABLE t_col_quote DROP COLUMN nonexistent", &ExecOptions{ContinueOnError: true}) + catErr, ok = results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + if catErr.Message != ctrMsg { + t.Errorf("unknown col message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + + // 1d: Duplicate key name — quotes around key name + ctr.execSQL("DROP TABLE IF EXISTS t_key_quote") + ctr.execSQL("CREATE TABLE t_key_quote (id INT, val INT, KEY idx_val (val))") + ctrErr = ctr.execSQL("ALTER TABLE t_key_quote ADD INDEX idx_val (id)") + if ctrErr == nil { + t.Fatal("container: expected error") + } + ctrCode, _, ctrMsg = extractMySQLErr(ctrErr) + t.Logf("container dup key: %d %s", ctrCode, ctrMsg) + + c4 := New() + c4.Exec("CREATE DATABASE test", nil) + c4.SetCurrentDatabase("test") + c4.Exec("CREATE TABLE t_key_quote (id INT, val INT, KEY idx_val (val))", nil) + results, _ = c4.Exec("ALTER TABLE t_key_quote ADD INDEX idx_val (id)", &ExecOptions{ContinueOnError: true}) + catErr, ok = results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + if catErr.Message != ctrMsg { + t.Errorf("dup key message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + + // 1e: Can't drop key — quotes around key name + ctr.execSQL("DROP TABLE IF EXISTS t_dropkey_quote") + ctr.execSQL("CREATE TABLE t_dropkey_quote (id INT)") + ctrErr = ctr.execSQL("ALTER TABLE t_dropkey_quote DROP INDEX nokey") + if ctrErr == nil { + t.Fatal("container: expected error") + } + ctrCode, _, ctrMsg = extractMySQLErr(ctrErr) + t.Logf("container can't drop key: %d %s", ctrCode, ctrMsg) + + c5 := New() + c5.Exec("CREATE DATABASE test", nil) + c5.SetCurrentDatabase("test") + c5.Exec("CREATE TABLE t_dropkey_quote (id INT)", nil) + results, _ = c5.Exec("ALTER TABLE t_dropkey_quote DROP INDEX nokey", &ExecOptions{ContinueOnError: true}) + catErr, ok = results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + if catErr.Message != ctrMsg { + t.Errorf("can't drop key message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + + // 1f: Unknown table (DROP TABLE) — quotes around db.table + ctr.execSQL("DROP TABLE IF EXISTS t_unknown_quote") + ctrErr = ctr.execSQL("DROP TABLE t_unknown_quote") + if ctrErr == nil { + t.Fatal("container: expected error") + } + ctrCode, _, ctrMsg = extractMySQLErr(ctrErr) + t.Logf("container unknown table: %d %s", ctrCode, ctrMsg) + + c6 := New() + c6.Exec("CREATE DATABASE test", nil) + c6.SetCurrentDatabase("test") + results, _ = c6.Exec("DROP TABLE t_unknown_quote", &ExecOptions{ContinueOnError: true}) + catErr, ok = results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + if catErr.Message != ctrMsg { + t.Errorf("unknown table message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + + // Cleanup + ctr.execSQL("DROP TABLE IF EXISTS t_quote_test") + ctr.execSQL("DROP TABLE IF EXISTS t_col_quote") + ctr.execSQL("DROP TABLE IF EXISTS t_key_quote") + ctr.execSQL("DROP TABLE IF EXISTS t_dropkey_quote") + ctr.execSQL("DROP DATABASE IF EXISTS db_quote_test") + }) + + // Scenario 2: Error position (index) for multi-statement SQL + // When executing multiple statements, the error should appear at the correct index. + t.Run("error_position_multi_statement", func(t *testing.T) { + // Setup: Execute 3 statements where the 3rd one fails. + // The first two should succeed, the third should error. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + + sql := "CREATE TABLE t_pos1 (id INT); CREATE TABLE t_pos2 (id INT); CREATE TABLE t_pos1 (id INT)" + results, _ := c.Exec(sql, &ExecOptions{ContinueOnError: true}) + + // Verify we got 3 results. + if len(results) != 3 { + t.Fatalf("expected 3 results, got %d", len(results)) + } + // First two should succeed. + if results[0].Error != nil { + t.Errorf("statement 0: unexpected error: %v", results[0].Error) + } + if results[1].Error != nil { + t.Errorf("statement 1: unexpected error: %v", results[1].Error) + } + // Third should fail with duplicate table. + if results[2].Error == nil { + t.Fatal("statement 2: expected error for duplicate table") + } + if results[2].Index != 2 { + t.Errorf("error index: want 2, got %d", results[2].Index) + } + catErr, ok := results[2].Error.(*Error) + if !ok { + t.Fatalf("expected *Error, got %T", results[2].Error) + } + if catErr.Code != 1050 { + t.Errorf("error code: want 1050, got %d", catErr.Code) + } + + // Also verify on container: same 3-statement SQL, error on 3rd. + ctr.execSQL("DROP TABLE IF EXISTS t_pos1") + ctr.execSQL("DROP TABLE IF EXISTS t_pos2") + ctr.execSQL("CREATE TABLE t_pos1 (id INT)") + ctr.execSQL("CREATE TABLE t_pos2 (id INT)") + ctrErr := ctr.execSQL("CREATE TABLE t_pos1 (id INT)") + if ctrErr == nil { + t.Fatal("container: expected error for duplicate table") + } + ctrCode, _, _ := extractMySQLErr(ctrErr) + if ctrCode != 1050 { + t.Errorf("container error code: want 1050, got %d", ctrCode) + } + + // Second test: error on the 1st statement of multi-statement batch. + c2 := New() + c2.Exec("CREATE DATABASE test", nil) + c2.SetCurrentDatabase("test") + + sql2 := "CREATE TABLE t_pos1 (a INT, a INT); CREATE TABLE t_pos3 (id INT)" + results2, _ := c2.Exec(sql2, &ExecOptions{ContinueOnError: true}) + if len(results2) < 1 { + t.Fatal("expected at least 1 result") + } + if results2[0].Error == nil { + t.Fatal("statement 0: expected error for duplicate column") + } + if results2[0].Index != 0 { + t.Errorf("error index: want 0, got %d", results2[0].Index) + } + + // Cleanup + ctr.execSQL("DROP TABLE IF EXISTS t_pos1") + ctr.execSQL("DROP TABLE IF EXISTS t_pos2") + }) + + // Scenario 3: IF EXISTS suppresses errors correctly + t.Run("if_exists_suppresses_errors", func(t *testing.T) { + // 3a: DROP TABLE IF EXISTS on nonexistent table — no error on both. + ctr.execSQL("DROP TABLE IF EXISTS t_ifexists_none") + ctrErr := ctr.execSQL("DROP TABLE IF EXISTS t_ifexists_none") + if ctrErr != nil { + t.Fatalf("container DROP TABLE IF EXISTS error: %v", ctrErr) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("DROP TABLE IF EXISTS t_ifexists_none", nil) + if results[0].Error != nil { + t.Errorf("omni DROP TABLE IF EXISTS error: %v", results[0].Error) + } + + // 3b: DROP DATABASE IF EXISTS on nonexistent database — no error on both. + ctr.execSQL("DROP DATABASE IF EXISTS db_ifexists_none") + ctrErr = ctr.execSQL("DROP DATABASE IF EXISTS db_ifexists_none") + if ctrErr != nil { + t.Fatalf("container DROP DATABASE IF EXISTS error: %v", ctrErr) + } + + c2 := New() + results, _ = c2.Exec("DROP DATABASE IF EXISTS db_ifexists_none", nil) + if results[0].Error != nil { + t.Errorf("omni DROP DATABASE IF EXISTS error: %v", results[0].Error) + } + + // 3c: DROP VIEW IF EXISTS on nonexistent view — no error on both. + ctr.execSQL("DROP VIEW IF EXISTS v_ifexists_none") + ctrErr = ctr.execSQL("DROP VIEW IF EXISTS v_ifexists_none") + if ctrErr != nil { + t.Fatalf("container DROP VIEW IF EXISTS error: %v", ctrErr) + } + + c3 := New() + c3.Exec("CREATE DATABASE test", nil) + c3.SetCurrentDatabase("test") + results, _ = c3.Exec("DROP VIEW IF EXISTS v_ifexists_none", nil) + if results[0].Error != nil { + t.Errorf("omni DROP VIEW IF EXISTS error: %v", results[0].Error) + } + + // 3d: DROP TABLE IF EXISTS on existing table — should succeed (table is dropped). + ctr.execSQL("DROP TABLE IF EXISTS t_ifexists_real") + ctr.execSQL("CREATE TABLE t_ifexists_real (id INT)") + ctrErr = ctr.execSQL("DROP TABLE IF EXISTS t_ifexists_real") + if ctrErr != nil { + t.Fatalf("container DROP TABLE IF EXISTS (existing) error: %v", ctrErr) + } + + c4 := New() + c4.Exec("CREATE DATABASE test", nil) + c4.SetCurrentDatabase("test") + c4.Exec("CREATE TABLE t_ifexists_real (id INT)", nil) + results, _ = c4.Exec("DROP TABLE IF EXISTS t_ifexists_real", nil) + if results[0].Error != nil { + t.Errorf("omni DROP TABLE IF EXISTS (existing) error: %v", results[0].Error) + } + // Verify the table was actually dropped. + results, _ = c4.Exec("DROP TABLE t_ifexists_real", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Error("omni: table should have been dropped but still exists") + } + + // 3e: Without IF EXISTS, DROP TABLE on nonexistent table — must error. + ctr.execSQL("DROP TABLE IF EXISTS t_noifexists") + ctrErr = ctr.execSQL("DROP TABLE t_noifexists") + if ctrErr == nil { + t.Fatal("container: expected error for DROP TABLE without IF EXISTS") + } + ctrCode, _, ctrMsg := extractMySQLErr(ctrErr) + + c5 := New() + c5.Exec("CREATE DATABASE test", nil) + c5.SetCurrentDatabase("test") + results, _ = c5.Exec("DROP TABLE t_noifexists", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for DROP TABLE without IF EXISTS") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + if catErr.Code != int(ctrCode) { + t.Errorf("error code mismatch: container=%d omni=%d", ctrCode, catErr.Code) + } + if catErr.Message != ctrMsg { + t.Errorf("message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + }) + + // Scenario 4: IF NOT EXISTS suppresses errors correctly + t.Run("if_not_exists_suppresses_errors", func(t *testing.T) { + // 4a: CREATE DATABASE IF NOT EXISTS on existing database — no error. + ctr.execSQL("DROP DATABASE IF EXISTS db_ifne") + ctr.execSQL("CREATE DATABASE db_ifne") + ctrErr := ctr.execSQL("CREATE DATABASE IF NOT EXISTS db_ifne") + if ctrErr != nil { + t.Fatalf("container CREATE DATABASE IF NOT EXISTS error: %v", ctrErr) + } + + c := New() + c.Exec("CREATE DATABASE db_ifne", nil) + results, _ := c.Exec("CREATE DATABASE IF NOT EXISTS db_ifne", nil) + if results[0].Error != nil { + t.Errorf("omni CREATE DATABASE IF NOT EXISTS error: %v", results[0].Error) + } + + // 4b: CREATE TABLE IF NOT EXISTS on existing table — no error, original preserved. + ctr.execSQL("DROP TABLE IF EXISTS t_ifne") + ctr.execSQL("CREATE TABLE t_ifne (id INT)") + ctrErr = ctr.execSQL("CREATE TABLE IF NOT EXISTS t_ifne (val VARCHAR(100))") + if ctrErr != nil { + t.Fatalf("container CREATE TABLE IF NOT EXISTS error: %v", ctrErr) + } + // Verify original table is unchanged. + ctrDDL, _ := ctr.showCreateTable("t_ifne") + if !strings.Contains(ctrDDL, "`id`") { + t.Errorf("container: original table structure should be preserved, got: %s", ctrDDL) + } + + c2 := New() + c2.Exec("CREATE DATABASE test", nil) + c2.SetCurrentDatabase("test") + c2.Exec("CREATE TABLE t_ifne (id INT)", nil) + results, _ = c2.Exec("CREATE TABLE IF NOT EXISTS t_ifne (val VARCHAR(100))", nil) + if results[0].Error != nil { + t.Errorf("omni CREATE TABLE IF NOT EXISTS error: %v", results[0].Error) + } + // Verify original table is unchanged. + omniDDL := c2.ShowCreateTable("test", "t_ifne") + if !strings.Contains(omniDDL, "`id`") { + t.Errorf("omni: original table structure should be preserved, got: %s", omniDDL) + } + + // 4c: Without IF NOT EXISTS, CREATE TABLE on existing table — must error. + ctr.execSQL("DROP TABLE IF EXISTS t_no_ifne") + ctr.execSQL("CREATE TABLE t_no_ifne (id INT)") + ctrErr = ctr.execSQL("CREATE TABLE t_no_ifne (id INT)") + if ctrErr == nil { + t.Fatal("container: expected error for CREATE TABLE without IF NOT EXISTS") + } + ctrCode, _, ctrMsg := extractMySQLErr(ctrErr) + + c3 := New() + c3.Exec("CREATE DATABASE test", nil) + c3.SetCurrentDatabase("test") + c3.Exec("CREATE TABLE t_no_ifne (id INT)", nil) + results, _ = c3.Exec("CREATE TABLE t_no_ifne (id INT)", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for CREATE TABLE without IF NOT EXISTS") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + if catErr.Code != int(ctrCode) { + t.Errorf("error code mismatch: container=%d omni=%d", ctrCode, catErr.Code) + } + if catErr.Message != ctrMsg { + t.Errorf("message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + + // Cleanup + ctr.execSQL("DROP TABLE IF EXISTS t_ifne") + ctr.execSQL("DROP TABLE IF EXISTS t_no_ifne") + ctr.execSQL("DROP DATABASE IF EXISTS db_ifne") + }) +} + +func TestContainer_Section_4_1_Partitioning(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + { + "partition_by_range", + `CREATE TABLE t_part_range ( + id INT NOT NULL, + created DATE NOT NULL + ) PARTITION BY RANGE (YEAR(created)) ( + PARTITION p0 VALUES LESS THAN (2020), + PARTITION p1 VALUES LESS THAN (2025), + PARTITION pmax VALUES LESS THAN MAXVALUE + )`, + "t_part_range", + }, + { + "partition_by_range_columns", + `CREATE TABLE t_part_range_cols ( + id INT NOT NULL, + city VARCHAR(50) NOT NULL, + name VARCHAR(50) NOT NULL + ) PARTITION BY RANGE COLUMNS(city) ( + PARTITION p0 VALUES LESS THAN ('M'), + PARTITION p1 VALUES LESS THAN MAXVALUE + )`, + "t_part_range_cols", + }, + { + "partition_by_list", + `CREATE TABLE t_part_list ( + id INT NOT NULL, + region INT NOT NULL + ) PARTITION BY LIST (region) ( + PARTITION pNorth VALUES IN (1,2,3), + PARTITION pSouth VALUES IN (4,5,6), + PARTITION pWest VALUES IN (7,8,9) + )`, + "t_part_list", + }, + { + "partition_by_hash", + `CREATE TABLE t_part_hash ( + id INT NOT NULL, + name VARCHAR(100) + ) PARTITION BY HASH (id) PARTITIONS 4`, + "t_part_hash", + }, + { + "partition_by_key", + `CREATE TABLE t_part_key ( + id INT NOT NULL, + name VARCHAR(100) + ) PARTITION BY KEY (id) PARTITIONS 3`, + "t_part_key", + }, + { + "partition_linear_hash", + `CREATE TABLE t_part_linear ( + id INT NOT NULL + ) PARTITION BY LINEAR HASH (id) PARTITIONS 4`, + "t_part_linear", + }, + { + "partition_subpartition", + `CREATE TABLE t_part_sub ( + id INT NOT NULL, + purchased DATE NOT NULL + ) PARTITION BY RANGE (YEAR(purchased)) + SUBPARTITION BY HASH (id) + SUBPARTITIONS 2 ( + PARTITION p0 VALUES LESS THAN (2020), + PARTITION p1 VALUES LESS THAN MAXVALUE + )`, + "t_part_sub", + }, + { + "show_create_partitioned_table", + `CREATE TABLE t_part_show ( + id INT NOT NULL, + val INT NOT NULL + ) PARTITION BY RANGE (val) ( + PARTITION p0 VALUES LESS THAN (10), + PARTITION p1 VALUES LESS THAN (20), + PARTITION p2 VALUES LESS THAN MAXVALUE + )`, + "t_part_show", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, _ := ctr.showCreateTable(tc.table) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec(tc.sql, nil) + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } + + // ALTER TABLE partition tests — multi-step + t.Run("alter_add_partition", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_alter_addp") + setupSQL := `CREATE TABLE t_alter_addp ( + id INT NOT NULL, + val INT NOT NULL + ) PARTITION BY RANGE (val) ( + PARTITION p0 VALUES LESS THAN (10), + PARTITION p1 VALUES LESS THAN (20) + )` + if err := ctr.execSQL(setupSQL); err != nil { + t.Fatalf("container setup: %v", err) + } + alterSQL := "ALTER TABLE t_alter_addp ADD PARTITION (PARTITION p2 VALUES LESS THAN (30))" + if err := ctr.execSQL(alterSQL); err != nil { + t.Fatalf("container alter: %v", err) + } + ctrDDL, _ := ctr.showCreateTable("t_alter_addp") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec(setupSQL, nil) + results, _ := c.Exec(alterSQL, nil) + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_alter_addp") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + t.Run("alter_drop_partition", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_alter_dropp") + setupSQL := `CREATE TABLE t_alter_dropp ( + id INT NOT NULL, + val INT NOT NULL + ) PARTITION BY RANGE (val) ( + PARTITION p0 VALUES LESS THAN (10), + PARTITION p1 VALUES LESS THAN (20), + PARTITION p2 VALUES LESS THAN MAXVALUE + )` + if err := ctr.execSQL(setupSQL); err != nil { + t.Fatalf("container setup: %v", err) + } + alterSQL := "ALTER TABLE t_alter_dropp DROP PARTITION p1" + if err := ctr.execSQL(alterSQL); err != nil { + t.Fatalf("container alter: %v", err) + } + ctrDDL, _ := ctr.showCreateTable("t_alter_dropp") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec(setupSQL, nil) + results, _ := c.Exec(alterSQL, nil) + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_alter_dropp") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + t.Run("alter_reorganize_partition", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_alter_reorgp") + setupSQL := `CREATE TABLE t_alter_reorgp ( + id INT NOT NULL, + val INT NOT NULL + ) PARTITION BY RANGE (val) ( + PARTITION p0 VALUES LESS THAN (10), + PARTITION p1 VALUES LESS THAN (30) + )` + if err := ctr.execSQL(setupSQL); err != nil { + t.Fatalf("container setup: %v", err) + } + alterSQL := `ALTER TABLE t_alter_reorgp REORGANIZE PARTITION p1 INTO ( + PARTITION p1a VALUES LESS THAN (20), + PARTITION p1b VALUES LESS THAN (30) + )` + if err := ctr.execSQL(alterSQL); err != nil { + t.Fatalf("container alter: %v", err) + } + ctrDDL, _ := ctr.showCreateTable("t_alter_reorgp") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec(setupSQL, nil) + results, _ := c.Exec(alterSQL, nil) + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_alter_reorgp") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + t.Run("alter_truncate_partition", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_alter_truncp") + setupSQL := `CREATE TABLE t_alter_truncp ( + id INT NOT NULL, + val INT NOT NULL + ) PARTITION BY RANGE (val) ( + PARTITION p0 VALUES LESS THAN (10), + PARTITION p1 VALUES LESS THAN MAXVALUE + )` + if err := ctr.execSQL(setupSQL); err != nil { + t.Fatalf("container setup: %v", err) + } + alterSQL := "ALTER TABLE t_alter_truncp TRUNCATE PARTITION p0" + if err := ctr.execSQL(alterSQL); err != nil { + t.Fatalf("container alter: %v", err) + } + ctrDDL, _ := ctr.showCreateTable("t_alter_truncp") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec(setupSQL, nil) + results, _ := c.Exec(alterSQL, nil) + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_alter_truncp") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + t.Run("alter_coalesce_partition", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_alter_coalp") + setupSQL := `CREATE TABLE t_alter_coalp ( + id INT NOT NULL + ) PARTITION BY HASH (id) PARTITIONS 4` + if err := ctr.execSQL(setupSQL); err != nil { + t.Fatalf("container setup: %v", err) + } + alterSQL := "ALTER TABLE t_alter_coalp COALESCE PARTITION 2" + if err := ctr.execSQL(alterSQL); err != nil { + t.Fatalf("container alter: %v", err) + } + ctrDDL, _ := ctr.showCreateTable("t_alter_coalp") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec(setupSQL, nil) + results, _ := c.Exec(alterSQL, nil) + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_alter_coalp") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + t.Run("alter_exchange_partition", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_alter_exchp") + ctr.execSQL("DROP TABLE IF EXISTS t_alter_exchp_swap") + setupSQL := `CREATE TABLE t_alter_exchp ( + id INT NOT NULL, + val INT NOT NULL + ) PARTITION BY RANGE (val) ( + PARTITION p0 VALUES LESS THAN (10), + PARTITION p1 VALUES LESS THAN MAXVALUE + )` + swapSQL := `CREATE TABLE t_alter_exchp_swap ( + id INT NOT NULL, + val INT NOT NULL + )` + if err := ctr.execSQL(setupSQL); err != nil { + t.Fatalf("container setup: %v", err) + } + if err := ctr.execSQL(swapSQL); err != nil { + t.Fatalf("container swap table setup: %v", err) + } + alterSQL := "ALTER TABLE t_alter_exchp EXCHANGE PARTITION p0 WITH TABLE t_alter_exchp_swap" + if err := ctr.execSQL(alterSQL); err != nil { + t.Fatalf("container alter: %v", err) + } + ctrDDL, _ := ctr.showCreateTable("t_alter_exchp") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec(setupSQL, nil) + c.Exec(swapSQL, nil) + results, _ := c.Exec(alterSQL, nil) + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_alter_exchp") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) +} + +func TestContainer_Section_4_2_StoredRoutines(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // 4.2.1: CREATE FUNCTION — store metadata (name, params, return type, body) + t.Run("create_function", func(t *testing.T) { + ctr.execSQL("DROP FUNCTION IF EXISTS fn_add") + createSQL := "CREATE FUNCTION fn_add(a INT, b INT) RETURNS INT DETERMINISTIC RETURN a + b" + if err := ctr.execSQL(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateFunction("fn_add") + if err != nil { + t.Fatalf("container SHOW CREATE FUNCTION: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateFunction("test", "fn_add") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + // 4.2.2: CREATE PROCEDURE — store metadata + t.Run("create_procedure", func(t *testing.T) { + ctr.execSQLDirect("DROP PROCEDURE IF EXISTS sp_greet") + createSQL := "CREATE PROCEDURE sp_greet(IN name VARCHAR(100)) BEGIN SELECT CONCAT('Hello, ', name); END" + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateProcedure("sp_greet") + if err != nil { + t.Fatalf("container SHOW CREATE PROCEDURE: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateProcedure("test", "sp_greet") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + // 4.2.3: DROP FUNCTION / PROCEDURE + t.Run("drop_function", func(t *testing.T) { + ctr.execSQL("DROP FUNCTION IF EXISTS fn_drop_test") + ctr.execSQL("CREATE FUNCTION fn_drop_test() RETURNS INT DETERMINISTIC RETURN 1") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE FUNCTION fn_drop_test() RETURNS INT DETERMINISTIC RETURN 1", nil) + + // Drop on container + if err := ctr.execSQL("DROP FUNCTION fn_drop_test"); err != nil { + t.Fatalf("container DROP FUNCTION: %v", err) + } + // Drop on omni + results, _ := c.Exec("DROP FUNCTION fn_drop_test", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP FUNCTION error: %v", results[0].Error) + } + + // Both should now error on SHOW CREATE FUNCTION + _, ctrErr := ctr.showCreateFunction("fn_drop_test") + omniDDL := c.ShowCreateFunction("test", "fn_drop_test") + + if ctrErr == nil { + t.Error("expected container error after DROP FUNCTION") + } + if omniDDL != "" { + t.Error("expected empty omni DDL after DROP FUNCTION") + } + }) + + t.Run("drop_procedure", func(t *testing.T) { + ctr.execSQLDirect("DROP PROCEDURE IF EXISTS sp_drop_test") + ctr.execSQLDirect("CREATE PROCEDURE sp_drop_test() BEGIN SELECT 1; END") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE PROCEDURE sp_drop_test() BEGIN SELECT 1; END", nil) + + // Drop on container + if err := ctr.execSQL("DROP PROCEDURE sp_drop_test"); err != nil { + t.Fatalf("container DROP PROCEDURE: %v", err) + } + // Drop on omni + results, _ := c.Exec("DROP PROCEDURE sp_drop_test", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP PROCEDURE error: %v", results[0].Error) + } + + _, ctrErr := ctr.showCreateProcedure("sp_drop_test") + omniDDL := c.ShowCreateProcedure("test", "sp_drop_test") + + if ctrErr == nil { + t.Error("expected container error after DROP PROCEDURE") + } + if omniDDL != "" { + t.Error("expected empty omni DDL after DROP PROCEDURE") + } + }) + + // 4.2.4: DROP FUNCTION IF EXISTS — no error for nonexistent + t.Run("drop_function_if_exists", func(t *testing.T) { + ctr.execSQL("DROP FUNCTION IF EXISTS fn_nonexistent_xyz") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("DROP FUNCTION IF EXISTS fn_nonexistent_xyz", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP FUNCTION IF EXISTS should not error: %v", results[0].Error) + } + }) + + // 4.2.5: DROP FUNCTION nonexistent — error + t.Run("drop_function_nonexistent_error", func(t *testing.T) { + ctrErr := ctr.execSQL("DROP FUNCTION fn_totally_nonexistent") + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("DROP FUNCTION fn_totally_nonexistent", nil) + + if ctrErr == nil { + t.Fatal("expected container error for DROP nonexistent FUNCTION") + } + if results[0].Error == nil { + t.Fatal("expected omni error for DROP nonexistent FUNCTION") + } + + // Both should produce error code 1305 + var mysqlErr *mysqldriver.MySQLError + if errors.As(ctrErr, &mysqlErr) { + if catErr, ok := results[0].Error.(*Error); ok { + if catErr.Code != int(mysqlErr.Number) { + t.Errorf("error code mismatch: container=%d omni=%d", mysqlErr.Number, catErr.Code) + } + } + } + }) + + // 4.2.6: ALTER ROUTINE (characteristics only) + t.Run("alter_function_comment", func(t *testing.T) { + ctr.execSQL("DROP FUNCTION IF EXISTS fn_alter_test") + ctr.execSQL("CREATE FUNCTION fn_alter_test() RETURNS INT DETERMINISTIC RETURN 42") + if err := ctr.execSQL("ALTER FUNCTION fn_alter_test COMMENT 'test comment'"); err != nil { + t.Fatalf("container ALTER FUNCTION: %v", err) + } + ctrDDL, err := ctr.showCreateFunction("fn_alter_test") + if err != nil { + t.Fatalf("container SHOW CREATE FUNCTION: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE FUNCTION fn_alter_test() RETURNS INT DETERMINISTIC RETURN 42", nil) + results, _ := c.Exec("ALTER FUNCTION fn_alter_test COMMENT 'test comment'", nil) + if results[0].Error != nil { + t.Fatalf("omni ALTER FUNCTION error: %v", results[0].Error) + } + omniDDL := c.ShowCreateFunction("test", "fn_alter_test") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + t.Run("alter_procedure_sql_security", func(t *testing.T) { + ctr.execSQLDirect("DROP PROCEDURE IF EXISTS sp_alter_test") + ctr.execSQLDirect("CREATE PROCEDURE sp_alter_test() BEGIN SELECT 1; END") + if err := ctr.execSQLDirect("ALTER PROCEDURE sp_alter_test SQL SECURITY INVOKER"); err != nil { + t.Fatalf("container ALTER PROCEDURE: %v", err) + } + ctrDDL, err := ctr.showCreateProcedure("sp_alter_test") + if err != nil { + t.Fatalf("container SHOW CREATE PROCEDURE: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE PROCEDURE sp_alter_test() BEGIN SELECT 1; END", nil) + results, _ := c.Exec("ALTER PROCEDURE sp_alter_test SQL SECURITY INVOKER", nil) + if results[0].Error != nil { + t.Fatalf("omni ALTER PROCEDURE error: %v", results[0].Error) + } + omniDDL := c.ShowCreateProcedure("test", "sp_alter_test") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + // 4.2.7: SHOW CREATE FUNCTION output + t.Run("show_create_function_output", func(t *testing.T) { + ctr.execSQL("DROP FUNCTION IF EXISTS fn_show_test") + createSQL := "CREATE FUNCTION fn_show_test(x INT) RETURNS VARCHAR(100) DETERMINISTIC RETURN CONCAT('val=', x)" + if err := ctr.execSQL(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateFunction("fn_show_test") + if err != nil { + t.Fatalf("container SHOW CREATE FUNCTION: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateFunction("test", "fn_show_test") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + // 4.2.8: SHOW CREATE PROCEDURE output + t.Run("show_create_procedure_output", func(t *testing.T) { + ctr.execSQLDirect("DROP PROCEDURE IF EXISTS sp_show_test") + createSQL := "CREATE PROCEDURE sp_show_test(IN id INT, OUT result VARCHAR(100)) BEGIN SET result = CONCAT('id=', id); END" + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateProcedure("sp_show_test") + if err != nil { + t.Fatalf("container SHOW CREATE PROCEDURE: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateProcedure("test", "sp_show_test") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) +} + +func TestContainer_Section_4_3_Triggers(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create a table that triggers can reference. + ctr.execSQL("DROP TABLE IF EXISTS t_trigger") + if err := ctr.execSQL("CREATE TABLE t_trigger (id INT AUTO_INCREMENT PRIMARY KEY, val INT, name VARCHAR(100))"); err != nil { + t.Fatalf("container setup table: %v", err) + } + + // 4.3.1: CREATE TRIGGER — store metadata (name, timing, event, table, body) + t.Run("create_trigger_before_insert", func(t *testing.T) { + ctr.execSQL("DROP TRIGGER IF EXISTS tr_before_insert") + createSQL := "CREATE TRIGGER tr_before_insert BEFORE INSERT ON t_trigger FOR EACH ROW SET NEW.val = NEW.val + 1" + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateTrigger("tr_before_insert") + if err != nil { + t.Fatalf("container SHOW CREATE TRIGGER: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_trigger (id INT AUTO_INCREMENT PRIMARY KEY, val INT, name VARCHAR(100))", nil) + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTrigger("test", "tr_before_insert") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + t.Run("create_trigger_after_update", func(t *testing.T) { + ctr.execSQL("DROP TRIGGER IF EXISTS tr_after_update") + createSQL := "CREATE TRIGGER tr_after_update AFTER UPDATE ON t_trigger FOR EACH ROW SET @updated = @updated + 1" + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateTrigger("tr_after_update") + if err != nil { + t.Fatalf("container SHOW CREATE TRIGGER: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_trigger (id INT AUTO_INCREMENT PRIMARY KEY, val INT, name VARCHAR(100))", nil) + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTrigger("test", "tr_after_update") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + t.Run("create_trigger_before_delete", func(t *testing.T) { + ctr.execSQL("DROP TRIGGER IF EXISTS tr_before_delete") + createSQL := "CREATE TRIGGER tr_before_delete BEFORE DELETE ON t_trigger FOR EACH ROW SET @deleted = OLD.id" + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateTrigger("tr_before_delete") + if err != nil { + t.Fatalf("container SHOW CREATE TRIGGER: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_trigger (id INT AUTO_INCREMENT PRIMARY KEY, val INT, name VARCHAR(100))", nil) + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTrigger("test", "tr_before_delete") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + // 4.3.2: DROP TRIGGER + t.Run("drop_trigger", func(t *testing.T) { + ctr.execSQL("DROP TRIGGER IF EXISTS tr_drop_test") + ctr.execSQLDirect("CREATE TRIGGER tr_drop_test BEFORE INSERT ON t_trigger FOR EACH ROW SET NEW.val = 0") + if err := ctr.execSQL("DROP TRIGGER tr_drop_test"); err != nil { + t.Fatalf("container DROP TRIGGER: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_trigger (id INT AUTO_INCREMENT PRIMARY KEY, val INT, name VARCHAR(100))", nil) + c.Exec("CREATE TRIGGER tr_drop_test BEFORE INSERT ON t_trigger FOR EACH ROW SET NEW.val = 0", nil) + results, _ := c.Exec("DROP TRIGGER tr_drop_test", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP TRIGGER error: %v", results[0].Error) + } + + // Verify trigger is gone + ddl := c.ShowCreateTrigger("test", "tr_drop_test") + if ddl != "" { + t.Errorf("trigger should be gone, but got: %s", ddl) + } + }) + + t.Run("drop_trigger_if_exists", func(t *testing.T) { + ctr.execSQL("DROP TRIGGER IF EXISTS tr_nonexistent_drop") + if err := ctr.execSQL("DROP TRIGGER IF EXISTS tr_nonexistent_drop"); err != nil { + t.Fatalf("container DROP TRIGGER IF EXISTS: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("DROP TRIGGER IF EXISTS tr_nonexistent_drop", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP TRIGGER IF EXISTS should not error: %v", results[0].Error) + } + }) + + t.Run("drop_trigger_not_exist_error", func(t *testing.T) { + ctr.execSQL("DROP TRIGGER IF EXISTS tr_no_exist") + ctrErr := ctr.execSQL("DROP TRIGGER tr_no_exist") + if ctrErr == nil { + t.Fatal("expected container error for DROP TRIGGER on nonexistent trigger") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("DROP TRIGGER tr_no_exist", nil) + if results[0].Error == nil { + t.Fatal("expected omni error for DROP TRIGGER on nonexistent trigger") + } + if catErr, ok := results[0].Error.(*Error); ok { + if catErr.Code != ErrNoSuchTrigger { + t.Errorf("error code: want %d, got %d", ErrNoSuchTrigger, catErr.Code) + } + } + }) + + // 4.3.3: SHOW CREATE TRIGGER output + t.Run("show_create_trigger_output", func(t *testing.T) { + ctr.execSQL("DROP TRIGGER IF EXISTS tr_show_test") + createSQL := "CREATE TRIGGER tr_show_test AFTER INSERT ON t_trigger FOR EACH ROW SET @last_insert = NEW.id" + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateTrigger("tr_show_test") + if err != nil { + t.Fatalf("container SHOW CREATE TRIGGER: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_trigger (id INT AUTO_INCREMENT PRIMARY KEY, val INT, name VARCHAR(100))", nil) + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTrigger("test", "tr_show_test") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + // 4.3.4: Multiple triggers per table/event (MySQL 8.0 supports ordering) + t.Run("multiple_triggers_ordering", func(t *testing.T) { + ctr.execSQL("DROP TRIGGER IF EXISTS tr_multi_2") + ctr.execSQL("DROP TRIGGER IF EXISTS tr_multi_1") + create1 := "CREATE TRIGGER tr_multi_1 BEFORE INSERT ON t_trigger FOR EACH ROW SET NEW.val = NEW.val + 10" + create2 := "CREATE TRIGGER tr_multi_2 BEFORE INSERT ON t_trigger FOR EACH ROW FOLLOWS tr_multi_1 SET NEW.val = NEW.val + 20" + if err := ctr.execSQLDirect(create1); err != nil { + t.Fatalf("container exec trigger 1: %v", err) + } + if err := ctr.execSQLDirect(create2); err != nil { + t.Fatalf("container exec trigger 2: %v", err) + } + + ctrDDL1, err := ctr.showCreateTrigger("tr_multi_1") + if err != nil { + t.Fatalf("container SHOW CREATE TRIGGER tr_multi_1: %v", err) + } + ctrDDL2, err := ctr.showCreateTrigger("tr_multi_2") + if err != nil { + t.Fatalf("container SHOW CREATE TRIGGER tr_multi_2: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_trigger (id INT AUTO_INCREMENT PRIMARY KEY, val INT, name VARCHAR(100))", nil) + results1, _ := c.Exec(create1, nil) + if results1[0].Error != nil { + t.Fatalf("omni exec trigger 1 error: %v", results1[0].Error) + } + results2, _ := c.Exec(create2, nil) + if results2[0].Error != nil { + t.Fatalf("omni exec trigger 2 error: %v", results2[0].Error) + } + + omniDDL1 := c.ShowCreateTrigger("test", "tr_multi_1") + omniDDL2 := c.ShowCreateTrigger("test", "tr_multi_2") + + if normalizeWhitespace(ctrDDL1) != normalizeWhitespace(omniDDL1) { + t.Errorf("trigger 1 mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL1, omniDDL1) + } + if normalizeWhitespace(ctrDDL2) != normalizeWhitespace(omniDDL2) { + t.Errorf("trigger 2 mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL2, omniDDL2) + } + }) +} + +// normalizeEventDDL normalizes event DDL for comparison by removing +// auto-generated STARTS timestamps that MySQL adds to EVERY schedules. +// MySQL always adds "STARTS ''" when EVERY is used without explicit STARTS, +// but our catalog doesn't have a clock to generate timestamps. +func normalizeEventDDL(ddl string) string { + // Remove STARTS '' from EVERY schedules (auto-generated by MySQL). + // Pattern: " STARTS ''" + s := normalizeWhitespace(ddl) + // Remove auto-generated STARTS with timestamp pattern + for { + idx := strings.Index(s, " STARTS '") + if idx < 0 { + break + } + // Find closing quote + end := strings.Index(s[idx+9:], "'") + if end < 0 { + break + } + s = s[:idx] + s[idx+9+end+1:] + } + return s +} + +func TestContainer_Section_4_4_Events(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // 4.4.1: CREATE EVENT — store metadata + t.Run("create_event_every", func(t *testing.T) { + ctr.execSQL("DROP EVENT IF EXISTS ev_test_every") + createSQL := "CREATE EVENT ev_test_every ON SCHEDULE EVERY 1 HOUR DO SELECT 1" + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateEvent("ev_test_every") + if err != nil { + t.Fatalf("container SHOW CREATE EVENT: %v", err) + } + t.Logf("container SHOW CREATE EVENT:\n%s", ctrDDL) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateEvent("test", "ev_test_every") + + // Compare with normalized DDL (MySQL adds auto-STARTS timestamp for EVERY schedules) + if normalizeEventDDL(ctrDDL) != normalizeEventDDL(omniDDL) { + t.Errorf("mismatch:\n--- container (normalized) ---\n%s\n--- omni (normalized) ---\n%s\n--- container (raw) ---\n%s\n--- omni (raw) ---\n%s", + normalizeEventDDL(ctrDDL), normalizeEventDDL(omniDDL), ctrDDL, omniDDL) + } + }) + + t.Run("create_event_at_timestamp", func(t *testing.T) { + ctr.execSQL("DROP EVENT IF EXISTS ev_test_at") + createSQL := "CREATE EVENT ev_test_at ON SCHEDULE AT '2035-01-01 00:00:00' DO SELECT 1" + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateEvent("ev_test_at") + if err != nil { + t.Fatalf("container SHOW CREATE EVENT: %v", err) + } + t.Logf("container SHOW CREATE EVENT:\n%s", ctrDDL) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateEvent("test", "ev_test_at") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + t.Run("create_event_with_options", func(t *testing.T) { + ctr.execSQL("DROP EVENT IF EXISTS ev_test_opts") + createSQL := "CREATE EVENT ev_test_opts ON SCHEDULE EVERY 1 DAY ON COMPLETION PRESERVE ENABLE COMMENT 'daily cleanup' DO DELETE FROM t_trigger WHERE id < 0" + // Need a table for the body to reference + ctr.execSQL("DROP TABLE IF EXISTS t_trigger") + ctr.execSQL("CREATE TABLE t_trigger (id INT AUTO_INCREMENT PRIMARY KEY, val INT, name VARCHAR(100))") + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateEvent("ev_test_opts") + if err != nil { + t.Fatalf("container SHOW CREATE EVENT: %v", err) + } + t.Logf("container SHOW CREATE EVENT:\n%s", ctrDDL) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_trigger (id INT AUTO_INCREMENT PRIMARY KEY, val INT, name VARCHAR(100))", nil) + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateEvent("test", "ev_test_opts") + + // Compare with normalized DDL (MySQL adds auto-STARTS timestamp for EVERY schedules) + if normalizeEventDDL(ctrDDL) != normalizeEventDDL(omniDDL) { + t.Errorf("mismatch:\n--- container (normalized) ---\n%s\n--- omni (normalized) ---\n%s\n--- container (raw) ---\n%s\n--- omni (raw) ---\n%s", + normalizeEventDDL(ctrDDL), normalizeEventDDL(omniDDL), ctrDDL, omniDDL) + } + }) + + t.Run("create_event_disabled", func(t *testing.T) { + ctr.execSQL("DROP EVENT IF EXISTS ev_test_disabled") + createSQL := "CREATE EVENT ev_test_disabled ON SCHEDULE EVERY 1 HOUR DISABLE DO SELECT 1" + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateEvent("ev_test_disabled") + if err != nil { + t.Fatalf("container SHOW CREATE EVENT: %v", err) + } + t.Logf("container SHOW CREATE EVENT:\n%s", ctrDDL) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateEvent("test", "ev_test_disabled") + + // Compare with normalized DDL (MySQL adds auto-STARTS timestamp for EVERY schedules) + if normalizeEventDDL(ctrDDL) != normalizeEventDDL(omniDDL) { + t.Errorf("mismatch:\n--- container (normalized) ---\n%s\n--- omni (normalized) ---\n%s\n--- container (raw) ---\n%s\n--- omni (raw) ---\n%s", + normalizeEventDDL(ctrDDL), normalizeEventDDL(omniDDL), ctrDDL, omniDDL) + } + }) + + // 4.4.2: ALTER EVENT + t.Run("alter_event_disable", func(t *testing.T) { + ctr.execSQL("DROP EVENT IF EXISTS ev_alter_test") + ctr.execSQLDirect("CREATE EVENT ev_alter_test ON SCHEDULE EVERY 1 HOUR DO SELECT 1") + if err := ctr.execSQLDirect("ALTER EVENT ev_alter_test DISABLE"); err != nil { + t.Fatalf("container ALTER EVENT: %v", err) + } + ctrDDL, err := ctr.showCreateEvent("ev_alter_test") + if err != nil { + t.Fatalf("container SHOW CREATE EVENT: %v", err) + } + t.Logf("container SHOW CREATE EVENT after ALTER:\n%s", ctrDDL) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE EVENT ev_alter_test ON SCHEDULE EVERY 1 HOUR DO SELECT 1", nil) + results, _ := c.Exec("ALTER EVENT ev_alter_test DISABLE", nil) + if results[0].Error != nil { + t.Fatalf("omni ALTER EVENT error: %v", results[0].Error) + } + omniDDL := c.ShowCreateEvent("test", "ev_alter_test") + + // Compare with normalized DDL (MySQL adds auto-STARTS timestamp for EVERY schedules) + if normalizeEventDDL(ctrDDL) != normalizeEventDDL(omniDDL) { + t.Errorf("mismatch:\n--- container (normalized) ---\n%s\n--- omni (normalized) ---\n%s\n--- container (raw) ---\n%s\n--- omni (raw) ---\n%s", + normalizeEventDDL(ctrDDL), normalizeEventDDL(omniDDL), ctrDDL, omniDDL) + } + }) + + t.Run("alter_event_rename", func(t *testing.T) { + ctr.execSQL("DROP EVENT IF EXISTS ev_rename_old") + ctr.execSQL("DROP EVENT IF EXISTS ev_rename_new") + ctr.execSQLDirect("CREATE EVENT ev_rename_old ON SCHEDULE EVERY 1 HOUR DO SELECT 1") + if err := ctr.execSQLDirect("ALTER EVENT ev_rename_old RENAME TO ev_rename_new"); err != nil { + t.Fatalf("container ALTER EVENT RENAME: %v", err) + } + ctrDDL, err := ctr.showCreateEvent("ev_rename_new") + if err != nil { + t.Fatalf("container SHOW CREATE EVENT: %v", err) + } + t.Logf("container SHOW CREATE EVENT after RENAME:\n%s", ctrDDL) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE EVENT ev_rename_old ON SCHEDULE EVERY 1 HOUR DO SELECT 1", nil) + results, _ := c.Exec("ALTER EVENT ev_rename_old RENAME TO ev_rename_new", nil) + if results[0].Error != nil { + t.Fatalf("omni ALTER EVENT RENAME error: %v", results[0].Error) + } + omniDDL := c.ShowCreateEvent("test", "ev_rename_new") + + // Compare with normalized DDL (MySQL adds auto-STARTS timestamp for EVERY schedules) + if normalizeEventDDL(ctrDDL) != normalizeEventDDL(omniDDL) { + t.Errorf("mismatch:\n--- container (normalized) ---\n%s\n--- omni (normalized) ---\n%s\n--- container (raw) ---\n%s\n--- omni (raw) ---\n%s", + normalizeEventDDL(ctrDDL), normalizeEventDDL(omniDDL), ctrDDL, omniDDL) + } + + // Verify old name is gone + oldDDL := c.ShowCreateEvent("test", "ev_rename_old") + if oldDDL != "" { + t.Errorf("old event name should be gone, but got: %s", oldDDL) + } + }) + + // 4.4.3: DROP EVENT + t.Run("drop_event", func(t *testing.T) { + ctr.execSQL("DROP EVENT IF EXISTS ev_drop_test") + ctr.execSQLDirect("CREATE EVENT ev_drop_test ON SCHEDULE EVERY 1 HOUR DO SELECT 1") + if err := ctr.execSQL("DROP EVENT ev_drop_test"); err != nil { + t.Fatalf("container DROP EVENT: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE EVENT ev_drop_test ON SCHEDULE EVERY 1 HOUR DO SELECT 1", nil) + results, _ := c.Exec("DROP EVENT ev_drop_test", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP EVENT error: %v", results[0].Error) + } + + // Verify event is gone + ddl := c.ShowCreateEvent("test", "ev_drop_test") + if ddl != "" { + t.Errorf("event should be gone, but got: %s", ddl) + } + }) + + t.Run("drop_event_if_exists", func(t *testing.T) { + if err := ctr.execSQL("DROP EVENT IF EXISTS ev_nonexistent"); err != nil { + t.Fatalf("container DROP EVENT IF EXISTS: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("DROP EVENT IF EXISTS ev_nonexistent", nil) + if results[0].Error != nil { + t.Fatalf("omni DROP EVENT IF EXISTS should not error: %v", results[0].Error) + } + }) + + t.Run("drop_event_not_exist_error", func(t *testing.T) { + ctr.execSQL("DROP EVENT IF EXISTS ev_no_exist") + ctrErr := ctr.execSQL("DROP EVENT ev_no_exist") + if ctrErr == nil { + t.Fatal("expected container error for DROP EVENT on nonexistent event") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("DROP EVENT ev_no_exist", nil) + if results[0].Error == nil { + t.Fatal("expected omni error for DROP EVENT on nonexistent event") + } + if catErr, ok := results[0].Error.(*Error); ok { + if catErr.Code != ErrNoSuchEvent { + t.Errorf("error code: want %d, got %d", ErrNoSuchEvent, catErr.Code) + } + } + }) + + // 4.4.4: SHOW CREATE EVENT output + t.Run("show_create_event_basic", func(t *testing.T) { + ctr.execSQL("DROP EVENT IF EXISTS ev_show_test") + createSQL := "CREATE EVENT ev_show_test ON SCHEDULE EVERY 1 HOUR COMMENT 'test event' DO SELECT 1" + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateEvent("ev_show_test") + if err != nil { + t.Fatalf("container SHOW CREATE EVENT: %v", err) + } + t.Logf("container SHOW CREATE EVENT:\n%s", ctrDDL) + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateEvent("test", "ev_show_test") + + // Compare with normalized DDL (MySQL adds auto-STARTS timestamp for EVERY schedules) + if normalizeEventDDL(ctrDDL) != normalizeEventDDL(omniDDL) { + t.Errorf("mismatch:\n--- container (normalized) ---\n%s\n--- omni (normalized) ---\n%s\n--- container (raw) ---\n%s\n--- omni (raw) ---\n%s", + normalizeEventDDL(ctrDDL), normalizeEventDDL(omniDDL), ctrDDL, omniDDL) + } + }) +} + +// extractViewPreamble extracts the CREATE VIEW preamble up to and including " AS ". +// This allows comparing structural elements without comparing the rewritten SELECT text. +func extractViewPreamble(ddl string) string { + idx := strings.Index(ddl, " AS ") + if idx < 0 { + return ddl + } + return normalizeWhitespace(ddl[:idx+4]) +} + +func TestContainer_Section_4_5_ViewsDeep(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + t.Run("alter_view", func(t *testing.T) { + // ALTER VIEW changes the view definition. + // Verify both container and omni accept ALTER VIEW and update the view. + ctr.execSQL("DROP VIEW IF EXISTS v_alter_test") + ctr.execSQL("CREATE VIEW v_alter_test AS SELECT 1 AS a") + ctrErr := ctr.execSQL("ALTER VIEW v_alter_test AS SELECT 2 AS b, 3 AS c") + if ctrErr != nil { + t.Fatalf("container ALTER VIEW error: %v", ctrErr) + } + // Verify container view was updated. + ctrDDL, _ := ctr.showCreateView("v_alter_test") + if !strings.Contains(ctrDDL, "v_alter_test") { + t.Fatalf("container: view not found after ALTER VIEW") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE VIEW v_alter_test AS SELECT 1 AS a", nil) + results, _ := c.Exec("ALTER VIEW v_alter_test AS SELECT 2 AS b, 3 AS c", nil) + if results[0].Error != nil { + t.Fatalf("omni ALTER VIEW error: %v", results[0].Error) + } + // Verify omni view was updated. + db := c.GetDatabase("test") + v := db.Views[toLower("v_alter_test")] + if v == nil { + t.Fatal("omni: view v_alter_test should exist after ALTER VIEW") + } + if !strings.Contains(v.Definition, "2") { + t.Errorf("omni: view definition should contain new select, got: %s", v.Definition) + } + + // ALTER VIEW on nonexistent view should error on both. + ctr.execSQL("DROP VIEW IF EXISTS v_alter_noexist") + ctrErr = ctr.execSQL("ALTER VIEW v_alter_noexist AS SELECT 1") + if ctrErr == nil { + t.Fatal("container: expected error for ALTER VIEW on nonexistent view") + } + + c2 := New() + c2.Exec("CREATE DATABASE test", nil) + c2.SetCurrentDatabase("test") + results2, _ := c2.Exec("ALTER VIEW v_alter_noexist AS SELECT 1", &ExecOptions{ContinueOnError: true}) + if results2[0].Error == nil { + t.Fatal("omni: expected error for ALTER VIEW on nonexistent view") + } + }) + + t.Run("view_dependency_tracking", func(t *testing.T) { + // In MySQL, dropping a base table does NOT drop the view. + // The view still exists but errors when queried. + ctr.execSQL("DROP VIEW IF EXISTS v_dep_test") + ctr.execSQL("DROP TABLE IF EXISTS t_dep_base") + ctr.execSQL("CREATE TABLE t_dep_base (id INT, val VARCHAR(100))") + ctr.execSQL("CREATE VIEW v_dep_test AS SELECT id, val FROM t_dep_base") + + // Drop the base table + ctrErr := ctr.execSQL("DROP TABLE t_dep_base") + if ctrErr != nil { + t.Fatalf("container DROP TABLE error: %v", ctrErr) + } + + // View should still exist in container + _, showErr := ctr.showCreateView("v_dep_test") + if showErr != nil { + t.Fatalf("container: view should still exist after base table drop, got: %v", showErr) + } + + // Omni: same behavior + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_dep_base (id INT, val VARCHAR(100))", nil) + c.Exec("CREATE VIEW v_dep_test AS SELECT id, val FROM t_dep_base", nil) + c.Exec("DROP TABLE t_dep_base", nil) + + // View should still exist + db := c.GetDatabase("test") + if db.Views[toLower("v_dep_test")] == nil { + t.Error("omni: view v_dep_test should still exist after base table drop") + } + + // Cleanup container + ctr.execSQL("DROP VIEW IF EXISTS v_dep_test") + }) + + t.Run("show_create_view_basic", func(t *testing.T) { + // SHOW CREATE VIEW output format — verify structural elements. + // Note: MySQL rewrites the SELECT text (lowercases keywords, backtick-quotes aliases) + // which requires a full SQL deparser. We verify the CREATE VIEW preamble matches. + ctr.execSQL("DROP VIEW IF EXISTS v_show_basic") + ctr.execSQL("CREATE VIEW v_show_basic AS SELECT 1 AS a, 2 AS b") + ctrDDL, err := ctr.showCreateView("v_show_basic") + if err != nil { + t.Fatalf("container SHOW CREATE VIEW error: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, parseErr := c.Exec("CREATE VIEW v_show_basic AS SELECT 1 AS a, 2 AS b", nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateView("test", "v_show_basic") + + // Verify structural preamble matches. + oraclePreamble := extractViewPreamble(ctrDDL) + omniPreamble := extractViewPreamble(omniDDL) + if oraclePreamble != omniPreamble { + t.Errorf("preamble mismatch:\n--- container ---\n%s\n--- omni ---\n%s\n--- container full ---\n%s\n--- omni full ---\n%s", + oraclePreamble, omniPreamble, ctrDDL, omniDDL) + } + }) + + t.Run("show_create_view_with_options", func(t *testing.T) { + // SHOW CREATE VIEW with SQL SECURITY INVOKER. + // Note: MySQL normalizes ALGORITHM to UNDEFINED for simple views even if MERGE was specified. + ctr.execSQL("DROP VIEW IF EXISTS v_show_opts") + ctr.execSQL("CREATE SQL SECURITY INVOKER VIEW v_show_opts AS SELECT 1 AS x") + ctrDDL, err := ctr.showCreateView("v_show_opts") + if err != nil { + t.Fatalf("container SHOW CREATE VIEW error: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("CREATE SQL SECURITY INVOKER VIEW v_show_opts AS SELECT 1 AS x", nil) + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateView("test", "v_show_opts") + + // Verify structural preamble matches (up to and including AS). + oraclePreamble := extractViewPreamble(ctrDDL) + omniPreamble := extractViewPreamble(omniDDL) + if oraclePreamble != omniPreamble { + t.Errorf("preamble mismatch:\n--- container ---\n%s\n--- omni ---\n%s\n--- container full ---\n%s\n--- omni full ---\n%s", + oraclePreamble, omniPreamble, ctrDDL, omniDDL) + } + }) + + t.Run("view_with_column_aliases", func(t *testing.T) { + // View with column list — verify columns are stored and preamble matches. + ctr.execSQL("DROP VIEW IF EXISTS v_col_aliases") + ctr.execSQL("CREATE VIEW v_col_aliases (x, y, z) AS SELECT 1, 2, 3") + ctrDDL, err := ctr.showCreateView("v_col_aliases") + if err != nil { + t.Fatalf("container SHOW CREATE VIEW error: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("CREATE VIEW v_col_aliases (x, y, z) AS SELECT 1, 2, 3", nil) + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateView("test", "v_col_aliases") + + // Verify columns are present in SHOW CREATE VIEW output. + oraclePreamble := extractViewPreamble(ctrDDL) + omniPreamble := extractViewPreamble(omniDDL) + if oraclePreamble != omniPreamble { + t.Errorf("preamble mismatch:\n--- container ---\n%s\n--- omni ---\n%s\n--- container full ---\n%s\n--- omni full ---\n%s", + oraclePreamble, omniPreamble, ctrDDL, omniDDL) + } + + // Verify column list appears in both outputs. + if !strings.Contains(ctrDDL, "`x`") || !strings.Contains(ctrDDL, "`y`") || !strings.Contains(ctrDDL, "`z`") { + t.Errorf("container DDL missing column aliases: %s", ctrDDL) + } + if !strings.Contains(omniDDL, "`x`") || !strings.Contains(omniDDL, "`y`") || !strings.Contains(omniDDL, "`z`") { + t.Errorf("omni DDL missing column aliases: %s", omniDDL) + } + }) +} + +func TestContainer_Section_5_1_UseStatement(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Helper to extract MySQL error code from go-sql-driver error. + extractMySQLErr := func(err error) (uint16, string, string) { + var mysqlErr *mysqldriver.MySQLError + if errors.As(err, &mysqlErr) { + return mysqlErr.Number, string(mysqlErr.SQLState[:]), mysqlErr.Message + } + return 0, "", "" + } + + t.Run("use_db_sets_current_database", func(t *testing.T) { + // On container: create a separate database, USE it, then CREATE TABLE without qualifying. + ctr.execSQL("DROP DATABASE IF EXISTS use_test_db") + ctr.execSQL("CREATE DATABASE use_test_db") + if err := ctr.execSQL("USE use_test_db"); err != nil { + t.Fatalf("container USE: %v", err) + } + ctr.execSQL("DROP TABLE IF EXISTS t_use_check") + if err := ctr.execSQL("CREATE TABLE t_use_check (id INT PRIMARY KEY)"); err != nil { + t.Fatalf("container CREATE TABLE: %v", err) + } + ctrDDL, err := ctr.showCreateTable("t_use_check") + if err != nil { + t.Fatalf("container SHOW CREATE TABLE: %v", err) + } + + // On omni: same sequence via Exec. + c := New() + c.Exec("CREATE DATABASE use_test_db", nil) + results, _ := c.Exec("USE use_test_db", nil) + if results[0].Error != nil { + t.Fatalf("omni USE error: %v", results[0].Error) + } + // Verify current database was set. + if c.CurrentDatabase() != "use_test_db" { + t.Fatalf("omni CurrentDatabase: got %q, want %q", c.CurrentDatabase(), "use_test_db") + } + + results, _ = c.Exec("CREATE TABLE t_use_check (id INT PRIMARY KEY)", nil) + if results[0].Error != nil { + t.Fatalf("omni CREATE TABLE error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("use_test_db", "t_use_check") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + + // Cleanup ctr. + ctr.execSQL("DROP DATABASE IF EXISTS use_test_db") + }) + + t.Run("use_nonexistent_error_1049", func(t *testing.T) { + // On container: USE a nonexistent database. + ctr.execSQL("DROP DATABASE IF EXISTS use_nonexistent_db") + ctrErr := ctr.execSQL("USE use_nonexistent_db") + if ctrErr == nil { + t.Fatal("container: expected error for USE nonexistent database") + } + ctrCode, ctrState, ctrMsg := extractMySQLErr(ctrErr) + t.Logf("container error: %d (%s) %s", ctrCode, ctrState, ctrMsg) + + if ctrCode != 1049 { + t.Fatalf("container: expected error code 1049, got %d", ctrCode) + } + + // On omni: USE a nonexistent database. + c := New() + results, _ := c.Exec("USE use_nonexistent_db", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("omni: expected error for USE nonexistent database") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("omni: expected *Error, got %T", results[0].Error) + } + + // Compare error code. + if catErr.Code != int(ctrCode) { + t.Errorf("error code mismatch: container=%d omni=%d", ctrCode, catErr.Code) + } + // Compare SQLSTATE. + if catErr.SQLState != ctrState { + t.Errorf("SQLSTATE mismatch: container=%q omni=%q", ctrState, catErr.SQLState) + } + // Compare error message. + if catErr.Message != ctrMsg { + t.Errorf("message mismatch:\n container: %s\n omni: %s", ctrMsg, catErr.Message) + } + }) +} + +func TestContainer_Section_5_2_SetVariables(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Scenario: SET foreign_key_checks = 0 — skip FK validation on CREATE TABLE + t.Run("fk_checks_off_create", func(t *testing.T) { + ctr.execSQL("DROP TABLE IF EXISTS t_fkoff_child") + ctr.execSQL("SET foreign_key_checks = 0") + ctrErr := ctr.execSQL("CREATE TABLE t_fkoff_child (id INT, parent_id INT, FOREIGN KEY (parent_id) REFERENCES nonexistent_parent(id))") + ctr.execSQL("SET foreign_key_checks = 1") + if ctrErr != nil { + t.Fatalf("container: unexpected error with FK checks off: %v", ctrErr) + } + ctrDDL, err := ctr.showCreateTable("t_fkoff_child") + if err != nil { + t.Fatalf("container: SHOW CREATE TABLE: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("SET foreign_key_checks = 0", nil) + results, _ := c.Exec("CREATE TABLE t_fkoff_child (id INT, parent_id INT, FOREIGN KEY (parent_id) REFERENCES nonexistent_parent(id))", nil) + if results[0].Error != nil { + t.Fatalf("omni: unexpected error with FK checks off: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_fkoff_child") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", ctrDDL, omniDDL) + } + }) + + // Scenario: SET foreign_key_checks = 0 — drop table with FK references + t.Run("fk_checks_off_drop", func(t *testing.T) { + ctr.execSQL("SET foreign_key_checks = 0") + ctr.execSQL("DROP TABLE IF EXISTS t_fkdrop_child") + ctr.execSQL("DROP TABLE IF EXISTS t_fkdrop_parent") + ctr.execSQL("SET foreign_key_checks = 1") + ctr.execSQL("CREATE TABLE t_fkdrop_parent (id INT PRIMARY KEY)") + ctr.execSQL("CREATE TABLE t_fkdrop_child (id INT, parent_id INT, FOREIGN KEY (parent_id) REFERENCES t_fkdrop_parent(id))") + ctr.execSQL("SET foreign_key_checks = 0") + ctrErr := ctr.execSQL("DROP TABLE t_fkdrop_parent") + ctr.execSQL("SET foreign_key_checks = 1") + if ctrErr != nil { + t.Fatalf("container: unexpected error dropping FK-referenced table with checks off: %v", ctrErr) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_fkdrop_parent (id INT PRIMARY KEY)", nil) + c.Exec("CREATE TABLE t_fkdrop_child (id INT, parent_id INT, FOREIGN KEY (parent_id) REFERENCES t_fkdrop_parent(id))", nil) + c.Exec("SET foreign_key_checks = 0", nil) + results, _ := c.Exec("DROP TABLE t_fkdrop_parent", nil) + if results[0].Error != nil { + t.Fatalf("omni: unexpected error dropping FK-referenced table with checks off: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", "t_fkdrop_parent") + if omniDDL != "" { + t.Errorf("omni: parent table should be gone after drop, got: %s", omniDDL) + } + }) + + // Scenario: SET foreign_key_checks = 1 — enforce FK validation + t.Run("fk_checks_on_enforce", func(t *testing.T) { + ctr.execSQL("SET foreign_key_checks = 0") + ctr.execSQL("DROP TABLE IF EXISTS t_fkon_child") + ctr.execSQL("SET foreign_key_checks = 1") + ctrErr := ctr.execSQL("CREATE TABLE t_fkon_child (id INT, parent_id INT, FOREIGN KEY (parent_id) REFERENCES nonexistent_on_parent(id))") + if ctrErr == nil { + t.Fatal("container: expected error with FK checks on, referencing non-existent table") + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("SET foreign_key_checks = 1", nil) + results, _ := c.Exec("CREATE TABLE t_fkon_child (id INT, parent_id INT, FOREIGN KEY (parent_id) REFERENCES nonexistent_on_parent(id))", nil) + if results[0].Error == nil { + t.Fatal("omni: expected error with FK checks on, referencing non-existent table") + } + }) + + // Scenario: SET NAMES utf8mb4 — silently accepted + t.Run("set_names", func(t *testing.T) { + ctrErr := ctr.execSQL("SET NAMES utf8mb4") + if ctrErr != nil { + t.Fatalf("container: SET NAMES error: %v", ctrErr) + } + + c := New() + results, err := c.Exec("SET NAMES utf8mb4", nil) + if err != nil { + t.Fatalf("omni: parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni: SET NAMES error: %v", results[0].Error) + } + }) + + // Scenario: SET CHARACTER SET utf8mb4 — silently accepted + t.Run("set_character_set", func(t *testing.T) { + ctrErr := ctr.execSQL("SET CHARACTER SET utf8mb4") + if ctrErr != nil { + t.Fatalf("container: SET CHARACTER SET error: %v", ctrErr) + } + + c := New() + results, err := c.Exec("SET CHARACTER SET utf8mb4", nil) + if err != nil { + t.Fatalf("omni: parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni: SET CHARACTER SET error: %v", results[0].Error) + } + }) + + // Scenario: SET sql_mode — silently accepted + t.Run("set_sql_mode", func(t *testing.T) { + ctrErr := ctr.execSQL("SET sql_mode = 'STRICT_TRANS_TABLES,NO_ENGINE_SUBSTITUTION'") + if ctrErr != nil { + t.Fatalf("container: SET sql_mode error: %v", ctrErr) + } + + c := New() + results, err := c.Exec("SET sql_mode = 'STRICT_TRANS_TABLES,NO_ENGINE_SUBSTITUTION'", nil) + if err != nil { + t.Fatalf("omni: parse error: %v", err) + } + if results[0].Error != nil { + t.Fatalf("omni: SET sql_mode error: %v", results[0].Error) + } + }) +} + +// TestContainer_Section_5_3_UserRoleManagement verifies that user/role management +// statements are accepted by both MySQL and omni without error. +// These are marked [~] partial because the in-memory catalog does not actually +// store users, roles, or privileges — it just silently accepts them. +func TestContainer_Section_5_3_UserRoleManagement(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + setup string // SQL to run before the main statement (on container only) + sql string // main statement to test + cleanup string // SQL to run after the test (on container only) + }{ + { + "create_user", + "", + "CREATE USER 'testuser53'@'localhost' IDENTIFIED BY 'Password123!'", + "DROP USER IF EXISTS 'testuser53'@'localhost'", + }, + { + "drop_user", + "CREATE USER IF NOT EXISTS 'dropme53'@'localhost' IDENTIFIED BY 'Password123!'", + "DROP USER 'dropme53'@'localhost'", + "", + }, + { + "alter_user", + "CREATE USER IF NOT EXISTS 'alterme53'@'localhost' IDENTIFIED BY 'Password123!'", + "ALTER USER 'alterme53'@'localhost' IDENTIFIED BY 'NewPassword456!'", + "DROP USER IF EXISTS 'alterme53'@'localhost'", + }, + { + "create_role", + "", + "CREATE ROLE 'app_role53'", + "DROP ROLE IF EXISTS 'app_role53'", + }, + { + "drop_role", + "CREATE ROLE IF NOT EXISTS 'droprole53'", + "DROP ROLE 'droprole53'", + "", + }, + { + "grant_privileges", + "CREATE USER IF NOT EXISTS 'grantee53'@'localhost' IDENTIFIED BY 'Password123!'", + "GRANT SELECT, INSERT ON test.* TO 'grantee53'@'localhost'", + "DROP USER IF EXISTS 'grantee53'@'localhost'", + }, + { + "revoke_privileges", + "CREATE USER IF NOT EXISTS 'revokee53'@'localhost' IDENTIFIED BY 'Password123!'; GRANT SELECT ON test.* TO 'revokee53'@'localhost'", + "REVOKE SELECT ON test.* FROM 'revokee53'@'localhost'", + "DROP USER IF EXISTS 'revokee53'@'localhost'", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // Setup on container + if tc.setup != "" { + if err := ctr.execSQL(tc.setup); err != nil { + t.Fatalf("container setup: %v", err) + } + } + + // Run on container — must succeed + ctrErr := ctr.execSQL(tc.sql) + if ctrErr != nil { + t.Fatalf("container exec: %v", ctrErr) + } + + // Cleanup on container + if tc.cleanup != "" { + _ = ctr.execSQL(tc.cleanup) + } + + // Run on omni — must parse and not return a parse error. + // The catalog silently accepts these (no-op) since it doesn't + // track users/roles/privileges, which is expected behavior + // for an in-memory DDL catalog. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if len(results) > 0 && results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + }) + } +} + +func TestContainer_Section_6_1_ShowCreateTableIntegration(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + cases := []struct { + name string + sql string + table string + }{ + // --- All data types (Phase 1.1-1.6) integration --- + {"all_data_types", `CREATE TABLE t_all_types ( + id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + tiny_col TINYINT, + small_col SMALLINT, + med_col MEDIUMINT, + int_col INT, + big_col BIGINT, + float_col FLOAT, + float_prec FLOAT(7,3), + double_col DOUBLE, + double_prec DOUBLE(15,5), + decimal_col DECIMAL(10,2), + decimal_bare DECIMAL, + bool_col BOOLEAN, + bit_col BIT(8), + char_col CHAR(10), + varchar_col VARCHAR(255), + tinytext_col TINYTEXT, + text_col TEXT, + mediumtext_col MEDIUMTEXT, + longtext_col LONGTEXT, + enum_col ENUM('a','b','c'), + set_col SET('x','y','z'), + binary_col BINARY(16), + varbinary_col VARBINARY(255), + tinyblob_col TINYBLOB, + blob_col BLOB, + mediumblob_col MEDIUMBLOB, + longblob_col LONGBLOB, + date_col DATE, + time_col TIME, + time_frac TIME(3), + datetime_col DATETIME, + datetime_frac DATETIME(6), + timestamp_col TIMESTAMP NULL, + timestamp_frac TIMESTAMP(3) NULL, + year_col YEAR, + json_col JSON, + geo_col GEOMETRY, + point_col POINT, + linestring_col LINESTRING, + polygon_col POLYGON, + multipoint_col MULTIPOINT, + multiline_col MULTILINESTRING, + multipoly_col MULTIPOLYGON, + geocoll_col GEOMETRYCOLLECTION, + PRIMARY KEY (id) + )`, "t_all_types"}, + + // --- All default value forms (Phase 1.7) integration --- + {"all_defaults", `CREATE TABLE t_all_defaults ( + id INT NOT NULL AUTO_INCREMENT, + int_def INT DEFAULT 0, + int_null INT DEFAULT NULL, + int_notnull INT NOT NULL, + varchar_def VARCHAR(100) DEFAULT 'hello', + varchar_empty VARCHAR(100) DEFAULT '', + float_def FLOAT DEFAULT 3.14, + decimal_def DECIMAL(10,2) DEFAULT 0.00, + bool_true BOOLEAN DEFAULT TRUE, + bool_false BOOLEAN DEFAULT FALSE, + enum_def ENUM('a','b','c') DEFAULT 'a', + set_def SET('x','y','z') DEFAULT 'x,y', + bit_def BIT(8) DEFAULT b'00001111', + blob_col BLOB, + text_col TEXT, + json_col JSON, + ts_def TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + dt_def DATETIME DEFAULT CURRENT_TIMESTAMP, + ts3_def TIMESTAMP(3) NULL DEFAULT CURRENT_TIMESTAMP(3), + expr_int INT DEFAULT (FLOOR(RAND()*100)), + expr_json JSON DEFAULT (JSON_ARRAY()), + expr_varchar VARCHAR(100) DEFAULT (UUID()), + dt_literal DATETIME DEFAULT '2024-01-01 00:00:00', + date_literal DATE DEFAULT '2024-01-01', + PRIMARY KEY (id) + )`, "t_all_defaults"}, + + // --- All index types (Phase 1.13-1.16) integration --- + {"all_index_types", `CREATE TABLE t_all_indexes ( + id INT NOT NULL AUTO_INCREMENT, + name VARCHAR(255) NOT NULL, + email VARCHAR(255), + score INT, + rank_val INT, + bio TEXT, + description TEXT, + geo_col POINT NOT NULL, + PRIMARY KEY (id), + KEY idx_name (name), + KEY idx_email_prefix (email(10)), + KEY idx_score_desc (score DESC), + KEY idx_multi_mixed (score ASC, rank_val DESC), + UNIQUE KEY uk_email (email), + FULLTEXT KEY ft_bio (bio, description), + SPATIAL KEY sp_geo (geo_col), + KEY idx_expr ((UPPER(name))), + KEY idx_expr_calc ((score + rank_val)), + KEY idx_comment (name) COMMENT 'name lookup index', + KEY idx_invisible (score) /*!80000 INVISIBLE */ + ) ENGINE=InnoDB`, "t_all_indexes"}, + + // --- All constraint forms (Phase 1.17-1.18) integration --- + {"all_constraints_parent", `CREATE TABLE t_parent ( + id INT NOT NULL AUTO_INCREMENT, + code VARCHAR(10) NOT NULL, + PRIMARY KEY (id), + UNIQUE KEY uk_code (code) + )`, "t_parent"}, + {"all_constraints", `CREATE TABLE t_all_constraints ( + id INT NOT NULL AUTO_INCREMENT, + parent_id INT, + parent_code VARCHAR(10), + self_ref INT, + val INT, + score INT, + PRIMARY KEY (id), + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES t_parent(id) ON DELETE CASCADE, + CONSTRAINT fk_code FOREIGN KEY (parent_code) REFERENCES t_parent(code) ON UPDATE SET NULL, + CONSTRAINT fk_self FOREIGN KEY (self_ref) REFERENCES t_all_constraints(id) ON DELETE SET NULL ON UPDATE CASCADE, + CONSTRAINT chk_val CHECK (val > 0), + CONSTRAINT chk_score CHECK (score >= 0 AND score <= 100), + CONSTRAINT chk_not_enforced CHECK (val < 1000) /*!80016 NOT ENFORCED */ + )`, "t_all_constraints"}, + + // --- Partitioned table output (Phase 4.1) integration --- + {"partitioned_range", `CREATE TABLE t_part_range ( + id INT NOT NULL, + created_date DATE NOT NULL, + name VARCHAR(100), + PRIMARY KEY (id, created_date) + ) ENGINE=InnoDB + PARTITION BY RANGE (YEAR(created_date)) ( + PARTITION p2020 VALUES LESS THAN (2021), + PARTITION p2021 VALUES LESS THAN (2022), + PARTITION p2022 VALUES LESS THAN (2023), + PARTITION pmax VALUES LESS THAN MAXVALUE + )`, "t_part_range"}, + {"partitioned_list", `CREATE TABLE t_part_list ( + id INT NOT NULL AUTO_INCREMENT, + region VARCHAR(20) NOT NULL, + data VARCHAR(100), + PRIMARY KEY (id, region) + ) ENGINE=InnoDB + PARTITION BY LIST COLUMNS(region) ( + PARTITION p_east VALUES IN ('east','northeast'), + PARTITION p_west VALUES IN ('west','northwest'), + PARTITION p_other VALUES IN ('central','south') + )`, "t_part_list"}, + {"partitioned_hash", `CREATE TABLE t_part_hash ( + id INT NOT NULL AUTO_INCREMENT, + val INT, + PRIMARY KEY (id) + ) ENGINE=InnoDB + PARTITION BY HASH(id) + PARTITIONS 4`, "t_part_hash"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // Clean up in proper order (child before parent) + if tc.table == "t_all_constraints" { + ctr.execSQL("DROP TABLE IF EXISTS t_all_constraints") + } + if tc.table == "t_parent" { + ctr.execSQL("DROP TABLE IF EXISTS t_all_constraints") + ctr.execSQL("DROP TABLE IF EXISTS t_parent") + } + if tc.table != "t_all_constraints" && tc.table != "t_parent" { + ctr.execSQL("DROP TABLE IF EXISTS " + tc.table) + } + if err := ctr.execSQL(tc.sql); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateTable(tc.table) + if err != nil { + t.Fatalf("container showCreateTable: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + // For constraint tests, need parent table first + if tc.table == "t_all_constraints" { + c.Exec("CREATE TABLE t_parent (id INT NOT NULL AUTO_INCREMENT, code VARCHAR(10) NOT NULL, PRIMARY KEY (id), UNIQUE KEY uk_code (code))", nil) + } + results, err := c.Exec(tc.sql, nil) + if err != nil { + t.Fatalf("omni parse error: %v", err) + } + if len(results) > 0 && results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTable("test", tc.table) + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + } +} + +// TestContainer_Section_6_2_ShowCreateOtherObjects tests SHOW CREATE VIEW/FUNCTION/PROCEDURE/TRIGGER/EVENT +// output against real MySQL 8.0. This complements sections 4.2-4.5 by testing additional patterns +// and verifying SHOW CREATE as a query API surface. +func TestContainer_Section_6_2_ShowCreateOtherObjects(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // --- SHOW CREATE VIEW --- + t.Run("show_create_view", func(t *testing.T) { + // Test SHOW CREATE VIEW with default options. + // MySQL rewrites SELECT text, so we compare preamble only. + ctr.execSQL("DROP VIEW IF EXISTS v_sc_basic") + ctr.execSQL("CREATE VIEW v_sc_basic AS SELECT 1 AS col1, 2 AS col2") + ctrDDL, err := ctr.showCreateView("v_sc_basic") + if err != nil { + t.Fatalf("container SHOW CREATE VIEW: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, parseErr := c.Exec("CREATE VIEW v_sc_basic AS SELECT 1 AS col1, 2 AS col2", nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateView("test", "v_sc_basic") + + // Compare preamble (up to AS) — SELECT text may differ due to MySQL rewriting. + oraclePreamble := extractViewPreamble(ctrDDL) + omniPreamble := extractViewPreamble(omniDDL) + if oraclePreamble != omniPreamble { + t.Errorf("preamble mismatch:\n--- container ---\n%s\n--- omni ---\n%s\n--- container full ---\n%s\n--- omni full ---\n%s", + oraclePreamble, omniPreamble, ctrDDL, omniDDL) + } + + // Verify both contain the key structural elements. + for _, substr := range []string{"ALGORITHM=", "DEFINER=", "SQL SECURITY", "VIEW", "v_sc_basic", " AS "} { + if !strings.Contains(ctrDDL, substr) { + t.Errorf("container DDL missing %q: %s", substr, ctrDDL) + } + if !strings.Contains(omniDDL, substr) { + t.Errorf("omni DDL missing %q: %s", substr, omniDDL) + } + } + }) + + // --- SHOW CREATE FUNCTION --- + t.Run("show_create_function", func(t *testing.T) { + // Test SHOW CREATE FUNCTION with various characteristics. + ctr.execSQL("DROP FUNCTION IF EXISTS fn_sc_multiply") + createSQL := "CREATE FUNCTION fn_sc_multiply(x INT, y INT) RETURNS INT DETERMINISTIC RETURN x * y" + if err := ctr.execSQL(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateFunction("fn_sc_multiply") + if err != nil { + t.Fatalf("container SHOW CREATE FUNCTION: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateFunction("test", "fn_sc_multiply") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + // --- SHOW CREATE PROCEDURE --- + t.Run("show_create_procedure", func(t *testing.T) { + // Test SHOW CREATE PROCEDURE with INOUT params. + ctr.execSQLDirect("DROP PROCEDURE IF EXISTS sp_sc_swap") + createSQL := "CREATE PROCEDURE sp_sc_swap(INOUT a INT, INOUT b INT) BEGIN DECLARE tmp INT; SET tmp = a; SET a = b; SET b = tmp; END" + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateProcedure("sp_sc_swap") + if err != nil { + t.Fatalf("container SHOW CREATE PROCEDURE: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateProcedure("test", "sp_sc_swap") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + // --- SHOW CREATE TRIGGER --- + t.Run("show_create_trigger", func(t *testing.T) { + // Setup table for triggers. + ctr.execSQL("DROP TRIGGER IF EXISTS tr_sc_audit") + ctr.execSQL("DROP TABLE IF EXISTS t_sc_trigger") + if err := ctr.execSQL("CREATE TABLE t_sc_trigger (id INT AUTO_INCREMENT PRIMARY KEY, val INT, updated_at DATETIME)"); err != nil { + t.Fatalf("container setup table: %v", err) + } + createSQL := "CREATE TRIGGER tr_sc_audit AFTER INSERT ON t_sc_trigger FOR EACH ROW SET @last_insert_id = NEW.id" + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateTrigger("tr_sc_audit") + if err != nil { + t.Fatalf("container SHOW CREATE TRIGGER: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t_sc_trigger (id INT AUTO_INCREMENT PRIMARY KEY, val INT, updated_at DATETIME)", nil) + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateTrigger("test", "tr_sc_audit") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) + + // --- SHOW CREATE EVENT --- + t.Run("show_create_event", func(t *testing.T) { + // Test SHOW CREATE EVENT with AT schedule (exact timestamp, no auto-STARTS issue). + ctr.execSQL("DROP EVENT IF EXISTS ev_sc_onetime") + createSQL := "CREATE EVENT ev_sc_onetime ON SCHEDULE AT '2035-06-15 12:00:00' ON COMPLETION PRESERVE COMMENT 'one-time event' DO SELECT 1" + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("container exec: %v", err) + } + ctrDDL, err := ctr.showCreateEvent("ev_sc_onetime") + if err != nil { + t.Fatalf("container SHOW CREATE EVENT: %v", err) + } + + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, parseErr := c.Exec(createSQL, nil) + if parseErr != nil { + t.Fatalf("omni parse error: %v", parseErr) + } + if results[0].Error != nil { + t.Fatalf("omni exec error: %v", results[0].Error) + } + omniDDL := c.ShowCreateEvent("test", "ev_sc_onetime") + + if normalizeWhitespace(ctrDDL) != normalizeWhitespace(omniDDL) { + t.Errorf("mismatch:\n--- container ---\n%s\n--- omni ---\n%s", + ctrDDL, omniDDL) + } + }) +} + +// TestContainer_Section_6_3_InformationSchemaConsistency verifies that the catalog's +// internal state is consistent with what MySQL 8.0 reports via INFORMATION_SCHEMA. +// +// The omni catalog does not support INFORMATION_SCHEMA SQL queries (SELECT is +// treated as DML and skipped). Instead, we compare the catalog's Go-level data +// structures against real MySQL INFORMATION_SCHEMA query results. +// +// All scenarios are marked [~] partial because the catalog lacks an +// INFORMATION_SCHEMA query engine — users cannot run SELECT ... FROM +// INFORMATION_SCHEMA.* against the in-memory catalog. +func TestContainer_Section_6_3_InformationSchemaConsistency(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup SQL: a table with various column types, indexes, FK, and CHECK constraints. + parentSQL := "CREATE TABLE t_is_parent (id INT NOT NULL AUTO_INCREMENT, PRIMARY KEY (id)) ENGINE=InnoDB" + setupSQL := `CREATE TABLE t_is_test ( + id INT NOT NULL AUTO_INCREMENT, + name VARCHAR(100) NOT NULL DEFAULT '', + price DECIMAL(10,2) DEFAULT '0.00', + status ENUM('active','inactive') DEFAULT 'active', + parent_id INT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id), + UNIQUE KEY uk_name (name), + KEY idx_status (status), + KEY idx_parent (parent_id), + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES t_is_parent (id) ON DELETE SET NULL, + CONSTRAINT chk_price CHECK (price >= 0) + ) ENGINE=InnoDB` + + // Setup on ctr. + ctr.execSQL("DROP TABLE IF EXISTS t_is_test") + ctr.execSQL("DROP TABLE IF EXISTS t_is_parent") + if err := ctr.execSQL(parentSQL); err != nil { + t.Fatalf("container parent table: %v", err) + } + if err := ctr.execSQL(setupSQL); err != nil { + t.Fatalf("container setup: %v", err) + } + + // Setup on omni. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + if results, _ := c.Exec(parentSQL, nil); results[0].Error != nil { + t.Fatalf("omni parent table: %v", results[0].Error) + } + if results, _ := c.Exec(setupSQL, nil); results[0].Error != nil { + t.Fatalf("omni setup: %v", results[0].Error) + } + + db := c.GetDatabase("test") + tbl := db.GetTable("t_is_test") + + // 6.3.1: INFORMATION_SCHEMA.COLUMNS matches catalog state + t.Run("columns_match", func(t *testing.T) { + oracleCols, err := ctr.queryColumns("test", "t_is_test") + if err != nil { + t.Fatalf("container queryColumns: %v", err) + } + if len(oracleCols) != len(tbl.Columns) { + t.Fatalf("column count: container=%d omni=%d", len(oracleCols), len(tbl.Columns)) + } + for i, oc := range oracleCols { + omniCol := tbl.Columns[i] + if oc.Name != omniCol.Name { + t.Errorf("col[%d] name: container=%q omni=%q", i, oc.Name, omniCol.Name) + } + if oc.Position != omniCol.Position { + t.Errorf("col[%d] position: container=%d omni=%d", i, oc.Position, omniCol.Position) + } + wantNullable := "YES" + if !omniCol.Nullable { + wantNullable = "NO" + } + if oc.Nullable != wantNullable { + t.Errorf("col[%d] %s nullable: container=%q omni=%q", i, oc.Name, oc.Nullable, wantNullable) + } + } + t.Log("[~] partial: catalog internal state matches, but INFORMATION_SCHEMA SQL queries require query engine") + }) + + // 6.3.2: INFORMATION_SCHEMA.STATISTICS matches catalog indexes + t.Run("statistics_match", func(t *testing.T) { + oracleIdxs, err := ctr.queryIndexes("test", "t_is_test") + if err != nil { + t.Fatalf("container queryIndexes: %v", err) + } + // Build omni index column map: indexName -> []columnName + omniIdxCols := make(map[string][]string) + for _, idx := range tbl.Indexes { + for _, ic := range idx.Columns { + omniIdxCols[idx.Name] = append(omniIdxCols[idx.Name], ic.Name) + } + } + // Build container index column map. + oracleIdxCols := make(map[string][]string) + for _, oi := range oracleIdxs { + oracleIdxCols[oi.Name] = append(oracleIdxCols[oi.Name], oi.ColumnName) + } + // Compare index names and columns. + for name, oracleCols := range oracleIdxCols { + omniCols, ok := omniIdxCols[name] + if !ok { + t.Errorf("index %q exists in container but not in omni", name) + continue + } + if len(oracleCols) != len(omniCols) { + t.Errorf("index %q column count: container=%d omni=%d", name, len(oracleCols), len(omniCols)) + continue + } + for j := range oracleCols { + if oracleCols[j] != omniCols[j] { + t.Errorf("index %q col[%d]: container=%q omni=%q", name, j, oracleCols[j], omniCols[j]) + } + } + } + for name := range omniIdxCols { + if _, ok := oracleIdxCols[name]; !ok { + t.Errorf("index %q exists in omni but not in container", name) + } + } + t.Log("[~] partial: catalog internal state matches, but INFORMATION_SCHEMA SQL queries require query engine") + }) + + // 6.3.3: INFORMATION_SCHEMA.TABLE_CONSTRAINTS matches + t.Run("table_constraints_match", func(t *testing.T) { + oracleCons, err := ctr.queryConstraints("test", "t_is_test") + if err != nil { + t.Fatalf("container queryConstraints: %v", err) + } + // Build omni constraint map. + omniCons := make(map[string]string) // name -> type + for _, con := range tbl.Constraints { + var typeName string + switch con.Type { + case ConPrimaryKey: + typeName = "PRIMARY KEY" + case ConUniqueKey: + typeName = "UNIQUE" + case ConForeignKey: + typeName = "FOREIGN KEY" + case ConCheck: + typeName = "CHECK" + } + omniCons[con.Name] = typeName + } + // Build container constraint map. + oracleConsMap := make(map[string]string) + for _, oc := range oracleCons { + oracleConsMap[oc.Name] = oc.Type + } + for name, oType := range oracleConsMap { + omniType, ok := omniCons[name] + if !ok { + t.Errorf("constraint %q (%s) in container but not omni", name, oType) + continue + } + if oType != omniType { + t.Errorf("constraint %q type: container=%q omni=%q", name, oType, omniType) + } + } + for name := range omniCons { + if _, ok := oracleConsMap[name]; !ok { + t.Errorf("constraint %q in omni but not container", name) + } + } + t.Log("[~] partial: catalog internal state matches, but INFORMATION_SCHEMA SQL queries require query engine") + }) + + // 6.3.4: INFORMATION_SCHEMA.KEY_COLUMN_USAGE matches + t.Run("key_column_usage_match", func(t *testing.T) { + rows, err := ctr.db.QueryContext(ctr.ctx, ` + SELECT CONSTRAINT_NAME, COLUMN_NAME, ORDINAL_POSITION, + REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME + FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE + WHERE TABLE_SCHEMA = 'test' AND TABLE_NAME = 't_is_test' + ORDER BY CONSTRAINT_NAME, ORDINAL_POSITION`) + if err != nil { + t.Fatalf("container KEY_COLUMN_USAGE query: %v", err) + } + defer rows.Close() + + type kcuRow struct { + constraintName, columnName string + ordinalPos int + refTable, refColumn *string + } + var oracleKCU []kcuRow + for rows.Next() { + var r kcuRow + var refTbl, refCol *string + if err := rows.Scan(&r.constraintName, &r.columnName, &r.ordinalPos, &refTbl, &refCol); err != nil { + t.Fatalf("scan: %v", err) + } + r.refTable = refTbl + r.refColumn = refCol + oracleKCU = append(oracleKCU, r) + } + + // Verify omni constraints have matching columns. + omniKCU := make(map[string][]string) // constraint name -> columns + for _, con := range tbl.Constraints { + if con.Type == ConCheck { + continue // CHECK constraints don't appear in KEY_COLUMN_USAGE + } + omniKCU[con.Name] = con.Columns + } + for _, okcu := range oracleKCU { + cols, ok := omniKCU[okcu.constraintName] + if !ok { + t.Errorf("constraint %q in container KEY_COLUMN_USAGE but not omni", okcu.constraintName) + continue + } + idx := okcu.ordinalPos - 1 + if idx < len(cols) && cols[idx] != okcu.columnName { + t.Errorf("constraint %q col[%d]: container=%q omni=%q", + okcu.constraintName, idx, okcu.columnName, cols[idx]) + } + } + + // Verify FK references match. + for _, con := range tbl.Constraints { + if con.Type != ConForeignKey { + continue + } + for _, okcu := range oracleKCU { + if okcu.constraintName == con.Name && okcu.refTable != nil { + if *okcu.refTable != con.RefTable { + t.Errorf("FK %q ref table: container=%q omni=%q", con.Name, *okcu.refTable, con.RefTable) + } + } + } + } + t.Log("[~] partial: catalog internal state matches, but INFORMATION_SCHEMA SQL queries require query engine") + }) + + // 6.3.5: INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS matches + t.Run("referential_constraints_match", func(t *testing.T) { + rows, err := ctr.db.QueryContext(ctr.ctx, ` + SELECT CONSTRAINT_NAME, UNIQUE_CONSTRAINT_NAME, + MATCH_OPTION, UPDATE_RULE, DELETE_RULE, + REFERENCED_TABLE_NAME + FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS + WHERE CONSTRAINT_SCHEMA = 'test' AND TABLE_NAME = 't_is_test' + ORDER BY CONSTRAINT_NAME`) + if err != nil { + t.Fatalf("container REFERENTIAL_CONSTRAINTS query: %v", err) + } + defer rows.Close() + + type refConRow struct { + name, uniqueConName, matchOption, updateRule, deleteRule, refTable string + } + var oracleRefs []refConRow + for rows.Next() { + var r refConRow + if err := rows.Scan(&r.name, &r.uniqueConName, &r.matchOption, &r.updateRule, &r.deleteRule, &r.refTable); err != nil { + t.Fatalf("scan: %v", err) + } + oracleRefs = append(oracleRefs, r) + } + + // Compare with omni FK constraints. + for _, oref := range oracleRefs { + var found *Constraint + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey && con.Name == oref.name { + found = con + break + } + } + if found == nil { + t.Errorf("FK %q in container REFERENTIAL_CONSTRAINTS but not omni", oref.name) + continue + } + if oref.refTable != found.RefTable { + t.Errorf("FK %q ref table: container=%q omni=%q", oref.name, oref.refTable, found.RefTable) + } + // Normalize action names for comparison. + oracleDelete := oref.deleteRule + omniDelete := found.OnDelete + if omniDelete == "" { + omniDelete = "RESTRICT" // MySQL default + } + if oracleDelete != omniDelete { + t.Errorf("FK %q delete rule: container=%q omni=%q", oref.name, oracleDelete, omniDelete) + } + } + t.Log("[~] partial: catalog internal state matches, but INFORMATION_SCHEMA SQL queries require query engine") + }) + + // 6.3.6: INFORMATION_SCHEMA.CHECK_CONSTRAINTS matches + t.Run("check_constraints_match", func(t *testing.T) { + rows, err := ctr.db.QueryContext(ctr.ctx, ` + SELECT CONSTRAINT_NAME, CHECK_CLAUSE + FROM INFORMATION_SCHEMA.CHECK_CONSTRAINTS + WHERE CONSTRAINT_SCHEMA = 'test' + ORDER BY CONSTRAINT_NAME`) + if err != nil { + t.Fatalf("container CHECK_CONSTRAINTS query: %v", err) + } + defer rows.Close() + + type checkRow struct { + name, clause string + } + var oracleChecks []checkRow + for rows.Next() { + var r checkRow + if err := rows.Scan(&r.name, &r.clause); err != nil { + t.Fatalf("scan: %v", err) + } + oracleChecks = append(oracleChecks, r) + } + + // Filter to just our table's check constraints (INFORMATION_SCHEMA.CHECK_CONSTRAINTS + // doesn't have TABLE_NAME, so we match by constraint name). + omniChecks := make(map[string]string) // name -> expr + for _, con := range tbl.Constraints { + if con.Type == ConCheck { + omniChecks[con.Name] = con.CheckExpr + } + } + for _, oc := range oracleChecks { + omniExpr, ok := omniChecks[oc.name] + if !ok { + // May be an ENUM/SET constraint auto-generated by MySQL — skip. + continue + } + // Normalize for comparison: remove outer parens and whitespace. + oracleNorm := normalizeWhitespace(oc.clause) + omniNorm := normalizeWhitespace(omniExpr) + if oracleNorm != omniNorm { + // Log but don't fail — expression format may differ slightly. + t.Logf("check %q clause: container=%q omni=%q (expression format may differ)", oc.name, oracleNorm, omniNorm) + } + } + t.Log("[~] partial: catalog internal state matches, but INFORMATION_SCHEMA SQL queries require query engine") + }) +} diff --git a/tidb/catalog/container_test.go b/tidb/catalog/container_test.go new file mode 100644 index 00000000..8353cf76 --- /dev/null +++ b/tidb/catalog/container_test.go @@ -0,0 +1,411 @@ +package catalog + +import ( + "context" + "database/sql" + "fmt" + "strings" + "testing" + + _ "github.com/go-sql-driver/mysql" + "github.com/testcontainers/testcontainers-go" + tcmysql "github.com/testcontainers/testcontainers-go/modules/mysql" +) + +// mysqlContainer wraps a real MySQL 8.0 container connection for container testing. +type mysqlContainer struct { + db *sql.DB + ctx context.Context +} + +// columnInfo holds a row from INFORMATION_SCHEMA.COLUMNS. +type columnInfo struct { + Name, DataType, ColumnType, ColumnKey, Extra, Nullable string + Position int + Default, Charset, Collation sql.NullString + CharMaxLen, NumPrecision, NumScale sql.NullInt64 +} + +// indexInfo holds a row from INFORMATION_SCHEMA.STATISTICS. +type indexInfo struct { + Name, ColumnName, IndexType, Nullable string + NonUnique, SeqInIndex int + Collation sql.NullString +} + +// constraintInfo holds a row from INFORMATION_SCHEMA.TABLE_CONSTRAINTS. +type constraintInfo struct { + Name, Type string +} + +// startContainer starts a MySQL 8.0 container and returns an container handle plus +// a cleanup function. The caller must defer the cleanup function. +func startContainer(t *testing.T) (*mysqlContainer, func()) { + t.Helper() + ctx := context.Background() + + container, err := tcmysql.Run(ctx, "mysql:8.0", + tcmysql.WithDatabase("test"), + tcmysql.WithUsername("root"), + tcmysql.WithPassword("test"), + ) + if err != nil { + t.Fatalf("failed to start MySQL container: %v", err) + } + + connStr, err := container.ConnectionString(ctx, "parseTime=true", "multiStatements=true") + if err != nil { + _ = testcontainers.TerminateContainer(container) + t.Fatalf("failed to get connection string: %v", err) + } + + db, err := sql.Open("mysql", connStr) + if err != nil { + _ = testcontainers.TerminateContainer(container) + t.Fatalf("failed to open database: %v", err) + } + + if err := db.PingContext(ctx); err != nil { + db.Close() + _ = testcontainers.TerminateContainer(container) + t.Fatalf("failed to ping database: %v", err) + } + + cleanup := func() { + db.Close() + _ = testcontainers.TerminateContainer(container) + } + + return &mysqlContainer{db: db, ctx: ctx}, cleanup +} + +// execSQL executes one or more SQL statements separated by semicolons. +// It respects quoted strings when splitting. +func (o *mysqlContainer) execSQL(sqlStr string) error { + stmts := splitStatements(sqlStr) + for _, stmt := range stmts { + stmt = strings.TrimSpace(stmt) + if stmt == "" { + continue + } + if _, err := o.db.ExecContext(o.ctx, stmt); err != nil { + return fmt.Errorf("executing %q: %w", stmt, err) + } + } + return nil +} + +// execSQLDirect executes a single SQL statement directly without splitting on semicolons. +// This is needed for CREATE PROCEDURE/FUNCTION with BEGIN...END blocks. +func (o *mysqlContainer) execSQLDirect(sqlStr string) error { + _, err := o.db.ExecContext(o.ctx, sqlStr) + return err +} + +// showCreateDatabase runs SHOW CREATE DATABASE and returns the CREATE DATABASE statement. +func (o *mysqlContainer) showCreateDatabase(database string) (string, error) { + var dbName, createStmt string + err := o.db.QueryRowContext(o.ctx, "SHOW CREATE DATABASE "+database).Scan(&dbName, &createStmt) + if err != nil { + return "", fmt.Errorf("SHOW CREATE DATABASE %s: %w", database, err) + } + return createStmt, nil +} + +// showCreateTable runs SHOW CREATE TABLE and returns the CREATE TABLE statement. +func (o *mysqlContainer) showCreateTable(table string) (string, error) { + var tableName, createStmt string + err := o.db.QueryRowContext(o.ctx, "SHOW CREATE TABLE "+table).Scan(&tableName, &createStmt) + if err != nil { + return "", fmt.Errorf("SHOW CREATE TABLE %s: %w", table, err) + } + return createStmt, nil +} + +// showCreateFunction runs SHOW CREATE FUNCTION and returns the CREATE FUNCTION statement. +func (o *mysqlContainer) showCreateFunction(name string) (string, error) { + var funcName, sqlMode, createStmt, charSetClient, collConn, dbCollation string + err := o.db.QueryRowContext(o.ctx, "SHOW CREATE FUNCTION "+name).Scan( + &funcName, &sqlMode, &createStmt, &charSetClient, &collConn, &dbCollation) + if err != nil { + return "", fmt.Errorf("SHOW CREATE FUNCTION %s: %w", name, err) + } + return createStmt, nil +} + +// showCreateProcedure runs SHOW CREATE PROCEDURE and returns the CREATE PROCEDURE statement. +func (o *mysqlContainer) showCreateProcedure(name string) (string, error) { + var procName, sqlMode, createStmt, charSetClient, collConn, dbCollation string + err := o.db.QueryRowContext(o.ctx, "SHOW CREATE PROCEDURE "+name).Scan( + &procName, &sqlMode, &createStmt, &charSetClient, &collConn, &dbCollation) + if err != nil { + return "", fmt.Errorf("SHOW CREATE PROCEDURE %s: %w", name, err) + } + return createStmt, nil +} + +// showCreateTrigger runs SHOW CREATE TRIGGER and returns the SQL Original Statement field. +func (o *mysqlContainer) showCreateTrigger(name string) (string, error) { + var trigName, sqlMode, createStmt, charSetClient, collConn, dbCollation string + var created sql.NullString + err := o.db.QueryRowContext(o.ctx, "SHOW CREATE TRIGGER "+name).Scan( + &trigName, &sqlMode, &createStmt, &charSetClient, &collConn, &dbCollation, &created) + if err != nil { + return "", fmt.Errorf("SHOW CREATE TRIGGER %s: %w", name, err) + } + return createStmt, nil +} + +// showCreateView runs SHOW CREATE VIEW and returns the CREATE VIEW statement. +func (o *mysqlContainer) showCreateView(name string) (string, error) { + var viewName, createStmt, charSetClient, collConn string + err := o.db.QueryRowContext(o.ctx, "SHOW CREATE VIEW "+name).Scan( + &viewName, &createStmt, &charSetClient, &collConn) + if err != nil { + return "", fmt.Errorf("SHOW CREATE VIEW %s: %w", name, err) + } + return createStmt, nil +} + +// showCreateEvent runs SHOW CREATE EVENT and returns the CREATE EVENT statement. +func (o *mysqlContainer) showCreateEvent(name string) (string, error) { + var eventName, sqlMode, tz, createStmt, charSetClient, collConn, dbCollation string + err := o.db.QueryRowContext(o.ctx, "SHOW CREATE EVENT "+name).Scan( + &eventName, &sqlMode, &tz, &createStmt, &charSetClient, &collConn, &dbCollation) + if err != nil { + return "", fmt.Errorf("SHOW CREATE EVENT %s: %w", name, err) + } + return createStmt, nil +} + +// queryColumns queries INFORMATION_SCHEMA.COLUMNS for the given table. +func (o *mysqlContainer) queryColumns(database, table string) ([]columnInfo, error) { + rows, err := o.db.QueryContext(o.ctx, ` + SELECT COLUMN_NAME, ORDINAL_POSITION, DATA_TYPE, COLUMN_TYPE, + IS_NULLABLE, COLUMN_DEFAULT, COLUMN_KEY, EXTRA, + CHARACTER_SET_NAME, COLLATION_NAME, + CHARACTER_MAXIMUM_LENGTH, NUMERIC_PRECISION, NUMERIC_SCALE + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? + ORDER BY ORDINAL_POSITION`, + database, table) + if err != nil { + return nil, fmt.Errorf("querying columns: %w", err) + } + defer rows.Close() + + var cols []columnInfo + for rows.Next() { + var c columnInfo + if err := rows.Scan( + &c.Name, &c.Position, &c.DataType, &c.ColumnType, + &c.Nullable, &c.Default, &c.ColumnKey, &c.Extra, + &c.Charset, &c.Collation, + &c.CharMaxLen, &c.NumPrecision, &c.NumScale, + ); err != nil { + return nil, fmt.Errorf("scanning column row: %w", err) + } + cols = append(cols, c) + } + return cols, rows.Err() +} + +// queryIndexes queries INFORMATION_SCHEMA.STATISTICS for the given table. +func (o *mysqlContainer) queryIndexes(database, table string) ([]indexInfo, error) { + rows, err := o.db.QueryContext(o.ctx, ` + SELECT INDEX_NAME, SEQ_IN_INDEX, COLUMN_NAME, COLLATION, + NON_UNIQUE, INDEX_TYPE, NULLABLE + FROM INFORMATION_SCHEMA.STATISTICS + WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? + ORDER BY INDEX_NAME, SEQ_IN_INDEX`, + database, table) + if err != nil { + return nil, fmt.Errorf("querying indexes: %w", err) + } + defer rows.Close() + + var idxs []indexInfo + for rows.Next() { + var idx indexInfo + if err := rows.Scan( + &idx.Name, &idx.SeqInIndex, &idx.ColumnName, &idx.Collation, + &idx.NonUnique, &idx.IndexType, &idx.Nullable, + ); err != nil { + return nil, fmt.Errorf("scanning index row: %w", err) + } + idxs = append(idxs, idx) + } + return idxs, rows.Err() +} + +// queryConstraints queries INFORMATION_SCHEMA.TABLE_CONSTRAINTS for the given table. +func (o *mysqlContainer) queryConstraints(database, table string) ([]constraintInfo, error) { + rows, err := o.db.QueryContext(o.ctx, ` + SELECT CONSTRAINT_NAME, CONSTRAINT_TYPE + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS + WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? + ORDER BY CONSTRAINT_NAME`, + database, table) + if err != nil { + return nil, fmt.Errorf("querying constraints: %w", err) + } + defer rows.Close() + + var cs []constraintInfo + for rows.Next() { + var c constraintInfo + if err := rows.Scan(&c.Name, &c.Type); err != nil { + return nil, fmt.Errorf("scanning constraint row: %w", err) + } + cs = append(cs, c) + } + return cs, rows.Err() +} + +// splitStatements splits SQL text on semicolons, respecting single quotes, +// double quotes, and backtick-quoted identifiers. +func splitStatements(sqlStr string) []string { + var stmts []string + var current strings.Builder + var inQuote rune // 0 means not in a quote + var prevChar rune + + for _, ch := range sqlStr { + switch { + case inQuote != 0: + current.WriteRune(ch) + // End quote only if matching quote and not escaped by backslash. + if ch == inQuote && prevChar != '\\' { + inQuote = 0 + } + case ch == '\'' || ch == '"' || ch == '`': + inQuote = ch + current.WriteRune(ch) + case ch == ';': + stmt := strings.TrimSpace(current.String()) + if stmt != "" { + stmts = append(stmts, stmt) + } + current.Reset() + default: + current.WriteRune(ch) + } + prevChar = ch + } + + // Remaining text after the last semicolon (or if no semicolons). + if stmt := strings.TrimSpace(current.String()); stmt != "" { + stmts = append(stmts, stmt) + } + return stmts +} + +// normalizeWhitespace collapses runs of whitespace to a single space and trims. +func normalizeWhitespace(s string) string { + fields := strings.Fields(s) + return strings.Join(fields, " ") +} + +func TestContainerSmoke(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + + ctr, cleanup := startContainer(t) + defer cleanup() + + // Create a simple table. + err := ctr.execSQL("CREATE TABLE t1 (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL)") + if err != nil { + t.Fatalf("failed to create table: %v", err) + } + + // Verify SHOW CREATE TABLE works. + createStmt, err := ctr.showCreateTable("t1") + if err != nil { + t.Fatalf("SHOW CREATE TABLE failed: %v", err) + } + if !strings.Contains(createStmt, "CREATE TABLE") { + t.Errorf("expected CREATE TABLE in output, got: %s", createStmt) + } + t.Logf("SHOW CREATE TABLE t1:\n%s", createStmt) + + // Verify queryColumns works. + cols, err := ctr.queryColumns("test", "t1") + if err != nil { + t.Fatalf("queryColumns failed: %v", err) + } + if len(cols) != 2 { + t.Fatalf("expected 2 columns, got %d", len(cols)) + } + if cols[0].Name != "id" { + t.Errorf("expected first column 'id', got %q", cols[0].Name) + } + if cols[1].Name != "name" { + t.Errorf("expected second column 'name', got %q", cols[1].Name) + } + if cols[1].Nullable != "NO" { + t.Errorf("expected 'name' to be NOT NULL, got Nullable=%q", cols[1].Nullable) + } +} + +func TestSplitStatements(t *testing.T) { + tests := []struct { + input string + want []string + }{ + { + input: "CREATE TABLE t1 (id INT); CREATE TABLE t2 (id INT)", + want: []string{"CREATE TABLE t1 (id INT)", "CREATE TABLE t2 (id INT)"}, + }, + { + input: "INSERT INTO t1 VALUES ('a;b'); SELECT 1", + want: []string{"INSERT INTO t1 VALUES ('a;b')", "SELECT 1"}, + }, + { + input: `SELECT "col;name" FROM t1`, + want: []string{`SELECT "col;name" FROM t1`}, + }, + { + input: "SELECT `col;name` FROM t1", + want: []string{"SELECT `col;name` FROM t1"}, + }, + { + input: "", + want: nil, + }, + { + input: " ; ; ", + want: nil, + }, + } + for _, tt := range tests { + got := splitStatements(tt.input) + if len(got) != len(tt.want) { + t.Errorf("splitStatements(%q): got %d stmts, want %d", tt.input, len(got), len(tt.want)) + continue + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("splitStatements(%q)[%d] = %q, want %q", tt.input, i, got[i], tt.want[i]) + } + } + } +} + +func TestNormalizeWhitespace(t *testing.T) { + tests := []struct { + input, want string + }{ + {" hello world ", "hello world"}, + {"a\n\tb", "a b"}, + {"already clean", "already clean"}, + {"", ""}, + } + for _, tt := range tests { + got := normalizeWhitespace(tt.input) + if got != tt.want { + t.Errorf("normalizeWhitespace(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} diff --git a/tidb/catalog/database.go b/tidb/catalog/database.go new file mode 100644 index 00000000..06bd6767 --- /dev/null +++ b/tidb/catalog/database.go @@ -0,0 +1,37 @@ +package catalog + +import "strings" + +type Database struct { + Name string + Charset string + Collation string + Tables map[string]*Table // lowered name -> Table + Views map[string]*View + Functions map[string]*Routine // lowered name -> stored function + Procedures map[string]*Routine // lowered name -> stored procedure + Triggers map[string]*Trigger // lowered name -> trigger + Events map[string]*Event // lowered name -> event +} + +func newDatabase(name, charset, collation string) *Database { + return &Database{ + Name: name, + Charset: charset, + Collation: collation, + Tables: make(map[string]*Table), + Views: make(map[string]*View), + Functions: make(map[string]*Routine), + Procedures: make(map[string]*Routine), + Triggers: make(map[string]*Trigger), + Events: make(map[string]*Event), + } +} + +func (db *Database) GetTable(name string) *Table { + return db.Tables[toLower(name)] +} + +func toLower(s string) string { + return strings.ToLower(s) +} diff --git a/tidb/catalog/dbcmds.go b/tidb/catalog/dbcmds.go new file mode 100644 index 00000000..6e615e99 --- /dev/null +++ b/tidb/catalog/dbcmds.go @@ -0,0 +1,92 @@ +package catalog + +import nodes "github.com/bytebase/omni/tidb/ast" + +func (c *Catalog) createDatabase(stmt *nodes.CreateDatabaseStmt) error { + name := stmt.Name + key := toLower(name) + if c.databases[key] != nil { + if stmt.IfNotExists { + return nil + } + return errDupDatabase(name) + } + charset := c.defaultCharset + collation := c.defaultCollation + charsetExplicit := false + collationExplicit := false + for _, opt := range stmt.Options { + switch toLower(opt.Name) { + case "character set", "charset": + charset = opt.Value + charsetExplicit = true + case "collate": + collation = opt.Value + collationExplicit = true + } + } + // When charset is specified without explicit collation, derive the default collation. + if charsetExplicit && !collationExplicit { + if dc, ok := defaultCollationForCharset[toLower(charset)]; ok { + collation = dc + } + } + c.databases[key] = newDatabase(name, charset, collation) + return nil +} + +func (c *Catalog) dropDatabase(stmt *nodes.DropDatabaseStmt) error { + name := stmt.Name + key := toLower(name) + if c.databases[key] == nil { + if stmt.IfExists { + return nil + } + return errUnknownDatabase(name) + } + delete(c.databases, key) + if toLower(c.currentDB) == key { + c.currentDB = "" + } + return nil +} + +func (c *Catalog) useDatabase(stmt *nodes.UseStmt) error { + name := stmt.Database + key := toLower(name) + if c.databases[key] == nil { + return errUnknownDatabase(name) + } + c.currentDB = name + return nil +} + +func (c *Catalog) alterDatabase(stmt *nodes.AlterDatabaseStmt) error { + name := stmt.Name + if name == "" { + name = c.currentDB + } + db, err := c.resolveDatabase(name) + if err != nil { + return err + } + charsetExplicit := false + collationExplicit := false + for _, opt := range stmt.Options { + switch toLower(opt.Name) { + case "character set", "charset": + db.Charset = opt.Value + charsetExplicit = true + case "collate": + db.Collation = opt.Value + collationExplicit = true + } + } + // When charset is changed without explicit collation, derive the default collation. + if charsetExplicit && !collationExplicit { + if dc, ok := defaultCollationForCharset[toLower(db.Charset)]; ok { + db.Collation = dc + } + } + return nil +} diff --git a/tidb/catalog/dbcmds_test.go b/tidb/catalog/dbcmds_test.go new file mode 100644 index 00000000..241bada7 --- /dev/null +++ b/tidb/catalog/dbcmds_test.go @@ -0,0 +1,96 @@ +package catalog + +import "testing" + +func TestCreateDatabase(t *testing.T) { + c := New() + _, err := c.Exec("CREATE DATABASE mydb", nil) + if err != nil { + t.Fatal(err) + } + db := c.GetDatabase("mydb") + if db == nil { + t.Fatal("database not found") + } + if db.Name != "mydb" { + t.Errorf("expected name 'mydb', got %q", db.Name) + } +} + +func TestCreateDatabaseIfNotExists(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE mydb", nil) + results, _ := c.Exec("CREATE DATABASE IF NOT EXISTS mydb", nil) + if results[0].Error != nil { + t.Errorf("IF NOT EXISTS should not error: %v", results[0].Error) + } +} + +func TestCreateDatabaseDuplicate(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE mydb", nil) + results, _ := c.Exec("CREATE DATABASE mydb", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("expected duplicate database error") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("expected *Error, got %T", results[0].Error) + } + if catErr.Code != ErrDupDatabase { + t.Errorf("expected error code %d, got %d", ErrDupDatabase, catErr.Code) + } +} + +func TestDropDatabase(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE mydb", nil) + _, err := c.Exec("DROP DATABASE mydb", nil) + if err != nil { + t.Fatal(err) + } + if c.GetDatabase("mydb") != nil { + t.Fatal("database should be dropped") + } +} + +func TestDropDatabaseNotExists(t *testing.T) { + c := New() + results, _ := c.Exec("DROP DATABASE noexist", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("expected error for nonexistent database") + } +} + +func TestDropDatabaseIfExists(t *testing.T) { + c := New() + results, _ := c.Exec("DROP DATABASE IF EXISTS noexist", nil) + if results[0].Error != nil { + t.Errorf("IF EXISTS should not error: %v", results[0].Error) + } +} + +func TestCreateDatabaseCharset(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE mydb CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci", nil) + db := c.GetDatabase("mydb") + if db == nil { + t.Fatal("database not found") + } + if db.Charset != "utf8mb4" { + t.Errorf("expected charset utf8mb4, got %q", db.Charset) + } + if db.Collation != "utf8mb4_unicode_ci" { + t.Errorf("expected collation utf8mb4_unicode_ci, got %q", db.Collation) + } +} + +func TestDropDatabaseResetsCurrentDB(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE mydb", nil) + c.SetCurrentDatabase("mydb") + c.Exec("DROP DATABASE mydb", nil) + if c.CurrentDatabase() != "" { + t.Error("current database should be unset after drop") + } +} diff --git a/tidb/catalog/define.go b/tidb/catalog/define.go new file mode 100644 index 00000000..be569e3b --- /dev/null +++ b/tidb/catalog/define.go @@ -0,0 +1,169 @@ +// Package catalog — define.go +// +// AST-level Define API for catalog initialization. +// +// These entry points let callers install schema objects without going +// through the SQL lexer/parser, mirroring the shape of pg/catalog's +// Define{Relation,View,Enum,…} family. Each wrapper is a thin guard +// around the internal create* method reached today via Exec's +// processUtility dispatch; no additional validation or side effects +// are introduced beyond rejecting nil / empty-named stmts. +// +// # Design philosophy: loader, not validator +// +// This API is for cold-starting an in-memory catalog from structured +// metadata. It attempts a best-effort install of each object. +// Constraint checking (FK integrity, routine body validity, etc.) is +// explicitly NOT a responsibility of this API: +// +// - FKs installed while ForeignKeyChecks()==false are stored as-is +// and never revalidated. +// - Views whose SELECT body fails AnalyzeSelectStmt are still +// created, but with AnalyzedQuery == nil and a nil error. +// +// Callers who need validated state must either load with FK checks on +// (paying the topological-order cost) or run their own downstream +// checks. +// +// # Preconditions (per call) +// +// - SetCurrentDatabase(name) MUST be called whenever the stmt does +// not carry an explicit database qualifier. Missing currentDB +// surfaces as *Error{Code: ErrNoDatabaseSelected}. +// - SetForeignKeyChecks(false) is typically required while +// bulk-loading schemas with forward FK references. Re-enabling it +// after the load does not retroactively validate FKs already +// installed. +// - Topological ordering across kinds is the caller's responsibility: +// DefineTrigger and DefineIndex on a not-yet-installed target +// table return *Error{Code: ErrNoSuchTable}. DefineView tolerates +// forward refs but yields AnalyzedQuery=nil. +// +// # Error contract +// +// Every Define* returns error, always of concrete type *Error when +// non-nil. Callers may inspect err.(*Error).Code (ErrDupTable, +// ErrNoDatabaseSelected, ErrWrongArguments, etc.) for idempotency and +// fallback decisions. On error, no catalog state is written; the call +// is atomic at the object level. +// +// # Concurrency +// +// *Catalog is NOT goroutine-safe. The underlying maps have no sync +// primitives. Callers MUST serialize Define* calls on a given Catalog. +// +// # Nil / empty-name guards +// +// Every Define* rejects nil stmt and empty required names with +// *Error{Code: ErrWrongArguments} rather than panicking. Per-kind +// required fields: +// +// DefineDatabase: stmt.Name +// DefineTable: stmt.Table.Name +// DefineView: stmt.Name.Name +// DefineIndex: stmt.Table.Name (and stmt.IndexName for the index itself) +// DefineFunction, +// DefineProcedure, +// DefineRoutine: stmt.Name.Name +// DefineTrigger: stmt.Name and stmt.Table.Name +// DefineEvent: stmt.Name +package catalog + +import nodes "github.com/bytebase/omni/tidb/ast" + +// DefineDatabase installs a database. stmt.Name must be non-empty. +func (c *Catalog) DefineDatabase(stmt *nodes.CreateDatabaseStmt) error { + if stmt == nil || stmt.Name == "" { + return errWrongArguments("DefineDatabase") + } + return c.createDatabase(stmt) +} + +// DefineTable installs a table (including inline columns, indexes, +// foreign keys, CHECK constraints, and partitions). stmt.Table and +// stmt.Table.Name must be non-nil/non-empty. +// +// Foreign-key validity depends on the current foreign_key_checks +// session flag. See package doc. +func (c *Catalog) DefineTable(stmt *nodes.CreateTableStmt) error { + if stmt == nil || stmt.Table == nil || stmt.Table.Name == "" { + return errWrongArguments("DefineTable") + } + return c.createTable(stmt) +} + +// DefineView installs a view. stmt.Name and stmt.Name.Name must be +// non-nil/non-empty. +// +// If AnalyzeSelectStmt on the view body fails (e.g. referenced table +// is not yet installed), DefineView still returns nil and the view is +// stored with AnalyzedQuery=nil. This is intentional loader behavior. +func (c *Catalog) DefineView(stmt *nodes.CreateViewStmt) error { + if stmt == nil || stmt.Name == nil || stmt.Name.Name == "" { + return errWrongArguments("DefineView") + } + return c.createView(stmt) +} + +// DefineIndex installs an index on an existing table. +func (c *Catalog) DefineIndex(stmt *nodes.CreateIndexStmt) error { + if stmt == nil || stmt.Table == nil || stmt.Table.Name == "" { + return errWrongArguments("DefineIndex") + } + return c.createIndex(stmt) +} + +// DefineFunction installs a stored function. +// +// Routing note: this function is a thin wrapper over createRoutine. +// The catalog routes the stmt to db.Functions or db.Procedures based +// on stmt.IsProcedure. Callers who set IsProcedure=true on a stmt +// passed to DefineFunction will land in db.Procedures — no kind-guard +// is applied. Use DefineFunction at call sites where intent is known +// to be a function, for readability; use DefineRoutine for generic +// paths. +func (c *Catalog) DefineFunction(stmt *nodes.CreateFunctionStmt) error { + if stmt == nil || stmt.Name == nil || stmt.Name.Name == "" { + return errWrongArguments("DefineFunction") + } + return c.createRoutine(stmt) +} + +// DefineProcedure installs a stored procedure. +// +// See DefineFunction for the routing semantics — this wrapper is +// identical and exists purely for call-site clarity when the caller +// knows the intent is a procedure. +func (c *Catalog) DefineProcedure(stmt *nodes.CreateFunctionStmt) error { + if stmt == nil || stmt.Name == nil || stmt.Name.Name == "" { + return errWrongArguments("DefineProcedure") + } + return c.createRoutine(stmt) +} + +// DefineRoutine installs a function or procedure, routed by +// stmt.IsProcedure. Use this when the caller does not statically know +// the kind (e.g. bulk loaders processing heterogeneous metadata). +func (c *Catalog) DefineRoutine(stmt *nodes.CreateFunctionStmt) error { + if stmt == nil || stmt.Name == nil || stmt.Name.Name == "" { + return errWrongArguments("DefineRoutine") + } + return c.createRoutine(stmt) +} + +// DefineTrigger installs a trigger on an existing table. stmt.Name +// (trigger name) and stmt.Table.Name must be non-empty. +func (c *Catalog) DefineTrigger(stmt *nodes.CreateTriggerStmt) error { + if stmt == nil || stmt.Name == "" || stmt.Table == nil || stmt.Table.Name == "" { + return errWrongArguments("DefineTrigger") + } + return c.createTrigger(stmt) +} + +// DefineEvent installs an event in the current or specified database. +func (c *Catalog) DefineEvent(stmt *nodes.CreateEventStmt) error { + if stmt == nil || stmt.Name == "" { + return errWrongArguments("DefineEvent") + } + return c.createEvent(stmt) +} diff --git a/tidb/catalog/define_test.go b/tidb/catalog/define_test.go new file mode 100644 index 00000000..392bf674 --- /dev/null +++ b/tidb/catalog/define_test.go @@ -0,0 +1,592 @@ +package catalog + +import ( + "errors" + "testing" + + nodes "github.com/bytebase/omni/tidb/ast" + "github.com/bytebase/omni/tidb/parser" +) + +// -- helpers ------------------------------------------------------------- + +// parseFirst parses sql and returns the single top-level statement. +func parseFirst(t *testing.T, sql string) nodes.Node { + t.Helper() + list, err := parser.Parse(sql) + if err != nil { + t.Fatalf("parse error: %v\nSQL: %s", err, sql) + } + if list == nil || len(list.Items) != 1 { + t.Fatalf("expected 1 statement, got %d", len(list.Items)) + } + return list.Items[0] +} + +func mustParseDatabase(t *testing.T, sql string) *nodes.CreateDatabaseStmt { + t.Helper() + stmt, ok := parseFirst(t, sql).(*nodes.CreateDatabaseStmt) + if !ok { + t.Fatalf("not a CreateDatabaseStmt: %T", parseFirst(t, sql)) + } + return stmt +} + +func mustParseTable(t *testing.T, sql string) *nodes.CreateTableStmt { + t.Helper() + stmt, ok := parseFirst(t, sql).(*nodes.CreateTableStmt) + if !ok { + t.Fatalf("not a CreateTableStmt: %T", parseFirst(t, sql)) + } + return stmt +} + +func mustParseView(t *testing.T, sql string) *nodes.CreateViewStmt { + t.Helper() + stmt, ok := parseFirst(t, sql).(*nodes.CreateViewStmt) + if !ok { + t.Fatalf("not a CreateViewStmt: %T", parseFirst(t, sql)) + } + return stmt +} + +func mustParseIndex(t *testing.T, sql string) *nodes.CreateIndexStmt { + t.Helper() + stmt, ok := parseFirst(t, sql).(*nodes.CreateIndexStmt) + if !ok { + t.Fatalf("not a CreateIndexStmt: %T", parseFirst(t, sql)) + } + return stmt +} + +func mustParseRoutine(t *testing.T, sql string) *nodes.CreateFunctionStmt { + t.Helper() + stmt, ok := parseFirst(t, sql).(*nodes.CreateFunctionStmt) + if !ok { + t.Fatalf("not a CreateFunctionStmt: %T", parseFirst(t, sql)) + } + return stmt +} + +func mustParseTrigger(t *testing.T, sql string) *nodes.CreateTriggerStmt { + t.Helper() + stmt, ok := parseFirst(t, sql).(*nodes.CreateTriggerStmt) + if !ok { + t.Fatalf("not a CreateTriggerStmt: %T", parseFirst(t, sql)) + } + return stmt +} + +func mustParseEvent(t *testing.T, sql string) *nodes.CreateEventStmt { + t.Helper() + stmt, ok := parseFirst(t, sql).(*nodes.CreateEventStmt) + if !ok { + t.Fatalf("not a CreateEventStmt: %T", parseFirst(t, sql)) + } + return stmt +} + +// assertErrCode fails the test unless err is a *Error with the given code. +func assertErrCode(t *testing.T, err error, code int) { + t.Helper() + if err == nil { + t.Fatalf("expected *Error(code=%d), got nil", code) + } + var e *Error + if !errors.As(err, &e) { + t.Fatalf("expected *Error, got %T: %v", err, err) + } + if e.Code != code { + t.Fatalf("expected code %d, got %d (%s)", code, e.Code, e.Message) + } +} + +// newCatalogWithDB returns a fresh catalog with db created and selected. +func newCatalogWithDB(t *testing.T, name string) *Catalog { + t.Helper() + c := New() + if _, err := c.Exec("CREATE DATABASE "+name+"; USE "+name+";", nil); err != nil { + t.Fatalf("setup: %v", err) + } + return c +} + +// -- §6.1 Happy-path per kind ------------------------------------------- + +func TestDefineDatabase_HappyPath(t *testing.T) { + c := New() + if err := c.DefineDatabase(mustParseDatabase(t, "CREATE DATABASE mydb")); err != nil { + t.Fatalf("DefineDatabase: %v", err) + } + if db := c.GetDatabase("mydb"); db == nil || db.Name != "mydb" { + t.Fatalf("database not registered: %+v", db) + } +} + +func TestDefineTable_HappyPath(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + stmt := mustParseTable(t, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(50))") + if err := c.DefineTable(stmt); err != nil { + t.Fatalf("DefineTable: %v", err) + } + got := c.ShowCreateTable("mydb", "t") + if got == "" { + t.Fatal("ShowCreateTable returned empty") + } +} + +func TestDefineView_HappyPath(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + if err := c.DefineTable(mustParseTable(t, "CREATE TABLE t (id INT PRIMARY KEY)")); err != nil { + t.Fatal(err) + } + stmt := mustParseView(t, "CREATE VIEW v AS SELECT id FROM t") + if err := c.DefineView(stmt); err != nil { + t.Fatalf("DefineView: %v", err) + } + db := c.GetDatabase("mydb") + view := db.Views["v"] + if view == nil { + t.Fatal("view not registered") + } + if view.AnalyzedQuery == nil { + t.Fatal("AnalyzedQuery should be non-nil when referenced table is installed") + } +} + +func TestDefineIndex_HappyPath(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + if err := c.DefineTable(mustParseTable(t, "CREATE TABLE t (id INT, name VARCHAR(50))")); err != nil { + t.Fatal(err) + } + if err := c.DefineIndex(mustParseIndex(t, "CREATE INDEX idx_name ON t (name)")); err != nil { + t.Fatalf("DefineIndex: %v", err) + } + db := c.GetDatabase("mydb") + tbl := db.Tables["t"] + if tbl == nil { + t.Fatal("table missing") + } + found := false + for _, idx := range tbl.Indexes { + if idx.Name == "idx_name" { + found = true + break + } + } + if !found { + t.Fatalf("idx_name not found in table indexes") + } +} + +func TestDefineFunction_HappyPath(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + stmt := mustParseRoutine(t, "CREATE FUNCTION f() RETURNS INT RETURN 1") + if err := c.DefineFunction(stmt); err != nil { + t.Fatalf("DefineFunction: %v", err) + } + if got := c.ShowCreateFunction("mydb", "f"); got == "" { + t.Fatal("ShowCreateFunction empty") + } +} + +func TestDefineProcedure_HappyPath(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + stmt := mustParseRoutine(t, "CREATE PROCEDURE p() BEGIN END") + if err := c.DefineProcedure(stmt); err != nil { + t.Fatalf("DefineProcedure: %v", err) + } + if got := c.ShowCreateProcedure("mydb", "p"); got == "" { + t.Fatal("ShowCreateProcedure empty") + } +} + +func TestDefineRoutine_FunctionRoutesToFunctions(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + stmt := mustParseRoutine(t, "CREATE FUNCTION f() RETURNS INT RETURN 1") + if stmt.IsProcedure { + t.Fatalf("expected IsProcedure=false for CREATE FUNCTION") + } + if err := c.DefineRoutine(stmt); err != nil { + t.Fatal(err) + } + db := c.GetDatabase("mydb") + if _, ok := db.Functions["f"]; !ok { + t.Fatal("f not in db.Functions") + } + if _, ok := db.Procedures["f"]; ok { + t.Fatal("f should not be in db.Procedures") + } +} + +func TestDefineTrigger_HappyPath(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + if err := c.DefineTable(mustParseTable(t, "CREATE TABLE t (id INT PRIMARY KEY)")); err != nil { + t.Fatal(err) + } + stmt := mustParseTrigger(t, "CREATE TRIGGER trg_ins BEFORE INSERT ON t FOR EACH ROW BEGIN END") + if err := c.DefineTrigger(stmt); err != nil { + t.Fatalf("DefineTrigger: %v", err) + } + if got := c.ShowCreateTrigger("mydb", "trg_ins"); got == "" { + t.Fatal("ShowCreateTrigger empty") + } +} + +func TestDefineEvent_HappyPath(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + stmt := mustParseEvent(t, "CREATE EVENT e ON SCHEDULE EVERY 1 DAY DO SELECT 1") + if err := c.DefineEvent(stmt); err != nil { + t.Fatalf("DefineEvent: %v", err) + } + if got := c.ShowCreateEvent("mydb", "e"); got == "" { + t.Fatal("ShowCreateEvent empty") + } +} + +// -- §6.2 Duplicate install ---------------------------------------------- + +func TestDefineDatabase_Duplicate(t *testing.T) { + c := New() + if err := c.DefineDatabase(mustParseDatabase(t, "CREATE DATABASE mydb")); err != nil { + t.Fatal(err) + } + err := c.DefineDatabase(mustParseDatabase(t, "CREATE DATABASE mydb")) + assertErrCode(t, err, ErrDupDatabase) +} + +func TestDefineTable_Duplicate(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + if err := c.DefineTable(mustParseTable(t, "CREATE TABLE t (id INT)")); err != nil { + t.Fatal(err) + } + err := c.DefineTable(mustParseTable(t, "CREATE TABLE t (id INT)")) + assertErrCode(t, err, ErrDupTable) +} + +func TestDefineFunction_Duplicate(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + stmt := mustParseRoutine(t, "CREATE FUNCTION f() RETURNS INT RETURN 1") + if err := c.DefineFunction(stmt); err != nil { + t.Fatal(err) + } + err := c.DefineFunction(mustParseRoutine(t, "CREATE FUNCTION f() RETURNS INT RETURN 1")) + assertErrCode(t, err, ErrDupFunction) +} + +func TestDefineTrigger_Duplicate(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + if err := c.DefineTable(mustParseTable(t, "CREATE TABLE t (id INT)")); err != nil { + t.Fatal(err) + } + stmt := mustParseTrigger(t, "CREATE TRIGGER trg BEFORE INSERT ON t FOR EACH ROW BEGIN END") + if err := c.DefineTrigger(stmt); err != nil { + t.Fatal(err) + } + err := c.DefineTrigger(mustParseTrigger(t, "CREATE TRIGGER trg BEFORE INSERT ON t FOR EACH ROW BEGIN END")) + assertErrCode(t, err, ErrDupTrigger) +} + +func TestDefineEvent_Duplicate(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + stmt := mustParseEvent(t, "CREATE EVENT e ON SCHEDULE EVERY 1 DAY DO SELECT 1") + if err := c.DefineEvent(stmt); err != nil { + t.Fatal(err) + } + err := c.DefineEvent(mustParseEvent(t, "CREATE EVENT e ON SCHEDULE EVERY 1 DAY DO SELECT 1")) + assertErrCode(t, err, ErrDupEvent) +} + +// -- §6.3 Missing currentDB ---------------------------------------------- + +func TestDefineTable_NoCurrentDatabase(t *testing.T) { + c := New() + err := c.DefineTable(mustParseTable(t, "CREATE TABLE t (id INT)")) + assertErrCode(t, err, ErrNoDatabaseSelected) +} + +func TestDefineFunction_NoCurrentDatabase(t *testing.T) { + c := New() + err := c.DefineFunction(mustParseRoutine(t, "CREATE FUNCTION f() RETURNS INT RETURN 1")) + assertErrCode(t, err, ErrNoDatabaseSelected) +} + +// -- §6.4 Nil / incomplete stmt ------------------------------------------ + +func TestDefine_NilStmt(t *testing.T) { + c := New() + assertErrCode(t, c.DefineDatabase(nil), ErrWrongArguments) + assertErrCode(t, c.DefineTable(nil), ErrWrongArguments) + assertErrCode(t, c.DefineView(nil), ErrWrongArguments) + assertErrCode(t, c.DefineIndex(nil), ErrWrongArguments) + assertErrCode(t, c.DefineFunction(nil), ErrWrongArguments) + assertErrCode(t, c.DefineProcedure(nil), ErrWrongArguments) + assertErrCode(t, c.DefineRoutine(nil), ErrWrongArguments) + assertErrCode(t, c.DefineTrigger(nil), ErrWrongArguments) + assertErrCode(t, c.DefineEvent(nil), ErrWrongArguments) +} + +func TestDefine_IncompleteStmt(t *testing.T) { + c := New() + assertErrCode(t, c.DefineDatabase(&nodes.CreateDatabaseStmt{}), ErrWrongArguments) + assertErrCode(t, c.DefineTable(&nodes.CreateTableStmt{}), ErrWrongArguments) + assertErrCode(t, c.DefineTable(&nodes.CreateTableStmt{Table: &nodes.TableRef{}}), ErrWrongArguments) + assertErrCode(t, c.DefineView(&nodes.CreateViewStmt{}), ErrWrongArguments) + assertErrCode(t, c.DefineIndex(&nodes.CreateIndexStmt{}), ErrWrongArguments) + assertErrCode(t, c.DefineFunction(&nodes.CreateFunctionStmt{}), ErrWrongArguments) + assertErrCode(t, c.DefineTrigger(&nodes.CreateTriggerStmt{}), ErrWrongArguments) + assertErrCode(t, c.DefineTrigger(&nodes.CreateTriggerStmt{Name: "x"}), ErrWrongArguments) + assertErrCode(t, c.DefineEvent(&nodes.CreateEventStmt{}), ErrWrongArguments) +} + +// -- §6.5 Cross-database explicit schema -------------------------------- + +func TestDefineTable_ExplicitSchemaBypassesCurrentDB(t *testing.T) { + c := New() + if _, err := c.Exec("CREATE DATABASE main_db; CREATE DATABASE other_db; USE main_db;", nil); err != nil { + t.Fatal(err) + } + // currentDB is main_db; install into other_db via qualifier. + stmt := mustParseTable(t, "CREATE TABLE other_db.t (id INT)") + if err := c.DefineTable(stmt); err != nil { + t.Fatalf("DefineTable: %v", err) + } + if tbl := c.GetDatabase("other_db").Tables["t"]; tbl == nil { + t.Fatal("table should be in other_db") + } + if tbl := c.GetDatabase("main_db").Tables["t"]; tbl != nil { + t.Fatal("table should NOT be in main_db") + } +} + +// -- §6.6 FK forward reference under foreignKeyChecks=false ------------- + +func TestDefineTable_FKForwardRefWithChecksOff(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + c.SetForeignKeyChecks(false) + defer c.SetForeignKeyChecks(true) + + // child references parent which is NOT yet installed. + child := mustParseTable(t, ` + CREATE TABLE child ( + id INT PRIMARY KEY, + parent_id INT, + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent (id) + )`) + if err := c.DefineTable(child); err != nil { + t.Fatalf("child install with checks off should succeed, got: %v", err) + } + + parent := mustParseTable(t, "CREATE TABLE parent (id INT PRIMARY KEY)") + if err := c.DefineTable(parent); err != nil { + t.Fatalf("parent install: %v", err) + } + + db := c.GetDatabase("mydb") + if db.Tables["child"] == nil || db.Tables["parent"] == nil { + t.Fatal("both tables should be present") + } + // Confirm the FK struct was carried on child even though unvalidated. + var fkFound bool + for _, con := range db.Tables["child"].Constraints { + if con.Type == ConForeignKey { + fkFound = true + break + } + } + if !fkFound { + t.Fatal("child should still carry FK constraint struct (unvalidated but present)") + } +} + +func TestDefineTable_FKForwardRefWithChecksOn(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + // Default is FK checks ON. + child := mustParseTable(t, ` + CREATE TABLE child ( + id INT PRIMARY KEY, + parent_id INT, + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent (id) + )`) + err := c.DefineTable(child) + assertErrCode(t, err, ErrFKNoRefTable) +} + +// -- §6.7 Trigger ordering (both directions) ---------------------------- + +func TestDefineTrigger_BeforeTable(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + stmt := mustParseTrigger(t, "CREATE TRIGGER trg BEFORE INSERT ON t FOR EACH ROW BEGIN END") + err := c.DefineTrigger(stmt) + assertErrCode(t, err, ErrNoSuchTable) +} + +func TestDefineTrigger_AfterTable(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + if err := c.DefineTable(mustParseTable(t, "CREATE TABLE t (id INT)")); err != nil { + t.Fatal(err) + } + stmt := mustParseTrigger(t, "CREATE TRIGGER trg BEFORE INSERT ON t FOR EACH ROW BEGIN END") + if err := c.DefineTrigger(stmt); err != nil { + t.Fatalf("DefineTrigger after table install should succeed: %v", err) + } +} + +// -- §6.8 View forward reference — loader contract ---------------------- + +func TestDefineView_ForwardRefDegradesGracefully(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + // View references a table that doesn't exist yet. + stmt := mustParseView(t, "CREATE VIEW v AS SELECT id FROM missing_table") + if err := c.DefineView(stmt); err != nil { + t.Fatalf("DefineView with forward ref should NOT error (loader contract): %v", err) + } + view := c.GetDatabase("mydb").Views["v"] + if view == nil { + t.Fatal("view should be registered even without resolved deps") + } + if view.AnalyzedQuery != nil { + t.Fatal("AnalyzedQuery should be nil when reference cannot be resolved") + } +} + +func TestDefineView_ResolvedRefHasAnalyzedQuery(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + if err := c.DefineTable(mustParseTable(t, "CREATE TABLE t (id INT)")); err != nil { + t.Fatal(err) + } + if err := c.DefineView(mustParseView(t, "CREATE VIEW v AS SELECT id FROM t")); err != nil { + t.Fatal(err) + } + view := c.GetDatabase("mydb").Views["v"] + if view.AnalyzedQuery == nil { + t.Fatal("AnalyzedQuery should be populated when reference resolves") + } +} + +// -- §6.9 Routine dispatch by IsProcedure ------------------------------- + +func TestDefineRoutine_ProcedureRoutesToProcedures(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + stmt := mustParseRoutine(t, "CREATE PROCEDURE p() BEGIN END") + if !stmt.IsProcedure { + t.Fatalf("expected IsProcedure=true for CREATE PROCEDURE") + } + if err := c.DefineRoutine(stmt); err != nil { + t.Fatal(err) + } + db := c.GetDatabase("mydb") + if _, ok := db.Procedures["p"]; !ok { + t.Fatal("p not in db.Procedures") + } + if _, ok := db.Functions["p"]; ok { + t.Fatal("p should not be in db.Functions") + } +} + +// DefineFunction with IsProcedure=true routes by stmt bit (no kind guard). +// This documents the loader philosophy: we do not reject mismatched +// metadata; we route it where the stmt says it belongs. +func TestDefineFunction_WithProcedureStmt_RoutesByBit(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + stmt := mustParseRoutine(t, "CREATE PROCEDURE p() BEGIN END") + if !stmt.IsProcedure { + t.Fatal("expected IsProcedure=true") + } + if err := c.DefineFunction(stmt); err != nil { + t.Fatalf("DefineFunction with procedure stmt should not error; loader routes by stmt bit: %v", err) + } + db := c.GetDatabase("mydb") + if _, ok := db.Procedures["p"]; !ok { + t.Fatal("procedure should land in db.Procedures regardless of entry-point name") + } +} + +// -- §6.10 Partial-load downstream usability ---------------------------- + +// A loader with a broken table should still produce a catalog where +// queries over healthy tables analyze correctly. This is the end-to-end +// proof that the loader contract delivers isolation. +func TestDefine_PartialLoadPreservesHealthyAnalysis(t *testing.T) { + c := newCatalogWithDB(t, "mydb") + c.SetForeignKeyChecks(false) + defer c.SetForeignKeyChecks(true) + + // Two healthy tables. + if err := c.DefineTable(mustParseTable(t, "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(50))")); err != nil { + t.Fatal(err) + } + if err := c.DefineTable(mustParseTable(t, "CREATE TABLE orders (id INT PRIMARY KEY, user_id INT, amount INT)")); err != nil { + t.Fatal(err) + } + // One broken table: FK to non-existent table. Under checks-off, it still installs. + broken := mustParseTable(t, ` + CREATE TABLE broken ( + id INT PRIMARY KEY, + ghost_id INT, + CONSTRAINT fk_ghost FOREIGN KEY (ghost_id) REFERENCES ghost_table (id) + )`) + if err := c.DefineTable(broken); err != nil { + t.Fatalf("broken table install should succeed under checks-off: %v", err) + } + + // Analyze a query joining the two healthy tables. Must not be poisoned + // by the broken table's dangling FK. + stmts, err := parser.Parse("SELECT u.name, o.amount FROM users u JOIN orders o ON o.user_id = u.id") + if err != nil { + t.Fatalf("parse: %v", err) + } + sel, ok := stmts.Items[0].(*nodes.SelectStmt) + if !ok { + t.Fatalf("not a SelectStmt") + } + q, err := c.AnalyzeSelectStmt(sel) + if err != nil { + t.Fatalf("AnalyzeSelectStmt on healthy tables should succeed despite broken peer: %v", err) + } + if q == nil { + t.Fatal("Query should be non-nil") + } +} + +// -- §6.11 Parity with Exec (scoped to CREATE kinds) -------------------- + +func TestDefineTable_ParityWithExec(t *testing.T) { + stmt := "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(50))" + + cDefine := newCatalogWithDB(t, "mydb") + if err := cDefine.DefineTable(mustParseTable(t, stmt)); err != nil { + t.Fatal(err) + } + + cExec := newCatalogWithDB(t, "mydb") + if _, err := cExec.Exec(stmt, nil); err != nil { + t.Fatal(err) + } + + if a, b := cDefine.ShowCreateTable("mydb", "t"), cExec.ShowCreateTable("mydb", "t"); a != b { + t.Fatalf("ShowCreateTable diverges:\nDefine:\n%s\nExec:\n%s", a, b) + } +} + +func TestDefineView_ParityWithExec(t *testing.T) { + setup := "CREATE TABLE t (id INT);" + stmt := "CREATE VIEW v AS SELECT id FROM t" + + cDefine := newCatalogWithDB(t, "mydb") + if _, err := cDefine.Exec(setup, nil); err != nil { + t.Fatal(err) + } + if err := cDefine.DefineView(mustParseView(t, stmt)); err != nil { + t.Fatal(err) + } + + cExec := newCatalogWithDB(t, "mydb") + if _, err := cExec.Exec(setup+stmt, nil); err != nil { + t.Fatal(err) + } + + if a, b := cDefine.ShowCreateView("mydb", "v"), cExec.ShowCreateView("mydb", "v"); a != b { + t.Fatalf("ShowCreateView diverges:\nDefine:\n%s\nExec:\n%s", a, b) + } +} diff --git a/tidb/catalog/deparse_container_test.go b/tidb/catalog/deparse_container_test.go new file mode 100644 index 00000000..d132f737 --- /dev/null +++ b/tidb/catalog/deparse_container_test.go @@ -0,0 +1,2525 @@ +package catalog + +import ( + "fmt" + "strings" + "testing" + + nodes "github.com/bytebase/omni/tidb/ast" + "github.com/bytebase/omni/tidb/deparse" + "github.com/bytebase/omni/tidb/parser" +) + +// extractExprFromView extracts the expression portion from SHOW CREATE VIEW output. +// MySQL 8.0 SHOW CREATE VIEW returns: +// +// CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`%` SQL SECURITY DEFINER VIEW `test`.`v` AS select AS `` from `test`.`t` +// +// We extract — the part between "AS select " and the first " AS `". +func extractExprFromView(showCreate string) string { + idx := strings.Index(showCreate, " AS select ") + if idx < 0 { + return showCreate + } + selectPart := showCreate[idx+len(" AS select "):] + + // Find " AS `" which marks the column alias + aliasIdx := strings.Index(selectPart, " AS `") + if aliasIdx < 0 { + return selectPart + } + return selectPart[:aliasIdx] +} + +// deparseExprForOracle parses a SQL expression and deparses it via our deparser. +func deparseExprForOracle(t *testing.T, expr string) string { + t.Helper() + sql := "SELECT " + expr + " FROM t" + stmts, err := parser.Parse(sql) + if err != nil { + t.Fatalf("failed to parse %q: %v", sql, err) + } + if stmts.Len() == 0 { + t.Fatalf("no statements parsed from %q", sql) + } + sel, ok := stmts.Items[0].(*nodes.SelectStmt) + if !ok { + t.Fatalf("expected SelectStmt, got %T", stmts.Items[0]) + } + if len(sel.TargetList) == 0 { + t.Fatalf("no target list in SELECT from %q", sql) + } + target := sel.TargetList[0] + if rt, ok := target.(*nodes.ResTarget); ok { + return deparse.Deparse(rt.Val) + } + return deparse.Deparse(target) +} + +// deparseExprRewriteForOracle parses a SQL expression, applies RewriteExpr, and deparses. +func deparseExprRewriteForOracle(t *testing.T, expr string) string { + t.Helper() + sql := "SELECT " + expr + " FROM t" + stmts, err := parser.Parse(sql) + if err != nil { + t.Fatalf("failed to parse %q: %v", sql, err) + } + if stmts.Len() == 0 { + t.Fatalf("no statements parsed from %q", sql) + } + sel, ok := stmts.Items[0].(*nodes.SelectStmt) + if !ok { + t.Fatalf("expected SelectStmt, got %T", stmts.Items[0]) + } + if len(sel.TargetList) == 0 { + t.Fatalf("no target list in SELECT from %q", sql) + } + target := sel.TargetList[0] + if rt, ok := target.(*nodes.ResTarget); ok { + return deparse.Deparse(deparse.RewriteExpr(rt.Val)) + } + return deparse.Deparse(deparse.RewriteExpr(target)) +} + +// TestDeparse_Section_4_1_Container verifies NOT folding against MySQL 8.0. +func TestDeparse_Section_4_1_Container(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Create base table with integer column for boolean/comparison tests + ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT)") + + cases := []struct { + name string + input string // expression to test + }{ + {"not_gt", "NOT (a > 0)"}, + {"not_lt", "NOT (a < 0)"}, + {"not_ge", "NOT (a >= 0)"}, + {"not_le", "NOT (a <= 0)"}, + {"not_eq", "NOT (a = 0)"}, + {"not_ne", "NOT (a <> 0)"}, + {"not_col", "NOT a"}, + {"not_add", "NOT (a + 1)"}, + {"bang_col", "!a"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + viewName := "v_" + tc.name + selectSQL := fmt.Sprintf("SELECT %s FROM t", tc.input) + + // Create view on MySQL 8.0 + ctr.execSQLDirect("DROP VIEW IF EXISTS " + viewName) + createSQL := fmt.Sprintf("CREATE VIEW %s AS %s", viewName, selectSQL) + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("CREATE VIEW failed: %v", err) + } + + mysqlOutput, err := ctr.showCreateView(viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW failed: %v", err) + } + + // Extract just the expression from MySQL's output + mysqlExpr := extractExprFromView(mysqlOutput) + + // Our deparser output (with rewrite) + omniExpr := deparseExprRewriteForOracle(t, tc.input) + + t.Logf("MySQL: %s", mysqlExpr) + t.Logf("Omni: %s", omniExpr) + + if mysqlExpr != omniExpr { + t.Errorf("mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlExpr, omniExpr) + } + }) + } +} + +// TestDeparse_Section_3_2_Container verifies TRIM special forms against MySQL 8.0. +func TestDeparse_Section_3_2_Container(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Create base table + ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a VARCHAR(50), b VARCHAR(50))") + + cases := []struct { + name string + input string // expression to test + }{ + {"trim_simple", "TRIM(a)"}, + {"trim_leading", "TRIM(LEADING 'x' FROM a)"}, + {"trim_trailing", "TRIM(TRAILING 'x' FROM a)"}, + {"trim_both", "TRIM(BOTH 'x' FROM a)"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + viewName := "v_" + tc.name + selectSQL := fmt.Sprintf("SELECT %s FROM t", tc.input) + + // Create view on MySQL 8.0 + ctr.execSQLDirect("DROP VIEW IF EXISTS " + viewName) + createSQL := fmt.Sprintf("CREATE VIEW %s AS %s", viewName, selectSQL) + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("CREATE VIEW failed: %v", err) + } + + mysqlOutput, err := ctr.showCreateView(viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW failed: %v", err) + } + + // Extract just the expression from MySQL's output + mysqlExpr := extractExprFromView(mysqlOutput) + + // Our deparser output + omniExpr := deparseExprForOracle(t, tc.input) + + t.Logf("MySQL: %s", mysqlExpr) + t.Logf("Omni: %s", omniExpr) + + if mysqlExpr != omniExpr { + t.Errorf("mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlExpr, omniExpr) + } + }) + } +} + +// extractSelectBody extracts the SELECT body from SHOW CREATE VIEW output. +// MySQL 8.0 format: +// +// CREATE ALGORITHM=... VIEW `test`.`v` AS select ... +// +// Our catalog format: +// +// CREATE ALGORITHM=... VIEW `v` AS select ... +// +// We extract everything after " AS " (the first occurrence after VIEW). +func extractSelectBody(showCreate string) string { + // Find "VIEW " to locate the view name portion, then find " AS " after that. + viewIdx := strings.Index(showCreate, "VIEW ") + if viewIdx < 0 { + return showCreate + } + rest := showCreate[viewIdx:] + asIdx := strings.Index(rest, " AS ") + if asIdx < 0 { + return showCreate + } + return rest[asIdx+len(" AS "):] +} + +// TestDeparse_Section_7_2_SimpleViews verifies that our catalog's SHOW CREATE VIEW +// output matches MySQL 8.0's output for simple view definitions. +// For each view, we compare the SELECT body portion (after "AS "). +func TestDeparse_Section_7_2_SimpleViews(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table on MySQL: %v", err) + } + + cases := []struct { + name string + createAs string // the SELECT portion after CREATE VIEW v AS + }{ + {"select_constant", "SELECT 1"}, + {"select_column", "SELECT a FROM t"}, + {"select_alias", "SELECT a AS col1 FROM t"}, + {"select_multi_columns", "SELECT a, b FROM t"}, + {"select_where", "SELECT a FROM t WHERE a > 0"}, + {"select_orderby_limit", "SELECT a FROM t ORDER BY a LIMIT 10"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + viewName := "v_" + tc.name + + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + viewName) + createSQL := fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs) + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("CREATE VIEW on MySQL failed: %v", err) + } + mysqlOutput, err := ctr.showCreateView(viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := extractSelectBody(mysqlOutput) + + // --- Our catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs), nil) + if results[0].Error != nil { + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", viewName) + if omniOutput == "" { + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL full: %s", mysqlOutput) + t.Logf("Omni full: %s", omniOutput) + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparse_Section_7_4_JoinViews verifies that our catalog's SHOW CREATE VIEW +// output matches MySQL 8.0's output for views with JOINs (INNER JOIN, LEFT JOIN, +// multiple tables, subquery in FROM). +func TestDeparse_Section_7_4_JoinViews(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base tables on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table t on MySQL: %v", err) + } + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t1 (a INT, b INT)"); err != nil { + t.Fatalf("failed to create table t1 on MySQL: %v", err) + } + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t2 (a INT, b INT)"); err != nil { + t.Fatalf("failed to create table t2 on MySQL: %v", err) + } + + cases := []struct { + name string + createAs string // the SELECT portion after CREATE VIEW v AS + tables string // additional CREATE TABLE statements for our catalog (beyond t, t1, t2) + partial bool // expected partial match (parser limitation) + }{ + {"inner_join", "SELECT t1.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a", "", false}, + {"left_join", "SELECT t1.a, t2.b FROM t1 LEFT JOIN t2 ON t1.a = t2.a", "", false}, + {"multi_table", "SELECT t1.a FROM t1, t2 WHERE t1.a = t2.a", "", false}, + {"subquery_from", "SELECT d.x FROM (SELECT a AS x FROM t) d", "", true}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + viewName := "v_" + tc.name + + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + viewName) + createSQL := fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs) + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("CREATE VIEW on MySQL failed: %v", err) + } + mysqlOutput, err := ctr.showCreateView(viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Our catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + cat.Exec("CREATE TABLE t1 (a INT, b INT)", nil) + cat.Exec("CREATE TABLE t2 (a INT, b INT)", nil) + if tc.tables != "" { + cat.Exec(tc.tables, nil) + } + results, _ := cat.Exec(fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs), nil) + if len(results) == 0 { + if tc.partial { + t.Skipf("CREATE VIEW on catalog returned no results (expected partial — parser limitation)") + } + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + if tc.partial { + t.Skipf("CREATE VIEW on catalog failed (expected partial): %v", results[0].Error) + } + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", viewName) + if omniOutput == "" { + if tc.partial { + t.Skip("ShowCreateView returned empty (expected partial)") + } + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL full: %s", mysqlOutput) + t.Logf("Omni full: %s", omniOutput) + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + if tc.partial { + t.Skipf("SELECT body mismatch (expected partial):\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparse_Section_7_5_AdvancedViews verifies that our catalog's SHOW CREATE VIEW +// output matches MySQL 8.0's output for advanced view definitions: UNION, CTE, +// window functions, nested subqueries, boolean expressions, and combined rewrites. +func TestDeparse_Section_7_5_AdvancedViews(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table t on MySQL: %v", err) + } + + cases := []struct { + name string + createAs string // the SELECT portion after CREATE VIEW v AS + partial bool // expected partial match (parser/resolver limitation) + }{ + {"union_view", "SELECT a FROM t UNION SELECT b FROM t", false}, + {"cte_view", "WITH cte AS (SELECT a FROM t) SELECT * FROM cte", false}, + {"window_func_view", "SELECT a, ROW_NUMBER() OVER (ORDER BY a) FROM t", false}, + {"nested_subquery_view", "SELECT * FROM t WHERE a IN (SELECT a FROM t WHERE a > 0)", false}, + {"boolean_expr_view", "SELECT a AND b, a OR b FROM t", false}, + {"combined_rewrite_view", "SELECT a + b, NOT (a > 0), CAST(a AS CHAR), COUNT(*) FROM t GROUP BY a, b HAVING COUNT(*) > 1", false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + viewName := "v_" + tc.name + + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + viewName) + createSQL := fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs) + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("CREATE VIEW on MySQL failed: %v", err) + } + mysqlOutput, err := ctr.showCreateView(viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Our catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs), nil) + if len(results) == 0 { + if tc.partial { + t.Skipf("CREATE VIEW on catalog returned no results (expected partial — parser/resolver limitation)") + } + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + if tc.partial { + t.Skipf("CREATE VIEW on catalog failed (expected partial): %v", results[0].Error) + } + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", viewName) + if omniOutput == "" { + if tc.partial { + t.Skip("ShowCreateView returned empty (expected partial)") + } + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL full: %s", mysqlOutput) + t.Logf("Omni full: %s", omniOutput) + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + if tc.partial { + t.Skipf("SELECT body mismatch (expected partial):\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// stripDatabasePrefix removes the `test`. database prefix from MySQL 8.0 SHOW CREATE VIEW output. +// MySQL 8.0 qualifies all identifiers with the database name (e.g., `test`.`t`.`a`), +// while our catalog does not. We strip the prefix for comparison. +func stripDatabasePrefix(s string) string { + return strings.ReplaceAll(s, "`test`.", "") +} + +// TestDeparse_Section_7_6_Regression verifies that the deparser integration does not +// break existing tests (scenarios 1-2 are covered by running go test ./mysql/catalog/ -short +// and go test ./mysql/parser/ -short separately) and that views with explicit column +// aliases match MySQL 8.0 output exactly. +func TestDeparse_Section_7_6_Regression(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table on MySQL: %v", err) + } + + t.Run("view_with_explicit_column_aliases", func(t *testing.T) { + viewName := "v_col_alias" + createSQL := "CREATE VIEW " + viewName + "(x, y) AS SELECT a, b FROM t" + + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + viewName) + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("CREATE VIEW on MySQL failed: %v", err) + } + mysqlOutput, err := ctr.showCreateView(viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + + // --- Our catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(createSQL, nil) + if results[0].Error != nil { + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", viewName) + if omniOutput == "" { + t.Fatal("ShowCreateView returned empty") + } + + // Compare full output (stripping database prefix from MySQL). + // MySQL: CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`%` SQL SECURITY DEFINER VIEW `test`.`v_col_alias` (`x`,`y`) AS select ... + // Omni: CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`%` SQL SECURITY DEFINER VIEW `v_col_alias` (`x`,`y`) AS select ... + mysqlNorm := stripDatabasePrefix(mysqlOutput) + + t.Logf("MySQL full: %s", mysqlOutput) + t.Logf("Omni full: %s", omniOutput) + t.Logf("MySQL norm: %s", mysqlNorm) + + // Compare the preamble (up to and including column list). + mysqlPreamble := extractViewPreamble(mysqlNorm) + omniPreamble := extractViewPreamble(omniOutput) + if mysqlPreamble != omniPreamble { + t.Errorf("preamble mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlPreamble, omniPreamble) + } + + // Compare the SELECT body. + mysqlBody := extractSelectBody(mysqlNorm) + omniBody := extractSelectBody(omniOutput) + if mysqlBody != omniBody { + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + + // Verify simple and complex views still match (re-run 7.2 and 7.5 representative cases). + simpleAndComplexCases := []struct { + name string + createAs string + }{ + {"simple_select_column", "SELECT a FROM t"}, + {"simple_select_where", "SELECT a FROM t WHERE a > 0"}, + {"complex_union", "SELECT a FROM t UNION SELECT b FROM t"}, + {"complex_boolean", "SELECT a AND b, a OR b FROM t"}, + } + + for _, tc := range simpleAndComplexCases { + t.Run(tc.name, func(t *testing.T) { + viewName := "v_reg_" + tc.name + + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + viewName) + createSQL := fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs) + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("CREATE VIEW on MySQL failed: %v", err) + } + mysqlOutput, err := ctr.showCreateView(viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Our catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs), nil) + if results[0].Error != nil { + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", viewName) + if omniOutput == "" { + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_7_1_WindowFunctions verifies window function patterns +// against MySQL 8.0 SHOW CREATE VIEW output. +// Covers: ROW_NUMBER, SUM OVER PARTITION BY+ORDER BY, ROWS frame, RANGE frame, +// named window, multiple window functions, LAG/LEAD. +func TestDeparseContainer_7_1_WindowFunctions(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table on MySQL: %v", err) + } + + cases := []struct { + name string + viewName string + viewSQL string + partial bool + }{ + {"row_number_basic", "v_wf_rownum", "CREATE VIEW v_wf_rownum AS SELECT a, ROW_NUMBER() OVER (ORDER BY a) FROM t", false}, + {"sum_partition_orderby", "v_wf_sum_part", "CREATE VIEW v_wf_sum_part AS SELECT a, SUM(b) OVER (PARTITION BY a ORDER BY b) FROM t", false}, + {"rows_frame", "v_wf_rows", "CREATE VIEW v_wf_rows AS SELECT a, SUM(b) OVER (PARTITION BY a ORDER BY b ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM t", false}, + {"range_frame", "v_wf_range", "CREATE VIEW v_wf_range AS SELECT a, AVG(b) OVER (ORDER BY a RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM t", false}, + {"named_window", "v_wf_named", "CREATE VIEW v_wf_named AS SELECT a, RANK() OVER w, DENSE_RANK() OVER w FROM t WINDOW w AS (ORDER BY a)", false}, + {"multiple_window_funcs", "v_wf_multi", "CREATE VIEW v_wf_multi AS SELECT a, ROW_NUMBER() OVER (ORDER BY a), SUM(b) OVER (ORDER BY a) FROM t", false}, + {"lag_lead", "v_wf_lagld", "CREATE VIEW v_wf_lagld AS SELECT a, LAG(a, 1) OVER (ORDER BY a), LEAD(a, 1) OVER (ORDER BY a) FROM t", false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + tc.viewName) + if err := ctr.execSQLDirect(tc.viewSQL); err != nil { + if tc.partial { + t.Skipf("MySQL 8.0 rejected (expected partial): %v", err) + } + t.Skipf("MySQL 8.0 rejected: %v", err) + return + } + mysqlOutput, err := ctr.showCreateView(tc.viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Omni catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(tc.viewSQL, nil) + if len(results) == 0 { + if tc.partial { + t.Skipf("CREATE VIEW on catalog returned no results (expected partial — parser limitation)") + } + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + if tc.partial { + t.Skipf("CREATE VIEW on catalog failed (expected partial): %v", results[0].Error) + } + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", tc.viewName) + if omniOutput == "" { + if tc.partial { + t.Skip("ShowCreateView returned empty (expected partial)") + } + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + if tc.partial { + t.Skipf("SELECT body mismatch (expected partial):\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_1_1_ArithmeticComparison verifies arithmetic and comparison operators +// against MySQL 8.0 SHOW CREATE VIEW output. +func TestDeparseContainer_1_1_ArithmeticComparison(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table on MySQL: %v", err) + } + + cases := []struct { + name string + viewName string + viewSQL string + }{ + {"addition", "v_add", "CREATE VIEW v_add AS SELECT a + b FROM t"}, + {"subtraction", "v_sub", "CREATE VIEW v_sub AS SELECT a - b FROM t"}, + {"multiplication", "v_mul", "CREATE VIEW v_mul AS SELECT a * b FROM t"}, + {"division", "v_div", "CREATE VIEW v_div AS SELECT a / b FROM t"}, + {"int_division", "v_intdiv", "CREATE VIEW v_intdiv AS SELECT a DIV b FROM t"}, + {"mod", "v_mod", "CREATE VIEW v_mod AS SELECT a MOD b FROM t"}, + {"equals", "v_eq", "CREATE VIEW v_eq AS SELECT a = b FROM t"}, + {"not_equals_bang", "v_neq", "CREATE VIEW v_neq AS SELECT a != b FROM t"}, + {"not_equals_ltgt", "v_neq2", "CREATE VIEW v_neq2 AS SELECT a <> b FROM t"}, + {"comparisons", "v_cmp", "CREATE VIEW v_cmp AS SELECT a > b, a < b, a >= b, a <= b FROM t"}, + {"null_safe_equals", "v_nseq", "CREATE VIEW v_nseq AS SELECT a <=> b FROM t"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + tc.viewName) + if err := ctr.execSQLDirect(tc.viewSQL); err != nil { + t.Skipf("MySQL 8.0 rejected: %v", err) + return + } + mysqlOutput, err := ctr.showCreateView(tc.viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Omni catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(tc.viewSQL, nil) + if len(results) == 0 { + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", tc.viewName) + if omniOutput == "" { + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparse_Section_7_3_ExpressionViews verifies that our catalog's SHOW CREATE VIEW +// output matches MySQL 8.0's output for views with expressions (arithmetic, functions, +// CASE, CAST, aggregates). +func TestDeparse_Section_7_3_ExpressionViews(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table on MySQL: %v", err) + } + + cases := []struct { + name string + createAs string // the SELECT portion after CREATE VIEW v AS + }{ + {"arithmetic_expr", "SELECT a + b FROM t"}, + {"function_call", "SELECT CONCAT(a, b) FROM t"}, + {"case_expr", "SELECT CASE WHEN a > 0 THEN 'pos' ELSE 'neg' END FROM t"}, + {"cast_expr", "SELECT CAST(a AS CHAR) FROM t"}, + {"aggregate_expr", "SELECT COUNT(*), SUM(a) FROM t GROUP BY a HAVING SUM(a) > 10"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + viewName := "v_" + tc.name + + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + viewName) + createSQL := fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs) + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("CREATE VIEW on MySQL failed: %v", err) + } + mysqlOutput, err := ctr.showCreateView(viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Our catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs), nil) + if results[0].Error != nil { + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", viewName) + if omniOutput == "" { + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL full: %s", mysqlOutput) + t.Logf("Omni full: %s", omniOutput) + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} +// TestDeparseContainer_1_3_LiteralsSpacing verifies literals and spacing rules +// against MySQL 8.0 SHOW CREATE VIEW output. +func TestDeparseContainer_1_3_LiteralsSpacing(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table on MySQL: %v", err) + } + + cases := []struct { + name string + viewName string + viewSQL string + partial bool // mark [~] if parser limitation + }{ + {"basic_literals", "v_basic_lit", "CREATE VIEW v_basic_lit AS SELECT 1, 1.5, 'hello', NULL, TRUE, FALSE FROM t", false}, + {"hex_bit_literals", "v_hex_bit", "CREATE VIEW v_hex_bit AS SELECT 0xFF, X'FF', 0b1010, b'1010' FROM t", false}, + {"charset_introducers", "v_charset", "CREATE VIEW v_charset AS SELECT _utf8mb4'hello', _latin1'world' FROM t", false}, + {"empty_string", "v_empty_str", "CREATE VIEW v_empty_str AS SELECT '' FROM t", false}, + {"escaped_quotes", "v_esc_quotes", "CREATE VIEW v_esc_quotes AS SELECT 'it''s' FROM t", false}, + {"escaped_backslash", "v_esc_bslash", "CREATE VIEW v_esc_bslash AS SELECT 'back\\\\slash' FROM t", false}, + {"temporal_literals", "v_temporal", "CREATE VIEW v_temporal AS SELECT DATE '2024-01-01', TIME '12:00:00', TIMESTAMP '2024-01-01 12:00:00' FROM t", false}, + {"func_args_no_space", "v_func_args", "CREATE VIEW v_func_args AS SELECT CONCAT(a, b, c) FROM t", false}, + {"in_list_no_space", "v_in_list", "CREATE VIEW v_in_list AS SELECT a IN (1, 2, 3) FROM t", false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + tc.viewName) + if err := ctr.execSQLDirect(tc.viewSQL); err != nil { + if tc.partial { + t.Skipf("MySQL 8.0 rejected (expected partial): %v", err) + } + t.Skipf("MySQL 8.0 rejected: %v", err) + return + } + mysqlOutput, err := ctr.showCreateView(tc.viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Omni catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(tc.viewSQL, nil) + if len(results) == 0 { + if tc.partial { + t.Skipf("CREATE VIEW on catalog returned no results (expected partial — parser limitation)") + } + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + if tc.partial { + t.Skipf("CREATE VIEW on catalog failed (expected partial): %v", results[0].Error) + } + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", tc.viewName) + if omniOutput == "" { + if tc.partial { + t.Skip("ShowCreateView returned empty (expected partial)") + } + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + if tc.partial { + t.Skipf("SELECT body mismatch (expected partial):\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_1_2_LogicalBitwiseIS verifies logical, bitwise, and IS operators +// against MySQL 8.0 SHOW CREATE VIEW output. +func TestDeparseContainer_1_2_LogicalBitwiseIS(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table on MySQL: %v", err) + } + + cases := []struct { + name string + viewName string + viewSQL string + }{ + {"and_or", "v_and_or", "CREATE VIEW v_and_or AS SELECT a AND b, a OR b FROM t"}, + {"xor", "v_xor", "CREATE VIEW v_xor AS SELECT a XOR b FROM t"}, + {"not", "v_not", "CREATE VIEW v_not AS SELECT NOT a FROM t"}, + {"bitwise_ops", "v_bitwise", "CREATE VIEW v_bitwise AS SELECT a | b, a & b, a ^ b FROM t"}, + {"shifts", "v_shifts", "CREATE VIEW v_shifts AS SELECT a << b, a >> b FROM t"}, + {"bitwise_not", "v_bitnot", "CREATE VIEW v_bitnot AS SELECT ~a FROM t"}, + {"is_null", "v_isnull", "CREATE VIEW v_isnull AS SELECT a IS NULL, a IS NOT NULL FROM t"}, + {"is_true_false", "v_istf", "CREATE VIEW v_istf AS SELECT a IS TRUE, a IS FALSE FROM t"}, + {"in_not_in", "v_in", "CREATE VIEW v_in AS SELECT a IN (1,2,3), a NOT IN (1,2,3) FROM t"}, + {"between", "v_between", "CREATE VIEW v_between AS SELECT a BETWEEN 1 AND 10, a NOT BETWEEN 1 AND 10 FROM t"}, + {"like", "v_like", "CREATE VIEW v_like AS SELECT a LIKE 'foo%', a NOT LIKE 'bar%' FROM t"}, + {"like_escape", "v_like_esc", "CREATE VIEW v_like_esc AS SELECT a LIKE 'x' ESCAPE '\\\\' FROM t"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + tc.viewName) + if err := ctr.execSQLDirect(tc.viewSQL); err != nil { + t.Skipf("MySQL 8.0 rejected: %v", err) + return + } + mysqlOutput, err := ctr.showCreateView(tc.viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Omni catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(tc.viewSQL, nil) + if len(results) == 0 { + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", tc.viewName) + if omniOutput == "" { + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_Section_2_1_FunctionNameRewrites verifies that function name +// rewrites match MySQL 8.0 SHOW CREATE VIEW output. +// Covers: SUBSTRING->substr, CURRENT_TIMESTAMP->now(), CURRENT_DATE->curdate(), +// CURRENT_TIME->curtime(), CURRENT_USER->current_user(), NOW()->now(), +// COUNT(*)->count(0), COUNT(DISTINCT). +func TestDeparseContainer_Section_2_1_FunctionNameRewrites(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table on MySQL: %v", err) + } + + cases := []struct { + name string + viewName string + viewSQL string + }{ + {"substring_rewrite", "v_substr", "CREATE VIEW v_substr AS SELECT SUBSTRING('abc', 1, 2) FROM t"}, + {"current_timestamp_kw", "v_cur_ts", "CREATE VIEW v_cur_ts AS SELECT CURRENT_TIMESTAMP FROM t"}, + {"current_timestamp_fn", "v_cur_ts_fn", "CREATE VIEW v_cur_ts_fn AS SELECT CURRENT_TIMESTAMP() FROM t"}, + {"current_date_kw", "v_cur_date", "CREATE VIEW v_cur_date AS SELECT CURRENT_DATE FROM t"}, + {"current_time_kw", "v_cur_time", "CREATE VIEW v_cur_time AS SELECT CURRENT_TIME FROM t"}, + {"current_user_kw", "v_cur_user", "CREATE VIEW v_cur_user AS SELECT CURRENT_USER FROM t"}, + {"now_func", "v_now", "CREATE VIEW v_now AS SELECT NOW() FROM t"}, + {"count_star", "v_count_star", "CREATE VIEW v_count_star AS SELECT COUNT(*) FROM t"}, + {"count_distinct", "v_count_dist", "CREATE VIEW v_count_dist AS SELECT COUNT(DISTINCT a) FROM t"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + tc.viewName) + if err := ctr.execSQLDirect(tc.viewSQL); err != nil { + t.Skipf("MySQL 8.0 rejected: %v", err) + return + } + mysqlOutput, err := ctr.showCreateView(tc.viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Omni catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(tc.viewSQL, nil) + if len(results) == 0 { + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", tc.viewName) + if omniOutput == "" { + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_2_2_RegularFunctionsAggregates verifies regular functions and aggregates +// against MySQL 8.0 SHOW CREATE VIEW output. +func TestDeparseContainer_2_2_RegularFunctionsAggregates(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table on MySQL: %v", err) + } + + cases := []struct { + name string + viewName string + viewSQL string + }{ + {"concat_upper_lower", "v_str_funcs", "CREATE VIEW v_str_funcs AS SELECT CONCAT(a, b), UPPER(a), LOWER(a) FROM t"}, + {"ifnull_coalesce_nullif", "v_null_funcs", "CREATE VIEW v_null_funcs AS SELECT IFNULL(a, 0), COALESCE(a, b, 0), NULLIF(a, 0) FROM t"}, + {"if_function", "v_if_func", "CREATE VIEW v_if_func AS SELECT IF(a > 0, 'yes', 'no') FROM t"}, + {"abs_greatest_least", "v_num_funcs", "CREATE VIEW v_num_funcs AS SELECT ABS(a), GREATEST(a, b), LEAST(a, b) FROM t"}, + {"sum_avg_max_min", "v_agg_funcs", "CREATE VIEW v_agg_funcs AS SELECT SUM(a), AVG(a), MAX(a), MIN(a) FROM t"}, + {"nested_functions", "v_nested_funcs", "CREATE VIEW v_nested_funcs AS SELECT CONCAT(UPPER(TRIM(a)), LOWER(b)) FROM t"}, + {"multiple_aggregates_groupby", "v_multi_agg", "CREATE VIEW v_multi_agg AS SELECT COUNT(*), SUM(a), AVG(b), MAX(c) FROM t GROUP BY a"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + tc.viewName) + if err := ctr.execSQLDirect(tc.viewSQL); err != nil { + t.Skipf("MySQL 8.0 rejected: %v", err) + return + } + mysqlOutput, err := ctr.showCreateView(tc.viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Omni catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(tc.viewSQL, nil) + if len(results) == 0 { + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", tc.viewName) + if omniOutput == "" { + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_2_4_CastConvertOperatorRewrites verifies CAST, CONVERT, +// REGEXP→regexp_like, NOT REGEXP, -> (json_extract), ->> (json_unquote(json_extract)) +// against MySQL 8.0 SHOW CREATE VIEW output. +func TestDeparseContainer_2_4_CastConvertOperatorRewrites(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base tables on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table t on MySQL: %v", err) + } + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS tj (a JSON, b INT)"); err != nil { + t.Fatalf("failed to create table tj on MySQL: %v", err) + } + + cases := []struct { + name string + viewName string + viewSQL string + tables string // extra CREATE TABLE for omni catalog (beyond t) + partial bool + }{ + {"cast_char", "v_cast_char", "CREATE VIEW v_cast_char AS SELECT CAST(a AS CHAR) FROM t", "", false}, + {"cast_char10", "v_cast_char10", "CREATE VIEW v_cast_char10 AS SELECT CAST(a AS CHAR(10)) FROM t", "", false}, + {"cast_binary", "v_cast_binary", "CREATE VIEW v_cast_binary AS SELECT CAST(a AS BINARY) FROM t", "", false}, + {"cast_signed_unsigned", "v_cast_su", "CREATE VIEW v_cast_su AS SELECT CAST(a AS SIGNED), CAST(a AS UNSIGNED) FROM t", "", false}, + {"cast_decimal", "v_cast_dec", "CREATE VIEW v_cast_dec AS SELECT CAST(a AS DECIMAL(10,2)) FROM t", "", false}, + {"cast_date_datetime_json", "v_cast_ddj", "CREATE VIEW v_cast_ddj AS SELECT CAST(a AS DATE), CAST(a AS DATETIME), CAST(a AS JSON) FROM t", "", false}, + {"convert_char", "v_conv_char", "CREATE VIEW v_conv_char AS SELECT CONVERT(a, CHAR) FROM t", "", false}, + {"convert_using", "v_conv_using", "CREATE VIEW v_conv_using AS SELECT CONVERT(a USING utf8mb4) FROM t", "", false}, + {"regexp", "v_regexp", "CREATE VIEW v_regexp AS SELECT a REGEXP 'pattern' FROM t", "", false}, + {"not_regexp", "v_not_regexp", "CREATE VIEW v_not_regexp AS SELECT a NOT REGEXP 'pattern' FROM t", "", false}, + {"json_extract", "v_json_ext", "CREATE VIEW v_json_ext AS SELECT a->'$.key' FROM tj", "CREATE TABLE tj (a JSON, b INT)", false}, + {"json_unquote", "v_json_unq", "CREATE VIEW v_json_unq AS SELECT a->>'$.key' FROM tj", "CREATE TABLE tj (a JSON, b INT)", false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + tc.viewName) + if err := ctr.execSQLDirect(tc.viewSQL); err != nil { + if tc.partial { + t.Skipf("MySQL 8.0 rejected (expected partial): %v", err) + } + t.Skipf("MySQL 8.0 rejected: %v", err) + return + } + mysqlOutput, err := ctr.showCreateView(tc.viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Omni catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + if tc.tables != "" { + cat.Exec(tc.tables, nil) + } + results, _ := cat.Exec(tc.viewSQL, nil) + if len(results) == 0 { + if tc.partial { + t.Skipf("CREATE VIEW on catalog returned no results (expected partial)") + } + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + if tc.partial { + t.Skipf("CREATE VIEW on catalog failed (expected partial): %v", results[0].Error) + } + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", tc.viewName) + if omniOutput == "" { + if tc.partial { + t.Skip("ShowCreateView returned empty (expected partial)") + } + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + if tc.partial { + t.Skipf("SELECT body mismatch (expected partial):\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_2_3_SpecialFunctions verifies TRIM, GROUP_CONCAT, and simple CASE +// against MySQL 8.0 SHOW CREATE VIEW output. +func TestDeparseContainer_2_3_SpecialFunctions(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base tables on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table t on MySQL: %v", err) + } + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS tv (a VARCHAR(50), b VARCHAR(50), c INT)"); err != nil { + t.Fatalf("failed to create table tv on MySQL: %v", err) + } + + cases := []struct { + name string + viewName string + viewSQL string + tables string // extra CREATE TABLE for omni catalog (beyond t) + }{ + {"trim_simple", "v_trim_simple", "CREATE VIEW v_trim_simple AS SELECT TRIM(a) FROM tv", "CREATE TABLE tv (a VARCHAR(50), b VARCHAR(50), c INT)"}, + {"trim_leading", "v_trim_leading", "CREATE VIEW v_trim_leading AS SELECT TRIM(LEADING 'x' FROM a) FROM tv", "CREATE TABLE tv (a VARCHAR(50), b VARCHAR(50), c INT)"}, + {"trim_trailing", "v_trim_trailing", "CREATE VIEW v_trim_trailing AS SELECT TRIM(TRAILING 'x' FROM a) FROM tv", "CREATE TABLE tv (a VARCHAR(50), b VARCHAR(50), c INT)"}, + {"trim_both", "v_trim_both", "CREATE VIEW v_trim_both AS SELECT TRIM(BOTH 'x' FROM a) FROM tv", "CREATE TABLE tv (a VARCHAR(50), b VARCHAR(50), c INT)"}, + {"group_concat_basic", "v_gc_basic", "CREATE VIEW v_gc_basic AS SELECT GROUP_CONCAT(a ORDER BY a SEPARATOR ',') FROM t", ""}, + {"group_concat_full", "v_gc_full", "CREATE VIEW v_gc_full AS SELECT GROUP_CONCAT(DISTINCT a ORDER BY a DESC SEPARATOR ';') FROM t", ""}, + {"simple_case", "v_simple_case", "CREATE VIEW v_simple_case AS SELECT CASE a WHEN 1 THEN 'one' WHEN 2 THEN 'two' ELSE 'other' END FROM t", ""}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + tc.viewName) + if err := ctr.execSQLDirect(tc.viewSQL); err != nil { + t.Skipf("MySQL 8.0 rejected: %v", err) + return + } + mysqlOutput, err := ctr.showCreateView(tc.viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Omni catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + if tc.tables != "" { + cat.Exec(tc.tables, nil) + } + results, _ := cat.Exec(tc.viewSQL, nil) + if len(results) == 0 { + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", tc.viewName) + if omniOutput == "" { + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_3_1_BooleanContextWrapping verifies that non-boolean expressions +// in boolean context (AND/OR) get (0 <> ...) wrapping to match MySQL 8.0's +// SHOW CREATE VIEW output. +func TestDeparseContainer_3_1_BooleanContextWrapping(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table on MySQL: %v", err) + } + + cases := []struct { + name string + viewName string + viewSQL string + }{ + {"col_and_col", "v_bool_cc", "CREATE VIEW v_bool_cc AS SELECT a AND b FROM t"}, + {"arith_and_col", "v_bool_ac", "CREATE VIEW v_bool_ac AS SELECT (a + 1) AND b FROM t"}, + {"cmp_and_arith", "v_bool_ca", "CREATE VIEW v_bool_ca AS SELECT (a > 0) AND (b + 1) FROM t"}, + {"cmp_and_cmp", "v_bool_cmpx2", "CREATE VIEW v_bool_cmpx2 AS SELECT (a > 0) AND (b > 0) FROM t"}, + {"abs_and_col", "v_bool_abs", "CREATE VIEW v_bool_abs AS SELECT ABS(a) AND b FROM t"}, + {"case_and_col", "v_bool_case", "CREATE VIEW v_bool_case AS SELECT CASE WHEN a > 0 THEN 1 ELSE 0 END AND b FROM t"}, + {"if_and_col", "v_bool_if", "CREATE VIEW v_bool_if AS SELECT IF(a > 0, 1, 0) AND b FROM t"}, + {"subquery_and_col", "v_bool_subq", "CREATE VIEW v_bool_subq AS SELECT (SELECT MAX(a) FROM t) AND b FROM t"}, + {"string_and_int", "v_bool_str", "CREATE VIEW v_bool_str AS SELECT 'hello' AND 1 FROM t"}, + {"ifnull_and_col", "v_bool_ifnull", "CREATE VIEW v_bool_ifnull AS SELECT IFNULL(a, 0) AND b FROM t"}, + {"coalesce_and_int", "v_bool_coal", "CREATE VIEW v_bool_coal AS SELECT COALESCE(a, b) AND 1 FROM t"}, + {"nullif_and_col", "v_bool_nullif", "CREATE VIEW v_bool_nullif AS SELECT NULLIF(a, 0) AND b FROM t"}, + {"greatest_and_int", "v_bool_great", "CREATE VIEW v_bool_great AS SELECT GREATEST(a, b) AND 1 FROM t"}, + {"least_and_int", "v_bool_least", "CREATE VIEW v_bool_least AS SELECT LEAST(a, b) AND 1 FROM t"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + tc.viewName) + if err := ctr.execSQLDirect(tc.viewSQL); err != nil { + t.Skipf("MySQL 8.0 rejected: %v", err) + return + } + mysqlOutput, err := ctr.showCreateView(tc.viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Omni catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(tc.viewSQL, nil) + if len(results) == 0 { + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", tc.viewName) + if omniOutput == "" { + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_3_2_NotFoldingNoDoubleWrap verifies NOT folding and +// no-double-wrapping of boolean expressions against MySQL 8.0. +// Section 3.2 scenarios: +// - NOT(comparison) folds into inverted operator +// - NOT(non-boolean) becomes (0 = ...) +// - ! is same as NOT +// - comparisons in AND are NOT double-wrapped +// - predicates (IN, BETWEEN) in AND are NOT wrapped +// - IS/LIKE in AND are NOT wrapped +// - EXISTS in AND is NOT wrapped +func TestDeparseContainer_3_2_NotFoldingNoDoubleWrap(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table on MySQL: %v", err) + } + + cases := []struct { + name string + viewName string + viewSQL string + }{ + {"not_folds_into_le", "v_s32_not_fold", "CREATE VIEW v_s32_not_fold AS SELECT NOT (a > 0) FROM t"}, + {"not_non_boolean", "v_s32_not_nb", "CREATE VIEW v_s32_not_nb AS SELECT NOT (a + 1) FROM t"}, + {"bang_col", "v_s32_bang", "CREATE VIEW v_s32_bang AS SELECT !a FROM t"}, + {"cmp_not_double_wrapped", "v_s32_cmp_ndw", "CREATE VIEW v_s32_cmp_ndw AS SELECT (a = b) AND (a > 0) FROM t"}, + {"predicates_not_wrapped", "v_s32_pred_nw", "CREATE VIEW v_s32_pred_nw AS SELECT (a IN (1, 2, 3)) AND (b BETWEEN 1 AND 10) FROM t"}, + {"is_like_not_wrapped", "v_s32_islike_nw", "CREATE VIEW v_s32_islike_nw AS SELECT (a IS NULL) AND (b LIKE 'x%') FROM t"}, + {"exists_not_wrapped", "v_s32_exists_nw", "CREATE VIEW v_s32_exists_nw AS SELECT EXISTS(SELECT 1 FROM t WHERE a > 0) AND (b > 0) FROM t"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + tc.viewName) + if err := ctr.execSQLDirect(tc.viewSQL); err != nil { + t.Fatalf("CREATE VIEW on MySQL failed: %v", err) + } + mysqlOutput, err := ctr.showCreateView(tc.viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Omni catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(tc.viewSQL, nil) + if len(results) == 0 { + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", tc.viewName) + if omniOutput == "" { + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_3_3_ComplexPrecedence verifies complex operator precedence +// against MySQL 8.0 SHOW CREATE VIEW output. +func TestDeparseContainer_3_3_ComplexPrecedence(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table on MySQL: %v", err) + } + + cases := []struct { + name string + viewName string + viewSQL string + }{ + {"left_assoc_add", "v_left_add", "CREATE VIEW v_left_add AS SELECT a + b + c FROM t"}, + {"mul_then_add", "v_mul_add", "CREATE VIEW v_mul_add AS SELECT a * b + c FROM t"}, + {"add_then_mul", "v_add_mul", "CREATE VIEW v_add_mul AS SELECT a + b * c FROM t"}, + {"paren_add_mul", "v_paren_mul", "CREATE VIEW v_paren_mul AS SELECT (a + b) * c FROM t"}, + {"or_and_prec", "v_or_and", "CREATE VIEW v_or_and AS SELECT a OR b AND c FROM t"}, + {"paren_or_and", "v_paren_or", "CREATE VIEW v_paren_or AS SELECT (a OR b) AND c FROM t"}, + {"mixed_cmp_logic", "v_mixed", "CREATE VIEW v_mixed AS SELECT a > 0 AND b < 10 OR c = 5 FROM t"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + tc.viewName) + if err := ctr.execSQLDirect(tc.viewSQL); err != nil { + t.Skipf("MySQL 8.0 rejected: %v", err) + return + } + mysqlOutput, err := ctr.showCreateView(tc.viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Omni catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(tc.viewSQL, nil) + if len(results) == 0 { + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", tc.viewName) + if omniOutput == "" { + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_4_1_AllJoinTypes verifies all JOIN types against MySQL 8.0 +// SHOW CREATE VIEW output: INNER JOIN, LEFT JOIN, RIGHT JOIN→LEFT swap, +// CROSS JOIN, NATURAL JOIN expanded, STRAIGHT_JOIN, USING expanded, comma→explicit join. +func TestDeparseContainer_4_1_AllJoinTypes(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base tables on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table t on MySQL: %v", err) + } + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t1 (a INT, b INT)"); err != nil { + t.Fatalf("failed to create table t1 on MySQL: %v", err) + } + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t2 (a INT, b INT)"); err != nil { + t.Fatalf("failed to create table t2 on MySQL: %v", err) + } + + cases := []struct { + name string + viewName string + viewSQL string + partial bool + }{ + {"inner_join", "v_inner_join", "CREATE VIEW v_inner_join AS SELECT t1.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a", false}, + {"left_join", "v_left_join", "CREATE VIEW v_left_join AS SELECT t1.a, t2.b FROM t1 LEFT JOIN t2 ON t1.a = t2.a", false}, + {"right_join_swap", "v_right_join", "CREATE VIEW v_right_join AS SELECT t1.a, t2.b FROM t1 RIGHT JOIN t2 ON t1.a = t2.a", false}, + {"cross_join", "v_cross_join", "CREATE VIEW v_cross_join AS SELECT t1.a, t2.b FROM t1 CROSS JOIN t2", false}, + {"natural_join", "v_natural_join", "CREATE VIEW v_natural_join AS SELECT * FROM t1 NATURAL JOIN t2", false}, + {"straight_join", "v_straight_join", "CREATE VIEW v_straight_join AS SELECT t1.a, t2.b FROM t1 STRAIGHT_JOIN t2 ON t1.a = t2.a", false}, + {"using_expanded", "v_using", "CREATE VIEW v_using AS SELECT t1.a, t2.b FROM t1 JOIN t2 USING (a)", false}, + {"comma_to_join", "v_comma_join", "CREATE VIEW v_comma_join AS SELECT t1.a, t2.b FROM t1, t2 WHERE t1.a = t2.a", false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + tc.viewName) + if err := ctr.execSQLDirect(tc.viewSQL); err != nil { + if tc.partial { + t.Skipf("MySQL 8.0 rejected (expected partial): %v", err) + } + t.Skipf("MySQL 8.0 rejected: %v", err) + return + } + mysqlOutput, err := ctr.showCreateView(tc.viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Omni catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + cat.Exec("CREATE TABLE t1 (a INT, b INT)", nil) + cat.Exec("CREATE TABLE t2 (a INT, b INT)", nil) + results, _ := cat.Exec(tc.viewSQL, nil) + if len(results) == 0 { + if tc.partial { + t.Skipf("CREATE VIEW on catalog returned no results (expected partial)") + } + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + if tc.partial { + t.Skipf("CREATE VIEW on catalog failed (expected partial): %v", results[0].Error) + } + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", tc.viewName) + if omniOutput == "" { + if tc.partial { + t.Skip("ShowCreateView returned empty (expected partial)") + } + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + if tc.partial { + t.Skipf("SELECT body mismatch (expected partial):\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_4_2_MultiTableDerived verifies multi-table JOINs, chained LEFT JOINs, +// derived tables, and table aliases against MySQL 8.0 SHOW CREATE VIEW output. +func TestDeparseContainer_4_2_MultiTableDerived(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base tables on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table t on MySQL: %v", err) + } + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t1 (a INT, b INT)"); err != nil { + t.Fatalf("failed to create table t1 on MySQL: %v", err) + } + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t2 (a INT, b INT)"); err != nil { + t.Fatalf("failed to create table t2 on MySQL: %v", err) + } + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t3 (b INT, c INT)"); err != nil { + t.Fatalf("failed to create table t3 on MySQL: %v", err) + } + + cases := []struct { + name string + viewName string + viewSQL string + partial bool + }{ + {"three_table_join", "v_3join", "CREATE VIEW v_3join AS SELECT t1.a, t2.b, t3.c FROM t1 JOIN t2 ON t1.a = t2.a JOIN t3 ON t2.b = t3.b", false}, + {"chained_left_joins", "v_chain_left", "CREATE VIEW v_chain_left AS SELECT t1.a FROM t1 LEFT JOIN t2 ON t1.a = t2.a LEFT JOIN t3 ON t1.a = t3.b", false}, + {"derived_table", "v_derived", "CREATE VIEW v_derived AS SELECT d.x FROM (SELECT a AS x FROM t) d", false}, + {"derived_with_where", "v_derived_where", "CREATE VIEW v_derived_where AS SELECT d.x FROM (SELECT a AS x FROM t WHERE a > 0) AS d WHERE d.x < 10", false}, + {"table_alias_with_as", "v_alias_as", "CREATE VIEW v_alias_as AS SELECT x.a FROM t AS x", false}, + {"table_alias_without_as", "v_alias_noas", "CREATE VIEW v_alias_noas AS SELECT x.a FROM t x", false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + tc.viewName) + if err := ctr.execSQLDirect(tc.viewSQL); err != nil { + if tc.partial { + t.Skipf("MySQL 8.0 rejected (expected partial): %v", err) + } + t.Skipf("MySQL 8.0 rejected: %v", err) + return + } + mysqlOutput, err := ctr.showCreateView(tc.viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Omni catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + cat.Exec("CREATE TABLE t1 (a INT, b INT)", nil) + cat.Exec("CREATE TABLE t2 (a INT, b INT)", nil) + cat.Exec("CREATE TABLE t3 (b INT, c INT)", nil) + results, _ := cat.Exec(tc.viewSQL, nil) + if len(results) == 0 { + if tc.partial { + t.Skipf("CREATE VIEW on catalog returned no results (expected partial)") + } + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + if tc.partial { + t.Skipf("CREATE VIEW on catalog failed (expected partial): %v", results[0].Error) + } + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", tc.viewName) + if omniOutput == "" { + if tc.partial { + t.Skip("ShowCreateView returned empty (expected partial)") + } + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + if tc.partial { + t.Skipf("SELECT body mismatch (expected partial):\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_Section_5_1_SelectClauses verifies that our catalog's SHOW CREATE VIEW +// output matches MySQL 8.0's output for views with WHERE, GROUP BY, HAVING, ORDER BY, +// LIMIT, OFFSET, DISTINCT, and expression-based GROUP BY. +func TestDeparseContainer_Section_5_1_SelectClauses(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table t on MySQL: %v", err) + } + + cases := []struct { + name string + viewName string + viewSQL string + partial bool + }{ + {"all_clauses_combined", "v_all_clauses", "CREATE VIEW v_all_clauses AS SELECT a FROM t WHERE a > 1 GROUP BY a HAVING COUNT(*) > 1 ORDER BY a LIMIT 10", false}, + {"alias_in_order_by", "v_alias_orderby", "CREATE VIEW v_alias_orderby AS SELECT a, COUNT(*) cnt FROM t GROUP BY a HAVING COUNT(*) > 1 ORDER BY cnt DESC", false}, + {"distinct_order_by", "v_distinct_orderby", "CREATE VIEW v_distinct_orderby AS SELECT DISTINCT a FROM t ORDER BY a DESC", false}, + {"multi_column_order_by", "v_multi_orderby", "CREATE VIEW v_multi_orderby AS SELECT a FROM t ORDER BY a, b DESC", false}, + {"limit_offset", "v_limit_offset", "CREATE VIEW v_limit_offset AS SELECT a FROM t LIMIT 10 OFFSET 5", false}, + {"expression_group_by", "v_expr_groupby", "CREATE VIEW v_expr_groupby AS SELECT a + b FROM t GROUP BY a + b", false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + tc.viewName) + if err := ctr.execSQLDirect(tc.viewSQL); err != nil { + if tc.partial { + t.Skipf("MySQL 8.0 rejected (expected partial): %v", err) + } + t.Skipf("MySQL 8.0 rejected: %v", err) + return + } + mysqlOutput, err := ctr.showCreateView(tc.viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Omni catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(tc.viewSQL, nil) + if len(results) == 0 { + if tc.partial { + t.Skipf("CREATE VIEW on catalog returned no results (expected partial)") + } + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + if tc.partial { + t.Skipf("CREATE VIEW on catalog failed (expected partial): %v", results[0].Error) + } + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", tc.viewName) + if omniOutput == "" { + if tc.partial { + t.Skip("ShowCreateView returned empty (expected partial)") + } + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + if tc.partial { + t.Skipf("SELECT body mismatch (expected partial):\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_5_2_SetOperations verifies set operations (UNION, UNION ALL, +// multiple UNION, INTERSECT, EXCEPT, UNION+ORDER BY+LIMIT) against MySQL 8.0. +// INTERSECT/EXCEPT require MySQL 8.0.31+; if rejected, the test is skipped. +func TestDeparseContainer_5_2_SetOperations(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table t on MySQL: %v", err) + } + + cases := []struct { + name string + viewName string + viewSQL string + partial bool + }{ + {"union", "v_union", "CREATE VIEW v_union AS SELECT a FROM t UNION SELECT b FROM t", false}, + {"union_all", "v_union_all", "CREATE VIEW v_union_all AS SELECT a FROM t UNION ALL SELECT b FROM t", false}, + {"multiple_union", "v_multi_union", "CREATE VIEW v_multi_union AS SELECT a FROM t UNION SELECT b FROM t UNION SELECT c FROM t", false}, + {"intersect", "v_intersect", "CREATE VIEW v_intersect AS SELECT a FROM t INTERSECT SELECT b FROM t", false}, + {"except", "v_except", "CREATE VIEW v_except AS SELECT a FROM t EXCEPT SELECT b FROM t", false}, + {"union_orderby_limit", "v_union_ordlim", "CREATE VIEW v_union_ordlim AS SELECT a FROM t UNION SELECT b FROM t ORDER BY 1 LIMIT 5", false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + tc.viewName) + if err := ctr.execSQLDirect(tc.viewSQL); err != nil { + t.Skipf("MySQL 8.0 rejected: %v", err) + return + } + mysqlOutput, err := ctr.showCreateView(tc.viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Omni catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(tc.viewSQL, nil) + if len(results) == 0 { + if tc.partial { + t.Skipf("CREATE VIEW on catalog returned no results (expected partial)") + } + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + if tc.partial { + t.Skipf("CREATE VIEW on catalog failed (expected partial): %v", results[0].Error) + } + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", tc.viewName) + if omniOutput == "" { + if tc.partial { + t.Skip("ShowCreateView returned empty (expected partial)") + } + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + if tc.partial { + t.Skipf("SELECT body mismatch (expected partial):\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_5_3_ColumnAliasPatterns verifies column and alias patterns +// against MySQL 8.0 SHOW CREATE VIEW output. +func TestDeparseContainer_5_3_ColumnAliasPatterns(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base tables on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table t on MySQL: %v", err) + } + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t1 (a INT, b INT)"); err != nil { + t.Fatalf("failed to create table t1 on MySQL: %v", err) + } + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t2 (a INT, b INT)"); err != nil { + t.Fatalf("failed to create table t2 on MySQL: %v", err) + } + + cases := []struct { + name string + viewName string + viewSQL string + tables string // extra tables for omni catalog (beyond t, t1, t2) + partial bool + }{ + { + "explicit_alias_as_vs_space", + "v_alias_as_space", + "CREATE VIEW v_alias_as_space AS SELECT a AS col1, b col2 FROM t", + "", + false, + }, + { + "expression_explicit_alias", + "v_expr_alias", + "CREATE VIEW v_expr_alias AS SELECT a + b AS sum_col FROM t", + "", + false, + }, + { + "literal_auto_alias", + "v_lit_auto", + "CREATE VIEW v_lit_auto AS SELECT 1 FROM t", + "", + false, + }, + { + "star_expansion", + "v_star", + "CREATE VIEW v_star AS SELECT * FROM t", + "", + false, + }, + { + "same_name_columns_join", + "v_same_name_join", + "CREATE VIEW v_same_name_join AS SELECT t1.a, t2.a FROM t1 JOIN t2 ON t1.a = t2.a", + "", + false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + tc.viewName) + if err := ctr.execSQLDirect(tc.viewSQL); err != nil { + t.Skipf("MySQL 8.0 rejected: %v", err) + return + } + mysqlOutput, err := ctr.showCreateView(tc.viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Omni catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + cat.Exec("CREATE TABLE t1 (a INT, b INT)", nil) + cat.Exec("CREATE TABLE t2 (a INT, b INT)", nil) + if tc.tables != "" { + cat.Exec(tc.tables, nil) + } + results, _ := cat.Exec(tc.viewSQL, nil) + if len(results) == 0 { + if tc.partial { + t.Skipf("CREATE VIEW on catalog returned no results (expected partial)") + } + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + if tc.partial { + t.Skipf("CREATE VIEW on catalog failed (expected partial): %v", results[0].Error) + } + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", tc.viewName) + if omniOutput == "" { + if tc.partial { + t.Skip("ShowCreateView returned empty (expected partial)") + } + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + if tc.partial { + t.Skipf("SELECT body mismatch (expected partial):\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_Section_6_1_SubqueryPatterns verifies that subquery patterns +// in views match MySQL 8.0 SHOW CREATE VIEW output. +func TestDeparseContainer_Section_6_1_SubqueryPatterns(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base tables on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table t on MySQL: %v", err) + } + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t1 (a INT, b INT)"); err != nil { + t.Fatalf("failed to create table t1 on MySQL: %v", err) + } + + cases := []struct { + name string + createAs string + partial bool + }{ + { + "scalar_subquery_in_select", + "SELECT (SELECT MAX(a) FROM t) FROM t", + false, + }, + { + "in_subquery", + "SELECT * FROM t WHERE a IN (SELECT a FROM t WHERE a > 0)", + false, + }, + { + "exists_subquery", + "SELECT * FROM t WHERE EXISTS (SELECT 1 FROM t WHERE a > 0)", + false, + }, + { + "correlated_subquery", + "SELECT a, (SELECT COUNT(*) FROM t t2 WHERE t2.a = t1.a) FROM t t1", + false, + }, + { + "nested_subqueries_2_levels", + "SELECT * FROM t WHERE a IN (SELECT a FROM t WHERE b IN (SELECT b FROM t WHERE c > 0))", + false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + viewName := "v_" + tc.name + + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + viewName) + createSQL := fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs) + if err := ctr.execSQLDirect(createSQL); err != nil { + if tc.partial { + t.Skipf("MySQL 8.0 rejected (expected partial): %v", err) + } + t.Fatalf("CREATE VIEW on MySQL failed: %v", err) + } + mysqlOutput, err := ctr.showCreateView(viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Our catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + cat.Exec("CREATE TABLE t1 (a INT, b INT)", nil) + results, _ := cat.Exec(fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs), nil) + if len(results) == 0 { + if tc.partial { + t.Skipf("CREATE VIEW on catalog returned no results (expected partial)") + } + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + if tc.partial { + t.Skipf("CREATE VIEW on catalog failed (expected partial): %v", results[0].Error) + } + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", viewName) + if omniOutput == "" { + if tc.partial { + t.Skip("ShowCreateView returned empty (expected partial)") + } + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + if tc.partial { + t.Skipf("SELECT body mismatch (expected partial):\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +func TestDeparseContainer_Section_6_2_CTEPatterns(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table t on MySQL: %v", err) + } + + cases := []struct { + name string + createAs string + partial bool + }{ + { + "simple_cte", + "WITH cte AS (SELECT a FROM t) SELECT * FROM cte", + false, + }, + { + "cte_with_column_list", + "WITH cte(x) AS (SELECT a FROM t) SELECT x FROM cte", + false, + }, + { + "recursive_cte", + "WITH RECURSIVE cte AS (SELECT 1 AS n UNION ALL SELECT n + 1 FROM cte WHERE n < 10) SELECT * FROM cte", + false, + }, + { + "multiple_ctes", + "WITH c1 AS (SELECT a FROM t), c2 AS (SELECT b FROM t) SELECT c1.a, c2.b FROM c1, c2", + false, + }, + { + "cte_used_in_union", + "WITH cte AS (SELECT a FROM t) SELECT * FROM cte UNION SELECT * FROM cte", + false, + }, + { + "recursive_cte_multi_col", + "WITH RECURSIVE cte AS (SELECT 1 AS n, 'a' AS s UNION ALL SELECT n + 1, CONCAT(s, 'a') FROM cte WHERE n < 5) SELECT * FROM cte", + false, + }, + { + "recursive_cte_join_base", + "WITH RECURSIVE cte AS (SELECT a, 1 AS depth FROM t WHERE a = 1 UNION ALL SELECT t.a, cte.depth + 1 FROM t JOIN cte ON t.b = cte.a WHERE cte.depth < 5) SELECT * FROM cte", + false, + }, + { + "cte_references_cte", + "WITH c1 AS (SELECT a FROM t), c2 AS (SELECT * FROM c1 WHERE a > 0) SELECT * FROM c2", + false, + }, + { + "cte_with_aggregation", + "WITH cte AS (SELECT a, COUNT(*) AS cnt FROM t GROUP BY a) SELECT * FROM cte", + false, + }, + { + "cte_with_inner_union", + "WITH cte AS (SELECT a FROM t UNION SELECT b FROM t) SELECT * FROM cte", + false, + }, + { + "cte_with_complex_expr", + "WITH cte AS (SELECT a + b AS sum_val, CASE WHEN a > 0 THEN 'pos' ELSE 'neg' END AS label FROM t) SELECT * FROM cte", + false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + viewName := "v_" + tc.name + + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + viewName) + createSQL := fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs) + if err := ctr.execSQLDirect(createSQL); err != nil { + if tc.partial { + t.Skipf("MySQL 8.0 rejected (expected partial): %v", err) + } + t.Fatalf("CREATE VIEW on MySQL failed: %v", err) + } + mysqlOutput, err := ctr.showCreateView(viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Our catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs), nil) + if len(results) == 0 { + if tc.partial { + t.Skipf("CREATE VIEW on catalog returned no results (expected partial)") + } + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + if tc.partial { + t.Skipf("CREATE VIEW on catalog failed (expected partial): %v", results[0].Error) + } + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", viewName) + if omniOutput == "" { + if tc.partial { + t.Skip("ShowCreateView returned empty (expected partial)") + } + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + if tc.partial { + t.Skipf("SELECT body mismatch (expected partial):\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_8_1_ViewOfViewComplexStructures verifies view-of-view, +// many-column views, reserved word aliases, CASE without ELSE, and BETWEEN +// with column bounds against MySQL 8.0. +func TestDeparseContainer_8_1_ViewOfViewComplexStructures(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base table on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table t on MySQL: %v", err) + } + + // --- Test 1: View referencing another view --- + t.Run("view_of_view", func(t *testing.T) { + // MySQL side: create v1, then v2 referencing v1 + ctr.execSQLDirect("DROP VIEW IF EXISTS v2_vov") + ctr.execSQLDirect("DROP VIEW IF EXISTS v1_vov") + if err := ctr.execSQLDirect("CREATE VIEW v1_vov AS SELECT a FROM t"); err != nil { + t.Fatalf("CREATE VIEW v1_vov on MySQL failed: %v", err) + } + if err := ctr.execSQLDirect("CREATE VIEW v2_vov AS SELECT * FROM v1_vov"); err != nil { + t.Fatalf("CREATE VIEW v2_vov on MySQL failed: %v", err) + } + mysqlOutput, err := ctr.showCreateView("v2_vov") + if err != nil { + t.Fatalf("SHOW CREATE VIEW v2_vov on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // Omni side + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec("CREATE VIEW v1_vov AS SELECT a FROM t", nil) + if len(results) > 0 && results[0].Error != nil { + t.Fatalf("CREATE VIEW v1_vov on catalog failed: %v", results[0].Error) + } + results, _ = cat.Exec("CREATE VIEW v2_vov AS SELECT * FROM v1_vov", nil) + if len(results) > 0 && results[0].Error != nil { + t.Fatalf("CREATE VIEW v2_vov on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", "v2_vov") + if omniOutput == "" { + t.Fatal("ShowCreateView returned empty for v2_vov") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + + // --- Tests 2-5: table-driven --- + cases := []struct { + name string + createAs string + }{ + { + "ten_plus_columns", + "SELECT a, b, c, a + 1, b + 1, c + 1, a * b, b * c, a * c, a + b + c FROM t", + }, + { + "reserved_word_aliases", + "SELECT a AS `select`, b AS `from`, c AS `where` FROM t", + }, + { + "case_without_else", + "SELECT CASE WHEN a > 0 THEN 'pos' WHEN a < 0 THEN 'neg' END FROM t", + }, + { + "between_column_bounds", + "SELECT a BETWEEN b AND c FROM t", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + viewName := "v_" + tc.name + + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + viewName) + createSQL := fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs) + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Fatalf("CREATE VIEW on MySQL failed: %v", err) + } + mysqlOutput, err := ctr.showCreateView(viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Our catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + results, _ := cat.Exec(fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs), nil) + if len(results) == 0 { + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", viewName) + if omniOutput == "" { + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} + +// TestDeparseContainer_8_2_ExpressionEdgeCases verifies expression edge cases and stress tests +// against real MySQL 8.0 SHOW CREATE VIEW output. +func TestDeparseContainer_8_2_ExpressionEdgeCases(t *testing.T) { + if testing.Short() { + t.Skip("skipping container test in short mode") + } + ctr, cleanup := startContainer(t) + defer cleanup() + + // Setup: create base tables on MySQL 8.0 + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t (a INT, b INT, c INT)"); err != nil { + t.Fatalf("failed to create table t on MySQL: %v", err) + } + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS t_rw (`select` INT, b INT)"); err != nil { + t.Fatalf("failed to create table t_rw on MySQL: %v", err) + } + if err := ctr.execSQLDirect("CREATE TABLE IF NOT EXISTS tc (a VARCHAR(50), b INT, c INT)"); err != nil { + t.Fatalf("failed to create table tc on MySQL: %v", err) + } + + cases := []struct { + name string + setupTable string // which table(s) to create on omni side ("t", "t_rw", "tc") + createAs string // the SELECT portion after CREATE VIEW v AS + }{ + { + "cast_with_expression", + "t", + "SELECT CAST(a + b AS CHAR) FROM t", + }, + { + "collate_expression", + "tc", + "SELECT a COLLATE utf8mb4_unicode_ci FROM tc", + }, + { + "interval_arithmetic", + "t", + "SELECT INTERVAL 1 DAY + a FROM t", + }, + { + "sounds_like", + "tc", + "SELECT a SOUNDS LIKE b FROM tc", + }, + { + "unary_operators", + "t", + "SELECT -a, +a FROM t", + }, + { + "long_expression_auto_alias", + "t", + "SELECT CASE WHEN a > 0 THEN CONCAT(a, b, c) WHEN a < 0 THEN CONCAT(c, b, a) ELSE NULL END FROM t", + }, + { + "reserved_word_column", + "t_rw", + "SELECT `select` FROM t_rw", + }, + { + "multiple_rewrites", + "t", + "SELECT a + b, NOT (a > 0), CAST(a AS CHAR), COUNT(*), a REGEXP 'x' FROM t GROUP BY a, b HAVING COUNT(*) > 1", + }, + { + "all_logical_operators", + "t", + "SELECT a AND b, a OR b, NOT a, a XOR b, !a FROM t", + }, + { + "function_boolean_context", + "t", + "SELECT IFNULL(a, 0) AND COALESCE(b, 0) FROM t", + }, + { + "mixed_boolean_precedence", + "t", + "SELECT (a > 0) AND (b + 1) OR (c IS NULL) FROM t", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + viewName := "v_82_" + tc.name + + // --- MySQL 8.0 side --- + ctr.execSQLDirect("DROP VIEW IF EXISTS " + viewName) + createSQL := fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs) + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Skipf("MySQL 8.0 rejected: %v", err) + return + } + mysqlOutput, err := ctr.showCreateView(viewName) + if err != nil { + t.Fatalf("SHOW CREATE VIEW on MySQL failed: %v", err) + } + mysqlBody := stripDatabasePrefix(extractSelectBody(mysqlOutput)) + + // --- Our catalog side --- + cat := New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + // Create the appropriate table(s) + switch tc.setupTable { + case "t": + cat.Exec("CREATE TABLE t (a INT, b INT, c INT)", nil) + case "t_rw": + cat.Exec("CREATE TABLE t_rw (`select` INT, b INT)", nil) + case "tc": + cat.Exec("CREATE TABLE tc (a VARCHAR(50), b INT, c INT)", nil) + } + results, _ := cat.Exec(fmt.Sprintf("CREATE VIEW %s AS %s", viewName, tc.createAs), nil) + if len(results) == 0 { + t.Fatalf("CREATE VIEW on catalog returned no results") + } + if results[0].Error != nil { + t.Fatalf("CREATE VIEW on catalog failed: %v", results[0].Error) + } + omniOutput := cat.ShowCreateView("test", viewName) + if omniOutput == "" { + t.Fatal("ShowCreateView returned empty") + } + omniBody := extractSelectBody(omniOutput) + + t.Logf("MySQL body: %s", mysqlBody) + t.Logf("Omni body: %s", omniBody) + + if mysqlBody != omniBody { + t.Errorf("SELECT body mismatch:\n--- mysql ---\n%s\n--- omni ---\n%s", mysqlBody, omniBody) + } + }) + } +} diff --git a/tidb/catalog/deparse_expr.go b/tidb/catalog/deparse_expr.go new file mode 100644 index 00000000..ec84fc8f --- /dev/null +++ b/tidb/catalog/deparse_expr.go @@ -0,0 +1,367 @@ +package catalog + +import ( + "fmt" + "strings" + "unicode" +) + +// DeparseAnalyzedExpr converts an analyzed expression back to SQL text. +func DeparseAnalyzedExpr(expr AnalyzedExpr, q *Query) string { + if expr == nil { + return "" + } + switch e := expr.(type) { + case *VarExprQ: + return deparseVarExprQ(e, q) + case *ConstExprQ: + return deparseConstExprQ(e) + case *OpExprQ: + return deparseOpExprQ(e, q) + case *BoolExprQ: + return deparseBoolExprQ(e, q) + case *FuncCallExprQ: + return deparseFuncCallExprQ(e, q) + case *CaseExprQ: + return deparseCaseExprQ(e, q) + case *CoalesceExprQ: + return deparseCoalesceExprQ(e, q) + case *NullTestExprQ: + return deparseNullTestExprQ(e, q) + case *InListExprQ: + return deparseInListExprQ(e, q) + case *BetweenExprQ: + return deparseBetweenExprQ(e, q) + case *SubLinkExprQ: + return deparseSubLinkExprQ(e, q) + case *CastExprQ: + return deparseCastExprQ(e, q) + case *RowExprQ: + return deparseRowExprQ(e, q) + default: + return fmt.Sprintf("/* unknown expr %T */", expr) + } +} + +func deparseVarExprQ(v *VarExprQ, q *Query) string { + if q == nil || v.RangeIdx < 0 || v.RangeIdx >= len(q.RangeTable) { + return "?" + } + rte := q.RangeTable[v.RangeIdx] + colName := "?" + if v.AttNum >= 1 && v.AttNum <= len(rte.ColNames) { + colName = rte.ColNames[v.AttNum-1] + } + return backtickIDQ(rte.ERef) + "." + backtickIDQ(colName) +} + +func deparseConstExprQ(c *ConstExprQ) string { + if c.IsNull { + return "NULL" + } + if isNumericLiteralQ(c.Value) || isBoolLiteralQ(c.Value) { + return c.Value + } + return "'" + strings.ReplaceAll(c.Value, "'", "''") + "'" +} + +func deparseOpExprQ(o *OpExprQ, q *Query) string { + // Postfix operators: IS TRUE, IS FALSE, IS UNKNOWN, etc. + if o.Left != nil && o.Right == nil { + return "(" + DeparseAnalyzedExpr(o.Left, q) + " " + o.Op + ")" + } + // Prefix unary operators: -x, ~x, BINARY x + if o.Left == nil && o.Right != nil { + if o.Op == "-" || o.Op == "~" { + return "(" + o.Op + DeparseAnalyzedExpr(o.Right, q) + ")" + } + return "(" + o.Op + " " + DeparseAnalyzedExpr(o.Right, q) + ")" + } + return "(" + DeparseAnalyzedExpr(o.Left, q) + " " + o.Op + " " + DeparseAnalyzedExpr(o.Right, q) + ")" +} + +func deparseBoolExprQ(b *BoolExprQ, q *Query) string { + switch b.Op { + case BoolAnd: + parts := make([]string, len(b.Args)) + for i, arg := range b.Args { + parts[i] = DeparseAnalyzedExpr(arg, q) + } + return "(" + strings.Join(parts, " and ") + ")" + case BoolOr: + parts := make([]string, len(b.Args)) + for i, arg := range b.Args { + parts[i] = DeparseAnalyzedExpr(arg, q) + } + return "(" + strings.Join(parts, " or ") + ")" + case BoolNot: + if len(b.Args) > 0 { + return "(not " + DeparseAnalyzedExpr(b.Args[0], q) + ")" + } + return "(not ?)" + default: + return "?" + } +} + +func deparseFuncCallExprQ(f *FuncCallExprQ, q *Query) string { + if f.IsAggregate && len(f.Args) == 0 && strings.ToLower(f.Name) == "count" { + return "count(*)" + } + + var sb strings.Builder + sb.WriteString(f.Name) + sb.WriteByte('(') + + if f.Distinct { + sb.WriteString("distinct ") + } + + for i, arg := range f.Args { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(DeparseAnalyzedExpr(arg, q)) + } + sb.WriteByte(')') + + if f.Over != nil { + sb.WriteString(" over ") + sb.WriteString(deparseWindowDefQ(f.Over, q)) + } + + return sb.String() +} + +func deparseWindowDefQ(w *WindowDefQ, q *Query) string { + if w.Name != "" && len(w.PartitionBy) == 0 && len(w.OrderBy) == 0 && w.FrameClause == "" { + return backtickIDQ(w.Name) + } + + var sb strings.Builder + sb.WriteByte('(') + if w.Name != "" { + sb.WriteString(backtickIDQ(w.Name) + " ") + } + + needSpace := false + if len(w.PartitionBy) > 0 { + sb.WriteString("partition by ") + for i, expr := range w.PartitionBy { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(DeparseAnalyzedExpr(expr, q)) + } + needSpace = true + } + if len(w.OrderBy) > 0 { + if needSpace { + sb.WriteByte(' ') + } + sb.WriteString("order by ") + for i, sc := range w.OrderBy { + if i > 0 { + sb.WriteString(", ") + } + if sc.TargetIdx >= 1 && sc.TargetIdx <= len(q.TargetList) { + sb.WriteString(DeparseAnalyzedExpr(q.TargetList[sc.TargetIdx-1].Expr, q)) + } + if sc.Descending { + sb.WriteString(" desc") + } + } + needSpace = true + } + if w.FrameClause != "" { + if needSpace { + sb.WriteByte(' ') + } + sb.WriteString(strings.ToLower(w.FrameClause)) + } + sb.WriteByte(')') + return sb.String() +} + +func deparseCaseExprQ(c *CaseExprQ, q *Query) string { + var sb strings.Builder + sb.WriteString("case") + if c.TestExpr != nil { + sb.WriteByte(' ') + sb.WriteString(DeparseAnalyzedExpr(c.TestExpr, q)) + } + for _, w := range c.Args { + sb.WriteString(" when ") + sb.WriteString(DeparseAnalyzedExpr(w.Cond, q)) + sb.WriteString(" then ") + sb.WriteString(DeparseAnalyzedExpr(w.Then, q)) + } + if c.Default != nil { + sb.WriteString(" else ") + sb.WriteString(DeparseAnalyzedExpr(c.Default, q)) + } + sb.WriteString(" end") + return sb.String() +} + +func deparseCoalesceExprQ(c *CoalesceExprQ, q *Query) string { + var sb strings.Builder + sb.WriteString("coalesce(") + for i, arg := range c.Args { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(DeparseAnalyzedExpr(arg, q)) + } + sb.WriteByte(')') + return sb.String() +} + +func deparseNullTestExprQ(n *NullTestExprQ, q *Query) string { + arg := DeparseAnalyzedExpr(n.Arg, q) + if n.IsNull { + return "(" + arg + " is null)" + } + return "(" + arg + " is not null)" +} + +func deparseInListExprQ(in *InListExprQ, q *Query) string { + var sb strings.Builder + sb.WriteString(DeparseAnalyzedExpr(in.Arg, q)) + if in.Negated { + sb.WriteString(" not in (") + } else { + sb.WriteString(" in (") + } + for i, item := range in.List { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(DeparseAnalyzedExpr(item, q)) + } + sb.WriteByte(')') + return sb.String() +} + +func deparseBetweenExprQ(be *BetweenExprQ, q *Query) string { + arg := DeparseAnalyzedExpr(be.Arg, q) + lower := DeparseAnalyzedExpr(be.Lower, q) + upper := DeparseAnalyzedExpr(be.Upper, q) + if be.Negated { + return "(" + arg + " not between " + lower + " and " + upper + ")" + } + return "(" + arg + " between " + lower + " and " + upper + ")" +} + +func deparseSubLinkExprQ(s *SubLinkExprQ, q *Query) string { + inner := DeparseQuery(s.Subquery) + switch s.Kind { + case SubLinkExists: + return "exists (" + inner + ")" + case SubLinkScalar: + return "(" + inner + ")" + case SubLinkIn: + return DeparseAnalyzedExpr(s.TestExpr, q) + " in (" + inner + ")" + case SubLinkAny: + return DeparseAnalyzedExpr(s.TestExpr, q) + " " + s.Op + " any (" + inner + ")" + case SubLinkAll: + return DeparseAnalyzedExpr(s.TestExpr, q) + " " + s.Op + " all (" + inner + ")" + default: + return "(" + inner + ")" + } +} + +func deparseCastExprQ(c *CastExprQ, q *Query) string { + arg := DeparseAnalyzedExpr(c.Arg, q) + typeName := deparseResolvedTypeQ(c.TargetType) + return "cast(" + arg + " as " + typeName + ")" +} + +func deparseRowExprQ(r *RowExprQ, q *Query) string { + var sb strings.Builder + sb.WriteString("row(") + for i, arg := range r.Args { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(DeparseAnalyzedExpr(arg, q)) + } + sb.WriteByte(')') + return sb.String() +} + +func deparseResolvedTypeQ(rt *ResolvedType) string { + if rt == nil { + return "char" + } + switch rt.BaseType { + case BaseTypeBigInt: + if rt.Unsigned { + return "unsigned" + } + return "signed" + case BaseTypeChar: + if rt.Length > 0 { + return fmt.Sprintf("char(%d)", rt.Length) + } + return "char" + case BaseTypeBinary: + if rt.Length > 0 { + return fmt.Sprintf("binary(%d)", rt.Length) + } + return "binary" + case BaseTypeDecimal: + if rt.Precision > 0 && rt.Scale > 0 { + return fmt.Sprintf("decimal(%d, %d)", rt.Precision, rt.Scale) + } + if rt.Precision > 0 { + return fmt.Sprintf("decimal(%d)", rt.Precision) + } + return "decimal" + case BaseTypeDate: + return "date" + case BaseTypeDateTime: + return "datetime" + case BaseTypeTime: + return "time" + case BaseTypeJSON: + return "json" + case BaseTypeFloat: + return "float" + case BaseTypeDouble: + return "double" + default: + return "char" + } +} + +func isNumericLiteralQ(s string) bool { + if s == "" { + return false + } + start := 0 + if s[0] == '-' || s[0] == '+' { + start = 1 + } + if start >= len(s) { + return false + } + hasDot := false + for i := start; i < len(s); i++ { + c := rune(s[i]) + if c == '.' { + if hasDot { + return false + } + hasDot = true + } else if !unicode.IsDigit(c) { + return false + } + } + return true +} + +func isBoolLiteralQ(s string) bool { + upper := strings.ToUpper(s) + return upper == "TRUE" || upper == "FALSE" +} diff --git a/tidb/catalog/deparse_query.go b/tidb/catalog/deparse_query.go new file mode 100644 index 00000000..57919440 --- /dev/null +++ b/tidb/catalog/deparse_query.go @@ -0,0 +1,292 @@ +package catalog + +import ( + "strings" +) + +// DeparseQuery converts an analyzed Query IR back to canonical SQL text. +func DeparseQuery(q *Query) string { + if q == nil { + return "" + } + if q.SetOp != SetOpNone { + return deparseSetOpQuery(q) + } + return deparseSimpleQuery(q) +} + +func deparseSimpleQuery(q *Query) string { + var b strings.Builder + + if len(q.CTEList) > 0 { + deparseCTEsQ(&b, q) + } + + b.WriteString("select ") + + if q.Distinct { + b.WriteString("distinct ") + } + + deparseTargetListQ(&b, q) + + if len(q.RangeTable) > 0 && q.JoinTree != nil && len(q.JoinTree.FromList) > 0 { + b.WriteString(" from ") + deparseFromClauseQ(&b, q) + } + + if q.JoinTree != nil && q.JoinTree.Quals != nil { + b.WriteString(" where ") + b.WriteString(DeparseAnalyzedExpr(q.JoinTree.Quals, q)) + } + + if len(q.GroupClause) > 0 { + b.WriteString(" group by ") + deparseGroupByQ(&b, q) + } + + if q.HavingQual != nil { + b.WriteString(" having ") + b.WriteString(DeparseAnalyzedExpr(q.HavingQual, q)) + } + + if len(q.SortClause) > 0 { + b.WriteString(" order by ") + deparseOrderByQ(&b, q) + } + + if q.LimitCount != nil { + b.WriteString(" limit ") + b.WriteString(DeparseAnalyzedExpr(q.LimitCount, q)) + if q.LimitOffset != nil { + b.WriteString(" offset ") + b.WriteString(DeparseAnalyzedExpr(q.LimitOffset, q)) + } + } + + return b.String() +} + +func deparseCTEsQ(b *strings.Builder, q *Query) { + b.WriteString("with ") + if q.IsRecursive { + b.WriteString("recursive ") + } + for i, cte := range q.CTEList { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(backtickIDQ(cte.Name)) + if len(cte.ColumnNames) > 0 { + b.WriteByte('(') + for j, col := range cte.ColumnNames { + if j > 0 { + b.WriteString(", ") + } + b.WriteString(backtickIDQ(col)) + } + b.WriteByte(')') + } + b.WriteString(" as (") + b.WriteString(DeparseQuery(cte.Query)) + b.WriteString(") ") + } +} + +func deparseTargetListQ(b *strings.Builder, q *Query) { + first := true + for _, te := range q.TargetList { + if te.ResJunk { + continue + } + if !first { + b.WriteString(", ") + } + first = false + + exprText := DeparseAnalyzedExpr(te.Expr, q) + b.WriteString(exprText) + + // Omit alias when expression is a simple column ref whose name matches ResName. + needAlias := true + if v, ok := te.Expr.(*VarExprQ); ok { + if v.RangeIdx >= 0 && v.RangeIdx < len(q.RangeTable) { + rte := q.RangeTable[v.RangeIdx] + if v.AttNum >= 1 && v.AttNum <= len(rte.ColNames) { + if strings.EqualFold(rte.ColNames[v.AttNum-1], te.ResName) { + needAlias = false + } + } + } + } + if needAlias { + b.WriteString(" as ") + b.WriteString(backtickIDQ(te.ResName)) + } + } +} + +func deparseFromClauseQ(b *strings.Builder, q *Query) { + for i, node := range q.JoinTree.FromList { + if i > 0 { + b.WriteString(", ") + } + deparseJoinNodeQ(b, node, q) + } +} + +func deparseJoinNodeQ(b *strings.Builder, node JoinNode, q *Query) { + switch n := node.(type) { + case *RangeTableRefQ: + deparseRTERefQ(b, n.RTIndex, q) + case *JoinExprNodeQ: + deparseJoinExprNodeQ(b, n, q) + } +} + +func deparseRTERefQ(b *strings.Builder, idx int, q *Query) { + if idx < 0 || idx >= len(q.RangeTable) { + b.WriteString("?") + return + } + rte := q.RangeTable[idx] + switch rte.Kind { + case RTERelation: + if rte.DBName != "" { + b.WriteString(backtickIDQ(rte.DBName)) + b.WriteByte('.') + } + b.WriteString(backtickIDQ(rte.TableName)) + if rte.Alias != "" && rte.Alias != rte.TableName { + b.WriteString(" as ") + b.WriteString(backtickIDQ(rte.Alias)) + } + case RTESubquery: + b.WriteByte('(') + b.WriteString(DeparseQuery(rte.Subquery)) + b.WriteString(") as ") + b.WriteString(backtickIDQ(rte.ERef)) + case RTECTE: + b.WriteString(backtickIDQ(rte.CTEName)) + if rte.Alias != "" && rte.Alias != rte.CTEName { + b.WriteString(" as ") + b.WriteString(backtickIDQ(rte.Alias)) + } + case RTEJoin: + // Synthetic; actual join structure is in JoinExprNodeQ. + case RTEFunction: + b.WriteString("/* function-in-FROM */") + } +} + +func deparseJoinExprNodeQ(b *strings.Builder, j *JoinExprNodeQ, q *Query) { + deparseJoinNodeQ(b, j.Left, q) + b.WriteByte(' ') + + if j.Natural { + b.WriteString("natural ") + } + + switch j.JoinType { + case JoinLeft: + b.WriteString("left join ") + case JoinRight: + b.WriteString("right join ") + case JoinCross: + b.WriteString("cross join ") + case JoinStraight: + b.WriteString("straight_join ") + default: // JoinInner + b.WriteString("join ") + } + + deparseJoinNodeQ(b, j.Right, q) + + if j.Quals != nil { + b.WriteString(" on ") + b.WriteString(DeparseAnalyzedExpr(j.Quals, q)) + } + + if len(j.UsingClause) > 0 && j.Quals == nil { + b.WriteString(" using (") + for i, col := range j.UsingClause { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(backtickIDQ(col)) + } + b.WriteByte(')') + } +} + +func deparseGroupByQ(b *strings.Builder, q *Query) { + for i, sc := range q.GroupClause { + if i > 0 { + b.WriteString(", ") + } + if sc.TargetIdx >= 1 && sc.TargetIdx <= len(q.TargetList) { + b.WriteString(DeparseAnalyzedExpr(q.TargetList[sc.TargetIdx-1].Expr, q)) + } + } +} + +func deparseOrderByQ(b *strings.Builder, q *Query) { + for i, sc := range q.SortClause { + if i > 0 { + b.WriteString(", ") + } + if sc.TargetIdx >= 1 && sc.TargetIdx <= len(q.TargetList) { + b.WriteString(DeparseAnalyzedExpr(q.TargetList[sc.TargetIdx-1].Expr, q)) + } + if sc.Descending { + b.WriteString(" desc") + } + } +} + +func deparseSetOpQuery(q *Query) string { + var b strings.Builder + + if len(q.CTEList) > 0 { + deparseCTEsQ(&b, q) + } + + b.WriteString(DeparseQuery(q.LArg)) + b.WriteByte(' ') + + switch q.SetOp { + case SetOpUnion: + b.WriteString("union") + case SetOpIntersect: + b.WriteString("intersect") + case SetOpExcept: + b.WriteString("except") + } + if q.AllSetOp { + b.WriteString(" all") + } + + b.WriteByte(' ') + b.WriteString(DeparseQuery(q.RArg)) + + if len(q.SortClause) > 0 { + b.WriteString(" order by ") + deparseOrderByQ(&b, q) + } + + if q.LimitCount != nil { + b.WriteString(" limit ") + b.WriteString(DeparseAnalyzedExpr(q.LimitCount, q)) + if q.LimitOffset != nil { + b.WriteString(" offset ") + b.WriteString(DeparseAnalyzedExpr(q.LimitOffset, q)) + } + } + + return b.String() +} + +// backtickIDQ wraps an identifier in backticks. +func backtickIDQ(name string) string { + return "`" + strings.ReplaceAll(name, "`", "``") + "`" +} diff --git a/tidb/catalog/deparse_query_test.go b/tidb/catalog/deparse_query_test.go new file mode 100644 index 00000000..9f09cdec --- /dev/null +++ b/tidb/catalog/deparse_query_test.go @@ -0,0 +1,246 @@ +package catalog + +import ( + "strings" + "testing" + + nodes "github.com/bytebase/omni/tidb/ast" + "github.com/bytebase/omni/tidb/parser" +) + +// parseSelectForDeparse parses a single SELECT and returns the AST node. +func parseSelectForDeparse(t *testing.T, sql string) *nodes.SelectStmt { + t.Helper() + list, err := parser.Parse(sql) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if list.Len() != 1 { + t.Fatalf("expected 1 statement, got %d", list.Len()) + } + sel, ok := list.Items[0].(*nodes.SelectStmt) + if !ok { + t.Fatalf("expected *ast.SelectStmt, got %T", list.Items[0]) + } + return sel +} + +// setupDeparseTestCatalog creates a catalog with test tables. +func setupDeparseTestCatalog(t *testing.T) *Catalog { + t.Helper() + c := New() + results, err := c.Exec("CREATE DATABASE testdb; USE testdb;", nil) + if err != nil { + t.Fatalf("setup parse error: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Fatalf("setup exec error: %v", r.Error) + } + } + + ddl := ` + CREATE TABLE employees ( + id INT PRIMARY KEY, + name VARCHAR(100), + salary DECIMAL(10,2), + department_id INT + ); + CREATE TABLE departments ( + id INT PRIMARY KEY, + name VARCHAR(100) + ); + ` + results, err = c.Exec(ddl, nil) + if err != nil { + t.Fatalf("DDL parse error: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Fatalf("DDL exec error: %v", r.Error) + } + } + return c +} + +// roundTrip parses, analyzes, deparses, re-parses, and re-analyzes. +// Returns both Query IRs and the intermediate SQL. +func roundTrip(t *testing.T, c *Catalog, sql string) (q1 *Query, sql2 string, q2 *Query) { + t.Helper() + + sel1 := parseSelectForDeparse(t, sql) + var err error + q1, err = c.AnalyzeSelectStmt(sel1) + if err != nil { + t.Fatalf("analyze q1 error: %v", err) + } + + sql2 = DeparseQuery(q1) + t.Logf("Original: %s", sql) + t.Logf("Deparsed: %s", sql2) + + sel2 := parseSelectForDeparse(t, sql2) + q2, err = c.AnalyzeSelectStmt(sel2) + if err != nil { + t.Fatalf("analyze q2 error (deparsed SQL: %s): %v", sql2, err) + } + + return q1, sql2, q2 +} + +// compareQueries checks structural equivalence between two Query IRs. +func compareQueries(t *testing.T, q1, q2 *Query) { + t.Helper() + + // Count non-junk targets. + nonJunk := func(q *Query) []*TargetEntryQ { + var out []*TargetEntryQ + for _, te := range q.TargetList { + if !te.ResJunk { + out = append(out, te) + } + } + return out + } + + tl1 := nonJunk(q1) + tl2 := nonJunk(q2) + if len(tl1) != len(tl2) { + t.Errorf("non-junk target count: q1=%d, q2=%d", len(tl1), len(tl2)) + return + } + + for i := range tl1 { + if !strings.EqualFold(tl1[i].ResName, tl2[i].ResName) { + t.Errorf("target[%d] ResName: q1=%q, q2=%q", i, tl1[i].ResName, tl2[i].ResName) + } + // Compare VarExprQ coordinates when both are VarExprQ. + v1, ok1 := tl1[i].Expr.(*VarExprQ) + v2, ok2 := tl2[i].Expr.(*VarExprQ) + if ok1 && ok2 { + if v1.RangeIdx != v2.RangeIdx { + t.Errorf("target[%d] VarExprQ.RangeIdx: q1=%d, q2=%d", i, v1.RangeIdx, v2.RangeIdx) + } + if v1.AttNum != v2.AttNum { + t.Errorf("target[%d] VarExprQ.AttNum: q1=%d, q2=%d", i, v1.AttNum, v2.AttNum) + } + } + } + + // Compare RTE count. + if len(q1.RangeTable) != len(q2.RangeTable) { + t.Errorf("RangeTable count: q1=%d, q2=%d", len(q1.RangeTable), len(q2.RangeTable)) + } + + // Compare WHERE existence. + hasWhere1 := q1.JoinTree != nil && q1.JoinTree.Quals != nil + hasWhere2 := q2.JoinTree != nil && q2.JoinTree.Quals != nil + if hasWhere1 != hasWhere2 { + t.Errorf("WHERE presence: q1=%v, q2=%v", hasWhere1, hasWhere2) + } + + // Compare GROUP BY count. + if len(q1.GroupClause) != len(q2.GroupClause) { + t.Errorf("GroupClause count: q1=%d, q2=%d", len(q1.GroupClause), len(q2.GroupClause)) + } + + // Compare HAVING existence. + if (q1.HavingQual != nil) != (q2.HavingQual != nil) { + t.Errorf("HAVING presence: q1=%v, q2=%v", q1.HavingQual != nil, q2.HavingQual != nil) + } + + // Compare ORDER BY count. + if len(q1.SortClause) != len(q2.SortClause) { + t.Errorf("SortClause count: q1=%d, q2=%d", len(q1.SortClause), len(q2.SortClause)) + } + + // Compare set op. + if q1.SetOp != q2.SetOp { + t.Errorf("SetOp: q1=%d, q2=%d", q1.SetOp, q2.SetOp) + } +} + +// TestDeparseQuery_16_1_SimpleRoundTrip tests a simple SELECT with WHERE. +func TestDeparseQuery_16_1_SimpleRoundTrip(t *testing.T) { + c := setupDeparseTestCatalog(t) + q1, _, q2 := roundTrip(t, c, "SELECT name, salary FROM employees WHERE salary > 50000") + compareQueries(t, q1, q2) +} + +// TestDeparseQuery_16_2_JoinRoundTrip tests a JOIN round-trip. +func TestDeparseQuery_16_2_JoinRoundTrip(t *testing.T) { + c := setupDeparseTestCatalog(t) + q1, _, q2 := roundTrip(t, c, "SELECT e.name, d.name FROM employees e JOIN departments d ON e.department_id = d.id") + compareQueries(t, q1, q2) +} + +// TestDeparseQuery_16_3_AggregateRoundTrip tests GROUP BY + HAVING round-trip. +func TestDeparseQuery_16_3_AggregateRoundTrip(t *testing.T) { + c := setupDeparseTestCatalog(t) + q1, _, q2 := roundTrip(t, c, "SELECT department_id, COUNT(*) FROM employees GROUP BY department_id HAVING COUNT(*) > 5") + compareQueries(t, q1, q2) +} + +// TestDeparseQuery_16_4_CTERoundTrip tests CTE round-trip. +func TestDeparseQuery_16_4_CTERoundTrip(t *testing.T) { + c := setupDeparseTestCatalog(t) + q1, _, q2 := roundTrip(t, c, "WITH cte AS (SELECT id, name FROM employees) SELECT id, name FROM cte") + compareQueries(t, q1, q2) +} + +// TestDeparseQuery_16_5_SetOpRoundTrip tests UNION ALL round-trip. +func TestDeparseQuery_16_5_SetOpRoundTrip(t *testing.T) { + c := setupDeparseTestCatalog(t) + q1, _, q2 := roundTrip(t, c, "SELECT name FROM employees UNION ALL SELECT name FROM departments") + compareQueries(t, q1, q2) +} + +// TestDeparseQuery_BareLiteral tests SELECT 1 (no FROM). +func TestDeparseQuery_BareLiteral(t *testing.T) { + c := setupDeparseTestCatalog(t) + q1, sql2, _ := roundTrip(t, c, "SELECT 1") + if q1 == nil { + t.Fatal("q1 is nil") + } + // Should not contain "from". + if strings.Contains(strings.ToLower(sql2), "from") { + t.Errorf("deparsed bare literal should not have FROM: %s", sql2) + } +} + +// TestDeparseQuery_DeparseOutput tests that DeparseQuery produces valid SQL text. +func TestDeparseQuery_DeparseOutput(t *testing.T) { + c := setupDeparseTestCatalog(t) + + tests := []struct { + name string + sql string + }{ + {"simple_select", "SELECT id, name FROM employees"}, + {"where_clause", "SELECT name FROM employees WHERE salary > 50000"}, + {"order_by", "SELECT name, salary FROM employees ORDER BY salary DESC"}, + {"limit", "SELECT name FROM employees LIMIT 10"}, + {"distinct", "SELECT DISTINCT department_id FROM employees"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sel := parseSelectForDeparse(t, tt.sql) + q, err := c.AnalyzeSelectStmt(sel) + if err != nil { + t.Fatalf("analyze error: %v", err) + } + result := DeparseQuery(q) + if result == "" { + t.Fatal("DeparseQuery returned empty string") + } + t.Logf("SQL: %s => %s", tt.sql, result) + + // Verify the deparsed SQL can be parsed back. + _, err = parser.Parse(result) + if err != nil { + t.Fatalf("deparsed SQL failed to parse: %v\nSQL: %s", err, result) + } + }) + } +} diff --git a/tidb/catalog/deparse_rules_test.go b/tidb/catalog/deparse_rules_test.go new file mode 100644 index 00000000..6a3ab5e8 --- /dev/null +++ b/tidb/catalog/deparse_rules_test.go @@ -0,0 +1,340 @@ +package catalog + +import ( + "fmt" + "testing" +) + +// TestMySQL_DeparseRules creates views with specific SQL patterns and examines +// SHOW CREATE VIEW output to discover MySQL 8.0's exact deparsing/formatting rules. +// This is a research test — it asserts nothing, only logs results. +func TestMySQL_DeparseRules(t *testing.T) { + ctr, cleanup := startContainer(t) + defer cleanup() + + // Create base tables used by the views. + setupSQL := ` + CREATE TABLE t (a INT, b INT, c INT); + CREATE TABLE t1 (a INT, b INT); + CREATE TABLE t2 (a INT, b INT); + ` + if err := ctr.execSQL(setupSQL); err != nil { + t.Fatalf("setup tables: %v", err) + } + + type viewTest struct { + name string + createAs string // the SELECT part after "CREATE VIEW vN AS " + } + + categories := []struct { + label string + views []viewTest + }{ + { + label: "A. Keyword Casing", + views: []viewTest{ + {"v1", "SELECT 1"}, + {"v2", "SELECT a FROM t WHERE a > 1 GROUP BY a HAVING COUNT(*) > 1 ORDER BY a LIMIT 10"}, + }, + }, + { + label: "B. Identifier Quoting", + views: []viewTest{ + {"v3", "SELECT a, b AS alias1 FROM t"}, + {"v4", "SELECT t.a FROM t"}, + {"v5", "SELECT `select` FROM t"}, + }, + }, + { + label: "C. Operator Formatting (spacing, case, parens)", + views: []viewTest{ + {"v6", "SELECT a+b, a-b, a*b, a/b, a DIV b, a MOD b, a%b FROM t"}, + {"v7", "SELECT a=b, a<>b, a!=b, a>b, a=b, a<=b, a<=>b FROM t"}, + {"v8", "SELECT a AND b, a OR b, NOT a, a XOR b FROM t"}, + {"v9", "SELECT a|b, a&b, a^b, a<>b, ~a FROM t"}, + {"v10", "SELECT a IS NULL, a IS NOT NULL, a IS TRUE, a IS FALSE FROM t"}, + {"v11", "SELECT a IN (1,2,3), a NOT IN (1,2,3) FROM t"}, + {"v12", "SELECT a BETWEEN 1 AND 10, a NOT BETWEEN 1 AND 10 FROM t"}, + {"v13", "SELECT a LIKE 'foo%', a NOT LIKE 'bar%', a LIKE 'x' ESCAPE '\\\\' FROM t"}, + {"v14", "SELECT a REGEXP 'pattern', a NOT REGEXP 'pattern' FROM t"}, + }, + }, + { + label: "D. Function Name Casing", + views: []viewTest{ + {"v15", "SELECT COUNT(*), SUM(a), AVG(a), MAX(a), MIN(a), COUNT(DISTINCT a) FROM t"}, + {"v16", "SELECT CONCAT(a, b), SUBSTRING(a, 1, 3), TRIM(a), UPPER(a), LOWER(a) FROM t"}, + {"v17", "SELECT NOW(), CURRENT_TIMESTAMP, CURRENT_DATE, CURRENT_TIME, CURRENT_USER FROM t"}, + {"v18", "SELECT IFNULL(a, 0), COALESCE(a, b, 0), NULLIF(a, 0), IF(a > 0, 'yes', 'no') FROM t"}, + {"v19", "SELECT CAST(a AS CHAR), CAST(a AS SIGNED), CONVERT(a, CHAR), CONVERT(a USING utf8mb4) FROM t"}, + }, + }, + { + label: "E. Expression Structure", + views: []viewTest{ + {"v20", "SELECT CASE WHEN a > 0 THEN 'pos' WHEN a < 0 THEN 'neg' ELSE 'zero' END FROM t"}, + {"v21", "SELECT CASE a WHEN 1 THEN 'one' WHEN 2 THEN 'two' ELSE 'other' END FROM t"}, + {"v22", "SELECT (a + b) * c, a + (b * c) FROM t"}, + {"v23", "SELECT -a, +a, !a FROM t"}, + }, + }, + { + label: "F. Literals", + views: []viewTest{ + {"v24", "SELECT 1, 1.5, 'hello', NULL, TRUE, FALSE FROM t"}, + {"v25", "SELECT 0xFF, X'FF', 0b1010, b'1010' FROM t"}, + {"v26", "SELECT _utf8mb4'hello', _latin1'world' FROM t"}, + {"v27", "SELECT DATE '2024-01-01', TIME '12:00:00', TIMESTAMP '2024-01-01 12:00:00' FROM t"}, + }, + }, + { + label: "G. Subqueries", + views: []viewTest{ + {"v28", "SELECT (SELECT MAX(a) FROM t) FROM t"}, + {"v29", "SELECT * FROM t WHERE a IN (SELECT a FROM t WHERE a > 0)"}, + {"v30", "SELECT * FROM t WHERE EXISTS (SELECT 1 FROM t WHERE a > 0)"}, + }, + }, + { + label: "H. JOIN Formatting", + views: []viewTest{ + {"v31", "SELECT t1.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a"}, + {"v32", "SELECT t1.a, t2.b FROM t1 LEFT JOIN t2 ON t1.a = t2.a"}, + {"v33", "SELECT t1.a, t2.b FROM t1 RIGHT JOIN t2 ON t1.a = t2.a"}, + {"v34", "SELECT t1.a, t2.b FROM t1 CROSS JOIN t2"}, + {"v35", "SELECT * FROM t1 NATURAL JOIN t2"}, + {"v36", "SELECT t1.a, t2.b FROM t1 STRAIGHT_JOIN t2 ON t1.a = t2.a"}, + {"v37", "SELECT t1.a, t2.b FROM t1 JOIN t2 USING (a)"}, + {"v38", "SELECT t1.a, t2.b FROM t1, t2 WHERE t1.a = t2.a"}, + }, + }, + { + label: "I. UNION/INTERSECT/EXCEPT", + views: []viewTest{ + {"v39", "SELECT a FROM t UNION SELECT b FROM t"}, + {"v40", "SELECT a FROM t UNION ALL SELECT b FROM t"}, + {"v41", "SELECT a FROM t UNION SELECT b FROM t UNION SELECT c FROM t"}, + }, + }, + { + label: "J. Window Functions", + views: []viewTest{ + {"v42", "SELECT a, ROW_NUMBER() OVER (ORDER BY a) FROM t"}, + {"v43", "SELECT a, SUM(b) OVER (PARTITION BY a ORDER BY b) FROM t"}, + {"v44", "SELECT a, SUM(b) OVER (PARTITION BY a ORDER BY b ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM t"}, + }, + }, + { + label: "K. Alias Formatting", + views: []viewTest{ + {"v45", "SELECT a AS col1, b col2 FROM t"}, + {"v46", "SELECT a + b AS sum_col FROM t"}, + {"v47", "SELECT * FROM t AS t1"}, + }, + }, + { + label: "L. Misc MySQL-specific", + views: []viewTest{ + {"v48", "SELECT a->>'$.key' FROM t"}, + {"v49", "SELECT INTERVAL 1 DAY + a FROM t"}, + }, + }, + { + label: "M. Semantic Rewrites (things MySQL changes fundamentally)", + views: []viewTest{ + // NOT LIKE → not((...like...)) + {"v50", "SELECT a NOT LIKE 'x%' FROM t"}, + // != → <> + {"v51", "SELECT a != b FROM t"}, + // COUNT(*) → count(0) + {"v52", "SELECT COUNT(*) FROM t"}, + // SUBSTRING → substr + {"v53", "SELECT SUBSTRING('abc', 1, 2) FROM t"}, + // CURRENT_TIMESTAMP → now() + {"v54", "SELECT CURRENT_TIMESTAMP() FROM t"}, + // MOD → % + {"v55", "SELECT a MOD b FROM t"}, + // +a → a (unary plus dropped) + {"v56", "SELECT +a FROM t"}, + // ! → (0 = ...) + {"v57", "SELECT !a FROM t"}, + // AND/OR on INT cols → (0 <> ...) and/or (0 <> ...) + {"v58", "SELECT a AND b FROM t"}, + // REGEXP → regexp_like() + {"v59", "SELECT a REGEXP 'x' FROM t"}, + // -> → json_extract() + {"v60", "SELECT a->'$.key' FROM t"}, + }, + }, + { + label: "N. Derived Tables & Subquery in FROM", + views: []viewTest{ + {"v61", "SELECT d.x FROM (SELECT a AS x FROM t) d"}, + {"v62", "SELECT d.x FROM (SELECT a AS x FROM t WHERE a > 0) AS d WHERE d.x < 10"}, + }, + }, + { + label: "O. GROUP_CONCAT & Special Aggregates", + views: []viewTest{ + {"v63", "SELECT GROUP_CONCAT(a ORDER BY a SEPARATOR ',') FROM t"}, + {"v64", "SELECT GROUP_CONCAT(DISTINCT a ORDER BY a DESC SEPARATOR ';') FROM t"}, + }, + }, + { + label: "P. Spacing & Comma Rules", + views: []viewTest{ + // Are there spaces after commas in function args? + {"v65", "SELECT CONCAT(a, b, c) FROM t"}, + // Spaces in IN list? + {"v66", "SELECT a IN (1, 2, 3) FROM t"}, + // Space before/after AS? + {"v67", "SELECT a AS x FROM t"}, + }, + }, + { + label: "Q. Table Alias Format (AS vs space)", + views: []viewTest{ + {"v68", "SELECT x.a FROM t AS x"}, + {"v69", "SELECT x.a FROM t x"}, + }, + }, + { + label: "R. Complex Precedence", + views: []viewTest{ + {"v70", "SELECT a + b + c FROM t"}, + {"v71", "SELECT a * b + c FROM t"}, + {"v72", "SELECT a + b * c FROM t"}, + {"v73", "SELECT a OR b AND c FROM t"}, + {"v74", "SELECT (a OR b) AND c FROM t"}, + {"v75", "SELECT a > 0 AND b < 10 OR c = 5 FROM t"}, + }, + }, + { + label: "S. CTE (WITH clause)", + views: []viewTest{ + {"v76", "WITH cte AS (SELECT a FROM t) SELECT * FROM cte"}, + {"v77", "WITH cte(x) AS (SELECT a FROM t) SELECT x FROM cte"}, + }, + }, + { + label: "T. Type-Aware Boolean Context (expressions in AND/OR)", + views: []viewTest{ + // Expression (not column) in boolean context + {"v78", "SELECT (a + 1) AND b FROM t"}, + {"v79", "SELECT (a > 0) AND (b + 1) FROM t"}, + {"v80", "SELECT (a > 0) AND (b > 0) FROM t"}, + // Function results in boolean context + {"v81", "SELECT ABS(a) AND b FROM t"}, + {"v82", "SELECT COUNT(*) > 0 FROM t"}, + // CASE in boolean context + {"v83", "SELECT CASE WHEN a > 0 THEN 1 ELSE 0 END AND b FROM t"}, + // Nested boolean + {"v84", "SELECT NOT (a + 1) FROM t"}, + {"v85", "SELECT NOT (a > 0) FROM t"}, + // IF result in boolean context + {"v86", "SELECT IF(a > 0, 1, 0) AND b FROM t"}, + // Subquery in boolean context + {"v87", "SELECT (SELECT MAX(a) FROM t) AND b FROM t"}, + }, + }, + { + label: "U. CAST charset inference", + views: []viewTest{ + {"v88", "SELECT CAST(a AS CHAR(10)) FROM t"}, + {"v89", "SELECT CAST(a AS BINARY) FROM t"}, + {"v90", "SELECT CAST(a AS DECIMAL(10,2)) FROM t"}, + {"v91", "SELECT CAST(a AS UNSIGNED) FROM t"}, + {"v92", "SELECT CAST(a AS DATE) FROM t"}, + {"v93", "SELECT CAST(a AS DATETIME) FROM t"}, + {"v94", "SELECT CAST(a AS JSON) FROM t"}, + }, + }, + { + label: "V. String vs INT in boolean context", + views: []viewTest{ + // Need a VARCHAR column to test string behavior + {"v95_setup", "SELECT 'hello' AND 1 FROM t"}, + {"v96", "SELECT a AND 'hello' FROM t"}, + {"v97", "SELECT CONCAT(a,b) AND 1 FROM t"}, + }, + }, + { + label: "W. Complex function return types", + views: []viewTest{ + {"v98", "SELECT IFNULL(a, 0) AND b FROM t"}, + {"v99", "SELECT COALESCE(a, b) AND 1 FROM t"}, + {"v100", "SELECT NULLIF(a, 0) AND b FROM t"}, + {"v101", "SELECT GREATEST(a, b) AND 1 FROM t"}, + {"v102", "SELECT LEAST(a, b) AND 1 FROM t"}, + }, + }, + { + label: "X. Comparison results NOT wrapped", + views: []viewTest{ + // These should NOT get (0 <> ...) wrapping because they're already boolean + {"v103", "SELECT (a = b) AND (a > 0) FROM t"}, + {"v104", "SELECT (a IN (1,2,3)) AND (b BETWEEN 1 AND 10) FROM t"}, + {"v105", "SELECT (a IS NULL) AND (b LIKE 'x%') FROM t"}, + {"v106", "SELECT (a = 1) OR (b = 2) FROM t"}, + {"v107", "SELECT EXISTS(SELECT 1 FROM t WHERE a > 0) AND (b > 0) FROM t"}, + }, + }, + { + label: "Y. Multi-table column qualification", + views: []viewTest{ + {"v108", "SELECT t1.a, t2.a FROM t1 JOIN t2 ON t1.a = t2.a"}, + {"v109", "SELECT a FROM t WHERE a > 0"}, + {"v110", "SELECT t.a FROM t"}, + }, + }, + { + label: "Z. Edge cases", + views: []viewTest{ + // Empty string + {"v111", "SELECT '' FROM t"}, + // Escaped quotes in strings + {"v112", "SELECT 'it''s' FROM t"}, + // Backslash in strings + {"v113", "SELECT 'back\\\\slash' FROM t"}, + // Very long expression + {"v114", "SELECT a + b + c + a + b + c FROM t"}, + // Nested function calls + {"v115", "SELECT CONCAT(UPPER(TRIM(a)), LOWER(b)) FROM t"}, + // Multiple aggregates + {"v116", "SELECT COUNT(*), SUM(a), AVG(b), MAX(c) FROM t GROUP BY a"}, + }, + }, + } + + t.Log("=== MySQL 8.0 Deparse Rules Research ===") + t.Log("") + + for _, cat := range categories { + t.Logf("--- %s ---", cat.label) + for _, vt := range cat.views { + createSQL := fmt.Sprintf("CREATE VIEW %s AS %s", vt.name, vt.createAs) + + if err := ctr.execSQLDirect(createSQL); err != nil { + t.Logf(" [%s] CREATE failed: %v", vt.name, err) + t.Logf(" INPUT: %s", createSQL) + t.Logf("") + continue + } + + output, err := ctr.showCreateView(vt.name) + if err != nil { + t.Logf(" [%s] SHOW CREATE VIEW failed: %v", vt.name, err) + t.Logf(" INPUT: %s", createSQL) + t.Logf("") + continue + } + + t.Logf(" [%s]", vt.name) + t.Logf(" INPUT: %s", vt.createAs) + t.Logf(" OUTPUT: %s", output) + t.Logf("") + } + t.Log("") + } +} diff --git a/tidb/catalog/dropcmds.go b/tidb/catalog/dropcmds.go new file mode 100644 index 00000000..4bd95898 --- /dev/null +++ b/tidb/catalog/dropcmds.go @@ -0,0 +1,76 @@ +package catalog + +import nodes "github.com/bytebase/omni/tidb/ast" + +func (c *Catalog) dropTable(stmt *nodes.DropTableStmt) error { + for _, ref := range stmt.Tables { + dbName := ref.Schema + db, err := c.resolveDatabase(dbName) + if err != nil { + if stmt.IfExists { + continue + } + return err + } + key := toLower(ref.Name) + if db.Tables[key] == nil { + if stmt.IfExists { + continue + } + return errUnknownTable(db.Name, ref.Name) + } + // Check if any other table in any database has a FK referencing this table + // (unless foreign_key_checks=0). + if c.foreignKeyChecks { + if err := c.checkFKReferences(db.Name, ref.Name); err != nil { + return err + } + } + delete(db.Tables, key) + } + return nil +} + +// checkFKReferences returns an error if any table in any database has a +// foreign key constraint that references the given table. +func (c *Catalog) checkFKReferences(dbName, tableName string) error { + dbKey := toLower(dbName) + tblKey := toLower(tableName) + for _, db := range c.databases { + for _, tbl := range db.Tables { + // Skip the table itself. + if toLower(db.Name) == dbKey && toLower(tbl.Name) == tblKey { + continue + } + for _, con := range tbl.Constraints { + if con.Type != ConForeignKey { + continue + } + refDB := con.RefDatabase + if refDB == "" { + refDB = db.Name + } + if toLower(refDB) == dbKey && toLower(con.RefTable) == tblKey { + return errFKCannotDropParent(tableName, con.Name, tbl.Name) + } + } + } + } + return nil +} + +func (c *Catalog) truncateTable(stmt *nodes.TruncateStmt) error { + for _, ref := range stmt.Tables { + dbName := ref.Schema + db, err := c.resolveDatabase(dbName) + if err != nil { + return err + } + tbl := db.GetTable(ref.Name) + if tbl == nil { + return errNoSuchTable(db.Name, ref.Name) + } + tbl.AutoIncrement = 0 + } + return nil +} diff --git a/tidb/catalog/dropcmds_test.go b/tidb/catalog/dropcmds_test.go new file mode 100644 index 00000000..d2fcfbce --- /dev/null +++ b/tidb/catalog/dropcmds_test.go @@ -0,0 +1,71 @@ +package catalog + +import "testing" + +func TestDropTable(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t1 (id INT)", nil) + _, err := c.Exec("DROP TABLE t1", nil) + if err != nil { + t.Fatal(err) + } + if c.GetDatabase("test").GetTable("t1") != nil { + t.Fatal("table should be dropped") + } +} + +func TestDropTableIfExists(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("DROP TABLE IF EXISTS noexist", nil) + if results[0].Error != nil { + t.Errorf("IF EXISTS should not error: %v", results[0].Error) + } +} + +func TestDropTableNotExists(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("DROP TABLE noexist", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("expected error") + } +} + +func TestDropMultipleTables(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t1 (id INT)", nil) + c.Exec("CREATE TABLE t2 (id INT)", nil) + c.Exec("DROP TABLE t1, t2", nil) + db := c.GetDatabase("test") + if db.GetTable("t1") != nil || db.GetTable("t2") != nil { + t.Fatal("both tables should be dropped") + } +} + +func TestTruncateTable(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t1 (id INT AUTO_INCREMENT PRIMARY KEY)", nil) + results, _ := c.Exec("TRUNCATE TABLE t1", nil) + if results[0].Error != nil { + t.Fatalf("truncate failed: %v", results[0].Error) + } +} + +func TestTruncateTableNotExists(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("TRUNCATE TABLE noexist", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("expected error") + } +} diff --git a/tidb/catalog/errors.go b/tidb/catalog/errors.go new file mode 100644 index 00000000..960e4afb --- /dev/null +++ b/tidb/catalog/errors.go @@ -0,0 +1,210 @@ +package catalog + +import "fmt" + +type Error struct { + Code int + SQLState string + Message string +} + +func (e *Error) Error() string { + return fmt.Sprintf("ERROR %d (%s): %s", e.Code, e.SQLState, e.Message) +} + +const ( + ErrDupDatabase = 1007 + ErrUnknownDatabase = 1049 + ErrDupTable = 1050 + ErrUnknownTable = 1051 + ErrDupColumn = 1060 + ErrDupKeyName = 1061 + ErrDupEntry = 1062 + ErrMultiplePriKey = 1068 + ErrNoSuchTable = 1146 + ErrNoSuchColumn = 1054 + ErrNoDatabaseSelected = 1046 + ErrDupIndex = 1831 + ErrFKNoRefTable = 1824 + ErrCantDropKey = 1091 + ErrCheckConstraintViolated = 3819 + ErrFKCannotDropParent = 3730 + ErrFKMissingIndex = 1822 + ErrFKIncompatibleColumns = 3780 + ErrNoSuchFunction = 1305 + ErrNoSuchProcedure = 1305 + ErrDupFunction = 1304 + ErrDupProcedure = 1304 + ErrNoSuchTrigger = 1360 + ErrDupTrigger = 1359 + ErrNoSuchEvent = 1539 + ErrDupEvent = 1537 + ErrUnsupportedGeneratedStorageChange = 3106 + ErrDependentByGenCol = 3108 + ErrWrongArguments = 1210 +) + +var sqlStateMap = map[int]string{ + ErrDupDatabase: "HY000", + ErrUnknownDatabase: "42000", + ErrDupTable: "42S01", + ErrUnknownTable: "42S02", + ErrDupColumn: "42S21", + ErrDupKeyName: "42000", + ErrDupEntry: "23000", + ErrMultiplePriKey: "42000", + ErrNoSuchTable: "42S02", + ErrNoSuchColumn: "42S22", + ErrNoDatabaseSelected: "3D000", + ErrDupIndex: "42000", + ErrFKNoRefTable: "HY000", + ErrCantDropKey: "42000", + ErrCheckConstraintViolated: "HY000", + ErrFKCannotDropParent: "HY000", + ErrFKMissingIndex: "HY000", + ErrFKIncompatibleColumns: "HY000", + ErrNoSuchFunction: "42000", + ErrDupFunction: "HY000", + ErrNoSuchEvent: "HY000", + ErrDupEvent: "HY000", + ErrUnsupportedGeneratedStorageChange: "HY000", + ErrDependentByGenCol: "HY000", + ErrWrongArguments: "HY000", +} + +func sqlState(code int) string { + if s, ok := sqlStateMap[code]; ok { + return s + } + return "HY000" +} + +func errDupDatabase(name string) error { + return &Error{Code: ErrDupDatabase, SQLState: sqlState(ErrDupDatabase), + Message: fmt.Sprintf("Can't create database '%s'; database exists", name)} +} + +func errUnknownDatabase(name string) error { + return &Error{Code: ErrUnknownDatabase, SQLState: sqlState(ErrUnknownDatabase), + Message: fmt.Sprintf("Unknown database '%s'", name)} +} + +func errNoDatabaseSelected() error { + return &Error{Code: ErrNoDatabaseSelected, SQLState: sqlState(ErrNoDatabaseSelected), + Message: "No database selected"} +} + +func errDupTable(name string) error { + return &Error{Code: ErrDupTable, SQLState: sqlState(ErrDupTable), + Message: fmt.Sprintf("Table '%s' already exists", name)} +} + +func errNoSuchTable(db, name string) error { + return &Error{Code: ErrNoSuchTable, SQLState: sqlState(ErrNoSuchTable), + Message: fmt.Sprintf("Table '%s.%s' doesn't exist", db, name)} +} + +func errDupColumn(name string) error { + return &Error{Code: ErrDupColumn, SQLState: sqlState(ErrDupColumn), + Message: fmt.Sprintf("Duplicate column name '%s'", name)} +} + +func errDupKeyName(name string) error { + return &Error{Code: ErrDupKeyName, SQLState: sqlState(ErrDupKeyName), + Message: fmt.Sprintf("Duplicate key name '%s'", name)} +} + +func errMultiplePriKey() error { + return &Error{Code: ErrMultiplePriKey, SQLState: sqlState(ErrMultiplePriKey), + Message: "Multiple primary key defined"} +} + +func errNoSuchColumn(name, context string) error { + return &Error{Code: ErrNoSuchColumn, SQLState: sqlState(ErrNoSuchColumn), + Message: fmt.Sprintf("Unknown column '%s' in '%s'", name, context)} +} + +func errUnknownTable(db, name string) error { + return &Error{Code: ErrUnknownTable, SQLState: sqlState(ErrUnknownTable), + Message: fmt.Sprintf("Unknown table '%s.%s'", db, name)} +} + +func errFKCannotDropParent(table, fkName, refTable string) error { + return &Error{Code: ErrFKCannotDropParent, SQLState: sqlState(ErrFKCannotDropParent), + Message: fmt.Sprintf("Cannot drop table '%s' referenced by a foreign key constraint '%s' on table '%s'", table, fkName, refTable)} +} + +func errCantDropKey(name string) error { + return &Error{Code: ErrCantDropKey, SQLState: sqlState(ErrCantDropKey), + Message: fmt.Sprintf("Can't DROP '%s'; check that column/key exists", name)} +} + +func errFKNoRefTable(table string) error { + return &Error{Code: ErrFKNoRefTable, SQLState: sqlState(ErrFKNoRefTable), + Message: fmt.Sprintf("Failed to open the referenced table '%s'", table)} +} + +func errFKMissingIndex(constraint, refTable string) error { + return &Error{Code: ErrFKMissingIndex, SQLState: sqlState(ErrFKMissingIndex), + Message: fmt.Sprintf("Failed to add the foreign key constraint. Missing index for constraint '%s' in the referenced table '%s'", constraint, refTable)} +} + +func errFKIncompatibleColumns(col, refCol, constraint string) error { + return &Error{Code: ErrFKIncompatibleColumns, SQLState: sqlState(ErrFKIncompatibleColumns), + Message: fmt.Sprintf("Referencing column '%s' and referenced column '%s' in foreign key constraint '%s' are incompatible.", col, refCol, constraint)} +} + +func errDupFunction(name string) error { + return &Error{Code: ErrDupFunction, SQLState: sqlState(ErrDupFunction), + Message: fmt.Sprintf("FUNCTION %s already exists", name)} +} + +func errDupProcedure(name string) error { + return &Error{Code: ErrDupProcedure, SQLState: sqlState(ErrDupProcedure), + Message: fmt.Sprintf("PROCEDURE %s already exists", name)} +} + +func errNoSuchFunction(name string) error { + return &Error{Code: ErrNoSuchFunction, SQLState: sqlState(ErrNoSuchFunction), + Message: fmt.Sprintf("FUNCTION %s does not exist", name)} +} + +func errNoSuchProcedure(db, name string) error { + return &Error{Code: ErrNoSuchProcedure, SQLState: sqlState(ErrNoSuchProcedure), + Message: fmt.Sprintf("PROCEDURE %s.%s does not exist", db, name)} +} + +func errDupTrigger(name string) error { + return &Error{Code: ErrDupTrigger, SQLState: sqlState(ErrDupTrigger), + Message: fmt.Sprintf("Trigger already exists")} +} + +func errNoSuchTrigger(db, name string) error { + return &Error{Code: ErrNoSuchTrigger, SQLState: sqlState(ErrNoSuchTrigger), + Message: fmt.Sprintf("Trigger does not exist")} +} + +func errDupEvent(name string) error { + return &Error{Code: ErrDupEvent, SQLState: sqlState(ErrDupEvent), + Message: fmt.Sprintf("Event '%s' already exists", name)} +} + +func errNoSuchEvent(db, name string) error { + return &Error{Code: ErrNoSuchEvent, SQLState: sqlState(ErrNoSuchEvent), + Message: fmt.Sprintf("Unknown event '%s'", name)} +} + +func errUnsupportedGeneratedStorageChange(col, table string) error { + return &Error{Code: ErrUnsupportedGeneratedStorageChange, SQLState: sqlState(ErrUnsupportedGeneratedStorageChange), + Message: fmt.Sprintf("'Changing the STORED status' is not supported for generated columns.")} +} + +func errDependentByGeneratedColumn(column, genColumn, table string) error { + return &Error{Code: ErrDependentByGenCol, SQLState: sqlState(ErrDependentByGenCol), + Message: fmt.Sprintf("Column '%s' has a generated column dependency and cannot be dropped or renamed. A generated column '%s' refers to this column in table '%s'.", column, genColumn, table)} +} + +func errWrongArguments(fn string) error { + return &Error{Code: ErrWrongArguments, SQLState: sqlState(ErrWrongArguments), + Message: fmt.Sprintf("Incorrect arguments to %s", fn)} +} diff --git a/tidb/catalog/eventcmds.go b/tidb/catalog/eventcmds.go new file mode 100644 index 00000000..58f42a95 --- /dev/null +++ b/tidb/catalog/eventcmds.go @@ -0,0 +1,196 @@ +package catalog + +import ( + "fmt" + "strings" + + nodes "github.com/bytebase/omni/tidb/ast" +) + +func (c *Catalog) createEvent(stmt *nodes.CreateEventStmt) error { + db, err := c.resolveDatabase("") + if err != nil { + return err + } + + name := stmt.Name + key := toLower(name) + + if _, exists := db.Events[key]; exists { + if !stmt.IfNotExists { + return errDupEvent(name) + } + return nil + } + + // MySQL always sets a definer. Default to `root`@`%` when not specified. + definer := stmt.Definer + if definer == "" { + definer = "`root`@`%`" + } + + // Extract raw schedule text from the AST. + schedule := "" + if stmt.Schedule != nil { + schedule = stmt.Schedule.RawText + } + + event := &Event{ + Name: name, + Database: db, + Definer: definer, + Schedule: schedule, + OnCompletion: stmt.OnCompletion, + Enable: stmt.Enable, + Comment: stmt.Comment, + Body: strings.TrimSpace(stmt.Body), + } + + db.Events[key] = event + return nil +} + +func (c *Catalog) alterEvent(stmt *nodes.AlterEventStmt) error { + db, err := c.resolveDatabase("") + if err != nil { + return err + } + + name := stmt.Name + key := toLower(name) + + event, exists := db.Events[key] + if !exists { + return errNoSuchEvent(db.Name, name) + } + + // Update definer if specified. + if stmt.Definer != "" { + event.Definer = stmt.Definer + } + + // Update schedule if specified. + if stmt.Schedule != nil { + event.Schedule = stmt.Schedule.RawText + } + + // Update ON COMPLETION if specified. + if stmt.OnCompletion != "" { + event.OnCompletion = stmt.OnCompletion + } + + // Update enable/disable if specified. + if stmt.Enable != "" { + event.Enable = stmt.Enable + } + + // Update comment if specified. + if stmt.Comment != "" { + event.Comment = stmt.Comment + } + + // Update body if specified. + if stmt.Body != "" { + event.Body = strings.TrimSpace(stmt.Body) + } + + // Handle RENAME TO. + if stmt.RenameTo != "" { + newKey := toLower(stmt.RenameTo) + delete(db.Events, key) + event.Name = stmt.RenameTo + db.Events[newKey] = event + } + + return nil +} + +func (c *Catalog) dropEvent(stmt *nodes.DropEventStmt) error { + db, err := c.resolveDatabase("") + if err != nil { + if stmt.IfExists { + return nil + } + return err + } + + name := stmt.Name + key := toLower(name) + + if _, exists := db.Events[key]; !exists { + if stmt.IfExists { + return nil + } + return errNoSuchEvent(db.Name, name) + } + + delete(db.Events, key) + return nil +} + +// ShowCreateEvent produces MySQL 8.0-compatible SHOW CREATE EVENT output. +// +// MySQL 8.0 SHOW CREATE EVENT format: +// +// CREATE DEFINER=`root`@`%` EVENT `event_name` ON SCHEDULE schedule ON COMPLETION [NOT] PRESERVE [ENABLE|DISABLE|DISABLE ON SLAVE] [COMMENT 'string'] DO event_body +func (c *Catalog) ShowCreateEvent(database, name string) string { + db := c.GetDatabase(database) + if db == nil { + return "" + } + event := db.Events[toLower(name)] + if event == nil { + return "" + } + return showCreateEvent(event) +} + +func showCreateEvent(e *Event) string { + var b strings.Builder + + b.WriteString("CREATE") + + // DEFINER + if e.Definer != "" { + b.WriteString(fmt.Sprintf(" DEFINER=%s", e.Definer)) + } + + b.WriteString(fmt.Sprintf(" EVENT `%s`", e.Name)) + + // ON SCHEDULE + if e.Schedule != "" { + b.WriteString(fmt.Sprintf(" ON SCHEDULE %s", e.Schedule)) + } + + // ON COMPLETION + if e.OnCompletion == "NOT PRESERVE" { + b.WriteString(" ON COMPLETION NOT PRESERVE") + } else if e.OnCompletion == "PRESERVE" { + b.WriteString(" ON COMPLETION PRESERVE") + } else { + // MySQL default: NOT PRESERVE, shown explicitly in SHOW CREATE EVENT + b.WriteString(" ON COMPLETION NOT PRESERVE") + } + + // ENABLE / DISABLE + if e.Enable == "DISABLE" { + b.WriteString(" DISABLE") + } else if e.Enable == "DISABLE ON SLAVE" { + b.WriteString(" DISABLE ON SLAVE") + } else { + // MySQL default: ENABLE, shown explicitly in SHOW CREATE EVENT + b.WriteString(" ENABLE") + } + + // COMMENT + if e.Comment != "" { + b.WriteString(fmt.Sprintf(" COMMENT '%s'", escapeComment(e.Comment))) + } + + // DO event_body + if e.Body != "" { + b.WriteString(fmt.Sprintf(" DO %s", e.Body)) + } + + return b.String() +} diff --git a/tidb/catalog/exec.go b/tidb/catalog/exec.go new file mode 100644 index 00000000..412f56f6 --- /dev/null +++ b/tidb/catalog/exec.go @@ -0,0 +1,274 @@ +package catalog + +import ( + "fmt" + + nodes "github.com/bytebase/omni/tidb/ast" + mysqlparser "github.com/bytebase/omni/tidb/parser" +) + +type ExecOptions struct { + ContinueOnError bool +} + +type ExecResult struct { + Index int + SQL string + Line int // 1-based start line in the original SQL + Skipped bool + Error error +} + +func (c *Catalog) Exec(sql string, opts *ExecOptions) ([]ExecResult, error) { + list, err := mysqlparser.Parse(sql) + if err != nil { + return nil, err + } + if list == nil || len(list.Items) == 0 { + return nil, nil + } + + lineIndex := buildLineIndex(sql) + + continueOnError := false + if opts != nil { + continueOnError = opts.ContinueOnError + } + + results := make([]ExecResult, 0, len(list.Items)) + for i, item := range list.Items { + locStart := stmtLocStart(item) + result := ExecResult{ + Index: i, + Line: offsetToLine(lineIndex, locStart), + } + + if isDML(item) { + result.Skipped = true + results = append(results, result) + continue + } + + execErr := c.processUtility(item) + result.Error = execErr + results = append(results, result) + + if execErr != nil && !continueOnError { + break + } + } + return results, nil +} + +func LoadSQL(sql string) (*Catalog, error) { + c := New() + results, err := c.Exec(sql, nil) + if err != nil { + return nil, err + } + for _, r := range results { + if r.Error != nil { + return c, r.Error + } + } + return c, nil +} + +// execSet handles SET statements that affect catalog behavior. +// Most SET variables are silently accepted (session-level settings like NAMES, +// CHARACTER SET, sql_mode). Variables that affect DDL behavior (foreign_key_checks) +// update the catalog state. +func (c *Catalog) execSet(stmt *nodes.SetStmt) error { + for _, asgn := range stmt.Assignments { + varName := toLower(asgn.Column.Column) + switch varName { + case "foreign_key_checks": + // Extract the value. + val := nodeToSQLValue(asgn.Value) + switch toLower(val) { + case "0", "off", "false": + c.foreignKeyChecks = false + case "1", "on", "true": + c.foreignKeyChecks = true + } + case "names", "character set": + // Silently accept — these affect character encoding but + // the in-memory catalog doesn't need to change behavior. + default: + // Silently accept all other SET variables (sql_mode, etc.). + } + } + return nil +} + +// nodeToSQLValue extracts a simple string value from an expression node. +func nodeToSQLValue(expr nodes.ExprNode) string { + switch e := expr.(type) { + case *nodes.StringLit: + return e.Value + case *nodes.IntLit: + return fmt.Sprintf("%d", e.Value) + case *nodes.FloatLit: + return e.Value + case *nodes.BoolLit: + if e.Value { + return "1" + } + return "0" + case *nodes.ColumnRef: + return e.Column + default: + return "" + } +} + +func isDML(stmt nodes.Node) bool { + switch stmt.(type) { + case *nodes.SelectStmt, *nodes.InsertStmt, *nodes.UpdateStmt, *nodes.DeleteStmt: + return true + default: + return false + } +} + +func (c *Catalog) processUtility(stmt nodes.Node) error { + switch s := stmt.(type) { + case *nodes.CreateDatabaseStmt: + return c.createDatabase(s) + case *nodes.CreateTableStmt: + return c.createTable(s) + case *nodes.CreateIndexStmt: + return c.createIndex(s) + case *nodes.CreateViewStmt: + return c.createView(s) + case *nodes.AlterViewStmt: + return c.alterView(s) + case *nodes.AlterTableStmt: + return c.alterTable(s) + case *nodes.AlterDatabaseStmt: + return c.alterDatabase(s) + case *nodes.DropTableStmt: + return c.dropTable(s) + case *nodes.DropDatabaseStmt: + return c.dropDatabase(s) + case *nodes.DropIndexStmt: + return c.dropIndex(s) + case *nodes.DropViewStmt: + return c.dropView(s) + case *nodes.RenameTableStmt: + return c.renameTable(s) + case *nodes.TruncateStmt: + return c.truncateTable(s) + case *nodes.UseStmt: + return c.useDatabase(s) + case *nodes.CreateFunctionStmt: + return c.createRoutine(s) + case *nodes.DropRoutineStmt: + return c.dropRoutine(s) + case *nodes.AlterRoutineStmt: + return c.alterRoutine(s) + case *nodes.CreateTriggerStmt: + return c.createTrigger(s) + case *nodes.DropTriggerStmt: + return c.dropTrigger(s) + case *nodes.CreateEventStmt: + return c.createEvent(s) + case *nodes.AlterEventStmt: + return c.alterEvent(s) + case *nodes.DropEventStmt: + return c.dropEvent(s) + case *nodes.SetStmt: + return c.execSet(s) + default: + return nil + } +} + +// buildLineIndex returns the byte offset of each line start. +func buildLineIndex(sql string) []int { + index := []int{0} + for i := 0; i < len(sql); i++ { + if sql[i] == '\n' { + index = append(index, i+1) + } + } + return index +} + +// offsetToLine converts a byte offset to a 1-based line number. +func offsetToLine(lineIndex []int, offset int) int { + lo, hi := 0, len(lineIndex)-1 + for lo < hi { + mid := (lo + hi + 1) / 2 + if lineIndex[mid] <= offset { + lo = mid + } else { + hi = mid - 1 + } + } + return lo + 1 +} + +// stmtLocStart extracts Loc.Start from a statement node. +// Uses a type switch over the statement types handled by processUtility, +// plus common DML types. All have a Loc field set by the parser. +func stmtLocStart(node nodes.Node) int { + switch s := node.(type) { + case *nodes.CreateDatabaseStmt: + return s.Loc.Start + case *nodes.CreateTableStmt: + return s.Loc.Start + case *nodes.CreateIndexStmt: + return s.Loc.Start + case *nodes.CreateViewStmt: + return s.Loc.Start + case *nodes.AlterViewStmt: + return s.Loc.Start + case *nodes.AlterTableStmt: + return s.Loc.Start + case *nodes.AlterDatabaseStmt: + return s.Loc.Start + case *nodes.DropTableStmt: + return s.Loc.Start + case *nodes.DropDatabaseStmt: + return s.Loc.Start + case *nodes.DropIndexStmt: + return s.Loc.Start + case *nodes.DropViewStmt: + return s.Loc.Start + case *nodes.RenameTableStmt: + return s.Loc.Start + case *nodes.TruncateStmt: + return s.Loc.Start + case *nodes.UseStmt: + return s.Loc.Start + case *nodes.CreateFunctionStmt: + return s.Loc.Start + case *nodes.DropRoutineStmt: + return s.Loc.Start + case *nodes.AlterRoutineStmt: + return s.Loc.Start + case *nodes.CreateTriggerStmt: + return s.Loc.Start + case *nodes.DropTriggerStmt: + return s.Loc.Start + case *nodes.CreateEventStmt: + return s.Loc.Start + case *nodes.AlterEventStmt: + return s.Loc.Start + case *nodes.DropEventStmt: + return s.Loc.Start + case *nodes.SetStmt: + return s.Loc.Start + case *nodes.SelectStmt: + return s.Loc.Start + case *nodes.InsertStmt: + return s.Loc.Start + case *nodes.UpdateStmt: + return s.Loc.Start + case *nodes.DeleteStmt: + return s.Loc.Start + default: + return 0 + } +} diff --git a/tidb/catalog/exec_line_test.go b/tidb/catalog/exec_line_test.go new file mode 100644 index 00000000..9d3529c3 --- /dev/null +++ b/tidb/catalog/exec_line_test.go @@ -0,0 +1,49 @@ +package catalog + +import "testing" + +func TestExecResultLine(t *testing.T) { + sql := "CREATE TABLE t1 (id INT);\nCREATE TABLE t2 (id INT);\nCREATE TABLE t3 (id INT);" + + c := New() + // Set up a database first. + c.Exec("CREATE DATABASE test; USE test;", nil) //nolint:errcheck + + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("Exec error: %v", err) + } + + if len(results) != 3 { + t.Fatalf("got %d results, want 3", len(results)) + } + wantLines := []int{1, 2, 3} + for i, r := range results { + if r.Line != wantLines[i] { + t.Errorf("result[%d].Line = %d, want %d", i, r.Line, wantLines[i]) + } + } +} + +func TestExecResultLineWithDelimiter(t *testing.T) { + sql := "DELIMITER ;;\nCREATE TABLE t1 (id INT);;\nDELIMITER ;\nCREATE TABLE t2 (id INT);" + + c := New() + c.Exec("CREATE DATABASE test; USE test;", nil) //nolint:errcheck + + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("Exec error: %v", err) + } + + if len(results) != 2 { + t.Fatalf("got %d results, want 2", len(results)) + } + // First CREATE TABLE is on line 2, second is on line 4. + wantLines := []int{2, 4} + for i, r := range results { + if r.Line != wantLines[i] { + t.Errorf("result[%d].Line = %d, want %d", i, r.Line, wantLines[i]) + } + } +} diff --git a/tidb/catalog/exec_test.go b/tidb/catalog/exec_test.go new file mode 100644 index 00000000..13bcbb33 --- /dev/null +++ b/tidb/catalog/exec_test.go @@ -0,0 +1,27 @@ +package catalog + +import "testing" + +func TestExecSkipsDML(t *testing.T) { + c := New() + results, err := c.Exec("SELECT 1; INSERT INTO t VALUES (1)", nil) + if err != nil { + t.Fatal(err) + } + for _, r := range results { + if !r.Skipped { + t.Errorf("expected DML to be skipped") + } + } +} + +func TestExecEmpty(t *testing.T) { + c := New() + results, err := c.Exec("", nil) + if err != nil { + t.Fatal(err) + } + if results != nil { + t.Errorf("expected nil results for empty SQL") + } +} diff --git a/tidb/catalog/function_types.go b/tidb/catalog/function_types.go new file mode 100644 index 00000000..6a740f42 --- /dev/null +++ b/tidb/catalog/function_types.go @@ -0,0 +1,88 @@ +package catalog + +import "strings" + +// functionReturnType returns the inferred return type for a known function. +// Returns nil for unknown functions (Phase 3 covers ~50 common functions). +func functionReturnType(name string, args []AnalyzedExpr) *ResolvedType { + switch strings.ToLower(name) { + // String functions + case "concat", "concat_ws", "lower", "upper", "trim", "ltrim", "rtrim", + "substring", "substr", "left", "right", "replace", "reverse", + "lpad", "rpad", "repeat", "space", "format": + return &ResolvedType{BaseType: BaseTypeVarchar} + + // Numeric functions + case "abs", "ceil", "ceiling", "floor", "round", "truncate", "mod": + return &ResolvedType{BaseType: BaseTypeDecimal} + case "rand": + return &ResolvedType{BaseType: BaseTypeDouble} + + // Aggregate functions + case "count": + return &ResolvedType{BaseType: BaseTypeBigInt} + case "sum", "avg": + return &ResolvedType{BaseType: BaseTypeDecimal} + case "min", "max": + return nil // type depends on argument + case "group_concat": + return &ResolvedType{BaseType: BaseTypeText} + + // Date/time functions + case "now", "current_timestamp", "sysdate", "localtime", "localtimestamp": + return &ResolvedType{BaseType: BaseTypeDateTime} + case "curdate", "current_date": + return &ResolvedType{BaseType: BaseTypeDate} + case "curtime", "current_time": + return &ResolvedType{BaseType: BaseTypeTime} + case "year", "month", "day", "hour", "minute", "second", + "dayofweek", "dayofmonth", "dayofyear", "weekday", + "quarter", "week", "yearweek": + return &ResolvedType{BaseType: BaseTypeInt} + case "date": + return &ResolvedType{BaseType: BaseTypeDate} + case "time": + return &ResolvedType{BaseType: BaseTypeTime} + case "timestamp": + return &ResolvedType{BaseType: BaseTypeTimestamp} + + // Type conversion + case "cast", "convert": + return nil // handled by CastExprQ + + // Control flow + case "if": + return nil // type depends on arguments + case "nullif": + return nil // type depends on first argument + + // JSON functions + case "json_extract", "json_unquote": + return &ResolvedType{BaseType: BaseTypeJSON} + case "json_length", "json_depth", "json_valid": + return &ResolvedType{BaseType: BaseTypeInt} + case "json_type": + return &ResolvedType{BaseType: BaseTypeVarchar} + case "json_array", "json_object", "json_merge_preserve", "json_merge_patch": + return &ResolvedType{BaseType: BaseTypeJSON} + + // Misc + case "coalesce", "ifnull": + return nil // type depends on arguments + case "uuid": + return &ResolvedType{BaseType: BaseTypeVarchar} + case "version": + return &ResolvedType{BaseType: BaseTypeVarchar} + case "database", "schema", "user", "current_user", "session_user", "system_user": + return &ResolvedType{BaseType: BaseTypeVarchar} + case "last_insert_id", "row_count", "found_rows": + return &ResolvedType{BaseType: BaseTypeBigInt} + case "connection_id": + return &ResolvedType{BaseType: BaseTypeBigInt} + case "charset", "collation": + return &ResolvedType{BaseType: BaseTypeVarchar} + case "length", "char_length", "character_length", "bit_length", "octet_length": + return &ResolvedType{BaseType: BaseTypeInt} + } + return nil +} diff --git a/tidb/catalog/index.go b/tidb/catalog/index.go new file mode 100644 index 00000000..fc35572f --- /dev/null +++ b/tidb/catalog/index.go @@ -0,0 +1,22 @@ +package catalog + +type Index struct { + Name string + Table *Table + Columns []*IndexColumn + Unique bool + Primary bool + Fulltext bool + Spatial bool + IndexType string // BTREE, HASH, FULLTEXT, SPATIAL + Comment string + Visible bool + KeyBlockSize int +} + +type IndexColumn struct { + Name string + Expr string + Length int + Descending bool +} diff --git a/tidb/catalog/indexcmds.go b/tidb/catalog/indexcmds.go new file mode 100644 index 00000000..e6e930df --- /dev/null +++ b/tidb/catalog/indexcmds.go @@ -0,0 +1,131 @@ +package catalog + +import nodes "github.com/bytebase/omni/tidb/ast" + +// createIndex handles standalone CREATE INDEX statements. +func (c *Catalog) createIndex(stmt *nodes.CreateIndexStmt) error { + // Resolve database. + dbName := "" + if stmt.Table != nil { + dbName = stmt.Table.Schema + } + db, err := c.resolveDatabase(dbName) + if err != nil { + return err + } + + tableName := stmt.Table.Name + tbl := db.Tables[toLower(tableName)] + if tbl == nil { + return errNoSuchTable(db.Name, tableName) + } + + // Check for duplicate key name. + if indexNameExists(tbl, stmt.IndexName) { + if stmt.IfNotExists { + return nil + } + return errDupKeyName(stmt.IndexName) + } + + // Build index columns from AST columns. + idxCols := make([]*IndexColumn, 0, len(stmt.Columns)) + for _, ic := range stmt.Columns { + col := &IndexColumn{ + Length: ic.Length, + Descending: ic.Desc, + } + if cr, ok := ic.Expr.(*nodes.ColumnRef); ok { + col.Name = cr.Column + } else { + col.Expr = nodeToSQL(ic.Expr) + } + idxCols = append(idxCols, col) + } + + // Determine index type: Fulltext/Spatial override IndexType. + idx := &Index{ + Name: stmt.IndexName, + Table: tbl, + Columns: idxCols, + Visible: true, + } + + switch { + case stmt.Fulltext: + idx.Fulltext = true + idx.IndexType = "FULLTEXT" + case stmt.Spatial: + idx.Spatial = true + idx.IndexType = "SPATIAL" + default: + idx.IndexType = stmt.IndexType + idx.Unique = stmt.Unique + } + + applyIndexOptions(idx, stmt.Options) + + tbl.Indexes = append(tbl.Indexes, idx) + + // If unique, also add a UniqueKey constraint. + if stmt.Unique { + cols := make([]string, 0, len(idxCols)) + for _, ic := range idxCols { + if ic.Name != "" { + cols = append(cols, ic.Name) + } + } + tbl.Constraints = append(tbl.Constraints, &Constraint{ + Name: stmt.IndexName, + Type: ConUniqueKey, + Table: tbl, + Columns: cols, + IndexName: stmt.IndexName, + }) + } + + return nil +} + +// dropIndex handles standalone DROP INDEX statements. +func (c *Catalog) dropIndex(stmt *nodes.DropIndexStmt) error { + // Resolve database. + dbName := "" + if stmt.Table != nil { + dbName = stmt.Table.Schema + } + db, err := c.resolveDatabase(dbName) + if err != nil { + return err + } + + tableName := stmt.Table.Name + tbl := db.Tables[toLower(tableName)] + if tbl == nil { + return errNoSuchTable(db.Name, tableName) + } + + // Find and remove index. + key := toLower(stmt.Name) + found := false + for i, idx := range tbl.Indexes { + if toLower(idx.Name) == key { + tbl.Indexes = append(tbl.Indexes[:i], tbl.Indexes[i+1:]...) + found = true + break + } + } + if !found { + return errCantDropKey(stmt.Name) + } + + // Also remove any constraint that references this index. + for i, con := range tbl.Constraints { + if toLower(con.IndexName) == key || toLower(con.Name) == key { + tbl.Constraints = append(tbl.Constraints[:i], tbl.Constraints[i+1:]...) + break + } + } + + return nil +} diff --git a/tidb/catalog/indexcmds_test.go b/tidb/catalog/indexcmds_test.go new file mode 100644 index 00000000..557e2307 --- /dev/null +++ b/tidb/catalog/indexcmds_test.go @@ -0,0 +1,141 @@ +package catalog + +import "testing" + +func setupIndexTestTable(t *testing.T) *Catalog { + t.Helper() + c := New() + mustExec(t, c, "CREATE DATABASE test") + c.SetCurrentDatabase("test") + mustExec(t, c, "CREATE TABLE t1 (id INT NOT NULL, name VARCHAR(100), body TEXT)") + return c +} + +func TestCreateIndex(t *testing.T) { + c := setupIndexTestTable(t) + mustExec(t, c, "CREATE INDEX idx_name ON t1 (name)") + + tbl := c.GetDatabase("test").GetTable("t1") + if len(tbl.Indexes) != 1 { + t.Fatalf("expected 1 index, got %d", len(tbl.Indexes)) + } + + idx := tbl.Indexes[0] + if idx.Name != "idx_name" { + t.Errorf("expected index name 'idx_name', got %q", idx.Name) + } + if idx.Unique { + t.Error("expected non-unique index") + } + if idx.IndexType != "" { + t.Errorf("expected empty IndexType (implicit BTREE), got %q", idx.IndexType) + } + if len(idx.Columns) != 1 || idx.Columns[0].Name != "name" { + t.Errorf("expected index on column 'name', got %v", idx.Columns) + } + + // No constraint should be created for a plain index. + if len(tbl.Constraints) != 0 { + t.Errorf("expected 0 constraints, got %d", len(tbl.Constraints)) + } +} + +func TestCreateUniqueIndex(t *testing.T) { + c := setupIndexTestTable(t) + mustExec(t, c, "CREATE UNIQUE INDEX idx_id ON t1 (id)") + + tbl := c.GetDatabase("test").GetTable("t1") + if len(tbl.Indexes) != 1 { + t.Fatalf("expected 1 index, got %d", len(tbl.Indexes)) + } + + idx := tbl.Indexes[0] + if idx.Name != "idx_id" { + t.Errorf("expected index name 'idx_id', got %q", idx.Name) + } + if !idx.Unique { + t.Error("expected unique index") + } + + // Unique index should also create a constraint. + if len(tbl.Constraints) != 1 { + t.Fatalf("expected 1 constraint, got %d", len(tbl.Constraints)) + } + con := tbl.Constraints[0] + if con.Type != ConUniqueKey { + t.Errorf("expected ConUniqueKey, got %d", con.Type) + } + if con.Name != "idx_id" { + t.Errorf("expected constraint name 'idx_id', got %q", con.Name) + } + if con.IndexName != "idx_id" { + t.Errorf("expected constraint IndexName 'idx_id', got %q", con.IndexName) + } +} + +func TestCreateFulltextIndex(t *testing.T) { + c := setupIndexTestTable(t) + mustExec(t, c, "CREATE FULLTEXT INDEX idx_body ON t1 (body)") + + tbl := c.GetDatabase("test").GetTable("t1") + if len(tbl.Indexes) != 1 { + t.Fatalf("expected 1 index, got %d", len(tbl.Indexes)) + } + + idx := tbl.Indexes[0] + if idx.Name != "idx_body" { + t.Errorf("expected index name 'idx_body', got %q", idx.Name) + } + if !idx.Fulltext { + t.Error("expected fulltext index") + } + if idx.IndexType != "FULLTEXT" { + t.Errorf("expected IndexType 'FULLTEXT', got %q", idx.IndexType) + } + if idx.Unique { + t.Error("fulltext index should not be unique") + } +} + +func TestDropIndex(t *testing.T) { + c := setupIndexTestTable(t) + mustExec(t, c, "CREATE INDEX idx_name ON t1 (name)") + mustExec(t, c, "DROP INDEX idx_name ON t1") + + tbl := c.GetDatabase("test").GetTable("t1") + if len(tbl.Indexes) != 0 { + t.Fatalf("expected 0 indexes after drop, got %d", len(tbl.Indexes)) + } +} + +func TestDropIndexNotFound(t *testing.T) { + c := setupIndexTestTable(t) + results, _ := c.Exec("DROP INDEX nonexistent ON t1", &ExecOptions{ContinueOnError: true}) + if len(results) == 0 || results[0].Error == nil { + t.Fatal("expected error for dropping nonexistent index") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("expected *Error, got %T", results[0].Error) + } + if catErr.Code != ErrCantDropKey { + t.Errorf("expected error code %d, got %d", ErrCantDropKey, catErr.Code) + } +} + +func TestCreateIndexDupKeyName(t *testing.T) { + c := setupIndexTestTable(t) + mustExec(t, c, "CREATE INDEX idx_name ON t1 (name)") + + results, _ := c.Exec("CREATE INDEX idx_name ON t1 (id)", &ExecOptions{ContinueOnError: true}) + if len(results) == 0 || results[0].Error == nil { + t.Fatal("expected duplicate key name error") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("expected *Error, got %T", results[0].Error) + } + if catErr.Code != ErrDupKeyName { + t.Errorf("expected error code %d, got %d", ErrDupKeyName, catErr.Code) + } +} diff --git a/tidb/catalog/query.go b/tidb/catalog/query.go new file mode 100644 index 00000000..d1dc89e2 --- /dev/null +++ b/tidb/catalog/query.go @@ -0,0 +1,1310 @@ +// Package catalog — query.go defines the analyzed-query IR for MySQL. +// +// This file is the Phase 0 deliverable of the MySQL semantic layer effort. +// See: docs/plans/2026-04-09-mysql-semantic-layer.md +// +// # Purpose +// +// These types are the post-analysis (semantically resolved) form of a SELECT +// statement, parallel to PostgreSQL's `Query` / `RangeTblEntry` / `Var` / +// `TargetEntry` family. They are produced by `AnalyzeSelectStmt` (Phase 1) +// and consumed by: +// +// - bytebase MySQL query span (column-level lineage), via an external adapter +// - mysql/deparse Phase 4: deparse from IR back to canonical SQL +// - mysql/diff and mysql/sdl Phase 5: structural schema diffing +// +// The IR follows PG's shape so algorithms (lineage walking, dependency +// extraction, deparse) can be ported across engines with minimal change. +// MySQL-specific divergences are documented inline next to each affected type. +// +// # Status +// +// Phase 0 = TYPE DEFINITIONS ONLY. No analyzer, no methods beyond interface +// tags and trivial accessors. Code in this file is not consumed by any +// production path yet; the purpose is to be reviewed (user + cc + codex) and +// locked down before any consumer is built. +// +// This is **Revision 2** (2026-04-10), incorporating sub-agent review feedback. +// See `docs/plans/2026-04-09-mysql-semantic-layer.md` § 9 for the change log. +// +// # Naming convention: the `Q` suffix +// +// All Query-IR struct types in this file end in `Q` (for "Query IR"), with +// two exceptions: +// +// 1. `Query` itself does not get a Q suffix — it is the IR's namespace +// anchor and `QueryQ` would be tautological. +// 2. The `AnalyzedExpr` interface does not get a Q suffix — interfaces are +// category labels, not IR nodes. +// +// Enum types do NOT get a Q suffix (they are values, not structures), with +// one mechanical exception: `JoinTypeQ` collides with `mysql/ast.JoinType` +// and is forced to take the suffix. +// +// The rule mirrors PG's convention but applies it uniformly within the +// MySQL IR, rather than only to types that collide with the parser AST. +// Rationale: in `mysql/catalog/analyze.go` (Phase 1) and below, Q-suffixed +// names provide an instant visual cue that this is the analyzed namespace, +// not the catalog state machine (`Table`, `Column`, etc.) and not the parser +// AST (`mysql/ast.SelectStmt`, etc.). +// +// # ResolvedType (vs the plan doc's `ColumnType`) +// +// The plan doc § 4.1 named the type `ColumnType`. That name collides with the +// existing field `Column.ColumnType` (string) on `mysql/catalog/table.go:63`, +// which holds the full type text (e.g. `"varchar(100)"`). We use `ResolvedType` +// here to avoid shadowing and to better describe the role: the type as +// resolved by the analyzer at output time. +// +// # PG cross-reference key +// +// Each type carries a `// pg:` reference to the PG header / file it mirrors. +// When in doubt about semantics, consult `pg/catalog/query.go` for the +// translated equivalent. +package catalog + +// ============================================================================= +// Section 1 — Top-level Query +// ============================================================================= + +// Query is the analyzed form of a SELECT (and, in later phases, of an +// INSERT / UPDATE / DELETE). +// +// pg: src/include/nodes/parsenodes.h — Query +// +// Field ordering parallels pg/catalog/query.go::Query for ease of review; +// MySQL-specific fields (SQLMode, LockClause) are appended last. +type Query struct { + // CommandType discriminates SELECT / INSERT / UPDATE / DELETE. + // Phase 1 only emits CmdSelect; the field exists so the IR shape is stable + // when DML analysis is added in Phase 8+. + CommandType CmdType + + // TargetList is the analyzed SELECT list, in output order. ResNo on each + // entry is 1-based. ResJunk entries (helper columns synthesized for + // ORDER BY references) appear after non-junk entries. + TargetList []*TargetEntryQ + + // RangeTable is the flattened set of input relations referenced by this + // query level: base tables, FROM-clause subqueries, CTE references, and + // JOIN result rows. Indices into this slice are referenced by VarExprQ + // (RangeIdx) and JoinTreeQ (RTIndex on RangeTableRefQ / JoinExprNodeQ). + // + // Order is the discovery order during FROM-clause walking. Once analyzer + // is done, the slice is immutable. + RangeTable []*RangeTableEntryQ + + // JoinTree describes the structure of the FROM clause and the WHERE qual. + // FromList contains top-level FROM items; nested JOINs become JoinExprNodeQ. + // JoinTree.Quals is the WHERE expression. + // + // For set-operation queries (SetOp != SetOpNone), JoinTree is set to an + // empty FromList with nil Quals. The "real" inputs are in LArg/RArg. + JoinTree *JoinTreeQ + + // GroupClause is the GROUP BY list. Each entry references a TargetEntryQ + // by index (1-based) into TargetList. Plain GROUP BY only — Phase 1 does + // not model ROLLUP / CUBE / GROUPING SETS (deferred). + GroupClause []*SortGroupClauseQ + + // HavingQual is the HAVING expression, nil if absent. + HavingQual AnalyzedExpr + + // SortClause is the ORDER BY list. Same indexing convention as GroupClause. + SortClause []*SortGroupClauseQ + + // LimitCount and LimitOffset are the LIMIT / OFFSET expressions. + // MySQL syntax: `LIMIT [offset,] count` or `LIMIT count OFFSET offset`. + // Both forms are normalized into these two fields. + LimitCount AnalyzedExpr + LimitOffset AnalyzedExpr + + // Distinct is true for SELECT DISTINCT. + // Note: PG has a DistinctOn slice for `SELECT DISTINCT ON (...)`; + // MySQL does not have that construct, so it is not modeled here. + Distinct bool + + // Set operations: when SetOp != SetOpNone, LArg and RArg hold the analyzed + // arms and TargetList describes the *result* shape of the set op (column + // names, output types). RangeTable and JoinTree are empty. + // + // AllSetOp distinguishes UNION from UNION ALL, etc. It applies to *this* + // level's SetOp only — nested set ops carry their own AllSetOp on + // LArg/RArg recursively. + SetOp SetOpType + AllSetOp bool + LArg *Query + RArg *Query + + // CTEList holds WITH-clause CTEs declared at this query level. Each CTE + // is referenced from the RangeTable by an RTECTE entry whose CTEIndex + // indexes into this slice. + CTEList []*CommonTableExprQ + IsRecursive bool // WITH RECURSIVE + + // WindowClause holds named window declarations from the WINDOW clause: + // SELECT ... OVER w FROM t WINDOW w AS (PARTITION BY ...) + // References to named windows from FuncCallExprQ.Over carry only the Name + // field; the analyzer resolves the reference against this slice. + // Inline window definitions (`OVER (PARTITION BY ...)`) are stored + // directly on FuncCallExprQ.Over and not added here. + WindowClause []*WindowDefQ + + // HasAggs is true if the analyzer found at least one aggregate call in + // TargetList / HavingQual. Used by deparse and by GROUP BY validation. + // (Computed via FuncCallExprQ.IsAggregate, not by string-matching names.) + HasAggs bool + + // LockClause holds FOR UPDATE / FOR SHARE / LOCK IN SHARE MODE if present. + // MySQL-specific: nil for SELECTs without locking clauses. + LockClause *LockingClauseQ + + // SQLMode captures the session sql_mode at analyze time. The same SELECT + // can have different semantics under different sql_mode values + // (ANSI_QUOTES affects identifier quoting, ONLY_FULL_GROUP_BY affects + // validation, PIPES_AS_CONCAT affects `||`, etc.). Required for + // round-trip deparse fidelity. + SQLMode SQLMode +} + +// CmdType discriminates the kind of analyzed statement. +// +// pg: src/include/nodes/nodes.h — CmdType +type CmdType int + +const ( + CmdSelect CmdType = iota + CmdInsert + CmdUpdate + CmdDelete +) + +// ============================================================================= +// Section 2 — Target list +// ============================================================================= + +// TargetEntryQ represents one column in the SELECT list. +// +// pg: src/include/nodes/primnodes.h — TargetEntry +type TargetEntryQ struct { + // Expr is the analyzed select-list expression. For `SELECT t.a`, this is + // a *VarExprQ; for `SELECT a + 1`, this is an *OpExprQ; etc. + Expr AnalyzedExpr + + // ResNo is the 1-based output position. Helper (junk) columns are + // numbered after non-junk columns. Kept as a separate field (rather than + // implied by slice index) because the analyzer may reorder TargetList + // during set-operation column unification or DISTINCT processing. + // + // Note: PG's TargetEntry.resno is int16; we use plain int for simplicity. + ResNo int + + // ResName is the output column name as it would appear in a result set + // header. Either the user-provided alias (`SELECT a AS x`), the original + // column name (`SELECT a`), or the synthesized expression text + // (`SELECT a+1` → `a+1`). + ResName string + + // ResJunk marks helper columns synthesized for ORDER BY references that + // don't appear in the user-visible select list. These are stripped by + // deparse for the SELECT list output but referenced by SortClause. + ResJunk bool + + // Provenance — populated when Expr is a single VarExprQ that resolves to + // a physical table column (or a chain of VarExprQs through views/CTEs + // that bottoms out at one). Used by: + // - SDL view column derivation (Phase 2) + // - lineage shortcut path (Phase 2 in-test walker) + // + // Empty when the column is computed (`a+1`, `COUNT(*)`, `CASE ...`). + // Multi-source provenance (e.g. COALESCE over two columns) is the + // lineage walker's responsibility, not this shortcut field's. + ResOrigDB string + ResOrigTable string + ResOrigCol string +} + +// SortGroupClauseQ references a TargetEntryQ from GROUP BY / ORDER BY / DISTINCT. +// +// pg: src/include/nodes/parsenodes.h — SortGroupClause +// +// MySQL note: PG's SortGroupClause carries equality- and sort-operator OIDs; +// we don't (no operator OID space). Collation is captured on the underlying +// VarExprQ / expression rather than here. +type SortGroupClauseQ struct { + // TargetIdx is the 1-based index into Query.TargetList. + // Note: PG's TLESortGroupRef is 0-based; we standardize on 1-based here + // to match TargetEntryQ.ResNo and avoid arithmetic shifts. + TargetIdx int + + // Descending controls sort order. + Descending bool + + // NullsFirst is preserved for symmetry with PG; MySQL has no + // `NULLS FIRST` / `NULLS LAST` syntax. The analyzer always sets it to + // MySQL's default behavior: + // - ASC → NULLs sort first + // - DESC → NULLs sort last + NullsFirst bool +} + +// ============================================================================= +// Section 3 — Range table +// ============================================================================= + +// RTEKind discriminates the variant of a RangeTableEntryQ. +// +// pg: src/include/nodes/parsenodes.h — RTEKind +type RTEKind int + +const ( + // RTERelation is a base table or view referenced from FROM. + // + // For views: the analyzer sets IsView=true and ViewAlgorithm. The view's + // underlying body is NOT substituted into Subquery here. Consumers that + // want lineage transparency through MERGE views call the (Phase 2) + // helper `(*Query).ExpandMergeViews()` to perform substitution at consume + // time. This keeps deparse-of-original-text correct (SHOW CREATE VIEW + // must reproduce the user's text, not the inlined body) and lets each + // consumer decide whether to expand. + RTERelation RTEKind = iota + + // RTESubquery is a subquery in FROM (`FROM (SELECT ...) AS x`). + // Subquery holds the analyzed inner Query. DBName and TableName are + // empty; the user-visible name is in Alias / ERef. + RTESubquery + + // RTEJoin is the synthetic RTE created for the result of a JOIN + // expression. ColNames holds the *coalesced* column list (NATURAL JOIN + // and USING merge same-named columns); the underlying tables remain in + // the RangeTable as separate RTEs referenced by JoinExprNodeQ.Left/Right. + RTEJoin + + // RTECTE is a reference to a CTE declared in WITH. CTEIndex indexes into + // the enclosing Query.CTEList; the CTE body itself is + // `Query.CTEList[i].Query`. + RTECTE + + // RTEFunction is a function-in-FROM clause. In MySQL the only such + // construct is `JSON_TABLE(...)` (8.0.19+). FuncExprs holds the analyzed + // function call expression(s). Phase 1 analyzer rejects this kind with + // an "unsupported" error; full implementation is deferred to Phase 8+. + // The kind exists in the IR now so callers can dispatch on it. + RTEFunction +) + +// RangeTableEntryQ represents one entry in Query.RangeTable. +// +// pg: src/include/nodes/parsenodes.h — RangeTblEntry +// +// # Identification +// +// MySQL has no schema namespace, so we identify base relations by the +// (DBName, TableName) pair instead of by an OID. See plan doc decision D1. +// +// # Per-Kind Field Applicability +// +// Field RTERelation RTESubquery RTEJoin RTECTE RTEFunction +// ------------- ----------- ----------- ------- ------ ----------- +// DBName ✓ - - - - +// TableName ✓ - - - - +// Alias ✓ ✓ - ✓ ✓ +// ERef ✓ ✓ ✓ ✓ ✓ +// ColNames ✓ ✓ ✓ ✓ ✓ +// ColTypes ✓ ✓ ✓ ✓ ✓ +// ColCollations ✓ ✓ ✓ ✓ ✓ +// Subquery - ✓ - ✓ (mirror) - +// JoinType - - ✓ - - +// JoinUsing - - ✓ - - +// CTEIndex - - - ✓ - +// CTEName - - - ✓ - +// Lateral - ✓ - - ✓ +// IsView ✓ - - - - +// ViewAlgorithm ✓ - - - - +// FuncExprs - - - - ✓ +// +// Fields not applicable to a kind are zero-valued. The analyzer must enforce +// this; consumers may assume it. +type RangeTableEntryQ struct { + Kind RTEKind + + // DBName / TableName — populated for RTERelation only (the underlying + // base table or view). For other kinds these are empty; the user-visible + // name is in Alias / ERef. + DBName string + TableName string + + // Alias is the user-provided alias (`FROM t AS x` → "x"). Empty if none. + Alias string + + // ERef is the *effective reference name* used to qualify columns from + // this RTE. It is Alias if non-empty, else TableName, else a synthesized + // name for unaliased subqueries. Always populated. + ERef string + + // Column catalog — present for ALL kinds. The analyzer populates it + // differently per kind: + // - RTERelation: from the catalog (table/view definition) + // - RTESubquery: from Subquery.TargetList (non-junk columns) + // - RTECTE: from the CTE body's TargetList + // - RTEJoin: from the join's coalesced column list + // - RTEFunction: from the function's declared output schema (JSON_TABLE) + // + // Contract: the three slices are parallel and equal-length to ColNames. + // In Phase 1, ColTypes and ColCollations entries are nil/empty — + // populated in Phase 3. Consumers must tolerate nil entries until then. + ColNames []string + ColTypes []*ResolvedType // parallel to ColNames; entries may be nil in Phase 1 + ColCollations []string // parallel to ColNames; entries may be empty in Phase 1 + + // Subquery is populated for RTESubquery, and mirrors the CTE body for + // RTECTE (for convenience when walking). + // + // IMPORTANT: For RTERelation referring to a view, Subquery is NOT + // populated by the analyzer. View body expansion happens at consume time + // via (*Query).ExpandMergeViews() — see RTERelation doc. + Subquery *Query + + // JoinType / JoinUsing — populated for RTEJoin only. Mirrors the + // corresponding JoinExprNodeQ for convenience during lineage walking. + // + // Invariant: JoinType here MUST equal the corresponding + // JoinExprNodeQ.JoinType, and JoinUsing MUST equal JoinExprNodeQ.UsingClause. + // The analyzer maintains this; consumers may assume it. + JoinType JoinTypeQ + JoinUsing []string // USING column names, in source order + + // CTEIndex / CTEName — populated for RTECTE. + // CTEIndex is the index into Query.CTEList of the referenced CTE. + CTEIndex int + CTEName string + + // Lateral marks an RTESubquery (or RTEFunction) as LATERAL, allowing it + // to reference columns from earlier FROM items in the same FROM clause. + // MySQL 8.0.14+. + Lateral bool + + // View metadata — populated when an RTERelation references a view. + // IsView=false, ViewAlgorithm=ViewAlgNone → not a view + // IsView=true, ViewAlgorithm=ViewAlgMerge → MERGE view + // IsView=true, ViewAlgorithm=ViewAlgTemptable → TEMPTABLE view + // IsView=true, ViewAlgorithm=ViewAlgUndefined → UNDEFINED (MySQL chooses) + // Consumers MUST check IsView before interpreting ViewAlgorithm — the + // zero value (ViewAlgNone) means "not a view", not "undefined view". + IsView bool + ViewAlgorithm ViewAlgorithm + + // FuncExprs — populated for RTEFunction only (JSON_TABLE in MySQL 8.0.19+). + // Holds the analyzed function call expression(s). Phase 1 analyzer + // rejects RTEFunction; this field exists for forward compatibility. + FuncExprs []AnalyzedExpr +} + +// ViewAlgorithm mirrors the MySQL `ALGORITHM` clause on CREATE VIEW. +// +// MySQL doc: https://dev.mysql.com/doc/refman/8.0/en/view-algorithms.html +type ViewAlgorithm int + +const ( + // ViewAlgNone is the zero value used when the RTE is not a view at all. + // IsView=false implies ViewAlgorithm == ViewAlgNone. + ViewAlgNone ViewAlgorithm = iota + + // ViewAlgUndefined — MySQL chooses MERGE if possible, else TEMPTABLE. + // User wrote no ALGORITHM clause, or wrote ALGORITHM=UNDEFINED explicitly. + ViewAlgUndefined + + // ViewAlgMerge — view body is rewritten into the referencing query. + // Lineage-transparent: callers of ExpandMergeViews() will see through it. + ViewAlgMerge + + // ViewAlgTemptable — view is materialized into a temporary table. + // Lineage-opaque: ExpandMergeViews() does not expand TEMPTABLE views. + ViewAlgTemptable +) + +// ============================================================================= +// Section 4 — Join tree +// ============================================================================= + +// JoinTreeQ describes the FROM clause and WHERE clause structure. +// +// pg: src/include/nodes/primnodes.h — FromExpr +type JoinTreeQ struct { + // FromList holds the top-level FROM items, each one a JoinNode. + // `FROM a, b, c` produces three RangeTableRefQ entries; nested JOINs + // produce JoinExprNodeQ entries. + // + // Empty for set-operation queries (Query.SetOp != SetOpNone). + FromList []JoinNode + + // Quals is the analyzed WHERE expression, nil if no WHERE clause. + // Always nil for set-operation queries. + Quals AnalyzedExpr +} + +// JoinNode is the interface for items in a JoinTreeQ's FromList and for the +// children of a JoinExprNodeQ. Implementations: *RangeTableRefQ, *JoinExprNodeQ. +type JoinNode interface { + joinNodeTag() +} + +// RangeTableRefQ is a leaf in the join tree — a reference to a single RTE. +// +// pg: src/include/nodes/primnodes.h — RangeTblRef +type RangeTableRefQ struct { + // RTIndex is the 0-based index into the enclosing Query.RangeTable. + RTIndex int +} + +func (*RangeTableRefQ) joinNodeTag() {} + +// JoinExprNodeQ represents a JOIN expression in the FROM clause. +// +// pg: src/include/nodes/primnodes.h — JoinExpr +// +// The join itself produces an RTEJoin entry in the enclosing Query.RangeTable; +// RTIndex is the index of that synthetic RTE. The join's input rows come from +// Left and Right (each a RangeTableRefQ or another JoinExprNodeQ). +// +// Invariant: this node's JoinType MUST equal the corresponding +// `Query.RangeTable[RTIndex].JoinType`, and UsingClause MUST equal that RTE's +// JoinUsing. The analyzer maintains both copies in sync. +type JoinExprNodeQ struct { + JoinType JoinTypeQ + Left JoinNode + Right JoinNode + Quals AnalyzedExpr // ON expression, nil for CROSS / NATURAL / USING-only + UsingClause []string // USING (col, col, ...) — empty if not used + Natural bool // NATURAL JOIN flag + RTIndex int // index of this join's synthetic RTE in RangeTable +} + +func (*JoinExprNodeQ) joinNodeTag() {} + +// JoinTypeQ discriminates the kind of JOIN. +// +// MySQL note: includes STRAIGHT_JOIN (MySQL-only optimizer hint that forces +// left-to-right join order). The hint affects optimizer behavior but not +// lineage; deparse must preserve it. +// +// Naming: this enum is `JoinTypeQ` (with Q) because `mysql/ast.JoinType` +// already exists. The Q suffix is mechanically forced here, not stylistic. +type JoinTypeQ int + +const ( + JoinInner JoinTypeQ = iota + JoinLeft + JoinRight + JoinCross + JoinStraight // STRAIGHT_JOIN — MySQL-only + // Note: MySQL does NOT support FULL OUTER JOIN. +) + +// ============================================================================= +// Section 5 — CTE +// ============================================================================= + +// CommonTableExprQ is an analyzed CTE declared in a WITH clause. +// +// pg: src/include/nodes/parsenodes.h — CommonTableExpr +type CommonTableExprQ struct { + // Name is the CTE name (`WITH x AS (...)` → "x"). + Name string + + // ColumnNames is the optional explicit column rename list + // (`WITH x(a, b) AS (...)`). Empty if not specified. + ColumnNames []string + + // Query is the analyzed body. For recursive CTEs, the analyzer handles + // the self-reference by giving the CTE its own RTECTE entry visible to + // its own body during analysis. + Query *Query + + // Recursive marks WITH RECURSIVE CTEs. + Recursive bool +} + +// ============================================================================= +// Section 6 — Set operations +// ============================================================================= + +// SetOpType discriminates the kind of set operation in a Query. +// +// MySQL gained INTERSECT and EXCEPT in 8.0.31; UNION has been supported +// since the beginning. The All distinction (UNION vs UNION ALL) is carried +// on Query.AllSetOp rather than as separate enum values. +type SetOpType int + +const ( + SetOpNone SetOpType = iota + SetOpUnion + SetOpIntersect + SetOpExcept +) + +// ============================================================================= +// Section 7 — Locking clause +// ============================================================================= + +// LockingClauseQ is the analyzed FOR UPDATE / FOR SHARE / LOCK IN SHARE MODE. +// +// pg: src/include/nodes/parsenodes.h — LockingClause (PG's syntax differs) +type LockingClauseQ struct { + Strength LockStrength + + // Tables is the OF list, if specified. Empty means "all tables in FROM". + // + // Constraint: Tables MUST be empty when Strength == LockInShareMode + // (the legacy syntax does not support OF). Analyzer enforces. + Tables []string + + // WaitPolicy controls behavior on lock contention. + // + // Constraint: WaitPolicy MUST be LockWaitDefault when + // Strength == LockInShareMode (the legacy syntax does not support + // NOWAIT or SKIP LOCKED). Analyzer enforces. + WaitPolicy LockWaitPolicy +} + +// LockStrength enumerates lock modes. +// +// IMPORTANT: LockForShare and LockInShareMode are NOT synonyms despite +// targeting the same lock semantics. They differ in their syntactic +// envelope: +// +// - FOR SHARE (8.0+): supports OF tbl_name, NOWAIT, SKIP LOCKED +// - LOCK IN SHARE MODE (legacy): does NOT support any of those modifiers +// +// Both enum values exist so deparse can faithfully reproduce the user's +// original syntax. +type LockStrength int + +const ( + LockNone LockStrength = iota + LockForUpdate // FOR UPDATE + LockForShare // FOR SHARE (8.0+, supports OF/NOWAIT/SKIP LOCKED) + LockInShareMode // LOCK IN SHARE MODE (legacy, no modifier support) +) + +// LockWaitPolicy enumerates the wait behavior on lock contention. +type LockWaitPolicy int + +const ( + LockWaitDefault LockWaitPolicy = iota // block until lock acquired (no NOWAIT/SKIP LOCKED) + LockWaitNowait // NOWAIT — error on contention + LockWaitSkipLocked // SKIP LOCKED — silently skip locked rows +) + +// ============================================================================= +// Section 8 — AnalyzedExpr interface and implementations +// ============================================================================= + +// AnalyzedExpr is the interface implemented by all post-analysis expression +// nodes. It exposes the resolved type and collation of the expression's +// result, plus an unexported tag method to close the interface. +// +// pg: src/include/nodes/primnodes.h — Expr (base node) +// +// Naming: this interface does NOT carry a Q suffix because interfaces are +// category labels rather than IR nodes themselves. PG also uses +// `AnalyzedExpr`. +// +// In Phase 1, exprType() returns nil for most node kinds because the +// analyzer does not yet do type inference. Phase 3 fills these in. +// Consumers must tolerate nil return values until Phase 3. +// +// Two exceptions where exprType() is non-nil even in Phase 1: +// - BoolExprQ.exprType() returns BoolType (the package-level singleton) +// - NullTestExprQ.exprType() returns BoolType +type AnalyzedExpr interface { + // exprType returns the resolved result type. May be nil in Phase 1/2 for + // most node kinds; always BoolType for boolean-result nodes. + exprType() *ResolvedType + + // exprCollation returns the resolved result collation name (e.g. + // "utf8mb4_0900_ai_ci"). Empty string in Phase 1/2. + exprCollation() string + + // analyzedExprTag closes the interface to this package's types. + analyzedExprTag() +} + +// BoolType is the package-level singleton for boolean-result expressions. +// MySQL has no native BOOLEAN type; TINYINT(1) is the conventional encoding. +// Defined here so BoolExprQ and NullTestExprQ can return a non-nil exprType +// even in Phase 1, eliminating a class of nil-checks in consumers. +// +// Note that this uses BaseTypeTinyIntBool, not BaseTypeTinyInt — see +// MySQLBaseType § 10 for the rationale. +var BoolType = &ResolvedType{ + BaseType: BaseTypeTinyIntBool, +} + +// ---------------------------------------------------------------------------- +// VarExprQ — resolved column reference. The single most important node kind. +// ---------------------------------------------------------------------------- + +// VarExprQ is a resolved column reference: an `(RangeIdx, AttNum)` coordinate +// into the enclosing Query's RangeTable. +// +// pg: src/include/nodes/primnodes.h — Var +// +// Lineage extraction (Phase 2 in-test walker, bytebase adapter) traverses +// the TargetList collecting VarExprQs and resolves each one through the +// RangeTable: +// +// - RTERelation: terminal — emit (DBName, TableName, ColNames[AttNum-1]) +// - RTESubquery: recurse into Subquery.TargetList[AttNum-1].Expr +// - RTECTE: recurse into the CTE body's TargetList[AttNum-1].Expr +// - RTEJoin: dispatch to the underlying Left/Right RTE based on which +// side the column was coalesced from +// +// For RTERelation that is a view (IsView=true), the walker may call +// (*Query).ExpandMergeViews() first to make MERGE views transparent. +type VarExprQ struct { + // RangeIdx is the 0-based index into Query.RangeTable. + RangeIdx int + + // AttNum is the 1-based column number within + // `Query.RangeTable[RangeIdx].ColNames`. Matches PG convention so lineage + // algorithms port without arithmetic shifts. + AttNum int + + // LevelsUp distinguishes correlated subquery references: + // 0 = column from this query level + // 1 = column from immediate enclosing query + // ... + LevelsUp int + + // Type and Collation — populated in Phase 3+. Phase 1 leaves them nil/empty. + // The field is named Type rather than ResolvedType for brevity at use sites + // (`v.Type` reads naturally; `v.ResolvedType` is verbose). + Type *ResolvedType + Collation string +} + +func (*VarExprQ) analyzedExprTag() {} +func (v *VarExprQ) exprType() *ResolvedType { return v.Type } +func (v *VarExprQ) exprCollation() string { return v.Collation } + +// ---------------------------------------------------------------------------- +// ConstExprQ — typed constant. +// ---------------------------------------------------------------------------- + +// ConstExprQ is a literal value with a resolved type. +// +// pg: src/include/nodes/primnodes.h — Const +type ConstExprQ struct { + Type *ResolvedType + Collation string + IsNull bool + // Value is the textual form as the lexer produced it. Quoting and + // numeric format are preserved so deparse can reproduce the original. + Value string +} + +func (*ConstExprQ) analyzedExprTag() {} +func (c *ConstExprQ) exprType() *ResolvedType { return c.Type } +func (c *ConstExprQ) exprCollation() string { return c.Collation } + +// ---------------------------------------------------------------------------- +// FuncCallExprQ — scalar / aggregate / window function call. +// ---------------------------------------------------------------------------- + +// FuncCallExprQ is a resolved function invocation. Aggregates and window +// functions are NOT separated into their own node types (PG has Aggref and +// WindowFunc); instead, IsAggregate distinguishes aggregates and Over +// distinguishes window functions. This matches the MySQL parser's choice and +// keeps the IR smaller. +// +// pg: src/include/nodes/primnodes.h — FuncExpr / Aggref / WindowFunc (merged) +type FuncCallExprQ struct { + // Name is the canonical lowercase function name (e.g. "concat", "count"). + Name string + + // Args are the analyzed arguments in source order. + Args []AnalyzedExpr + + // IsAggregate marks this call as an aggregate function (COUNT, SUM, AVG, + // MIN, MAX, GROUP_CONCAT, etc.). Set by the analyzer based on a function + // catalog, NOT by string-matching `Name` at consumer sites. Lineage + // walkers and HasAggs validation use this field directly. + IsAggregate bool + + // Distinct marks aggregates with DISTINCT, e.g. COUNT(DISTINCT x). + // Always false for non-aggregate functions. + Distinct bool + + // Over is the analyzed OVER clause for window functions; nil for plain + // scalar / aggregate calls. The analyzer sets Over for any call followed + // by an OVER clause in the source. Window functions are detected by + // `Over != nil`, not by IsAggregate (window funcs are NOT aggregates + // even though some aggregates can be used as window funcs). + Over *WindowDefQ + + // ResultType — populated in Phase 3 from the function return type table. + // Phase 1 leaves nil. + ResultType *ResolvedType + Collation string +} + +func (*FuncCallExprQ) analyzedExprTag() {} +func (f *FuncCallExprQ) exprType() *ResolvedType { return f.ResultType } +func (f *FuncCallExprQ) exprCollation() string { return f.Collation } + +// ---------------------------------------------------------------------------- +// OpExprQ — binary / unary operator expression. +// ---------------------------------------------------------------------------- + +// OpExprQ is a resolved operator application. Operators are represented by +// their canonical text (`+`, `=`, `LIKE`, `IS DISTINCT FROM`) rather than by +// an OID — MySQL has no operator OID space, and PG-style operator +// overloading is absent. +// +// Note: `LIKE x ESCAPE y` is NOT modeled here (the ESCAPE adds a third +// operand). It is deferred — see "deferred expression nodes" below. +// +// pg: src/include/nodes/primnodes.h — OpExpr +type OpExprQ struct { + Op string + + // Left is nil for prefix unary operators (e.g. `-x`). + // `NOT x` uses BoolExprQ, not OpExprQ. + Left AnalyzedExpr + Right AnalyzedExpr + + ResultType *ResolvedType + Collation string +} + +func (*OpExprQ) analyzedExprTag() {} +func (o *OpExprQ) exprType() *ResolvedType { return o.ResultType } +func (o *OpExprQ) exprCollation() string { return o.Collation } + +// ---------------------------------------------------------------------------- +// BoolExprQ — AND / OR / NOT. +// ---------------------------------------------------------------------------- + +// BoolExprQ is a logical combinator. +// +// pg: src/include/nodes/primnodes.h — BoolExpr +type BoolExprQ struct { + Op BoolOpType + Args []AnalyzedExpr // single arg for NOT +} + +// BoolOpType enumerates AND / OR / NOT. +type BoolOpType int + +const ( + BoolAnd BoolOpType = iota + BoolOr + BoolNot +) + +func (*BoolExprQ) analyzedExprTag() {} +func (*BoolExprQ) exprType() *ResolvedType { return BoolType } +func (*BoolExprQ) exprCollation() string { return "" } + +// ---------------------------------------------------------------------------- +// CaseExprQ — CASE WHEN. +// ---------------------------------------------------------------------------- + +// CaseExprQ is both forms of CASE: +// +// simple: CASE x WHEN 1 THEN ... WHEN 2 THEN ... ELSE ... END +// searched: CASE WHEN cond1 THEN ... WHEN cond2 THEN ... ELSE ... END +// +// TestExpr is non-nil for the simple form, nil for the searched form. +// +// pg: src/include/nodes/primnodes.h — CaseExpr +type CaseExprQ struct { + TestExpr AnalyzedExpr // nil for searched CASE + Args []*CaseWhenQ + Default AnalyzedExpr // ELSE branch; nil if absent + ResultType *ResolvedType + Collation string +} + +// CaseWhenQ is one WHEN/THEN arm of a CASE expression. +// +// pg: src/include/nodes/primnodes.h — CaseWhen +type CaseWhenQ struct { + Cond AnalyzedExpr + Then AnalyzedExpr +} + +func (*CaseExprQ) analyzedExprTag() {} +func (c *CaseExprQ) exprType() *ResolvedType { return c.ResultType } +func (c *CaseExprQ) exprCollation() string { return c.Collation } + +// ---------------------------------------------------------------------------- +// CoalesceExprQ — COALESCE / IFNULL. +// ---------------------------------------------------------------------------- + +// CoalesceExprQ is a COALESCE() call. IFNULL is normalized to a 2-arg +// CoalesceExprQ by the analyzer. +// +// pg: src/include/nodes/primnodes.h — CoalesceExpr +type CoalesceExprQ struct { + Args []AnalyzedExpr + ResultType *ResolvedType + Collation string +} + +func (*CoalesceExprQ) analyzedExprTag() {} +func (c *CoalesceExprQ) exprType() *ResolvedType { return c.ResultType } +func (c *CoalesceExprQ) exprCollation() string { return c.Collation } + +// ---------------------------------------------------------------------------- +// NullTestExprQ — IS NULL / IS NOT NULL. +// ---------------------------------------------------------------------------- + +// NullTestExprQ is the IS [NOT] NULL predicate. +// +// pg: src/include/nodes/primnodes.h — NullTest +type NullTestExprQ struct { + Arg AnalyzedExpr + IsNull bool // true = IS NULL, false = IS NOT NULL +} + +func (*NullTestExprQ) analyzedExprTag() {} +func (*NullTestExprQ) exprType() *ResolvedType { return BoolType } +func (*NullTestExprQ) exprCollation() string { return "" } + +// ---------------------------------------------------------------------------- +// SubLinkExprQ — subquery expression (EXISTS, IN, scalar, ALL/ANY). +// ---------------------------------------------------------------------------- + +// SubLinkExprQ is a subquery used as an expression: `EXISTS (...)`, +// `x IN (SELECT ...)`, `(SELECT ...)`, `x = ANY (...)`. +// +// Note: `x IN (1, 2, 3)` (non-subquery) is NOT a SubLinkExprQ — it is an +// InListExprQ. The two are kept separate so consumers can distinguish them +// without having to peek at the right-hand side. +// +// pg: src/include/nodes/primnodes.h — SubLink +type SubLinkExprQ struct { + Kind SubLinkKind + + // TestExpr is the left-hand side for ALL / ANY / IN forms; nil for + // EXISTS and scalar. + TestExpr AnalyzedExpr + + // Op is the comparison operator for ALL / ANY / IN forms ("=", "<", ...). + // For SubLinkIn, the comparison is implicitly "=". Empty for EXISTS / + // scalar. + Op string + + // Subquery is the analyzed inner SELECT. + Subquery *Query + + // ResultType is BoolType for EXISTS / IN / ALL / ANY; the inner column's + // type for scalar. Phase 3 fills it in. + ResultType *ResolvedType + + // Collation applies to scalar subqueries returning string types. Empty + // for boolean-result kinds. Phase 3 fills it in. + Collation string +} + +// SubLinkKind discriminates the subquery flavor. +type SubLinkKind int + +const ( + SubLinkExists SubLinkKind = iota // EXISTS (...) + SubLinkScalar // (SELECT ...) used as scalar + SubLinkIn // x IN (SELECT ...) + SubLinkAny // x op ANY (SELECT ...) + SubLinkAll // x op ALL (SELECT ...) +) + +func (*SubLinkExprQ) analyzedExprTag() {} +func (s *SubLinkExprQ) exprType() *ResolvedType { return s.ResultType } +func (s *SubLinkExprQ) exprCollation() string { return s.Collation } + +// ---------------------------------------------------------------------------- +// InListExprQ — x IN (literal_list). +// ---------------------------------------------------------------------------- + +// InListExprQ is `x IN (1, 2, 3)` — the non-subquery form. Modeled as its +// own node (rather than lowered to a chain of OR/=) so deparse can faithfully +// reproduce the source syntax and lineage walkers can introspect the list. +// +// pg: src/include/nodes/primnodes.h — ScalarArrayOpExpr (different shape) +type InListExprQ struct { + Arg AnalyzedExpr + List []AnalyzedExpr + Negated bool // true for NOT IN +} + +func (*InListExprQ) analyzedExprTag() {} +func (*InListExprQ) exprType() *ResolvedType { return BoolType } +func (*InListExprQ) exprCollation() string { return "" } + +// ---------------------------------------------------------------------------- +// BetweenExprQ — x BETWEEN a AND b. +// ---------------------------------------------------------------------------- + +// BetweenExprQ is `x BETWEEN a AND b` (and the negated `NOT BETWEEN` form). +// Modeled as its own node rather than lowered to `BoolAnd(>=, <=)` so +// deparse can faithfully reproduce the source syntax. The lineage walker +// descends into Arg, Lower, and Upper. +// +// pg: src/parser/parse_expr.c — transformAExprBetween (lowered in PG; we keep) +type BetweenExprQ struct { + Arg AnalyzedExpr + Lower AnalyzedExpr + Upper AnalyzedExpr + Negated bool // true for NOT BETWEEN +} + +func (*BetweenExprQ) analyzedExprTag() {} +func (*BetweenExprQ) exprType() *ResolvedType { return BoolType } +func (*BetweenExprQ) exprCollation() string { return "" } + +// ---------------------------------------------------------------------------- +// RowExprQ — ROW(...) constructor. +// ---------------------------------------------------------------------------- + +// RowExprQ is a row constructor used in row comparisons: +// `ROW(a, b) = ROW(1, 2)` or implicit row form `(a, b) = (1, 2)`. +// +// pg: src/include/nodes/primnodes.h — RowExpr +type RowExprQ struct { + Args []AnalyzedExpr + ResultType *ResolvedType +} + +func (*RowExprQ) analyzedExprTag() {} +func (r *RowExprQ) exprType() *ResolvedType { return r.ResultType } +func (*RowExprQ) exprCollation() string { return "" } + +// ---------------------------------------------------------------------------- +// CastExprQ — explicit CAST / CONVERT. +// ---------------------------------------------------------------------------- + +// CastExprQ is an *explicit* type cast: `CAST(x AS UNSIGNED)`, +// `CONVERT(x, CHAR)`. +// +// Implicit casts inserted by MySQL during expression evaluation are NOT +// materialized as CastExprQ nodes — see decision D6 in the plan doc. The +// analyzer leaves implicit conversions implicit and lets deparse / lineage +// reason about them as needed. This decision is revisited at end of Phase 3. +// +// pg: src/include/nodes/parsenodes.h — TypeCast (raw form), +// src/include/nodes/primnodes.h — CoerceViaIO / RelabelType (analyzed form) +type CastExprQ struct { + Arg AnalyzedExpr + TargetType *ResolvedType + + // Collation applies when the cast target is a string type with an + // explicit COLLATE clause: `CAST(x AS CHAR CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)`. + Collation string +} + +func (*CastExprQ) analyzedExprTag() {} +func (c *CastExprQ) exprType() *ResolvedType { return c.TargetType } +func (c *CastExprQ) exprCollation() string { return c.Collation } + +// ---------------------------------------------------------------------------- +// Deferred expression nodes +// ---------------------------------------------------------------------------- +// +// The following expression node kinds are intentionally NOT in the IR yet. +// Each has a known consumer or scenario but is deferred to keep Phase 0/1 +// scope manageable. Listed here so reviewers can challenge the omissions +// explicitly: +// +// - LikeExprQ: `x LIKE pattern ESCAPE c`. The 3-operand ESCAPE form +// does not fit OpExprQ; deferred to Phase 3. +// - MatchAgainstExprQ: `MATCH (col) AGAINST ('text' IN NATURAL LANGUAGE MODE)`. +// MySQL fulltext predicate. Deferred to Phase 4 (deparse +// fidelity for fulltext indexes). +// - JsonExtractExprQ: `col->'$.x'` and `col->>'$.x'` operator forms (vs the +// JSON_EXTRACT function form). Deferred to Phase 3 once +// we decide whether the analyzer normalizes one form to +// the other or preserves both. +// - AssignmentExprQ: `SET col = expr` for UPDATE statements. Deferred to +// Phase 8+ when DML analysis lands. +// +// If a Phase 1/2 corpus query hits one of these, the analyzer should return +// a clear "unsupported expression kind" error rather than silently dropping +// the node. + +// ============================================================================= +// Section 9 — Window definition +// ============================================================================= + +// WindowDefQ is the analyzed OVER clause of a window function, OR a named +// window declaration from the WINDOW clause. The same struct serves both +// roles: +// +// 1. Reference (`OVER w`): +// Only `Name` is set; PartitionBy/OrderBy/FrameClause are empty. +// Analyzer resolves the reference against `Query.WindowClause`. +// +// 2. Inline definition (`OVER (PARTITION BY ... ORDER BY ...)`): +// `Name` is empty; PartitionBy/OrderBy/FrameClause carry the definition. +// +// 3. Named declaration (`WINDOW w AS (PARTITION BY ...)`): +// Stored in `Query.WindowClause`. Both `Name` and the body fields are set. +// +// In Phase 1 the FrameClause is preserved as raw text — frame parsing is +// nontrivial and not required for lineage / SDL view bodies. Phase 3+ may +// upgrade this to a structured form if needed (Phase 4 deparse round-trip +// will care). +// +// pg: src/include/nodes/parsenodes.h — WindowDef +type WindowDefQ struct { + // Name is the window's identifier: + // - For references: the name being referenced (`OVER w` → "w") + // - For inline definitions: empty + // - For declarations in WindowClause: the declared name + Name string + + // PartitionBy and OrderBy are analyzed. + PartitionBy []AnalyzedExpr + OrderBy []*SortGroupClauseQ + + // FrameClause is the raw text of the ROWS / RANGE / GROUPS frame + // specification, e.g. "ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW". + // Empty if no frame specified. + // + // Phase 4 round-trip caveat: raw text is from the lexer; whitespace and + // keyword case are not normalized. Round-trip deparse must normalize on + // output or the SDL diff will see false positives. + FrameClause string +} + +// ============================================================================= +// Section 10 — ResolvedType + MySQLBaseType +// ============================================================================= + +// ResolvedType is the analyzed-time MySQL type representation used by +// expression nodes (VarExprQ.Type, FuncCallExprQ.ResultType, etc.) and by +// RTE column metadata (RangeTableEntryQ.ColTypes). +// +// # Naming +// +// This type was called `ColumnType` in plan doc § 4.1, but the existing +// `Column.ColumnType` (string) field on `mysql/catalog/table.go:63` shadowing +// forced a rename. The semantic role is "the resolved type of an expression +// or column at analyzer output", hence ResolvedType. +// +// # Status +// +// In Phase 1, instances of this type are not produced — the analyzer leaves +// VarExprQ.Type / RangeTableEntryQ.ColTypes nil. Phase 3 begins populating +// it from the function return type table and view column derivation. +// +// # Design principle +// +// Track only the modifiers that change one of: +// - the value range of the type +// - the storage layout of the type +// - the comparison/sort behavior +// - client-visible semantics (e.g. TINYINT(1)-as-bool) +// +// Modifiers that MySQL itself normalizes away (integer display width, +// ZEROFILL on 8.0.17+) are NOT tracked. Column-level properties +// (Nullable, AutoIncrement, Default, Comment) belong on the Column struct, +// not here. +// +// MySQL doc: https://dev.mysql.com/doc/refman/8.0/en/data-types.html +type ResolvedType struct { + // BaseType is the type's identity. See MySQLBaseType for the value list. + BaseType MySQLBaseType + + // Unsigned applies to integer and decimal/float families. Doubles the + // non-negative range for integers and shifts the range for decimal/float. + Unsigned bool + + // Length is the parameterized length for fixed/variable-length types: + // - VARCHAR(N), CHAR(N), VARBINARY(N), BINARY(N), BIT(N) + // + // NOT used for integer types (INT(N) is deprecated display width; + // MySQL 8.0.17+ strips it. Special-cased TINYINT(1) is represented via + // BaseType=BaseTypeTinyIntBool, not via Length.) + // + // NOT used for TEXT/BLOB families (the size variants are different + // BaseTypes: TINYTEXT/TEXT/MEDIUMTEXT/LONGTEXT). + // + // 0 if unset. + Length int + + // Precision and Scale are for fixed-point types: DECIMAL(P, S), NUMERIC(P, S). + // Both 0 if unset. + Precision int + Scale int + + // FSP is the fractional seconds precision for DATETIME(N), TIME(N), + // TIMESTAMP(N). Range 0–6. Affects both storage size and precision. + // 0 if unset (which means precision 0, not "absent" — DATETIME and + // DATETIME(0) are the same). + FSP int + + // Charset and Collation apply to string-like types and ENUM/SET. Required + // for SDL diff fidelity (changing a column's collation triggers schema + // changes in MySQL). + Charset string + Collation string + + // EnumValues holds the value list for ENUM('a', 'b'). Nil for non-ENUM. + EnumValues []string + + // SetValues holds the value list for SET('a', 'b'). Nil for non-SET. + // Kept separate from EnumValues for clarity at use sites despite the + // structural similarity. + SetValues []string +} + +// MySQLBaseType is the typed enumeration of MySQL base type identities. +// +// One value per distinct type as MySQL itself reports them through +// `SHOW COLUMNS` and `information_schema.COLUMNS`. Storage size, value range, +// and client-visible semantics are properties of the BaseType, not of any +// modifier slot in ResolvedType. +// +// # Special case: TINYINT(1) +// +// MySQL 8.0.17 deprecated integer display widths, so `INT(11)` and `INT(5)` +// are normalized to plain `INT`. The single exception is `TINYINT(1)`, +// which MySQL preserves because client libraries (Connector/J, mysqlclient, +// etc.) treat `TINYINT(1)` columns as boolean. +// +// We model this by giving `TINYINT(1)` its own BaseType +// (`BaseTypeTinyIntBool`) rather than carrying a `DisplayWidth` field on +// ResolvedType. This makes the special case explicit at the type level, +// keeps integer handling code free of display-width branches, and lets SDL +// diff distinguish `TINYINT(1)` from `TINYINT` automatically. +// +// MySQL doc: https://dev.mysql.com/doc/refman/8.0/en/data-types.html +type MySQLBaseType int + +const ( + BaseTypeUnknown MySQLBaseType = iota + + // Integer family + BaseTypeTinyInt + BaseTypeTinyIntBool // TINYINT(1) — client libraries treat as boolean + BaseTypeSmallInt + BaseTypeMediumInt + BaseTypeInt + BaseTypeBigInt + + // Fixed point + BaseTypeDecimal // DECIMAL / NUMERIC + + // Floating point + BaseTypeFloat + BaseTypeDouble // DOUBLE / REAL + + // Bit + BaseTypeBit // BIT(N) + + // Date / time + BaseTypeDate + BaseTypeDateTime + BaseTypeTimestamp // distinct from DATETIME (timezone semantics differ) + BaseTypeTime + BaseTypeYear + + // Strings + BaseTypeChar + BaseTypeVarchar + BaseTypeBinary + BaseTypeVarBinary + BaseTypeTinyText + BaseTypeText + BaseTypeMediumText + BaseTypeLongText + BaseTypeTinyBlob + BaseTypeBlob + BaseTypeMediumBlob + BaseTypeLongBlob + BaseTypeEnum + BaseTypeSet + + // JSON + BaseTypeJSON + + // Spatial + BaseTypeGeometry + BaseTypePoint + BaseTypeLineString + BaseTypePolygon + BaseTypeMultiPoint + BaseTypeMultiLineString + BaseTypeMultiPolygon + BaseTypeGeometryCollection +) + +// ============================================================================= +// Section 11 — SQL mode bitmap +// ============================================================================= + +// SQLMode is a bitmap of MySQL `sql_mode` flags relevant to analyzer behavior. +// +// Captured on Query.SQLMode at analyze time so deparse can reproduce the +// exact semantics later. Includes both flags that affect parsing/analysis +// directly and flags that affect later evaluation (the latter are needed +// for SDL diff: changing sql_mode at the session level can alter how a +// view body is interpreted). +// +// MySQL doc: https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html +type SQLMode uint64 + +const ( + // SQLModeAnsiQuotes — `"x"` is an identifier delimiter, not a string. + // Affects parsing. + SQLModeAnsiQuotes SQLMode = 1 << iota + + // SQLModePipesAsConcat — `||` is string concatenation, not boolean OR. + // Affects expression analysis. + SQLModePipesAsConcat + + // SQLModeIgnoreSpace — allows whitespace between a function name and the + // opening parenthesis. With this off, `count (*)` is a column ref + paren + // expr, not a function call. Directly affects how the parse tree maps to + // FuncCallExprQ vs other forms. + SQLModeIgnoreSpace + + // SQLModeHighNotPrecedence — raises NOT precedence above comparison + // operators. Changes how `NOT a BETWEEN b AND c` is parsed. + SQLModeHighNotPrecedence + + // SQLModeNoBackslashEscapes — disables backslash escape interpretation in + // string literals. Affects ConstExprQ.Value semantics. + SQLModeNoBackslashEscapes + + // SQLModeOnlyFullGroupBy — every non-aggregated select-list column must + // be functionally dependent on GROUP BY. Affects validation but not + // parsing. + SQLModeOnlyFullGroupBy + + // SQLModeNoUnsignedSubtraction — disables unsigned-subtract wraparound. + SQLModeNoUnsignedSubtraction + + // SQLModeStrictAllTables, SQLModeStrictTransTables — error vs warning on + // invalid data. Affects DDL CHECK / GENERATED expression evaluation. + SQLModeStrictAllTables + SQLModeStrictTransTables + + // SQLModeNoZeroDate / SQLModeNoZeroInDate — restrict zero-valued dates. + SQLModeNoZeroDate + SQLModeNoZeroInDate + + // SQLModeRealAsFloat — `REAL` is `FLOAT` instead of `DOUBLE`. + SQLModeRealAsFloat + + // Add more as the analyzer encounters them. Composite modes (ANSI, + // TRADITIONAL) are intentionally NOT modeled here — the analyzer + // expands them at session-capture time. +) + +// ============================================================================= +// End of Phase 0 IR. +// ============================================================================= diff --git a/tidb/catalog/query_expand.go b/tidb/catalog/query_expand.go new file mode 100644 index 00000000..3af1f56d --- /dev/null +++ b/tidb/catalog/query_expand.go @@ -0,0 +1,38 @@ +package catalog + +// ExpandMergeViews creates a copy of the Query where RTERelation entries +// for MERGE views have their Subquery field populated from View.AnalyzedQuery. +// TEMPTABLE views remain opaque. +// +// This is a consume-time operation (decision D5): the analyzer keeps views +// opaque; consumers that want lineage transparency call this method. +// +// The expansion is recursive: if a view references another view, the inner +// view's RTE is also expanded. +func (q *Query) ExpandMergeViews(c *Catalog) *Query { + if q == nil { + return nil + } + + // Shallow copy the query; we only need to replace the RangeTable slice. + expanded := *q + expanded.RangeTable = make([]*RangeTableEntryQ, len(q.RangeTable)) + for i, rte := range q.RangeTable { + if rte.IsView && (rte.ViewAlgorithm == ViewAlgMerge || rte.ViewAlgorithm == ViewAlgUndefined) { + // Look up the view's AnalyzedQuery from the catalog. + db := c.GetDatabase(rte.DBName) + if db != nil { + view := db.Views[toLower(rte.TableName)] + if view != nil && view.AnalyzedQuery != nil { + // Copy RTE and set Subquery to the recursively expanded view body. + rteCopy := *rte + rteCopy.Subquery = view.AnalyzedQuery.ExpandMergeViews(c) + expanded.RangeTable[i] = &rteCopy + continue + } + } + } + expanded.RangeTable[i] = rte + } + return &expanded +} diff --git a/tidb/catalog/query_span_test.go b/tidb/catalog/query_span_test.go new file mode 100644 index 00000000..a36b9670 --- /dev/null +++ b/tidb/catalog/query_span_test.go @@ -0,0 +1,349 @@ +package catalog + +import "testing" + +// --------------------------------------------------------------------------- +// Lineage types and walker — test-only reference implementation. +// --------------------------------------------------------------------------- + +// columnLineage represents one source column contributing to a target column. +type columnLineage struct { + DB string + Table string + Column string +} + +// resultLineage represents the lineage of one output column. +type resultLineage struct { + Name string + SourceCols []columnLineage +} + +// collectColumnLineage walks the analyzed Query and extracts column-level +// lineage for each non-junk target entry. +func collectColumnLineage(_ *Catalog, q *Query) []resultLineage { + var results []resultLineage + for _, te := range q.TargetList { + if te.ResJunk { + continue + } + var sources []columnLineage + collectVarExprs(q, te.Expr, &sources) + results = append(results, resultLineage{ + Name: te.ResName, + SourceCols: sources, + }) + } + return results +} + +// collectVarExprs recursively walks an expression tree to find all VarExprQ +// nodes and resolves them to (db, table, column) tuples via the RangeTable. +func collectVarExprs(q *Query, expr AnalyzedExpr, out *[]columnLineage) { + if expr == nil { + return + } + switch e := expr.(type) { + case *VarExprQ: + resolveVar(q, e, out) + case *OpExprQ: + collectVarExprs(q, e.Left, out) + collectVarExprs(q, e.Right, out) + case *BoolExprQ: + for _, arg := range e.Args { + collectVarExprs(q, arg, out) + } + case *FuncCallExprQ: + for _, arg := range e.Args { + collectVarExprs(q, arg, out) + } + case *CaseExprQ: + collectVarExprs(q, e.TestExpr, out) + for _, w := range e.Args { + collectVarExprs(q, w.Cond, out) + collectVarExprs(q, w.Then, out) + } + collectVarExprs(q, e.Default, out) + case *CoalesceExprQ: + for _, arg := range e.Args { + collectVarExprs(q, arg, out) + } + case *CastExprQ: + collectVarExprs(q, e.Arg, out) + case *NullTestExprQ: + collectVarExprs(q, e.Arg, out) + case *InListExprQ: + collectVarExprs(q, e.Arg, out) + for _, item := range e.List { + collectVarExprs(q, item, out) + } + case *BetweenExprQ: + collectVarExprs(q, e.Arg, out) + collectVarExprs(q, e.Lower, out) + collectVarExprs(q, e.Upper, out) + case *SubLinkExprQ: + if e.Subquery != nil { + for _, innerTE := range e.Subquery.TargetList { + if !innerTE.ResJunk { + collectVarExprs(e.Subquery, innerTE.Expr, out) + } + } + } + case *ConstExprQ: + // No column references in constants. + case *RowExprQ: + for _, arg := range e.Args { + collectVarExprs(q, arg, out) + } + } +} + +// resolveVar resolves a VarExprQ to (db, table, column) by walking the RangeTable. +func resolveVar(q *Query, v *VarExprQ, out *[]columnLineage) { + if v.RangeIdx < 0 || v.RangeIdx >= len(q.RangeTable) { + return + } + rte := q.RangeTable[v.RangeIdx] + colIdx := v.AttNum - 1 // convert to 0-based + if colIdx < 0 || colIdx >= len(rte.ColNames) { + return + } + + switch rte.Kind { + case RTERelation: + if rte.Subquery != nil { + // MERGE view expanded — recurse into the view's analyzed query. + if colIdx < len(rte.Subquery.TargetList) { + collectVarExprs(rte.Subquery, rte.Subquery.TargetList[colIdx].Expr, out) + } + } else { + // Physical table or opaque view — terminal. + *out = append(*out, columnLineage{ + DB: rte.DBName, + Table: rte.TableName, + Column: rte.ColNames[colIdx], + }) + } + case RTESubquery: + if rte.Subquery != nil && colIdx < len(rte.Subquery.TargetList) { + collectVarExprs(rte.Subquery, rte.Subquery.TargetList[colIdx].Expr, out) + } + case RTECTE: + if rte.Subquery != nil && colIdx < len(rte.Subquery.TargetList) { + collectVarExprs(rte.Subquery, rte.Subquery.TargetList[colIdx].Expr, out) + } + case RTEJoin: + *out = append(*out, columnLineage{ + DB: rte.DBName, + Table: rte.TableName, + Column: rte.ColNames[colIdx], + }) + } +} + +// --------------------------------------------------------------------------- +// Helpers for lineage assertions. +// --------------------------------------------------------------------------- + +// assertLineage checks that the lineage results match the expected values. +func assertLineage(t *testing.T, got []resultLineage, expected []resultLineage) { + t.Helper() + if len(got) != len(expected) { + t.Fatalf("lineage: want %d columns, got %d", len(expected), len(got)) + } + for i, want := range expected { + g := got[i] + if g.Name != want.Name { + t.Errorf("column[%d].Name: want %q, got %q", i, want.Name, g.Name) + } + if len(g.SourceCols) != len(want.SourceCols) { + t.Errorf("column[%d] %q: want %d sources, got %d: %v", + i, want.Name, len(want.SourceCols), len(g.SourceCols), g.SourceCols) + continue + } + for j, ws := range want.SourceCols { + gs := g.SourceCols[j] + if gs.DB != ws.DB || gs.Table != ws.Table || gs.Column != ws.Column { + t.Errorf("column[%d] %q source[%d]: want %s.%s.%s, got %s.%s.%s", + i, want.Name, j, + ws.DB, ws.Table, ws.Column, + gs.DB, gs.Table, gs.Column) + } + } + } +} + +// --------------------------------------------------------------------------- +// Phase 2 lineage tests. +// --------------------------------------------------------------------------- + +// TestLineage_14_1_SimpleView tests lineage through a simple view. +func TestLineage_14_1_SimpleView(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, ` + CREATE TABLE employees (id INT, name VARCHAR(100), salary DECIMAL(10,2), department_id INT, is_active TINYINT); + CREATE VIEW emp_names AS SELECT id, name FROM employees; + `) + + sel := parseSelect(t, "SELECT * FROM emp_names") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + expanded := q.ExpandMergeViews(c) + lineage := collectColumnLineage(c, expanded) + + assertLineage(t, lineage, []resultLineage{ + {Name: "id", SourceCols: []columnLineage{{DB: "testdb", Table: "employees", Column: "id"}}}, + {Name: "name", SourceCols: []columnLineage{{DB: "testdb", Table: "employees", Column: "name"}}}, + }) +} + +// TestLineage_14_2_ViewWithExpression tests lineage through a view with an expression. +func TestLineage_14_2_ViewWithExpression(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, ` + CREATE TABLE employees (id INT, name VARCHAR(100), salary DECIMAL(10,2), department_id INT, is_active TINYINT); + CREATE VIEW emp_salary AS SELECT name, salary * 12 AS annual FROM employees; + `) + + sel := parseSelect(t, "SELECT * FROM emp_salary") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + expanded := q.ExpandMergeViews(c) + lineage := collectColumnLineage(c, expanded) + + assertLineage(t, lineage, []resultLineage{ + {Name: "name", SourceCols: []columnLineage{{DB: "testdb", Table: "employees", Column: "name"}}}, + {Name: "annual", SourceCols: []columnLineage{{DB: "testdb", Table: "employees", Column: "salary"}}}, + }) +} + +// TestLineage_14_3_ViewOfView tests lineage through nested views (view-of-view). +func TestLineage_14_3_ViewOfView(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, ` + CREATE TABLE employees (id INT, name VARCHAR(100), salary DECIMAL(10,2), department_id INT, is_active TINYINT); + CREATE VIEW v1 AS SELECT id, name FROM employees; + CREATE VIEW v2 AS SELECT name FROM v1; + `) + + sel := parseSelect(t, "SELECT * FROM v2") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + expanded := q.ExpandMergeViews(c) + lineage := collectColumnLineage(c, expanded) + + assertLineage(t, lineage, []resultLineage{ + {Name: "name", SourceCols: []columnLineage{{DB: "testdb", Table: "employees", Column: "name"}}}, + }) +} + +// TestLineage_14_4_ViewWithJoin tests lineage through a view with a JOIN. +func TestLineage_14_4_ViewWithJoin(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, ` + CREATE TABLE employees (id INT, name VARCHAR(100), salary DECIMAL(10,2), department_id INT, is_active TINYINT); + CREATE TABLE departments (id INT, name VARCHAR(100), budget DECIMAL(15,2)); + CREATE VIEW emp_dept AS SELECT e.name, d.name AS dept FROM employees e JOIN departments d ON e.department_id = d.id; + `) + + sel := parseSelect(t, "SELECT * FROM emp_dept") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + expanded := q.ExpandMergeViews(c) + lineage := collectColumnLineage(c, expanded) + + assertLineage(t, lineage, []resultLineage{ + {Name: "name", SourceCols: []columnLineage{{DB: "testdb", Table: "employees", Column: "name"}}}, + {Name: "dept", SourceCols: []columnLineage{{DB: "testdb", Table: "departments", Column: "name"}}}, + }) +} + +// TestLineage_14_5_TemptableView tests that TEMPTABLE views remain opaque. +func TestLineage_14_5_TemptableView(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, ` + CREATE TABLE employees (id INT, name VARCHAR(100), salary DECIMAL(10,2), department_id INT, is_active TINYINT); + CREATE ALGORITHM=TEMPTABLE VIEW temp_v AS SELECT name FROM employees; + `) + + sel := parseSelect(t, "SELECT * FROM temp_v") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + expanded := q.ExpandMergeViews(c) + lineage := collectColumnLineage(c, expanded) + + // TEMPTABLE view is opaque — lineage stops at the view, not the base table. + assertLineage(t, lineage, []resultLineage{ + {Name: "name", SourceCols: []columnLineage{{DB: "testdb", Table: "temp_v", Column: "name"}}}, + }) +} + +// TestLineage_14_6_CTELineage tests lineage through a CTE. +func TestLineage_14_6_CTELineage(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, ` + CREATE TABLE employees (id INT, name VARCHAR(100), salary DECIMAL(10,2), department_id INT, is_active TINYINT); + `) + + sel := parseSelect(t, ` + WITH active AS (SELECT id, name FROM employees WHERE is_active = 1) + SELECT * FROM active + `) + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + expanded := q.ExpandMergeViews(c) + lineage := collectColumnLineage(c, expanded) + + assertLineage(t, lineage, []resultLineage{ + {Name: "id", SourceCols: []columnLineage{{DB: "testdb", Table: "employees", Column: "id"}}}, + {Name: "name", SourceCols: []columnLineage{{DB: "testdb", Table: "employees", Column: "name"}}}, + }) +} + +// TestLineage_14_7_SubqueryLineage tests lineage through a FROM subquery with aggregate. +func TestLineage_14_7_SubqueryLineage(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, ` + CREATE TABLE employees (id INT, name VARCHAR(100), salary DECIMAL(10,2), department_id INT, is_active TINYINT); + `) + + sel := parseSelect(t, "SELECT x.total FROM (SELECT COUNT(*) AS total FROM employees) AS x") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + expanded := q.ExpandMergeViews(c) + lineage := collectColumnLineage(c, expanded) + + // COUNT(*) has no physical column source — sources should be empty. + assertLineage(t, lineage, []resultLineage{ + {Name: "total", SourceCols: []columnLineage{}}, + }) +} + +// TestLineage_14_8_ViewInJoin tests lineage when a view is used in a JOIN. +func TestLineage_14_8_ViewInJoin(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, ` + CREATE TABLE employees (id INT, name VARCHAR(100), salary DECIMAL(10,2), department_id INT, is_active TINYINT); + CREATE TABLE departments (id INT, name VARCHAR(100), budget DECIMAL(15,2)); + CREATE VIEW dept_info AS SELECT id, name, budget FROM departments; + `) + + sel := parseSelect(t, "SELECT e.name, di.budget FROM employees e JOIN dept_info di ON e.department_id = di.id") + q, err := c.AnalyzeSelectStmt(sel) + assertNoError(t, err) + + expanded := q.ExpandMergeViews(c) + lineage := collectColumnLineage(c, expanded) + + assertLineage(t, lineage, []resultLineage{ + {Name: "name", SourceCols: []columnLineage{{DB: "testdb", Table: "employees", Column: "name"}}}, + {Name: "budget", SourceCols: []columnLineage{{DB: "testdb", Table: "departments", Column: "budget"}}}, + }) +} diff --git a/tidb/catalog/renamecmds.go b/tidb/catalog/renamecmds.go new file mode 100644 index 00000000..b23575fd --- /dev/null +++ b/tidb/catalog/renamecmds.go @@ -0,0 +1,32 @@ +package catalog + +import nodes "github.com/bytebase/omni/tidb/ast" + +func (c *Catalog) renameTable(stmt *nodes.RenameTableStmt) error { + for _, pair := range stmt.Pairs { + oldDB, err := c.resolveDatabase(pair.Old.Schema) + if err != nil { + return err + } + oldKey := toLower(pair.Old.Name) + tbl := oldDB.Tables[oldKey] + if tbl == nil { + return errNoSuchTable(oldDB.Name, pair.Old.Name) + } + + newDB, err := c.resolveDatabase(pair.New.Schema) + if err != nil { + return err + } + newKey := toLower(pair.New.Name) + if newDB.Tables[newKey] != nil { + return errDupTable(pair.New.Name) + } + + delete(oldDB.Tables, oldKey) + tbl.Name = pair.New.Name + tbl.Database = newDB + newDB.Tables[newKey] = tbl + } + return nil +} diff --git a/tidb/catalog/renamecmds_test.go b/tidb/catalog/renamecmds_test.go new file mode 100644 index 00000000..2982b5d0 --- /dev/null +++ b/tidb/catalog/renamecmds_test.go @@ -0,0 +1,49 @@ +package catalog + +import "testing" + +func TestRenameTable(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t1 (id INT)", nil) + _, err := c.Exec("RENAME TABLE t1 TO t2", nil) + if err != nil { + t.Fatal(err) + } + db := c.GetDatabase("test") + if db.GetTable("t1") != nil { + t.Fatal("old table should not exist") + } + if db.GetTable("t2") == nil { + t.Fatal("new table should exist") + } +} + +func TestRenameTableCrossDatabase(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE db1", nil) + c.Exec("CREATE DATABASE db2", nil) + c.SetCurrentDatabase("db1") + c.Exec("CREATE TABLE t1 (id INT)", nil) + _, err := c.Exec("RENAME TABLE db1.t1 TO db2.t2", nil) + if err != nil { + t.Fatal(err) + } + if c.GetDatabase("db1").GetTable("t1") != nil { + t.Fatal("old table should not exist in db1") + } + if c.GetDatabase("db2").GetTable("t2") == nil { + t.Fatal("new table should exist in db2") + } +} + +func TestRenameTableNotFound(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("RENAME TABLE noexist TO t2", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("expected error for missing table") + } +} diff --git a/tidb/catalog/routinecmds.go b/tidb/catalog/routinecmds.go new file mode 100644 index 00000000..6c4c1f1a --- /dev/null +++ b/tidb/catalog/routinecmds.go @@ -0,0 +1,328 @@ +package catalog + +import ( + "fmt" + "strings" + + nodes "github.com/bytebase/omni/tidb/ast" +) + +func (c *Catalog) createRoutine(stmt *nodes.CreateFunctionStmt) error { + db, err := c.resolveDatabase(stmt.Name.Schema) + if err != nil { + return err + } + + name := stmt.Name.Name + key := toLower(name) + + routineMap := db.Functions + if stmt.IsProcedure { + routineMap = db.Procedures + } + + if _, exists := routineMap[key]; exists { + if !stmt.IfNotExists { + if stmt.IsProcedure { + return errDupProcedure(name) + } + return errDupFunction(name) + } + return nil + } + + // Build params. + params := make([]*RoutineParam, 0, len(stmt.Params)) + for _, p := range stmt.Params { + params = append(params, &RoutineParam{ + Direction: p.Direction, + Name: p.Name, + TypeName: formatParamType(p.TypeName), + }) + } + + // Build return type -- MySQL shows lowercase with CHARSET for string types. + var returns string + if stmt.Returns != nil { + returns = formatReturnType(stmt.Returns, db.Charset) + } + + // MySQL always sets a definer. Default to `root`@`%` when not specified. + definer := stmt.Definer + if definer == "" { + definer = "`root`@`%`" + } + + // Build characteristics. + chars := make(map[string]string) + for _, ch := range stmt.Characteristics { + chars[ch.Name] = ch.Value + } + + routine := &Routine{ + Name: name, + Database: db, + IsProcedure: stmt.IsProcedure, + Definer: definer, + Params: params, + Returns: returns, + Body: strings.TrimSpace(stmt.Body), + Characteristics: chars, + } + + routineMap[key] = routine + return nil +} + +func (c *Catalog) dropRoutine(stmt *nodes.DropRoutineStmt) error { + db, err := c.resolveDatabase(stmt.Name.Schema) + if err != nil { + if stmt.IfExists { + return nil + } + return err + } + + name := stmt.Name.Name + key := toLower(name) + + routineMap := db.Functions + if stmt.IsProcedure { + routineMap = db.Procedures + } + + if _, exists := routineMap[key]; !exists { + if stmt.IfExists { + return nil + } + if stmt.IsProcedure { + return errNoSuchProcedure(db.Name, name) + } + return errNoSuchFunction(name) + } + + delete(routineMap, key) + return nil +} + +func (c *Catalog) alterRoutine(stmt *nodes.AlterRoutineStmt) error { + db, err := c.resolveDatabase(stmt.Name.Schema) + if err != nil { + return err + } + + name := stmt.Name.Name + key := toLower(name) + + routineMap := db.Functions + if stmt.IsProcedure { + routineMap = db.Procedures + } + + routine, exists := routineMap[key] + if !exists { + if stmt.IsProcedure { + return errNoSuchProcedure(db.Name, name) + } + return errNoSuchFunction(name) + } + + // Update characteristics. + for _, ch := range stmt.Characteristics { + routine.Characteristics[ch.Name] = ch.Value + } + + return nil +} + +// formatDataType formats a DataType node into a display string for routine parameters/returns. +func formatDataType(dt *nodes.DataType) string { + if dt == nil { + return "" + } + + name := strings.ToLower(dt.Name) + + switch name { + case "int", "integer": + name = "int" + case "tinyint": + name = "tinyint" + case "smallint": + name = "smallint" + case "mediumint": + name = "mediumint" + case "bigint": + name = "bigint" + case "float": + name = "float" + case "double", "real": + name = "double" + case "decimal", "numeric", "dec", "fixed": + name = "decimal" + case "varchar": + name = "varchar" + case "char": + name = "char" + case "text", "tinytext", "mediumtext", "longtext": + // keep as-is + case "blob", "tinyblob", "mediumblob", "longblob": + // keep as-is + case "date", "time", "datetime", "timestamp", "year": + // keep as-is + case "json": + name = "json" + case "bool", "boolean": + name = "tinyint" + } + + var b strings.Builder + b.WriteString(name) + + // Length/precision. + if dt.Length > 0 { + if dt.Scale > 0 { + b.WriteString(fmt.Sprintf("(%d,%d)", dt.Length, dt.Scale)) + } else { + b.WriteString(fmt.Sprintf("(%d)", dt.Length)) + } + } + + if dt.Unsigned { + b.WriteString(" unsigned") + } + + return b.String() +} + +// formatParamType formats a DataType for display in a routine parameter list. +// MySQL 8.0 shows parameter types in UPPERCASE (INT, VARCHAR(100), etc.) +func formatParamType(dt *nodes.DataType) string { + raw := formatDataType(dt) + return strings.ToUpper(raw) +} + +// formatReturnType formats a DataType for display in the RETURNS clause. +// MySQL 8.0 shows return types in lowercase but adds CHARSET for string types. +func formatReturnType(dt *nodes.DataType, dbCharset string) string { + raw := formatDataType(dt) + // MySQL 8.0 appends CHARSET for string return types. + name := strings.ToLower(dt.Name) + if isStringRoutineType(name) { + charset := dt.Charset + if charset == "" { + charset = dbCharset + } + if charset == "" { + charset = "utf8mb4" + } + raw += " CHARSET " + charset + } + return raw +} + +// isStringRoutineType returns true for types where MySQL shows CHARSET in routine RETURNS. +func isStringRoutineType(dt string) bool { + switch dt { + case "varchar", "char", "text", "tinytext", "mediumtext", "longtext", + "enum", "set": + return true + } + return false +} + +// ShowCreateFunction produces MySQL 8.0-compatible SHOW CREATE FUNCTION output. +func (c *Catalog) ShowCreateFunction(database, name string) string { + db := c.GetDatabase(database) + if db == nil { + return "" + } + routine := db.Functions[toLower(name)] + if routine == nil { + return "" + } + return showCreateRoutine(routine) +} + +// ShowCreateProcedure produces MySQL 8.0-compatible SHOW CREATE PROCEDURE output. +func (c *Catalog) ShowCreateProcedure(database, name string) string { + db := c.GetDatabase(database) + if db == nil { + return "" + } + routine := db.Procedures[toLower(name)] + if routine == nil { + return "" + } + return showCreateRoutine(routine) +} + +// showCreateRoutine produces the SHOW CREATE output for a stored routine. +// MySQL 8.0 format: +// +// CREATE DEFINER=`root`@`%` FUNCTION `name`(a INT, b INT) RETURNS int +// DETERMINISTIC +// RETURN a + b +func showCreateRoutine(r *Routine) string { + var b strings.Builder + + b.WriteString("CREATE") + + // DEFINER -- MySQL 8.0 always shows DEFINER. + if r.Definer != "" { + b.WriteString(fmt.Sprintf(" DEFINER=%s", r.Definer)) + } + + if r.IsProcedure { + b.WriteString(fmt.Sprintf(" PROCEDURE `%s`(", r.Name)) + } else { + b.WriteString(fmt.Sprintf(" FUNCTION `%s`(", r.Name)) + } + + // Parameters -- MySQL 8.0 separates with ", " (comma space). + for i, p := range r.Params { + if i > 0 { + b.WriteString(", ") + } + if r.IsProcedure && p.Direction != "" { + b.WriteString(fmt.Sprintf("%s %s %s", p.Direction, p.Name, p.TypeName)) + } else { + b.WriteString(fmt.Sprintf("%s %s", p.Name, p.TypeName)) + } + } + b.WriteString(")") + + // RETURNS (functions only) + if !r.IsProcedure && r.Returns != "" { + b.WriteString(fmt.Sprintf(" RETURNS %s", r.Returns)) + } + + // Characteristics -- MySQL 8.0 outputs each on its own line with 4-space indent. + // Order: DETERMINISTIC, DATA ACCESS, SQL SECURITY, COMMENT + if v, ok := r.Characteristics["DETERMINISTIC"]; ok { + if v == "YES" { + b.WriteString("\n DETERMINISTIC") + } else { + b.WriteString("\n NOT DETERMINISTIC") + } + } + + if v, ok := r.Characteristics["DATA ACCESS"]; ok { + b.WriteString(fmt.Sprintf("\n %s", v)) + } + + if v, ok := r.Characteristics["SQL SECURITY"]; ok { + b.WriteString(fmt.Sprintf("\n SQL SECURITY %s", strings.ToUpper(v))) + } + + if v, ok := r.Characteristics["COMMENT"]; ok { + b.WriteString(fmt.Sprintf("\n COMMENT '%s'", escapeComment(v))) + } + + // Body -- MySQL 8.0 starts body on its own line with no indent. + if r.Body != "" { + b.WriteString(fmt.Sprintf("\n%s", r.Body)) + } + + return b.String() +} diff --git a/tidb/catalog/scenarios_ax_test.go b/tidb/catalog/scenarios_ax_test.go new file mode 100644 index 00000000..d12a2c6f --- /dev/null +++ b/tidb/catalog/scenarios_ax_test.go @@ -0,0 +1,657 @@ +package catalog + +import ( + "fmt" + "strings" + "testing" +) + +// TestScenario_AX covers section AX (ALTER TABLE sub-command implicit +// behaviors) from SCENARIOS-mysql-implicit-behavior.md. Each subtest +// asserts that real MySQL 8.0 and the omni catalog agree on the +// post-ALTER state (column order, index list, constraint presence, +// error behavior) for a given ALTER TABLE sequence. +// +// Failures in omni assertions are NOT proof failures — they are +// recorded in mysql/catalog/scenarios_bug_queue/ax.md. +func TestScenario_AX(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // --- AX.1 ADD COLUMN append / FIRST / AFTER -------------------------- + t.Run("AX_1_AddColumn_append_first_after", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, b INT, c INT); +ALTER TABLE t ADD COLUMN d INT; +ALTER TABLE t ADD COLUMN e INT FIRST; +ALTER TABLE t ADD COLUMN f INT AFTER a;`) + + // Oracle: columns ordered. + want := []string{"e", "a", "f", "b", "c", "d"} + oracle := axOracleColumnOrder(t, mc, "t") + assertStringEq(t, "oracle column order", strings.Join(oracle, ","), strings.Join(want, ",")) + + // omni + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + omni := axOmniColumnOrder(tbl) + assertStringEq(t, "omni column order", strings.Join(omni, ","), strings.Join(want, ",")) + }) + + // --- AX.2 DROP COLUMN cascades removal of indexes containing column -- + t.Run("AX_2_DropColumn_cascades_indexes", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, b INT, c INT, INDEX idx_a(a), INDEX idx_ab(a, b), INDEX idx_bc(b, c)); +ALTER TABLE t DROP COLUMN a;`) + + // Oracle: scenario doc claims idx_ab should be removed entirely, + // but MySQL 8.0 actually strips only the dropped column and keeps + // idx_ab as an index on the surviving column(s). Verify dual-agreement + // against observed MySQL behavior rather than the doc's stale claim. + oracleIdx := axOracleIndexNames(t, mc, "t") + want := []string{"idx_ab", "idx_bc"} + assertStringEq(t, "oracle surviving indexes", strings.Join(oracleIdx, ","), strings.Join(want, ",")) + + // omni + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + omni := axOmniIndexNames(tbl) + assertStringEq(t, "omni surviving indexes", strings.Join(omni, ","), strings.Join(want, ",")) + }) + + // --- AX.3 DROP COLUMN rejects last-column removal -------------------- + t.Run("AX_3_DropColumn_rejects_last_column", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT);`) + + // Oracle rejects with ER_CANT_REMOVE_ALL_FIELDS (1090). + _, oracleErr := mc.db.ExecContext(mc.ctx, `ALTER TABLE t DROP COLUMN a`) + if oracleErr == nil { + t.Errorf("oracle: expected error dropping last column, got nil") + } + + // omni should also reject. + results, perr := c.Exec(`ALTER TABLE t DROP COLUMN a`, nil) + omniRejected := perr != nil + if !omniRejected { + for _, r := range results { + if r.Error != nil { + omniRejected = true + break + } + } + } + if !omniRejected { + t.Errorf("omni: expected error dropping last column, got nil") + } + }) + + // --- AX.4 DROP COLUMN rejects when referenced by CHECK or GENERATED -- + t.Run("AX_4_DropColumn_rejects_check_or_generated_ref", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t1 (a INT, b INT, CHECK (a > 0)); +CREATE TABLE t2 (a INT, b INT GENERATED ALWAYS AS (a + 1));`) + + // Oracle: check behavior of both DROPs against this container's + // MySQL build. We then assert omni matches whatever oracle does. + _, oracleErrCheck := mc.db.ExecContext(mc.ctx, `ALTER TABLE t1 DROP COLUMN a`) + _, oracleErrGen := mc.db.ExecContext(mc.ctx, `ALTER TABLE t2 DROP COLUMN a`) + + // omni + r1, _ := c.Exec(`ALTER TABLE t1 DROP COLUMN a`, nil) + omniBlockedCheck := axExecHasError(r1) + if (oracleErrCheck != nil) != omniBlockedCheck { + t.Errorf("omni vs oracle CHECK-ref DROP divergence: oracleErr=%v omniBlocked=%v", oracleErrCheck, omniBlockedCheck) + } + r2, _ := c.Exec(`ALTER TABLE t2 DROP COLUMN a`, nil) + omniBlockedGen := axExecHasError(r2) + if (oracleErrGen != nil) != omniBlockedGen { + t.Errorf("omni vs oracle GENERATED-ref DROP divergence: oracleErr=%v omniBlocked=%v", oracleErrGen, omniBlockedGen) + } + }) + + // --- AX.5 MODIFY COLUMN rewrites spec; attrs NOT inherited ---------- + t.Run("AX_5_ModifyColumn_rewrites_spec", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT NOT NULL AUTO_INCREMENT PRIMARY KEY, b INT NOT NULL DEFAULT 5 COMMENT 'x'); +ALTER TABLE t MODIFY COLUMN b BIGINT;`) + + var isNullable, colComment string + var colDefault *string + oracleScan(t, mc, `SELECT IS_NULLABLE, COLUMN_DEFAULT, COLUMN_COMMENT FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='b'`, + &isNullable, &colDefault, &colComment) + if isNullable != "YES" { + t.Errorf("oracle b nullable: got %q, want YES", isNullable) + } + if colDefault != nil { + t.Errorf("oracle b default: got %v, want nil", *colDefault) + } + if colComment != "" { + t.Errorf("oracle b comment: got %q, want empty", colComment) + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + col := tbl.GetColumn("b") + if col == nil { + t.Errorf("omni: column b missing") + return + } + if !col.Nullable { + t.Errorf("omni b nullable: got false, want true (MODIFY should not inherit NOT NULL)") + } + if col.Default != nil { + t.Errorf("omni b default: got %q, want nil (MODIFY should not inherit DEFAULT)", *col.Default) + } + if col.Comment != "" { + t.Errorf("omni b comment: got %q, want empty (MODIFY should not inherit COMMENT)", col.Comment) + } + if !strings.Contains(strings.ToLower(col.ColumnType), "bigint") && col.DataType != "bigint" { + t.Errorf("omni b type: got %q/%q, want bigint", col.DataType, col.ColumnType) + } + }) + + // --- AX.6 CHANGE COLUMN atomic rename+retype ------------------------ + t.Run("AX_6_ChangeColumn_rename_retype", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT); +ALTER TABLE t CHANGE COLUMN a b BIGINT NOT NULL;`) + + var colName, dataType, isNullable string + oracleScan(t, mc, `SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t'`, + &colName, &dataType, &isNullable) + if colName != "b" || dataType != "bigint" || isNullable != "NO" { + t.Errorf("oracle: got (%q,%q,%q), want (b,bigint,NO)", colName, dataType, isNullable) + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + if tbl.GetColumn("a") != nil { + t.Errorf("omni: column a still present after CHANGE") + } + col := tbl.GetColumn("b") + if col == nil { + t.Errorf("omni: column b missing after CHANGE") + return + } + if col.Nullable { + t.Errorf("omni b nullable: got true, want false") + } + if !strings.Contains(strings.ToLower(col.ColumnType+" "+col.DataType), "bigint") { + t.Errorf("omni b type: got %q/%q, want bigint", col.DataType, col.ColumnType) + } + }) + + // --- AX.7 ADD INDEX auto-name in ALTER context ----------------------- + t.Run("AX_7_AddIndex_auto_name_alter", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, b INT, INDEX (a)); +ALTER TABLE t ADD INDEX (b); +ALTER TABLE t ADD INDEX (a, b);`) + + want := []string{"a", "a_2", "b"} + oracle := axOracleIndexNames(t, mc, "t") + assertStringEq(t, "oracle index names", strings.Join(oracle, ","), strings.Join(want, ",")) + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + omni := axOmniIndexNames(tbl) + assertStringEq(t, "omni index names", strings.Join(omni, ","), strings.Join(want, ",")) + }) + + // --- AX.8 ADD UNIQUE / KEY / FULLTEXT implicit index types ---------- + t.Run("AX_8_AddUnique_Key_Fulltext_types", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, b VARCHAR(100)) ENGINE=InnoDB; +ALTER TABLE t ADD UNIQUE (a); +ALTER TABLE t ADD KEY (b); +ALTER TABLE t ADD FULLTEXT (b);`) + + // Oracle check: examine STATISTICS for counts. + rows := oracleRows(t, mc, `SELECT INDEX_NAME, NON_UNIQUE, INDEX_TYPE FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' ORDER BY INDEX_NAME`) + if len(rows) < 3 { + t.Errorf("oracle: expected >=3 index rows, got %d", len(rows)) + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + var haveUnique, haveFulltext, havePlain bool + for _, idx := range tbl.Indexes { + if idx.Primary { + continue + } + if idx.Unique { + haveUnique = true + } + if idx.Fulltext { + haveFulltext = true + } + if !idx.Unique && !idx.Fulltext && !idx.Spatial && !idx.Primary { + havePlain = true + } + } + if !haveUnique { + t.Errorf("omni: expected a UNIQUE index after ADD UNIQUE") + } + if !havePlain { + t.Errorf("omni: expected a plain KEY index after ADD KEY") + } + if !haveFulltext { + t.Errorf("omni: expected a FULLTEXT index after ADD FULLTEXT") + } + }) + + // --- AX.9 FK column-level shorthand silent-ignored in CREATE -------- + t.Run("AX_9_FK_column_shorthand_silent_ignore", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE parent (id INT PRIMARY KEY); +CREATE TABLE t1 (a INT); +ALTER TABLE t1 ADD FOREIGN KEY (a) REFERENCES parent(id); +CREATE TABLE t2 (a INT REFERENCES parent(id));`) + + // Oracle: t1 has FK, t2 has none. + oracleT1 := oracleFKNames(t, mc, "t1") + if len(oracleT1) != 1 { + t.Errorf("oracle t1 FK count: got %d, want 1", len(oracleT1)) + } + oracleT2 := oracleFKNames(t, mc, "t2") + if len(oracleT2) != 0 { + t.Errorf("oracle t2 FK count: got %d, want 0 (column-level REFERENCES should be ignored)", len(oracleT2)) + } + + tbl1 := c.GetDatabase("testdb").GetTable("t1") + tbl2 := c.GetDatabase("testdb").GetTable("t2") + if tbl1 == nil || tbl2 == nil { + t.Errorf("omni: t1 or t2 missing") + return + } + if n := len(omniFKNames(tbl1)); n != 1 { + t.Errorf("omni t1 FK count: got %d, want 1", n) + } + if n := len(omniFKNames(tbl2)); n != 0 { + t.Errorf("omni t2 FK count: got %d, want 0 (column-level REFERENCES should be parsed-but-ignored)", n) + } + }) + + // --- AX.10 RENAME COLUMN preserves attributes ------------------------ + t.Run("AX_10_RenameColumn_preserves_attrs", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT NOT NULL DEFAULT 5 COMMENT 'hi', INDEX (a)); +ALTER TABLE t RENAME COLUMN a TO aa;`) + + var colName, isNullable, colComment string + var colDefault *string + oracleScan(t, mc, `SELECT COLUMN_NAME, IS_NULLABLE, COLUMN_DEFAULT, COLUMN_COMMENT + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t'`, + &colName, &isNullable, &colDefault, &colComment) + if colName != "aa" || isNullable != "NO" || colComment != "hi" { + t.Errorf("oracle: got (%q,%q,%q), want (aa,NO,hi)", colName, isNullable, colComment) + } + if colDefault == nil || *colDefault != "5" { + t.Errorf("oracle default: got %v, want 5", colDefault) + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + if tbl.GetColumn("a") != nil { + t.Errorf("omni: column a still present after RENAME COLUMN") + } + col := tbl.GetColumn("aa") + if col == nil { + t.Errorf("omni: column aa missing") + return + } + if col.Nullable { + t.Errorf("omni aa nullable: got true, want false (attrs preserved)") + } + if col.Default == nil || *col.Default != "5" { + t.Errorf("omni aa default: got %v, want 5", col.Default) + } + if col.Comment != "hi" { + t.Errorf("omni aa comment: got %q, want hi", col.Comment) + } + }) + + // --- AX.11 RENAME INDEX --------------------------------------------- + t.Run("AX_11_RenameIndex", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, b INT, INDEX old_name (a, b)); +ALTER TABLE t RENAME INDEX old_name TO new_name;`) + + oracle := axOracleIndexNames(t, mc, "t") + want := []string{"new_name"} + assertStringEq(t, "oracle index names", strings.Join(oracle, ","), strings.Join(want, ",")) + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + omni := axOmniIndexNames(tbl) + assertStringEq(t, "omni index names", strings.Join(omni, ","), strings.Join(want, ",")) + }) + + // --- AX.12 RENAME TO (table rename via ALTER) ----------------------- + t.Run("AX_12_RenameTable_via_ALTER", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT); +ALTER TABLE t RENAME TO t2;`) + + // Oracle: t gone, t2 present. + var cnt int64 + oracleScan(t, mc, `SELECT COUNT(*) FROM information_schema.TABLES + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t'`, &cnt) + if cnt != 0 { + t.Errorf("oracle: expected t to be gone, count=%d", cnt) + } + oracleScan(t, mc, `SELECT COUNT(*) FROM information_schema.TABLES + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t2'`, &cnt) + if cnt != 1 { + t.Errorf("oracle: expected t2 to exist, count=%d", cnt) + } + + db := c.GetDatabase("testdb") + if db.GetTable("t") != nil { + t.Errorf("omni: table t still present after RENAME TO t2") + } + if db.GetTable("t2") == nil { + t.Errorf("omni: table t2 missing after RENAME TO") + } + }) + + // --- AX.13 ALTER COLUMN SET/DROP DEFAULT, SET INVISIBLE/VISIBLE ----- + t.Run("AX_13_AlterColumn_default_visibility", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT NOT NULL, b INT); +ALTER TABLE t ALTER COLUMN a SET DEFAULT 5;`) + + var colDefault *string + oracleScan(t, mc, `SELECT COLUMN_DEFAULT FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='a'`, &colDefault) + if colDefault == nil || *colDefault != "5" { + t.Errorf("oracle a default after SET: got %v, want 5", colDefault) + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + colA := tbl.GetColumn("a") + if colA == nil { + t.Errorf("omni: column a missing") + return + } + if colA.Default == nil || *colA.Default != "5" { + t.Errorf("omni a default after SET: got %v, want 5", colA.Default) + } + + runOnBoth(t, mc, c, `ALTER TABLE t ALTER COLUMN a DROP DEFAULT;`) + oracleScan(t, mc, `SELECT COLUMN_DEFAULT FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='a'`, &colDefault) + if colDefault != nil { + t.Errorf("oracle a default after DROP: got %q, want nil", *colDefault) + } + colA = c.GetDatabase("testdb").GetTable("t").GetColumn("a") + if colA.Default != nil { + t.Errorf("omni a default after DROP: got %q, want nil", *colA.Default) + } + + runOnBoth(t, mc, c, `ALTER TABLE t ALTER COLUMN b SET INVISIBLE;`) + var extra string + oracleScan(t, mc, `SELECT EXTRA FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='b'`, &extra) + if !strings.Contains(strings.ToUpper(extra), "INVISIBLE") { + t.Errorf("oracle b EXTRA after SET INVISIBLE: got %q, want INVISIBLE", extra) + } + colB := c.GetDatabase("testdb").GetTable("t").GetColumn("b") + if !colB.Invisible { + t.Errorf("omni b Invisible after SET INVISIBLE: got false, want true") + } + + runOnBoth(t, mc, c, `ALTER TABLE t ALTER COLUMN b SET VISIBLE;`) + oracleScan(t, mc, `SELECT EXTRA FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='b'`, &extra) + if strings.Contains(strings.ToUpper(extra), "INVISIBLE") { + t.Errorf("oracle b EXTRA after SET VISIBLE: got %q, want no INVISIBLE", extra) + } + colB = c.GetDatabase("testdb").GetTable("t").GetColumn("b") + if colB.Invisible { + t.Errorf("omni b Invisible after SET VISIBLE: got true, want false") + } + }) + + // --- AX.14 ALTER INDEX VISIBLE / INVISIBLE --------------------------- + t.Run("AX_14_AlterIndex_visibility", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, b INT, INDEX idx_a (a)); +ALTER TABLE t ALTER INDEX idx_a INVISIBLE;`) + + var isVisible string + oracleScan(t, mc, `SELECT IS_VISIBLE FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND INDEX_NAME='idx_a' LIMIT 1`, &isVisible) + if isVisible != "NO" { + t.Errorf("oracle idx_a IS_VISIBLE: got %q, want NO", isVisible) + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + var omniIdx *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_a" { + omniIdx = idx + break + } + } + if omniIdx == nil { + t.Errorf("omni: index idx_a missing") + return + } + if omniIdx.Visible { + t.Errorf("omni idx_a Visible: got true, want false after INVISIBLE") + } + }) + + // --- AX.15 Multi-sub-command ALTER — drop/add/rename composed ------- + t.Run("AX_15_MultiSubCommand_compose", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // Use a variant the scenario doc's single-pass claim actually holds + // for: MySQL evaluates RENAME COLUMN after ADD COLUMN's AFTER, so + // `ADD COLUMN c INT AFTER a` + `RENAME a TO aa` in the same ALTER + // raises "Unknown column a" (observed in MySQL 8.0). We split the + // scenario into a simpler composition that both engines accept and + // verify positional + index resolution after DROP + ADD + RENAME. + if _, err := mc.db.ExecContext(mc.ctx, `CREATE TABLE t (a INT, b INT)`); err != nil { + t.Fatalf("oracle create t: %v", err) + } + if _, err := c.Exec(`CREATE TABLE t (a INT, b INT);`, nil); err != nil { + t.Fatalf("omni create t: %v", err) + } + + ddl := `ALTER TABLE t + ADD COLUMN c INT, + DROP COLUMN b, + ADD INDEX (c), + RENAME COLUMN a TO aa` + _, oracleErr := mc.db.ExecContext(mc.ctx, ddl) + if oracleErr != nil { + t.Errorf("oracle ALTER failed: %v", oracleErr) + } + + // omni + results, perr := c.Exec(ddl, nil) + if perr != nil { + t.Errorf("omni parse error: %v", perr) + } + if axExecHasError(results) { + t.Errorf("omni ALTER exec error: %v", results) + } + + want := []string{"aa", "c"} + oracle := axOracleColumnOrder(t, mc, "t") + assertStringEq(t, "oracle column order", strings.Join(oracle, ","), strings.Join(want, ",")) + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + omni := axOmniColumnOrder(tbl) + assertStringEq(t, "omni column order", strings.Join(omni, ","), strings.Join(want, ",")) + + // And the (c) index exists. + oracleIdx := axOracleIndexNames(t, mc, "t") + foundOracle := false + for _, n := range oracleIdx { + if n == "c" { + foundOracle = true + break + } + } + if !foundOracle { + t.Errorf("oracle: expected index named c, got %v", oracleIdx) + } + foundOmni := false + for _, idx := range tbl.Indexes { + if idx.Name == "c" { + foundOmni = true + break + } + } + if !foundOmni { + t.Errorf("omni: expected index named c, got %v", axOmniIndexNames(tbl)) + } + }) +} + +// axOracleColumnOrder returns the column names of the given table in +// ORDINAL_POSITION order, as seen by the MySQL container. +func axOracleColumnOrder(t *testing.T, mc *mysqlContainer, tableName string) []string { + t.Helper() + rows := oracleRows(t, mc, fmt.Sprintf( + `SELECT COLUMN_NAME FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME=%q ORDER BY ORDINAL_POSITION`, tableName)) + var names []string + for _, row := range rows { + if len(row) > 0 { + names = append(names, asString(row[0])) + } + } + return names +} + +// axOmniColumnOrder returns the column names of the given omni table +// in positional order. +func axOmniColumnOrder(tbl *Table) []string { + names := make([]string, 0, len(tbl.Columns)) + for _, col := range tbl.Columns { + names = append(names, col.Name) + } + return names +} + +// axOracleIndexNames returns sorted distinct index names for a table +// from information_schema.STATISTICS (PRIMARY excluded). +func axOracleIndexNames(t *testing.T, mc *mysqlContainer, tableName string) []string { + t.Helper() + rows := oracleRows(t, mc, fmt.Sprintf( + `SELECT DISTINCT INDEX_NAME FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME=%q AND INDEX_NAME <> 'PRIMARY' + ORDER BY INDEX_NAME`, tableName)) + var names []string + for _, row := range rows { + if len(row) > 0 { + names = append(names, asString(row[0])) + } + } + return names +} + +// axOmniIndexNames returns sorted non-primary index names for an omni table. +func axOmniIndexNames(tbl *Table) []string { + var names []string + for _, idx := range tbl.Indexes { + if idx.Primary { + continue + } + names = append(names, idx.Name) + } + // Simple insertion sort to keep dep-free; tests compare joined strings. + for i := 1; i < len(names); i++ { + for j := i; j > 0 && names[j-1] > names[j]; j-- { + names[j-1], names[j] = names[j], names[j-1] + } + } + return names +} + +// axExecHasError reports whether any ExecResult carries an error. +func axExecHasError(results []ExecResult) bool { + for _, r := range results { + if r.Error != nil { + return true + } + } + return false +} diff --git a/tidb/catalog/scenarios_c10_test.go b/tidb/catalog/scenarios_c10_test.go new file mode 100644 index 00000000..ad573202 --- /dev/null +++ b/tidb/catalog/scenarios_c10_test.go @@ -0,0 +1,359 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C10 covers Section C10 "View metadata defaults" from +// SCENARIOS-mysql-implicit-behavior.md. Each subtest runs DDL against both +// a real MySQL 8.0 container and the omni catalog, then asserts that both +// agree on the effective default for a given view-metadata behavior. +// +// Failed omni assertions are NOT proof failures — they are recorded in +// mysql/catalog/scenarios_bug_queue/c10.md. +func TestScenario_C10(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // c10OmniView fetches a view from the omni catalog by name. + c10OmniView := func(c *Catalog, name string) *View { + db := c.GetDatabase("testdb") + if db == nil { + return nil + } + return db.Views[strings.ToLower(name)] + } + + // c10OracleViewRow returns one row from information_schema.VIEWS for (testdb, name). + c10OracleViewRow := func(t *testing.T, name string) (definer, securityType, checkOption, isUpdatable string) { + t.Helper() + oracleScan(t, mc, + `SELECT DEFINER, SECURITY_TYPE, CHECK_OPTION, IS_UPDATABLE + FROM information_schema.VIEWS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='`+name+`'`, + &definer, &securityType, &checkOption, &isUpdatable) + return + } + + // --- 10.1 ALGORITHM defaults to UNDEFINED ------------------------------ + t.Run("10_1_algorithm_defaults_undefined", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (a INT); + CREATE VIEW v AS SELECT a FROM t;`) + + // Oracle: CHECK_OPTION=NONE, SECURITY_TYPE=DEFINER, SHOW CREATE has ALGORITHM=UNDEFINED. + _, secType, checkOpt, _ := c10OracleViewRow(t, "v") + if secType != "DEFINER" { + t.Errorf("oracle SECURITY_TYPE: got %q, want %q", secType, "DEFINER") + } + if checkOpt != "NONE" { + t.Errorf("oracle CHECK_OPTION: got %q, want %q", checkOpt, "NONE") + } + showCreate := oracleShow(t, mc, "SHOW CREATE VIEW v") + if !strings.Contains(strings.ToUpper(showCreate), "ALGORITHM=UNDEFINED") { + t.Errorf("oracle SHOW CREATE VIEW v: got %q, want contains ALGORITHM=UNDEFINED", showCreate) + } + + // omni: view object reports Algorithm=UNDEFINED, SqlSecurity=DEFINER, CheckOption=NONE. + v := c10OmniView(c, "v") + if v == nil { + t.Error("omni: view v not found") + return + } + if strings.ToUpper(v.Algorithm) != "UNDEFINED" { + t.Errorf("omni Algorithm: got %q, want UNDEFINED", v.Algorithm) + } + if strings.ToUpper(v.SqlSecurity) != "DEFINER" { + t.Errorf("omni SqlSecurity: got %q, want DEFINER", v.SqlSecurity) + } + // CheckOption may be represented as "" or "NONE" depending on omni; "NONE" semantics + // means "no WITH CHECK OPTION clause". Accept either. + if v.CheckOption != "" && strings.ToUpper(v.CheckOption) != "NONE" { + t.Errorf("omni CheckOption: got %q, want \"\" or NONE", v.CheckOption) + } + }) + + // --- 10.2 DEFINER defaults to current user ----------------------------- + t.Run("10_2_definer_defaults_current_user", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (a INT); + CREATE VIEW v AS SELECT a FROM t;`) + + definer, _, _, _ := c10OracleViewRow(t, "v") + // The container uses root; DEFINER comes back as something like "root@%" + // (without backticks in I_S but with them in SHOW CREATE). + if !strings.Contains(strings.ToLower(definer), "root") { + t.Errorf("oracle DEFINER: got %q, want contains 'root'", definer) + } + + v := c10OmniView(c, "v") + if v == nil { + t.Error("omni: view v not found") + return + } + if v.Definer == "" { + t.Error("omni Definer: got empty, want non-empty (e.g. `root`@`%`)") + } + }) + + // --- 10.3 CHECK OPTION default is CASCADED when WITH CHECK OPTION lacks a qualifier --- + t.Run("10_3_check_option_default_cascaded", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (a INT); + CREATE VIEW v1 AS SELECT a FROM t WHERE a > 0 WITH CHECK OPTION; + CREATE VIEW v2 AS SELECT a FROM t WHERE a > 0 WITH LOCAL CHECK OPTION;`) + + // Oracle v1 → CASCADED, v2 → LOCAL. + _, _, v1CheckOpt, _ := c10OracleViewRow(t, "v1") + _, _, v2CheckOpt, _ := c10OracleViewRow(t, "v2") + if v1CheckOpt != "CASCADED" { + t.Errorf("oracle v1 CHECK_OPTION: got %q, want CASCADED", v1CheckOpt) + } + if v2CheckOpt != "LOCAL" { + t.Errorf("oracle v2 CHECK_OPTION: got %q, want LOCAL", v2CheckOpt) + } + + // omni must distinguish three states (NONE/LOCAL/CASCADED) and normalize + // bare WITH CHECK OPTION → CASCADED. + v1 := c10OmniView(c, "v1") + v2 := c10OmniView(c, "v2") + if v1 == nil || v2 == nil { + t.Error("omni: v1 or v2 not found") + return + } + if strings.ToUpper(v1.CheckOption) != "CASCADED" { + t.Errorf("omni v1 CheckOption: got %q, want CASCADED", v1.CheckOption) + } + if strings.ToUpper(v2.CheckOption) != "LOCAL" { + t.Errorf("omni v2 CheckOption: got %q, want LOCAL", v2.CheckOption) + } + }) + + // --- 10.4 ALGORITHM=UNDEFINED persists at CREATE; MERGE downgrades on non-mergeable --- + t.Run("10_4_algorithm_undefined_resolution", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (a INT); + CREATE ALGORITHM=UNDEFINED VIEW v_agg AS SELECT COUNT(*) FROM t; + CREATE ALGORITHM=MERGE VIEW v_merge_bad AS SELECT DISTINCT a FROM t;`) + + // Oracle: both views stored with ALGORITHM=UNDEFINED (v_merge_bad gets downgraded). + showAgg := oracleShow(t, mc, "SHOW CREATE VIEW v_agg") + if !strings.Contains(strings.ToUpper(showAgg), "ALGORITHM=UNDEFINED") { + t.Errorf("oracle SHOW CREATE VIEW v_agg: got %q, want contains ALGORITHM=UNDEFINED", showAgg) + } + showMergeBad := oracleShow(t, mc, "SHOW CREATE VIEW v_merge_bad") + if !strings.Contains(strings.ToUpper(showMergeBad), "ALGORITHM=UNDEFINED") { + t.Errorf("oracle SHOW CREATE VIEW v_merge_bad (post-downgrade): got %q, want contains ALGORITHM=UNDEFINED", showMergeBad) + } + + // omni: v_agg Algorithm=UNDEFINED; v_merge_bad should record the user-declared + // value (MERGE) verbatim per SCENARIOS guidance — MySQL silently downgrades + // but the catalog representation must preserve the pre-downgrade value. + vAgg := c10OmniView(c, "v_agg") + if vAgg == nil { + t.Error("omni: v_agg not found") + } else if strings.ToUpper(vAgg.Algorithm) != "UNDEFINED" { + t.Errorf("omni v_agg Algorithm: got %q, want UNDEFINED", vAgg.Algorithm) + } + vMergeBad := c10OmniView(c, "v_merge_bad") + if vMergeBad == nil { + t.Error("omni: v_merge_bad not found") + } else if strings.ToUpper(vMergeBad.Algorithm) != "MERGE" { + // Per scenario: omni must record the declared algorithm, not silently downgrade. + t.Errorf("omni v_merge_bad Algorithm: got %q, want MERGE (as declared)", vMergeBad.Algorithm) + } + }) + + // --- 10.5 SQL SECURITY defaults to DEFINER ----------------------------- + t.Run("10_5_sql_security_defaults_definer", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (a INT); + CREATE VIEW v AS SELECT a FROM t;`) + + var secType string + oracleScan(t, mc, + `SELECT SECURITY_TYPE FROM information_schema.VIEWS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='v'`, + &secType) + if secType != "DEFINER" { + t.Errorf("oracle SECURITY_TYPE: got %q, want DEFINER", secType) + } + + v := c10OmniView(c, "v") + if v == nil { + t.Error("omni: view v not found") + return + } + if strings.ToUpper(v.SqlSecurity) != "DEFINER" { + t.Errorf("omni SqlSecurity: got %q, want DEFINER (must be defaulted, not empty)", v.SqlSecurity) + } + }) + + // --- 10.6 View column names default to SELECT expression spelling ------ + t.Run("10_6_view_column_name_derivation", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (a INT); + CREATE VIEW v_auto AS SELECT a, a+1, COUNT(*) FROM t GROUP BY a; + CREATE VIEW v_list (x,y,z) AS SELECT a, a+1, COUNT(*) FROM t GROUP BY a;`) + + // Oracle: v_auto columns are ['a', 'a+1', 'COUNT(*)'] exactly. + rows := oracleRows(t, mc, + `SELECT COLUMN_NAME FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='v_auto' + ORDER BY ORDINAL_POSITION`) + oracleCols := make([]string, 0, len(rows)) + for _, r := range rows { + if len(r) > 0 { + oracleCols = append(oracleCols, asString(r[0])) + } + } + wantAuto := []string{"a", "a+1", "COUNT(*)"} + if strings.Join(oracleCols, ",") != strings.Join(wantAuto, ",") { + t.Errorf("oracle v_auto columns: got %v, want %v", oracleCols, wantAuto) + } + + // v_list: explicit column list wins. + rows2 := oracleRows(t, mc, + `SELECT COLUMN_NAME FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='v_list' + ORDER BY ORDINAL_POSITION`) + oracleCols2 := make([]string, 0, len(rows2)) + for _, r := range rows2 { + if len(r) > 0 { + oracleCols2 = append(oracleCols2, asString(r[0])) + } + } + wantList := []string{"x", "y", "z"} + if strings.Join(oracleCols2, ",") != strings.Join(wantList, ",") { + t.Errorf("oracle v_list columns: got %v, want %v", oracleCols2, wantList) + } + + // omni: same expectations. + vAuto := c10OmniView(c, "v_auto") + if vAuto == nil { + t.Error("omni: v_auto not found") + } else { + if strings.Join(vAuto.Columns, ",") != strings.Join(wantAuto, ",") { + t.Errorf("omni v_auto Columns: got %v, want %v", vAuto.Columns, wantAuto) + } + } + vList := c10OmniView(c, "v_list") + if vList == nil { + t.Error("omni: v_list not found") + } else { + if strings.Join(vList.Columns, ",") != strings.Join(wantList, ",") { + t.Errorf("omni v_list Columns: got %v, want %v", vList.Columns, wantList) + } + if !vList.ExplicitColumns { + t.Error("omni v_list ExplicitColumns: got false, want true") + } + } + }) + + // --- 10.7 View updatability is derived from SELECT shape --------------- + t.Run("10_7_is_updatable_derivation", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (a INT); + CREATE VIEW v_ok AS SELECT a FROM t; + CREATE VIEW v_distinct AS SELECT DISTINCT a FROM t; + CREATE ALGORITHM=TEMPTABLE VIEW v_temp AS SELECT a FROM t;`) + + for _, tc := range []struct { + name string + want string + }{ + {"v_ok", "YES"}, + {"v_distinct", "NO"}, + {"v_temp", "NO"}, + } { + var isUpd string + oracleScan(t, mc, + `SELECT IS_UPDATABLE FROM information_schema.VIEWS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='`+tc.name+`'`, + &isUpd) + if isUpd != tc.want { + t.Errorf("oracle %s IS_UPDATABLE: got %q, want %q", tc.name, isUpd, tc.want) + } + + // omni: the View struct has no IsUpdatable field today — see bug queue. + // This stanza documents the absence by asserting the view at least + // exists in the catalog. + v := c10OmniView(c, tc.name) + if v == nil { + t.Errorf("omni: view %s not found", tc.name) + } + } + // Explicit assertion: omni View has no IsUpdatable representation. + // This is the "declared bug": we want omni to carry IsUpdatable per scenario. + t.Error("omni: View struct is missing an IsUpdatable field (scenario 10.7 cannot be asserted positively)") + }) + + // --- 10.8 View column nullability widened vs base columns ------------- + t.Run("10_8_outer_join_nullability_widening", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t1 (id INT NOT NULL, a INT NOT NULL); + CREATE TABLE t2 (id INT NOT NULL, b INT NOT NULL); + CREATE VIEW v AS SELECT t1.a, t2.b FROM t1 LEFT JOIN t2 ON t1.id = t2.id;`) + + // Oracle (empirical, MySQL 8.0.45): t1.a stays NOT NULL (left/preserved + // side of LEFT JOIN), t2.b widens to nullable (right/optional side). + // The SCENARIOS doc text claims `a → YES` but that is incorrect; real + // MySQL only widens the optional side. We assert against the oracle + // ground truth. + rows := oracleRows(t, mc, + `SELECT COLUMN_NAME, IS_NULLABLE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='v' + ORDER BY ORDINAL_POSITION`) + gotNullable := map[string]string{} + for _, r := range rows { + if len(r) < 2 { + continue + } + gotNullable[asString(r[0])] = asString(r[1]) + } + if gotNullable["a"] != "NO" { + t.Errorf("oracle view column a IS_NULLABLE: got %q, want NO (left/preserved side)", gotNullable["a"]) + } + if gotNullable["b"] != "YES" { + t.Errorf("oracle view column b IS_NULLABLE: got %q, want YES (right/optional side)", gotNullable["b"]) + } + + // omni: omni's View struct does not carry per-column nullability info. + // The column list (v.Columns) is just names. This is the "declared bug": + // omni view column resolver must track outer-join nullability. + v := c10OmniView(c, "v") + if v == nil { + t.Error("omni: view v not found") + return + } + t.Error("omni: View struct has no per-column nullability; scenario 10.8 cannot be asserted positively") + }) +} diff --git a/tidb/catalog/scenarios_c11_test.go b/tidb/catalog/scenarios_c11_test.go new file mode 100644 index 00000000..30d2d87e --- /dev/null +++ b/tidb/catalog/scenarios_c11_test.go @@ -0,0 +1,335 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C11 covers Section C11 "Trigger defaults" from +// SCENARIOS-mysql-implicit-behavior.md. Each subtest runs DDL on both a real +// MySQL 8.0 container and the omni catalog, then asserts they agree on +// trigger metadata defaults: DEFINER, SQL SECURITY (no INVOKER option), +// charset/collation snapshot, ACTION_ORDER sequencing, NEW/OLD pseudo-row +// access rules, and trigger-on-partitioned-table survival across partition +// mutations. +// +// Failed omni assertions are NOT proof failures — they are recorded in +// mysql/catalog/scenarios_bug_queue/c11.md. +func TestScenario_C11(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // c11OmniExec runs a multi-statement DDL on the omni catalog and returns + // (errored, firstErr). A parse error or any per-statement Error flips + // the bool. Used by scenarios that expect omni to reject DDL. + c11OmniExec := func(c *Catalog, ddl string) (bool, error) { + results, err := c.Exec(ddl, nil) + if err != nil { + return true, err + } + for _, r := range results { + if r.Error != nil { + return true, r.Error + } + } + return false, nil + } + + // --- 11.1 Trigger DEFINER defaults to current user --------------------- + t.Run("11_1_trigger_definer_defaults_to_current_user", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT); +CREATE TRIGGER trg BEFORE INSERT ON t FOR EACH ROW SET NEW.a = NEW.a;`) + + // Oracle: DEFINER populated with session user (typically `root`@`%` + // in the test container). + var definer string + oracleScan(t, mc, + `SELECT DEFINER FROM information_schema.TRIGGERS + WHERE TRIGGER_SCHEMA='testdb' AND TRIGGER_NAME='trg'`, + &definer) + if definer == "" { + t.Errorf("oracle: DEFINER should be non-empty for trg") + } + + // omni: trigger stored with non-empty Definer. + db := c.GetDatabase("testdb") + if db == nil { + t.Error("omni: testdb missing") + return + } + trg := db.Triggers[toLower("trg")] + if trg == nil { + t.Error("omni: trigger trg missing from Triggers map") + return + } + if trg.Definer == "" { + t.Errorf("omni: trigger trg Definer should default to a session user, got empty") + } + }) + + // --- 11.2 Trigger SQL SECURITY always DEFINER; no INVOKER option ------- + t.Run("11_2_trigger_sql_security_always_definer", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // Valid form: CREATE DEFINER=... TRIGGER must succeed on both sides. + good := `CREATE TABLE t (a INT); +CREATE DEFINER='root'@'%' TRIGGER trg1 BEFORE INSERT ON t FOR EACH ROW SET NEW.a=1;` + runOnBoth(t, mc, c, good) + + // Invalid form: SQL SECURITY INVOKER on a trigger is a grammar error. + bad := `CREATE TRIGGER trg2 SQL SECURITY INVOKER BEFORE INSERT ON t FOR EACH ROW SET NEW.a=1` + _, oracleErr := mc.db.ExecContext(mc.ctx, bad) + if oracleErr == nil { + t.Errorf("oracle: expected ER_PARSE_ERROR for SQL SECURITY INVOKER on trigger, got nil") + } + + // omni: must also reject. A permissive parse here is a bug. + omniErrored, _ := c11OmniExec(c, bad+";") + assertBoolEq(t, "omni rejects SQL SECURITY INVOKER on trigger", omniErrored, true) + + // information_schema.TRIGGERS must NOT expose a SECURITY_TYPE column + // for triggers (unlike VIEWS / ROUTINES). If MySQL ever added one we + // would need to rethink this scenario; assert absence empirically. + var colCount int64 + oracleScan(t, mc, + `SELECT COUNT(*) FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='information_schema' AND TABLE_NAME='TRIGGERS' + AND COLUMN_NAME='SECURITY_TYPE'`, + &colCount) + if colCount != 0 { + t.Errorf("oracle: information_schema.TRIGGERS has SECURITY_TYPE (unexpected), count=%d", colCount) + } + }) + + // --- 11.3 charset/collation snapshot at trigger creation time ---------- + t.Run("11_3_charset_collation_snapshot", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `SET NAMES utf8mb4; +CREATE TABLE t (a INT); +CREATE TRIGGER trg BEFORE INSERT ON t FOR EACH ROW SET NEW.a = NEW.a;`) + + var cset, ccoll, dbcoll string + oracleScan(t, mc, + `SELECT CHARACTER_SET_CLIENT, COLLATION_CONNECTION, DATABASE_COLLATION + FROM information_schema.TRIGGERS + WHERE TRIGGER_SCHEMA='testdb' AND TRIGGER_NAME='trg'`, + &cset, &ccoll, &dbcoll) + if cset == "" || ccoll == "" || dbcoll == "" { + t.Errorf("oracle: expected three non-empty snapshot fields, got (%q,%q,%q)", + cset, ccoll, dbcoll) + } + if !strings.HasPrefix(strings.ToLower(cset), "utf8") { + t.Errorf("oracle: CHARACTER_SET_CLIENT for trg should be utf8*; got %q", cset) + } + + // omni: Trigger struct as of today has no CharacterSetClient / + // CollationConnection / DatabaseCollation fields. Record the gap — + // deparse→reparse cannot currently round-trip session snapshots. + db := c.GetDatabase("testdb") + if db == nil { + t.Error("omni: testdb missing") + return + } + trg := db.Triggers[toLower("trg")] + if trg == nil { + t.Error("omni: trigger trg missing") + return + } + // Assert the core fields omni does track so the subtest has real + // substance, then flag the charset-snapshot gap explicitly. If a + // future patch adds CharacterSetClient / CollationConnection / + // DatabaseCollation fields, tighten the assertion below. + if strings.ToUpper(trg.Timing) != "BEFORE" { + t.Errorf("omni 11.3: trg.Timing=%q, want BEFORE", trg.Timing) + } + if strings.ToUpper(trg.Event) != "INSERT" { + t.Errorf("omni 11.3: trg.Event=%q, want INSERT", trg.Event) + } + if trg.Table != "t" { + t.Errorf("omni 11.3: trg.Table=%q, want t", trg.Table) + } + t.Errorf("omni 11.3: KNOWN GAP — Trigger struct lacks CharacterSetClient/CollationConnection/DatabaseCollation fields; session snapshot cannot round-trip (see scenarios_bug_queue/c11.md)") + }) + + // --- 11.4 ACTION_ORDER default sequencing within (table, timing, event) - + t.Run("11_4_action_order_default_sequencing", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t (a INT); +CREATE TRIGGER trg_a BEFORE INSERT ON t FOR EACH ROW SET NEW.a = NEW.a + 1; +CREATE TRIGGER trg_b BEFORE INSERT ON t FOR EACH ROW SET NEW.a = NEW.a + 2; +CREATE TRIGGER trg_c BEFORE INSERT ON t FOR EACH ROW PRECEDES trg_a SET NEW.a = NEW.a + 10;` + runOnBoth(t, mc, c, ddl) + + // Oracle: expect trg_c=1, trg_a=2, trg_b=3 after the PRECEDES splice. + rows := oracleRows(t, mc, + `SELECT TRIGGER_NAME, ACTION_ORDER FROM information_schema.TRIGGERS + WHERE TRIGGER_SCHEMA='testdb' ORDER BY ACTION_ORDER`) + if len(rows) != 3 { + t.Errorf("oracle: expected 3 triggers, got %d", len(rows)) + } else { + wantOrder := []struct { + name string + order int64 + }{ + {"trg_c", 1}, + {"trg_a", 2}, + {"trg_b", 3}, + } + for i, w := range wantOrder { + name := asString(rows[i][0]) + var ord int64 + switch v := rows[i][1].(type) { + case int64: + ord = v + case int32: + ord = int64(v) + case int: + ord = int64(v) + } + if name != w.name || ord != w.order { + t.Errorf("oracle row %d: got (%q, %d), want (%q, %d)", + i, name, ord, w.name, w.order) + } + } + } + + // omni: Trigger struct has no ActionOrder field. The best we can + // check today is that all three trigger objects exist and the + // Order info (FOLLOWS/PRECEDES) was captured for trg_c. + db := c.GetDatabase("testdb") + if db == nil { + t.Error("omni: testdb missing") + return + } + for _, name := range []string{"trg_a", "trg_b", "trg_c"} { + if db.Triggers[toLower(name)] == nil { + t.Errorf("omni: trigger %s missing", name) + } + } + if trgC := db.Triggers[toLower("trg_c")]; trgC != nil { + if trgC.Order == nil { + t.Errorf("omni: trigger trg_c should have Order info for PRECEDES trg_a") + } else { + assertBoolEq(t, "omni trg_c Order.Follows (false means PRECEDES)", + trgC.Order.Follows, false) + assertStringEq(t, "omni trg_c Order.TriggerName", + strings.ToLower(trgC.Order.TriggerName), "trg_a") + } + } + }) + + // --- 11.5 NEW/OLD pseudo-row access by event type ---------------------- + t.Run("11_5_new_old_pseudorow_by_event", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // Legal cases: NEW in BEFORE INSERT, OLD in BEFORE DELETE, both in UPDATE. + legal := `CREATE TABLE t (a INT); +CREATE TRIGGER t_ins BEFORE INSERT ON t FOR EACH ROW SET NEW.a = NEW.a + 1; +CREATE TRIGGER t_del BEFORE DELETE ON t FOR EACH ROW SET @x = OLD.a; +CREATE TRIGGER t_upd BEFORE UPDATE ON t FOR EACH ROW SET NEW.a = OLD.a + 1;` + runOnBoth(t, mc, c, legal) + + // Illegal: OLD inside an INSERT trigger -> ER_TRG_NO_SUCH_ROW_IN_TRG. + badOldInInsert := `CREATE TRIGGER bad1 AFTER INSERT ON t FOR EACH ROW SET @x = OLD.a` + _, oErr1 := mc.db.ExecContext(mc.ctx, badOldInInsert) + if oErr1 == nil { + t.Errorf("oracle: expected rejection of OLD.a in INSERT trigger, got nil") + } + omniErr1, _ := c11OmniExec(c, badOldInInsert+";") + assertBoolEq(t, "omni rejects OLD.* in INSERT trigger body", omniErr1, true) + + // Illegal: NEW inside a DELETE trigger -> ER_TRG_NO_SUCH_ROW_IN_TRG. + badNewInDelete := `CREATE TRIGGER bad2 AFTER DELETE ON t FOR EACH ROW SET @x = NEW.a` + _, oErr2 := mc.db.ExecContext(mc.ctx, badNewInDelete) + if oErr2 == nil { + t.Errorf("oracle: expected rejection of NEW.a in DELETE trigger, got nil") + } + omniErr2, _ := c11OmniExec(c, badNewInDelete+";") + assertBoolEq(t, "omni rejects NEW.* in DELETE trigger body", omniErr2, true) + + // Illegal: SET NEW.a in AFTER INSERT — NEW is read-only after the row + // is written. + badAfterAssign := `CREATE TRIGGER bad3 AFTER INSERT ON t FOR EACH ROW SET NEW.a = 99` + _, oErr3 := mc.db.ExecContext(mc.ctx, badAfterAssign) + if oErr3 == nil { + t.Errorf("oracle: expected rejection of SET NEW.a in AFTER INSERT trigger, got nil") + } + omniErr3, _ := c11OmniExec(c, badAfterAssign+";") + assertBoolEq(t, "omni rejects writes to NEW.* in AFTER trigger", omniErr3, true) + + // Sanity: legal triggers landed in omni and oracle. + db := c.GetDatabase("testdb") + if db != nil { + for _, name := range []string{"t_ins", "t_del", "t_upd"} { + if db.Triggers[toLower(name)] == nil { + t.Errorf("omni: legal trigger %s missing", name) + } + } + } + var legalCount int64 + oracleScan(t, mc, + `SELECT COUNT(*) FROM information_schema.TRIGGERS + WHERE TRIGGER_SCHEMA='testdb' AND TRIGGER_NAME IN ('t_ins','t_del','t_upd')`, + &legalCount) + if legalCount != 3 { + t.Errorf("oracle: expected 3 legal triggers, got %d", legalCount) + } + }) + + // --- 11.6 Trigger on partitioned table survives partition mutation ----- + t.Run("11_6_trigger_on_partitioned_table", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t (a INT, b INT) PARTITION BY HASH(a) PARTITIONS 4; +CREATE TRIGGER trg BEFORE INSERT ON t FOR EACH ROW SET NEW.b = NEW.a * 2;` + runOnBoth(t, mc, c, ddl) + + // Alter the partition layout. Oracle: trigger survives. + alter := `ALTER TABLE t COALESCE PARTITION 2` + if _, err := mc.db.ExecContext(mc.ctx, alter); err != nil { + t.Errorf("oracle: COALESCE PARTITION failed: %v", err) + } + // omni: execute the same ALTER; if omni's parser does not support + // this syntax yet, record it as part of the bug queue but don't + // abort the subtest. + if _, err := c.Exec(alter+";", nil); err != nil { + t.Logf("omni: ALTER TABLE ... COALESCE PARTITION not yet supported: %v", err) + } + + var trgCount int64 + oracleScan(t, mc, + `SELECT COUNT(*) FROM information_schema.TRIGGERS + WHERE TRIGGER_SCHEMA='testdb' AND EVENT_OBJECT_TABLE='t' AND TRIGGER_NAME='trg'`, + &trgCount) + if trgCount != 1 { + t.Errorf("oracle: trigger trg should survive COALESCE PARTITION, count=%d", trgCount) + } + + // omni: trigger must still be registered against the table, not + // hanging off any per-partition structure. + db := c.GetDatabase("testdb") + if db == nil { + t.Error("omni: testdb missing") + return + } + trg := db.Triggers[toLower("trg")] + if trg == nil { + t.Errorf("omni: trigger trg missing after partition mutation") + return + } + assertStringEq(t, "omni trigger trg.Table", strings.ToLower(trg.Table), "t") + }) +} diff --git a/tidb/catalog/scenarios_c14_test.go b/tidb/catalog/scenarios_c14_test.go new file mode 100644 index 00000000..74a3d3d3 --- /dev/null +++ b/tidb/catalog/scenarios_c14_test.go @@ -0,0 +1,279 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C14 covers section C14 (Constraint enforcement defaults) from +// SCENARIOS-mysql-implicit-behavior.md. Each subtest asserts that both real +// MySQL 8.0 and the omni catalog agree on CHECK-constraint enforcement and +// validation behavior. +// +// Failures in omni assertions are NOT proof failures — they are recorded in +// mysql/catalog/scenarios_bug_queue/c14.md. +func TestScenario_C14(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // --- 14.1 CHECK constraint defaults to ENFORCED ---------------------- + t.Run("14_1_check_defaults_enforced", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, CHECK (a > 0));`) + + // Oracle: information_schema.TABLE_CONSTRAINTS.ENFORCED should be YES. + // Note: information_schema.CHECK_CONSTRAINTS in MySQL 8.0 does NOT + // expose an ENFORCED column — it only has (CONSTRAINT_CATALOG, + // CONSTRAINT_SCHEMA, CONSTRAINT_NAME, CHECK_CLAUSE). The ENFORCED + // metadata lives in TABLE_CONSTRAINTS instead. + var tcEnforced string + oracleScan(t, mc, `SELECT ENFORCED FROM information_schema.TABLE_CONSTRAINTS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND CONSTRAINT_TYPE='CHECK'`, + &tcEnforced) + assertStringEq(t, "oracle TABLE_CONSTRAINTS.ENFORCED", tcEnforced, "YES") + + // SHOW CREATE TABLE must not contain NOT ENFORCED. + create := oracleShow(t, mc, "SHOW CREATE TABLE t") + if strings.Contains(strings.ToUpper(create), "NOT ENFORCED") { + t.Errorf("oracle: SHOW CREATE TABLE unexpectedly contains NOT ENFORCED: %s", create) + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + chk := c14FindCheck(tbl, "t_chk_1") + if chk == nil { + t.Errorf("omni: CHECK constraint t_chk_1 missing") + return + } + // Default = ENFORCED → NotEnforced must be false. + assertBoolEq(t, "omni t_chk_1 NotEnforced", chk.NotEnforced, false) + }) + + // --- 14.2 ALTER CHECK NOT ENFORCED / ENFORCED toggles the flag ------- + t.Run("14_2_alter_check_enforcement_toggle", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, CONSTRAINT c_pos CHECK (a > 0)); +ALTER TABLE t ALTER CHECK c_pos NOT ENFORCED;`) + + // Oracle: after NOT ENFORCED, TABLE_CONSTRAINTS.ENFORCED should be NO. + var enforced string + oracleScan(t, mc, `SELECT ENFORCED FROM information_schema.TABLE_CONSTRAINTS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND CONSTRAINT_NAME='c_pos'`, + &enforced) + assertStringEq(t, "oracle after NOT ENFORCED", enforced, "NO") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + chk := c14FindCheck(tbl, "c_pos") + if chk == nil { + t.Errorf("omni: CHECK c_pos missing after initial CREATE+ALTER") + return + } + assertBoolEq(t, "omni c_pos NotEnforced after NOT ENFORCED", chk.NotEnforced, true) + + // Toggle back to ENFORCED and re-check both sides. + runOnBoth(t, mc, c, `ALTER TABLE t ALTER CHECK c_pos ENFORCED;`) + + oracleScan(t, mc, `SELECT ENFORCED FROM information_schema.TABLE_CONSTRAINTS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND CONSTRAINT_NAME='c_pos'`, + &enforced) + assertStringEq(t, "oracle after re-ENFORCED", enforced, "YES") + + tbl2 := c.GetDatabase("testdb").GetTable("t") + chk2 := c14FindCheck(tbl2, "c_pos") + if chk2 == nil { + t.Errorf("omni: CHECK c_pos missing after re-ENFORCED") + return + } + assertBoolEq(t, "omni c_pos NotEnforced after re-ENFORCED", chk2.NotEnforced, false) + }) + + // --- 14.3 STORED generated column + CHECK: predicate evaluated against stored value + t.Run("14_3_check_with_stored_generated_col", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a INT, + g INT AS (a * 2) STORED, + CONSTRAINT c_g_nonneg CHECK (g >= 0) +);`) + + // Oracle: CHECK_CONSTRAINTS row for c_g_nonneg exists with clause + // referencing g. MySQL stores it as (`g` >= 0). + var name, clause string + oracleScan(t, mc, `SELECT CONSTRAINT_NAME, CHECK_CLAUSE FROM information_schema.CHECK_CONSTRAINTS + WHERE CONSTRAINT_SCHEMA='testdb' AND CONSTRAINT_NAME='c_g_nonneg'`, + &name, &clause) + assertStringEq(t, "oracle CHECK name", name, "c_g_nonneg") + if !strings.Contains(clause, "g") || !strings.Contains(clause, ">=") { + t.Errorf("oracle: expected CHECK_CLAUSE to reference g and >=, got %q", clause) + } + + // Valid INSERT: a=5 → g=10 passes. + if _, err := mc.db.ExecContext(mc.ctx, `INSERT INTO t (a) VALUES (5)`); err != nil { + t.Errorf("oracle: expected INSERT a=5 to succeed, got %v", err) + } + // Invalid INSERT: a=-1 → g=-2 violates CHECK, MySQL error 3819. + _, err := mc.db.ExecContext(mc.ctx, `INSERT INTO t (a) VALUES (-1)`) + if err == nil { + t.Errorf("oracle: expected ER_CHECK_CONSTRAINT_VIOLATED (3819) for a=-1, got nil") + } else if !strings.Contains(err.Error(), "3819") && + !strings.Contains(strings.ToLower(err.Error()), "check constraint") { + t.Errorf("oracle: expected CHECK violation error, got %v", err) + } + + // omni: CHECK must be registered and reference the generated column g. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + chk := c14FindCheck(tbl, "c_g_nonneg") + if chk == nil { + t.Errorf("omni: CHECK c_g_nonneg missing") + return + } + if !strings.Contains(chk.CheckExpr, "g") { + t.Errorf("omni: CheckExpr should reference g, got %q", chk.CheckExpr) + } + // g column must exist and be stored-generated. + var gcol *Column + for _, col := range tbl.Columns { + if strings.EqualFold(col.Name, "g") { + gcol = col + break + } + } + if gcol == nil { + t.Errorf("omni: generated column g missing") + } + }) + + // --- 14.4 Forbidden constructs in CHECK: subquery, NOW(), user variable + t.Run("14_4_check_forbidden_constructs", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // Legal baseline: deterministic built-in CHAR_LENGTH is OK. + runOnBoth(t, mc, c, `CREATE TABLE t1 (a INT, CHECK (CHAR_LENGTH(CAST(a AS CHAR)) < 10));`) + + // --- 14.4a Subquery in CHECK → ER_CHECK_CONSTRAINT_NOT_ALLOWED_CONTEXT (3812). + // Need a referenced table first. + if _, err := mc.db.ExecContext(mc.ctx, `CREATE TABLE other (id INT)`); err != nil { + t.Errorf("oracle setup: CREATE other failed: %v", err) + } + _, errSub := mc.db.ExecContext(mc.ctx, + `CREATE TABLE t2 (a INT, CHECK (a IN (SELECT id FROM other)))`) + if errSub == nil { + t.Errorf("oracle: expected subquery CHECK to fail, got nil") + } else { + // MySQL 8.0 observed to return 3815/3814 ("disallowed function") + // for subquery-in-CHECK; older docs mention 3812 as well. Accept + // any check-constraint rejection error. + if !c14IsCheckRejection(errSub) { + t.Errorf("oracle: expected CHECK-rejection error, got %v", errSub) + } + } + c14AssertOmniRejects(t, c, "subquery CHECK", + `CREATE TABLE t2 (a INT, CHECK (a IN (SELECT id FROM other)));`) + + // --- 14.4b NOW() in CHECK → ER_CHECK_CONSTRAINT_NAMED_FUNCTION_IS_NOT_ALLOWED (3815). + _, errNow := mc.db.ExecContext(mc.ctx, + `CREATE TABLE t3 (a INT, CHECK (a < NOW()))`) + if errNow == nil { + t.Errorf("oracle: expected NOW() CHECK to fail, got nil") + } else if !c14IsCheckRejection(errNow) { + t.Errorf("oracle: expected NOW() CHECK-rejection error, got %v", errNow) + } + c14AssertOmniRejects(t, c, "NOW() CHECK", + `CREATE TABLE t3 (a INT, CHECK (a < NOW()));`) + + // --- 14.4c User variable in CHECK → ER_CHECK_CONSTRAINT_VARIABLES (3813). + _, errVar := mc.db.ExecContext(mc.ctx, + `CREATE TABLE t4 (a INT, CHECK (a < @max))`) + if errVar == nil { + t.Errorf("oracle: expected user-var CHECK to fail, got nil") + } else if !c14IsCheckRejection(errVar) { + t.Errorf("oracle: expected user-var CHECK-rejection error, got %v", errVar) + } + c14AssertOmniRejects(t, c, "user-var CHECK", + `CREATE TABLE t4 (a INT, CHECK (a < @max));`) + + // --- 14.4d RAND() in CHECK → ER_CHECK_CONSTRAINT_NAMED_FUNCTION_IS_NOT_ALLOWED (3815). + _, errRand := mc.db.ExecContext(mc.ctx, + `CREATE TABLE t5 (a INT, CHECK (a < RAND()))`) + if errRand == nil { + t.Errorf("oracle: expected RAND() CHECK to fail, got nil") + } else if !c14IsCheckRejection(errRand) { + t.Errorf("oracle: expected RAND() CHECK-rejection error, got %v", errRand) + } + c14AssertOmniRejects(t, c, "RAND() CHECK", + `CREATE TABLE t5 (a INT, CHECK (a < RAND()));`) + }) +} + +// --- section-local helpers ------------------------------------------------ + +// c14FindCheck returns the first CHECK constraint on the table whose name +// matches (case-insensitive), or nil. +func c14FindCheck(tbl *Table, name string) *Constraint { + for _, con := range tbl.Constraints { + if con.Type == ConCheck && strings.EqualFold(con.Name, name) { + return con + } + } + return nil +} + +// c14IsCheckRejection reports whether the given error looks like a MySQL +// rejection of a forbidden construct in a CHECK clause. Accepts any of the +// documented error codes (3812–3815) or the strings "not allowed", +// "disallowed", "subquer", or "variable". +func c14IsCheckRejection(err error) bool { + if err == nil { + return false + } + msg := err.Error() + lower := strings.ToLower(msg) + for _, code := range []string{"3812", "3813", "3814", "3815"} { + if strings.Contains(msg, code) { + return true + } + } + return strings.Contains(lower, "not allowed") || + strings.Contains(lower, "disallowed") || + strings.Contains(lower, "subquer") || + strings.Contains(lower, "check constraint") +} + +// c14AssertOmniRejects executes the given DDL against the omni catalog and +// reports a test error if execution succeeds with no error. Uses t.Error so +// subsequent assertions continue to run. The catalog state after a rejected +// DDL is considered undefined — callers should not rely on it. +func c14AssertOmniRejects(t *testing.T, c *Catalog, label, ddl string) { + t.Helper() + results, err := c.Exec(ddl, nil) + if err != nil { + return // parse error counts as rejection + } + for _, r := range results { + if r.Error != nil { + return // exec error counts as rejection + } + } + t.Errorf("omni: expected %s to be rejected, but execution succeeded", label) +} diff --git a/tidb/catalog/scenarios_c15_test.go b/tidb/catalog/scenarios_c15_test.go new file mode 100644 index 00000000..2abedcae --- /dev/null +++ b/tidb/catalog/scenarios_c15_test.go @@ -0,0 +1,150 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C15 covers Section C15 "Column positioning defaults" from +// SCENARIOS-mysql-implicit-behavior.md. Each subtest runs DDL against both +// a real MySQL 8.0 container and the omni catalog, then asserts that both +// agree on the resulting column ordering. +// +// Failed omni assertions are NOT proof failures — they are recorded in +// mysql/catalog/scenarios_bug_queue/c15.md. +func TestScenario_C15(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // c15OracleColumnOrder fetches the column names from MySQL's + // information_schema.COLUMNS for testdb., ordered by + // ORDINAL_POSITION. + c15OracleColumnOrder := func(t *testing.T, table string) []string { + t.Helper() + rows := oracleRows(t, mc, + `SELECT COLUMN_NAME FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='`+table+`' + ORDER BY ORDINAL_POSITION`) + out := make([]string, 0, len(rows)) + for _, r := range rows { + if len(r) == 0 { + continue + } + out = append(out, asString(r[0])) + } + return out + } + + // c15OmniColumnOrder fetches column names from the omni catalog in + // declaration order (Table.Columns slice order). + c15OmniColumnOrder := func(c *Catalog, table string) []string { + db := c.GetDatabase("testdb") + if db == nil { + return nil + } + tbl := db.GetTable(table) + if tbl == nil { + return nil + } + out := make([]string, 0, len(tbl.Columns)) + for _, col := range tbl.Columns { + out = append(out, col.Name) + } + return out + } + + c15AssertOrder := func(t *testing.T, label string, got, want []string) { + t.Helper() + if strings.Join(got, ",") != strings.Join(want, ",") { + t.Errorf("%s: got %v, want %v", label, got, want) + } + } + + // --- 15.1 ADD COLUMN appends to end ----------------------------------- + t.Run("15_1_add_column_appends_end", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, b INT); + ALTER TABLE t ADD COLUMN c INT;`) + + want := []string{"a", "b", "c"} + c15AssertOrder(t, "oracle order after ADD COLUMN c", c15OracleColumnOrder(t, "t"), want) + c15AssertOrder(t, "omni order after ADD COLUMN c", c15OmniColumnOrder(c, "t"), want) + }) + + // --- 15.2 ADD COLUMN ... FIRST ---------------------------------------- + t.Run("15_2_add_column_first", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, b INT); + ALTER TABLE t ADD COLUMN c INT FIRST;`) + + want := []string{"c", "a", "b"} + c15AssertOrder(t, "oracle order after ADD COLUMN c FIRST", c15OracleColumnOrder(t, "t"), want) + c15AssertOrder(t, "omni order after ADD COLUMN c FIRST", c15OmniColumnOrder(c, "t"), want) + }) + + // --- 15.3 ADD COLUMN ... AFTER col ------------------------------------ + t.Run("15_3_add_column_after", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, b INT, c INT); + ALTER TABLE t ADD COLUMN x INT AFTER a;`) + + want := []string{"a", "x", "b", "c"} + c15AssertOrder(t, "oracle order after ADD COLUMN x AFTER a", c15OracleColumnOrder(t, "t"), want) + c15AssertOrder(t, "omni order after ADD COLUMN x AFTER a", c15OmniColumnOrder(c, "t"), want) + }) + + // --- 15.4 MODIFY retains position unless FIRST/AFTER ------------------ + t.Run("15_4_modify_retains_position", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, b INT, c INT); + ALTER TABLE t MODIFY b BIGINT;`) + + want := []string{"a", "b", "c"} + c15AssertOrder(t, "oracle order after MODIFY b BIGINT", c15OracleColumnOrder(t, "t"), want) + c15AssertOrder(t, "omni order after MODIFY b BIGINT", c15OmniColumnOrder(c, "t"), want) + + // Also verify the type actually changed on both sides. + var dataType string + oracleScan(t, mc, + `SELECT DATA_TYPE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='b'`, + &dataType) + if dataType != "bigint" { + t.Errorf("oracle DATA_TYPE for b: got %q, want %q", dataType, "bigint") + } + + if tbl := c.GetDatabase("testdb").GetTable("t"); tbl != nil { + if col := tbl.GetColumn("b"); col != nil { + if !strings.EqualFold(col.DataType, "bigint") { + t.Errorf("omni DataType for b: got %q, want %q", col.DataType, "bigint") + } + } else { + t.Errorf("omni: column b missing after MODIFY") + } + } + }) + + // --- 15.5 multiple ADD COLUMN in one ALTER, left-to-right resolution --- + t.Run("15_5_multi_add_left_to_right", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, b INT); + ALTER TABLE t ADD COLUMN x INT AFTER a, ADD COLUMN y INT AFTER x;`) + + want := []string{"a", "x", "y", "b"} + c15AssertOrder(t, "oracle order after multi-ADD", c15OracleColumnOrder(t, "t"), want) + c15AssertOrder(t, "omni order after multi-ADD", c15OmniColumnOrder(c, "t"), want) + }) +} diff --git a/tidb/catalog/scenarios_c16_test.go b/tidb/catalog/scenarios_c16_test.go new file mode 100644 index 00000000..8a1f1e5b --- /dev/null +++ b/tidb/catalog/scenarios_c16_test.go @@ -0,0 +1,591 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C16 covers section C16 (Date/time function precision defaults) +// from SCENARIOS-mysql-implicit-behavior.md. 12 scenarios, each asserted on +// both MySQL 8.0 container and the omni catalog. Failures in omni assertions +// are documented in scenarios_bug_queue/c16.md (NOT proof failures). +func TestScenario_C16(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // Helper: fetch IS.COLUMNS.DATETIME_PRECISION for testdb.t.. + c16OracleDatetimePrecision := func(t *testing.T, col string) (int, bool) { + t.Helper() + var v any + row := mc.db.QueryRowContext(mc.ctx, + `SELECT DATETIME_PRECISION FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME=?`, col) + if err := row.Scan(&v); err != nil { + t.Errorf("oracle DATETIME_PRECISION(%q): %v", col, err) + return 0, false + } + if v == nil { + return 0, false // NULL (e.g. DATE columns) + } + switch x := v.(type) { + case int64: + return int(x), true + case []byte: + var n int + for _, c := range x { + if c >= '0' && c <= '9' { + n = n*10 + int(c-'0') + } + } + return n, true + } + return 0, false + } + + // Helper: run a DDL that we expect to fail on MySQL. Returns the errors + // from each side so the caller can assert. + c16RunExpectError := func(t *testing.T, c *Catalog, ddl string) (oracleErr, omniErr error) { + t.Helper() + _, oracleErr = mc.db.ExecContext(mc.ctx, ddl) + results, parseErr := c.Exec(ddl, nil) + if parseErr != nil { + omniErr = parseErr + return + } + for _, r := range results { + if r.Error != nil { + omniErr = r.Error + return + } + } + return + } + + // ---- 16.1 NOW()/CURRENT_TIMESTAMP precision defaults to 0 ------------- + t.Run("16_1_NOW_precision_default_0", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (ts DATETIME DEFAULT NOW())`) + + // Oracle: SHOW CREATE TABLE should render DEFAULT CURRENT_TIMESTAMP + // without any (n) suffix. + create := oracleShow(t, mc, "SHOW CREATE TABLE t") + lo := strings.ToLower(create) + if !strings.Contains(lo, "default current_timestamp") { + t.Errorf("oracle SHOW CREATE TABLE missing DEFAULT CURRENT_TIMESTAMP: %s", create) + } + if strings.Contains(lo, "current_timestamp(") { + t.Errorf("oracle SHOW CREATE TABLE unexpectedly has fsp suffix: %s", create) + } + + // Oracle: DATETIME_PRECISION = 0 for plain DATETIME col. + if p, ok := c16OracleDatetimePrecision(t, "ts"); !ok { + t.Errorf("oracle DATETIME_PRECISION NULL for plain DATETIME") + } else if p != 0 { + t.Errorf("oracle DATETIME_PRECISION: got %d, want 0", p) + } + + // omni: column must exist with a default referencing CURRENT_TIMESTAMP + // (any rendering: "now()", "CURRENT_TIMESTAMP", etc). + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t not found") + return + } + col := tbl.GetColumn("ts") + if col == nil { + t.Errorf("omni: column ts not found") + return + } + if col.Default == nil { + t.Errorf("omni: DEFAULT missing for ts") + } else { + lo := strings.ToLower(*col.Default) + if !strings.Contains(lo, "current_timestamp") && !strings.Contains(lo, "now") { + t.Errorf("omni: DEFAULT = %q, expected CURRENT_TIMESTAMP/NOW form", *col.Default) + } + if strings.Contains(lo, "(") && strings.Contains(lo, ")") && !strings.Contains(lo, "()") { + // Contains some fsp number — unexpected for plain DATETIME. + t.Errorf("omni: DEFAULT = %q, expected no fsp suffix", *col.Default) + } + } + }) + + // ---- 16.2 NOW(N) range 0..6; NOW(7) rejected -------------------------- + t.Run("16_2_NOW_explicit_precision_range", func(t *testing.T) { + scenarioReset(t, mc) + + // 0..6 should all work at runtime via SELECT LENGTH(NOW(n)). + wantLen := map[int]int{0: 19, 1: 21, 2: 22, 3: 23, 4: 24, 5: 25, 6: 26} + for n, want := range wantLen { + var got int + row := mc.db.QueryRowContext(mc.ctx, + `SELECT LENGTH(NOW(`+itoaC16(n)+`))`) + if err := row.Scan(&got); err != nil { + t.Errorf("oracle LENGTH(NOW(%d)): %v", n, err) + continue + } + if got != want { + t.Errorf("oracle LENGTH(NOW(%d)): got %d, want %d", n, got, want) + } + } + + // NOW(7) must be rejected with ER_TOO_BIG_PRECISION (1426). + _, err := mc.db.ExecContext(mc.ctx, `DO NOW(7)`) + if err == nil { + t.Errorf("oracle: NOW(7) unexpectedly accepted") + } else if !strings.Contains(err.Error(), "1426") && + !strings.Contains(strings.ToLower(err.Error()), "precision") { + t.Errorf("oracle: NOW(7) unexpected error: %v", err) + } + + // omni: NOW(7) as DEFAULT should be rejected (strictness gap if it isn't). + c := scenarioNewCatalog(t) + ddl := `CREATE TABLE t (a DATETIME(7))` + results, parseErr := c.Exec(ddl, nil) + var omniErr error + if parseErr != nil { + omniErr = parseErr + } else { + for _, r := range results { + if r.Error != nil { + omniErr = r.Error + break + } + } + } + if omniErr == nil { + t.Errorf("omni: KNOWN BUG — DATETIME(7) should be rejected (fsp > 6), see scenarios_bug_queue/c16.md") + } + }) + + // ---- 16.3 CURDATE / CURRENT_DATE / UTC_DATE take no precision arg ---- + t.Run("16_3_CURDATE_no_precision", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // Oracle: CURDATE() works, CURDATE(6) is a parse error. + if _, err := mc.db.ExecContext(mc.ctx, `DO CURDATE()`); err != nil { + t.Errorf("oracle CURDATE() failed: %v", err) + } + if _, err := mc.db.ExecContext(mc.ctx, `DO CURDATE(6)`); err == nil { + t.Errorf("oracle: CURDATE(6) unexpectedly accepted") + } + + // Table with default CURDATE() works and has NULL DATETIME_PRECISION. + runOnBoth(t, mc, c, `CREATE TABLE t (d DATE DEFAULT (CURDATE()))`) + if _, ok := c16OracleDatetimePrecision(t, "d"); ok { + t.Errorf("oracle: DATE column should have NULL DATETIME_PRECISION") + } + + // omni: CURDATE(6) as DEFAULT should be rejected by parser. + c2 := scenarioNewCatalog(t) + results, parseErr := c2.Exec(`SELECT CURDATE(6)`, nil) + omniAccepted := parseErr == nil + if parseErr == nil { + for _, r := range results { + if r.Error != nil { + omniAccepted = false + break + } + } + } + if omniAccepted { + t.Errorf("omni: KNOWN BUG — CURDATE(6) should fail parse, see scenarios_bug_queue/c16.md") + } + }) + + // ---- 16.4 CURTIME / CURRENT_TIME / UTC_TIME precision defaults to 0 -- + t.Run("16_4_CURTIME_precision_default_0", func(t *testing.T) { + scenarioReset(t, mc) + + var l0, l6 int + row := mc.db.QueryRowContext(mc.ctx, + `SELECT LENGTH(CURTIME()), LENGTH(CURTIME(6))`) + if err := row.Scan(&l0, &l6); err != nil { + t.Errorf("oracle CURTIME length scan: %v", err) + } + if l0 != 8 { + t.Errorf("oracle LENGTH(CURTIME())=%d, want 8", l0) + } + if l6 != 15 { + t.Errorf("oracle LENGTH(CURTIME(6))=%d, want 15", l6) + } + + // CURTIME(7) rejected + if _, err := mc.db.ExecContext(mc.ctx, `DO CURTIME(7)`); err == nil { + t.Errorf("oracle: CURTIME(7) unexpectedly accepted") + } + + // omni: parses a SELECT CURTIME() successfully. Rejects CURTIME(7)? + c := scenarioNewCatalog(t) + results, parseErr := c.Exec(`SELECT CURTIME(7)`, nil) + omniAccepted := parseErr == nil + if parseErr == nil { + for _, r := range results { + if r.Error != nil { + omniAccepted = false + break + } + } + } + if omniAccepted { + t.Errorf("omni: KNOWN BUG — CURTIME(7) should be rejected (ER_TOO_BIG_PRECISION), see scenarios_bug_queue/c16.md") + } + }) + + // ---- 16.5 SYSDATE cannot be used as DEFAULT -------------------------- + t.Run("16_5_SYSDATE_not_allowed_as_default", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // SYSDATE is not a valid bareword in DEFAULT (parser rejects as syntax + // error) and SYSDATE() also fails add_field (not Item_func_now). Both + // forms must be rejected by MySQL. + ddl := `CREATE TABLE t (a DATETIME DEFAULT SYSDATE())` + oracleErr, omniErr := c16RunExpectError(t, c, ddl) + if oracleErr == nil { + t.Errorf("oracle: DATETIME DEFAULT SYSDATE() unexpectedly accepted") + } + if omniErr == nil { + t.Errorf("omni: KNOWN BUG — DATETIME DEFAULT SYSDATE should error (not a NOW_FUNC), see scenarios_bug_queue/c16.md") + } + }) + + // ---- 16.6 UTC_TIMESTAMP precision defaults to 0; not valid as DEFAULT - + t.Run("16_6_UTC_TIMESTAMP_precision_and_default", func(t *testing.T) { + scenarioReset(t, mc) + + var l0, l3 int + row := mc.db.QueryRowContext(mc.ctx, + `SELECT LENGTH(UTC_TIMESTAMP()), LENGTH(UTC_TIMESTAMP(3))`) + if err := row.Scan(&l0, &l3); err != nil { + t.Errorf("oracle UTC_TIMESTAMP length scan: %v", err) + } + if l0 != 19 { + t.Errorf("oracle LENGTH(UTC_TIMESTAMP())=%d, want 19", l0) + } + if l3 != 23 { + t.Errorf("oracle LENGTH(UTC_TIMESTAMP(3))=%d, want 23", l3) + } + + // UTC_TIMESTAMP as DEFAULT → ER_INVALID_DEFAULT on oracle. + c := scenarioNewCatalog(t) + ddl := `CREATE TABLE t (a DATETIME DEFAULT UTC_TIMESTAMP)` + oracleErr, omniErr := c16RunExpectError(t, c, ddl) + if oracleErr == nil { + t.Errorf("oracle: DATETIME DEFAULT UTC_TIMESTAMP unexpectedly accepted") + } + if omniErr == nil { + t.Errorf("omni: KNOWN BUG — DATETIME DEFAULT UTC_TIMESTAMP should error, see scenarios_bug_queue/c16.md") + } + }) + + // ---- 16.7 UNIX_TIMESTAMP return type depends on arg ------------------ + t.Run("16_7_UNIX_TIMESTAMP_return_type", func(t *testing.T) { + scenarioReset(t, mc) + + // Oracle: zero-arg UNIX_TIMESTAMP() returns BIGINT UNSIGNED; + // UNIX_TIMESTAMP(NOW(6)) returns DECIMAL with scale 6. Observe via + // the data type MySQL assigns to a view column. + if _, err := mc.db.ExecContext(mc.ctx, + `CREATE VIEW v AS SELECT UNIX_TIMESTAMP() AS u0, UNIX_TIMESTAMP(NOW(6)) AS u6`); err != nil { + t.Errorf("oracle CREATE VIEW v: %v", err) + return + } + var t0, t6 string + var s6 any + row := mc.db.QueryRowContext(mc.ctx, + `SELECT DATA_TYPE, NUMERIC_SCALE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='v' AND COLUMN_NAME='u6'`) + if err := row.Scan(&t6, &s6); err != nil { + t.Errorf("oracle u6 type scan: %v", err) + } else { + if strings.ToLower(t6) != "decimal" { + t.Errorf("oracle u6 DATA_TYPE: got %q, want decimal", t6) + } + // NUMERIC_SCALE may be int64 or []byte. + scale := 0 + switch x := s6.(type) { + case int64: + scale = int(x) + case []byte: + for _, c := range x { + if c >= '0' && c <= '9' { + scale = scale*10 + int(c-'0') + } + } + } + if scale != 6 { + t.Errorf("oracle u6 NUMERIC_SCALE: got %d, want 6", scale) + } + } + row = mc.db.QueryRowContext(mc.ctx, + `SELECT DATA_TYPE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='v' AND COLUMN_NAME='u0'`) + if err := row.Scan(&t0); err != nil { + t.Errorf("oracle u0 type scan: %v", err) + } else if strings.ToLower(t0) != "bigint" { + t.Errorf("oracle u0 DATA_TYPE: got %q, want bigint", t0) + } + + // omni side: parse the same view. If omni cannot derive a type for + // UNIX_TIMESTAMP at all, that is a gap documented in c16.md. + c := scenarioNewCatalog(t) + results, parseErr := c.Exec( + `CREATE VIEW v AS SELECT UNIX_TIMESTAMP() AS u0, UNIX_TIMESTAMP(NOW(6)) AS u6`, nil) + var omniErr error + if parseErr != nil { + omniErr = parseErr + } else { + for _, r := range results { + if r.Error != nil { + omniErr = r.Error + break + } + } + } + if omniErr != nil { + t.Errorf("omni: KNOWN GAP — UNIX_TIMESTAMP CREATE VIEW failed: %v", omniErr) + } + }) + + // ---- 16.8 DATETIME(N) DEFAULT NOW() fsp mismatch → ER_INVALID_DEFAULT - + t.Run("16_8_Datetime_fsp_mismatch_DEFAULT", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + cases := []string{ + `CREATE TABLE t (a DATETIME(6) DEFAULT NOW())`, + `CREATE TABLE t (a DATETIME(3) DEFAULT NOW(6))`, + } + for _, ddl := range cases { + scenarioReset(t, mc) + c2 := scenarioNewCatalog(t) + oracleErr, omniErr := c16RunExpectError(t, c2, ddl) + if oracleErr == nil { + t.Errorf("oracle: %q unexpectedly accepted", ddl) + } else if !strings.Contains(oracleErr.Error(), "1067") && + !strings.Contains(strings.ToLower(oracleErr.Error()), "invalid default") { + t.Errorf("oracle: %q unexpected error: %v", ddl, oracleErr) + } + if omniErr == nil { + t.Errorf("omni: KNOWN BUG — %q should error (fsp mismatch), see scenarios_bug_queue/c16.md", ddl) + } + } + + // And the valid pair must succeed on both. + scenarioReset(t, mc) + c3 := scenarioNewCatalog(t) + runOnBoth(t, mc, c3, `CREATE TABLE t (a DATETIME(6) DEFAULT NOW(6))`) + _ = c + }) + + // ---- 16.9 ON UPDATE NOW(N) must match column fsp --------------------- + t.Run("16_9_On_update_fsp_mismatch", func(t *testing.T) { + // DATETIME(6) ON UPDATE NOW() → fsp mismatch (0 vs 6). + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + ddl := `CREATE TABLE t (a DATETIME(6) DEFAULT NOW(6) ON UPDATE NOW())` + oracleErr, omniErr := c16RunExpectError(t, c, ddl) + if oracleErr == nil { + t.Errorf("oracle: %q unexpectedly accepted", ddl) + } + if omniErr == nil { + t.Errorf("omni: KNOWN BUG — ON UPDATE NOW() fsp 0 on DATETIME(6) should error, see scenarios_bug_queue/c16.md") + } + + // DATE ON UPDATE NOW() → not allowed (ON UPDATE only on TIMESTAMP/DATETIME). + scenarioReset(t, mc) + c2 := scenarioNewCatalog(t) + ddl2 := `CREATE TABLE t (a DATE ON UPDATE NOW())` + oErr, mErr := c16RunExpectError(t, c2, ddl2) + if oErr == nil { + t.Errorf("oracle: %q unexpectedly accepted", ddl2) + } + if mErr == nil { + t.Errorf("omni: KNOWN BUG — DATE ON UPDATE NOW() should error, see scenarios_bug_queue/c16.md") + } + }) + + // ---- 16.10 DATETIME storage & SHOW CREATE omits (0) ------------------ + t.Run("16_10_Datetime_fsp_round_trip", func(t *testing.T) { + // Verify SHOW CREATE renders DATETIME (no suffix) for fsp=0, DATETIME(3) + // for fsp=3, and the omni catalog preserves the same column type. + type tc struct { + decl string + wantShow string + wantPrec int + } + cases := []tc{ + {"DATETIME", "datetime", 0}, + {"DATETIME(0)", "datetime", 0}, + {"DATETIME(3)", "datetime(3)", 3}, + {"DATETIME(6)", "datetime(6)", 6}, + {"TIMESTAMP(6)", "timestamp(6)", 6}, + } + for _, k := range cases { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + // Use explicit_defaults_for_timestamp so TIMESTAMP doesn't pick up promotion. + if _, err := mc.db.ExecContext(mc.ctx, `SET SESSION explicit_defaults_for_timestamp=1`); err != nil { + t.Errorf("oracle SET: %v", err) + } + ddl := `CREATE TABLE t (a ` + k.decl + ` NULL)` + runOnBoth(t, mc, c, ddl) + + show := strings.ToLower(oracleShow(t, mc, "SHOW CREATE TABLE t")) + if !strings.Contains(show, "`a` "+k.wantShow) { + t.Errorf("oracle SHOW CREATE for %q: missing %q in %s", k.decl, k.wantShow, show) + } + // DATETIME(0) must NOT appear. + if strings.Contains(show, "datetime(0)") || strings.Contains(show, "timestamp(0)") { + t.Errorf("oracle SHOW CREATE for %q: (0) suffix not elided: %s", k.decl, show) + } + + if p, ok := c16OracleDatetimePrecision(t, "a"); !ok { + t.Errorf("oracle DATETIME_PRECISION NULL for %q", k.decl) + } else if p != k.wantPrec { + t.Errorf("oracle DATETIME_PRECISION for %q: got %d, want %d", k.decl, p, k.wantPrec) + } + + // omni: column type should match the rendered form. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table missing for %q", k.decl) + continue + } + col := tbl.GetColumn("a") + if col == nil { + t.Errorf("omni: column a missing for %q", k.decl) + continue + } + omniType := strings.ToLower(col.ColumnType) + if omniType == "" { + omniType = strings.ToLower(col.DataType) + } + if !strings.Contains(omniType, k.wantShow) { + t.Errorf("omni: ColumnType=%q for %q, want containing %q", omniType, k.decl, k.wantShow) + } + if strings.Contains(omniType, "(0)") { + t.Errorf("omni: KNOWN BUG — %q kept (0) suffix: %q", k.decl, omniType) + } + } + }) + + // ---- 16.11 YEAR(N) deprecation — only YEAR(4) accepted ---------------- + t.Run("16_11_YEAR_normalization", func(t *testing.T) { + // YEAR(2), YEAR(3), YEAR(5) → ER_INVALID_YEAR_COLUMN_LENGTH (1818). + for _, n := range []string{"2", "3", "5"} { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + ddl := `CREATE TABLE t (y YEAR(` + n + `))` + oErr, mErr := c16RunExpectError(t, c, ddl) + if oErr == nil { + t.Errorf("oracle: YEAR(%s) unexpectedly accepted", n) + } + if mErr == nil { + t.Errorf("omni: KNOWN BUG — YEAR(%s) should error (ER_INVALID_YEAR_COLUMN_LENGTH), see scenarios_bug_queue/c16.md", n) + } + } + + // YEAR(4) normalized to YEAR in SHOW CREATE. + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (y YEAR(4))`) + show := strings.ToLower(oracleShow(t, mc, "SHOW CREATE TABLE t")) + if !strings.Contains(show, "`y` year") || strings.Contains(show, "year(4)") { + t.Errorf("oracle SHOW CREATE: YEAR(4) not normalized: %s", show) + } + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl != nil { + if col := tbl.GetColumn("y"); col != nil { + omniType := strings.ToLower(col.ColumnType) + if omniType == "" { + omniType = strings.ToLower(col.DataType) + } + if strings.Contains(omniType, "year(4)") || strings.Contains(omniType, "year(2)") { + t.Errorf("omni: KNOWN BUG — YEAR(4) not normalized: %q", omniType) + } + } + } + }) + + // ---- 16.12 TIMESTAMP first-column promotion carries column fsp ------- + // + // NOTE: asymmetric scenario. The session variable + // `explicit_defaults_for_timestamp=0` only affects the MySQL oracle — + // omni has no session-variable model today. The omni-side assertions + // below are tagged "KNOWN GAP" and are expected to fail in either + // direction: omni's promotion path either doesn't honor the oracle's + // session state at all (today's behavior) or eventually will honor + // session vars and then match automatically. If omni starts tracking + // session vars, revisit this test to mirror the SET on both sides. + t.Run("16_12_Timestamp_promotion_fsp", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + // Default explicit_defaults_for_timestamp=ON on MySQL 8.0, so turn it + // off to trigger implicit promotion on the oracle side. + if _, err := mc.db.ExecContext(mc.ctx, `SET SESSION explicit_defaults_for_timestamp=0`); err != nil { + t.Errorf("oracle SET: %v", err) + } + runOnBoth(t, mc, c, `CREATE TABLE t (ts TIMESTAMP(3) NOT NULL)`) + + show := strings.ToLower(oracleShow(t, mc, "SHOW CREATE TABLE t")) + if !strings.Contains(show, "timestamp(3)") { + t.Errorf("oracle SHOW CREATE: missing timestamp(3): %s", show) + } + if !strings.Contains(show, "default current_timestamp(3)") { + t.Errorf("oracle SHOW CREATE: missing DEFAULT CURRENT_TIMESTAMP(3): %s", show) + } + if !strings.Contains(show, "on update current_timestamp(3)") { + t.Errorf("oracle SHOW CREATE: missing ON UPDATE CURRENT_TIMESTAMP(3): %s", show) + } + + // omni side: check promotion produced matching fsp on DEFAULT and ON UPDATE. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + col := tbl.GetColumn("ts") + if col == nil { + t.Errorf("omni: column ts missing") + return + } + if col.Default == nil { + t.Errorf("omni: KNOWN GAP — expected DEFAULT CURRENT_TIMESTAMP(3) after first-col TIMESTAMP promotion, see scenarios_bug_queue/c16.md") + } else { + lo := strings.ToLower(*col.Default) + if !strings.Contains(lo, "current_timestamp(3)") && !strings.Contains(lo, "now(3)") { + t.Errorf("omni: KNOWN BUG — DEFAULT = %q, expected CURRENT_TIMESTAMP(3), see scenarios_bug_queue/c16.md", *col.Default) + } + } + if col.OnUpdate == "" { + t.Errorf("omni: KNOWN GAP — expected ON UPDATE CURRENT_TIMESTAMP(3) after promotion, see scenarios_bug_queue/c16.md") + } else if !strings.Contains(strings.ToLower(col.OnUpdate), "current_timestamp(3)") && + !strings.Contains(strings.ToLower(col.OnUpdate), "now(3)") { + t.Errorf("omni: KNOWN BUG — ON UPDATE = %q, expected CURRENT_TIMESTAMP(3)", col.OnUpdate) + } + }) +} + +// itoaC16 is a tiny local helper that avoids importing strconv just for +// TestScenario_C16. Only used with small non-negative integers. +func itoaC16(n int) string { + if n == 0 { + return "0" + } + var buf [8]byte + i := len(buf) + for n > 0 { + i-- + buf[i] = byte('0' + n%10) + n /= 10 + } + return string(buf[i:]) +} diff --git a/tidb/catalog/scenarios_c17_test.go b/tidb/catalog/scenarios_c17_test.go new file mode 100644 index 00000000..3d41a97e --- /dev/null +++ b/tidb/catalog/scenarios_c17_test.go @@ -0,0 +1,390 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C17 covers section C17 (String function charset / collation +// propagation) from SCENARIOS-mysql-implicit-behavior.md. 8 scenarios, each +// asserted on both a MySQL 8.0 container and the omni catalog. Every C17 +// scenario is expected to currently fail on omni because analyze_expr.go / +// function_types.go do not track charset, collation, or derivation +// coercibility. Each omni failure is documented in scenarios_bug_queue/c17.md +// and reported via t.Errorf as KNOWN BUG so proof stays compile-clean. +func TestScenario_C17(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // c17OracleViewCol fetches (CHARACTER_SET_NAME, COLLATION_NAME) for a + // view column, tolerating NULL as "". + c17OracleViewCol := func(t *testing.T, view, col string) (charset, collation string, ok bool) { + t.Helper() + var cs, co any + row := mc.db.QueryRowContext(mc.ctx, + `SELECT CHARACTER_SET_NAME, COLLATION_NAME + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME=? AND COLUMN_NAME=?`, + view, col) + if err := row.Scan(&cs, &co); err != nil { + t.Errorf("oracle view col (%s.%s): %v", view, col, err) + return "", "", false + } + return asString(cs), asString(co), true + } + + // c17OracleViewMaxLen returns CHARACTER_MAXIMUM_LENGTH for a view column. + c17OracleViewMaxLen := func(t *testing.T, view, col string) (int, bool) { + t.Helper() + var v any + row := mc.db.QueryRowContext(mc.ctx, + `SELECT CHARACTER_MAXIMUM_LENGTH FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME=? AND COLUMN_NAME=?`, + view, col) + if err := row.Scan(&v); err != nil { + t.Errorf("oracle max_len (%s.%s): %v", view, col, err) + return 0, false + } + switch x := v.(type) { + case nil: + return 0, false + case int64: + return int(x), true + case []byte: + n := 0 + for _, b := range x { + if b >= '0' && b <= '9' { + n = n*10 + int(b-'0') + } + } + return n, true + } + return 0, false + } + + // c17RunExpectError runs a DDL expected to fail and returns both errors. + c17RunExpectError := func(t *testing.T, c *Catalog, ddl string) (oracleErr, omniErr error) { + t.Helper() + _, oracleErr = mc.db.ExecContext(mc.ctx, ddl) + results, parseErr := c.Exec(ddl, nil) + if parseErr != nil { + omniErr = parseErr + return + } + for _, r := range results { + if r.Error != nil { + omniErr = r.Error + return + } + } + return + } + + // c17OmniViewExists checks whether omni's catalog has a view by name. + c17OmniViewExists := func(c *Catalog, view string) bool { + db := c.GetDatabase("testdb") + if db == nil { + return false + } + _, ok := db.Views[toLower(view)] + return ok + } + + // ---- 17.1 CONCAT identical charset/collation -------------------------- + t.Run("17_1_CONCAT_same_charset", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci, + b VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci + )`) + runOnBoth(t, mc, c, `CREATE VIEW v1 AS SELECT CONCAT(a, b) AS c FROM t`) + + cs, co, ok := c17OracleViewCol(t, "v1", "c") + if ok { + assertStringEq(t, "oracle v1.c CHARACTER_SET_NAME", cs, "utf8mb4") + assertStringEq(t, "oracle v1.c COLLATION_NAME", co, "utf8mb4_0900_ai_ci") + } + if ml, ok := c17OracleViewMaxLen(t, "v1", "c"); ok { + assertIntEq(t, "oracle v1.c CHARACTER_MAXIMUM_LENGTH", ml, 20) + } + + // omni: no charset metadata on the view target list (KNOWN GAP). + if !c17OmniViewExists(c, "v1") { + t.Errorf("omni: view v1 not created") + } else { + t.Errorf("omni: KNOWN GAP — CONCAT result carries no charset/collation metadata (17.1), see scenarios_bug_queue/c17.md") + } + }) + + // ---- 17.2 CONCAT mixing latin1 + utf8mb4 ------------------------------ + t.Run("17_2_CONCAT_latin1_utf8mb4_superset", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci, + b VARCHAR(10) CHARACTER SET latin1 COLLATE latin1_swedish_ci + )`) + runOnBoth(t, mc, c, `CREATE VIEW v2 AS SELECT CONCAT(a, b) AS c FROM t`) + + cs, co, ok := c17OracleViewCol(t, "v2", "c") + if ok { + assertStringEq(t, "oracle v2.c CHARACTER_SET_NAME", cs, "utf8mb4") + assertStringEq(t, "oracle v2.c COLLATION_NAME", co, "utf8mb4_0900_ai_ci") + } + + if c17OmniViewExists(c, "v2") { + t.Errorf("omni: KNOWN GAP — CONCAT superset widening (latin1→utf8mb4) not tracked (17.2)") + } + }) + + // ---- 17.3 CONCAT incompatible collations should error ----------------- + t.Run("17_3_CONCAT_incompatible_collations", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci, + b VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_as_cs + )`) + if _, err := mc.db.ExecContext(mc.ctx, `INSERT INTO t VALUES ('x','y')`); err != nil { + t.Errorf("oracle INSERT: %v", err) + } + + // ORACLE FINDING: MySQL 8.0.45 does NOT raise 1267 for + // CONCAT(utf8mb4_0900_ai_ci, utf8mb4_0900_as_cs) — the CONCAT path + // silently widens via a newer pad-space-compat rule. The canonical + // illegal-mix trigger for this pair is the `=` comparison path + // (Item_bool_func2::fix_length_and_dec), which we use here as the + // stable probe that forces DTCollation::aggregate to fail. + _, oracleCmpErr := mc.db.ExecContext(mc.ctx, `SELECT 1 FROM t WHERE a = b`) + if oracleCmpErr == nil { + t.Errorf("oracle: comparison of two incompatible IMPLICIT collations unexpectedly accepted — expected ER_CANT_AGGREGATE_2COLLATIONS (1267)") + } else if !strings.Contains(oracleCmpErr.Error(), "1267") && + !strings.Contains(strings.ToLower(oracleCmpErr.Error()), "illegal mix") { + t.Errorf("oracle: unexpected error: %v", oracleCmpErr) + } + + // omni: SELECT ... WHERE a=b must reject (aggregation gap). + results, parseErr := c.Exec(`CREATE VIEW v3 AS SELECT a FROM t WHERE a = b`, nil) + omniAccepted := parseErr == nil + if parseErr == nil { + for _, r := range results { + if r.Error != nil { + omniAccepted = false + break + } + } + } + if omniAccepted { + t.Errorf("omni: KNOWN BUG — soft-accept of illegal-mix comparison (17.3); should error 1267. See scenarios_bug_queue/c17.md") + } + }) + + // ---- 17.4 CONCAT_WS NULL skipping + separator aggregation ------------- + t.Run("17_4_CONCAT_WS_nullskip", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci, + b VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci + )`) + runOnBoth(t, mc, c, `CREATE VIEW v4 AS SELECT CONCAT_WS(',', a, b, NULL) AS c FROM t`) + + if cs, co, ok := c17OracleViewCol(t, "v4", "c"); ok { + assertStringEq(t, "oracle v4.c CHARACTER_SET_NAME", cs, "utf8mb4") + assertStringEq(t, "oracle v4.c COLLATION_NAME", co, "utf8mb4_0900_ai_ci") + } + + // Runtime: CONCAT(NULL,'x') is NULL; CONCAT_WS(',',NULL,'x') is 'x'. + if _, err := mc.db.ExecContext(mc.ctx, `INSERT INTO t VALUES (NULL, 'x')`); err != nil { + t.Errorf("oracle INSERT: %v", err) + } + var concatRes, cwsRes any + row := mc.db.QueryRowContext(mc.ctx, + `SELECT CONCAT(a,b), CONCAT_WS(',',a,b) FROM t LIMIT 1`) + if err := row.Scan(&concatRes, &cwsRes); err != nil { + t.Errorf("oracle runtime scan: %v", err) + } else { + if concatRes != nil { + t.Errorf("oracle CONCAT(NULL,'x'): got %v, want NULL", concatRes) + } + if s := asString(cwsRes); s != "x" { + t.Errorf("oracle CONCAT_WS(',',NULL,'x'): got %q, want \"x\"", s) + } + } + + if c17OmniViewExists(c, "v4") { + t.Errorf("omni: KNOWN GAP — CONCAT_WS NULL-skip semantics + charset aggregation not tracked (17.4)") + } + }) + + // ---- 17.5 _utf8mb4'x' introducer is still COERCIBLE ------------------- + t.Run("17_5_introducer_still_coercible", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // MySQL session: force a known session charset. + if _, err := mc.db.ExecContext(mc.ctx, `SET NAMES utf8mb4`); err != nil { + t.Errorf("oracle SET NAMES: %v", err) + } + + runOnBoth(t, mc, c, + `CREATE TABLE t (a VARCHAR(10) CHARACTER SET latin1 COLLATE latin1_swedish_ci)`) + runOnBoth(t, mc, c, `CREATE VIEW v5a AS SELECT CONCAT(a, 'x') AS c FROM t`) + runOnBoth(t, mc, c, `CREATE VIEW v5b AS SELECT CONCAT(a, _utf8mb4'x') AS c FROM t`) + + for _, view := range []string{"v5a", "v5b"} { + if cs, co, ok := c17OracleViewCol(t, view, "c"); ok { + assertStringEq(t, "oracle "+view+".c CHARACTER_SET_NAME", cs, "latin1") + assertStringEq(t, "oracle "+view+".c COLLATION_NAME", co, "latin1_swedish_ci") + } + } + + if c17OmniViewExists(c, "v5a") || c17OmniViewExists(c, "v5b") { + t.Errorf("omni: KNOWN GAP — literal/introducer coercibility (COERCIBLE vs IMPLICIT) not tracked (17.5)") + } + }) + + // ---- 17.6 REPEAT / LPAD / RPAD pin to first-arg charset --------------- + t.Run("17_6_first_arg_pins_charset", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a VARCHAR(10) CHARACTER SET latin1 COLLATE latin1_swedish_ci, + b VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci + )`) + runOnBoth(t, mc, c, `CREATE VIEW v6a AS SELECT REPEAT(a, 3) AS c FROM t`) + runOnBoth(t, mc, c, `CREATE VIEW v6b AS SELECT LPAD(a, 20, b) AS c FROM t`) + runOnBoth(t, mc, c, `CREATE VIEW v6c AS SELECT RPAD(b, 20, a) AS c FROM t`) + + type want struct{ cs, co string } + cases := map[string]want{ + "v6a": {"latin1", "latin1_swedish_ci"}, + "v6b": {"latin1", "latin1_swedish_ci"}, + "v6c": {"utf8mb4", "utf8mb4_0900_ai_ci"}, + } + for view, w := range cases { + if cs, co, ok := c17OracleViewCol(t, view, "c"); ok { + assertStringEq(t, "oracle "+view+".c CHARACTER_SET_NAME", cs, w.cs) + assertStringEq(t, "oracle "+view+".c COLLATION_NAME", co, w.co) + } + } + + anyExists := c17OmniViewExists(c, "v6a") || c17OmniViewExists(c, "v6b") || c17OmniViewExists(c, "v6c") + if anyExists { + t.Errorf("omni: KNOWN GAP — REPEAT/LPAD/RPAD first-arg-pins rule missing (17.6). Fix in function_types.go") + } + }) + + // ---- 17.7 CONVERT(x USING cs) forces charset IMPLICIT ----------------- + t.Run("17_7_CONVERT_USING_pins_charset", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_as_cs, + b VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci + )`) + // Without CONVERT this would be 17.3 (ER 1267). CONVERT rescues it. + ddl := `CREATE VIEW v7 AS SELECT CONCAT(a, CONVERT(b USING utf8mb4)) AS c FROM t` + _, oracleErr := mc.db.ExecContext(mc.ctx, ddl) + if oracleErr != nil { + // If MySQL rejects even one-sided CONVERT, the scenario has no + // oracle ground truth to compare omni against. Record this so the + // scenario can be refined and return — do NOT assert omni-side + // behavior against a missing oracle. + t.Skipf("17.7 oracle rejected one-sided CONVERT; scenario needs two-sided wrap. err=%v", oracleErr) + return + } + if cs, co, ok := c17OracleViewCol(t, "v7", "c"); ok { + assertStringEq(t, "oracle v7.c CHARACTER_SET_NAME", cs, "utf8mb4") + // The default collation of utf8mb4 is server-configurable; accept + // any utf8mb4_* collation. What matters is the charset pin. + if !strings.HasPrefix(co, "utf8mb4_") { + t.Errorf("oracle v7.c COLLATION_NAME: got %q, want utf8mb4_* (some utf8mb4 collation)", co) + } + } + + // omni side: parse the same DDL. Only reached when oracle accepted, + // so the KNOWN GAP comparison is meaningful. + results, parseErr := c.Exec(ddl, nil) + omniAccepted := parseErr == nil + if parseErr == nil { + for _, r := range results { + if r.Error != nil { + omniAccepted = false + break + } + } + } + if omniAccepted { + t.Errorf("omni: KNOWN GAP — CONVERT ... USING cs accepted but charset not pinned on result (17.7), see scenarios_bug_queue/c17.md") + } + }) + + // ---- 17.8 COLLATE clause is EXPLICIT — highest precedence ------------- + t.Run("17_8_COLLATE_explicit", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci, + b VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_as_cs + )`) + + // Case A: one side EXPLICIT wins. + ddlA := `CREATE VIEW v8a AS SELECT CONCAT(a, b COLLATE utf8mb4_bin) AS c FROM t` + if _, err := mc.db.ExecContext(mc.ctx, ddlA); err != nil { + t.Errorf("oracle v8a unexpected error: %v", err) + } else { + if cs, co, ok := c17OracleViewCol(t, "v8a", "c"); ok { + assertStringEq(t, "oracle v8a.c CHARACTER_SET_NAME", cs, "utf8mb4") + assertStringEq(t, "oracle v8a.c COLLATION_NAME", co, "utf8mb4_bin") + } + } + + // omni: same DDL should analyze — gap is that EXPLICIT derivation + // isn't tracked, so downstream charset metadata is missing. + resultsA, parseErrA := c.Exec(ddlA, nil) + omniAcceptedA := parseErrA == nil + if parseErrA == nil { + for _, r := range resultsA { + if r.Error != nil { + omniAcceptedA = false + break + } + } + } + if omniAcceptedA { + t.Errorf("omni: KNOWN GAP — COLLATE EXPLICIT derivation not tracked on v8a (17.8)") + } + + // Case B: two EXPLICIT sides with different collations must error. + scenarioReset(t, mc) + cB := scenarioNewCatalog(t) + runOnBoth(t, mc, cB, `CREATE TABLE t ( + a VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci, + b VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_as_cs + )`) + ddlB := `CREATE VIEW v8b AS SELECT CONCAT(a COLLATE utf8mb4_0900_ai_ci, b COLLATE utf8mb4_bin) AS c FROM t` + oracleErr, omniErr := c17RunExpectError(t, cB, ddlB) + if oracleErr == nil { + t.Errorf("oracle: %q unexpectedly accepted — expected illegal-mix error (1267/1270)", ddlB) + } else if !strings.Contains(oracleErr.Error(), "1267") && + !strings.Contains(oracleErr.Error(), "1270") && + !strings.Contains(strings.ToLower(oracleErr.Error()), "illegal mix") { + t.Errorf("oracle v8b unexpected error: %v", oracleErr) + } + if omniErr == nil { + t.Errorf("omni: KNOWN BUG — two EXPLICIT COLLATE sides silently accepted (17.8), see scenarios_bug_queue/c17.md") + } + }) +} diff --git a/tidb/catalog/scenarios_c18_test.go b/tidb/catalog/scenarios_c18_test.go new file mode 100644 index 00000000..6b0ba682 --- /dev/null +++ b/tidb/catalog/scenarios_c18_test.go @@ -0,0 +1,524 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C18 covers Section C18 "SHOW CREATE TABLE elision rules" from +// mysql/catalog/SCENARIOS-mysql-implicit-behavior.md. Each subtest runs the +// scenario's DDL on both the MySQL 8.0 container and the omni catalog, then +// asserts that omni's ShowCreateTable output matches the oracle on the +// specific elision rule under test. +// +// Critical section: every mismatch here breaks SDL round-trip. Failures are +// recorded as t.Error (not t.Fatal) so all 15 scenarios run in one pass, and +// each omni gap is documented in mysql/catalog/scenarios_bug_queue/c18.md. +func TestScenario_C18(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // ----------------------------------------------------------------- + // 18.1 Column charset elided when equal to table default + // ----------------------------------------------------------------- + t.Run("18_1_column_charset_elision", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t ( + a VARCHAR(10), + b VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci + ) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci` + runOnBoth(t, mc, c, ddl) + + mysqlCreate := oracleShow(t, mc, "SHOW CREATE TABLE t") + omniCreate := c.ShowCreateTable("testdb", "t") + + // Column a: no column-level CHARACTER SET / COLLATE. + if strings.Contains(aLine(mysqlCreate, "`a`"), "CHARACTER SET") { + t.Errorf("oracle: column a should not have CHARACTER SET; got %q", mysqlCreate) + } + if strings.Contains(aLine(omniCreate, "`a`"), "CHARACTER SET") { + t.Errorf("omni: column a should not have CHARACTER SET; got %q", omniCreate) + } + // Column b: has non-default collation, so CHARACTER SET/COLLATE rendered. + if !strings.Contains(aLine(mysqlCreate, "`b`"), "utf8mb4_unicode_ci") { + t.Errorf("oracle: column b should have COLLATE utf8mb4_unicode_ci; got %q", mysqlCreate) + } + if !strings.Contains(aLine(omniCreate, "`b`"), "utf8mb4_unicode_ci") { + t.Errorf("omni: column b should have COLLATE utf8mb4_unicode_ci; got %q", omniCreate) + } + }) + + // ----------------------------------------------------------------- + // 18.2 NOT NULL elision: TIMESTAMP shows NULL, others hide it + // ----------------------------------------------------------------- + t.Run("18_2_null_elision_timestamp", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t ( + i INT, + i_nn INT NOT NULL, + ts TIMESTAMP NULL DEFAULT NULL, + ts_nn TIMESTAMP NOT NULL DEFAULT '2020-01-01' + )` + runOnBoth(t, mc, c, ddl) + + mysqlCreate := oracleShow(t, mc, "SHOW CREATE TABLE t") + omniCreate := c.ShowCreateTable("testdb", "t") + + // Oracle: i has no explicit NULL, ts has explicit NULL. + if strings.Contains(aLine(mysqlCreate, "`i` "), "NULL DEFAULT") { + // `i int DEFAULT NULL` is fine; we only object to `NULL DEFAULT NULL`. + } + if !strings.Contains(aLine(mysqlCreate, "`ts` "), "NULL DEFAULT NULL") { + t.Errorf("oracle: ts TIMESTAMP should render explicit NULL; got %q", aLine(mysqlCreate, "`ts` ")) + } + // omni comparison + if strings.Contains(aLine(omniCreate, "`i` "), "`i` int NULL") { + t.Errorf("omni: i should not render NULL keyword; got %q", aLine(omniCreate, "`i` ")) + } + if !strings.Contains(aLine(omniCreate, "`ts` "), "NULL") { + t.Errorf("omni: ts TIMESTAMP should render explicit NULL; got %q", aLine(omniCreate, "`ts` ")) + } + }) + + // ----------------------------------------------------------------- + // 18.3 ENGINE always rendered + // ----------------------------------------------------------------- + t.Run("18_3_engine_always_rendered", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t (a INT)") + + mysqlCreate := oracleShow(t, mc, "SHOW CREATE TABLE t") + omniCreate := c.ShowCreateTable("testdb", "t") + + if !strings.Contains(mysqlCreate, "ENGINE=InnoDB") { + t.Errorf("oracle: expected ENGINE=InnoDB; got %q", mysqlCreate) + } + if !strings.Contains(omniCreate, "ENGINE=InnoDB") { + t.Errorf("omni: expected ENGINE=InnoDB; got %q", omniCreate) + } + }) + + // ----------------------------------------------------------------- + // 18.4 AUTO_INCREMENT elided when counter == 1 + // ----------------------------------------------------------------- + t.Run("18_4_auto_increment_elided_when_one", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t (id INT AUTO_INCREMENT PRIMARY KEY)") + + mysqlCreate := oracleShow(t, mc, "SHOW CREATE TABLE t") + omniCreate := c.ShowCreateTable("testdb", "t") + + if strings.Contains(mysqlCreate, "AUTO_INCREMENT=") { + t.Errorf("oracle: AUTO_INCREMENT= should be elided; got %q", mysqlCreate) + } + if strings.Contains(omniCreate, "AUTO_INCREMENT=") { + t.Errorf("omni: AUTO_INCREMENT= should be elided; got %q", omniCreate) + } + }) + + // ----------------------------------------------------------------- + // 18.5 DEFAULT CHARSET always rendered + // ----------------------------------------------------------------- + t.Run("18_5_default_charset_always_rendered", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE tnocs (x INT)") + + mysqlCreate := oracleShow(t, mc, "SHOW CREATE TABLE tnocs") + omniCreate := c.ShowCreateTable("testdb", "tnocs") + + if !strings.Contains(mysqlCreate, "DEFAULT CHARSET=utf8mb4") { + t.Errorf("oracle: DEFAULT CHARSET=utf8mb4 missing; got %q", mysqlCreate) + } + if !strings.Contains(mysqlCreate, "COLLATE=utf8mb4_0900_ai_ci") { + t.Errorf("oracle: COLLATE=utf8mb4_0900_ai_ci missing; got %q", mysqlCreate) + } + if !strings.Contains(omniCreate, "DEFAULT CHARSET=utf8mb4") { + t.Errorf("omni: DEFAULT CHARSET=utf8mb4 missing; got %q", omniCreate) + } + if !strings.Contains(omniCreate, "COLLATE=utf8mb4_0900_ai_ci") { + t.Errorf("omni: COLLATE=utf8mb4_0900_ai_ci missing; got %q", omniCreate) + } + }) + + // ----------------------------------------------------------------- + // 18.6 ROW_FORMAT elided when not explicitly specified + // ----------------------------------------------------------------- + t.Run("18_6_row_format_elision", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t (a INT)") + runOnBoth(t, mc, c, "CREATE TABLE t2 (a INT) ROW_FORMAT=DYNAMIC") + + mysqlT := oracleShow(t, mc, "SHOW CREATE TABLE t") + mysqlT2 := oracleShow(t, mc, "SHOW CREATE TABLE t2") + omniT := c.ShowCreateTable("testdb", "t") + omniT2 := c.ShowCreateTable("testdb", "t2") + + if strings.Contains(mysqlT, "ROW_FORMAT=") { + t.Errorf("oracle: implicit ROW_FORMAT should be elided; got %q", mysqlT) + } + if !strings.Contains(mysqlT2, "ROW_FORMAT=DYNAMIC") { + t.Errorf("oracle: explicit ROW_FORMAT=DYNAMIC missing; got %q", mysqlT2) + } + if strings.Contains(omniT, "ROW_FORMAT=") { + t.Errorf("omni: implicit ROW_FORMAT should be elided; got %q", omniT) + } + if !strings.Contains(omniT2, "ROW_FORMAT=DYNAMIC") { + t.Errorf("omni: explicit ROW_FORMAT=DYNAMIC missing; got %q", omniT2) + } + }) + + // ----------------------------------------------------------------- + // 18.7 Table-level COLLATE rendering rules + // ----------------------------------------------------------------- + t.Run("18_7_table_collate_rendering", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t_prim (x INT) CHARACTER SET latin1") + runOnBoth(t, mc, c, "CREATE TABLE t_nonprim (x INT) CHARACTER SET latin1 COLLATE latin1_bin") + runOnBoth(t, mc, c, "CREATE TABLE t_0900 (x INT) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci") + + mysqlPrim := oracleShow(t, mc, "SHOW CREATE TABLE t_prim") + mysqlNonPrim := oracleShow(t, mc, "SHOW CREATE TABLE t_nonprim") + mysql0900 := oracleShow(t, mc, "SHOW CREATE TABLE t_0900") + omniPrim := c.ShowCreateTable("testdb", "t_prim") + omniNonPrim := c.ShowCreateTable("testdb", "t_nonprim") + omni0900 := c.ShowCreateTable("testdb", "t_0900") + + // t_prim: DEFAULT CHARSET=latin1, NO COLLATE=. + if !strings.Contains(mysqlPrim, "DEFAULT CHARSET=latin1") { + t.Errorf("oracle t_prim: missing DEFAULT CHARSET=latin1; got %q", mysqlPrim) + } + if strings.Contains(mysqlPrim, "COLLATE=") { + t.Errorf("oracle t_prim: COLLATE= should be elided; got %q", mysqlPrim) + } + if strings.Contains(omniPrim, "COLLATE=") { + t.Errorf("omni t_prim: COLLATE= should be elided; got %q", omniPrim) + } + // t_nonprim: COLLATE=latin1_bin. + if !strings.Contains(mysqlNonPrim, "COLLATE=latin1_bin") { + t.Errorf("oracle t_nonprim: missing COLLATE=latin1_bin; got %q", mysqlNonPrim) + } + if !strings.Contains(omniNonPrim, "COLLATE=latin1_bin") { + t.Errorf("omni t_nonprim: missing COLLATE=latin1_bin; got %q", omniNonPrim) + } + // t_0900: COLLATE=utf8mb4_0900_ai_ci (special case — always rendered). + if !strings.Contains(mysql0900, "COLLATE=utf8mb4_0900_ai_ci") { + t.Errorf("oracle t_0900: missing COLLATE=utf8mb4_0900_ai_ci; got %q", mysql0900) + } + if !strings.Contains(omni0900, "COLLATE=utf8mb4_0900_ai_ci") { + t.Errorf("omni t_0900: missing COLLATE=utf8mb4_0900_ai_ci; got %q", omni0900) + } + }) + + // ----------------------------------------------------------------- + // 18.8 KEY_BLOCK_SIZE elision + // ----------------------------------------------------------------- + t.Run("18_8_key_block_size_elision", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t_nokbs (a INT)") + runOnBoth(t, mc, c, "CREATE TABLE t_kbs (a INT) KEY_BLOCK_SIZE=4") + + mysqlNo := oracleShow(t, mc, "SHOW CREATE TABLE t_nokbs") + mysqlYes := oracleShow(t, mc, "SHOW CREATE TABLE t_kbs") + omniNo := c.ShowCreateTable("testdb", "t_nokbs") + omniYes := c.ShowCreateTable("testdb", "t_kbs") + + if strings.Contains(mysqlNo, "KEY_BLOCK_SIZE=") { + t.Errorf("oracle t_nokbs: KEY_BLOCK_SIZE should be elided; got %q", mysqlNo) + } + if !strings.Contains(mysqlYes, "KEY_BLOCK_SIZE=4") { + t.Errorf("oracle t_kbs: missing KEY_BLOCK_SIZE=4; got %q", mysqlYes) + } + if strings.Contains(omniNo, "KEY_BLOCK_SIZE=") { + t.Errorf("omni t_nokbs: KEY_BLOCK_SIZE should be elided; got %q", omniNo) + } + if !strings.Contains(omniYes, "KEY_BLOCK_SIZE=4") { + t.Errorf("omni t_kbs: missing KEY_BLOCK_SIZE=4; got %q", omniYes) + } + }) + + // ----------------------------------------------------------------- + // 18.9 COMPRESSION elision + // ----------------------------------------------------------------- + t.Run("18_9_compression_elision", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t_nocomp (a INT)") + runOnBoth(t, mc, c, "CREATE TABLE t_comp (a INT) COMPRESSION='ZLIB'") + + mysqlNo := oracleShow(t, mc, "SHOW CREATE TABLE t_nocomp") + mysqlYes := oracleShow(t, mc, "SHOW CREATE TABLE t_comp") + omniNo := c.ShowCreateTable("testdb", "t_nocomp") + omniYes := c.ShowCreateTable("testdb", "t_comp") + + if strings.Contains(mysqlNo, "COMPRESSION=") { + t.Errorf("oracle t_nocomp: COMPRESSION should be elided; got %q", mysqlNo) + } + if !strings.Contains(mysqlYes, "COMPRESSION='ZLIB'") { + t.Errorf("oracle t_comp: missing COMPRESSION='ZLIB'; got %q", mysqlYes) + } + if strings.Contains(omniNo, "COMPRESSION=") { + t.Errorf("omni t_nocomp: COMPRESSION should be elided; got %q", omniNo) + } + if !strings.Contains(omniYes, "COMPRESSION='ZLIB'") { + t.Errorf("omni t_comp: missing COMPRESSION='ZLIB'; got %q", omniYes) + } + }) + + // ----------------------------------------------------------------- + // 18.10 STATS_PERSISTENT / STATS_AUTO_RECALC / STATS_SAMPLE_PAGES + // ----------------------------------------------------------------- + t.Run("18_10_stats_clauses_elision", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t_nostats (a INT)") + runOnBoth(t, mc, c, "CREATE TABLE t_stats (a INT) STATS_PERSISTENT=1 STATS_AUTO_RECALC=0 STATS_SAMPLE_PAGES=32") + + mysqlNo := oracleShow(t, mc, "SHOW CREATE TABLE t_nostats") + mysqlYes := oracleShow(t, mc, "SHOW CREATE TABLE t_stats") + omniNo := c.ShowCreateTable("testdb", "t_nostats") + omniYes := c.ShowCreateTable("testdb", "t_stats") + + for _, clause := range []string{"STATS_PERSISTENT=", "STATS_AUTO_RECALC=", "STATS_SAMPLE_PAGES="} { + if strings.Contains(mysqlNo, clause) { + t.Errorf("oracle t_nostats: %s should be elided; got %q", clause, mysqlNo) + } + if strings.Contains(omniNo, clause) { + t.Errorf("omni t_nostats: %s should be elided; got %q", clause, omniNo) + } + } + for _, want := range []string{"STATS_PERSISTENT=1", "STATS_AUTO_RECALC=0", "STATS_SAMPLE_PAGES=32"} { + if !strings.Contains(mysqlYes, want) { + t.Errorf("oracle t_stats: missing %s; got %q", want, mysqlYes) + } + if !strings.Contains(omniYes, want) { + t.Errorf("omni t_stats: missing %s; got %q", want, omniYes) + } + } + }) + + // ----------------------------------------------------------------- + // 18.11 MIN_ROWS / MAX_ROWS / AVG_ROW_LENGTH elision + // ----------------------------------------------------------------- + t.Run("18_11_min_max_avg_rows_elision", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t_nominmax (a INT)") + runOnBoth(t, mc, c, "CREATE TABLE t_minmax (a INT) MIN_ROWS=10 MAX_ROWS=1000 AVG_ROW_LENGTH=256") + + mysqlNo := oracleShow(t, mc, "SHOW CREATE TABLE t_nominmax") + mysqlYes := oracleShow(t, mc, "SHOW CREATE TABLE t_minmax") + omniNo := c.ShowCreateTable("testdb", "t_nominmax") + omniYes := c.ShowCreateTable("testdb", "t_minmax") + + for _, clause := range []string{"MIN_ROWS=", "MAX_ROWS=", "AVG_ROW_LENGTH="} { + if strings.Contains(mysqlNo, clause) { + t.Errorf("oracle t_nominmax: %s should be elided; got %q", clause, mysqlNo) + } + if strings.Contains(omniNo, clause) { + t.Errorf("omni t_nominmax: %s should be elided; got %q", clause, omniNo) + } + } + for _, want := range []string{"MIN_ROWS=10", "MAX_ROWS=1000", "AVG_ROW_LENGTH=256"} { + if !strings.Contains(mysqlYes, want) { + t.Errorf("oracle t_minmax: missing %s; got %q", want, mysqlYes) + } + if !strings.Contains(omniYes, want) { + t.Errorf("omni t_minmax: missing %s; got %q", want, omniYes) + } + } + }) + + // ----------------------------------------------------------------- + // 18.12 TABLESPACE clause elision + // ----------------------------------------------------------------- + t.Run("18_12_tablespace_elision", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t_default (a INT)") + // Use innodb_system — always available in MySQL 8.0. + runOnBoth(t, mc, c, "CREATE TABLE t_gts (a INT) TABLESPACE=innodb_system") + + mysqlNo := oracleShow(t, mc, "SHOW CREATE TABLE t_default") + mysqlYes := oracleShow(t, mc, "SHOW CREATE TABLE t_gts") + omniNo := c.ShowCreateTable("testdb", "t_default") + omniYes := c.ShowCreateTable("testdb", "t_gts") + + if strings.Contains(mysqlNo, "TABLESPACE") { + t.Errorf("oracle t_default: TABLESPACE should be elided; got %q", mysqlNo) + } + if !strings.Contains(mysqlYes, "TABLESPACE") { + t.Errorf("oracle t_gts: missing TABLESPACE clause; got %q", mysqlYes) + } + if strings.Contains(omniNo, "TABLESPACE") { + t.Errorf("omni t_default: TABLESPACE should be elided; got %q", omniNo) + } + if !strings.Contains(omniYes, "TABLESPACE") { + t.Errorf("omni t_gts: missing TABLESPACE clause; got %q", omniYes) + } + }) + + // ----------------------------------------------------------------- + // 18.13 PACK_KEYS / CHECKSUM / DELAY_KEY_WRITE elision + // ----------------------------------------------------------------- + t.Run("18_13_pack_checksum_delay_elision", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t_none (a INT)") + runOnBoth(t, mc, c, "CREATE TABLE t_opts (a INT) PACK_KEYS=1 CHECKSUM=1 DELAY_KEY_WRITE=1") + + mysqlNo := oracleShow(t, mc, "SHOW CREATE TABLE t_none") + mysqlYes := oracleShow(t, mc, "SHOW CREATE TABLE t_opts") + omniNo := c.ShowCreateTable("testdb", "t_none") + omniYes := c.ShowCreateTable("testdb", "t_opts") + + for _, clause := range []string{"PACK_KEYS=", "CHECKSUM=", "DELAY_KEY_WRITE="} { + if strings.Contains(mysqlNo, clause) { + t.Errorf("oracle t_none: %s should be elided; got %q", clause, mysqlNo) + } + if strings.Contains(omniNo, clause) { + t.Errorf("omni t_none: %s should be elided; got %q", clause, omniNo) + } + } + // Oracle may reject or normalize; only assert presence where oracle shows them. + for _, want := range []string{"PACK_KEYS=1", "CHECKSUM=1", "DELAY_KEY_WRITE=1"} { + if strings.Contains(mysqlYes, want) && !strings.Contains(omniYes, want) { + t.Errorf("omni t_opts: oracle has %s but omni does not; got omni=%q", want, omniYes) + } + } + }) + + // ----------------------------------------------------------------- + // 18.14 Per-index COMMENT and KEY_BLOCK_SIZE inside index clauses + // ----------------------------------------------------------------- + t.Run("18_14_per_index_comment_kbs", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t ( + id INT PRIMARY KEY, + a INT, + b INT, + KEY ix_plain (a), + KEY ix_cmt (b) COMMENT 'hello' + ) KEY_BLOCK_SIZE=4` + runOnBoth(t, mc, c, ddl) + + mysqlCreate := oracleShow(t, mc, "SHOW CREATE TABLE t") + omniCreate := c.ShowCreateTable("testdb", "t") + + // ix_plain: no COMMENT, no per-index KEY_BLOCK_SIZE. + plainMy := aLine(mysqlCreate, "ix_plain") + plainOmni := aLine(omniCreate, "ix_plain") + if strings.Contains(plainMy, "COMMENT") { + t.Errorf("oracle ix_plain: should not have COMMENT; got %q", plainMy) + } + if strings.Contains(plainMy, "KEY_BLOCK_SIZE") { + t.Errorf("oracle ix_plain: should not have KEY_BLOCK_SIZE; got %q", plainMy) + } + if strings.Contains(plainOmni, "COMMENT") { + t.Errorf("omni ix_plain: should not have COMMENT; got %q", plainOmni) + } + if strings.Contains(plainOmni, "KEY_BLOCK_SIZE") { + t.Errorf("omni ix_plain: should not have KEY_BLOCK_SIZE; got %q", plainOmni) + } + // ix_cmt: COMMENT 'hello' present; no KEY_BLOCK_SIZE. + cmtMy := aLine(mysqlCreate, "ix_cmt") + cmtOmni := aLine(omniCreate, "ix_cmt") + if !strings.Contains(cmtMy, "COMMENT 'hello'") { + t.Errorf("oracle ix_cmt: missing COMMENT 'hello'; got %q", cmtMy) + } + if !strings.Contains(cmtOmni, "COMMENT 'hello'") { + t.Errorf("omni ix_cmt: missing COMMENT 'hello'; got %q", cmtOmni) + } + // Table-level KEY_BLOCK_SIZE=4 present. + if !strings.Contains(mysqlCreate, "KEY_BLOCK_SIZE=4") { + t.Errorf("oracle: missing table-level KEY_BLOCK_SIZE=4; got %q", mysqlCreate) + } + if !strings.Contains(omniCreate, "KEY_BLOCK_SIZE=4") { + t.Errorf("omni: missing table-level KEY_BLOCK_SIZE=4; got %q", omniCreate) + } + }) + + // ----------------------------------------------------------------- + // 18.15 USING BTREE/HASH only when algorithm explicit + // ----------------------------------------------------------------- + t.Run("18_15_using_algorithm_explicit", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t ( + id INT, + a INT, + b INT, + KEY ix_default (a), + KEY ix_btree (a) USING BTREE, + KEY ix_hash (b) USING HASH + ) ENGINE=InnoDB` + runOnBoth(t, mc, c, ddl) + + mysqlCreate := oracleShow(t, mc, "SHOW CREATE TABLE t") + omniCreate := c.ShowCreateTable("testdb", "t") + + // ix_default: no USING clause. + defaultMy := aLine(mysqlCreate, "ix_default") + defaultOmni := aLine(omniCreate, "ix_default") + if strings.Contains(defaultMy, "USING") { + t.Errorf("oracle ix_default: should not have USING clause; got %q", defaultMy) + } + if strings.Contains(defaultOmni, "USING") { + t.Errorf("omni ix_default: should not have USING clause; got %q", defaultOmni) + } + // ix_btree: has USING clause. + btreeMy := aLine(mysqlCreate, "ix_btree") + btreeOmni := aLine(omniCreate, "ix_btree") + if !strings.Contains(btreeMy, "USING") { + t.Errorf("oracle ix_btree: missing USING clause; got %q", btreeMy) + } + if !strings.Contains(btreeOmni, "USING") { + t.Errorf("omni ix_btree: missing USING clause; got %q", btreeOmni) + } + // ix_hash: oracle may or may not render USING (InnoDB rewrites HASH→BTREE). + // Contract: if oracle renders USING, omni must too. + hashMy := aLine(mysqlCreate, "ix_hash") + hashOmni := aLine(omniCreate, "ix_hash") + if strings.Contains(hashMy, "USING") && !strings.Contains(hashOmni, "USING") { + t.Errorf("omni ix_hash: oracle has USING but omni does not; oracle=%q omni=%q", hashMy, hashOmni) + } + }) +} + +// aLine returns the line from multi-line text s that contains needle. Empty +// string if no match. Used by C18 to grab a single column/index line out of +// SHOW CREATE TABLE output for per-line substring checks. +func aLine(s, needle string) string { + for _, line := range strings.Split(s, "\n") { + if strings.Contains(line, needle) { + return line + } + } + return "" +} diff --git a/tidb/catalog/scenarios_c19_test.go b/tidb/catalog/scenarios_c19_test.go new file mode 100644 index 00000000..84741a40 --- /dev/null +++ b/tidb/catalog/scenarios_c19_test.go @@ -0,0 +1,424 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C19 covers Section C19 "Virtual / functional indexes" from +// mysql/catalog/SCENARIOS-mysql-implicit-behavior.md. MySQL 8.0.13+ implements +// functional index key parts by synthesizing a hidden VIRTUAL generated +// column over the expression and building an ordinary index over it. omni +// currently represents functional key parts as an `Expr` string on +// `IndexColumn` with NO synthesized hidden Column, which means: +// +// - type inference on the expression (19.2) has nowhere to store its result +// - hidden-column suppression (19.3) is vacuously "fine" but untestable +// - deterministic-only / no-LOB validation (19.4) is not enforced at all +// - DROP INDEX "cascade" (19.6) is vacuous for the same reason +// +// Every subtest here is expected to fail on the omni side and is documented +// in scenarios_bug_queue/c19.md. We use t.Error (not t.Fatal) so all six +// scenarios run in one pass. +func TestScenario_C19(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // ----------------------------------------------------------------- + // 19.1 Functional index creates a hidden VIRTUAL generated column + // ----------------------------------------------------------------- + t.Run("19_1_hidden_virtual_column_created", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(64))") + runOnBoth(t, mc, c, "CREATE INDEX idx_lower ON t ((LOWER(name)))") + + // Oracle: information_schema.STATISTICS exposes the functional + // expression with COLUMN_NAME=NULL. + rows := oracleRows(t, mc, ` + SELECT COLUMN_NAME, EXPRESSION + FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND INDEX_NAME='idx_lower'`) + if len(rows) != 1 { + t.Errorf("oracle: expected 1 STATISTICS row for idx_lower, got %d", len(rows)) + } else { + colName := rows[0][0] + expr := asString(rows[0][1]) + if colName != nil { + t.Errorf("oracle: functional index COLUMN_NAME should be NULL, got %v", colName) + } + if !strings.Contains(strings.ToLower(expr), "lower(`name`)") { + t.Errorf("oracle: STATISTICS.EXPRESSION should contain lower(`name`), got %q", expr) + } + } + + // Oracle: SHOW CREATE renders functional key part as ((expr)). + mysqlCreate := oracleShow(t, mc, "SHOW CREATE TABLE t") + if !strings.Contains(strings.ToLower(mysqlCreate), "((lower(`name`)))") { + t.Errorf("oracle: SHOW CREATE should contain ((lower(`name`))); got %q", mysqlCreate) + } + + // omni: expose the hidden functional column. omni has no Hidden + // flag today, so we assert on what's observable: the IndexColumn + // should carry an expression and SHOW CREATE should match the + // oracle's ((expr)) form byte-for-byte. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("omni: table t not found") + } + var idx *Index + for _, i := range tbl.Indexes { + if strings.EqualFold(i.Name, "idx_lower") { + idx = i + break + } + } + if idx == nil { + t.Fatal("omni: idx_lower not found") + } + if len(idx.Columns) != 1 { + t.Errorf("omni: idx_lower expected 1 key part, got %d", len(idx.Columns)) + } else if idx.Columns[0].Expr == "" { + t.Errorf("omni: idx_lower key part has no Expr; MySQL stores this as a hidden generated column") + } + // omni gap: no hidden column is created alongside the index. + // MySQL's dd.columns.is_hidden=HT_HIDDEN_SQL row has no omni analog. + // We flag this as a bug: the count of user-visible columns stays at + // 2 in both engines, but omni's Column list should gain a + // HiddenBySystem entry for round-trip fidelity. Currently it does + // not — confirm with a direct probe. + hiddenFound := false + for _, col := range tbl.Columns { + if strings.HasPrefix(col.Name, "!hidden!") { + hiddenFound = true + break + } + } + if hiddenFound { + t.Log("omni: unexpectedly found hidden column — partial support?") + } else { + t.Errorf("omni: no hidden column synthesized for functional index (MySQL: !hidden!idx_lower!0!0)") + } + + // Byte-exact SHOW CREATE comparison on the key part line. + omniCreate := c.ShowCreateTable("testdb", "t") + myKey := c19Line(mysqlCreate, "idx_lower") + omniKey := c19Line(omniCreate, "idx_lower") + if myKey != omniKey { + t.Errorf("omni idx_lower key-part render differs from oracle:\noracle: %q\nomni: %q", myKey, omniKey) + } + }) + + // ----------------------------------------------------------------- + // 19.2 Hidden column type inferred from expression return type + // ----------------------------------------------------------------- + t.Run("19_2_type_inferred_from_expression", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t ( + a INT, b INT, + name VARCHAR(64) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci, + payload JSON, + INDEX k_sum ((a + b)), + INDEX k_low ((LOWER(name))), + INDEX k_cast ((CAST(payload->'$.age' AS UNSIGNED))) + )` + runOnBoth(t, mc, c, ddl) + + // Oracle: information_schema.STATISTICS should list all three + // functional indexes, each with COLUMN_NAME=NULL and a non-empty + // EXPRESSION. MySQL uses these to inform optimizer decisions. + rows := oracleRows(t, mc, ` + SELECT INDEX_NAME, EXPRESSION + FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' + ORDER BY INDEX_NAME`) + if len(rows) != 3 { + t.Errorf("oracle: expected 3 functional index rows, got %d", len(rows)) + } + + // omni: the only way to verify the inferred type today is to look + // for a synthesized hidden column. None exists, so this will fail. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("omni: table t not found") + } + for _, want := range []struct { + idx string + wantType string // expected hidden-column DataType + }{ + {"k_sum", "bigint"}, + {"k_low", "varchar"}, + {"k_cast", "bigint"}, + } { + var hiddenCol *Column + for _, col := range tbl.Columns { + if strings.Contains(col.Name, want.idx) && strings.HasPrefix(col.Name, "!hidden!") { + hiddenCol = col + break + } + } + if hiddenCol == nil { + t.Errorf("omni %s: no hidden column synthesized; MySQL types this as %s", want.idx, want.wantType) + continue + } + if !strings.EqualFold(hiddenCol.DataType, want.wantType) { + t.Errorf("omni %s: hidden col type %q, want %q", want.idx, hiddenCol.DataType, want.wantType) + } + } + }) + + // ----------------------------------------------------------------- + // 19.3 Hidden column suppressed in SELECT * and user I_S.COLUMNS + // ----------------------------------------------------------------- + t.Run("19_3_hidden_suppressed_in_select_star", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(64))") + runOnBoth(t, mc, c, "CREATE INDEX idx_lower ON t ((LOWER(name)))") + + // Oracle: user-scoped information_schema.COLUMNS does NOT list the + // hidden column. Only (id, name) should appear. + rows := oracleRows(t, mc, ` + SELECT COLUMN_NAME FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' + ORDER BY ORDINAL_POSITION`) + if len(rows) != 2 { + t.Errorf("oracle: expected 2 visible columns, got %d", len(rows)) + } + for _, r := range rows { + name := asString(r[0]) + if strings.HasPrefix(name, "!hidden!") { + t.Errorf("oracle: user I_S.COLUMNS leaked hidden column %q", name) + } + } + + // Oracle: STATISTICS still records the expression (confirming the + // hidden column exists at the storage layer). + stats := oracleRows(t, mc, ` + SELECT EXPRESSION FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND INDEX_NAME='idx_lower'`) + if len(stats) != 1 || asString(stats[0][0]) == "" { + t.Errorf("oracle: STATISTICS.EXPRESSION missing for idx_lower; got %v", stats) + } + + // omni: the catalog should expose exactly 2 user-visible columns + // AND the hidden column must not leak into the visible list. Since + // omni synthesizes no hidden column at all, visible-count is + // accidentally correct, but the storage-layer invariant (that a + // hidden column exists and is queryable via an internal API) is + // missing. Assert the hidden column is discoverable — this is the + // bug to track. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("omni: table t not found") + } + var visible, hidden int + for _, col := range tbl.Columns { + if strings.HasPrefix(col.Name, "!hidden!") { + hidden++ + } else { + visible++ + } + } + if visible != 2 { + t.Errorf("omni: visible column count = %d, want 2", visible) + } + if hidden != 1 { + t.Errorf("omni: hidden column count = %d, want 1 (HT_HIDDEN_SQL for idx_lower)", hidden) + } + }) + + // ----------------------------------------------------------------- + // 19.4 Functional expression must be deterministic / non-LOB + // ----------------------------------------------------------------- + t.Run("19_4_disallowed_expression_rejected", func(t *testing.T) { + scenarioReset(t, mc) + + // Oracle verification: run each bad DDL against MySQL directly + // and confirm it's rejected. omni should reject the same. + cases := []struct { + label string + ddl string + mysqlErrSubstr string // expected substring of oracle error text + }{ + { + "rand_disallowed", + "CREATE TABLE t4a (a INT, INDEX ((a + RAND())))", + "disallowed function", + }, + { + "bare_column_rejected", + "CREATE TABLE t4b (a INT, INDEX ((a)))", + "Functional index on a column", + }, + { + // MySQL 8.0.45 actually reports ER 3753 "JSON or GEOMETRY" + // for `->` (JSON return type). The broader substring + // "functional index" covers both 3753 and 3754. + "lob_json_rejected", + "CREATE TABLE t4c (payload JSON, INDEX ((payload->'$.name')))", + "functional index", + }, + } + for _, tc := range cases { + t.Run(tc.label, func(t *testing.T) { + // Oracle should reject. + _, mysqlErr := mc.db.ExecContext(mc.ctx, tc.ddl) + if mysqlErr == nil { + t.Errorf("oracle: %s DDL unexpectedly succeeded", tc.label) + } else if !strings.Contains(mysqlErr.Error(), tc.mysqlErrSubstr) { + t.Errorf("oracle: %s error = %q, want substring %q", tc.label, mysqlErr.Error(), tc.mysqlErrSubstr) + } + + // omni should reject as well. Use a fresh catalog so + // earlier cases don't pollute. + cc := scenarioNewCatalog(t) + results, parseErr := cc.Exec(tc.ddl, nil) + rejected := false + if parseErr != nil { + rejected = true + } else { + for _, r := range results { + if r.Error != nil { + rejected = true + break + } + } + } + if !rejected { + t.Errorf("omni: %s DDL was accepted; MySQL rejects it with %q", + tc.label, tc.mysqlErrSubstr) + } + }) + } + }) + + // ----------------------------------------------------------------- + // 19.5 Functional index on JSON path via (col->>'$.path') + // ----------------------------------------------------------------- + t.Run("19_5_json_path_functional_index", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t ( + id INT PRIMARY KEY, + doc JSON, + INDEX idx_name ((CAST(doc->>'$.name' AS CHAR(64)))) + )` + runOnBoth(t, mc, c, ddl) + + mysqlCreate := oracleShow(t, mc, "SHOW CREATE TABLE t") + omniCreate := c.ShowCreateTable("testdb", "t") + + // Oracle should render the full cast expression with ->> form. + if !strings.Contains(strings.ToLower(mysqlCreate), "cast(") || + !strings.Contains(mysqlCreate, "idx_name") { + t.Errorf("oracle: expected idx_name with CAST(...) clause; got %q", mysqlCreate) + } + + // omni: key-part line must match the oracle byte-for-byte. This is + // the round-trip test the scenario calls out as "the key test". + myKey := c19Line(mysqlCreate, "idx_name") + omniKey := c19Line(omniCreate, "idx_name") + if myKey != omniKey { + t.Errorf("omni idx_name render differs from oracle (round-trip test):\noracle: %q\nomni: %q", + myKey, omniKey) + } + + // Plain `doc->>'$.name'` with no CAST must be rejected (LOB return). + badDDL := "CREATE TABLE t_bad (doc JSON, INDEX ((doc->>'$.name')))" + if _, err := mc.db.ExecContext(mc.ctx, badDDL); err == nil { + t.Errorf("oracle: uncast ->> should be rejected as LOB") + } else if !strings.Contains(err.Error(), "functional index") { + t.Errorf("oracle: unexpected error for uncast ->>: %v", err) + } + cc := scenarioNewCatalog(t) + results, parseErr := cc.Exec(badDDL, nil) + rejected := parseErr != nil + for _, r := range results { + if r.Error != nil { + rejected = true + } + } + if !rejected { + t.Errorf("omni: uncast ->> in functional index was accepted; MySQL rejects as LOB") + } + }) + + // ----------------------------------------------------------------- + // 19.6 DROP INDEX cascades to hidden generated column + // ----------------------------------------------------------------- + t.Run("19_6_drop_index_cascades_hidden", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(64))") + runOnBoth(t, mc, c, "CREATE INDEX idx_lower ON t ((LOWER(name)))") + runOnBoth(t, mc, c, "DROP INDEX idx_lower ON t") + + // Oracle: SHOW CREATE must show no trace of idx_lower or any + // hidden column. + mysqlCreate := oracleShow(t, mc, "SHOW CREATE TABLE t") + if strings.Contains(mysqlCreate, "idx_lower") { + t.Errorf("oracle: idx_lower still present after DROP INDEX: %q", mysqlCreate) + } + if strings.Contains(mysqlCreate, "!hidden!") { + t.Errorf("oracle: hidden column leaked after DROP INDEX: %q", mysqlCreate) + } + + // omni: same round-trip assertion. + omniCreate := c.ShowCreateTable("testdb", "t") + if strings.Contains(omniCreate, "idx_lower") { + t.Errorf("omni: idx_lower still present after DROP INDEX: %q", omniCreate) + } + if strings.Contains(omniCreate, "!hidden!") { + t.Errorf("omni: hidden column leaked after DROP INDEX: %q", omniCreate) + } + + // Oracle: dropping the hidden column by name must fail with 3108. + // First recreate the functional index so there's a hidden column + // to try to drop. + if _, err := mc.db.ExecContext(mc.ctx, "CREATE INDEX idx_lower ON t ((LOWER(name)))"); err != nil { + t.Fatalf("oracle: recreating idx_lower: %v", err) + } + _, dropErr := mc.db.ExecContext(mc.ctx, "ALTER TABLE t DROP COLUMN `!hidden!idx_lower!0!0`") + if dropErr == nil { + t.Errorf("oracle: dropping hidden column should be rejected with ER 3108") + } else if !strings.Contains(dropErr.Error(), "functional index") { + t.Errorf("oracle: unexpected error for hidden-column drop: %v", dropErr) + } + + // omni: the same ALTER. omni has no hidden column, so it should + // either reject the column name as "unknown" OR reject it with a + // 3108-equivalent error. Accepting the DROP COLUMN silently is + // the bug we want to catch. + results, parseErr := c.Exec("ALTER TABLE t DROP COLUMN `!hidden!idx_lower!0!0`", nil) + rejected := parseErr != nil + for _, r := range results { + if r.Error != nil { + rejected = true + } + } + if !rejected { + t.Errorf("omni: DROP COLUMN `!hidden!idx_lower!0!0` silently accepted; MySQL returns ER_CANNOT_DROP_COLUMN_FUNCTIONAL_INDEX (3108)") + } + }) +} + +// c19Line returns the line from s that contains needle, trimmed of the +// trailing comma some SHOW CREATE outputs use. Empty string if no match. +func c19Line(s, needle string) string { + for _, line := range strings.Split(s, "\n") { + if strings.Contains(line, needle) { + return strings.TrimRight(strings.TrimSpace(line), ",") + } + } + return "" +} diff --git a/tidb/catalog/scenarios_c1_test.go b/tidb/catalog/scenarios_c1_test.go new file mode 100644 index 00000000..3db3aba5 --- /dev/null +++ b/tidb/catalog/scenarios_c1_test.go @@ -0,0 +1,483 @@ +package catalog + +import ( + "sort" + "strings" + "testing" +) + +// TestScenario_C1 covers section C1 (Name auto-generation) from +// SCENARIOS-mysql-implicit-behavior.md. Each subtest asserts that +// both real MySQL 8.0 and the omni catalog agree on the auto-generated +// name for a given DDL input. +// +// Failures in omni assertions are NOT proof failures — they are +// recorded in mysql/catalog/scenarios_bug_queue/c1.md. +func TestScenario_C1(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // --- 1.1 Foreign Key name — CREATE path (fresh counter) -------------- + t.Run("1_1_FK_name_create_path", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE p (id INT PRIMARY KEY); +CREATE TABLE child ( + a INT, CONSTRAINT child_ibfk_5 FOREIGN KEY (a) REFERENCES p(id), + b INT, FOREIGN KEY (b) REFERENCES p(id) +);` + runOnBoth(t, mc, c, ddl) + + // Oracle + got := oracleFKNames(t, mc, "child") + want := []string{"child_ibfk_1", "child_ibfk_5"} + assertStringEq(t, "oracle FK names", strings.Join(got, ","), strings.Join(want, ",")) + + // omni + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Errorf("omni: table child missing") + return + } + omniNames := omniFKNames(tbl) + assertStringEq(t, "omni FK names", strings.Join(omniNames, ","), strings.Join(want, ",")) + }) + + // --- 1.2 Foreign Key name — ALTER path (max+1 counter) --------------- + t.Run("1_2_FK_name_alter_path", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE p (id INT PRIMARY KEY); +CREATE TABLE child ( + a INT, b INT, + CONSTRAINT child_ibfk_20 FOREIGN KEY (a) REFERENCES p(id) +); +ALTER TABLE child ADD FOREIGN KEY (b) REFERENCES p(id);` + runOnBoth(t, mc, c, ddl) + + got := oracleFKNames(t, mc, "child") + want := []string{"child_ibfk_20", "child_ibfk_21"} + assertStringEq(t, "oracle FK names", strings.Join(got, ","), strings.Join(want, ",")) + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Errorf("omni: table child missing") + return + } + omniNames := omniFKNames(tbl) + assertStringEq(t, "omni FK names", strings.Join(omniNames, ","), strings.Join(want, ",")) + }) + + // --- 1.3 Partition default naming p0..p{n-1} ------------------------- + t.Run("1_3_partition_default_names", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (id INT) PARTITION BY HASH(id) PARTITIONS 4;`) + + got := oraclePartitionNames(t, mc, "t") + want := []string{"p0", "p1", "p2", "p3"} + assertStringEq(t, "oracle partition names", strings.Join(got, ","), strings.Join(want, ",")) + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + var omniNames []string + if tbl.Partitioning != nil { + for _, p := range tbl.Partitioning.Partitions { + omniNames = append(omniNames, p.Name) + } + } + assertStringEq(t, "omni partition names", strings.Join(omniNames, ","), strings.Join(want, ",")) + }) + + // --- 1.4 CHECK constraint auto-name (t_chk_1) ------------------------ + t.Run("1_4_check_auto_name", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, CHECK (a > 0));`) + + var got string + oracleScan(t, mc, `SELECT CONSTRAINT_NAME FROM information_schema.CHECK_CONSTRAINTS + WHERE CONSTRAINT_SCHEMA='testdb'`, &got) + assertStringEq(t, "oracle CHECK name", got, "t_chk_1") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + omniChecks := omniCheckNames(tbl) + assertStringEq(t, "omni CHECK names", strings.Join(omniChecks, ","), "t_chk_1") + }) + + // --- 1.5 UNIQUE KEY auto-name uses field name ------------------------ + t.Run("1_5_unique_auto_name_field", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, UNIQUE KEY (a));`) + + var got string + oracleScan(t, mc, `SELECT INDEX_NAME FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND NON_UNIQUE=0`, &got) + assertStringEq(t, "oracle UNIQUE name", got, "a") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + omniIdx := omniUniqueIndexNames(tbl) + assertStringEq(t, "omni unique index names", strings.Join(omniIdx, ","), "a") + }) + + // --- 1.6 UNIQUE KEY name collision appends _2 ------------------------ + t.Run("1_6_unique_collision_suffix", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, UNIQUE KEY a (a), UNIQUE KEY (a));`) + + rows := oracleRows(t, mc, `SELECT INDEX_NAME FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND NON_UNIQUE=0 + ORDER BY INDEX_NAME`) + var got []string + for _, r := range rows { + got = append(got, asString(r[0])) + } + want := []string{"a", "a_2"} + assertStringEq(t, "oracle unique index names", strings.Join(got, ","), strings.Join(want, ",")) + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + omniIdx := omniUniqueIndexNames(tbl) + sort.Strings(omniIdx) + assertStringEq(t, "omni unique index names", strings.Join(omniIdx, ","), strings.Join(want, ",")) + }) + + // --- 1.7 PRIMARY KEY always named PRIMARY ---------------------------- + t.Run("1_7_primary_key_always_named_primary", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a INT, + CONSTRAINT my_pk PRIMARY KEY (a) +);`) + + var got string + oracleScan(t, mc, `SELECT CONSTRAINT_NAME FROM information_schema.TABLE_CONSTRAINTS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND CONSTRAINT_TYPE='PRIMARY KEY'`, &got) + assertStringEq(t, "oracle PK constraint name", got, "PRIMARY") + + var idxName string + oracleScan(t, mc, `SELECT INDEX_NAME FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND INDEX_NAME='PRIMARY'`, &idxName) + assertStringEq(t, "oracle PK index name", idxName, "PRIMARY") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + // No constraint named my_pk should exist. + for _, con := range tbl.Constraints { + if strings.EqualFold(con.Name, "my_pk") { + t.Errorf("omni: unexpected constraint named %q (PK should be renamed to PRIMARY)", con.Name) + } + } + // PK index should be named PRIMARY. + var pkIdxName string + for _, idx := range tbl.Indexes { + if idx.Primary { + pkIdxName = idx.Name + break + } + } + assertStringEq(t, "omni PK index name", pkIdxName, "PRIMARY") + }) + + // --- 1.8 Non-PK index cannot be named PRIMARY ------------------------ + t.Run("1_8_primary_name_reserved", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // Oracle: expect error from MySQL. + _, mysqlErr := mc.db.ExecContext(mc.ctx, `CREATE TABLE t (a INT, UNIQUE KEY `+"`PRIMARY`"+` (a))`) + if mysqlErr == nil { + t.Errorf("oracle: expected error for UNIQUE KEY `PRIMARY`, got nil") + } else if !strings.Contains(mysqlErr.Error(), "1280") && !strings.Contains(strings.ToLower(mysqlErr.Error()), "incorrect index name") { + t.Errorf("oracle: expected ER_WRONG_NAME_FOR_INDEX (1280), got %v", mysqlErr) + } + + // omni: should also reject. Use fresh catalog per attempt. + results, err := c.Exec("CREATE TABLE t (a INT, UNIQUE KEY `PRIMARY` (a));", nil) + omniErrored := err != nil + if !omniErrored { + for _, r := range results { + if r.Error != nil { + omniErrored = true + break + } + } + } + assertBoolEq(t, "omni rejects index named PRIMARY", omniErrored, true) + }) + + // --- 1.9 Implicit index name from first key column ------------------- + t.Run("1_9_implicit_index_name_first_col", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a INT, + b INT, + c INT, + KEY (b, c) +);`) + + rows := oracleRows(t, mc, `SELECT INDEX_NAME, COLUMN_NAME FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' + ORDER BY INDEX_NAME, SEQ_IN_INDEX`) + if len(rows) < 2 { + t.Errorf("oracle: expected 2 STATISTICS rows, got %d", len(rows)) + } + if len(rows) >= 1 { + assertStringEq(t, "oracle index name", asString(rows[0][0]), "b") + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + var firstNonPKIdx *Index + for _, idx := range tbl.Indexes { + if !idx.Primary { + firstNonPKIdx = idx + break + } + } + if firstNonPKIdx == nil { + t.Errorf("omni: expected one non-PK index") + } else { + assertStringEq(t, "omni index name", firstNonPKIdx.Name, "b") + if len(firstNonPKIdx.Columns) != 2 || + firstNonPKIdx.Columns[0].Name != "b" || + firstNonPKIdx.Columns[1].Name != "c" { + t.Errorf("omni: expected columns [b,c], got %+v", firstNonPKIdx.Columns) + } + } + }) + + // --- 1.10 UNIQUE name fallback when first column is "PRIMARY" -------- + t.Run("1_10_unique_primary_column_fallback", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := "CREATE TABLE t (`PRIMARY` INT);\n" + + "ALTER TABLE t ADD UNIQUE KEY (`PRIMARY`);" + runOnBoth(t, mc, c, ddl) + + rows := oracleRows(t, mc, `SELECT INDEX_NAME FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND NON_UNIQUE=0`) + if len(rows) != 1 { + t.Errorf("oracle: expected 1 unique index row, got %d", len(rows)) + } + if len(rows) == 1 { + assertStringEq(t, "oracle unique index name", asString(rows[0][0]), "PRIMARY_2") + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + omniIdx := omniUniqueIndexNames(tbl) + assertStringEq(t, "omni unique index names", strings.Join(omniIdx, ","), "PRIMARY_2") + }) + + // --- 1.11 Functional index auto-name functional_index[_N] ------------ + t.Run("1_11_functional_index_auto_name", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t ( + a INT, + INDEX ((a + 1)), + INDEX ((a * 2)) +);` + runOnBoth(t, mc, c, ddl) + + rows := oracleRows(t, mc, `SELECT DISTINCT INDEX_NAME FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' + ORDER BY INDEX_NAME`) + var got []string + for _, r := range rows { + got = append(got, asString(r[0])) + } + want := []string{"functional_index", "functional_index_2"} + assertStringEq(t, "oracle functional index names", strings.Join(got, ","), strings.Join(want, ",")) + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + // omni almost certainly fails to parse/load functional indexes. + t.Errorf("omni: table t missing (functional index support gap)") + return + } + var omniIdxNames []string + for _, idx := range tbl.Indexes { + if !idx.Primary { + omniIdxNames = append(omniIdxNames, idx.Name) + } + } + sort.Strings(omniIdxNames) + assertStringEq(t, "omni functional index names", strings.Join(omniIdxNames, ","), strings.Join(want, ",")) + }) + + // --- 1.12 Functional index hidden generated column name -------------- + t.Run("1_12_functional_index_hidden_col", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t (a INT, INDEX fx ((a + 1), (a * 2)));` + runOnBoth(t, mc, c, ddl) + + // Oracle: hidden columns are not in information_schema.COLUMNS by + // default. Use a SELECT from the performance_schema / dd dumps via + // SHOW CREATE TABLE as a loose check. Full name verification needs + // a data-dictionary dump which we can't easily do from Go; document + // and fall back to SHOW CREATE TABLE containing the expression. + create := oracleShow(t, mc, "SHOW CREATE TABLE t") + if !strings.Contains(create, "`fx`") { + t.Errorf("oracle: SHOW CREATE TABLE missing index fx: %s", create) + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing (functional index support gap)") + return + } + // omni will almost certainly not synthesize hidden generated columns + // with MySQL's !hidden! naming scheme. Check and report. + hasHidden := false + for _, col := range tbl.Columns { + if strings.HasPrefix(col.Name, "!hidden!") { + hasHidden = true + break + } + } + assertBoolEq(t, "omni has hidden functional col", hasHidden, true) + }) + + // --- 1.13 CHECK constraint name is schema-scoped --------------------- + t.Run("1_13_check_name_schema_scoped", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // First table with named check: both should accept. + runOnBoth(t, mc, c, `CREATE TABLE t1 (a INT, CONSTRAINT mychk CHECK (a > 0));`) + + // Second table with duplicate named check: MySQL errors with 3822. + _, mysqlErr := mc.db.ExecContext(mc.ctx, + `CREATE TABLE t2 (a INT, CONSTRAINT mychk CHECK (a < 100))`) + if mysqlErr == nil { + t.Errorf("oracle: expected ER_CHECK_CONSTRAINT_DUP_NAME, got nil") + } else if !strings.Contains(mysqlErr.Error(), "3822") && + !strings.Contains(strings.ToLower(mysqlErr.Error()), "duplicate check constraint") { + t.Errorf("oracle: expected 3822 Duplicate check constraint name, got %v", mysqlErr) + } + + // omni: should also error. + results, err := c.Exec( + `CREATE TABLE t2 (a INT, CONSTRAINT mychk CHECK (a < 100));`, nil) + omniErrored := err != nil + if !omniErrored { + for _, r := range results { + if r.Error != nil { + omniErrored = true + break + } + } + } + assertBoolEq(t, "omni rejects cross-table duplicate CHECK name", omniErrored, true) + }) +} + +// --- section-local helpers ------------------------------------------------ + +// oracleFKNames returns the FK constraint names on the given table, ordered +// alphabetically by CONSTRAINT_NAME. +func oracleFKNames(t *testing.T, mc *mysqlContainer, tableName string) []string { + t.Helper() + rows := oracleRows(t, mc, `SELECT CONSTRAINT_NAME FROM information_schema.TABLE_CONSTRAINTS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='`+tableName+`' + AND CONSTRAINT_TYPE='FOREIGN KEY' + ORDER BY CONSTRAINT_NAME`) + var out []string + for _, r := range rows { + out = append(out, asString(r[0])) + } + return out +} + +// oraclePartitionNames returns partition names ordered by ordinal position. +func oraclePartitionNames(t *testing.T, mc *mysqlContainer, tableName string) []string { + t.Helper() + rows := oracleRows(t, mc, `SELECT PARTITION_NAME FROM information_schema.PARTITIONS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='`+tableName+`' + ORDER BY PARTITION_ORDINAL_POSITION`) + var out []string + for _, r := range rows { + out = append(out, asString(r[0])) + } + return out +} + +// omniFKNames returns FK constraint names for the table sorted alphabetically. +func omniFKNames(tbl *Table) []string { + var out []string + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey { + out = append(out, con.Name) + } + } + sort.Strings(out) + return out +} + +// omniCheckNames returns check constraint names sorted alphabetically. +func omniCheckNames(tbl *Table) []string { + var out []string + for _, con := range tbl.Constraints { + if con.Type == ConCheck { + out = append(out, con.Name) + } + } + sort.Strings(out) + return out +} + +// omniUniqueIndexNames returns names of unique indexes (excluding PK) sorted. +func omniUniqueIndexNames(tbl *Table) []string { + var out []string + for _, idx := range tbl.Indexes { + if idx.Unique && !idx.Primary { + out = append(out, idx.Name) + } + } + sort.Strings(out) + return out +} diff --git a/tidb/catalog/scenarios_c20_test.go b/tidb/catalog/scenarios_c20_test.go new file mode 100644 index 00000000..94d54427 --- /dev/null +++ b/tidb/catalog/scenarios_c20_test.go @@ -0,0 +1,424 @@ +package catalog + +import ( + "database/sql" + "strings" + "testing" +) + +// TestScenario_C20 covers Section C20 "Field-type-specific implicit defaults" +// from mysql/catalog/SCENARIOS-mysql-implicit-behavior.md. Each subtest runs +// the scenario's DDL on both a MySQL 8.0 container and omni's catalog, then +// asserts: +// +// 1. The catalog state (Column.Default pointer, Nullable) — no implicit +// default synthesized into the AST. +// 2. information_schema.COLUMNS.COLUMN_DEFAULT and SHOW CREATE TABLE +// rendering match between oracle and omni. +// +// For sections 20.6 / 20.8 (error scenarios), both the oracle and omni +// must reject the DDL — we compare only the fact that an error is raised, +// not the exact message. +// +// All failures use t.Error rather than t.Fatal so the whole section runs +// and each omni gap is captured in scenarios_bug_queue/c20.md. +func TestScenario_C20(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // --- 20.1 INT NOT NULL, no DEFAULT → implicit 0 ----------------------- + t.Run("20_1_int_notnull_no_default", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE c20_1 (id INT NOT NULL)") + + // Oracle: COLUMN_DEFAULT is NULL (no explicit default stored). + colDef := c20oracleColumnDefault(t, mc, "c20_1", "id") + if colDef.Valid { + t.Errorf("oracle: c20_1.id COLUMN_DEFAULT expected NULL, got %q", colDef.String) + } + // Oracle SHOW CREATE must NOT render DEFAULT 0. + my := oracleShow(t, mc, "SHOW CREATE TABLE c20_1") + if strings.Contains(aLine(my, "`id`"), "DEFAULT") { + t.Errorf("oracle: c20_1 SHOW CREATE should not render DEFAULT on id; got %q", aLine(my, "`id`")) + } + + // omni catalog state + col := c20getColumn(t, c, "c20_1", "id") + if col == nil { + return + } + if col.Default != nil { + t.Errorf("omni: c20_1.id Default expected nil, got %q", *col.Default) + } + if col.Nullable { + t.Errorf("omni: c20_1.id Nullable expected false, got true") + } + // omni SHOW CREATE must NOT render DEFAULT. + omniCreate := c.ShowCreateTable("testdb", "c20_1") + if strings.Contains(aLine(omniCreate, "`id`"), "DEFAULT") { + t.Errorf("omni: c20_1 SHOW CREATE should not render DEFAULT on id; got %q", aLine(omniCreate, "`id`")) + } + }) + + // --- 20.2 INT nullable, no DEFAULT → synthesized DEFAULT NULL --------- + t.Run("20_2_int_nullable_default_null_synthesis", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE c20_2 (id INT)") + + // Oracle: SHOW CREATE renders "`id` int DEFAULT NULL". + my := oracleShow(t, mc, "SHOW CREATE TABLE c20_2") + idLine := aLine(my, "`id`") + if !strings.Contains(idLine, "DEFAULT NULL") { + t.Errorf("oracle: c20_2 SHOW CREATE expected `DEFAULT NULL` on id; got %q", idLine) + } + // information_schema.COLUMN_DEFAULT is NULL (no string default stored). + colDef := c20oracleColumnDefault(t, mc, "c20_2", "id") + if colDef.Valid { + t.Errorf("oracle: c20_2.id COLUMN_DEFAULT expected NULL, got %q", colDef.String) + } + + // omni catalog: Default nil, Nullable true. + col := c20getColumn(t, c, "c20_2", "id") + if col == nil { + return + } + if col.Default != nil { + t.Errorf("omni: c20_2.id Default expected nil, got %q", *col.Default) + } + if !col.Nullable { + t.Errorf("omni: c20_2.id Nullable expected true, got false") + } + // omni deparse must synthesize DEFAULT NULL. + omniCreate := c.ShowCreateTable("testdb", "c20_2") + if !strings.Contains(aLine(omniCreate, "`id`"), "DEFAULT NULL") { + t.Errorf("omni: c20_2 SHOW CREATE expected DEFAULT NULL on id; got %q", aLine(omniCreate, "`id`")) + } + }) + + // --- 20.3 VARCHAR/CHAR NOT NULL, no DEFAULT → implicit '' ------------- + t.Run("20_3_string_notnull_no_default", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE c20_3 (name VARCHAR(64) NOT NULL, code CHAR(4) NOT NULL)") + + for _, cname := range []string{"name", "code"} { + colDef := c20oracleColumnDefault(t, mc, "c20_3", cname) + if colDef.Valid { + t.Errorf("oracle: c20_3.%s COLUMN_DEFAULT expected NULL, got %q", cname, colDef.String) + } + col := c20getColumn(t, c, "c20_3", cname) + if col == nil { + continue + } + if col.Default != nil { + t.Errorf("omni: c20_3.%s Default expected nil, got %q", cname, *col.Default) + } + if col.Nullable { + t.Errorf("omni: c20_3.%s Nullable expected false", cname) + } + } + + my := oracleShow(t, mc, "SHOW CREATE TABLE c20_3") + omniCreate := c.ShowCreateTable("testdb", "c20_3") + for _, cname := range []string{"`name`", "`code`"} { + if strings.Contains(aLine(my, cname), "DEFAULT") { + t.Errorf("oracle: c20_3 %s line should not render DEFAULT; got %q", cname, aLine(my, cname)) + } + if strings.Contains(aLine(omniCreate, cname), "DEFAULT") { + t.Errorf("omni: c20_3 %s line should not render DEFAULT; got %q", cname, aLine(omniCreate, cname)) + } + } + }) + + // --- 20.4 ENUM NOT NULL (default=first) & nullable (default NULL) ----- + t.Run("20_4_enum_notnull_first_value", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE c20_4 ( + status ENUM('active','archived','deleted') NOT NULL, + kind ENUM('a','b','c') + )`) + + // Oracle: status COLUMN_DEFAULT is NULL (catalog property — even + // though runtime fills in 'active'). kind COLUMN_DEFAULT is NULL too. + for _, cname := range []string{"status", "kind"} { + colDef := c20oracleColumnDefault(t, mc, "c20_4", cname) + if colDef.Valid { + t.Errorf("oracle: c20_4.%s COLUMN_DEFAULT expected NULL, got %q", cname, colDef.String) + } + } + + // omni catalog state + if col := c20getColumn(t, c, "c20_4", "status"); col != nil { + if col.Default != nil { + t.Errorf("omni: c20_4.status Default expected nil, got %q", *col.Default) + } + if col.Nullable { + t.Errorf("omni: c20_4.status Nullable expected false") + } + } + if col := c20getColumn(t, c, "c20_4", "kind"); col != nil { + if col.Default != nil { + t.Errorf("omni: c20_4.kind Default expected nil, got %q", *col.Default) + } + if !col.Nullable { + t.Errorf("omni: c20_4.kind Nullable expected true") + } + } + + // Oracle SHOW CREATE: status has no DEFAULT, kind has DEFAULT NULL. + my := oracleShow(t, mc, "SHOW CREATE TABLE c20_4") + if strings.Contains(aLine(my, "`status`"), "DEFAULT") { + t.Errorf("oracle: c20_4 status should not render DEFAULT; got %q", aLine(my, "`status`")) + } + if !strings.Contains(aLine(my, "`kind`"), "DEFAULT NULL") { + t.Errorf("oracle: c20_4 kind expected DEFAULT NULL; got %q", aLine(my, "`kind`")) + } + + omniCreate := c.ShowCreateTable("testdb", "c20_4") + if strings.Contains(aLine(omniCreate, "`status`"), "DEFAULT") { + t.Errorf("omni: c20_4 status should not render DEFAULT; got %q", aLine(omniCreate, "`status`")) + } + if !strings.Contains(aLine(omniCreate, "`kind`"), "DEFAULT NULL") { + t.Errorf("omni: c20_4 kind expected DEFAULT NULL; got %q", aLine(omniCreate, "`kind`")) + } + }) + + // --- 20.5 DATETIME/DATE NOT NULL, no DEFAULT -------------------------- + t.Run("20_5_datetime_notnull_no_default", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE c20_5 ( + created_at DATETIME NOT NULL, + birthday DATE NOT NULL + )`) + + // Oracle: COLUMN_DEFAULT is NULL; catalog does not pre-apply zero-date. + for _, cname := range []string{"created_at", "birthday"} { + colDef := c20oracleColumnDefault(t, mc, "c20_5", cname) + if colDef.Valid { + t.Errorf("oracle: c20_5.%s COLUMN_DEFAULT expected NULL, got %q", cname, colDef.String) + } + } + + // Oracle SHOW CREATE: no DEFAULT rendered. + my := oracleShow(t, mc, "SHOW CREATE TABLE c20_5") + if strings.Contains(aLine(my, "`created_at`"), "DEFAULT") { + t.Errorf("oracle: c20_5 created_at should not render DEFAULT; got %q", aLine(my, "`created_at`")) + } + if strings.Contains(aLine(my, "`birthday`"), "DEFAULT") { + t.Errorf("oracle: c20_5 birthday should not render DEFAULT; got %q", aLine(my, "`birthday`")) + } + + for _, cname := range []string{"created_at", "birthday"} { + col := c20getColumn(t, c, "c20_5", cname) + if col == nil { + continue + } + if col.Default != nil { + t.Errorf("omni: c20_5.%s Default expected nil, got %q", cname, *col.Default) + } + if col.Nullable { + t.Errorf("omni: c20_5.%s Nullable expected false", cname) + } + } + + omniCreate := c.ShowCreateTable("testdb", "c20_5") + if strings.Contains(aLine(omniCreate, "`created_at`"), "DEFAULT") { + t.Errorf("omni: c20_5 created_at should not render DEFAULT; got %q", aLine(omniCreate, "`created_at`")) + } + if strings.Contains(aLine(omniCreate, "`birthday`"), "DEFAULT") { + t.Errorf("omni: c20_5 birthday should not render DEFAULT; got %q", aLine(omniCreate, "`birthday`")) + } + }) + + // --- 20.6 BLOB/TEXT/JSON/GEOMETRY literal DEFAULT → ER 1101 ----------- + t.Run("20_6_blob_text_literal_default_rejected", func(t *testing.T) { + type caze struct { + name string + ddl string + } + cases := []caze{ + {"c20_6a", "CREATE TABLE c20_6a (b BLOB DEFAULT 'abc')"}, + {"c20_6b", "CREATE TABLE c20_6b (t TEXT DEFAULT 'hello')"}, + {"c20_6c", "CREATE TABLE c20_6c (g GEOMETRY DEFAULT 'x')"}, + {"c20_6d", "CREATE TABLE c20_6d (j JSON DEFAULT '[]')"}, + } + for _, tc := range cases { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + _, mysqlErr := mc.db.ExecContext(mc.ctx, tc.ddl) + if mysqlErr == nil { + t.Errorf("oracle: expected ER_BLOB_CANT_HAVE_DEFAULT (1101) for %s, got nil", tc.name) + } else if !strings.Contains(mysqlErr.Error(), "1101") && + !strings.Contains(strings.ToLower(mysqlErr.Error()), "can't have a default value") { + t.Errorf("oracle: expected 1101 for %s, got %v", tc.name, mysqlErr) + } + + omniErrored := c20execExpectError(t, c, tc.ddl+";") + if !omniErrored { + t.Errorf("omni: expected rejection of %s DDL; got success", tc.name) + } + // No table row should exist. + if c20getTable(c, tc.name) != nil { + t.Errorf("omni: %s should not be in catalog after rejected DDL", tc.name) + } + } + }) + + // --- 20.7 JSON/BLOB expression DEFAULT (8.0.13+) accepted ------------- + t.Run("20_7_expression_default_accepted", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE c20_7 ( + id INT PRIMARY KEY, + tags JSON DEFAULT (JSON_ARRAY()), + meta JSON DEFAULT (JSON_OBJECT('v', 1)), + blob1 BLOB DEFAULT (SUBSTRING('abcdef', 1, 3)), + pt POINT DEFAULT (POINT(0, 0)), + uuid BINARY(16) DEFAULT (UUID_TO_BIN(UUID())) + )` + runOnBoth(t, mc, c, ddl) + + // Oracle: each column must have a non-NULL COLUMN_DEFAULT (the + // expression text) and EXTRA containing "DEFAULT_GENERATED". + expressionCols := []string{"tags", "meta", "blob1", "pt", "uuid"} + for _, cname := range expressionCols { + var colDef sql.NullString + var extra string + oracleScan(t, mc, `SELECT COLUMN_DEFAULT, EXTRA FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='c20_7' AND COLUMN_NAME='`+cname+`'`, + &colDef, &extra) + if !colDef.Valid || colDef.String == "" { + t.Errorf("oracle: c20_7.%s COLUMN_DEFAULT expected non-empty expression text, got NULL", cname) + } + if !strings.Contains(extra, "DEFAULT_GENERATED") { + t.Errorf("oracle: c20_7.%s EXTRA expected DEFAULT_GENERATED, got %q", cname, extra) + } + } + + // Oracle: tables exist in omni too. + if c20getTable(c, "c20_7") == nil { + t.Errorf("omni: c20_7 expected to exist in catalog after DDL accepted") + } + for _, cname := range expressionCols { + col := c20getColumn(t, c, "c20_7", cname) + if col == nil { + continue + } + if col.Default == nil || *col.Default == "" { + t.Errorf("omni: c20_7.%s Default expected non-nil expression, got nil/empty", cname) + } + } + }) + + // --- 20.8 Generated column with DEFAULT clause → error --------------- + t.Run("20_8_generated_with_default_rejected", func(t *testing.T) { + cases := []struct { + name string + ddl string + }{ + {"c20_8a", "CREATE TABLE c20_8a (a INT, b INT AS (a + 1) DEFAULT 0)"}, + {"c20_8b", "CREATE TABLE c20_8b (a INT, b INT GENERATED ALWAYS AS (a + 1) VIRTUAL DEFAULT 0)"}, + {"c20_8c", "CREATE TABLE c20_8c (a INT, b INT GENERATED ALWAYS AS (a + 1) STORED DEFAULT (a * 2))"}, + } + for _, tc := range cases { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + _, mysqlErr := mc.db.ExecContext(mc.ctx, tc.ddl) + if mysqlErr == nil { + t.Errorf("oracle: expected parse/usage error for %s, got nil", tc.name) + } else if !strings.Contains(mysqlErr.Error(), "1064") && + !strings.Contains(mysqlErr.Error(), "1221") && + !strings.Contains(strings.ToLower(mysqlErr.Error()), "default") { + t.Errorf("oracle: expected 1064/1221 for %s, got %v", tc.name, mysqlErr) + } + + omniErrored := c20execExpectError(t, c, tc.ddl+";") + if !omniErrored { + t.Errorf("omni: expected rejection of %s DDL; got success", tc.name) + } + if c20getTable(c, tc.name) != nil { + t.Errorf("omni: %s should not be in catalog after rejected DDL", tc.name) + } + } + }) +} + +// -------------------- C20 local helpers -------------------- + +// c20oracleColumnDefault reads information_schema.COLUMNS.COLUMN_DEFAULT for a +// column on the testdb database. Returns a sql.NullString so the caller can +// distinguish "no default stored" (Valid=false) from "default is empty string". +func c20oracleColumnDefault(t *testing.T, mc *mysqlContainer, table, column string) sql.NullString { + t.Helper() + var colDef sql.NullString + row := mc.db.QueryRowContext(mc.ctx, + `SELECT COLUMN_DEFAULT FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME=? AND COLUMN_NAME=?`, + table, column) + if err := row.Scan(&colDef); err != nil { + t.Errorf("c20oracleColumnDefault %s.%s: %v", table, column, err) + } + return colDef +} + +// c20getColumn returns the omni catalog column or nil (reporting t.Error). +func c20getColumn(t *testing.T, c *Catalog, table, column string) *Column { + t.Helper() + db := c.GetDatabase("testdb") + if db == nil { + t.Errorf("omni: database testdb missing") + return nil + } + tbl := db.GetTable(table) + if tbl == nil { + t.Errorf("omni: table %s missing", table) + return nil + } + col := tbl.GetColumn(column) + if col == nil { + t.Errorf("omni: column %s.%s missing", table, column) + return nil + } + return col +} + +// c20getTable returns the omni catalog table or nil without reporting (used +// to confirm absence after rejected DDL). +func c20getTable(c *Catalog, table string) *Table { + db := c.GetDatabase("testdb") + if db == nil { + return nil + } + return db.GetTable(table) +} + +// c20execExpectError runs a DDL against omni and returns true if any statement +// in the batch raised an error (parse or exec). +func c20execExpectError(t *testing.T, c *Catalog, ddl string) bool { + t.Helper() + results, err := c.Exec(ddl, nil) + if err != nil { + return true + } + for _, r := range results { + if r.Error != nil { + return true + } + } + return false +} diff --git a/tidb/catalog/scenarios_c21_test.go b/tidb/catalog/scenarios_c21_test.go new file mode 100644 index 00000000..07f2c68a --- /dev/null +++ b/tidb/catalog/scenarios_c21_test.go @@ -0,0 +1,545 @@ +package catalog + +import ( + "strings" + "testing" + + nodes "github.com/bytebase/omni/tidb/ast" + "github.com/bytebase/omni/tidb/parser" +) + +// TestScenario_C21 covers section C21 (Parser-level implicit defaults) from +// SCENARIOS-mysql-implicit-behavior.md. Each subtest asserts that both real +// MySQL 8.0 and the omni AST / catalog agree on the default value that the +// grammar fills in when the user omits a clause. +// +// Some of these scenarios are inherently parser-AST-level (JOIN type, ORDER +// direction, LIMIT offset, INSERT column list) so they assert against the +// parsed AST directly rather than via the catalog. DDL-level defaults +// (DEFAULT NULL, FK actions, CREATE INDEX USING, CREATE VIEW ALGORITHM, ENGINE) +// assert against both the container oracle and omni's catalog. +// +// Failures are recorded in mysql/catalog/scenarios_bug_queue/c21.md. +func TestScenario_C21(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // --- 21.1 DEFAULT without value on nullable column -> DEFAULT NULL ---- + t.Run("21_1_default_null_on_nullable", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t (a INT DEFAULT NULL); +CREATE TABLE t2 (c INT);` + runOnBoth(t, mc, c, ddl) + + // Oracle: column a has COLUMN_DEFAULT NULL, IS_NULLABLE YES. + var aDefault *string + var aNullable string + oracleScan(t, mc, + `SELECT COLUMN_DEFAULT, IS_NULLABLE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='a'`, + &aDefault, &aNullable) + if aDefault != nil { + t.Errorf("oracle: expected COLUMN_DEFAULT NULL for t.a, got %q", *aDefault) + } + assertStringEq(t, "oracle t.a IS_NULLABLE", aNullable, "YES") + + // Oracle: column c (no DEFAULT clause at all) also has DEFAULT NULL. + var cDefault *string + var cNullable string + oracleScan(t, mc, + `SELECT COLUMN_DEFAULT, IS_NULLABLE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t2' AND COLUMN_NAME='c'`, + &cDefault, &cNullable) + if cDefault != nil { + t.Errorf("oracle: expected COLUMN_DEFAULT NULL for t2.c, got %q", *cDefault) + } + assertStringEq(t, "oracle t2.c IS_NULLABLE", cNullable, "YES") + + // omni: t.a explicit DEFAULT NULL — omni should model this as Nullable + // true and either Default == nil or Default pointing at "NULL". + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + } else { + col := tbl.GetColumn("a") + if col == nil { + t.Errorf("omni: column t.a missing") + } else { + assertBoolEq(t, "omni t.a nullable", col.Nullable, true) + if col.Default != nil && strings.ToUpper(*col.Default) != "NULL" { + t.Errorf("omni: t.a Default = %q, want nil or NULL", *col.Default) + } + } + } + + // omni: t2.c — no DEFAULT clause at all. Expect Nullable true and + // Default == nil (omni does not synthesize a "NULL" literal). + tbl2 := c.GetDatabase("testdb").GetTable("t2") + if tbl2 == nil { + t.Errorf("omni: table t2 missing") + } else { + col := tbl2.GetColumn("c") + if col == nil { + t.Errorf("omni: column t2.c missing") + } else { + assertBoolEq(t, "omni t2.c nullable", col.Nullable, true) + if col.Default != nil && strings.ToUpper(*col.Default) != "NULL" { + t.Errorf("omni: t2.c Default = %q, want nil (no clause)", *col.Default) + } + } + } + }) + + // --- 21.2 Bare JOIN -> INNER JOIN (JoinInner) -------------------------- + t.Run("21_2_join_type_default_inner", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + // Create tables so the container also accepts the SELECT. + runOnBoth(t, mc, c, `CREATE TABLE t1 (a INT); CREATE TABLE t2 (a INT);`) + + cases := []struct { + name string + sql string + }{ + {"bare_JOIN", "SELECT * FROM t1 JOIN t2 ON t1.a = t2.a"}, + {"INNER_JOIN", "SELECT * FROM t1 INNER JOIN t2 ON t1.a = t2.a"}, + {"CROSS_JOIN", "SELECT * FROM t1 CROSS JOIN t2"}, + } + for _, tc := range cases { + // Oracle: just ensure MySQL accepts it. + if _, err := mc.db.ExecContext(mc.ctx, "USE testdb; "+tc.sql); err != nil { + t.Errorf("oracle %s: %v", tc.name, err) + } + + // omni AST: parse and inspect JoinClause.Type. + jt, ok := c21FirstJoinType(t, tc.sql) + if !ok { + continue + } + // Grammar-level normalization: bare JOIN and INNER JOIN should both + // map to JoinInner. CROSS JOIN also maps to JoinInner per yacc but + // omni tracks JoinCross distinctly for deparse fidelity, which is + // acceptable — assert that the bare forms collapse. + if tc.name == "CROSS_JOIN" { + if jt != nodes.JoinInner && jt != nodes.JoinCross { + t.Errorf("omni %s: JoinType = %d, want JoinInner or JoinCross", tc.name, jt) + } + } else { + if jt != nodes.JoinInner { + t.Errorf("omni %s: JoinType = %d, want JoinInner(%d)", tc.name, jt, nodes.JoinInner) + } + } + } + }) + + // --- 21.3 ORDER BY without direction -> tri-state --------------------- + t.Run("21_3_order_by_no_direction", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (a INT);`) + + // Oracle: MySQL accepts both; we cannot read ORDER_NOT_RELEVANT from + // info_schema directly, so just verify both parse. + for _, s := range []string{ + "SELECT * FROM t ORDER BY a", + "SELECT * FROM t ORDER BY a ASC", + } { + if _, err := mc.db.ExecContext(mc.ctx, "USE testdb; "+s); err != nil { + t.Errorf("oracle %q: %v", s, err) + } + } + + // omni AST: omni's OrderByItem has a single Desc bool, which cannot + // distinguish ORDER_NOT_RELEVANT from ORDER_ASC — both parse to + // Desc=false. Capture this as an asymmetry the catalog doc notes. + bare := c21FirstOrderByItem(t, "SELECT * FROM t ORDER BY a") + asc := c21FirstOrderByItem(t, "SELECT * FROM t ORDER BY a ASC") + if bare == nil || asc == nil { + return + } + // Bug: omni cannot represent ORDER_NOT_RELEVANT distinctly. Both yield + // Desc=false. We assert MySQL's grammar distinction is *not* preserved + // so the bug surfaces in the queue. + assertBoolEq(t, "omni bare ORDER BY Desc", bare.Desc, false) + assertBoolEq(t, "omni ORDER BY ASC Desc", asc.Desc, false) + // Known omni gap: no tri-state direction field. Document in bug queue. + t.Errorf("omni: OrderByItem has no tri-state direction — " + + "ORDER BY a and ORDER BY a ASC are indistinguishable in AST") + }) + + // --- 21.4 LIMIT N without OFFSET -> opt_offset NULL ------------------- + t.Run("21_4_limit_without_offset", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (a INT);`) + + // Oracle: verify both accept. + for _, s := range []string{ + "SELECT * FROM t LIMIT 10", + "SELECT * FROM t LIMIT 10 OFFSET 0", + } { + if _, err := mc.db.ExecContext(mc.ctx, "USE testdb; "+s); err != nil { + t.Errorf("oracle %q: %v", s, err) + } + } + + // omni AST: LIMIT 10 -> Offset should be nil; LIMIT 10 OFFSET 0 -> + // Offset should be non-nil. + limBare := c21FirstLimit(t, "SELECT * FROM t LIMIT 10") + limOff := c21FirstLimit(t, "SELECT * FROM t LIMIT 10 OFFSET 0") + if limBare == nil { + t.Errorf("omni: LIMIT 10 produced no Limit node") + } else if limBare.Offset != nil { + t.Errorf("omni: LIMIT 10 Offset = %v, want nil", limBare.Offset) + } + if limOff == nil { + t.Errorf("omni: LIMIT 10 OFFSET 0 produced no Limit node") + } else if limOff.Offset == nil { + t.Errorf("omni: LIMIT 10 OFFSET 0 Offset = nil, want non-nil (Item_uint(0))") + } + }) + + // --- 21.5 FK ON DELETE omitted -> FK_OPTION_UNDEF --------------------- + t.Run("21_5_fk_on_delete_omitted_undef", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE parent (id INT PRIMARY KEY); +CREATE TABLE child (p INT, FOREIGN KEY (p) REFERENCES parent(id));` + runOnBoth(t, mc, c, ddl) + + // Oracle: information_schema.REFERENTIAL_CONSTRAINTS reports + // DELETE_RULE and UPDATE_RULE as "NO ACTION" (rendering of UNDEF). + var deleteRule, updateRule string + oracleScan(t, mc, + `SELECT DELETE_RULE, UPDATE_RULE FROM information_schema.REFERENTIAL_CONSTRAINTS + WHERE CONSTRAINT_SCHEMA='testdb' AND TABLE_NAME='child'`, + &deleteRule, &updateRule) + assertStringEq(t, "oracle DELETE_RULE", deleteRule, "NO ACTION") + assertStringEq(t, "oracle UPDATE_RULE", updateRule, "NO ACTION") + + // omni catalog: the FK constraint's OnDelete/OnUpdate should render as + // "NO ACTION" (via refActionToString mapping RefActNone -> "NO ACTION"). + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Errorf("omni: table child missing") + return + } + var fk *Constraint + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey { + fk = con + break + } + } + if fk == nil { + t.Errorf("omni: FK constraint missing on child") + return + } + // Omni maps UNDEF -> "NO ACTION" which matches the info_schema + // rendering. Any value distinct from NO ACTION is a bug. + assertStringEq(t, "omni FK OnDelete", fk.OnDelete, "NO ACTION") + assertStringEq(t, "omni FK OnUpdate", fk.OnUpdate, "NO ACTION") + }) + + // --- 21.6 FK ON DELETE present, ON UPDATE omitted --------------------- + t.Run("21_6_fk_on_update_independent_undef", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE parent (id INT PRIMARY KEY); +CREATE TABLE child (p INT, FOREIGN KEY (p) REFERENCES parent(id) ON DELETE CASCADE);` + runOnBoth(t, mc, c, ddl) + + // Oracle: DELETE_RULE=CASCADE, UPDATE_RULE=NO ACTION (not inherited). + var deleteRule, updateRule string + oracleScan(t, mc, + `SELECT DELETE_RULE, UPDATE_RULE FROM information_schema.REFERENTIAL_CONSTRAINTS + WHERE CONSTRAINT_SCHEMA='testdb' AND TABLE_NAME='child'`, + &deleteRule, &updateRule) + assertStringEq(t, "oracle DELETE_RULE", deleteRule, "CASCADE") + assertStringEq(t, "oracle UPDATE_RULE", updateRule, "NO ACTION") + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Errorf("omni: table child missing") + return + } + var fk *Constraint + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey { + fk = con + break + } + } + if fk == nil { + t.Errorf("omni: FK constraint missing") + return + } + assertStringEq(t, "omni FK OnDelete", fk.OnDelete, "CASCADE") + assertStringEq(t, "omni FK OnUpdate", fk.OnUpdate, "NO ACTION") + }) + + // --- 21.7 CREATE INDEX without USING -> nullptr (engine picks) -------- + t.Run("21_7_index_no_using_clause", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t (a INT, KEY k (a)); +CREATE TABLE t_b (a INT, KEY kb (a) USING BTREE);` + runOnBoth(t, mc, c, ddl) + + // Oracle: information_schema.STATISTICS.INDEX_TYPE is "BTREE" for both + // on InnoDB — engine fills in the default. We check only that both + // resolve to "BTREE" (InnoDB default), confirming the engine-fill. + var t1Type, t2Type string + oracleScan(t, mc, + `SELECT INDEX_TYPE FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND INDEX_NAME='k'`, + &t1Type) + oracleScan(t, mc, + `SELECT INDEX_TYPE FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t_b' AND INDEX_NAME='kb'`, + &t2Type) + assertStringEq(t, "oracle index k type", t1Type, "BTREE") + assertStringEq(t, "oracle index kb type", t2Type, "BTREE") + + // omni: the parser should preserve the distinction (no USING -> empty, + // USING BTREE -> "BTREE"). The engine default resolution is the + // catalog's job, not the parser's — but omni may collapse both. + tbl := c.GetDatabase("testdb").GetTable("t") + tbl2 := c.GetDatabase("testdb").GetTable("t_b") + if tbl == nil || tbl2 == nil { + t.Errorf("omni: missing tables t/t_b") + return + } + var kType, kbType string + for _, idx := range tbl.Indexes { + if idx.Name == "k" { + kType = idx.IndexType + } + } + for _, idx := range tbl2.Indexes { + if idx.Name == "kb" { + kbType = idx.IndexType + } + } + // omni grammar: no USING clause should leave IndexType empty. If omni + // defaults to "BTREE" at parse time, that's a bug (loses the info + // needed to deparse faithfully). + if kType != "" { + t.Errorf("omni: index k IndexType = %q, want \"\" (no USING clause)", kType) + } + assertStringEq(t, "omni index kb IndexType", kbType, "BTREE") + }) + + // --- 21.8 INSERT without column list -> empty Columns ----------------- + t.Run("21_8_insert_no_column_list", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, b INT, c INT);`) + + // Oracle: both forms accepted. + for _, s := range []string{ + "INSERT INTO t VALUES (1, 2, 3)", + "INSERT INTO t (a, b, c) VALUES (4, 5, 6)", + } { + if _, err := mc.db.ExecContext(mc.ctx, "USE testdb; "+s); err != nil { + t.Errorf("oracle %q: %v", s, err) + } + } + + // omni AST: INSERT without column list -> Columns is nil / empty. + bare := c21FirstInsert(t, "INSERT INTO t VALUES (1, 2, 3)") + full := c21FirstInsert(t, "INSERT INTO t (a, b, c) VALUES (4, 5, 6)") + if bare == nil || full == nil { + return + } + if len(bare.Columns) != 0 { + t.Errorf("omni: INSERT without column list produced %d columns, want 0", + len(bare.Columns)) + } + if len(full.Columns) != 3 { + t.Errorf("omni: INSERT with explicit column list produced %d columns, want 3", + len(full.Columns)) + } + }) + + // --- 21.9 CREATE VIEW without ALGORITHM -> UNDEFINED ------------------ + t.Run("21_9_view_no_algorithm_undefined", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (a INT);`) + + // Three forms that all resolve to ALGORITHM=UNDEFINED. + runOnBoth(t, mc, c, `CREATE VIEW v1 AS SELECT 1 AS x;`) + runOnBoth(t, mc, c, `CREATE ALGORITHM=UNDEFINED VIEW v2 AS SELECT 1 AS x;`) + + // Oracle: information_schema.VIEWS has no ALGORITHM column in MySQL + // 8.0; use SHOW CREATE VIEW which renders the algorithm verbatim. + // Both forms should contain ALGORITHM=UNDEFINED. + for _, name := range []string{"v1", "v2"} { + create := oracleShow(t, mc, "SHOW CREATE VIEW "+name) + if !strings.Contains(strings.ToUpper(create), "ALGORITHM=UNDEFINED") { + t.Errorf("oracle %s: SHOW CREATE VIEW missing ALGORITHM=UNDEFINED: %s", + name, create) + } + } + + // omni: View.Algorithm should either be "" (default) or "UNDEFINED" + // for v1 (no explicit clause), and "UNDEFINED" for v2. + db := c.GetDatabase("testdb") + v1 := db.Views["v1"] + v2 := db.Views["v2"] + if v1 == nil { + t.Errorf("omni: view v1 missing") + } else if v1.Algorithm != "" && !strings.EqualFold(v1.Algorithm, "UNDEFINED") { + t.Errorf("omni: v1 Algorithm = %q, want \"\" or UNDEFINED", v1.Algorithm) + } + if v2 == nil { + t.Errorf("omni: view v2 missing") + } else if !strings.EqualFold(v2.Algorithm, "UNDEFINED") { + t.Errorf("omni: v2 Algorithm = %q, want UNDEFINED", v2.Algorithm) + } + }) + + // --- 21.10 CREATE TABLE without ENGINE -> post-parse fill ------------ + t.Run("21_10_create_table_no_engine", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t_noeng (a INT);`) + runOnBoth(t, mc, c, `CREATE TABLE t_eng (a INT) ENGINE=InnoDB;`) + + // Oracle: both resolve to "InnoDB" via session default. + var e1, e2 string + oracleScan(t, mc, + `SELECT ENGINE FROM information_schema.TABLES + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t_noeng'`, + &e1) + oracleScan(t, mc, + `SELECT ENGINE FROM information_schema.TABLES + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t_eng'`, + &e2) + assertStringEq(t, "oracle t_noeng engine", e1, "InnoDB") + assertStringEq(t, "oracle t_eng engine", e2, "InnoDB") + + // omni: parser must leave Table.Engine empty when no ENGINE clause. + // Post-parse fill (from a session variable) is a catalog-layer job and + // omni may or may not do it. We assert the distinction: t_noeng should + // NOT silently default to "InnoDB" at parse time. + tNo := c.GetDatabase("testdb").GetTable("t_noeng") + tEx := c.GetDatabase("testdb").GetTable("t_eng") + if tNo == nil || tEx == nil { + t.Errorf("omni: missing t_noeng or t_eng") + return + } + // The spec says parser should leave engine NULL. If omni fills + // "InnoDB" at parse time, that's a parser-vs-catalog-layer mix-up. + if strings.EqualFold(tNo.Engine, "InnoDB") { + t.Errorf("omni: t_noeng Engine is prematurely filled with InnoDB " + + "(parser should leave it empty; session-var fill is a post-parse step)") + } + if !strings.EqualFold(tEx.Engine, "InnoDB") { + t.Errorf("omni: t_eng Engine = %q, want InnoDB (explicit)", tEx.Engine) + } + }) +} + +// --- section-local helpers ------------------------------------------------- + +// c21FirstJoinType parses a SELECT query and returns the JoinType of the +// first JoinClause found in the FROM list. Uses t.Errorf (not Fatal) so the +// subtest keeps running. +func c21FirstJoinType(t *testing.T, sql string) (nodes.JoinType, bool) { + t.Helper() + stmts, err := parser.Parse(sql) + if err != nil { + t.Errorf("omni parse %q: %v", sql, err) + return 0, false + } + if len(stmts.Items) == 0 { + t.Errorf("omni parse %q: no statements", sql) + return 0, false + } + sel, ok := stmts.Items[0].(*nodes.SelectStmt) + if !ok { + t.Errorf("omni parse %q: expected SelectStmt, got %T", sql, stmts.Items[0]) + return 0, false + } + for _, te := range sel.From { + if jc, ok := te.(*nodes.JoinClause); ok { + return jc.Type, true + } + } + t.Errorf("omni parse %q: no JoinClause in FROM list", sql) + return 0, false +} + +// c21FirstOrderByItem parses a SELECT and returns the first ORDER BY item. +func c21FirstOrderByItem(t *testing.T, sql string) *nodes.OrderByItem { + t.Helper() + stmts, err := parser.Parse(sql) + if err != nil { + t.Errorf("omni parse %q: %v", sql, err) + return nil + } + if len(stmts.Items) == 0 { + t.Errorf("omni parse %q: no statements", sql) + return nil + } + sel, ok := stmts.Items[0].(*nodes.SelectStmt) + if !ok { + t.Errorf("omni parse %q: expected SelectStmt, got %T", sql, stmts.Items[0]) + return nil + } + if len(sel.OrderBy) == 0 { + t.Errorf("omni parse %q: no ORDER BY items", sql) + return nil + } + return sel.OrderBy[0] +} + +// c21FirstLimit parses a SELECT and returns its Limit node (or nil). +func c21FirstLimit(t *testing.T, sql string) *nodes.Limit { + t.Helper() + stmts, err := parser.Parse(sql) + if err != nil { + t.Errorf("omni parse %q: %v", sql, err) + return nil + } + if len(stmts.Items) == 0 { + return nil + } + sel, ok := stmts.Items[0].(*nodes.SelectStmt) + if !ok { + t.Errorf("omni parse %q: expected SelectStmt, got %T", sql, stmts.Items[0]) + return nil + } + return sel.Limit +} + +// c21FirstInsert parses an INSERT and returns the InsertStmt node. +func c21FirstInsert(t *testing.T, sql string) *nodes.InsertStmt { + t.Helper() + stmts, err := parser.Parse(sql) + if err != nil { + t.Errorf("omni parse %q: %v", sql, err) + return nil + } + if len(stmts.Items) == 0 { + t.Errorf("omni parse %q: no statements", sql) + return nil + } + ins, ok := stmts.Items[0].(*nodes.InsertStmt) + if !ok { + t.Errorf("omni parse %q: expected InsertStmt, got %T", sql, stmts.Items[0]) + return nil + } + return ins +} diff --git a/tidb/catalog/scenarios_c22_test.go b/tidb/catalog/scenarios_c22_test.go new file mode 100644 index 00000000..dd30ab7c --- /dev/null +++ b/tidb/catalog/scenarios_c22_test.go @@ -0,0 +1,417 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C22 covers Section C22 "ALTER TABLE algorithm / lock defaults" +// from mysql/catalog/SCENARIOS-mysql-implicit-behavior.md. +// +// omni's catalog intentionally does NOT track ALGORITHM= / LOCK= clauses — +// they are execution-time concerns, not persisted schema state. So these +// tests verify three things per scenario: +// +// 1. omni parses and accepts the ALTER ... ALGORITHM=... / LOCK=... form. +// 2. omni applies the underlying column/index change to catalog state +// identically whether or not an ALGORITHM/LOCK clause is present. +// 3. MySQL 8.0 actually accepts or rejects the statement as the scenario +// claims (oracle sanity check — a regression in MySQL wouldn't be an +// omni bug but would invalidate the assertion). +// +// Algorithm selection (scenarios 22.1/22.2) is not directly observable from +// the client side without inspecting internal counters, so those tests only +// verify that DEFAULT clauses parse and succeed on both sides. +// +// Failures in omni assertions are NOT proof failures — they are recorded in +// mysql/catalog/scenarios_bug_queue/c22.md. +func TestScenario_C22(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // ----------------------------------------------------------------- + // 22.1 ALGORITHM=DEFAULT picks fastest supported (oracle-only parse check) + // ----------------------------------------------------------------- + t.Run("22_1_algorithm_default", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + "CREATE TABLE t1 (id INT PRIMARY KEY, a INT) ENGINE=InnoDB") + // Bare ADD COLUMN — DEFAULT algorithm, DEFAULT lock. + runOnBoth(t, mc, c, "ALTER TABLE t1 ADD COLUMN b INT") + // Explicit ALGORITHM=DEFAULT, LOCK=DEFAULT. + runOnBoth(t, mc, c, + "ALTER TABLE t1 ADD COLUMN c INT, ALGORITHM=DEFAULT, LOCK=DEFAULT") + + // Oracle: both columns present. + rows := oracleRows(t, mc, + `SELECT COLUMN_NAME FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t1' + ORDER BY ORDINAL_POSITION`) + var oracleCols []string + for _, r := range rows { + oracleCols = append(oracleCols, asString(r[0])) + } + assertStringEq(t, "oracle columns", strings.Join(oracleCols, ","), + "id,a,b,c") + + // omni: both columns applied to catalog. + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Errorf("omni: table t1 missing") + return + } + var omniCols []string + for _, col := range tbl.Columns { + omniCols = append(omniCols, col.Name) + } + assertStringEq(t, "omni columns", strings.Join(omniCols, ","), + "id,a,b,c") + }) + + // ----------------------------------------------------------------- + // 22.2 LOCK=DEFAULT picks least restrictive (oracle-only parse check) + // ----------------------------------------------------------------- + t.Run("22_2_lock_default_add_index", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + "CREATE TABLE t1 (id INT PRIMARY KEY, a INT) ENGINE=InnoDB") + runOnBoth(t, mc, c, "ALTER TABLE t1 ADD INDEX ix_a (a)") + runOnBoth(t, mc, c, + "ALTER TABLE t1 ADD INDEX ix_a2 (a), LOCK=DEFAULT") + + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Errorf("omni: table t1 missing") + return + } + foundIxA, foundIxA2 := false, false + for _, idx := range tbl.Indexes { + switch idx.Name { + case "ix_a": + foundIxA = true + case "ix_a2": + foundIxA2 = true + } + } + assertBoolEq(t, "omni ix_a present", foundIxA, true) + assertBoolEq(t, "omni ix_a2 present", foundIxA2, true) + }) + + // ----------------------------------------------------------------- + // 22.3 ADD COLUMN trailing nullable is INSTANT in 8.0.12+ + // ----------------------------------------------------------------- + t.Run("22_3_add_column_instant", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + "CREATE TABLE t1 (id INT PRIMARY KEY, a INT) ENGINE=InnoDB ROW_FORMAT=DYNAMIC") + // 22.3a: bare ADD COLUMN. + runOnBoth(t, mc, c, "ALTER TABLE t1 ADD COLUMN b VARCHAR(32) NULL") + // 22.3b: explicit ALGORITHM=INSTANT. + runOnBoth(t, mc, c, + "ALTER TABLE t1 ADD COLUMN c INT NULL, ALGORITHM=INSTANT") + + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Errorf("omni: table t1 missing") + return + } + var names []string + for _, col := range tbl.Columns { + names = append(names, col.Name) + } + assertStringEq(t, "omni columns after INSTANT adds", + strings.Join(names, ","), "id,a,b,c") + + // Oracle sanity: ALGORITHM=INSTANT statement must have succeeded. + var n int + oracleScan(t, mc, + `SELECT COUNT(*) FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t1' AND COLUMN_NAME='c'`, + &n) + assertIntEq(t, "oracle column c exists", n, 1) + }) + + // ----------------------------------------------------------------- + // 22.4 DROP COLUMN INSTANT (8.0.29+) else INPLACE + // ----------------------------------------------------------------- + t.Run("22_4_drop_column", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + "CREATE TABLE t1 (id INT PRIMARY KEY, a INT, b INT) ENGINE=InnoDB") + // 22.4a: bare DROP COLUMN. + runOnBoth(t, mc, c, "ALTER TABLE t1 DROP COLUMN b") + + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Errorf("omni: table t1 missing") + return + } + if tbl.GetColumn("b") != nil { + t.Errorf("omni: column b should be dropped") + } + if tbl.GetColumn("a") == nil { + t.Errorf("omni: column a should still exist") + } + + // 22.4b: explicit ALGORITHM=INSTANT on a fresh table. On MySQL 8.0.29+ + // this succeeds; on earlier versions it errors. Either way, omni's + // parser must accept it, and if oracle accepts it, omni must apply + // the drop. + runOnBoth(t, mc, c, + "CREATE TABLE t2 (id INT PRIMARY KEY, a INT, b INT) ENGINE=InnoDB") + _, oracleErr := mc.db.ExecContext(mc.ctx, + "ALTER TABLE t2 DROP COLUMN b, ALGORITHM=INSTANT") + + // omni: apply against catalog regardless. + results, err := c.Exec("ALTER TABLE t2 DROP COLUMN b, ALGORITHM=INSTANT;", nil) + if err != nil { + t.Errorf("omni: parse error for DROP COLUMN + ALGORITHM=INSTANT: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Errorf("omni: exec error on stmt %d: %v", r.Index, r.Error) + } + } + + if oracleErr == nil { + // 8.0.29+ — drop succeeded, omni must have the column gone too. + if tbl2 := c.GetDatabase("testdb").GetTable("t2"); tbl2 != nil && + tbl2.GetColumn("b") != nil { + t.Errorf("omni: t2.b should be dropped after ALGORITHM=INSTANT") + } + } else { + // Pre-8.0.29 — oracle rejected. omni catalog is algorithm-oblivious + // and still drops the column, which is the documented correct + // behavior (SDL diff classifier must filter these, not the + // catalog). Only assert the oracle error kind. + if !strings.Contains(oracleErr.Error(), "ALGORITHM") && + !strings.Contains(oracleErr.Error(), "not supported") { + t.Errorf("oracle: unexpected error form: %v", oracleErr) + } + } + }) + + // ----------------------------------------------------------------- + // 22.5 RENAME COLUMN metadata-only (INPLACE / INSTANT on 8.0.29+) + // ----------------------------------------------------------------- + t.Run("22_5_rename_column", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + "CREATE TABLE t1 (id INT PRIMARY KEY, old_name INT) ENGINE=InnoDB") + // 22.5a: RENAME COLUMN form. + runOnBoth(t, mc, c, "ALTER TABLE t1 RENAME COLUMN old_name TO new_name") + + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Errorf("omni: table t1 missing") + return + } + if tbl.GetColumn("old_name") != nil { + t.Errorf("omni: old_name should be gone after rename") + } + if tbl.GetColumn("new_name") == nil { + t.Errorf("omni: new_name should exist after rename") + } + + // Oracle: INFORMATION_SCHEMA.COLUMNS should have new_name. + var n int + oracleScan(t, mc, + `SELECT COUNT(*) FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t1' AND COLUMN_NAME='new_name'`, + &n) + assertIntEq(t, "oracle new_name present", n, 1) + + // 22.5b: CHANGE COLUMN form (type unchanged). + runOnBoth(t, mc, c, + "ALTER TABLE t1 CHANGE COLUMN new_name newer_name INT") + if tbl.GetColumn("new_name") != nil { + t.Errorf("omni: new_name should be gone after CHANGE rename") + } + if tbl.GetColumn("newer_name") == nil { + t.Errorf("omni: newer_name should exist after CHANGE rename") + } + }) + + // ----------------------------------------------------------------- + // 22.6 CHANGE COLUMN type forces COPY; explicit INPLACE/LOCK=NONE error + // ----------------------------------------------------------------- + t.Run("22_6_change_column_type_copy", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + "CREATE TABLE t1 (id INT PRIMARY KEY, a INT) ENGINE=InnoDB") + + // 22.6a: bare CHANGE COLUMN type. DEFAULT → COPY, succeeds. + runOnBoth(t, mc, c, "ALTER TABLE t1 CHANGE COLUMN a a BIGINT") + + // Oracle: column a now has type bigint. + var oracleType string + oracleScan(t, mc, + `SELECT DATA_TYPE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t1' AND COLUMN_NAME='a'`, + &oracleType) + if strings.ToLower(oracleType) != "bigint" { + t.Errorf("oracle: column a type got %q want bigint", oracleType) + } + + // omni: column a's type should be bigint. + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Errorf("omni: table t1 missing") + return + } + col := tbl.GetColumn("a") + if col == nil { + t.Errorf("omni: column a missing after CHANGE") + } else if !strings.Contains(strings.ToLower(col.DataType), "bigint") { + t.Errorf("omni: column a type got %q want bigint", col.DataType) + } + + // 22.6b: MODIFY + ALGORITHM=INPLACE → MySQL rejects. + _, oracleErr := mc.db.ExecContext(mc.ctx, + "ALTER TABLE t1 MODIFY COLUMN a INT, ALGORITHM=INPLACE") + if oracleErr == nil { + t.Errorf("oracle: expected error for MODIFY ... ALGORITHM=INPLACE, got nil") + } else if !strings.Contains(oracleErr.Error(), "ALGORITHM") && + !strings.Contains(oracleErr.Error(), "not supported") { + t.Errorf("oracle: unexpected error form: %v", oracleErr) + } + // omni must at least parse and (per catalog contract) apply the change. + results, err := c.Exec("ALTER TABLE t1 MODIFY COLUMN a INT, ALGORITHM=INPLACE;", nil) + if err != nil { + t.Errorf("omni: parse error on MODIFY + ALGORITHM=INPLACE: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Errorf("omni: exec error on stmt %d: %v", r.Index, r.Error) + } + } + + // 22.6c: MODIFY with a genuine type change + LOCK=NONE → MySQL + // rejects. Column `a` is currently bigint (from 22.6a); changing it + // to INT UNSIGNED is a real type change, which forces COPY, which is + // incompatible with LOCK=NONE. (A no-op MODIFY to the same type is + // silently optimized away and would not error.) + _, oracleErr2 := mc.db.ExecContext(mc.ctx, + "ALTER TABLE t1 MODIFY COLUMN a INT UNSIGNED, LOCK=NONE") + if oracleErr2 == nil { + t.Errorf("oracle: expected error for MODIFY ... LOCK=NONE, got nil") + } + results2, err2 := c.Exec("ALTER TABLE t1 MODIFY COLUMN a INT UNSIGNED, LOCK=NONE;", nil) + if err2 != nil { + t.Errorf("omni: parse error on MODIFY + LOCK=NONE: %v", err2) + } + for _, r := range results2 { + if r.Error != nil { + t.Errorf("omni: exec error on stmt %d: %v", r.Index, r.Error) + } + } + }) + + // ----------------------------------------------------------------- + // 22.7 ALGORITHM=INSTANT on unsupported operation → hard error + // ----------------------------------------------------------------- + t.Run("22_7_instant_unsupported_hard_error", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + "CREATE TABLE t1 (id INT PRIMARY KEY, a INT) ENGINE=InnoDB") + + // PK rebuild + ALGORITHM=INSTANT → MySQL rejects with + // ER_ALTER_OPERATION_NOT_SUPPORTED_REASON. + _, oracleErr := mc.db.ExecContext(mc.ctx, + "ALTER TABLE t1 DROP PRIMARY KEY, ADD PRIMARY KEY (a), ALGORITHM=INSTANT") + if oracleErr == nil { + t.Errorf("oracle: expected error for PK rebuild ALGORITHM=INSTANT, got nil") + } else if !strings.Contains(oracleErr.Error(), "ALGORITHM") && + !strings.Contains(oracleErr.Error(), "not supported") { + t.Errorf("oracle: unexpected error form: %v", oracleErr) + } + + // omni: parser must accept; catalog applies PK swap regardless. + results, err := c.Exec( + "ALTER TABLE t1 DROP PRIMARY KEY, ADD PRIMARY KEY (a), ALGORITHM=INSTANT;", nil) + if err != nil { + t.Errorf("omni: parse error on PK rebuild + ALGORITHM=INSTANT: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Errorf("omni: exec error on stmt %d: %v", r.Index, r.Error) + } + } + }) + + // ----------------------------------------------------------------- + // 22.8 LOCK=NONE on COPY-only op errors; LOCK=... with INSTANT errors + // ----------------------------------------------------------------- + t.Run("22_8_lock_none_copy_and_instant", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + "CREATE TABLE t1 (id INT PRIMARY KEY, a INT) ENGINE=InnoDB") + + // 22.8a: COPY-only operation + LOCK=NONE → hard error. CHANGE COLUMN + // with a real type change forces COPY, and COPY cannot honor + // LOCK=NONE. + _, oracleErr := mc.db.ExecContext(mc.ctx, + "ALTER TABLE t1 CHANGE COLUMN a a BIGINT UNSIGNED, LOCK=NONE") + if oracleErr == nil { + t.Errorf("oracle 22.8a: expected error for COPY op + LOCK=NONE, got nil") + } + results, _ := c.Exec( + "ALTER TABLE t1 CHANGE COLUMN a a BIGINT UNSIGNED, LOCK=NONE;", nil) + for _, r := range results { + if r.Error != nil { + t.Errorf("omni 22.8a: exec error on stmt %d: %v", r.Index, r.Error) + } + } + + // 22.8b: ADD COLUMN + ALGORITHM=INSTANT + LOCK=NONE → hard error + // ("Only LOCK=DEFAULT is permitted for operations using ALGORITHM=INSTANT"). + _, oracleErr2 := mc.db.ExecContext(mc.ctx, + "ALTER TABLE t1 ADD COLUMN b INT, ALGORITHM=INSTANT, LOCK=NONE") + if oracleErr2 == nil { + t.Errorf("oracle 22.8b: expected error for INSTANT + LOCK=NONE, got nil") + } + results2, err2 := c.Exec( + "ALTER TABLE t1 ADD COLUMN b INT, ALGORITHM=INSTANT, LOCK=NONE;", nil) + if err2 != nil { + t.Errorf("omni 22.8b: parse error: %v", err2) + } + for _, r := range results2 { + if r.Error != nil { + t.Errorf("omni 22.8b: exec error on stmt %d: %v", r.Index, r.Error) + } + } + + // 22.8c: ADD COLUMN + ALGORITHM=INSTANT (no LOCK) → succeeds. + // Note: b may already exist in omni catalog from 22.8b above; use c. + runOnBoth(t, mc, c, + "ALTER TABLE t1 ADD COLUMN c INT, ALGORITHM=INSTANT") + + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Errorf("omni 22.8c: table t1 missing") + return + } + if tbl.GetColumn("c") == nil { + t.Errorf("omni 22.8c: column c should be added") + } + }) +} diff --git a/tidb/catalog/scenarios_c23_test.go b/tidb/catalog/scenarios_c23_test.go new file mode 100644 index 00000000..9e670128 --- /dev/null +++ b/tidb/catalog/scenarios_c23_test.go @@ -0,0 +1,282 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C23 covers Section C23 "NULL in string context" from +// mysql/catalog/SCENARIOS-mysql-implicit-behavior.md. +// +// Scenarios in this section verify NULL propagation through string +// functions: +// +// 23.1 CONCAT(...) — any NULL arg → NULL result +// 23.2 CONCAT_WS(sep, ...) — NULL data args skipped; NULL separator → NULL +// 23.3 IFNULL/COALESCE — rescue pattern around CONCAT NULL propagation +// +// These are runtime expression behaviors. The omni catalog stores parsed +// VIEWs as a *catalog.View whose `Columns` field is just a list of column +// NAMES (see mysql/catalog/table.go). It does NOT track per-column +// nullability. So the most useful representation of these scenarios in +// omni today is: +// +// 1. Run the same CREATE TABLE / CREATE VIEW DDL against both MySQL 8.0 +// and the omni catalog (proves the DDL parses on both sides). +// 2. Use information_schema.COLUMNS on the container to assert that +// MySQL infers IS_NULLABLE the way the SCENARIOS doc claims. +// 3. SELECT the actual rows from the container to lock the runtime +// string values into the test (oracle ground truth). +// 4. Record the omni gap: the View struct has no per-column nullability +// info, so omni cannot answer "is column c1 of view v nullable?" — +// this is the declared bug, documented in scenarios_bug_queue/c23.md. +// +// Failed omni assertions are NOT proof failures — they are recorded in +// mysql/catalog/scenarios_bug_queue/c23.md. +func TestScenario_C23(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // c23OmniView fetches a view from the omni catalog by name. + c23OmniView := func(c *Catalog, name string) *View { + db := c.GetDatabase("testdb") + if db == nil { + return nil + } + return db.Views[strings.ToLower(name)] + } + + // c23OracleViewColNullable returns the IS_NULLABLE value for a single + // view column from information_schema.COLUMNS. + c23OracleViewColNullable := func(t *testing.T, view, col string) string { + t.Helper() + var s string + oracleScan(t, mc, + `SELECT IS_NULLABLE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='`+view+`' AND COLUMN_NAME='`+col+`'`, + &s) + return s + } + + // ----------------------------------------------------------------- + // 23.1 CONCAT with any NULL argument → NULL result + // ----------------------------------------------------------------- + t.Run("23_1_concat_null_propagates", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (a VARCHAR(10), b VARCHAR(10)); + CREATE VIEW v_concat AS SELECT CONCAT(a, b) AS c1, CONCAT('x','','y') AS c2 FROM t;`) + + // Oracle row-level checks (lock in the actual MySQL evaluation). + _, err := mc.db.ExecContext(mc.ctx, + `INSERT INTO t VALUES ('foo', NULL), ('foo', 'bar')`) + if err != nil { + t.Fatalf("oracle INSERT: %v", err) + } + // SELECT CONCAT(a,b) — first row NULL, second row 'foobar' + rows := oracleRows(t, mc, `SELECT CONCAT(a, b) FROM t ORDER BY b IS NULL DESC`) + if len(rows) != 2 { + t.Errorf("oracle: expected 2 rows, got %d", len(rows)) + } else { + if rows[0][0] != nil { + t.Errorf("oracle row[0] CONCAT(a,b): want NULL, got %v", rows[0][0]) + } + if asString(rows[1][0]) != "foobar" { + t.Errorf("oracle row[1] CONCAT(a,b): want 'foobar', got %v", rows[1][0]) + } + } + // SELECT CONCAT('x', NULL, 'y') → NULL + var lit any + oracleScan(t, mc, `SELECT CONCAT('x', NULL, 'y')`, &lit) + if lit != nil { + t.Errorf("oracle CONCAT('x',NULL,'y'): want NULL, got %v", lit) + } + // SELECT CONCAT('x', '', 'y') → 'xy' (empty string is NOT NULL) + var emptyStr string + oracleScan(t, mc, `SELECT CONCAT('x', '', 'y')`, &emptyStr) + assertStringEq(t, "oracle CONCAT('x','','y')", emptyStr, "xy") + + // Oracle view-column nullability — empirical MySQL 8.0.45: + // c1 (CONCAT of two nullable cols) → IS_NULLABLE=YES + // c2 (CONCAT of three string literals) → IS_NULLABLE=YES + // Note: although the SCENARIOS doc reasoning would suggest c2 + // should be NOT NULL (all inputs are non-null literals), MySQL + // 8.0's view metadata pass conservatively reports any string + // function result column as nullable. We lock the ground truth + // so omni's eventual nullability inference can match it. + assertStringEq(t, "oracle v_concat.c1 IS_NULLABLE", + c23OracleViewColNullable(t, "v_concat", "c1"), "YES") + assertStringEq(t, "oracle v_concat.c2 IS_NULLABLE", + c23OracleViewColNullable(t, "v_concat", "c2"), "YES") + + // omni: view exists but per-column nullability is not represented. + v := c23OmniView(c, "v_concat") + if v == nil { + t.Errorf("omni: view v_concat not found") + return + } + if len(v.Columns) != 2 { + t.Errorf("omni: v_concat expected 2 columns, got %d (%v)", len(v.Columns), v.Columns) + } + // Declared bug: View has no per-column nullability info, so the + // "CONCAT propagates NULL" semantics cannot be asserted positively. + t.Error("omni: View struct has no per-column nullability; scenario 23.1 cannot be asserted positively") + }) + + // ----------------------------------------------------------------- + // 23.2 CONCAT_WS skips NULL arguments (separator non-null) + // ----------------------------------------------------------------- + t.Run("23_2_concat_ws_skips_null_args", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (a VARCHAR(10), b VARCHAR(10), c VARCHAR(10)); + CREATE VIEW v_ws AS SELECT CONCAT_WS(',', a, b, c) AS d1 FROM t;`) + + _, err := mc.db.ExecContext(mc.ctx, + `INSERT INTO t VALUES ('x', NULL, 'z'), (NULL, NULL, NULL)`) + if err != nil { + t.Fatalf("oracle INSERT: %v", err) + } + + // row1 → 'x,z' (NULL skipped, no double separator) + // row2 → '' (all NULL data args → empty string, NOT NULL) + // Order stably so (a IS NOT NULL) row comes first regardless of + // storage engine ordering. + rows := oracleRows(t, mc, `SELECT CONCAT_WS(',', a, b, c) FROM t ORDER BY a IS NULL, a`) + if len(rows) != 2 { + t.Errorf("oracle: expected 2 rows, got %d", len(rows)) + } else { + assertStringEq(t, "oracle row[0] CONCAT_WS", + asString(rows[0][0]), "x,z") + assertStringEq(t, "oracle row[1] CONCAT_WS", + asString(rows[1][0]), "") + } + + // Literal: CONCAT_WS(',', 'x', NULL, 'z') → 'x,z' + var lit1 string + oracleScan(t, mc, `SELECT CONCAT_WS(',', 'x', NULL, 'z')`, &lit1) + assertStringEq(t, "oracle CONCAT_WS(',','x',NULL,'z')", lit1, "x,z") + + // Literal: CONCAT_WS(',', NULL, NULL, NULL) → '' (NOT NULL) + var lit2 string + oracleScan(t, mc, `SELECT CONCAT_WS(',', NULL, NULL, NULL)`, &lit2) + assertStringEq(t, "oracle CONCAT_WS(',',NULL,NULL,NULL)", lit2, "") + + // Literal: CONCAT_WS(NULL, 'x', 'y') → NULL (separator NULL → NULL) + var lit3 any + oracleScan(t, mc, `SELECT CONCAT_WS(NULL, 'x', 'y')`, &lit3) + if lit3 != nil { + t.Errorf("oracle CONCAT_WS(NULL,'x','y'): want NULL, got %v", lit3) + } + + // Oracle view nullability — empirical MySQL 8.0.45: even though + // the separator (',') is a non-null literal and the runtime rule + // is "result is NULL iff separator is NULL" (so d1 is in fact + // never NULL at runtime), MySQL's view metadata pass reports + // IS_NULLABLE='YES' for the CONCAT_WS result column. The SCENARIOS + // doc's reasoning describes the runtime semantics, not the + // information_schema metadata. We lock the metadata ground truth + // here and rely on the runtime literal assertions above (lit1, + // lit2, lit3) to lock the runtime semantics. + assertStringEq(t, "oracle v_ws.d1 IS_NULLABLE", + c23OracleViewColNullable(t, "v_ws", "d1"), "YES") + + v := c23OmniView(c, "v_ws") + if v == nil { + t.Errorf("omni: view v_ws not found") + return + } + if len(v.Columns) != 1 { + t.Errorf("omni: v_ws expected 1 column, got %d (%v)", len(v.Columns), v.Columns) + } + // Declared bug: omni cannot answer "is the separator NULL?" because + // View has no per-column nullability with the CONCAT_WS special case. + t.Error("omni: View struct has no per-column nullability; CONCAT_WS NULL-skip rule (23.2) cannot be asserted positively") + }) + + // ----------------------------------------------------------------- + // 23.3 IFNULL / COALESCE as rescue for CONCAT NULL-propagation + // ----------------------------------------------------------------- + t.Run("23_3_ifnull_coalesce_rescue", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (first_name VARCHAR(20), middle_name VARCHAR(20), last_name VARCHAR(20)); + CREATE VIEW v_name AS SELECT + CONCAT(first_name, ' ', middle_name, ' ', last_name) AS bad, + CONCAT(first_name, ' ', IFNULL(middle_name, ''), ' ', last_name) AS rescue_ifnull, + CONCAT(first_name, ' ', COALESCE(middle_name, ''), ' ', last_name) AS rescue_coalesce, + CONCAT_WS(' ', first_name, middle_name, last_name) AS rescue_ws + FROM t;`) + + _, err := mc.db.ExecContext(mc.ctx, + `INSERT INTO t VALUES ('Ada', NULL, 'Lovelace')`) + if err != nil { + t.Fatalf("oracle INSERT: %v", err) + } + + // Row-level oracle ground truth. + rows := oracleRows(t, mc, `SELECT bad, rescue_ifnull, rescue_coalesce, rescue_ws FROM v_name`) + if len(rows) != 1 { + t.Errorf("oracle: expected 1 row from v_name, got %d", len(rows)) + } else { + r := rows[0] + if r[0] != nil { + t.Errorf("oracle bad: want NULL, got %v", r[0]) + } + assertStringEq(t, "oracle rescue_ifnull", asString(r[1]), "Ada Lovelace") + assertStringEq(t, "oracle rescue_coalesce", asString(r[2]), "Ada Lovelace") + assertStringEq(t, "oracle rescue_ws", asString(r[3]), "Ada Lovelace") + } + + // Oracle column nullability: + // bad → YES (CONCAT with NULL middle_name propagates) + // rescue_ifnull → YES on MySQL 8.0 — even though the second IFNULL + // arg is a non-null literal, the surrounding CONCAT + // also reads first_name/last_name which are + // nullable (no NOT NULL on base table). So the + // result is reported nullable. We just record what + // MySQL says so the doc claim is locked. + // rescue_coalesce → YES (same reasoning) + // rescue_ws → YES (data args nullable; separator non-null but + // base columns nullable — the result column from + // a view body that references nullable inputs may + // still be reported nullable). + // The point of 23.3 is the RUNTIME values, which we asserted above. + // The IS_NULLABLE values here are oracle ground-truth — record them + // so any future MySQL upgrade or omni inference change is caught. + // Lock the specific MySQL 8.0.45 metadata ground truth for all four + // columns (empirical). Base columns are nullable → view columns + // reported nullable regardless of the IFNULL/COALESCE rescue, so all + // four are "YES". Any future MySQL upgrade or omni inference change + // will trip one of these and force a scenario revisit. + assertStringEq(t, "oracle v_name.bad IS_NULLABLE", + c23OracleViewColNullable(t, "v_name", "bad"), "YES") + assertStringEq(t, "oracle v_name.rescue_ifnull IS_NULLABLE", + c23OracleViewColNullable(t, "v_name", "rescue_ifnull"), "YES") + assertStringEq(t, "oracle v_name.rescue_coalesce IS_NULLABLE", + c23OracleViewColNullable(t, "v_name", "rescue_coalesce"), "YES") + assertStringEq(t, "oracle v_name.rescue_ws IS_NULLABLE", + c23OracleViewColNullable(t, "v_name", "rescue_ws"), "YES") + + v := c23OmniView(c, "v_name") + if v == nil { + t.Errorf("omni: view v_name not found") + return + } + if len(v.Columns) != 4 { + t.Errorf("omni: v_name expected 4 columns, got %d (%v)", len(v.Columns), v.Columns) + } + // Declared bug: omni cannot answer "does IFNULL/COALESCE rescue the + // CONCAT" because View has no per-column nullability inference. + t.Error("omni: View struct has no per-column nullability; IFNULL/COALESCE rescue rule (23.3) cannot be asserted positively") + }) +} diff --git a/tidb/catalog/scenarios_c24_test.go b/tidb/catalog/scenarios_c24_test.go new file mode 100644 index 00000000..2c52e3bb --- /dev/null +++ b/tidb/catalog/scenarios_c24_test.go @@ -0,0 +1,330 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C24 covers Section C24 "SHOW CREATE TABLE skip_gipk / Invisible +// PK" from mysql/catalog/SCENARIOS-mysql-implicit-behavior.md. +// +// MySQL 8.0.30+ supports a Generated Invisible Primary Key (GIPK): when +// `sql_generate_invisible_primary_key=ON` and a CREATE TABLE has no PRIMARY +// KEY declared, MySQL silently inserts a hidden column +// `my_row_id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT INVISIBLE` at position 0 +// and adds a PRIMARY KEY over it. The presence of this column is hidden from +// SHOW CREATE TABLE / information_schema unless +// `show_gipk_in_create_table_and_information_schema=ON` is also set. +// +// omni's catalog does NOT implement GIPK generation today — these scenarios +// document that gap. Failures are recorded in scenarios_bug_queue/c24.md. +// +// All session settings are issued on the pinned single-conn pool from +// scenarioContainer so they persist for the duration of each subtest. The +// session vars are reset to OFF at the end of each subtest so other workers +// (or later subtests) start from a known state. +func TestScenario_C24(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // Helper: set a session variable on the pinned container connection. + setSession := func(t *testing.T, stmt string) { + t.Helper() + if _, err := mc.db.ExecContext(mc.ctx, stmt); err != nil { + t.Errorf("session set %q: %v", stmt, err) + } + } + + // Helper: restore both GIPK session vars to OFF at end of subtest. + resetSession := func(t *testing.T) { + t.Helper() + setSession(t, "SET SESSION sql_generate_invisible_primary_key = OFF") + setSession(t, "SET SESSION show_gipk_in_create_table_and_information_schema = OFF") + } + + // ----------------------------------------------------------------- + // 24.1 GIPK omitted from SHOW CREATE TABLE by default + // ----------------------------------------------------------------- + t.Run("24_1_gipk_hidden_by_default", func(t *testing.T) { + scenarioReset(t, mc) + defer resetSession(t) + c := scenarioNewCatalog(t) + + setSession(t, "SET SESSION sql_generate_invisible_primary_key = ON") + // Explicitly set the visibility flag OFF for this subtest. Note: in + // MySQL 8.0.32+ the default for this var flipped to ON in some + // distributions, so we cannot rely on the session inheriting OFF. + setSession(t, "SET SESSION show_gipk_in_create_table_and_information_schema = OFF") + + ddl := "CREATE TABLE t (a INT, b INT)" + runOnBoth(t, mc, c, ddl) + + // Oracle: SHOW CREATE TABLE under default visibility must NOT include + // my_row_id. + hidden := oracleShow(t, mc, "SHOW CREATE TABLE t") + if strings.Contains(hidden, "my_row_id") { + t.Errorf("oracle 24.1: SHOW CREATE TABLE under default visibility should hide my_row_id, got:\n%s", hidden) + } + + // information_schema.COLUMNS under default visibility also hides it. + var colsHidden int + oracleScan(t, mc, + `SELECT COUNT(*) FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='my_row_id'`, + &colsHidden) + assertIntEq(t, "oracle 24.1 information_schema hidden", colsHidden, 0) + + // Toggle visibility ON: SHOW CREATE TABLE now reveals my_row_id. + setSession(t, "SET SESSION show_gipk_in_create_table_and_information_schema = ON") + shown := oracleShow(t, mc, "SHOW CREATE TABLE t") + if !strings.Contains(shown, "my_row_id") { + t.Errorf("oracle 24.1: SHOW CREATE TABLE under visibility=ON should reveal my_row_id, got:\n%s", shown) + } + var colsShown int + oracleScan(t, mc, + `SELECT COUNT(*) FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='my_row_id'`, + &colsShown) + assertIntEq(t, "oracle 24.1 information_schema visible", colsShown, 1) + + // omni: catalog should have generated my_row_id at position 0 with a + // PRIMARY KEY index. Today's catalog ignores + // sql_generate_invisible_primary_key entirely → expected to fail and + // land in c24.md. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni 24.1: table t missing") + return + } + if tbl.GetColumn("my_row_id") == nil { + t.Errorf("omni 24.1: expected catalog to generate my_row_id GIPK column (sql_generate_invisible_primary_key not honored)") + } + // And the catalog deparser should hide my_row_id under default + // visibility — but we cannot assert deparse output without invoking + // SHOW CREATE TABLE on the catalog side. The presence-or-absence + // check above is sufficient to flag the gap. + }) + + // ----------------------------------------------------------------- + // 24.2 GIPK column spec: name, type, attributes + // ----------------------------------------------------------------- + t.Run("24_2_gipk_column_spec", func(t *testing.T) { + scenarioReset(t, mc) + defer resetSession(t) + c := scenarioNewCatalog(t) + + setSession(t, "SET SESSION sql_generate_invisible_primary_key = ON") + setSession(t, "SET SESSION show_gipk_in_create_table_and_information_schema = ON") + + runOnBoth(t, mc, c, "CREATE TABLE t (a INT)") + + // Oracle: my_row_id has the documented spec. + var colName, colType, isNullable, extra, colKey string + oracleScan(t, mc, + `SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE, EXTRA, COLUMN_KEY + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='my_row_id'`, + &colName, &colType, &isNullable, &extra, &colKey) + assertStringEq(t, "oracle 24.2 name", colName, "my_row_id") + assertStringEq(t, "oracle 24.2 type", strings.ToLower(colType), "bigint unsigned") + assertStringEq(t, "oracle 24.2 nullable", isNullable, "NO") + // EXTRA is e.g. "auto_increment INVISIBLE" + extraLower := strings.ToLower(extra) + if !strings.Contains(extraLower, "auto_increment") { + t.Errorf("oracle 24.2 EXTRA missing auto_increment: %q", extra) + } + if !strings.Contains(extraLower, "invisible") { + t.Errorf("oracle 24.2 EXTRA missing INVISIBLE: %q", extra) + } + assertStringEq(t, "oracle 24.2 column_key", colKey, "PRI") + + // Oracle: my_row_id is the FIRST column of the table. + var firstCol string + oracleScan(t, mc, + `SELECT COLUMN_NAME FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' + ORDER BY ORDINAL_POSITION LIMIT 1`, + &firstCol) + assertStringEq(t, "oracle 24.2 first column", firstCol, "my_row_id") + + // Oracle: SHOW INDEX FROM t has a PRIMARY index over my_row_id. + idxRows := oracleRows(t, mc, "SHOW INDEX FROM t") + foundPK := false + for _, r := range idxRows { + // SHOW INDEX layout: Table, Non_unique, Key_name, Seq_in_index, + // Column_name, ... + if len(r) >= 5 && asString(r[2]) == "PRIMARY" && asString(r[4]) == "my_row_id" { + foundPK = true + break + } + } + assertBoolEq(t, "oracle 24.2 PRIMARY index over my_row_id", foundPK, true) + + // omni: validate the generated column matches MySQL's spec. Expected to + // fail today since omni does not generate GIPK. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni 24.2: table t missing") + return + } + col := tbl.GetColumn("my_row_id") + if col == nil { + t.Errorf("omni 24.2: expected GIPK column my_row_id in catalog (gap)") + return + } + if col.Position != 0 { + t.Errorf("omni 24.2: expected my_row_id at position 0, got %d", col.Position) + } + if !strings.Contains(strings.ToLower(col.ColumnType), "bigint") || + !strings.Contains(strings.ToLower(col.ColumnType), "unsigned") { + t.Errorf("omni 24.2: expected ColumnType bigint unsigned, got %q", col.ColumnType) + } + if col.Nullable { + t.Errorf("omni 24.2: expected NOT NULL, got Nullable=true") + } + if !col.AutoIncrement { + t.Errorf("omni 24.2: expected AutoIncrement=true") + } + if !col.Invisible { + t.Errorf("omni 24.2: expected Invisible=true") + } + // PK index over my_row_id. + pkOK := false + for _, idx := range tbl.Indexes { + if idx.Primary { + if len(idx.Columns) == 1 && strings.EqualFold(idx.Columns[0].Name, "my_row_id") { + pkOK = true + } + break + } + } + if !pkOK { + t.Errorf("omni 24.2: expected PRIMARY KEY (my_row_id) in catalog") + } + }) + + // ----------------------------------------------------------------- + // 24.3 GIPK NOT added when table has explicit PK; UNIQUE NOT NULL does NOT suppress + // ----------------------------------------------------------------- + t.Run("24_3_gipk_suppressed_only_by_pk", func(t *testing.T) { + scenarioReset(t, mc) + defer resetSession(t) + c := scenarioNewCatalog(t) + + setSession(t, "SET SESSION sql_generate_invisible_primary_key = ON") + setSession(t, "SET SESSION show_gipk_in_create_table_and_information_schema = ON") + + runOnBoth(t, mc, c, "CREATE TABLE t1 (id INT PRIMARY KEY, a INT)") + runOnBoth(t, mc, c, "CREATE TABLE t2 (id INT NOT NULL UNIQUE, a INT)") + runOnBoth(t, mc, c, "CREATE TABLE t3 (id INT AUTO_INCREMENT, a INT, PRIMARY KEY (id))") + + // Oracle: t1 — explicit PK → no my_row_id. + var n1, n2, n3 int + oracleScan(t, mc, + `SELECT COUNT(*) FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t1' AND COLUMN_NAME='my_row_id'`, + &n1) + assertIntEq(t, "oracle 24.3 t1 has no GIPK", n1, 0) + + // Oracle: t2 — UNIQUE NOT NULL does NOT suppress GIPK → my_row_id PRESENT. + oracleScan(t, mc, + `SELECT COUNT(*) FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t2' AND COLUMN_NAME='my_row_id'`, + &n2) + assertIntEq(t, "oracle 24.3 t2 has GIPK (UNIQUE NOT NULL is not a PK)", n2, 1) + + // Oracle: t3 — table-level PRIMARY KEY → no my_row_id. + oracleScan(t, mc, + `SELECT COUNT(*) FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t3' AND COLUMN_NAME='my_row_id'`, + &n3) + assertIntEq(t, "oracle 24.3 t3 has no GIPK", n3, 0) + + // omni: t1 should not have my_row_id (catalog ignores GIPK so this is + // trivially true today); t2 SHOULD have my_row_id (gap — omni does not + // generate it); t3 should not. + t1 := c.GetDatabase("testdb").GetTable("t1") + if t1 != nil && t1.GetColumn("my_row_id") != nil { + t.Errorf("omni 24.3: t1 should not have my_row_id (explicit PK present)") + } + t2 := c.GetDatabase("testdb").GetTable("t2") + if t2 == nil { + t.Errorf("omni 24.3: table t2 missing") + } else if t2.GetColumn("my_row_id") == nil { + t.Errorf("omni 24.3: expected GIPK my_row_id on t2 (UNIQUE NOT NULL is not a PK) — gap") + } + t3 := c.GetDatabase("testdb").GetTable("t3") + if t3 != nil && t3.GetColumn("my_row_id") != nil { + t.Errorf("omni 24.3: t3 should not have my_row_id (table-level PK present)") + } + }) + + // ----------------------------------------------------------------- + // 24.4 my_row_id name collision with user-defined column + // ----------------------------------------------------------------- + t.Run("24_4_gipk_name_collision", func(t *testing.T) { + scenarioReset(t, mc) + defer resetSession(t) + + setSession(t, "SET SESSION sql_generate_invisible_primary_key = ON") + + // Oracle: CREATE TABLE with user-declared my_row_id under GIPK=ON + // must fail with ER_GIPK_FAILED_AUTOINC_COLUMN_NAME_RESERVED (4108). + _, oracleErrOn := mc.db.ExecContext(mc.ctx, + "CREATE TABLE t (my_row_id INT, a INT)") + if oracleErrOn == nil { + t.Errorf("oracle 24.4: expected error creating table with my_row_id while GIPK=ON, got nil") + } else if !strings.Contains(oracleErrOn.Error(), "my_row_id") { + t.Errorf("oracle 24.4: error message should mention my_row_id, got: %v", oracleErrOn) + } + + // omni: catalog should reject the same CREATE TABLE while + // sql_generate_invisible_primary_key is conceptually ON. The catalog + // has no notion of session vars today, so this check exercises the + // best-effort path: we expect omni's exec to error if it implemented + // the GIPK reservation. Today omni accepts it → recorded as gap. + c := scenarioNewCatalog(t) + results, err := c.Exec("CREATE TABLE t (my_row_id INT, a INT);", nil) + omniErr := err + if omniErr == nil { + for _, r := range results { + if r.Error != nil { + omniErr = r.Error + break + } + } + } + if omniErr == nil { + t.Errorf("omni 24.4: expected error rejecting user-declared my_row_id under GIPK=ON (gap — catalog has no session-var awareness)") + } + + // Now turn GIPK off and verify both sides accept the same CREATE + // TABLE. We use a fresh table name to avoid clashing with whatever + // omni or oracle may have left behind from the failing path. + setSession(t, "SET SESSION sql_generate_invisible_primary_key = OFF") + scenarioReset(t, mc) + c2 := scenarioNewCatalog(t) + runOnBoth(t, mc, c2, "CREATE TABLE t (my_row_id INT, a INT)") + + // Oracle: my_row_id should exist as an ordinary user column. + var name string + oracleScan(t, mc, + `SELECT COLUMN_NAME FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='my_row_id'`, + &name) + assertStringEq(t, "oracle 24.4 my_row_id user column name", name, "my_row_id") + + // omni: catalog should also have it as an ordinary column. + tbl := c2.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni 24.4: table t missing under GIPK=OFF") + return + } + if tbl.GetColumn("my_row_id") == nil { + t.Errorf("omni 24.4: expected user column my_row_id under GIPK=OFF") + } + }) +} diff --git a/tidb/catalog/scenarios_c25_test.go b/tidb/catalog/scenarios_c25_test.go new file mode 100644 index 00000000..a6fc9a0a --- /dev/null +++ b/tidb/catalog/scenarios_c25_test.go @@ -0,0 +1,314 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C25 covers Section C25 "DECIMAL defaults" from +// mysql/catalog/SCENARIOS-mysql-implicit-behavior.md. +// +// MySQL canonicalizes DECIMAL columns: +// - DECIMAL → DECIMAL(10,0) +// - DECIMAL(M) → DECIMAL(M,0) +// - DECIMAL(M,D) → DECIMAL(M,D), with 1 <= M <= 65 and 0 <= D <= 30, D <= M +// - NUMERIC → synonym, stored as decimal in information_schema +// - UNSIGNED / ZEROFILL → flags after the (M,D) spec +// +// information_schema.COLUMNS exposes NUMERIC_PRECISION/NUMERIC_SCALE and +// COLUMN_TYPE; SHOW CREATE TABLE always renders the fully-qualified +// `decimal(M,D)` form, never the bare DECIMAL or DECIMAL(M) shorthand. +// +// Failures in omni assertions are NOT proof failures — they are recorded in +// mysql/catalog/scenarios_bug_queue/c25.md. +func TestScenario_C25(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // ----------------------------------------------------------------- + // 25.1 DECIMAL with no precision/scale → DECIMAL(10,0) + // ----------------------------------------------------------------- + t.Run("25_1_decimal_default_10_0", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t (d DECIMAL)") + + // Oracle: COLUMN_TYPE / NUMERIC_PRECISION / NUMERIC_SCALE. + var colType string + var prec, scale int + oracleScan(t, mc, + `SELECT COLUMN_TYPE, NUMERIC_PRECISION, NUMERIC_SCALE + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='d'`, + &colType, &prec, &scale) + assertStringEq(t, "oracle column_type", strings.ToLower(colType), "decimal(10,0)") + assertIntEq(t, "oracle numeric_precision", prec, 10) + assertIntEq(t, "oracle numeric_scale", scale, 0) + + // omni: column type renders as decimal(10,0). + col := c25Col(t, c, "t", "d") + if col == nil { + return + } + assertStringEq(t, "omni column_type", strings.ToLower(col.ColumnType), "decimal(10,0)") + }) + + // ----------------------------------------------------------------- + // 25.2 DECIMAL precision-only → scale defaults to 0 + // ----------------------------------------------------------------- + t.Run("25_2_precision_only_scale_zero", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + d5 DECIMAL(5), + d15 DECIMAL(15), + d65 DECIMAL(65) + )`) + + // Oracle: COLUMN_TYPE rows. + rows := oracleRows(t, mc, + `SELECT COLUMN_NAME, COLUMN_TYPE, NUMERIC_PRECISION, NUMERIC_SCALE + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' + ORDER BY ORDINAL_POSITION`) + got := map[string]string{} + for _, r := range rows { + got[asString(r[0])] = strings.ToLower(asString(r[1])) + } + assertStringEq(t, "oracle d5", got["d5"], "decimal(5,0)") + assertStringEq(t, "oracle d15", got["d15"], "decimal(15,0)") + assertStringEq(t, "oracle d65", got["d65"], "decimal(65,0)") + + // Oracle: SHOW CREATE TABLE must also use the explicit zero scale. + create := oracleShow(t, mc, "SHOW CREATE TABLE t") + lower := strings.ToLower(create) + for _, want := range []string{"decimal(5,0)", "decimal(15,0)", "decimal(65,0)"} { + if !strings.Contains(lower, want) { + t.Errorf("oracle SHOW CREATE TABLE: missing %q in:\n%s", want, create) + } + } + + // omni + for _, want := range []struct{ name, typ string }{ + {"d5", "decimal(5,0)"}, + {"d15", "decimal(15,0)"}, + {"d65", "decimal(65,0)"}, + } { + col := c25Col(t, c, "t", want.name) + if col == nil { + continue + } + assertStringEq(t, "omni "+want.name+" column_type", + strings.ToLower(col.ColumnType), want.typ) + } + }) + + // ----------------------------------------------------------------- + // 25.3 DECIMAL bounds: max P=65, S=30, scale > precision rejection + // ----------------------------------------------------------------- + t.Run("25_3_bounds_max_p_s_and_scale_gt_precision", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // 25.3a: DECIMAL(65,30) accepted on both sides. + runOnBoth(t, mc, c, "CREATE TABLE ok_max_p (d DECIMAL(65,30))") + + var colType string + oracleScan(t, mc, + `SELECT COLUMN_TYPE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='ok_max_p' AND COLUMN_NAME='d'`, + &colType) + assertStringEq(t, "oracle ok_max_p column_type", + strings.ToLower(colType), "decimal(65,30)") + if col := c25Col(t, c, "ok_max_p", "d"); col != nil { + assertStringEq(t, "omni ok_max_p column_type", + strings.ToLower(col.ColumnType), "decimal(65,30)") + } + + // 25.3b: DECIMAL(66,0) → ER_TOO_BIG_PRECISION (1426). + c25AssertBothError(t, mc, c, + "CREATE TABLE err_p_gt_65 (d DECIMAL(66, 0))", + "precision 66", "25.3b precision > 65") + + // 25.3c: DECIMAL(40,31) → ER_TOO_BIG_SCALE (1425). + c25AssertBothError(t, mc, c, + "CREATE TABLE err_s_gt_30 (d DECIMAL(40, 31))", + "scale 31", "25.3c scale > 30") + + // 25.3d: DECIMAL(5,6) → ER_M_BIGGER_THAN_D (1427). + c25AssertBothError(t, mc, c, + "CREATE TABLE err_s_gt_p (d DECIMAL(5, 6))", + "M must be >= D", "25.3d scale > precision") + + // 25.3e: DECIMAL(-1,0) → parse error on both sides. + c25AssertBothError(t, mc, c, + "CREATE TABLE err_neg (d DECIMAL(-1, 0))", + "", "25.3e negative precision") + }) + + // ----------------------------------------------------------------- + // 25.4 UNSIGNED / ZEROFILL / NUMERIC synonym + // ----------------------------------------------------------------- + t.Run("25_4_unsigned_zerofill_numeric_synonym", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + d1 DECIMAL(10,2) UNSIGNED, + d2 DECIMAL(10,2) UNSIGNED ZEROFILL, + d3 NUMERIC(10,2) UNSIGNED + )`) + + // Oracle: NUMERIC reported as decimal; precision=10, scale=2 for all. + rows := oracleRows(t, mc, + `SELECT COLUMN_NAME, COLUMN_TYPE, DATA_TYPE, NUMERIC_PRECISION, NUMERIC_SCALE + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' + ORDER BY ORDINAL_POSITION`) + want := map[string]string{ + "d1": "decimal(10,2) unsigned", + "d2": "decimal(10,2) unsigned zerofill", + "d3": "decimal(10,2) unsigned", + } + for _, r := range rows { + name := asString(r[0]) + ct := strings.ToLower(asString(r[1])) + dt := strings.ToLower(asString(r[2])) + assertStringEq(t, "oracle "+name+" column_type", ct, want[name]) + assertStringEq(t, "oracle "+name+" data_type", dt, "decimal") + } + + // omni: catalog should round-trip these as well. + for _, name := range []string{"d1", "d2", "d3"} { + col := c25Col(t, c, "t", name) + if col == nil { + continue + } + ct := strings.ToLower(col.ColumnType) + if !strings.Contains(ct, "decimal(10,2)") { + t.Errorf("omni %s column_type: got %q, want substring decimal(10,2)", name, ct) + } + if !strings.Contains(ct, "unsigned") { + t.Errorf("omni %s column_type: got %q, want substring unsigned", name, ct) + } + if name == "d2" && !strings.Contains(ct, "zerofill") { + t.Errorf("omni d2 column_type: got %q, want substring zerofill", ct) + } + } + }) + + // ----------------------------------------------------------------- + // 25.5 Zero-scale rendering: DECIMAL / DECIMAL(5) / DECIMAL(5,0) all collapse + // ----------------------------------------------------------------- + t.Run("25_5_zero_scale_rendering", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a DECIMAL, + b DECIMAL(5), + c DECIMAL(5,0), + d DECIMAL(10,0) + )`) + + // Oracle: information_schema rows. + rows := oracleRows(t, mc, + `SELECT COLUMN_NAME, COLUMN_TYPE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' + ORDER BY ORDINAL_POSITION`) + expected := map[string]string{ + "a": "decimal(10,0)", + "b": "decimal(5,0)", + "c": "decimal(5,0)", + "d": "decimal(10,0)", + } + for _, r := range rows { + n := asString(r[0]) + ct := strings.ToLower(asString(r[1])) + assertStringEq(t, "oracle "+n+" column_type", ct, expected[n]) + } + + // Oracle: SHOW CREATE TABLE renders explicit (M,0). + create := strings.ToLower(oracleShow(t, mc, "SHOW CREATE TABLE t")) + for _, want := range []string{"decimal(10,0)", "decimal(5,0)"} { + if !strings.Contains(create, want) { + t.Errorf("oracle SHOW CREATE TABLE: missing %q in:\n%s", want, create) + } + } + // And must NOT render the bare or single-arg form. + for _, bad := range []string{"decimal(5)", "decimal(10)", "decimal,"} { + if strings.Contains(create, bad) { + t.Errorf("oracle SHOW CREATE TABLE: should not contain %q in:\n%s", bad, create) + } + } + + // omni + for name, want := range expected { + col := c25Col(t, c, "t", name) + if col == nil { + continue + } + assertStringEq(t, "omni "+name+" column_type", + strings.ToLower(col.ColumnType), want) + } + }) +} + +// c25Col fetches a column from testdb.
, reporting via t.Error when +// missing rather than aborting the test. +func c25Col(t *testing.T, c *Catalog, table, col string) *Column { + t.Helper() + db := c.GetDatabase("testdb") + if db == nil { + t.Errorf("omni: database testdb missing") + return nil + } + tbl := db.GetTable(table) + if tbl == nil { + t.Errorf("omni: table %s missing", table) + return nil + } + column := tbl.GetColumn(col) + if column == nil { + t.Errorf("omni: column %s.%s missing", table, col) + return nil + } + return column +} + +// c25AssertBothError runs the same DDL on the MySQL container and the omni +// catalog, asserting that BOTH return an error. If wantSubstr is non-empty it +// must appear in the MySQL container error message (case-insensitive). label +// is used only in error messages for easier triage. +func c25AssertBothError(t *testing.T, mc *mysqlContainer, c *Catalog, ddl, wantSubstr, label string) { + t.Helper() + + _, oracleErr := mc.db.ExecContext(mc.ctx, ddl) + if oracleErr == nil { + t.Errorf("oracle %s: expected error for %q, got nil", label, ddl) + } else if wantSubstr != "" && + !strings.Contains(strings.ToLower(oracleErr.Error()), strings.ToLower(wantSubstr)) { + t.Errorf("oracle %s: error %q missing substring %q", label, oracleErr.Error(), wantSubstr) + } + + results, err := c.Exec(ddl+";", nil) + if err != nil { + // Parse-level rejection is fine — counts as an error from omni. + return + } + sawErr := false + for _, r := range results { + if r.Error != nil { + sawErr = true + break + } + } + if !sawErr { + t.Errorf("omni %s: expected error for %q, got nil", label, ddl) + } +} diff --git a/tidb/catalog/scenarios_c2_test.go b/tidb/catalog/scenarios_c2_test.go new file mode 100644 index 00000000..238b5757 --- /dev/null +++ b/tidb/catalog/scenarios_c2_test.go @@ -0,0 +1,599 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C2 runs the "Type normalization" section of +// SCENARIOS-mysql-implicit-behavior.md (section C2). Each subtest executes a +// DDL against both a real MySQL 8.0 container and the omni catalog, then +// asserts that both agree on the expected normalized type rendering. +// +// Per the worker protocol, failed omni assertions are NOT test infrastructure +// failures — they are tracked as discovered bugs in scenarios_bug_queue/c2.md. +// The test uses t.Error (not t.Fatal) so every scenario reports all its +// diffs in a single run. +func TestScenario_C2(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // Helper: read oracle COLUMN_TYPE for testdb.t.. + oracleColumnType := func(t *testing.T, col string) string { + t.Helper() + var s string + oracleScan(t, mc, + "SELECT COLUMN_TYPE FROM information_schema.COLUMNS "+ + "WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='"+col+"'", + &s) + return strings.ToLower(s) + } + oracleDataType := func(t *testing.T, col string) string { + t.Helper() + var s string + oracleScan(t, mc, + "SELECT DATA_TYPE FROM information_schema.COLUMNS "+ + "WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='"+col+"'", + &s) + return strings.ToLower(s) + } + oracleCharset := func(t *testing.T, col string) string { + t.Helper() + var s string + oracleScan(t, mc, + "SELECT IFNULL(CHARACTER_SET_NAME,'') FROM information_schema.COLUMNS "+ + "WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='"+col+"'", + &s) + return strings.ToLower(s) + } + omniCol := func(t *testing.T, c *Catalog, col string) *Column { + t.Helper() + db := c.GetDatabase("testdb") + if db == nil { + t.Error("omni: database testdb not found") + return nil + } + tbl := db.GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return nil + } + cc := tbl.GetColumn(col) + if cc == nil { + t.Errorf("omni: column %q not found", col) + return nil + } + return cc + } + + // 2.1 REAL → DOUBLE + t.Run("2_1_REAL_to_DOUBLE", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (c REAL)`) + + assertStringEq(t, "oracle DATA_TYPE", oracleDataType(t, "c"), "double") + assertStringEq(t, "oracle COLUMN_TYPE", oracleColumnType(t, "c"), "double") + + if col := omniCol(t, c, "c"); col != nil { + assertStringEq(t, "omni DataType", strings.ToLower(col.DataType), "double") + } + }) + + // 2.2 BOOL → TINYINT(1) + t.Run("2_2_BOOL_to_TINYINT1", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (c BOOL)`) + + assertStringEq(t, "oracle COLUMN_TYPE", oracleColumnType(t, "c"), "tinyint(1)") + assertStringEq(t, "oracle DATA_TYPE", oracleDataType(t, "c"), "tinyint") + + if col := omniCol(t, c, "c"); col != nil { + assertStringEq(t, "omni DataType", strings.ToLower(col.DataType), "tinyint") + assertStringEq(t, "omni ColumnType", strings.ToLower(col.ColumnType), "tinyint(1)") + } + }) + + // 2.3 INTEGER → INT + t.Run("2_3_INTEGER_to_INT", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (c INTEGER)`) + + assertStringEq(t, "oracle DATA_TYPE", oracleDataType(t, "c"), "int") + assertStringEq(t, "oracle COLUMN_TYPE", oracleColumnType(t, "c"), "int") + + if col := omniCol(t, c, "c"); col != nil { + assertStringEq(t, "omni DataType", strings.ToLower(col.DataType), "int") + } + }) + + // 2.4 BOOLEAN → TINYINT(1) + t.Run("2_4_BOOLEAN_to_TINYINT1", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (c BOOLEAN)`) + + assertStringEq(t, "oracle COLUMN_TYPE", oracleColumnType(t, "c"), "tinyint(1)") + + if col := omniCol(t, c, "c"); col != nil { + assertStringEq(t, "omni DataType", strings.ToLower(col.DataType), "tinyint") + assertStringEq(t, "omni ColumnType", strings.ToLower(col.ColumnType), "tinyint(1)") + } + }) + + // 2.5 INT1/INT2/INT3/INT4/INT8 → TINYINT/SMALLINT/MEDIUMINT/INT/BIGINT + t.Run("2_5_INTN_aliases", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, + `CREATE TABLE t (a INT1, b INT2, cc INT3, d INT4, e INT8)`) + + cases := []struct { + name, want string + }{ + {"a", "tinyint"}, + {"b", "smallint"}, + {"cc", "mediumint"}, + {"d", "int"}, + {"e", "bigint"}, + } + for _, cc := range cases { + assertStringEq(t, "oracle DATA_TYPE "+cc.name, oracleDataType(t, cc.name), cc.want) + if col := omniCol(t, c, cc.name); col != nil { + assertStringEq(t, "omni DataType "+cc.name, strings.ToLower(col.DataType), cc.want) + } + } + }) + + // 2.6 MIDDLEINT → MEDIUMINT + t.Run("2_6_MIDDLEINT_to_MEDIUMINT", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (c MIDDLEINT)`) + + assertStringEq(t, "oracle DATA_TYPE", oracleDataType(t, "c"), "mediumint") + + if col := omniCol(t, c, "c"); col != nil { + assertStringEq(t, "omni DataType", strings.ToLower(col.DataType), "mediumint") + } + }) + + // 2.7 INT(11) display width deprecated → stripped from output + t.Run("2_7_INT11_width_stripped", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (c INT(11))`) + + assertStringEq(t, "oracle COLUMN_TYPE", oracleColumnType(t, "c"), "int") + + if col := omniCol(t, c, "c"); col != nil { + assertStringEq(t, "omni ColumnType", strings.ToLower(col.ColumnType), "int") + } + }) + + // 2.8 INT(N) ZEROFILL → preserves display width + implies UNSIGNED + t.Run("2_8_INT5_ZEROFILL", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (c INT(5) ZEROFILL)`) + + assertStringEq(t, "oracle COLUMN_TYPE", oracleColumnType(t, "c"), "int(5) unsigned zerofill") + + if col := omniCol(t, c, "c"); col != nil { + assertStringEq(t, "omni ColumnType", strings.ToLower(col.ColumnType), "int(5) unsigned zerofill") + } + }) + + // 2.9 SERIAL → BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE + t.Run("2_9_SERIAL", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (c SERIAL)`) + + assertStringEq(t, "oracle COLUMN_TYPE", oracleColumnType(t, "c"), "bigint unsigned") + var nullable, extra string + oracleScan(t, mc, + `SELECT IS_NULLABLE, EXTRA FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='c'`, + &nullable, &extra) + assertStringEq(t, "oracle IS_NULLABLE", nullable, "NO") + assertStringEq(t, "oracle EXTRA", strings.ToLower(extra), "auto_increment") + + if col := omniCol(t, c, "c"); col != nil { + assertStringEq(t, "omni ColumnType", strings.ToLower(col.ColumnType), "bigint unsigned") + assertBoolEq(t, "omni Nullable", col.Nullable, false) + assertBoolEq(t, "omni AutoIncrement", col.AutoIncrement, true) + } + // implicit UNIQUE + db := c.GetDatabase("testdb") + if db != nil { + tbl := db.GetTable("t") + if tbl != nil { + foundUnique := false + for _, idx := range tbl.Indexes { + if idx.Unique && len(idx.Columns) == 1 && strings.EqualFold(idx.Columns[0].Name, "c") { + foundUnique = true + break + } + } + assertBoolEq(t, "omni SERIAL implicit UNIQUE index", foundUnique, true) + } + } + }) + + // 2.10 NUMERIC → DECIMAL + t.Run("2_10_NUMERIC_to_DECIMAL", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (c NUMERIC(10,2))`) + + assertStringEq(t, "oracle COLUMN_TYPE", oracleColumnType(t, "c"), "decimal(10,2)") + + if col := omniCol(t, c, "c"); col != nil { + assertStringEq(t, "omni ColumnType", strings.ToLower(col.ColumnType), "decimal(10,2)") + assertStringEq(t, "omni DataType", strings.ToLower(col.DataType), "decimal") + } + }) + + // 2.11 DEC and FIXED → DECIMAL + t.Run("2_11_DEC_FIXED_to_DECIMAL", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (a DEC(6,2), b FIXED(6,2))`) + + assertStringEq(t, "oracle COLUMN_TYPE a", oracleColumnType(t, "a"), "decimal(6,2)") + assertStringEq(t, "oracle COLUMN_TYPE b", oracleColumnType(t, "b"), "decimal(6,2)") + + if col := omniCol(t, c, "a"); col != nil { + assertStringEq(t, "omni ColumnType a", strings.ToLower(col.ColumnType), "decimal(6,2)") + } + if col := omniCol(t, c, "b"); col != nil { + assertStringEq(t, "omni ColumnType b", strings.ToLower(col.ColumnType), "decimal(6,2)") + } + }) + + // 2.12 DOUBLE PRECISION → DOUBLE + t.Run("2_12_DOUBLE_PRECISION", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (c DOUBLE PRECISION)`) + + assertStringEq(t, "oracle DATA_TYPE", oracleDataType(t, "c"), "double") + + if col := omniCol(t, c, "c"); col != nil { + assertStringEq(t, "omni DataType", strings.ToLower(col.DataType), "double") + } + }) + + // 2.13 FLOAT4 → FLOAT, FLOAT8 → DOUBLE + t.Run("2_13_FLOAT4_FLOAT8", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (a FLOAT4, b FLOAT8)`) + + assertStringEq(t, "oracle a", oracleDataType(t, "a"), "float") + assertStringEq(t, "oracle b", oracleDataType(t, "b"), "double") + + if col := omniCol(t, c, "a"); col != nil { + assertStringEq(t, "omni a DataType", strings.ToLower(col.DataType), "float") + } + if col := omniCol(t, c, "b"); col != nil { + assertStringEq(t, "omni b DataType", strings.ToLower(col.DataType), "double") + } + }) + + // 2.14 FLOAT(p) precision split + t.Run("2_14_FLOAT_precision_split", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (a FLOAT(10), b FLOAT(25))`) + + assertStringEq(t, "oracle a", oracleDataType(t, "a"), "float") + assertStringEq(t, "oracle b", oracleDataType(t, "b"), "double") + + if col := omniCol(t, c, "a"); col != nil { + assertStringEq(t, "omni a DataType", strings.ToLower(col.DataType), "float") + } + if col := omniCol(t, c, "b"); col != nil { + assertStringEq(t, "omni b DataType", strings.ToLower(col.DataType), "double") + } + }) + + // 2.15 FLOAT(M,D) deprecated but preserved + t.Run("2_15_FLOAT_M_D", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (c FLOAT(7,4))`) + + assertStringEq(t, "oracle COLUMN_TYPE", oracleColumnType(t, "c"), "float(7,4)") + + if col := omniCol(t, c, "c"); col != nil { + assertStringEq(t, "omni ColumnType", strings.ToLower(col.ColumnType), "float(7,4)") + } + }) + + // 2.16 CHARACTER → CHAR, CHARACTER VARYING → VARCHAR + t.Run("2_16_CHARACTER_VARYING", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (a CHARACTER(10), b CHARACTER VARYING(20))`) + + assertStringEq(t, "oracle a", oracleColumnType(t, "a"), "char(10)") + assertStringEq(t, "oracle b", oracleColumnType(t, "b"), "varchar(20)") + + if col := omniCol(t, c, "a"); col != nil { + assertStringEq(t, "omni a ColumnType", strings.ToLower(col.ColumnType), "char(10)") + } + if col := omniCol(t, c, "b"); col != nil { + assertStringEq(t, "omni b ColumnType", strings.ToLower(col.ColumnType), "varchar(20)") + } + }) + + // 2.17 NATIONAL CHAR / NCHAR → CHAR utf8mb3 + t.Run("2_17_NATIONAL_CHAR", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (a NATIONAL CHAR(10), b NCHAR(10))`) + + assertStringEq(t, "oracle a COLUMN_TYPE", oracleColumnType(t, "a"), "char(10)") + assertStringEq(t, "oracle b COLUMN_TYPE", oracleColumnType(t, "b"), "char(10)") + assertStringEq(t, "oracle a CHARSET", oracleCharset(t, "a"), "utf8mb3") + assertStringEq(t, "oracle b CHARSET", oracleCharset(t, "b"), "utf8mb3") + + if col := omniCol(t, c, "a"); col != nil { + assertStringEq(t, "omni a ColumnType", strings.ToLower(col.ColumnType), "char(10)") + assertStringEq(t, "omni a Charset", strings.ToLower(col.Charset), "utf8mb3") + } + if col := omniCol(t, c, "b"); col != nil { + assertStringEq(t, "omni b ColumnType", strings.ToLower(col.ColumnType), "char(10)") + assertStringEq(t, "omni b Charset", strings.ToLower(col.Charset), "utf8mb3") + } + }) + + // 2.18 NVARCHAR family → VARCHAR utf8mb3 + t.Run("2_18_NVARCHAR_family", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t ( + a NVARCHAR(10), + b NATIONAL VARCHAR(10), + cc NCHAR VARCHAR(10), + d NATIONAL CHAR VARYING(10), + e NCHAR VARYING(10) +)`) + + for _, name := range []string{"a", "b", "cc", "d", "e"} { + assertStringEq(t, "oracle "+name+" COLUMN_TYPE", oracleColumnType(t, name), "varchar(10)") + assertStringEq(t, "oracle "+name+" CHARSET", oracleCharset(t, name), "utf8mb3") + + if col := omniCol(t, c, name); col != nil { + assertStringEq(t, "omni "+name+" ColumnType", strings.ToLower(col.ColumnType), "varchar(10)") + assertStringEq(t, "omni "+name+" Charset", strings.ToLower(col.Charset), "utf8mb3") + } + } + }) + + // 2.19 LONG / LONG VARCHAR → MEDIUMTEXT; LONG VARBINARY → MEDIUMBLOB + t.Run("2_19_LONG_aliases", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (a LONG, b LONG VARCHAR, cc LONG VARBINARY)`) + + assertStringEq(t, "oracle a", oracleDataType(t, "a"), "mediumtext") + assertStringEq(t, "oracle b", oracleDataType(t, "b"), "mediumtext") + assertStringEq(t, "oracle cc", oracleDataType(t, "cc"), "mediumblob") + + if col := omniCol(t, c, "a"); col != nil { + assertStringEq(t, "omni a DataType", strings.ToLower(col.DataType), "mediumtext") + } + if col := omniCol(t, c, "b"); col != nil { + assertStringEq(t, "omni b DataType", strings.ToLower(col.DataType), "mediumtext") + } + if col := omniCol(t, c, "cc"); col != nil { + assertStringEq(t, "omni cc DataType", strings.ToLower(col.DataType), "mediumblob") + } + }) + + // 2.20 CHAR and BINARY default to length 1 + t.Run("2_20_CHAR_BINARY_default_length", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (a CHAR, b BINARY)`) + + assertStringEq(t, "oracle a COLUMN_TYPE", oracleColumnType(t, "a"), "char(1)") + assertStringEq(t, "oracle b COLUMN_TYPE", oracleColumnType(t, "b"), "binary(1)") + + if col := omniCol(t, c, "a"); col != nil { + assertStringEq(t, "omni a ColumnType", strings.ToLower(col.ColumnType), "char(1)") + } + if col := omniCol(t, c, "b"); col != nil { + assertStringEq(t, "omni b ColumnType", strings.ToLower(col.ColumnType), "binary(1)") + } + }) + + // 2.21 VARCHAR without length is a syntax error + t.Run("2_21_VARCHAR_no_length_is_error", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // Oracle: expect failure. + if _, err := mc.db.ExecContext(mc.ctx, `CREATE TABLE t (c VARCHAR)`); err == nil { + t.Error("oracle: expected syntax error for bare VARCHAR, got nil") + } + + // omni: expect failure. + results, err := c.Exec(`CREATE TABLE t (c VARCHAR);`, nil) + sawErr := err != nil + if !sawErr { + for _, r := range results { + if r.Error != nil { + sawErr = true + break + } + } + } + if !sawErr { + t.Error("omni: expected parse/exec error for bare VARCHAR, got success") + } + }) + + // 2.22 TIMESTAMP/DATETIME/TIME default fsp=0, explicit fsp preserved + t.Run("2_22_temporal_fsp", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t ( + a TIMESTAMP NULL, + b DATETIME, + cc TIME, + d TIMESTAMP(6) NULL, + e DATETIME(6), + f TIME(3) +)`) + + assertStringEq(t, "oracle a", oracleColumnType(t, "a"), "timestamp") + assertStringEq(t, "oracle b", oracleColumnType(t, "b"), "datetime") + assertStringEq(t, "oracle cc", oracleColumnType(t, "cc"), "time") + assertStringEq(t, "oracle d", oracleColumnType(t, "d"), "timestamp(6)") + assertStringEq(t, "oracle e", oracleColumnType(t, "e"), "datetime(6)") + assertStringEq(t, "oracle f", oracleColumnType(t, "f"), "time(3)") + + if col := omniCol(t, c, "a"); col != nil { + assertStringEq(t, "omni a ColumnType", strings.ToLower(col.ColumnType), "timestamp") + } + if col := omniCol(t, c, "b"); col != nil { + assertStringEq(t, "omni b ColumnType", strings.ToLower(col.ColumnType), "datetime") + } + if col := omniCol(t, c, "cc"); col != nil { + assertStringEq(t, "omni cc ColumnType", strings.ToLower(col.ColumnType), "time") + } + if col := omniCol(t, c, "d"); col != nil { + assertStringEq(t, "omni d ColumnType", strings.ToLower(col.ColumnType), "timestamp(6)") + } + if col := omniCol(t, c, "e"); col != nil { + assertStringEq(t, "omni e ColumnType", strings.ToLower(col.ColumnType), "datetime(6)") + } + if col := omniCol(t, c, "f"); col != nil { + assertStringEq(t, "omni f ColumnType", strings.ToLower(col.ColumnType), "time(3)") + } + }) + + // 2.23 YEAR(4) deprecated → stored as YEAR + t.Run("2_23_YEAR_4_bare_YEAR", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (c YEAR(4))`) + + assertStringEq(t, "oracle COLUMN_TYPE", oracleColumnType(t, "c"), "year") + + if col := omniCol(t, c, "c"); col != nil { + assertStringEq(t, "omni ColumnType", strings.ToLower(col.ColumnType), "year") + } + }) + + // 2.24 BIT without length defaults to BIT(1) + t.Run("2_24_BIT_default_1", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t (c BIT)`) + + assertStringEq(t, "oracle COLUMN_TYPE", oracleColumnType(t, "c"), "bit(1)") + + if col := omniCol(t, c, "c"); col != nil { + assertStringEq(t, "omni ColumnType", strings.ToLower(col.ColumnType), "bit(1)") + } + }) + + // 2.25 VARCHAR(65536) in non-strict → TEXT family + t.Run("2_25_VARCHAR_overflow_to_text", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // Set sql_mode='' on this session (oracle only — omni has no session state). + if _, err := mc.db.ExecContext(mc.ctx, "SET SESSION sql_mode=''"); err != nil { + t.Errorf("oracle SET sql_mode: %v", err) + } + // Restore strict mode after to avoid leaking. + defer func() { + _, _ = mc.db.ExecContext(mc.ctx, + "SET SESSION sql_mode='STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION'") + }() + + // Oracle side. + if _, err := mc.db.ExecContext(mc.ctx, `CREATE TABLE t (c VARCHAR(65536))`); err != nil { + t.Errorf("oracle CREATE (non-strict) failed: %v", err) + } + var gotType string + oracleScan(t, mc, + `SELECT DATA_TYPE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='c'`, + &gotType) + gotType = strings.ToLower(gotType) + if gotType != "mediumtext" && gotType != "text" { + t.Errorf("oracle: expected mediumtext or text, got %q", gotType) + } + + // omni side — we document the current behavior rather than asserting + // a specific outcome, since omni's byte-length → text promotion is + // a known gap (scenario 2.25 is pending-verify). + results, err := c.Exec(`CREATE TABLE t (c VARCHAR(65536));`, nil) + omniErr := err != nil + if !omniErr { + for _, r := range results { + if r.Error != nil { + omniErr = true + break + } + } + } + if !omniErr { + if col := omniCol(t, c, "c"); col != nil { + dt := strings.ToLower(col.DataType) + if dt != "mediumtext" && dt != "text" { + t.Errorf("omni: expected mediumtext/text, got DataType=%q ColumnType=%q", + dt, col.ColumnType) + } + } + } else { + // omni raised an error — acceptable only if we were in strict mode, + // but since omni has no session state, this is a diff vs oracle. + t.Errorf("omni: raised error for VARCHAR(65536) but oracle (non-strict) accepted it as %s", gotType) + } + }) + + // 2.26 TEXT(N) / BLOB(N) → TINY/TEXT/MEDIUM/LONG by byte count + t.Run("2_26_TEXT_N_promotion", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + runOnBoth(t, mc, c, `CREATE TABLE t ( + a TEXT(100), + b TEXT(1000), + cc TEXT(70000), + d TEXT(20000000) +)`) + + // NOTE: SCENARIOS says a TEXT(100) → tinytext, but with default + // utf8mb4 charset 100 chars = 400 bytes, exceeding tinytext's + // 255-byte cap, so MySQL promotes to text. Trust the oracle — + // see scenarios_bug_queue/c2.md for the scenario-doc mismatch. + cases := []struct { + name, want string + }{ + {"a", "text"}, + {"b", "text"}, + {"cc", "mediumtext"}, + {"d", "longtext"}, + } + for _, cc := range cases { + assertStringEq(t, "oracle "+cc.name, oracleDataType(t, cc.name), cc.want) + if col := omniCol(t, c, cc.name); col != nil { + assertStringEq(t, "omni "+cc.name+" DataType", + strings.ToLower(col.DataType), cc.want) + } + } + }) +} diff --git a/tidb/catalog/scenarios_c3_test.go b/tidb/catalog/scenarios_c3_test.go new file mode 100644 index 00000000..1d5735f7 --- /dev/null +++ b/tidb/catalog/scenarios_c3_test.go @@ -0,0 +1,375 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C3 covers section C3 (Nullability & default promotion) from +// SCENARIOS-mysql-implicit-behavior.md. Each subtest asserts that both real +// MySQL 8.0 and the omni catalog agree on the implicit nullability/default +// rules MySQL applies during CREATE TABLE. +// +// Failures in omni assertions are NOT proof failures — they are recorded in +// mysql/catalog/scenarios_bug_queue/c3.md. +func TestScenario_C3(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // --- 3.1 First TIMESTAMP NOT NULL auto-promotes (legacy mode) ------- + t.Run("3_1_first_timestamp_promotion", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // Force legacy mode on the container so the first-TS promotion fires. + // Also clear sql_mode so MySQL accepts the implicit '0000-00-00' zero + // default for the second TIMESTAMP column. omni does not track these + // session vars; we still verify omni's static behavior. + setLegacyTimestampMode(t, mc) + defer restoreTimestampMode(t, mc) + + ddl := `CREATE TABLE t ( + ts1 TIMESTAMP NOT NULL, + ts2 TIMESTAMP NOT NULL +)` + // Run on container directly — runOnBoth would split sql_mode. + if _, err := mc.db.ExecContext(mc.ctx, ddl); err != nil { + t.Errorf("oracle CREATE TABLE failed: %v", err) + } + results, err := c.Exec(ddl+";", nil) + if err != nil { + t.Errorf("omni parse error: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Errorf("omni exec error: %v", r.Error) + } + } + + // Oracle: SHOW CREATE TABLE on container. + create := oracleShow(t, mc, "SHOW CREATE TABLE t") + lo := strings.ToLower(create) + if !strings.Contains(lo, "ts1") || + !strings.Contains(lo, "default current_timestamp") || + !strings.Contains(lo, "on update current_timestamp") { + t.Errorf("oracle: expected ts1 promoted to CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP\n%s", create) + } + // ts2 should NOT carry an ON UPDATE CURRENT_TIMESTAMP — only one occurrence overall. + if strings.Count(lo, "on update current_timestamp") != 1 { + t.Errorf("oracle: expected exactly one ON UPDATE CURRENT_TIMESTAMP (ts1 only)\n%s", create) + } + + // omni: this is omni's "static" view. omni does not track + // explicit_defaults_for_timestamp, so the expectation here documents + // what omni *currently* does. We accept either: + // (a) omni promotes the first TIMESTAMP NOT NULL (matches legacy) + // (b) omni leaves both ts1 and ts2 alone (matches default mode) + // Anything else is a bug. Document the asymmetry rather than failing. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + ts1 := tbl.GetColumn("ts1") + ts2 := tbl.GetColumn("ts2") + if ts1 == nil || ts2 == nil { + t.Errorf("omni: ts1/ts2 columns missing") + return + } + // Expected omni (default-mode behavior): no promotion. + if ts1.Default != nil && strings.Contains(strings.ToLower(*ts1.Default), "current_timestamp") { + // Acceptable: omni mirrors legacy mode. + t.Logf("omni: ts1 has CURRENT_TIMESTAMP default — promotion happens unconditionally") + } + if ts2.Default != nil && strings.Contains(strings.ToLower(*ts2.Default), "current_timestamp") { + t.Errorf("omni: ts2 should NEVER auto-promote to CURRENT_TIMESTAMP, got %q", *ts2.Default) + } + if ts2.OnUpdate != "" && strings.Contains(strings.ToLower(ts2.OnUpdate), "current_timestamp") { + t.Errorf("omni: ts2 should NEVER auto-promote ON UPDATE, got %q", ts2.OnUpdate) + } + }) + + // --- 3.2 PRIMARY KEY implies NOT NULL ----------------------------------- + t.Run("3_2_primary_key_implies_not_null", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t (a INT PRIMARY KEY)") + + var isNullable string + oracleScan(t, mc, + `SELECT IS_NULLABLE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='a'`, + &isNullable) + assertStringEq(t, "oracle IS_NULLABLE", isNullable, "NO") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + col := tbl.GetColumn("a") + if col == nil { + t.Errorf("omni: column a missing") + return + } + assertBoolEq(t, "omni column a Nullable", col.Nullable, false) + }) + + // --- 3.3 AUTO_INCREMENT implies NOT NULL -------------------------------- + t.Run("3_3_auto_increment_implies_not_null", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t (a INT AUTO_INCREMENT, KEY (a))") + + var isNullable string + oracleScan(t, mc, + `SELECT IS_NULLABLE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='a'`, + &isNullable) + assertStringEq(t, "oracle IS_NULLABLE", isNullable, "NO") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + col := tbl.GetColumn("a") + if col == nil { + t.Errorf("omni: column a missing") + return + } + assertBoolEq(t, "omni column a Nullable", col.Nullable, false) + }) + + // --- 3.4 Explicit NULL on PRIMARY KEY column is a hard error ------------ + t.Run("3_4_explicit_null_pk_errors", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // Oracle: MySQL must reject. + _, mysqlErr := mc.db.ExecContext(mc.ctx, + "CREATE TABLE t (a INT NULL PRIMARY KEY)") + if mysqlErr == nil { + t.Errorf("oracle: expected ER_PRIMARY_CANT_HAVE_NULL (1171), got nil") + } else if !strings.Contains(mysqlErr.Error(), "1171") && + !strings.Contains(strings.ToLower(mysqlErr.Error()), "primary") { + t.Errorf("oracle: expected 1171 ER_PRIMARY_CANT_HAVE_NULL, got %v", mysqlErr) + } + + // omni: should also reject. + results, err := c.Exec("CREATE TABLE t (a INT NULL PRIMARY KEY);", nil) + omniErrored := err != nil + if !omniErrored { + for _, r := range results { + if r.Error != nil { + omniErrored = true + break + } + } + } + assertBoolEq(t, "omni rejects NULL + PK", omniErrored, true) + }) + + // --- 3.5 UNIQUE does NOT imply NOT NULL --------------------------------- + t.Run("3_5_unique_does_not_imply_not_null", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, "CREATE TABLE t (a INT UNIQUE)") + + var isNullable string + oracleScan(t, mc, + `SELECT IS_NULLABLE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='a'`, + &isNullable) + assertStringEq(t, "oracle IS_NULLABLE (UNIQUE)", isNullable, "YES") + + // Verify a UNIQUE index actually exists on the container side. + var idxName string + oracleScan(t, mc, + `SELECT INDEX_NAME FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND NON_UNIQUE=0`, + &idxName) + if idxName == "" { + t.Errorf("oracle: expected one UNIQUE index, got none") + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + col := tbl.GetColumn("a") + if col == nil { + t.Errorf("omni: column a missing") + return + } + assertBoolEq(t, "omni column a Nullable (UNIQUE)", col.Nullable, true) + + omniIdx := omniUniqueIndexNames(tbl) + if len(omniIdx) == 0 { + t.Errorf("omni: expected a UNIQUE index, got none") + } + }) + + // --- 3.6 Generated column nullability derived from expression ----------- + t.Run("3_6_generated_column_nullability", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t ( + a INT NULL, + b INT GENERATED ALWAYS AS (a+1) VIRTUAL +)` + runOnBoth(t, mc, c, ddl) + + var isNullable string + oracleScan(t, mc, + `SELECT IS_NULLABLE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='b'`, + &isNullable) + assertStringEq(t, "oracle IS_NULLABLE (gcol)", isNullable, "YES") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + col := tbl.GetColumn("b") + if col == nil { + t.Errorf("omni: column b missing") + return + } + assertBoolEq(t, "omni gcol Nullable", col.Nullable, true) + }) + + // --- 3.7 Explicit NULL + AUTO_INCREMENT -------------------------------- + // SCENARIOS expectation: MySQL errors. Empirically MySQL 8.0.45 accepts + // the statement and silently promotes id to NOT NULL (AUTO_INCREMENT + // wins). We test the observed silent-coercion behavior on both sides + // and document the discrepancy with the SCENARIOS file. + t.Run("3_7_null_plus_auto_increment_silent_promote", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := "CREATE TABLE t (id INT NULL AUTO_INCREMENT, KEY(id))" + runOnBoth(t, mc, c, ddl) + + var isNullable string + oracleScan(t, mc, + `SELECT IS_NULLABLE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='id'`, + &isNullable) + assertStringEq(t, "oracle id IS_NULLABLE", isNullable, "NO") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + col := tbl.GetColumn("id") + if col == nil { + t.Errorf("omni: column id missing") + return + } + assertBoolEq(t, "omni id Nullable (NULL + AUTO_INCREMENT silently promoted)", + col.Nullable, false) + }) + + // --- 3.8 Second TIMESTAMP under explicit_defaults_for_timestamp=OFF ----- + t.Run("3_8_second_timestamp_legacy_mode", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // Container side: legacy mode + permissive sql_mode. + setLegacyTimestampMode(t, mc) + defer restoreTimestampMode(t, mc) + + ddl := "CREATE TABLE t (ts1 TIMESTAMP, ts2 TIMESTAMP)" + // Run on container. + if _, err := mc.db.ExecContext(mc.ctx, ddl); err != nil { + t.Errorf("oracle CREATE TABLE failed: %v", err) + } + // Run on omni separately — omni does not track the session var so + // its observed state is whatever omni's default branch produces. + results, err := c.Exec(ddl+";", nil) + if err != nil { + t.Errorf("omni parse error: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Errorf("omni exec error: %v", r.Error) + } + } + + // Oracle: ts1 promoted, ts2 NOT NULL DEFAULT zero literal. + create := oracleShow(t, mc, "SHOW CREATE TABLE t") + lo := strings.ToLower(create) + if !strings.Contains(lo, "default current_timestamp") { + t.Errorf("oracle: legacy-mode ts1 should have DEFAULT CURRENT_TIMESTAMP\n%s", create) + } + if !strings.Contains(lo, "0000-00-00 00:00:00") { + t.Errorf("oracle: legacy-mode ts2 should have DEFAULT '0000-00-00 00:00:00'\n%s", create) + } + + // omni asymmetry: omni does not track explicit_defaults_for_timestamp, + // so we cannot expect it to mirror the legacy transform. Document + // what omni does for posterity. Anything goes here EXCEPT crashing. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing after legacy-mode CREATE TABLE") + return + } + ts1 := tbl.GetColumn("ts1") + ts2 := tbl.GetColumn("ts2") + if ts1 == nil || ts2 == nil { + t.Errorf("omni: ts1/ts2 columns missing") + return + } + t.Logf("omni asymmetry (no session var tracking): ts1.Nullable=%v ts1.Default=%v ts2.Nullable=%v ts2.Default=%v", + ts1.Nullable, derefStr(ts1.Default), ts2.Nullable, derefStr(ts2.Default)) + }) +} + +func derefStr(p *string) string { + if p == nil { + return "" + } + return *p +} + +// setLegacyTimestampMode flips the container into the deprecated +// pre-5.6 TIMESTAMP semantics (explicit_defaults_for_timestamp=OFF) and +// drops the strict-zero-date sql_mode flags so '0000-00-00 00:00:00' +// defaults are accepted. This is required for C3.1 and C3.8 which exercise +// the legacy TIMESTAMP promotion path inside `promote_first_timestamp_column`. +func setLegacyTimestampMode(t *testing.T, mc *mysqlContainer) { + t.Helper() + stmts := []string{ + "SET SESSION explicit_defaults_for_timestamp=0", + "SET SESSION sql_mode=''", + } + for _, s := range stmts { + if _, err := mc.db.ExecContext(mc.ctx, s); err != nil { + t.Fatalf("setLegacyTimestampMode %q: %v", s, err) + } + } +} + +// restoreTimestampMode reverts the session vars touched by +// setLegacyTimestampMode back to MySQL 8.0 defaults. +func restoreTimestampMode(t *testing.T, mc *mysqlContainer) { + t.Helper() + stmts := []string{ + "SET SESSION explicit_defaults_for_timestamp=1", + "SET SESSION sql_mode=DEFAULT", + } + for _, s := range stmts { + _, _ = mc.db.ExecContext(mc.ctx, s) + } +} diff --git a/tidb/catalog/scenarios_c4_test.go b/tidb/catalog/scenarios_c4_test.go new file mode 100644 index 00000000..850cc136 --- /dev/null +++ b/tidb/catalog/scenarios_c4_test.go @@ -0,0 +1,624 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C4 covers Section C4 "Charset / collation inheritance" from +// SCENARIOS-mysql-implicit-behavior.md. Each subtest runs DDL against both a +// real MySQL 8.0 container and the omni catalog, then asserts that both agree +// on charset/collation resolution. +// +// Failed omni assertions are NOT proof failures — they are recorded in +// mysql/catalog/scenarios_bug_queue/c4.md. +func TestScenario_C4(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // c4OracleColCharset fetches CHARACTER_SET_NAME from information_schema + // for testdb.
.. + c4OracleColCharset := func(t *testing.T, table, col string) string { + t.Helper() + var s string + oracleScan(t, mc, + "SELECT IFNULL(CHARACTER_SET_NAME,'') FROM information_schema.COLUMNS "+ + "WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='"+table+"' AND COLUMN_NAME='"+col+"'", + &s) + return strings.ToLower(s) + } + // c4OracleColCollation fetches COLLATION_NAME. + c4OracleColCollation := func(t *testing.T, table, col string) string { + t.Helper() + var s string + oracleScan(t, mc, + "SELECT IFNULL(COLLATION_NAME,'') FROM information_schema.COLUMNS "+ + "WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='"+table+"' AND COLUMN_NAME='"+col+"'", + &s) + return strings.ToLower(s) + } + // c4OracleTableCollation fetches TABLE_COLLATION. + c4OracleTableCollation := func(t *testing.T, table string) string { + t.Helper() + var s string + oracleScan(t, mc, + "SELECT IFNULL(TABLE_COLLATION,'') FROM information_schema.TABLES "+ + "WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='"+table+"'", + &s) + return strings.ToLower(s) + } + // c4OracleDataType fetches DATA_TYPE. + c4OracleDataType := func(t *testing.T, table, col string) string { + t.Helper() + var s string + oracleScan(t, mc, + "SELECT DATA_TYPE FROM information_schema.COLUMNS "+ + "WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='"+table+"' AND COLUMN_NAME='"+col+"'", + &s) + return strings.ToLower(s) + } + + // c4ResetDBWithCharset drops testdb, recreates it with a specific + // charset/collation, USEs it, and does the same on the omni catalog. + // Returns a fresh omni catalog with the same initial state. + c4ResetDBWithCharset := func(t *testing.T, charset, collation string) *Catalog { + t.Helper() + if _, err := mc.db.ExecContext(mc.ctx, "DROP DATABASE IF EXISTS testdb"); err != nil { + t.Fatalf("oracle DROP DATABASE: %v", err) + } + createStmt := "CREATE DATABASE testdb CHARACTER SET " + charset + " COLLATE " + collation + if _, err := mc.db.ExecContext(mc.ctx, createStmt); err != nil { + t.Fatalf("oracle CREATE DATABASE: %v", err) + } + if _, err := mc.db.ExecContext(mc.ctx, "USE testdb"); err != nil { + t.Fatalf("oracle USE testdb: %v", err) + } + c := New() + results, err := c.Exec(createStmt+"; USE testdb;", nil) + if err != nil { + t.Errorf("omni parse error for %q: %v", createStmt, err) + return c + } + for _, r := range results { + if r.Error != nil { + t.Errorf("omni exec error on stmt %d: %v", r.Index, r.Error) + } + } + return c + } + + // --- 4.1 Table charset inherits from database --- + t.Run("4_1_table_charset_from_db", func(t *testing.T) { + c := c4ResetDBWithCharset(t, "latin1", "latin1_swedish_ci") + + runOnBoth(t, mc, c, `CREATE TABLE t (c VARCHAR(10))`) + + assertStringEq(t, "oracle table collation", + c4OracleTableCollation(t, "t"), "latin1_swedish_ci") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + assertStringEq(t, "omni table charset", + strings.ToLower(tbl.Charset), "latin1") + assertStringEq(t, "omni table collation", + strings.ToLower(tbl.Collation), "latin1_swedish_ci") + }) + + // --- 4.2 Column charset inherits from table (elided in SHOW) --- + t.Run("4_2_column_charset_from_table", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (c VARCHAR(10)) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci`) + + assertStringEq(t, "oracle col collation", + c4OracleColCollation(t, "t", "c"), "utf8mb4_0900_ai_ci") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + col := tbl.GetColumn("c") + if col == nil { + t.Error("omni: column c not found") + return + } + // Column charset should match or equal the table charset after inheritance. + // The catalog may store empty (inherited) or the resolved charset. + gotCharset := strings.ToLower(col.Charset) + if gotCharset != "" && gotCharset != "utf8mb4" { + t.Errorf("omni col charset: got %q, want \"utf8mb4\" or empty (inherited)", gotCharset) + } + + // SHOW CREATE TABLE should not mention column-level CHARACTER SET. + mysqlCreate := oracleShow(t, mc, "SHOW CREATE TABLE t") + omniCreate := c.ShowCreateTable("testdb", "t") + // Locate the column line. + for _, line := range strings.Split(omniCreate, "\n") { + if strings.Contains(line, "`c`") && strings.Contains(strings.ToUpper(line), "CHARACTER SET") { + t.Errorf("omni SHOW CREATE TABLE: column c should not have CHARACTER SET clause (inherited from table default). Got: %s", line) + } + } + for _, line := range strings.Split(mysqlCreate, "\n") { + if strings.Contains(line, "`c`") && strings.Contains(strings.ToUpper(line), "CHARACTER SET") { + t.Logf("oracle SHOW CREATE TABLE column line: %s (unexpectedly has CHARACTER SET)", line) + } + } + }) + + // --- 4.3 Column COLLATE alone → derive CHARACTER SET --- + t.Run("4_3_collate_derives_charset", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (c VARCHAR(10) COLLATE utf8mb4_unicode_ci) DEFAULT CHARSET=latin1`) + + assertStringEq(t, "oracle col charset", + c4OracleColCharset(t, "t", "c"), "utf8mb4") + assertStringEq(t, "oracle col collation", + c4OracleColCollation(t, "t", "c"), "utf8mb4_unicode_ci") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + col := tbl.GetColumn("c") + if col == nil { + t.Error("omni: column c not found") + return + } + assertStringEq(t, "omni col charset", + strings.ToLower(col.Charset), "utf8mb4") + assertStringEq(t, "omni col collation", + strings.ToLower(col.Collation), "utf8mb4_unicode_ci") + }) + + // --- 4.4 Column CHARACTER SET alone → derive default COLLATE --- + t.Run("4_4_charset_derives_default_collation", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (c VARCHAR(10) CHARACTER SET latin1) DEFAULT CHARSET=utf8mb4`) + + assertStringEq(t, "oracle col charset", + c4OracleColCharset(t, "t", "c"), "latin1") + assertStringEq(t, "oracle col collation", + c4OracleColCollation(t, "t", "c"), "latin1_swedish_ci") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + col := tbl.GetColumn("c") + if col == nil { + t.Error("omni: column c not found") + return + } + assertStringEq(t, "omni col charset", + strings.ToLower(col.Charset), "latin1") + assertStringEq(t, "omni col collation", + strings.ToLower(col.Collation), "latin1_swedish_ci") + }) + + // --- 4.5 Table CHARSET/COLLATE mismatch error --- + t.Run("4_5_charset_collation_mismatch", func(t *testing.T) { + scenarioReset(t, mc) + + // Mismatch case: latin1 + utf8mb4_0900_ai_ci should fail. + mismatchStmt := `CREATE TABLE t_bad (c VARCHAR(10)) CHARACTER SET latin1 COLLATE utf8mb4_0900_ai_ci` + if _, err := mc.db.ExecContext(mc.ctx, mismatchStmt); err == nil { + t.Error("oracle: expected mismatch error, got none") + } + + c := scenarioNewCatalog(t) + results, err := c.Exec(mismatchStmt, nil) + omniRejected := false + if err != nil { + omniRejected = true + } else { + for _, r := range results { + if r.Error != nil { + omniRejected = true + break + } + } + } + if !omniRejected { + t.Error("omni: expected mismatch CHARSET/COLLATE to be rejected, got no error") + } + + // Compatible case: should succeed. + scenarioReset(t, mc) + c2 := scenarioNewCatalog(t) + runOnBoth(t, mc, c2, + `CREATE TABLE t_ok (c VARCHAR(10)) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci`) + assertStringEq(t, "oracle ok table collation", + c4OracleTableCollation(t, "t_ok"), "utf8mb4_0900_ai_ci") + tbl := c2.GetDatabase("testdb").GetTable("t_ok") + if tbl == nil { + t.Error("omni: table t_ok not found") + return + } + assertStringEq(t, "omni ok table collation", + strings.ToLower(tbl.Collation), "utf8mb4_0900_ai_ci") + }) + + // --- 4.6 BINARY modifier → {charset}_bin rewrite --- + t.Run("4_6_binary_modifier_bin_collation", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a CHAR(10) BINARY, + b VARCHAR(10) CHARACTER SET latin1 BINARY +) DEFAULT CHARSET=utf8mb4`) + + assertStringEq(t, "oracle a collation", + c4OracleColCollation(t, "t", "a"), "utf8mb4_bin") + assertStringEq(t, "oracle b collation", + c4OracleColCollation(t, "t", "b"), "latin1_bin") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + if colA := tbl.GetColumn("a"); colA != nil { + assertStringEq(t, "omni a collation", + strings.ToLower(colA.Collation), "utf8mb4_bin") + } else { + t.Error("omni: column a not found") + } + if colB := tbl.GetColumn("b"); colB != nil { + assertStringEq(t, "omni b collation", + strings.ToLower(colB.Collation), "latin1_bin") + assertStringEq(t, "omni b charset", + strings.ToLower(colB.Charset), "latin1") + } else { + t.Error("omni: column b not found") + } + + // Round-trip: deparse should not emit the BINARY keyword. + omniCreate := c.ShowCreateTable("testdb", "t") + if strings.Contains(strings.ToUpper(omniCreate), " BINARY,") || + strings.Contains(strings.ToUpper(omniCreate), " BINARY\n") { + // Accept canonical form only. The BINARY attribute must be rewritten. + // Allow "BINARY(" for BINARY(N) column type — different construct. + // We check only that the attribute form isn't present on CHAR/VARCHAR. + for _, line := range strings.Split(omniCreate, "\n") { + upper := strings.ToUpper(line) + if (strings.Contains(upper, "CHAR(") || strings.Contains(upper, "VARCHAR(")) && + strings.Contains(upper, " BINARY") { + t.Errorf("omni SHOW CREATE TABLE: expected BINARY attribute rewritten to _bin collation, got line: %s", line) + } + } + } + }) + + // --- 4.7 CHARACTER SET binary vs BINARY type distinction --- + t.Run("4_7_charset_binary_vs_binary_type", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a BINARY(10), + b CHAR(10) CHARACTER SET binary, + c VARBINARY(10) +)`) + + // MySQL 8.0 information_schema reality: + // - BINARY(N) / VARBINARY(N) columns have CHARACTER_SET_NAME and + // COLLATION_NAME reported as NULL (they're byte types, not text). + // - CHAR(N) CHARACTER SET binary is **silently rewritten** at parse + // time to BINARY(N) (sql_yacc.yy folds the form), so it ends up + // indistinguishable from `a` in information_schema: DATA_TYPE='binary', + // COLLATION_NAME=NULL. The "three different kinds of binary" the + // scenario describes manifests in the parse tree, not the post-store + // metadata. + assertStringEq(t, "oracle a collation (NULL for BINARY type)", + c4OracleColCollation(t, "t", "a"), "") + assertStringEq(t, "oracle b collation (rewritten to BINARY -> NULL)", + c4OracleColCollation(t, "t", "b"), "") + assertStringEq(t, "oracle c collation (NULL for VARBINARY type)", + c4OracleColCollation(t, "t", "c"), "") + // Post-fold DATA_TYPE: a -> binary, b -> binary (rewritten), c -> varbinary. + assertStringEq(t, "oracle a data_type", + c4OracleDataType(t, "t", "a"), "binary") + assertStringEq(t, "oracle b data_type (rewritten)", + c4OracleDataType(t, "t", "b"), "binary") + assertStringEq(t, "oracle c data_type", + c4OracleDataType(t, "t", "c"), "varbinary") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + // omni-side: match oracle's NULL-vs-"binary" distinction. For BINARY + // and VARBINARY column types the catalog should NOT report a column + // charset (mirrors information_schema NULL). For CHAR(N) CHARACTER SET + // binary, the catalog should report Collation="binary". + check := func(name, wantDT, wantCollation string) { + col := tbl.GetColumn(name) + if col == nil { + t.Errorf("omni: column %s not found", name) + return + } + if strings.ToLower(col.Collation) != wantCollation { + t.Errorf("omni col %s collation: got %q, want %q", name, col.Collation, wantCollation) + } + if strings.ToLower(col.DataType) != wantDT { + t.Errorf("omni col %s data_type: got %q, want %q", name, col.DataType, wantDT) + } + } + check("a", "binary", "") + // b: omni should match oracle's silent rewrite — DATA_TYPE='binary' + // after CHAR(N) CHARACTER SET binary normalisation. + check("b", "binary", "") + check("c", "varbinary", "") + }) + + // --- 4.8 utf8 → utf8mb3 alias normalization --- + t.Run("4_8_utf8_alias_normalization", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (c VARCHAR(10) CHARACTER SET utf8)`) + + assertStringEq(t, "oracle col charset", + c4OracleColCharset(t, "t", "c"), "utf8mb3") + assertStringEq(t, "oracle col collation", + c4OracleColCollation(t, "t", "c"), "utf8mb3_general_ci") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + col := tbl.GetColumn("c") + if col == nil { + t.Error("omni: column c not found") + return + } + assertStringEq(t, "omni col charset (normalized)", + strings.ToLower(col.Charset), "utf8mb3") + + // SHOW CREATE TABLE should print utf8mb3, never utf8. + omniCreate := strings.ToLower(c.ShowCreateTable("testdb", "t")) + if strings.Contains(omniCreate, "character set utf8 ") || + strings.Contains(omniCreate, "character set utf8,") || + strings.HasSuffix(omniCreate, "character set utf8") { + t.Errorf("omni SHOW CREATE TABLE: expected utf8mb3, got unnormalized utf8 in: %s", omniCreate) + } + }) + + // --- 4.9 NCHAR/NATIONAL → utf8mb3 hardcoding --- + t.Run("4_9_national_nchar_utf8mb3", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a NCHAR(10), + b NATIONAL CHARACTER(10), + cc NATIONAL VARCHAR(10), + d NCHAR VARYING(10) +)`) + + for _, name := range []string{"a", "b", "cc", "d"} { + assertStringEq(t, "oracle "+name+" charset", + c4OracleColCharset(t, "t", name), "utf8mb3") + assertStringEq(t, "oracle "+name+" collation", + c4OracleColCollation(t, "t", name), "utf8mb3_general_ci") + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + for _, name := range []string{"a", "b", "cc", "d"} { + col := tbl.GetColumn(name) + if col == nil { + t.Errorf("omni: column %s not found", name) + continue + } + assertStringEq(t, "omni "+name+" charset", + strings.ToLower(col.Charset), "utf8mb3") + } + }) + + // --- 4.10 ENUM/SET charset inheritance --- + t.Run("4_10_enum_set_charset_inheritance", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a ENUM('x','y'), + b ENUM('x','y') CHARACTER SET latin1, + cc SET('p','q') COLLATE utf8mb4_unicode_ci +) DEFAULT CHARSET=utf8mb4`) + + // a: inherits table default utf8mb4 + assertStringEq(t, "oracle a charset", + c4OracleColCharset(t, "t", "a"), "utf8mb4") + // b: latin1 + default collation + assertStringEq(t, "oracle b charset", + c4OracleColCharset(t, "t", "b"), "latin1") + assertStringEq(t, "oracle b collation", + c4OracleColCollation(t, "t", "b"), "latin1_swedish_ci") + // cc: charset derived from COLLATE + assertStringEq(t, "oracle cc charset", + c4OracleColCharset(t, "t", "cc"), "utf8mb4") + assertStringEq(t, "oracle cc collation", + c4OracleColCollation(t, "t", "cc"), "utf8mb4_unicode_ci") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + if colB := tbl.GetColumn("b"); colB != nil { + assertStringEq(t, "omni b charset", + strings.ToLower(colB.Charset), "latin1") + assertStringEq(t, "omni b collation", + strings.ToLower(colB.Collation), "latin1_swedish_ci") + } else { + t.Error("omni: column b not found") + } + if colC := tbl.GetColumn("cc"); colC != nil { + assertStringEq(t, "omni cc charset", + strings.ToLower(colC.Charset), "utf8mb4") + assertStringEq(t, "omni cc collation", + strings.ToLower(colC.Collation), "utf8mb4_unicode_ci") + } else { + t.Error("omni: column cc not found") + } + }) + + // --- 4.11 Index prefix × mbmaxlen --- + t.Run("4_11_index_prefix_mbmaxlen", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // t1: latin1 c(5) — fits (SUB_PART reported when prefix < full col). + runOnBoth(t, mc, c, + `CREATE TABLE t1 (c VARCHAR(10) CHARACTER SET latin1, KEY k (c(5)))`) + // t2: utf8mb4 c(5) — 20 bytes, fits. + runOnBoth(t, mc, c, + `CREATE TABLE t2 (c VARCHAR(10) CHARACTER SET utf8mb4, KEY k (c(5)))`) + + // t3: utf8mb4 VARCHAR(200), KEY (c(768)) = 3072 bytes. Should be rejected + // (exceeds InnoDB 3072-byte per-column key limit by default; really the + // prefix length > VARCHAR(200) also fails with ER_WRONG_SUB_KEY). + // Run only on oracle to verify it errors. + tooLong := `CREATE TABLE t3 (c VARCHAR(200) CHARACTER SET utf8mb4, KEY k (c(768)))` + if _, err := mc.db.ExecContext(mc.ctx, tooLong); err == nil { + t.Error("oracle: expected t3 creation to fail (prefix too long), got no error") + _, _ = mc.db.ExecContext(mc.ctx, "DROP TABLE IF EXISTS t3") + } + // omni: should also reject. + results, err := c.Exec(tooLong, nil) + omniRejected := err != nil + if !omniRejected { + for _, r := range results { + if r.Error != nil { + omniRejected = true + break + } + } + } + if !omniRejected { + t.Error("omni: expected t3 CREATE to be rejected, got no error") + } + + // t1/t2: prefix is reported in characters (SUB_PART). + var sub1, sub2 int + oracleScan(t, mc, + `SELECT SUB_PART FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t1' AND INDEX_NAME='k'`, + &sub1) + oracleScan(t, mc, + `SELECT SUB_PART FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t2' AND INDEX_NAME='k'`, + &sub2) + assertIntEq(t, "oracle t1 SUB_PART", sub1, 5) + assertIntEq(t, "oracle t2 SUB_PART", sub2, 5) + + // omni: index column Length should be 10 (characters, not bytes). + for _, name := range []string{"t1", "t2"} { + tbl := c.GetDatabase("testdb").GetTable(name) + if tbl == nil { + t.Errorf("omni: table %s not found", name) + continue + } + var idx *Index + for _, i := range tbl.Indexes { + if i.Name == "k" { + idx = i + break + } + } + if idx == nil { + t.Errorf("omni: index k on %s not found", name) + continue + } + if len(idx.Columns) != 1 { + t.Errorf("omni: %s.k expected 1 col, got %d", name, len(idx.Columns)) + continue + } + assertIntEq(t, "omni "+name+".k length", idx.Columns[0].Length, 5) + } + }) + + // --- 4.12 DTCollation derivation levels --- + // + // This scenario requires an expression-level collation resolver. omni + // catalog currently stores column-level Charset/Collation, but does not + // expose DTCollation / derivation for arbitrary SELECT expressions. We + // verify the column is set up correctly and that MySQL produces the + // documented comparison outcomes; omni-side assertions are best-effort. + t.Run("4_12_dtcollation_derivation", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (c VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci)`) + + // Oracle: baseline column collation. + assertStringEq(t, "oracle col collation", + c4OracleColCollation(t, "t", "c"), "utf8mb4_0900_ai_ci") + + // Oracle runtime checks for each comparison shape. These are queries + // against an empty table — we care about whether MySQL accepts or + // rejects them (not about returned rows). + ok := func(q string) { + if _, err := mc.db.ExecContext(mc.ctx, q); err != nil { + t.Errorf("oracle: %q should succeed, got %v", q, err) + } + } + _ = func(q string) {} // placeholder to keep ok() referenced if unused + ok("SELECT c = 'abc' FROM t") + ok("SELECT c = _utf8mb4'abc' COLLATE utf8mb4_bin FROM t") + ok("SELECT c COLLATE utf8mb4_bin = _latin1'abc' FROM t") + // The CAST-vs-column comparison is documented to fail with + // ER_CANT_AGGREGATE_2COLLATIONS, but MySQL 8.0.x has loosened + // implicit conversion in several point releases. Log the outcome + // for visibility but don't fail the test on either path. + castQuery := "SELECT CAST('abc' AS CHAR CHARACTER SET latin1) = c FROM t" + if _, err := mc.db.ExecContext(mc.ctx, castQuery); err == nil { + t.Logf("oracle: %q succeeded (DTCollation aggregation allowed in this MySQL version)", castQuery) + } else { + t.Logf("oracle: %q failed as documented: %v", castQuery, err) + } + + // omni: we only assert the column collation survives. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + col := tbl.GetColumn("c") + if col == nil { + t.Error("omni: column c not found") + return + } + assertStringEq(t, "omni col collation", + strings.ToLower(col.Collation), "utf8mb4_0900_ai_ci") + + // Best-effort: omni's catalog currently does not expose a SELECT + // expression evaluator with DTCollation, so the 4 query-time checks + // above are oracle-only. See scenarios_bug_queue/c4.md for the gap. + }) +} diff --git a/tidb/catalog/scenarios_c5_test.go b/tidb/catalog/scenarios_c5_test.go new file mode 100644 index 00000000..a96a78eb --- /dev/null +++ b/tidb/catalog/scenarios_c5_test.go @@ -0,0 +1,472 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C5 covers section C5 (Constraint defaults) from +// SCENARIOS-mysql-implicit-behavior.md. It checks FK and CHECK +// constraint defaults: ON DELETE/ON UPDATE/MATCH defaults, FK +// SET DEFAULT InnoDB rejection, FK column type compatibility, +// FK on virtual gcol rejection, CHECK ENFORCED default, +// column-level vs table-level CHECK equivalence, column-level CHECK +// cross-column reference rejection. +// +// Failures in omni assertions are NOT proof failures — they are +// recorded in mysql/catalog/scenarios_bug_queue/c5.md. +func TestScenario_C5(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // --- 5.1 FK ON DELETE default — RESTRICT internally / NO ACTION reported --- + t.Run("5_1_fk_on_delete_default", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE p (id INT PRIMARY KEY); +CREATE TABLE c (a INT, FOREIGN KEY (a) REFERENCES p(id));` + runOnBoth(t, mc, c, ddl) + + // Oracle: REFERENTIAL_CONSTRAINTS rules. + var delRule, updRule, matchOpt string + oracleScan(t, mc, `SELECT DELETE_RULE, UPDATE_RULE, MATCH_OPTION + FROM information_schema.REFERENTIAL_CONSTRAINTS + WHERE CONSTRAINT_SCHEMA='testdb' AND TABLE_NAME='c'`, + &delRule, &updRule, &matchOpt) + assertStringEq(t, "oracle DELETE_RULE", delRule, "NO ACTION") + assertStringEq(t, "oracle UPDATE_RULE", updRule, "NO ACTION") + assertStringEq(t, "oracle MATCH_OPTION", matchOpt, "NONE") + + // SHOW CREATE TABLE should not contain ON DELETE / ON UPDATE / MATCH. + create := oracleShow(t, mc, "SHOW CREATE TABLE c") + if strings.Contains(strings.ToUpper(create), "ON DELETE") { + t.Errorf("oracle: SHOW CREATE TABLE should omit ON DELETE: %s", create) + } + if strings.Contains(strings.ToUpper(create), "ON UPDATE") { + t.Errorf("oracle: SHOW CREATE TABLE should omit ON UPDATE: %s", create) + } + if strings.Contains(strings.ToUpper(create), "MATCH") { + t.Errorf("oracle: SHOW CREATE TABLE should omit MATCH: %s", create) + } + + // omni: FK should have default OnDelete/OnUpdate (empty or NO ACTION / + // RESTRICT) and the deparsed table should not render those clauses. + tbl := c.GetDatabase("testdb").GetTable("c") + if tbl == nil { + t.Errorf("omni: table c missing") + return + } + fk := c5FirstFK(tbl) + if fk == nil { + t.Errorf("omni: no FK constraint on c") + return + } + if !c5IsFKDefault(fk.OnDelete) { + t.Errorf("omni: FK OnDelete should be default, got %q", fk.OnDelete) + } + if !c5IsFKDefault(fk.OnUpdate) { + t.Errorf("omni: FK OnUpdate should be default, got %q", fk.OnUpdate) + } + omniCreate := c5OmniShowCreate(t, c, "c") + if strings.Contains(strings.ToUpper(omniCreate), "ON DELETE") { + t.Errorf("omni: deparse should omit ON DELETE: %s", omniCreate) + } + if strings.Contains(strings.ToUpper(omniCreate), "ON UPDATE") { + t.Errorf("omni: deparse should omit ON UPDATE: %s", omniCreate) + } + }) + + // --- 5.2 FK ON DELETE SET NULL on a NOT NULL column errors -------------- + t.Run("5_2_fk_set_null_requires_nullable", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + setup := `CREATE TABLE p (id INT PRIMARY KEY);` + if _, err := mc.db.ExecContext(mc.ctx, setup); err != nil { + t.Errorf("oracle setup: %v", err) + } + _, _ = c.Exec(setup, nil) + + bad := `CREATE TABLE c ( + a INT NOT NULL, + FOREIGN KEY (a) REFERENCES p(id) ON DELETE SET NULL + )` + _, mysqlErr := mc.db.ExecContext(mc.ctx, bad) + if mysqlErr == nil { + t.Errorf("oracle: expected ER_FK_COLUMN_NOT_NULL, got nil") + } + + results, err := c.Exec(bad+";", nil) + omniErrored := err != nil + if !omniErrored { + for _, r := range results { + if r.Error != nil { + omniErrored = true + break + } + } + } + assertBoolEq(t, "omni rejects SET NULL on NOT NULL column", omniErrored, true) + }) + + // --- 5.3 FK MATCH default rendered as NONE in information_schema -------- + t.Run("5_3_fk_match_default_none", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE p (id INT PRIMARY KEY); +CREATE TABLE c (a INT, FOREIGN KEY (a) REFERENCES p(id));`) + + var matchOpt string + oracleScan(t, mc, `SELECT MATCH_OPTION FROM information_schema.REFERENTIAL_CONSTRAINTS + WHERE CONSTRAINT_SCHEMA='testdb' AND TABLE_NAME='c'`, &matchOpt) + assertStringEq(t, "oracle MATCH_OPTION", matchOpt, "NONE") + + tbl := c.GetDatabase("testdb").GetTable("c") + if tbl == nil { + t.Errorf("omni: table c missing") + return + } + fk := c5FirstFK(tbl) + if fk == nil { + t.Errorf("omni: no FK on c") + return + } + // omni MatchType should be empty / NONE / SIMPLE — anything matching default. + if fk.MatchType != "" && !strings.EqualFold(fk.MatchType, "NONE") && + !strings.EqualFold(fk.MatchType, "SIMPLE") { + t.Errorf("omni: FK MatchType should be default (empty/NONE/SIMPLE), got %q", fk.MatchType) + } + }) + + // --- 5.4 FK ON UPDATE default independent of ON DELETE ----------------- + t.Run("5_4_fk_on_update_independent", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE p (id INT PRIMARY KEY); +CREATE TABLE c (a INT, FOREIGN KEY (a) REFERENCES p(id) ON DELETE CASCADE);` + runOnBoth(t, mc, c, ddl) + + var delRule, updRule string + oracleScan(t, mc, `SELECT DELETE_RULE, UPDATE_RULE + FROM information_schema.REFERENTIAL_CONSTRAINTS + WHERE CONSTRAINT_SCHEMA='testdb' AND TABLE_NAME='c'`, + &delRule, &updRule) + assertStringEq(t, "oracle DELETE_RULE", delRule, "CASCADE") + assertStringEq(t, "oracle UPDATE_RULE", updRule, "NO ACTION") + + // SHOW CREATE renders ON DELETE CASCADE only. + create := oracleShow(t, mc, "SHOW CREATE TABLE c") + if !strings.Contains(strings.ToUpper(create), "ON DELETE CASCADE") { + t.Errorf("oracle: expected ON DELETE CASCADE, got %s", create) + } + if strings.Contains(strings.ToUpper(create), "ON UPDATE") { + t.Errorf("oracle: expected no ON UPDATE clause, got %s", create) + } + + tbl := c.GetDatabase("testdb").GetTable("c") + if tbl == nil { + t.Errorf("omni: table c missing") + return + } + fk := c5FirstFK(tbl) + if fk == nil { + t.Errorf("omni: no FK on c") + return + } + if !strings.EqualFold(fk.OnDelete, "CASCADE") { + t.Errorf("omni: FK OnDelete should be CASCADE, got %q", fk.OnDelete) + } + if !c5IsFKDefault(fk.OnUpdate) { + t.Errorf("omni: FK OnUpdate should be default (empty/NO ACTION/RESTRICT), got %q", fk.OnUpdate) + } + omniCreate := c5OmniShowCreate(t, c, "c") + if !strings.Contains(strings.ToUpper(omniCreate), "ON DELETE CASCADE") { + t.Errorf("omni: deparse missing ON DELETE CASCADE: %s", omniCreate) + } + if strings.Contains(strings.ToUpper(omniCreate), "ON UPDATE") { + t.Errorf("omni: deparse should omit ON UPDATE: %s", omniCreate) + } + }) + + // --- 5.5 FK SET DEFAULT rejected by InnoDB ------------------------------ + t.Run("5_5_fk_set_default_innodb_limitation", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + setup := `CREATE TABLE p (id INT PRIMARY KEY);` + if _, err := mc.db.ExecContext(mc.ctx, setup); err != nil { + t.Errorf("oracle setup: %v", err) + } + _, _ = c.Exec(setup, nil) + + bad := `CREATE TABLE c ( + a INT DEFAULT 0, + FOREIGN KEY (a) REFERENCES p(id) ON DELETE SET DEFAULT + )` + + // MySQL/InnoDB returns an error or warning for SET DEFAULT. Capture + // whichever is observed rather than asserting the precise behavior so + // the test reflects the real oracle. + _, mysqlErr := mc.db.ExecContext(mc.ctx, bad) + oracleRejected := mysqlErr != nil + + results, err := c.Exec(bad+";", nil) + omniErrored := err != nil + if !omniErrored { + for _, r := range results { + if r.Error != nil { + omniErrored = true + break + } + } + } + + // Both should agree. + assertBoolEq(t, "omni FK SET DEFAULT matches MySQL rejection", + omniErrored, oracleRejected) + }) + + // --- 5.6 FK column type/size/sign must match parent -------------------- + t.Run("5_6_fk_column_type_compat", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + setup := `CREATE TABLE p (id BIGINT PRIMARY KEY);` + if _, err := mc.db.ExecContext(mc.ctx, setup); err != nil { + t.Errorf("oracle setup: %v", err) + } + _, _ = c.Exec(setup, nil) + + bad := `CREATE TABLE c (a INT, FOREIGN KEY (a) REFERENCES p(id))` + _, mysqlErr := mc.db.ExecContext(mc.ctx, bad) + if mysqlErr == nil { + t.Errorf("oracle: expected ER_FK_INCOMPATIBLE_COLUMNS (3780), got nil") + } + + results, err := c.Exec(bad+";", nil) + omniErrored := err != nil + if !omniErrored { + for _, r := range results { + if r.Error != nil { + omniErrored = true + break + } + } + } + assertBoolEq(t, "omni rejects FK with mismatched column types", omniErrored, true) + }) + + // --- 5.7 FK on a VIRTUAL generated column rejected --------------------- + t.Run("5_7_fk_on_virtual_gcol_rejected", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + setup := `CREATE TABLE p (id INT PRIMARY KEY);` + if _, err := mc.db.ExecContext(mc.ctx, setup); err != nil { + t.Errorf("oracle setup: %v", err) + } + _, _ = c.Exec(setup, nil) + + bad := `CREATE TABLE c ( + a INT, + b INT GENERATED ALWAYS AS (a+1) VIRTUAL, + FOREIGN KEY (b) REFERENCES p(id) + )` + _, mysqlErr := mc.db.ExecContext(mc.ctx, bad) + if mysqlErr == nil { + t.Errorf("oracle: expected ER_FK_CANNOT_USE_VIRTUAL_COLUMN (3104), got nil") + } + + results, err := c.Exec(bad+";", nil) + omniErrored := err != nil + if !omniErrored { + for _, r := range results { + if r.Error != nil { + omniErrored = true + break + } + } + } + assertBoolEq(t, "omni rejects FK on VIRTUAL gcol", omniErrored, true) + }) + + // --- 5.8 CHECK defaults to ENFORCED ------------------------------------ + t.Run("5_8_check_enforced_default", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a INT, + CONSTRAINT chk_pos CHECK (a > 0) + );`) + + var enforced string + oracleScan(t, mc, `SELECT ENFORCED FROM information_schema.TABLE_CONSTRAINTS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND CONSTRAINT_NAME='chk_pos'`, + &enforced) + assertStringEq(t, "oracle CHECK ENFORCED", enforced, "YES") + + create := oracleShow(t, mc, "SHOW CREATE TABLE t") + if strings.Contains(strings.ToUpper(create), "NOT ENFORCED") { + t.Errorf("oracle: SHOW CREATE should not contain NOT ENFORCED: %s", create) + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + var found *Constraint + for _, con := range tbl.Constraints { + if con.Type == ConCheck && strings.EqualFold(con.Name, "chk_pos") { + found = con + break + } + } + if found == nil { + t.Errorf("omni: CHECK constraint chk_pos missing") + return + } + assertBoolEq(t, "omni CHECK NotEnforced is false", found.NotEnforced, false) + + omniCreate := c5OmniShowCreate(t, c, "t") + if strings.Contains(strings.ToUpper(omniCreate), "NOT ENFORCED") { + t.Errorf("omni: deparse should not contain NOT ENFORCED: %s", omniCreate) + } + }) + + // --- 5.9 column-level vs table-level CHECK equivalence ----------------- + t.Run("5_9_check_column_vs_table_level", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t1 (a INT CHECK (a > 0)); +CREATE TABLE t2 (a INT, CHECK (a > 0));`) + + // Each table should have a single CHECK constraint with auto name. + rows1 := oracleRows(t, mc, `SELECT CONSTRAINT_NAME, CHECK_CLAUSE + FROM information_schema.CHECK_CONSTRAINTS + WHERE CONSTRAINT_SCHEMA='testdb' AND CONSTRAINT_NAME LIKE 't1_chk_%'`) + rows2 := oracleRows(t, mc, `SELECT CONSTRAINT_NAME, CHECK_CLAUSE + FROM information_schema.CHECK_CONSTRAINTS + WHERE CONSTRAINT_SCHEMA='testdb' AND CONSTRAINT_NAME LIKE 't2_chk_%'`) + if len(rows1) != 1 { + t.Errorf("oracle: expected 1 CHECK on t1, got %d", len(rows1)) + } + if len(rows2) != 1 { + t.Errorf("oracle: expected 1 CHECK on t2, got %d", len(rows2)) + } + if len(rows1) == 1 && len(rows2) == 1 { + expr1 := asString(rows1[0][1]) + expr2 := asString(rows2[0][1]) + if expr1 != expr2 { + t.Errorf("oracle: CHECK_CLAUSE differs: t1=%q t2=%q", expr1, expr2) + } + } + + t1 := c.GetDatabase("testdb").GetTable("t1") + t2 := c.GetDatabase("testdb").GetTable("t2") + if t1 == nil || t2 == nil { + t.Errorf("omni: t1 or t2 missing") + return + } + c1 := omniCheckNames(t1) + c2 := omniCheckNames(t2) + assertIntEq(t, "omni t1 check count", len(c1), 1) + assertIntEq(t, "omni t2 check count", len(c2), 1) + + // Compare normalized expressions (both should serialize to the same form). + var e1, e2 string + for _, con := range t1.Constraints { + if con.Type == ConCheck { + e1 = c5NormalizeCheckExpr(con.CheckExpr) + break + } + } + for _, con := range t2.Constraints { + if con.Type == ConCheck { + e2 = c5NormalizeCheckExpr(con.CheckExpr) + break + } + } + if e1 != e2 { + t.Errorf("omni: column-level vs table-level CHECK exprs differ: %q vs %q", e1, e2) + } + }) + + // --- 5.10 column-level CHECK with cross-column ref rejected ------------ + t.Run("5_10_column_check_cross_ref_rejected", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + bad := `CREATE TABLE t (a INT CHECK (a > b), b INT)` + _, mysqlErr := mc.db.ExecContext(mc.ctx, bad) + if mysqlErr == nil { + t.Errorf("oracle: expected ER_CHECK_CONSTRAINT_REFERS (3823), got nil") + } else if !strings.Contains(mysqlErr.Error(), "3823") && + !strings.Contains(strings.ToLower(mysqlErr.Error()), "check constraint") { + t.Errorf("oracle: expected 3823, got %v", mysqlErr) + } + + results, err := c.Exec(bad+";", nil) + omniErrored := err != nil + if !omniErrored { + for _, r := range results { + if r.Error != nil { + omniErrored = true + break + } + } + } + assertBoolEq(t, "omni rejects column-level CHECK referencing other column", omniErrored, true) + + // Sanity: same expression as TABLE-level CHECK should succeed in MySQL. + scenarioReset(t, mc) + c2 := scenarioNewCatalog(t) + good := `CREATE TABLE t (a INT, b INT, CHECK (a > b));` + runOnBoth(t, mc, c2, good) + }) +} + +// --- C5 section helpers --------------------------------------------------- + +// c5FirstFK returns the first FK constraint of the table or nil. +func c5FirstFK(tbl *Table) *Constraint { + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey { + return con + } + } + return nil +} + +// c5IsFKDefault returns true when the FK action string is one of the values +// MySQL treats as the implicit default (empty, NO ACTION, RESTRICT). +func c5IsFKDefault(action string) bool { + upper := strings.ToUpper(strings.TrimSpace(action)) + return upper == "" || upper == "NO ACTION" || upper == "RESTRICT" +} + +// c5OmniShowCreate returns omni's SHOW CREATE TABLE rendering for the named +// table in testdb, or the empty string if the table is missing. +func c5OmniShowCreate(t *testing.T, c *Catalog, table string) string { + t.Helper() + return c.ShowCreateTable("testdb", table) +} + +// c5NormalizeCheckExpr strips parentheses and whitespace so column-level vs +// table-level CHECKs (which may have different paren counts) compare equal. +func c5NormalizeCheckExpr(s string) string { + s = strings.ReplaceAll(s, " ", "") + s = strings.ReplaceAll(s, "`", "") + for strings.HasPrefix(s, "(") && strings.HasSuffix(s, ")") { + s = s[1 : len(s)-1] + } + return s +} diff --git a/tidb/catalog/scenarios_c6_test.go b/tidb/catalog/scenarios_c6_test.go new file mode 100644 index 00000000..b1b98579 --- /dev/null +++ b/tidb/catalog/scenarios_c6_test.go @@ -0,0 +1,566 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C6 covers Section C6 (Partition defaults) of the +// mysql-implicit-behavior starmap. Each subtest runs DDL against both a real +// MySQL 8.0 container and the omni catalog and asserts the observable state +// agrees. +// +// Uses helpers from scenarios_helpers_test.go and reuses: +// - oraclePartitionNames (scenarios_c1_test.go) +// +// Section-local helpers use the `c6` prefix to avoid collisions. +func TestScenario_C6(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // --- 6.1 HASH without PARTITIONS defaults to 1 ---------------------- + t.Run("6_1_HASH_partitions_default_1", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (id INT) PARTITION BY HASH(id)`) + + names := oraclePartitionNames(t, mc, "t") + wantNames := []string{"p0"} + assertStringEq(t, "oracle partition names", + strings.Join(names, ","), strings.Join(wantNames, ",")) + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + omniNames := c6OmniPartitionNames(tbl) + assertStringEq(t, "omni partition names", + strings.Join(omniNames, ","), strings.Join(wantNames, ",")) + }) + + // --- 6.2 SUBPARTITIONS default to 1 if not specified ---------------- + t.Run("6_2_subpartitions_default_1", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (id INT, d DATE) + PARTITION BY RANGE(YEAR(d)) + SUBPARTITION BY HASH(id) + (PARTITION p0 VALUES LESS THAN (2000), + PARTITION p1 VALUES LESS THAN MAXVALUE);`) + + // Oracle: expect 2 subpartitions total (1 per parent). + rows := oracleRows(t, mc, `SELECT PARTITION_NAME, SUBPARTITION_NAME + FROM information_schema.PARTITIONS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' + ORDER BY PARTITION_ORDINAL_POSITION, SUBPARTITION_ORDINAL_POSITION`) + if len(rows) != 2 { + t.Errorf("oracle subpartition rows: got %d, want 2", len(rows)) + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil || tbl.Partitioning == nil { + t.Errorf("omni: table or partitioning missing") + return + } + total := 0 + for _, p := range tbl.Partitioning.Partitions { + total += len(p.SubPartitions) + } + assertIntEq(t, "omni subpartition count", total, 2) + }) + + // --- 6.3 Partition ENGINE defaults to table ENGINE ------------------ + t.Run("6_3_partition_engine_defaults_to_table", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (id INT) ENGINE=InnoDB PARTITION BY HASH(id) PARTITIONS 2`) + + // Oracle: SHOW CREATE TABLE renders table-level ENGINE=InnoDB. + sc := oracleShow(t, mc, "SHOW CREATE TABLE t") + if !strings.Contains(sc, "ENGINE=InnoDB") { + t.Errorf("oracle SHOW CREATE: expected ENGINE=InnoDB, got %s", sc) + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil || tbl.Partitioning == nil { + t.Errorf("omni: table or partitioning missing") + return + } + for i, p := range tbl.Partitioning.Partitions { + // Empty per-partition engine is OK as long as table-level engine + // renders it in SHOW CREATE; but a concrete non-InnoDB would be + // wrong. + if p.Engine != "" && !strings.EqualFold(p.Engine, "InnoDB") { + t.Errorf("omni: partition %d engine %q != InnoDB", i, p.Engine) + } + } + }) + + // --- 6.4 KEY ALGORITHM default 2 ------------------------------------ + t.Run("6_4_key_algorithm_default_2", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (id INT) PARTITION BY KEY(id) PARTITIONS 4`) + + // Oracle SHOW CREATE TABLE should NOT mention ALGORITHM=1 (algo 2 is + // default, so it is elided from the output). + sc := oracleShow(t, mc, "SHOW CREATE TABLE t") + if strings.Contains(sc, "ALGORITHM=1") { + t.Errorf("oracle SHOW CREATE should not have ALGORITHM=1: %s", sc) + } + + // omni catalog: algorithm should default to 2 (or be 0 == unset; + // either way it must not be 1). + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil || tbl.Partitioning == nil { + t.Errorf("omni: table or partitioning missing") + return + } + if tbl.Partitioning.Algorithm == 1 { + t.Errorf("omni: Algorithm=1 leaked as default, want 2 or 0") + } + }) + + // --- 6.5 KEY() empty column list → PK columns ---------------------- + t.Run("6_5_key_empty_columns_defaults_to_PK", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (id INT PRIMARY KEY, v INT) PARTITION BY KEY() PARTITIONS 4`) + + // Oracle: PARTITION_EXPRESSION column should name `id`. + var expr string + oracleScan(t, mc, `SELECT COALESCE(PARTITION_EXPRESSION,'') + FROM information_schema.PARTITIONS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' + LIMIT 1`, &expr) + if !strings.Contains(expr, "id") { + t.Errorf("oracle partition expression: got %q, want containing 'id'", expr) + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil || tbl.Partitioning == nil { + t.Errorf("omni: table or partitioning missing") + return + } + got := strings.Join(tbl.Partitioning.Columns, ",") + if got != "id" { + t.Errorf("omni partition columns: got %q, want %q", got, "id") + } + }) + + // --- 6.6 LINEAR HASH / LINEAR KEY preserved ------------------------- + t.Run("6_6_linear_preserved", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (id INT) PARTITION BY LINEAR HASH(id) PARTITIONS 4`) + + sc := oracleShow(t, mc, "SHOW CREATE TABLE t") + if !strings.Contains(sc, "LINEAR HASH") { + t.Errorf("oracle SHOW CREATE: expected LINEAR HASH, got %s", sc) + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil || tbl.Partitioning == nil { + t.Errorf("omni: table or partitioning missing") + return + } + assertBoolEq(t, "omni Linear", tbl.Partitioning.Linear, true) + }) + + // --- 6.7 RANGE/LIST require explicit partition definitions --------- + t.Run("6_7_range_partitions_n_shortcut_rejected", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t (id INT) PARTITION BY RANGE(id) PARTITIONS 4` + + // Oracle: must reject. + if _, err := mc.db.ExecContext(mc.ctx, ddl); err == nil { + t.Errorf("oracle: expected error for RANGE without definitions") + } + + // omni: should also reject. + results, err := c.Exec(ddl, nil) + omniErr := err + if omniErr == nil { + for _, r := range results { + if r.Error != nil { + omniErr = r.Error + break + } + } + } + if omniErr == nil { + t.Errorf("omni: expected error for RANGE without definitions, got nil") + } + }) + + // --- 6.8 MAXVALUE must appear in the last RANGE partition only ----- + t.Run("6_8_maxvalue_last_only", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t (id INT) PARTITION BY RANGE(id) + (PARTITION p0 VALUES LESS THAN MAXVALUE, + PARTITION p1 VALUES LESS THAN (100))` + + if _, err := mc.db.ExecContext(mc.ctx, ddl); err == nil { + t.Errorf("oracle: expected error for misplaced MAXVALUE") + } + + results, err := c.Exec(ddl, nil) + omniErr := err + if omniErr == nil { + for _, r := range results { + if r.Error != nil { + omniErr = r.Error + break + } + } + } + if omniErr == nil { + t.Errorf("omni: expected error for misplaced MAXVALUE, got nil") + } + }) + + // --- 6.9 LIST comparison semantics (docs + round-trip) ------------ + t.Run("6_9_list_equality_roundtrip", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (c INT) PARTITION BY LIST(c) + (PARTITION p0 VALUES IN (1,2), PARTITION p1 VALUES IN (3,4))`) + + names := oraclePartitionNames(t, mc, "t") + wantNames := []string{"p0", "p1"} + assertStringEq(t, "oracle partition names", + strings.Join(names, ","), strings.Join(wantNames, ",")) + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil || tbl.Partitioning == nil { + t.Errorf("omni: table or partitioning missing") + return + } + assertStringEq(t, "omni partition type", tbl.Partitioning.Type, "LIST") + assertIntEq(t, "omni partition count", + len(tbl.Partitioning.Partitions), 2) + }) + + // --- 6.10 LIST DEFAULT partition -------------------------------- + // Requires MySQL 8.0.4+ (introduction of LIST ... VALUES IN (DEFAULT)). + // Skip on older server versions so the test survives contributors + // running pinned-older images or CI lanes behind the rolling 8.0 tag. + t.Run("6_10_list_default_partition", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + var ver string + oracleScan(t, mc, `SELECT VERSION()`, &ver) + if !mysqlAtLeast(ver, 8, 0, 4) { + t.Skipf("6.10 requires MySQL >= 8.0.4 for LIST DEFAULT partition, got %q", ver) + } + + runOnBoth(t, mc, c, + `CREATE TABLE t (c INT) PARTITION BY LIST(c) + (PARTITION p0 VALUES IN (1,2), PARTITION pd VALUES IN (DEFAULT))`) + + names := oraclePartitionNames(t, mc, "t") + wantNames := []string{"p0", "pd"} + assertStringEq(t, "oracle partition names", + strings.Join(names, ","), strings.Join(wantNames, ",")) + + // Oracle: pd's PARTITION_DESCRIPTION should be DEFAULT. + var desc string + oracleScan(t, mc, `SELECT COALESCE(PARTITION_DESCRIPTION,'') + FROM information_schema.PARTITIONS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND PARTITION_NAME='pd'`, + &desc) + if !strings.EqualFold(desc, "DEFAULT") { + t.Errorf("oracle: pd description = %q, want DEFAULT", desc) + } + + // omni: expect DEFAULT token to round-trip on the last partition's + // ValueExpr. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil || tbl.Partitioning == nil || len(tbl.Partitioning.Partitions) < 2 { + t.Errorf("omni: table/partitioning missing or short") + return + } + got := strings.ToUpper(tbl.Partitioning.Partitions[1].ValueExpr) + if !strings.Contains(got, "DEFAULT") { + t.Errorf("omni: pd ValueExpr = %q, want containing DEFAULT", got) + } + }) + + // --- 6.11 Partition function result must be INTEGER -------------- + t.Run("6_11_partition_expr_non_integer_rejected", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t (a VARCHAR(10), b VARCHAR(10)) + PARTITION BY RANGE(CONCAT(a,b)) + (PARTITION p0 VALUES LESS THAN ('m'), + PARTITION p1 VALUES LESS THAN MAXVALUE)` + + if _, err := mc.db.ExecContext(mc.ctx, ddl); err == nil { + t.Errorf("oracle: expected error for non-integer partition expr") + } + + results, err := c.Exec(ddl, nil) + omniErr := err + if omniErr == nil { + for _, r := range results { + if r.Error != nil { + omniErr = r.Error + break + } + } + } + if omniErr == nil { + t.Errorf("omni: expected error for non-integer partition expr, got nil") + } + }) + + // --- 6.12 TIMESTAMP requires UNIX_TIMESTAMP wrapping -------------- + t.Run("6_12_timestamp_requires_unix_timestamp", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + badDDL := `CREATE TABLE t1 (ts TIMESTAMP) PARTITION BY RANGE(ts) + (PARTITION p0 VALUES LESS THAN (100))` + + if _, err := mc.db.ExecContext(mc.ctx, badDDL); err == nil { + t.Errorf("oracle: expected error for bare TIMESTAMP partition expr") + } + results, err := c.Exec(badDDL, nil) + omniErr := err + if omniErr == nil { + for _, r := range results { + if r.Error != nil { + omniErr = r.Error + break + } + } + } + if omniErr == nil { + t.Errorf("omni: expected error for bare TIMESTAMP partition expr") + } + + // The wrapped form must succeed on both. + goodDDL := `CREATE TABLE t2 (ts TIMESTAMP NOT NULL) + PARTITION BY RANGE(UNIX_TIMESTAMP(ts)) + (PARTITION p0 VALUES LESS THAN (100), + PARTITION p1 VALUES LESS THAN MAXVALUE)` + runOnBoth(t, mc, c, goodDDL) + + tbl := c.GetDatabase("testdb").GetTable("t2") + if tbl == nil || tbl.Partitioning == nil { + t.Errorf("omni: t2 partitioning missing") + return + } + if !strings.Contains(strings.ToUpper(tbl.Partitioning.Expr), "UNIX_TIMESTAMP") { + t.Errorf("omni: expected expr containing UNIX_TIMESTAMP, got %q", + tbl.Partitioning.Expr) + } + }) + + // --- 6.13 UNIQUE KEY must cover partition expression columns ----- + t.Run("6_13_unique_key_must_cover_partition_cols", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t (a INT, b INT, UNIQUE KEY (a)) + PARTITION BY HASH(b) PARTITIONS 4` + + if _, err := mc.db.ExecContext(mc.ctx, ddl); err == nil { + t.Errorf("oracle: expected error when UNIQUE KEY excludes partition col") + } + + results, err := c.Exec(ddl, nil) + omniErr := err + if omniErr == nil { + for _, r := range results { + if r.Error != nil { + omniErr = r.Error + break + } + } + } + if omniErr == nil { + t.Errorf("omni: expected error, got nil (schema silently accepted)") + } + }) + + // --- 6.14 Per-partition options preserved verbatim -------------- + t.Run("6_14_per_partition_options_preserved", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (id INT) PARTITION BY HASH(id) + (PARTITION p0 COMMENT='first' ENGINE=InnoDB, + PARTITION p1 COMMENT='second' ENGINE=InnoDB)`) + + // Oracle: PARTITION_COMMENT should be set per partition. + rows := oracleRows(t, mc, `SELECT PARTITION_NAME, PARTITION_COMMENT + FROM information_schema.PARTITIONS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' + ORDER BY PARTITION_ORDINAL_POSITION`) + if len(rows) != 2 { + t.Errorf("oracle rows: got %d, want 2", len(rows)) + } else { + if s, _ := rows[0][1].(string); s != "first" { + t.Errorf("oracle p0 comment: got %q, want %q", s, "first") + } + if s, _ := rows[1][1].(string); s != "second" { + t.Errorf("oracle p1 comment: got %q, want %q", s, "second") + } + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil || tbl.Partitioning == nil || len(tbl.Partitioning.Partitions) < 2 { + t.Errorf("omni: table/partitioning missing or short") + return + } + assertStringEq(t, "omni p0 comment", tbl.Partitioning.Partitions[0].Comment, "first") + assertStringEq(t, "omni p1 comment", tbl.Partitioning.Partitions[1].Comment, "second") + }) + + // --- 6.15 Subpartition options inherit handling ---------------- + t.Run("6_15_subpartition_options", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (id INT, d DATE) + ENGINE=InnoDB + PARTITION BY RANGE(YEAR(d)) + SUBPARTITION BY HASH(id) SUBPARTITIONS 2 + (PARTITION p0 VALUES LESS THAN (2000) + (SUBPARTITION s0 COMMENT='sa', SUBPARTITION s1 COMMENT='sb'), + PARTITION p1 VALUES LESS THAN MAXVALUE + (SUBPARTITION s2, SUBPARTITION s3))`) + + // Oracle: 4 subpartitions total. + rows := oracleRows(t, mc, `SELECT SUBPARTITION_NAME + FROM information_schema.PARTITIONS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' + ORDER BY PARTITION_ORDINAL_POSITION, SUBPARTITION_ORDINAL_POSITION`) + assertIntEq(t, "oracle subpartition count", len(rows), 4) + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil || tbl.Partitioning == nil { + t.Errorf("omni: table/partitioning missing") + return + } + total := 0 + for _, p := range tbl.Partitioning.Partitions { + total += len(p.SubPartitions) + } + assertIntEq(t, "omni subpartition count", total, 4) + }) + + // --- 6.16 ALTER ADD PARTITION auto-naming --------------------- + t.Run("6_16_add_partition_auto_naming", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (id INT) PARTITION BY HASH(id) PARTITIONS 3`) + + // Try ALTER on both sides. omni may not support this — we report + // either outcome as assertion failures, not panics. + addDDL := `ALTER TABLE t ADD PARTITION PARTITIONS 2` + if _, err := mc.db.ExecContext(mc.ctx, addDDL); err != nil { + t.Errorf("oracle: ALTER ADD PARTITION failed: %v", err) + } + names := oraclePartitionNames(t, mc, "t") + want := []string{"p0", "p1", "p2", "p3", "p4"} + assertStringEq(t, "oracle partition names after ADD", + strings.Join(names, ","), strings.Join(want, ",")) + + // omni side: tolerant — report but do not panic. + results, err := c.Exec(addDDL, nil) + if err != nil { + t.Errorf("omni: ALTER ADD PARTITION parse error: %v", err) + return + } + for _, r := range results { + if r.Error != nil { + t.Errorf("omni: ALTER ADD PARTITION exec error: %v", r.Error) + } + } + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil || tbl.Partitioning == nil { + t.Errorf("omni: table/partitioning missing") + return + } + omniNames := c6OmniPartitionNames(tbl) + assertStringEq(t, "omni partition names after ADD", + strings.Join(omniNames, ","), strings.Join(want, ",")) + }) + + // --- 6.17 COALESCE PARTITION removes last N ------------------ + t.Run("6_17_coalesce_partition_tail", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (id INT) PARTITION BY HASH(id) PARTITIONS 6`) + + coalesce := `ALTER TABLE t COALESCE PARTITION 2` + if _, err := mc.db.ExecContext(mc.ctx, coalesce); err != nil { + t.Errorf("oracle: ALTER COALESCE PARTITION failed: %v", err) + } + names := oraclePartitionNames(t, mc, "t") + want := []string{"p0", "p1", "p2", "p3"} + assertStringEq(t, "oracle partition names after COALESCE", + strings.Join(names, ","), strings.Join(want, ",")) + + results, err := c.Exec(coalesce, nil) + if err != nil { + t.Errorf("omni: COALESCE parse error: %v", err) + return + } + for _, r := range results { + if r.Error != nil { + t.Errorf("omni: COALESCE exec error: %v", r.Error) + } + } + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil || tbl.Partitioning == nil { + t.Errorf("omni: table/partitioning missing") + return + } + omniNames := c6OmniPartitionNames(tbl) + assertStringEq(t, "omni partition names after COALESCE", + strings.Join(omniNames, ","), strings.Join(want, ",")) + }) +} + +// c6OmniPartitionNames returns the partition names in order from the omni +// catalog table. Returns nil if partitioning is missing. +func c6OmniPartitionNames(tbl *Table) []string { + if tbl == nil || tbl.Partitioning == nil { + return nil + } + names := make([]string, 0, len(tbl.Partitioning.Partitions)) + for _, p := range tbl.Partitioning.Partitions { + names = append(names, p.Name) + } + return names +} + diff --git a/tidb/catalog/scenarios_c7_test.go b/tidb/catalog/scenarios_c7_test.go new file mode 100644 index 00000000..6d18cad7 --- /dev/null +++ b/tidb/catalog/scenarios_c7_test.go @@ -0,0 +1,465 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C7 covers section C7 (Index defaults) from +// SCENARIOS-mysql-implicit-behavior.md. Each subtest asserts that +// both real MySQL 8.0 and the omni catalog agree on default index +// behaviour for a given DDL input. +// +// Failures in omni assertions are NOT proof failures — they are +// recorded in mysql/catalog/scenarios_bug_queue/c7.md. +func TestScenario_C7(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // --- 7.1 Index algorithm defaults to BTREE -------------------------- + t.Run("7_1_btree_default", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, KEY (a));`) + + var got string + oracleScan(t, mc, + `SELECT INDEX_TYPE FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND INDEX_NAME='a'`, + &got) + assertStringEq(t, "oracle index type", got, "BTREE") + + idx := c7findIndex(c, "t", "a") + if idx == nil { + t.Errorf("omni: index a missing") + return + } + // Omni either stores "BTREE" explicitly or leaves blank; both round-trip + // to BTREE. We accept either. + got2 := strings.ToUpper(idx.IndexType) + if got2 != "" && got2 != "BTREE" { + t.Errorf("omni: expected empty or BTREE, got %q", idx.IndexType) + } + }) + + // --- 7.2 FK creates implicit backing index -------------------------- + t.Run("7_2_fk_implicit_index", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE p (id INT PRIMARY KEY); +CREATE TABLE c (a INT, CONSTRAINT c_ibfk_1 FOREIGN KEY (a) REFERENCES p(id));`) + + rows := oracleRows(t, mc, `SELECT INDEX_NAME FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='c' + ORDER BY INDEX_NAME`) + var oracleIdxNames []string + for _, r := range rows { + oracleIdxNames = append(oracleIdxNames, asString(r[0])) + } + want := "c_ibfk_1" + found := false + for _, n := range oracleIdxNames { + if n == want { + found = true + break + } + } + if !found { + t.Errorf("oracle: expected index %q in %v", want, oracleIdxNames) + } + + tbl := c.GetDatabase("testdb").GetTable("c") + if tbl == nil { + t.Errorf("omni: table c missing") + return + } + // omni: should have an index on column a backing the FK + hasBackingIdx := false + for _, idx := range tbl.Indexes { + if len(idx.Columns) >= 1 && strings.EqualFold(idx.Columns[0].Name, "a") { + hasBackingIdx = true + if idx.Name != "c_ibfk_1" { + t.Errorf("omni: backing index name = %q, want c_ibfk_1", idx.Name) + } + break + } + } + assertBoolEq(t, "omni FK backing index exists", hasBackingIdx, true) + }) + + // --- 7.3 USING HASH coerced to BTREE on InnoDB ---------------------- + t.Run("7_3_hash_coerced_to_btree", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, KEY (a) USING HASH) ENGINE=InnoDB;`) + + var got string + oracleScan(t, mc, + `SELECT INDEX_TYPE FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND INDEX_NAME='a'`, + &got) + assertStringEq(t, "oracle index type after HASH coercion", got, "BTREE") + + idx := c7findIndex(c, "t", "a") + if idx == nil { + t.Errorf("omni: index a missing") + return + } + got2 := strings.ToUpper(idx.IndexType) + if got2 != "" && got2 != "BTREE" { + t.Errorf("omni: expected empty or BTREE (HASH coercion), got %q", idx.IndexType) + } + }) + + // --- 7.4 USING BTREE explicit vs implicit rendering ----------------- + t.Run("7_4_using_explicit_vs_implicit", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t1 (a INT, KEY (a)); +CREATE TABLE t2 (a INT, KEY (a) USING BTREE);`) + + ct1 := oracleShow(t, mc, "SHOW CREATE TABLE t1") + ct2 := oracleShow(t, mc, "SHOW CREATE TABLE t2") + // t1 should NOT contain USING BTREE in the KEY line + if strings.Contains(ct1, "USING BTREE") { + t.Errorf("oracle: t1 unexpectedly contains USING BTREE: %s", ct1) + } + // t2 SHOULD contain USING BTREE + if !strings.Contains(ct2, "USING BTREE") { + t.Errorf("oracle: t2 missing USING BTREE: %s", ct2) + } + + // omni: both tables exist; presence of explicit-flag distinction + // is an open omni gap. We assert both indexes parse and exist. + tbl1 := c.GetDatabase("testdb").GetTable("t1") + tbl2 := c.GetDatabase("testdb").GetTable("t2") + if tbl1 == nil || tbl2 == nil { + t.Errorf("omni: t1=%v t2=%v missing", tbl1, tbl2) + return + } + idx1 := c7findIndex(c, "t1", "a") + idx2 := c7findIndex(c, "t2", "a") + if idx1 == nil || idx2 == nil { + t.Errorf("omni: idx1=%v idx2=%v missing", idx1, idx2) + return + } + // omni gap: the catalog has no IndexTypeExplicit field, so it + // cannot distinguish the two. We assert the distinction here so + // that the test fails until omni grows the bit. + t1Explicit := strings.EqualFold(idx1.IndexType, "BTREE") + t2Explicit := strings.EqualFold(idx2.IndexType, "BTREE") + if t1Explicit == t2Explicit { + t.Errorf("omni: cannot distinguish explicit BTREE from default: t1=%q t2=%q", + idx1.IndexType, idx2.IndexType) + } + }) + + // --- 7.5 UNIQUE allows multiple NULLs ------------------------------- + t.Run("7_5_unique_multiple_nulls", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, UNIQUE KEY (a));`) + + // Insert three NULLs (oracle only — omni catalog does not execute DML). + for i := 0; i < 3; i++ { + if _, err := mc.db.ExecContext(mc.ctx, "INSERT INTO testdb.t VALUES (NULL)"); err != nil { + t.Errorf("oracle: insert NULL %d failed: %v", i, err) + } + } + var count int + oracleScan(t, mc, "SELECT COUNT(*) FROM testdb.t", &count) + assertIntEq(t, "oracle row count after 3 NULL inserts", count, 3) + + // Insert two (1)s — second should fail with duplicate key. + if _, err := mc.db.ExecContext(mc.ctx, "INSERT INTO testdb.t VALUES (1)"); err != nil { + t.Errorf("oracle: first INSERT 1 failed: %v", err) + } + _, err := mc.db.ExecContext(mc.ctx, "INSERT INTO testdb.t VALUES (1)") + if err == nil { + t.Errorf("oracle: expected duplicate-key error on second INSERT 1") + } else if !strings.Contains(err.Error(), "1062") && + !strings.Contains(strings.ToLower(err.Error()), "duplicate") { + t.Errorf("oracle: expected ER_DUP_ENTRY 1062, got %v", err) + } + + // omni: column a must remain nullable after UNIQUE. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + col := tbl.GetColumn("a") + if col == nil { + t.Errorf("omni: column a missing") + return + } + assertBoolEq(t, "omni: UNIQUE keeps column nullable", col.Nullable, true) + }) + + // --- 7.6 VISIBLE default + PK INVISIBLE rejection ------------------- + t.Run("7_6_visible_default_pk_invisible_blocked", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, KEY ix (a));`) + + var visYes string + oracleScan(t, mc, + `SELECT IS_VISIBLE FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND INDEX_NAME='ix'`, + &visYes) + assertStringEq(t, "oracle IS_VISIBLE default", visYes, "YES") + + idx := c7findIndex(c, "t", "ix") + if idx == nil { + t.Errorf("omni: index ix missing") + } else { + assertBoolEq(t, "omni: index visible default", idx.Visible, true) + } + + // PK INVISIBLE must be rejected (oracle 3522). Use fresh table name. + _, mysqlErr := mc.db.ExecContext(mc.ctx, + `CREATE TABLE t_pkinv (a INT, PRIMARY KEY (a) INVISIBLE)`) + if mysqlErr == nil { + t.Errorf("oracle: expected ER_PK_INDEX_CANT_BE_INVISIBLE, got nil") + } else if !strings.Contains(mysqlErr.Error(), "3522") && + !strings.Contains(strings.ToLower(mysqlErr.Error()), "primary key cannot be invisible") { + t.Errorf("oracle: expected 3522, got %v", mysqlErr) + } + + results, err := c.Exec(`CREATE TABLE t_pkinv (a INT, PRIMARY KEY (a) INVISIBLE);`, nil) + omniErrored := err != nil + if !omniErrored { + for _, r := range results { + if r.Error != nil { + omniErrored = true + break + } + } + } + assertBoolEq(t, "omni: rejects PK INVISIBLE", omniErrored, true) + }) + + // --- 7.7 BLOB/TEXT prefix length required --------------------------- + t.Run("7_7_blob_text_prefix_required", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // Form 1: TEXT KEY without length should error (1170). + _, mysqlErr := mc.db.ExecContext(mc.ctx, + `CREATE TABLE t_noprefix (a TEXT, KEY (a))`) + if mysqlErr == nil { + t.Errorf("oracle: expected ER_BLOB_KEY_WITHOUT_LENGTH, got nil") + } else if !strings.Contains(mysqlErr.Error(), "1170") && + !strings.Contains(strings.ToLower(mysqlErr.Error()), "blob/text column") { + t.Errorf("oracle: expected 1170, got %v", mysqlErr) + } + results, err := c.Exec(`CREATE TABLE t_noprefix (a TEXT, KEY (a));`, nil) + omniErrored := err != nil + if !omniErrored { + for _, r := range results { + if r.Error != nil { + omniErrored = true + break + } + } + } + assertBoolEq(t, "omni: rejects TEXT KEY without prefix", omniErrored, true) + + // Form 2: TEXT KEY with prefix length succeeds; SUB_PART=100. + runOnBoth(t, mc, c, `CREATE TABLE t_prefix (a TEXT, KEY (a(100)));`) + var subPart int + oracleScan(t, mc, + `SELECT SUB_PART FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t_prefix' AND INDEX_NAME='a'`, + &subPart) + assertIntEq(t, "oracle SUB_PART", subPart, 100) + + idx := c7findIndex(c, "t_prefix", "a") + if idx == nil { + t.Errorf("omni: index a on t_prefix missing") + } else if len(idx.Columns) != 1 { + t.Errorf("omni: expected 1 column in idx, got %d", len(idx.Columns)) + } else { + assertIntEq(t, "omni: prefix length", idx.Columns[0].Length, 100) + } + + // Form 3: FULLTEXT exempt from prefix requirement. + runOnBoth(t, mc, c, `CREATE TABLE t_ft (a TEXT, FULLTEXT KEY (a));`) + var idxType string + oracleScan(t, mc, + `SELECT INDEX_TYPE FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t_ft' AND INDEX_NAME='a'`, + &idxType) + assertStringEq(t, "oracle FULLTEXT index type", idxType, "FULLTEXT") + }) + + // --- 7.8 FULLTEXT WITH PARSER optionality --------------------------- + t.Run("7_8_fulltext_parser_optional", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t1 (a TEXT, FULLTEXT KEY (a)); +CREATE TABLE t2 (a TEXT, FULLTEXT KEY (a) WITH PARSER ngram);`) + + ct1 := oracleShow(t, mc, "SHOW CREATE TABLE t1") + ct2 := oracleShow(t, mc, "SHOW CREATE TABLE t2") + if strings.Contains(ct1, "WITH PARSER") { + t.Errorf("oracle: t1 (no parser) unexpectedly contains WITH PARSER: %s", ct1) + } + if !strings.Contains(ct2, "WITH PARSER") { + t.Errorf("oracle: t2 missing WITH PARSER ngram: %s", ct2) + } + + // omni: catalog Index has no ParserName field — gap. We assert + // only that both tables and FULLTEXT indexes exist. + idx1 := c7findIndex(c, "t1", "a") + idx2 := c7findIndex(c, "t2", "a") + if idx1 == nil || idx2 == nil { + t.Errorf("omni: idx1=%v idx2=%v missing", idx1, idx2) + return + } + assertBoolEq(t, "omni t1 fulltext", idx1.Fulltext, true) + assertBoolEq(t, "omni t2 fulltext", idx2.Fulltext, true) + // Distinction (parser name) is omni gap — flag for bug queue. + // We can't read ParserName since the field doesn't exist, so log. + t.Logf("omni: ParserName field not present on Index; cannot distinguish t1 vs t2 (expected gap)") + }) + + // --- 7.9 SPATIAL requires NOT NULL ---------------------------------- + t.Run("7_9_spatial_not_null", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // OK case: NOT NULL geometry with SRID. + runOnBoth(t, mc, c, `CREATE TABLE t_ok (g GEOMETRY NOT NULL SRID 4326, SPATIAL KEY (g));`) + + // Error case 1: nullable geometry — ER_SPATIAL_CANT_HAVE_NULL 1252. + _, mysqlErr := mc.db.ExecContext(mc.ctx, + `CREATE TABLE t_null (g GEOMETRY, SPATIAL KEY (g))`) + if mysqlErr == nil { + t.Errorf("oracle: expected ER_SPATIAL_CANT_HAVE_NULL, got nil") + } else if !strings.Contains(mysqlErr.Error(), "1252") && + !strings.Contains(strings.ToLower(mysqlErr.Error()), "all parts of a spatial index") { + t.Errorf("oracle: expected 1252, got %v", mysqlErr) + } + results, err := c.Exec(`CREATE TABLE t_null (g GEOMETRY, SPATIAL KEY (g));`, nil) + omniErrored := err != nil + if !omniErrored { + for _, r := range results { + if r.Error != nil { + omniErrored = true + break + } + } + } + assertBoolEq(t, "omni: rejects nullable SPATIAL", omniErrored, true) + + // Error case 2: SPATIAL with USING BTREE — ER_INDEX_TYPE_NOT_SUPPORTED_FOR_SPATIAL_INDEX 3500. + _, mysqlErr2 := mc.db.ExecContext(mc.ctx, + `CREATE TABLE t_spbtree (g GEOMETRY NOT NULL, SPATIAL KEY (g) USING BTREE)`) + if mysqlErr2 == nil { + t.Errorf("oracle: expected rejection of SPATIAL USING BTREE, got nil") + } else { + // MySQL 8.0 parser rejects `USING BTREE` on a SPATIAL index at parse + // time with a plain syntax error (1064) rather than the semantic + // ER_INDEX_TYPE_NOT_SUPPORTED_FOR_SPATIAL_INDEX (3500). Accept either + // so the scenario is version-robust. + msg := strings.ToLower(mysqlErr2.Error()) + if !strings.Contains(mysqlErr2.Error(), "3500") && + !strings.Contains(mysqlErr2.Error(), "1064") && + !strings.Contains(msg, "not supported") && + !strings.Contains(msg, "syntax") { + t.Errorf("oracle: expected 3500 or 1064/syntax, got %v", mysqlErr2) + } + } + results2, err2 := c.Exec( + `CREATE TABLE t_spbtree (g GEOMETRY NOT NULL, SPATIAL KEY (g) USING BTREE);`, nil) + omniErrored2 := err2 != nil + if !omniErrored2 { + for _, r := range results2 { + if r.Error != nil { + omniErrored2 = true + break + } + } + } + assertBoolEq(t, "omni: rejects SPATIAL USING BTREE", omniErrored2, true) + }) + + // --- 7.10 PK + UNIQUE coexistence on same column -------------------- + t.Run("7_10_pk_and_unique_coexist", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT NOT NULL, PRIMARY KEY (a), UNIQUE KEY uk (a));`) + + rows := oracleRows(t, mc, `SELECT INDEX_NAME FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' + ORDER BY INDEX_NAME`) + var got []string + for _, r := range rows { + got = append(got, asString(r[0])) + } + // Expect both PRIMARY and uk. + havePrimary := false + haveUk := false + for _, n := range got { + if n == "PRIMARY" { + havePrimary = true + } + if n == "uk" { + haveUk = true + } + } + if !havePrimary || !haveUk { + t.Errorf("oracle: expected both PRIMARY and uk indexes, got %v", got) + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Errorf("omni: table t missing") + return + } + omniHasPK := false + omniHasUk := false + for _, idx := range tbl.Indexes { + if idx.Primary { + omniHasPK = true + } + if idx.Unique && !idx.Primary && idx.Name == "uk" { + omniHasUk = true + } + } + assertBoolEq(t, "omni: has PRIMARY index", omniHasPK, true) + assertBoolEq(t, "omni: has uk unique index", omniHasUk, true) + }) +} + +// --- section-local helpers ------------------------------------------------ + +// c7findIndex returns the named index on the given table in testdb, or nil. +func c7findIndex(c *Catalog, tableName, indexName string) *Index { + db := c.GetDatabase("testdb") + if db == nil { + return nil + } + tbl := db.GetTable(tableName) + if tbl == nil { + return nil + } + for _, idx := range tbl.Indexes { + if strings.EqualFold(idx.Name, indexName) { + return idx + } + } + return nil +} diff --git a/tidb/catalog/scenarios_c8_test.go b/tidb/catalog/scenarios_c8_test.go new file mode 100644 index 00000000..941c2918 --- /dev/null +++ b/tidb/catalog/scenarios_c8_test.go @@ -0,0 +1,357 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C8 covers Section C8 "Table option defaults" from +// SCENARIOS-mysql-implicit-behavior.md. Each subtest runs DDL against both +// a real MySQL 8.0 container and the omni catalog, then asserts that both +// agree on the effective default for a given table-level option. +// +// Failed omni assertions are NOT proof failures — they are recorded in +// mysql/catalog/scenarios_bug_queue/c8.md. +func TestScenario_C8(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // c8OracleTableScalar runs a single-scalar SELECT against + // information_schema.TABLES for testdb.
, selecting a single + // string/int column. Uses IFNULL so NULL columns come back as empty + // string (the cases we care about are TABLE_COLLATION and CREATE_OPTIONS). + c8OracleStr := func(t *testing.T, col, table string) string { + t.Helper() + var s string + oracleScan(t, mc, + "SELECT IFNULL("+col+",'') FROM information_schema.TABLES "+ + "WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='"+table+"'", + &s) + return s + } + + // c8ResetDBWithCharset drops testdb, recreates with a specific + // charset, USEs it, and returns a fresh omni catalog with the same + // initial state. + c8ResetDBWithCharset := func(t *testing.T, charset string) *Catalog { + t.Helper() + if _, err := mc.db.ExecContext(mc.ctx, "DROP DATABASE IF EXISTS testdb"); err != nil { + t.Fatalf("oracle DROP DATABASE: %v", err) + } + createStmt := "CREATE DATABASE testdb CHARACTER SET " + charset + if _, err := mc.db.ExecContext(mc.ctx, createStmt); err != nil { + t.Fatalf("oracle CREATE DATABASE: %v", err) + } + if _, err := mc.db.ExecContext(mc.ctx, "USE testdb"); err != nil { + t.Fatalf("oracle USE testdb: %v", err) + } + c := New() + results, err := c.Exec(createStmt+"; USE testdb;", nil) + if err != nil { + t.Errorf("omni parse error for %q: %v", createStmt, err) + return c + } + for _, r := range results { + if r.Error != nil { + t.Errorf("omni exec error on stmt %d: %v", r.Index, r.Error) + } + } + return c + } + + // --- 8.1 Storage engine defaults to InnoDB --------------------------- + t.Run("8_1_engine_defaults_innodb", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT)`) + + got := c8OracleStr(t, "ENGINE", "t") + assertStringEq(t, "oracle ENGINE", strings.ToLower(got), "innodb") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + // Omni may store empty string (meaning default) or "InnoDB". + omniEngine := strings.ToLower(tbl.Engine) + if omniEngine != "" && omniEngine != "innodb" { + t.Errorf("omni Engine: got %q, want \"innodb\" or empty default", tbl.Engine) + } + }) + + // --- 8.2 ROW_FORMAT defaults to DYNAMIC ------------------------------ + t.Run("8_2_row_format_defaults_dynamic", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT)`) + + got := c8OracleStr(t, "ROW_FORMAT", "t") + assertStringEq(t, "oracle ROW_FORMAT", strings.ToLower(got), "dynamic") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + // Omni may store empty string (default) or "DYNAMIC". + omniRF := strings.ToLower(tbl.RowFormat) + if omniRF != "" && omniRF != "dynamic" { + t.Errorf("omni RowFormat: got %q, want \"dynamic\" or empty default", tbl.RowFormat) + } + }) + + // --- 8.3 AUTO_INCREMENT starts at 1, elided from SHOW CREATE --------- + t.Run("8_3_auto_increment_starts_at_one", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (id INT AUTO_INCREMENT PRIMARY KEY)`) + + // Oracle: information_schema.TABLES.AUTO_INCREMENT is the + // next counter value — reported as NULL (no rows yet) or 1 + // depending on engine state. IFNULL normalises to 0. Either + // 0 (unset/NULL) or 1 confirms "starts at 1". + var ai int64 + oracleScan(t, mc, + `SELECT IFNULL(AUTO_INCREMENT,0) FROM information_schema.TABLES + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t'`, + &ai) + if ai != 0 && ai != 1 { + t.Errorf("oracle AUTO_INCREMENT: got %d, want 0 (NULL) or 1", ai) + } + + // Oracle: SHOW CREATE TABLE elides AUTO_INCREMENT= clause (C18.4). + show := oracleShow(t, mc, "SHOW CREATE TABLE t") + if strings.Contains(strings.ToUpper(show), "AUTO_INCREMENT=") { + t.Errorf("oracle SHOW CREATE TABLE should elide AUTO_INCREMENT= clause; got:\n%s", show) + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + // Omni's AutoIncrement may be 0 (unset sentinel) or 1. + if tbl.AutoIncrement != 0 && tbl.AutoIncrement != 1 { + t.Errorf("omni AutoIncrement: got %d, want 0 or 1", tbl.AutoIncrement) + } + }) + + // --- 8.4 CHARSET inherits from database default ---------------------- + t.Run("8_4_charset_inherits_from_db", func(t *testing.T) { + c := c8ResetDBWithCharset(t, "latin1") + + runOnBoth(t, mc, c, `CREATE TABLE t (a VARCHAR(10))`) + + // Oracle: information_schema.TABLES.TABLE_COLLATION should be + // latin1_swedish_ci (the default collation of latin1). + got := c8OracleStr(t, "TABLE_COLLATION", "t") + assertStringEq(t, "oracle TABLE_COLLATION", strings.ToLower(got), "latin1_swedish_ci") + + // Oracle: column a charset is latin1. + var colCS string + oracleScan(t, mc, + `SELECT IFNULL(CHARACTER_SET_NAME,'') FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='a'`, + &colCS) + assertStringEq(t, "oracle column CHARACTER_SET_NAME", + strings.ToLower(colCS), "latin1") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + assertStringEq(t, "omni table Charset", + strings.ToLower(tbl.Charset), "latin1") + }) + + // --- 8.5 COLLATE alone derives CHARSET ------------------------------- + t.Run("8_5_collate_alone_derives_charset", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT) COLLATE=latin1_german2_ci`) + + // Oracle: TABLE_COLLATION should be latin1_german2_ci. + got := c8OracleStr(t, "TABLE_COLLATION", "t") + assertStringEq(t, "oracle TABLE_COLLATION", strings.ToLower(got), "latin1_german2_ci") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + assertStringEq(t, "omni table Charset", + strings.ToLower(tbl.Charset), "latin1") + assertStringEq(t, "omni table Collation", + strings.ToLower(tbl.Collation), "latin1_german2_ci") + }) + + // --- 8.6 KEY_BLOCK_SIZE defaults to 0 and elided in SHOW ------------- + t.Run("8_6_key_block_size_default_zero", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT)`) + + // Oracle: CREATE_OPTIONS does not mention key_block_size. + opts := c8OracleStr(t, "CREATE_OPTIONS", "t") + if strings.Contains(strings.ToLower(opts), "key_block_size") { + t.Errorf("oracle CREATE_OPTIONS unexpectedly contains key_block_size: %q", opts) + } + // Oracle: SHOW CREATE TABLE omits the clause. + show := oracleShow(t, mc, "SHOW CREATE TABLE t") + if strings.Contains(strings.ToUpper(show), "KEY_BLOCK_SIZE") { + t.Errorf("oracle SHOW CREATE TABLE should omit KEY_BLOCK_SIZE; got:\n%s", show) + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + assertIntEq(t, "omni KeyBlockSize", tbl.KeyBlockSize, 0) + }) + + // --- 8.7 COMPRESSION default (None) ---------------------------------- + t.Run("8_7_compression_default_none", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // Setup: plain CREATE without COMPRESSION. + runOnBoth(t, mc, c, `CREATE TABLE t (a INT)`) + + // Oracle: SHOW CREATE TABLE omits COMPRESSION=. + show := oracleShow(t, mc, "SHOW CREATE TABLE t") + if strings.Contains(strings.ToUpper(show), "COMPRESSION=") { + t.Errorf("oracle SHOW CREATE TABLE should omit COMPRESSION=; got:\n%s", show) + } + opts := c8OracleStr(t, "CREATE_OPTIONS", "t") + if strings.Contains(strings.ToLower(opts), "compress") { + t.Errorf("oracle CREATE_OPTIONS unexpectedly contains compression: %q", opts) + } + + // Omni: no Compression field exists. We still verify that a + // CREATE with an explicit COMPRESSION option parses without + // error. This exercises the omni parser path even though the + // value is dropped. + results, err := c.Exec(`CREATE TABLE t_cmp (a INT) COMPRESSION='NONE';`, nil) + if err != nil { + t.Errorf("omni: parse error for COMPRESSION='NONE': %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Errorf("omni: exec error for COMPRESSION='NONE': %v", r.Error) + } + } + // Intentional: document that Compression is not modeled. + // This is a MED-severity omni gap. See scenarios_bug_queue/c8.md. + _ = c.GetDatabase("testdb").GetTable("t_cmp") + }) + + // --- 8.8 ENCRYPTION depends on server default_table_encryption ------ + t.Run("8_8_encryption_default_off", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT)`) + + // Oracle: with default_table_encryption=OFF (testcontainer + // default), SHOW CREATE TABLE omits ENCRYPTION and + // CREATE_OPTIONS has no encryption entry. + show := oracleShow(t, mc, "SHOW CREATE TABLE t") + if strings.Contains(strings.ToUpper(show), "ENCRYPTION=") { + t.Errorf("oracle SHOW CREATE TABLE should omit ENCRYPTION=; got:\n%s", show) + } + opts := c8OracleStr(t, "CREATE_OPTIONS", "t") + if strings.Contains(strings.ToLower(opts), "encrypt") { + t.Errorf("oracle CREATE_OPTIONS unexpectedly contains encryption: %q", opts) + } + + // Omni gap: no Encryption field. Verify parser accepts the + // option without crashing. + results, err := c.Exec(`CREATE TABLE t_enc (a INT) ENCRYPTION='N';`, nil) + if err != nil { + t.Errorf("omni: parse error for ENCRYPTION='N': %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Errorf("omni: exec error for ENCRYPTION='N': %v", r.Error) + } + } + }) + + // --- 8.9 STATS_PERSISTENT defaults to DEFAULT ----------------------- + t.Run("8_9_stats_persistent_default", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT)`) + + // Oracle: CREATE_OPTIONS does not mention stats_persistent; + // SHOW CREATE TABLE omits the clause. + opts := c8OracleStr(t, "CREATE_OPTIONS", "t") + if strings.Contains(strings.ToLower(opts), "stats_persistent") { + t.Errorf("oracle CREATE_OPTIONS unexpectedly contains stats_persistent: %q", opts) + } + show := oracleShow(t, mc, "SHOW CREATE TABLE t") + if strings.Contains(strings.ToUpper(show), "STATS_PERSISTENT") { + t.Errorf("oracle SHOW CREATE TABLE should omit STATS_PERSISTENT; got:\n%s", show) + } + + // Omni gap: no StatsPersistent field. Verify parser accepts. + results, err := c.Exec(`CREATE TABLE t_sp (a INT) STATS_PERSISTENT=DEFAULT;`, nil) + if err != nil { + t.Errorf("omni: parse error for STATS_PERSISTENT=DEFAULT: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Errorf("omni: exec error for STATS_PERSISTENT=DEFAULT: %v", r.Error) + } + } + }) + + // --- 8.10 TABLESPACE defaults to innodb_file_per_table -------------- + t.Run("8_10_tablespace_default_file_per_table", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT)`) + + // Oracle: SHOW CREATE TABLE omits TABLESPACE= clause by default. + show := oracleShow(t, mc, "SHOW CREATE TABLE t") + if strings.Contains(strings.ToUpper(show), "TABLESPACE=") { + t.Errorf("oracle SHOW CREATE TABLE should omit TABLESPACE=; got:\n%s", show) + } + // Oracle: information_schema.INNODB_TABLES has a row for the + // table, and its SPACE column is non-zero (each + // file_per_table tablespace has its own id). + var space int64 + oracleScan(t, mc, + `SELECT IFNULL(SPACE,0) FROM information_schema.INNODB_TABLES + WHERE NAME='testdb/t'`, + &space) + if space == 0 { + t.Errorf("oracle INNODB_TABLES.SPACE for testdb/t: got 0, want non-zero (file_per_table)") + } + + // Omni gap: no Tablespace field. Verify parser accepts an + // explicit TABLESPACE clause without crashing. + results, err := c.Exec(`CREATE TABLE t_ts (a INT) TABLESPACE=innodb_file_per_table;`, nil) + if err != nil { + t.Errorf("omni: parse error for TABLESPACE=innodb_file_per_table: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Errorf("omni: exec error for TABLESPACE=innodb_file_per_table: %v", r.Error) + } + } + }) +} diff --git a/tidb/catalog/scenarios_c9_test.go b/tidb/catalog/scenarios_c9_test.go new file mode 100644 index 00000000..129adb88 --- /dev/null +++ b/tidb/catalog/scenarios_c9_test.go @@ -0,0 +1,367 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestScenario_C9 covers Section C9 "Generated column defaults" from +// SCENARIOS-mysql-implicit-behavior.md. Each subtest runs DDL against both +// a real MySQL 8.0 container and the omni catalog, then asserts that both +// agree on the effective default for a given generated-column behavior. +// +// Failed omni assertions are NOT proof failures — they are recorded in +// mysql/catalog/scenarios_bug_queue/c9.md. +func TestScenario_C9(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // c9OracleExtra returns the EXTRA column for a given column in testdb. + c9OracleExtra := func(t *testing.T, table, col string) string { + t.Helper() + var s string + oracleScan(t, mc, + `SELECT IFNULL(EXTRA,'') FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='`+table+`' AND COLUMN_NAME='`+col+`'`, + &s) + return s + } + + // c9OmniExec runs a multi-statement DDL on omni and returns (parseErr, + // anyStmtErrored). Used by scenarios that expect omni to reject DDL. + c9OmniExec := func(c *Catalog, ddl string) (bool, error) { + results, err := c.Exec(ddl, nil) + if err != nil { + return true, err + } + for _, r := range results { + if r.Error != nil { + return true, r.Error + } + } + return false, nil + } + + // --- 9.1 Generated column storage defaults to VIRTUAL -------------------- + t.Run("9_1_gcol_storage_defaults_virtual", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t (a INT, b INT GENERATED ALWAYS AS (a+1))`) + + extra := c9OracleExtra(t, "t", "b") + if !strings.Contains(strings.ToUpper(extra), "VIRTUAL GENERATED") { + t.Errorf("oracle EXTRA for b: got %q, want contains %q", extra, "VIRTUAL GENERATED") + } + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t not found") + return + } + col := tbl.GetColumn("b") + if col == nil { + t.Error("omni: column b not found") + return + } + if col.Generated == nil { + t.Error("omni: column b should be generated") + return + } + // Default storage is VIRTUAL (Stored == false). + assertBoolEq(t, "omni b Generated.Stored", col.Generated.Stored, false) + }) + + // --- 9.2 FK on generated column — dual-agreement test ------------------- + // + // SCENARIOS-mysql-implicit-behavior.md claims MySQL rejects FK where the + // child column is a STORED generated column (`ER_FK_CANNOT_USE_VIRTUAL_COLUMN`). + // Empirical oracle check: MySQL 8.0.45 ALLOWS FK on a child STORED gcol; + // only VIRTUAL gcols are rejected on the child side. The scenario's + // expected error is outdated. This test asserts oracle and omni agree + // (both should accept) and documents the scenario discrepancy. + t.Run("9_2_fk_on_stored_gcol_child", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + parent := `CREATE TABLE p (id INT PRIMARY KEY)` + if _, err := mc.db.ExecContext(mc.ctx, parent); err != nil { + t.Errorf("oracle setup parent: %v", err) + } + if _, err := c.Exec(parent+";", nil); err != nil { + t.Errorf("omni setup parent: %v", err) + } + + ddl := `CREATE TABLE c ( + a INT, + b INT GENERATED ALWAYS AS (a+1) STORED, + FOREIGN KEY (b) REFERENCES p(id) + )` + _, oracleErr := mc.db.ExecContext(mc.ctx, ddl) + oracleAccepted := oracleErr == nil + omniErrored, _ := c9OmniExec(c, ddl+";") + omniAccepted := !omniErrored + + // Both systems must agree. + assertBoolEq(t, "oracle vs omni agreement on FK on STORED gcol", + omniAccepted, oracleAccepted) + + // Also verify the VIRTUAL gcol case is rejected by both (this is the + // real rule — `ER_FK_CANNOT_USE_VIRTUAL_COLUMN`). + scenarioReset(t, mc) + c2 := scenarioNewCatalog(t) + if _, err := mc.db.ExecContext(mc.ctx, parent); err != nil { + t.Errorf("oracle setup parent (virtual case): %v", err) + } + if _, err := c2.Exec(parent+";", nil); err != nil { + t.Errorf("omni setup parent (virtual case): %v", err) + } + badVirtual := `CREATE TABLE cv ( + a INT, + b INT GENERATED ALWAYS AS (a+1) VIRTUAL, + FOREIGN KEY (b) REFERENCES p(id) + )` + _, oracleVirtErr := mc.db.ExecContext(mc.ctx, badVirtual) + if oracleVirtErr == nil { + t.Errorf("oracle: expected FK-on-VIRTUAL-gcol rejection, got nil error") + } + omniVirtErrored, _ := c9OmniExec(c2, badVirtual+";") + assertBoolEq(t, "omni rejects FK on VIRTUAL gcol (child)", omniVirtErrored, true) + }) + + // --- 9.3 VIRTUAL gcol in PK rejected; secondary KEY allowed -------------- + t.Run("9_3_virtual_gcol_pk_rejected_secondary_allowed", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // t1: secondary KEY on virtual gcol should succeed. + good := `CREATE TABLE t1 (a INT, b INT GENERATED ALWAYS AS (a+1) VIRTUAL, KEY (b))` + if _, err := mc.db.ExecContext(mc.ctx, good); err != nil { + t.Errorf("oracle: expected t1 to succeed, got: %v", err) + } + if omniErr, _ := c9OmniExec(c, good+";"); omniErr { + t.Errorf("omni: expected t1 to succeed (VIRTUAL gcol as secondary KEY)") + } + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl != nil { + found := false + for _, idx := range tbl.Indexes { + for _, col := range idx.Columns { + if strings.EqualFold(col.Name, "b") { + found = true + } + } + } + if !found { + t.Errorf("omni: secondary index on b missing from t1") + } + } else { + t.Errorf("omni: t1 table missing after CREATE") + } + + // t2: PRIMARY KEY on virtual gcol should error. + bad := `CREATE TABLE t2 (a INT, b INT GENERATED ALWAYS AS (a+1) VIRTUAL PRIMARY KEY)` + _, oracleErr := mc.db.ExecContext(mc.ctx, bad) + if oracleErr == nil { + t.Errorf("oracle: expected VIRTUAL gcol PK rejection, got nil error") + } + omniErrored, _ := c9OmniExec(c, bad+";") + assertBoolEq(t, "omni rejects VIRTUAL gcol as PRIMARY KEY", omniErrored, true) + }) + + // --- 9.4 Gcol expression must be deterministic (NOW() rejected) --------- + t.Run("9_4_gcol_expr_must_be_deterministic", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + bad := `CREATE TABLE t (a INT, b TIMESTAMP GENERATED ALWAYS AS (NOW()) VIRTUAL)` + _, oracleErr := mc.db.ExecContext(mc.ctx, bad) + if oracleErr == nil { + t.Errorf("oracle: expected non-deterministic gcol rejection, got nil error") + } + + omniErrored, _ := c9OmniExec(c, bad+";") + assertBoolEq(t, "omni rejects NOW() in gcol expression", omniErrored, true) + }) + + // --- 9.5 UNIQUE on gcol allowed (both STORED and VIRTUAL under InnoDB) --- + t.Run("9_5_unique_on_gcol_allowed", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl1 := `CREATE TABLE t1 (a INT, b INT GENERATED ALWAYS AS (a+1) STORED UNIQUE)` + ddl2 := `CREATE TABLE t2 (a INT, b INT GENERATED ALWAYS AS (a+1) VIRTUAL UNIQUE)` + + if _, err := mc.db.ExecContext(mc.ctx, ddl1); err != nil { + t.Errorf("oracle t1 STORED UNIQUE: %v", err) + } + if _, err := mc.db.ExecContext(mc.ctx, ddl2); err != nil { + t.Errorf("oracle t2 VIRTUAL UNIQUE: %v", err) + } + + if omniErr, err := c9OmniExec(c, ddl1+";"); omniErr { + t.Errorf("omni t1 STORED UNIQUE: %v", err) + } + if omniErr, err := c9OmniExec(c, ddl2+";"); omniErr { + t.Errorf("omni t2 VIRTUAL UNIQUE: %v", err) + } + + // Oracle: both tables should have a UNIQUE index covering b. + for _, table := range []string{"t1", "t2"} { + var idxCount int64 + oracleScan(t, mc, + `SELECT COUNT(*) FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='`+table+`' + AND COLUMN_NAME='b' AND NON_UNIQUE=0`, + &idxCount) + if idxCount < 1 { + t.Errorf("oracle: %s should have a UNIQUE index on b", table) + } + } + + // Omni: both tables should have a UNIQUE index on b. + for _, table := range []string{"t1", "t2"} { + tbl := c.GetDatabase("testdb").GetTable(table) + if tbl == nil { + t.Errorf("omni: %s missing", table) + continue + } + hasUnique := false + for _, idx := range tbl.Indexes { + if !idx.Unique { + continue + } + for _, col := range idx.Columns { + if strings.EqualFold(col.Name, "b") { + hasUnique = true + } + } + } + if !hasUnique { + t.Errorf("omni: %s missing UNIQUE index on b", table) + } + } + }) + + // --- 9.6 Gcol NOT NULL declaration accepted at CREATE time -------------- + t.Run("9_6_gcol_not_null_accepted", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t ( + a INT NULL, + b INT GENERATED ALWAYS AS (a+1) VIRTUAL NOT NULL + )` + runOnBoth(t, mc, c, ddl) + + // Oracle: IS_NULLABLE for b is 'NO'. + var isNullable string + oracleScan(t, mc, + `SELECT IS_NULLABLE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='b'`, + &isNullable) + assertStringEq(t, "oracle IS_NULLABLE for b", isNullable, "NO") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t missing") + return + } + col := tbl.GetColumn("b") + if col == nil { + t.Error("omni: column b missing") + return + } + assertBoolEq(t, "omni b Nullable", col.Nullable, false) + if col.Generated == nil { + t.Error("omni: b should still be generated") + } + }) + + // --- 9.7 FK child referencing STORED gcol parent is allowed ------------- + t.Run("9_7_fk_parent_stored_gcol_allowed", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE p ( + a INT, + b INT GENERATED ALWAYS AS (a+1) STORED, + UNIQUE KEY (b) + ); +CREATE TABLE c (x INT, FOREIGN KEY (x) REFERENCES p(b));` + runOnBoth(t, mc, c, ddl) + + // Oracle: FK exists in REFERENTIAL_CONSTRAINTS. + var fkCount int64 + oracleScan(t, mc, + `SELECT COUNT(*) FROM information_schema.REFERENTIAL_CONSTRAINTS + WHERE CONSTRAINT_SCHEMA='testdb' AND TABLE_NAME='c' + AND REFERENCED_TABLE_NAME='p'`, + &fkCount) + if fkCount != 1 { + t.Errorf("oracle: expected 1 FK on c referencing p, got %d", fkCount) + } + + tblC := c.GetDatabase("testdb").GetTable("c") + if tblC == nil { + t.Error("omni: table c missing") + return + } + foundFK := false + for _, con := range tblC.Constraints { + if con.Type == ConForeignKey && strings.EqualFold(con.RefTable, "p") { + foundFK = true + } + } + if !foundFK { + t.Errorf("omni: no FK on c referencing p (parent STORED gcol should be allowed)") + } + }) + + // --- 9.8 Gcol charset derived from expression inputs -------------------- + t.Run("9_8_gcol_charset_from_expression", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t ( + a VARCHAR(10) CHARACTER SET latin1, + b VARCHAR(20) GENERATED ALWAYS AS (CONCAT(a, 'x')) VIRTUAL + )` + runOnBoth(t, mc, c, ddl) + + // Oracle: query b's CHARACTER_SET_NAME. SCENARIOS claims this + // should be latin1 (inherited from expression inputs). Empirical + // oracle on MySQL 8.0.45 returns utf8mb4 (the table default), + // because when the gcol column is declared as VARCHAR without an + // explicit charset it picks up the table charset rather than the + // expression's result-coercion charset. Dual-assertion: omni and + // oracle must agree on whatever MySQL actually does. + var cs string + oracleScan(t, mc, + `SELECT IFNULL(CHARACTER_SET_NAME,'') FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' AND COLUMN_NAME='b'`, + &cs) + oracleCharset := strings.ToLower(cs) + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Error("omni: table t missing") + return + } + col := tbl.GetColumn("b") + if col == nil { + t.Error("omni: column b missing") + return + } + omniCharset := strings.ToLower(col.Charset) + // Omni may leave the field empty when the column inherits from the + // table default — treat that as an acceptable default state. + if omniCharset == "" { + omniCharset = "utf8mb4" + } + assertStringEq(t, "omni b Charset agrees with oracle", omniCharset, oracleCharset) + }) +} diff --git a/tidb/catalog/scenarios_helpers_test.go b/tidb/catalog/scenarios_helpers_test.go new file mode 100644 index 00000000..083268d8 --- /dev/null +++ b/tidb/catalog/scenarios_helpers_test.go @@ -0,0 +1,332 @@ +package catalog + +import ( + "context" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/testcontainers/testcontainers-go" +) + +// This file provides the shared infrastructure used by the "mysql-implicit-behavior" +// starmap scenario tests. Section workers in BATCHES 1-7 build on these helpers +// to run dual-assertion scenarios against both a real MySQL 8.0 container and +// the omni catalog. +// +// NOTE: asString is already defined in catalog_spotcheck_test.go and is reused +// here as-is; do not redeclare it. + +// scenarioContainer wraps startContainer for naming consistency with the +// scenario helpers. The caller must defer the cleanup func. +// +// IMPORTANT: pins the underlying *sql.DB pool to a single connection. Many +// scenario tests rely on connection-scoped state (USE testdb, SET SESSION +// explicit_defaults_for_timestamp=0, SET SESSION sql_mode='', etc.) that +// only affects the current MySQL session. Without pinning, subsequent +// queries may execute on a different pool connection and silently run +// against the wrong schema or session settings, producing nondeterministic +// oracle results. (Codex BATCH 4 review P1/P2.) +func scenarioContainer(t *testing.T) (*mysqlContainer, func()) { + t.Helper() + mc, cleanup := startContainer(t) + mc.db.SetMaxOpenConns(1) + mc.db.SetMaxIdleConns(1) + return mc, cleanup +} + +// scenarioReset drops and recreates the shared testdb database on the MySQL +// container and selects it. It uses t.Error rather than t.Fatal so that the +// calling test can continue and report additional diffs within one run. +func scenarioReset(t *testing.T, mc *mysqlContainer) { + t.Helper() + stmts := []string{ + "DROP DATABASE IF EXISTS testdb", + "CREATE DATABASE testdb", + "USE testdb", + } + for _, stmt := range stmts { + if _, err := mc.db.ExecContext(mc.ctx, stmt); err != nil { + t.Errorf("scenarioReset %q: %v", stmt, err) + } + } +} + +// scenarioNewCatalog returns a fresh omni catalog with a testdb database +// created and selected. Uses t.Fatal on setup errors because without a +// working catalog nothing else can run. +func scenarioNewCatalog(t *testing.T) *Catalog { + t.Helper() + c := New() + results, err := c.Exec("CREATE DATABASE testdb; USE testdb;", nil) + if err != nil { + t.Fatalf("scenarioNewCatalog parse error: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Fatalf("scenarioNewCatalog exec error on stmt %d: %v", r.Index, r.Error) + } + } + return c +} + +// runOnBoth executes a (possibly multi-statement) DDL string on both the +// MySQL container and the omni catalog. Errors on either side are reported +// via t.Error so that the calling test can continue comparing remaining +// scenario state. Statements are split respecting quotes; individual +// statements are executed one at a time on the container side. +func runOnBoth(t *testing.T, mc *mysqlContainer, c *Catalog, ddl string) { + t.Helper() + + for _, stmt := range splitStmts(ddl) { + if _, err := mc.db.ExecContext(mc.ctx, stmt); err != nil { + t.Errorf("mysql container DDL failed: %q: %v", stmt, err) + } + } + + results, err := c.Exec(ddl, nil) + if err != nil { + t.Errorf("omni catalog parse error for DDL %q: %v", ddl, err) + return + } + for _, r := range results { + if r.Error != nil { + t.Errorf("omni catalog exec error on stmt %d: %v", r.Index, r.Error) + } + } +} + +// oracleScan runs a single-row information_schema (or other) query against +// the MySQL container and scans into dests. Uses t.Error on failure so the +// test can continue. +func oracleScan(t *testing.T, mc *mysqlContainer, query string, dests ...any) { + t.Helper() + row := mc.db.QueryRowContext(mc.ctx, query) + if err := row.Scan(dests...); err != nil { + t.Errorf("oracleScan failed: %q: %v", query, err) + } +} + +// oracleRows runs a multi-row query against the MySQL container and returns +// the rows as a [][]any, converting []byte values to string for readability. +// Uses t.Error on failure and returns nil. +func oracleRows(t *testing.T, mc *mysqlContainer, query string) [][]any { + t.Helper() + rows, err := mc.db.QueryContext(mc.ctx, query) + if err != nil { + t.Errorf("oracleRows query failed: %q: %v", query, err) + return nil + } + defer rows.Close() + + cols, err := rows.Columns() + if err != nil { + t.Errorf("oracleRows columns failed: %v", err) + return nil + } + + var out [][]any + for rows.Next() { + vals := make([]any, len(cols)) + ptrs := make([]any, len(cols)) + for i := range vals { + ptrs[i] = &vals[i] + } + if err := rows.Scan(ptrs...); err != nil { + t.Errorf("oracleRows scan failed: %v", err) + return out + } + for i, v := range vals { + if b, ok := v.([]byte); ok { + vals[i] = string(b) + } + } + out = append(out, vals) + } + if err := rows.Err(); err != nil { + t.Errorf("oracleRows iteration error: %v", err) + } + return out +} + +// oracleShow runs a SHOW CREATE TABLE / VIEW / ... statement against the +// container and returns the second column (the CREATE statement text). +// The first column (name) and any trailing columns are discarded. Uses +// t.Error on failure and returns the empty string. +// +// Note: different SHOW CREATE variants return different numbers of columns +// (SHOW CREATE TABLE returns 2, SHOW CREATE VIEW returns 4, SHOW CREATE +// FUNCTION/PROCEDURE/TRIGGER/EVENT return 6 or 7). This helper scans +// dynamically so it works with any of them. +func oracleShow(t *testing.T, mc *mysqlContainer, stmt string) string { + t.Helper() + rows, err := mc.db.QueryContext(mc.ctx, stmt) + if err != nil { + t.Errorf("oracleShow query failed: %q: %v", stmt, err) + return "" + } + defer rows.Close() + + cols, err := rows.Columns() + if err != nil { + t.Errorf("oracleShow columns failed: %v", err) + return "" + } + if !rows.Next() { + if err := rows.Err(); err != nil { + t.Errorf("oracleShow iteration error: %v", err) + } else { + t.Errorf("oracleShow returned no rows for %q", stmt) + } + return "" + } + vals := make([]any, len(cols)) + ptrs := make([]any, len(cols)) + for i := range vals { + ptrs[i] = &vals[i] + } + if err := rows.Scan(ptrs...); err != nil { + t.Errorf("oracleShow scan failed: %v", err) + return "" + } + if len(vals) < 2 { + t.Errorf("oracleShow %q: expected >=2 columns, got %d", stmt, len(cols)) + return "" + } + return asString(vals[1]) +} + +// assertStringEq reports a diff if got != want. +func assertStringEq(t *testing.T, label, got, want string) { + t.Helper() + if got != want { + t.Errorf("%s: got %q, want %q", label, got, want) + } +} + +// assertIntEq reports a diff if got != want. +func assertIntEq(t *testing.T, label string, got, want int) { + t.Helper() + if got != want { + t.Errorf("%s: got %d, want %d", label, got, want) + } +} + +// assertBoolEq reports a diff if got != want. +func assertBoolEq(t *testing.T, label string, got, want bool) { + t.Helper() + if got != want { + t.Errorf("%s: got %v, want %v", label, got, want) + } +} + +// scenariosSkipIfShort skips the calling test when testing.Short() is true. +func scenariosSkipIfShort(t *testing.T) { + t.Helper() + if testing.Short() { + t.Skip("skipping scenario test in short mode") + } +} + +// scenariosSkipIfNoDocker skips the calling test when SKIP_SCENARIO_TESTS=1 +// is set OR when the Docker daemon is not reachable. Probing the daemon +// avoids a panic from testcontainers in environments without Docker. +// (Codex phase review finding.) +func scenariosSkipIfNoDocker(t *testing.T) { + t.Helper() + if os.Getenv("SKIP_SCENARIO_TESTS") == "1" { + t.Skip("SKIP_SCENARIO_TESTS=1 set; skipping scenario test") + } + if !dockerAvailable() { + t.Skip("Docker daemon not reachable; skipping scenario test") + } +} + +var ( + dockerAvailableOnce sync.Once + dockerAvailableVal bool +) + +func dockerAvailable() bool { + dockerAvailableOnce.Do(func() { + // testcontainers.NewDockerProvider can panic via MustExtractDockerHost + // when DOCKER_HOST is unset and no socket is reachable, so wrap the + // probe in recover() to guarantee a clean skip instead of a panic. + defer func() { + if r := recover(); r != nil { + dockerAvailableVal = false + } + }() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + provider, err := testcontainers.NewDockerProvider() + if err != nil { + dockerAvailableVal = false + return + } + defer provider.Close() + if err := provider.Health(ctx); err != nil { + dockerAvailableVal = false + return + } + dockerAvailableVal = true + }) + return dockerAvailableVal +} + +// mysqlAtLeast reports whether the VERSION() string reports a server at +// or above the requested (major, minor, patch) triple. Non-numeric suffixes +// like "-log" or "-debug" are ignored. Malformed versions return false. +func mysqlAtLeast(ver string, wantMajor, wantMinor, wantPatch int) bool { + // Strip anything after the first non-numeric/non-dot rune. + cut := len(ver) + for i, r := range ver { + if (r < '0' || r > '9') && r != '.' { + cut = i + break + } + } + parts := strings.Split(ver[:cut], ".") + if len(parts) < 3 { + return false + } + atoi := func(s string) int { + n := 0 + for _, r := range s { + if r < '0' || r > '9' { + return -1 + } + n = n*10 + int(r-'0') + } + return n + } + maj, min, pat := atoi(parts[0]), atoi(parts[1]), atoi(parts[2]) + if maj < 0 || min < 0 || pat < 0 { + return false + } + if maj != wantMajor { + return maj > wantMajor + } + if min != wantMinor { + return min > wantMinor + } + return pat >= wantPatch +} + +// splitStmts splits a possibly multi-statement DDL string into individual +// statements, respecting single quotes, double quotes, and backticks, and +// trimming empty results. It is a thin wrapper around splitStatements +// (defined in container_test.go) with extra trimming so scenario workers +// can write `splitStmts(ddl)` without importing two names. +func splitStmts(ddl string) []string { + raw := splitStatements(ddl) + out := raw[:0] + for _, s := range raw { + if s = strings.TrimSpace(s); s != "" { + out = append(out, s) + } + } + return out +} diff --git a/tidb/catalog/scenarios_ps_test.go b/tidb/catalog/scenarios_ps_test.go new file mode 100644 index 00000000..31deb794 --- /dev/null +++ b/tidb/catalog/scenarios_ps_test.go @@ -0,0 +1,369 @@ +package catalog + +import ( + "slices" + "sort" + "strings" + "testing" +) + +// TestScenario_PS covers section PS of SCENARIOS-mysql-implicit-behavior.md +// — "Path-split behaviors (CREATE vs ALTER)". Each subtest runs the scenario's +// DDL against both a real MySQL 8.0 container and the omni catalog and asserts +// both match the expected value. Existing TestBugFix_* tests remain unchanged; +// these TestScenario_PS tests are the durable dual-assertion versions. +func TestScenario_PS(t *testing.T) { + scenariosSkipIfShort(t) + scenariosSkipIfNoDocker(t) + + mc, cleanup := scenarioContainer(t) + defer cleanup() + + // PS.1 CHECK constraint counter — CREATE path (fresh counter). + // User-named t_chk_5 is NOT seeded into the generated counter; the + // single unnamed CHECK receives t_chk_1. + t.Run("PS_1_Check_counter_CREATE_fresh", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE t ( + a INT, + CONSTRAINT t_chk_5 CHECK (a > 0), + b INT, + CHECK (b < 100) + )`) + + want := []string{"t_chk_1", "t_chk_5"} + + rows := oracleRows(t, mc, ` + SELECT CONSTRAINT_NAME FROM information_schema.CHECK_CONSTRAINTS + WHERE CONSTRAINT_SCHEMA='testdb' + ORDER BY CONSTRAINT_NAME`) + var oracleNames []string + for _, r := range rows { + oracleNames = append(oracleNames, asString(r[0])) + } + if !slices.Equal(oracleNames, want) { + t.Errorf("PS.1 oracle CHECK names: got %v, want %v", oracleNames, want) + } + + omniNames := psCheckNames(c, "t") + if !slices.Equal(omniNames, want) { + t.Errorf("PS.1 omni CHECK names: got %v, want %v", omniNames, want) + } + }) + + // PS.2 CHECK constraint counter — ALTER path (max+1). + // The user-named t_chk_20 IS seeded on the ALTER path; new unnamed + // CHECK gets t_chk_21. + t.Run("PS_2_Check_counter_ALTER_maxplus1", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (a INT, b INT, CONSTRAINT t_chk_20 CHECK (a>0))`) + runOnBoth(t, mc, c, `ALTER TABLE t ADD CHECK (b>0)`) + + want := []string{"t_chk_20", "t_chk_21"} + + rows := oracleRows(t, mc, ` + SELECT CONSTRAINT_NAME FROM information_schema.CHECK_CONSTRAINTS + WHERE CONSTRAINT_SCHEMA='testdb' + ORDER BY CONSTRAINT_NAME`) + var oracleNames []string + for _, r := range rows { + oracleNames = append(oracleNames, asString(r[0])) + } + if !slices.Equal(oracleNames, want) { + t.Errorf("PS.2 oracle CHECK names: got %v, want %v", oracleNames, want) + } + + omniNames := psCheckNames(c, "t") + if !slices.Equal(omniNames, want) { + t.Errorf("PS.2 omni CHECK names: got %v, want %v", omniNames, want) + } + }) + + // PS.3 FK counter — CREATE path (fresh, user-named NOT seeded). + t.Run("PS_3_FK_counter_CREATE_fresh", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE parent (id INT PRIMARY KEY)`) + runOnBoth(t, mc, c, `CREATE TABLE child ( + a INT, + CONSTRAINT child_ibfk_5 FOREIGN KEY (a) REFERENCES parent(id), + b INT, + FOREIGN KEY (b) REFERENCES parent(id) + )`) + + want := []string{"child_ibfk_1", "child_ibfk_5"} + + rows := oracleRows(t, mc, ` + SELECT CONSTRAINT_NAME FROM information_schema.TABLE_CONSTRAINTS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='child' + AND CONSTRAINT_TYPE='FOREIGN KEY' + ORDER BY CONSTRAINT_NAME`) + var oracleNames []string + for _, r := range rows { + oracleNames = append(oracleNames, asString(r[0])) + } + if !slices.Equal(oracleNames, want) { + t.Errorf("PS.3 oracle FK names: got %v, want %v", oracleNames, want) + } + + omniNames := psFKNames(c, "child") + if !slices.Equal(omniNames, want) { + t.Errorf("PS.3 omni FK names: got %v, want %v", omniNames, want) + } + }) + + // PS.4 FK counter — ALTER path (max+1 over existing generated numbers). + t.Run("PS_4_FK_counter_ALTER_maxplus1", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, `CREATE TABLE parent (id INT PRIMARY KEY)`) + runOnBoth(t, mc, c, `CREATE TABLE child ( + a INT, + b INT, + CONSTRAINT child_ibfk_20 FOREIGN KEY (a) REFERENCES parent(id) + )`) + runOnBoth(t, mc, c, + `ALTER TABLE child ADD FOREIGN KEY (b) REFERENCES parent(id)`) + + want := []string{"child_ibfk_20", "child_ibfk_21"} + + rows := oracleRows(t, mc, ` + SELECT CONSTRAINT_NAME FROM information_schema.TABLE_CONSTRAINTS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='child' + AND CONSTRAINT_TYPE='FOREIGN KEY' + ORDER BY CONSTRAINT_NAME`) + var oracleNames []string + for _, r := range rows { + oracleNames = append(oracleNames, asString(r[0])) + } + if !slices.Equal(oracleNames, want) { + t.Errorf("PS.4 oracle FK names: got %v, want %v", oracleNames, want) + } + + omniNames := psFKNames(c, "child") + if !slices.Equal(omniNames, want) { + t.Errorf("PS.4 omni FK names: got %v, want %v", omniNames, want) + } + }) + + // PS.5 DEFAULT NOW() / fsp precision mismatch must error. + // MySQL rejects DATETIME(6) DEFAULT NOW() with ER_INVALID_DEFAULT (1067). + // omni currently accepts — this is a HIGH severity strictness gap. + t.Run("PS_5_Datetime_fsp_mismatch_errors", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + ddl := `CREATE TABLE t (a DATETIME(6) DEFAULT NOW())` + + _, oracleErr := mc.db.ExecContext(mc.ctx, ddl) + if oracleErr == nil { + t.Errorf("PS.5 oracle: expected MySQL to reject DATETIME(6) DEFAULT NOW() with ER_INVALID_DEFAULT, got success") + } else if !strings.Contains(oracleErr.Error(), "1067") && + !strings.Contains(strings.ToLower(oracleErr.Error()), "invalid default") { + t.Errorf("PS.5 oracle: unexpected error text: %v", oracleErr) + } + + var omniErr error + results, parseErr := c.Exec(ddl, nil) + if parseErr != nil { + omniErr = parseErr + } else { + for _, r := range results { + if r.Error != nil { + omniErr = r.Error + break + } + } + } + if omniErr == nil { + t.Errorf("PS.5 omni: KNOWN BUG — expected ER_INVALID_DEFAULT-style error, got success (see scenarios_bug_queue/ps.md)") + } + }) + + // PS.6 HASH partition ADD — seeded from count. + // omni has no ALTER TABLE ... ADD PARTITION support; this scenario is + // expected to fail on omni. Oracle side verifies the MySQL behavior. + t.Run("PS_6_Hash_partition_ADD_seeded", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + runOnBoth(t, mc, c, + `CREATE TABLE t (id INT) PARTITION BY HASH(id) PARTITIONS 3`) + // Only run the ALTER on the oracle — some omni builds have no support. + // We still try it on omni and report mismatch. + alter := `ALTER TABLE t ADD PARTITION PARTITIONS 2` + if _, err := mc.db.ExecContext(mc.ctx, alter); err != nil { + t.Errorf("PS.6 oracle ALTER failed: %v", err) + } + + // Oracle partition names. + rows := oracleRows(t, mc, ` + SELECT PARTITION_NAME FROM information_schema.PARTITIONS + WHERE TABLE_SCHEMA='testdb' AND TABLE_NAME='t' + ORDER BY PARTITION_ORDINAL_POSITION`) + var oracleNames []string + for _, r := range rows { + oracleNames = append(oracleNames, asString(r[0])) + } + wantOracle := []string{"p0", "p1", "p2", "p3", "p4"} + if !slices.Equal(oracleNames, wantOracle) { + t.Errorf("PS.6 oracle partition names: got %v, want %v", oracleNames, wantOracle) + } + + // omni side: attempt the ALTER; record whether it produced the + // expected 5-partition layout. + results, parseErr := c.Exec(alter, nil) + var omniAlterErr error + if parseErr != nil { + omniAlterErr = parseErr + } else { + for _, r := range results { + if r.Error != nil { + omniAlterErr = r.Error + break + } + } + } + + var omniNames []string + if tbl := c.GetDatabase("testdb").GetTable("t"); tbl != nil && tbl.Partitioning != nil { + for _, p := range tbl.Partitioning.Partitions { + omniNames = append(omniNames, p.Name) + } + } + if omniAlterErr != nil || !slices.Equal(omniNames, wantOracle) { + t.Errorf("PS.6 omni: KNOWN GAP — expected partition names %v; got %v (alter err: %v)", wantOracle, omniNames, omniAlterErr) + } + }) + + // PS.7 FK name collision — user-named t_ibfk_1 collides with the + // generator's implicit first unnamed FK name. MySQL errors with + // ER_FK_DUP_NAME (1826). omni silently succeeds → HIGH severity bug. + t.Run("PS_7_FK_name_collision_errors", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // Parent table first (both sides). + runOnBoth(t, mc, c, `CREATE TABLE p (id INT PRIMARY KEY)`) + + ddl := `CREATE TABLE c ( + a INT, + CONSTRAINT c_ibfk_1 FOREIGN KEY (a) REFERENCES p(id), + b INT, + FOREIGN KEY (b) REFERENCES p(id) + )` + + _, oracleErr := mc.db.ExecContext(mc.ctx, ddl) + if oracleErr == nil { + t.Errorf("PS.7 oracle: expected ER_FK_DUP_NAME (1826), got success") + } else if !strings.Contains(oracleErr.Error(), "1826") && + !strings.Contains(strings.ToLower(oracleErr.Error()), "duplicate") { + t.Errorf("PS.7 oracle: unexpected error text: %v", oracleErr) + } + + var omniErr error + results, parseErr := c.Exec(ddl, nil) + if parseErr != nil { + omniErr = parseErr + } else { + for _, r := range results { + if r.Error != nil { + omniErr = r.Error + break + } + } + } + if omniErr == nil { + t.Errorf("PS.7 omni: KNOWN BUG — expected ER_FK_DUP_NAME, got success (see scenarios_bug_queue/ps.md)") + } + }) + + // PS.8 CHECK constraint duplicate name in schema — must error. + // CHECK constraint names are schema-scoped in MySQL; the second + // CREATE TABLE with the same check name fails with + // ER_CHECK_CONSTRAINT_DUP_NAME (3822). omni does not enforce schema + // scoping → MED severity bug. + t.Run("PS_8_Check_dup_name_schema_scope", func(t *testing.T) { + scenarioReset(t, mc) + c := scenarioNewCatalog(t) + + // First create must succeed on both. + runOnBoth(t, mc, c, + `CREATE TABLE t1 (a INT, CONSTRAINT my_rule CHECK (a > 0))`) + + ddl2 := `CREATE TABLE t2 (b INT, CONSTRAINT my_rule CHECK (b > 0))` + + _, oracleErr := mc.db.ExecContext(mc.ctx, ddl2) + if oracleErr == nil { + t.Errorf("PS.8 oracle: expected ER_CHECK_CONSTRAINT_DUP_NAME, got success") + } else if !strings.Contains(oracleErr.Error(), "3822") && + !strings.Contains(strings.ToLower(oracleErr.Error()), "duplicate") { + t.Errorf("PS.8 oracle: unexpected error text: %v", oracleErr) + } + + var omniErr error + results, parseErr := c.Exec(ddl2, nil) + if parseErr != nil { + omniErr = parseErr + } else { + for _, r := range results { + if r.Error != nil { + omniErr = r.Error + break + } + } + } + if omniErr == nil { + t.Errorf("PS.8 omni: KNOWN BUG — expected ER_CHECK_CONSTRAINT_DUP_NAME, got success (see scenarios_bug_queue/ps.md)") + } + }) +} + +// omniCheckNames returns the CHECK constraint names from the omni catalog +// for the given table (in testdb), sorted. +func psCheckNames(c *Catalog, table string) []string { + db := c.GetDatabase("testdb") + if db == nil { + return nil + } + tbl := db.GetTable(table) + if tbl == nil { + return nil + } + var names []string + for _, con := range tbl.Constraints { + if con.Type == ConCheck { + names = append(names, con.Name) + } + } + sort.Strings(names) + return names +} + +// omniFKNames returns the FOREIGN KEY constraint names from the omni catalog +// for the given table (in testdb), sorted. +func psFKNames(c *Catalog, table string) []string { + db := c.GetDatabase("testdb") + if db == nil { + return nil + } + tbl := db.GetTable(table) + if tbl == nil { + return nil + } + var names []string + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey { + names = append(names, con.Name) + } + } + sort.Strings(names) + return names +} diff --git a/tidb/catalog/scope.go b/tidb/catalog/scope.go new file mode 100644 index 00000000..248efb48 --- /dev/null +++ b/tidb/catalog/scope.go @@ -0,0 +1,155 @@ +package catalog + +import ( + "fmt" + "strings" + + "github.com/bytebase/omni/tidb/scope" +) + +// analyzerScope wraps scope.Scope and adds analyzer-specific features: +// parent chain for correlated subqueries, and rteIdx mapping. +type analyzerScope struct { + base *scope.Scope + rteMap []int // parallel to base entries: entry index -> RTE index in Query.RangeTable + cols [][]*Column // parallel to base entries: entry index -> catalog Column pointers + parent *analyzerScope +} + +func newScope() *analyzerScope { + return &analyzerScope{ + base: scope.New(), + } +} + +// newScopeWithParent creates a new scope with a parent scope for correlated subquery resolution. +func newScopeWithParent(parent *analyzerScope) *analyzerScope { + return &analyzerScope{ + base: scope.New(), + parent: parent, + } +} + +// markCoalesced marks a column from a table as coalesced (hidden during star expansion). +func (s *analyzerScope) markCoalesced(tableName, colName string) { + s.base.MarkCoalesced(tableName, colName) +} + +// isCoalesced returns true if the given table.column is coalesced away by USING/NATURAL. +func (s *analyzerScope) isCoalesced(tableName, colName string) bool { + return s.base.IsCoalesced(tableName, colName) +} + +// add registers a table reference in the scope. +func (s *analyzerScope) add(name string, rteIdx int, columns []*Column) { + scopeCols := make([]scope.Column, len(columns)) + for i, c := range columns { + scopeCols[i] = scope.Column{Name: c.Name, Position: i + 1} + } + s.base.Add(name, &scope.Table{Name: name, Columns: scopeCols}) + s.rteMap = append(s.rteMap, rteIdx) + s.cols = append(s.cols, columns) +} + +// resolveColumn finds an unqualified column name across all scope entries. +// Returns the RTE index and 1-based attribute number. +// Error 1052 for ambiguous, 1054 for unknown. +func (s *analyzerScope) resolveColumn(colName string) (int, int, error) { + entryIdx, pos, err := s.base.ResolveColumn(colName) + if err != nil { + if strings.Contains(err.Error(), "ambiguous") { + return 0, 0, &Error{ + Code: 1052, + SQLState: "23000", + Message: fmt.Sprintf("Column '%s' in field list is ambiguous", colName), + } + } + return 0, 0, errNoSuchColumn(colName, "field list") + } + return s.rteMap[entryIdx], pos, nil +} + +// resolveQualifiedColumn finds a column qualified by table name or alias. +// Returns the RTE index and 1-based attribute number. +func (s *analyzerScope) resolveQualifiedColumn(tableName, colName string) (int, int, error) { + entryIdx, pos, err := s.base.ResolveQualifiedColumn(tableName, colName) + if err != nil { + if strings.Contains(err.Error(), "unknown table") { + return 0, 0, &Error{ + Code: ErrUnknownTable, + SQLState: sqlState(ErrUnknownTable), + Message: fmt.Sprintf("Unknown table '%s'", tableName), + } + } + return 0, 0, errNoSuchColumn(colName, "field list") + } + return s.rteMap[entryIdx], pos, nil +} + +// getColumns returns the columns for a named table reference, or nil if not found. +func (s *analyzerScope) getColumns(tableName string) []*Column { + entries := s.base.AllEntries() + lower := strings.ToLower(tableName) + for i, e := range entries { + if strings.ToLower(e.Name) == lower { + return s.cols[i] + } + } + return nil +} + +// resolveColumnFull resolves an unqualified column, trying parent scopes +// if not found locally. Returns (rteIdx, attNum, levelsUp, error). +func (s *analyzerScope) resolveColumnFull(colName string) (int, int, int, error) { + rteIdx, attNum, err := s.resolveColumn(colName) + if err == nil { + return rteIdx, attNum, 0, nil + } + if s.parent != nil { + rteIdx, attNum, parentLevels, parentErr := s.parent.resolveColumnFull(colName) + if parentErr == nil { + return rteIdx, attNum, parentLevels + 1, nil + } + } + return 0, 0, 0, err +} + +// resolveQualifiedColumnFull resolves a qualified column, trying parent scopes +// if not found locally. Returns (rteIdx, attNum, levelsUp, error). +func (s *analyzerScope) resolveQualifiedColumnFull(tableName, colName string) (int, int, int, error) { + rteIdx, attNum, err := s.resolveQualifiedColumn(tableName, colName) + if err == nil { + return rteIdx, attNum, 0, nil + } + if s.parent != nil { + rteIdx, attNum, parentLevels, parentErr := s.parent.resolveQualifiedColumnFull(tableName, colName) + if parentErr == nil { + return rteIdx, attNum, parentLevels + 1, nil + } + } + return 0, 0, 0, err +} + +// allEntries returns all scope entries in registration order. +// This returns a slice of the internal scopeEntry type for backward compatibility +// with the analyzer code that needs rteIdx and catalog column pointers. +func (s *analyzerScope) allEntries() []scopeEntry { + entries := s.base.AllEntries() + result := make([]scopeEntry, len(entries)) + for i, e := range entries { + result[i] = scopeEntry{ + name: e.Name, + rteIdx: s.rteMap[i], + columns: s.cols[i], + } + } + return result +} + +// scopeEntry is one named table reference visible in the current scope. +// Kept for backward compatibility with analyzer code that accesses rteIdx and columns. +type scopeEntry struct { + name string // effective reference name (alias or table name) + rteIdx int // index into Query.RangeTable + columns []*Column // columns available from this entry +} diff --git a/tidb/catalog/show.go b/tidb/catalog/show.go new file mode 100644 index 00000000..fe3cfafd --- /dev/null +++ b/tidb/catalog/show.go @@ -0,0 +1,674 @@ +package catalog + +import ( + "fmt" + "sort" + "strings" +) + +// charsetForCollation derives the charset name from a collation name. +// MySQL collation names are prefixed with the charset name (e.g. latin1_swedish_ci → latin1). +func charsetForCollation(collation string) string { + collation = toLower(collation) + // Try known charsets from longest to shortest to handle prefixes like utf8mb4 vs utf8. + knownCharsets := []string{ + "utf8mb4", "utf8mb3", "utf16le", "gb18030", "geostd8", "armscii8", + "eucjpms", "cp1252", "gb2312", "euckr", "utf16", "utf32", + "latin1", "ascii", "binary", "cp932", "tis620", "hebrew", + "greek", "sjis", "big5", "ucs2", "utf8", "gbk", + } + for _, cs := range knownCharsets { + if strings.HasPrefix(collation, cs+"_") || collation == cs { + return cs + } + } + // Fallback: take prefix before first underscore. + if idx := strings.IndexByte(collation, '_'); idx > 0 { + return collation[:idx] + } + return "" +} + +// defaultCollationForCharset returns the default collation for common MySQL charsets. +var defaultCollationForCharset = map[string]string{ + "utf8mb4": "utf8mb4_0900_ai_ci", + "utf8mb3": "utf8mb3_general_ci", + "utf8": "utf8mb3_general_ci", + "latin1": "latin1_swedish_ci", + "ascii": "ascii_general_ci", + "binary": "binary", + "gbk": "gbk_chinese_ci", + "big5": "big5_chinese_ci", + "euckr": "euckr_korean_ci", + "gb2312": "gb2312_chinese_ci", + "sjis": "sjis_japanese_ci", + "cp1252": "cp1252_general_ci", + "ucs2": "ucs2_general_ci", + "utf16": "utf16_general_ci", + "utf16le": "utf16le_general_ci", + "utf32": "utf32_general_ci", + "cp932": "cp932_japanese_ci", + "eucjpms": "eucjpms_japanese_ci", + "gb18030": "gb18030_chinese_ci", + "geostd8": "geostd8_general_ci", + "tis620": "tis620_thai_ci", + "hebrew": "hebrew_general_ci", + "greek": "greek_general_ci", + "armscii8": "armscii8_general_ci", +} + +// ShowCreateTable produces MySQL 8.0-compatible SHOW CREATE TABLE output. +// Returns "" if the database or table does not exist. +func (c *Catalog) ShowCreateTable(database, table string) string { + db := c.GetDatabase(database) + if db == nil { + return "" + } + tbl := db.GetTable(table) + if tbl == nil { + return "" + } + + var b strings.Builder + if tbl.Temporary { + b.WriteString(fmt.Sprintf("CREATE TEMPORARY TABLE `%s` (\n", tbl.Name)) + } else { + b.WriteString(fmt.Sprintf("CREATE TABLE `%s` (\n", tbl.Name)) + } + + // Columns. + parts := make([]string, 0, len(tbl.Columns)+len(tbl.Indexes)+len(tbl.Constraints)) + for _, col := range tbl.Columns { + parts = append(parts, showColumnWithTable(col, tbl)) + } + + // Indexes — MySQL 8.0 orders them in groups: + // 1. PRIMARY KEY + // 2. UNIQUE KEYs (creation order) + // 3. Regular + SPATIAL KEYs, non-expression (creation order) + // 4. Expression-based KEYs (creation order) + // 5. FULLTEXT KEYs (creation order) + var idxPrimary, idxUnique, idxRegular, idxExpr, idxFulltext []*Index + for _, idx := range tbl.Indexes { + switch { + case idx.Primary: + idxPrimary = append(idxPrimary, idx) + case idx.Unique: + idxUnique = append(idxUnique, idx) + case idx.Fulltext: + idxFulltext = append(idxFulltext, idx) + case isExpressionIndex(idx): + idxExpr = append(idxExpr, idx) + default: + idxRegular = append(idxRegular, idx) + } + } + for _, group := range [][]*Index{idxPrimary, idxUnique, idxRegular, idxExpr, idxFulltext} { + for _, idx := range group { + parts = append(parts, showIndex(idx)) + } + } + + // Constraints (FK and CHECK only — PK/UNIQUE are shown via indexes). + // MySQL 8.0 sorts FKs alphabetically by name, then CHECKs alphabetically by name. + var fkConstraints, chkConstraints []*Constraint + for _, con := range tbl.Constraints { + switch con.Type { + case ConForeignKey: + fkConstraints = append(fkConstraints, con) + case ConCheck: + chkConstraints = append(chkConstraints, con) + } + } + sort.Slice(fkConstraints, func(i, j int) bool { + return fkConstraints[i].Name < fkConstraints[j].Name + }) + sort.Slice(chkConstraints, func(i, j int) bool { + return chkConstraints[i].Name < chkConstraints[j].Name + }) + for _, con := range fkConstraints { + parts = append(parts, showConstraint(con)) + } + for _, con := range chkConstraints { + parts = append(parts, showConstraint(con)) + } + + b.WriteString(" ") + b.WriteString(strings.Join(parts, ",\n ")) + b.WriteString("\n)") + + // Table options. + opts := showTableOptions(tbl) + if opts != "" { + b.WriteString(" ") + b.WriteString(opts) + } + + // Partition clause. + if tbl.Partitioning != nil { + b.WriteString("\n") + b.WriteString(showPartitioning(tbl.Partitioning)) + } + + return b.String() +} + +func showColumn(col *Column) string { + return showColumnWithTable(col, nil) +} + +func showColumnWithTable(col *Column, tbl *Table) string { + var b strings.Builder + b.WriteString(fmt.Sprintf("`%s` %s", col.Name, col.ColumnType)) + + // CHARACTER SET and COLLATE — MySQL 8.0 display rules: + // - CHARACTER SET shown when column charset differs from table charset + // - COLLATE shown when column collation differs from the charset's default collation + if isStringType(col.DataType) || isEnumSetType(col.DataType) { + tableCharset := "" + if tbl != nil { + tableCharset = tbl.Charset + } + // Resolve the default collation for the column's charset. + colCharsetDefault := "" + if col.Charset != "" { + if dc, ok := defaultCollationForCharset[toLower(col.Charset)]; ok { + colCharsetDefault = dc + } + } + charsetDiffers := col.Charset != "" && !eqFoldStr(col.Charset, tableCharset) + // Show COLLATE when the column's collation differs from its charset's default. + collationNonDefault := col.Collation != "" && !eqFoldStr(col.Collation, colCharsetDefault) + + // Determine the table's effective collation for comparison. + tableCollation := "" + if tbl != nil { + tableCollation = tbl.Collation + } + // Column collation differs from table collation (= explicitly set on column). + collationDiffersFromTable := col.Collation != "" && !eqFoldStr(col.Collation, tableCollation) + + if charsetDiffers { + // When charset differs from table, show CHARACTER SET and always COLLATE. + b.WriteString(fmt.Sprintf(" CHARACTER SET %s", col.Charset)) + collation := col.Collation + if collation == "" { + collation = colCharsetDefault + } + if collation != "" { + b.WriteString(fmt.Sprintf(" COLLATE %s", collation)) + } + } else if collationNonDefault && collationDiffersFromTable { + // Collation explicitly set on column (differs from both charset default and table). + // MySQL shows both CHARACTER SET and COLLATE. + if col.Charset != "" { + b.WriteString(fmt.Sprintf(" CHARACTER SET %s", col.Charset)) + } + b.WriteString(fmt.Sprintf(" COLLATE %s", col.Collation)) + } else if collationNonDefault { + // Collation inherited from table but non-default for charset. + // MySQL shows only COLLATE (no CHARACTER SET). + b.WriteString(fmt.Sprintf(" COLLATE %s", col.Collation)) + } + } + + // Generated column. + if col.Generated != nil { + mode := "VIRTUAL" + if col.Generated.Stored { + mode = "STORED" + } + b.WriteString(fmt.Sprintf(" GENERATED ALWAYS AS (%s) %s", col.Generated.Expr, mode)) + if !col.Nullable { + b.WriteString(" NOT NULL") + } + if col.Comment != "" { + b.WriteString(fmt.Sprintf(" COMMENT '%s'", escapeComment(col.Comment))) + } + if col.Invisible { + b.WriteString(" /*!80023 INVISIBLE */") + } + return b.String() + } + + // NOT NULL / NULL. + if !col.Nullable { + b.WriteString(" NOT NULL") + } else if isTimestampType(col.DataType) { + // MySQL 8.0 explicitly shows NULL for TIMESTAMP columns. + b.WriteString(" NULL") + } + + // SRID for spatial types — MySQL 8.0 places it after NOT NULL, before DEFAULT. + if col.SRID != 0 { + b.WriteString(fmt.Sprintf(" /*!80003 SRID %d */", col.SRID)) + } + + // DEFAULT. + if col.Default != nil { + b.WriteString(" DEFAULT ") + b.WriteString(formatDefault(*col.Default, col)) + } else if col.Nullable && !col.AutoIncrement && !isTextBlobType(col.DataType) && !col.DefaultDropped { + b.WriteString(" DEFAULT NULL") + } + + // AUTO_INCREMENT. + if col.AutoIncrement { + b.WriteString(" AUTO_INCREMENT") + } + + // ON UPDATE. + if col.OnUpdate != "" { + b.WriteString(fmt.Sprintf(" ON UPDATE %s", formatOnUpdate(col.OnUpdate))) + } + + // COMMENT. + if col.Comment != "" { + b.WriteString(fmt.Sprintf(" COMMENT '%s'", escapeComment(col.Comment))) + } + + // INVISIBLE. + if col.Invisible { + b.WriteString(" /*!80023 INVISIBLE */") + } + + return b.String() +} + +// formatDefault formats a default value for SHOW CREATE TABLE output. +// MySQL 8.0 quotes numeric defaults as strings (e.g. DEFAULT '0'). +func formatDefault(val string, col *Column) string { + if strings.EqualFold(val, "NULL") { + return "NULL" + } + // Normalize CURRENT_TIMESTAMP() → CURRENT_TIMESTAMP (MySQL 8.0 format). + upper := strings.ToUpper(val) + if upper == "CURRENT_TIMESTAMP" || upper == "CURRENT_TIMESTAMP()" { + return "CURRENT_TIMESTAMP" + } + if strings.HasPrefix(upper, "CURRENT_TIMESTAMP(") { + // CURRENT_TIMESTAMP(N) — keep precision, use uppercase. + return upper + } + if upper == "NOW()" { + return "CURRENT_TIMESTAMP" + } + // b'...' and 0x... bit/hex literals — not quoted. + if strings.HasPrefix(val, "b'") || strings.HasPrefix(val, "B'") || + strings.HasPrefix(val, "0x") || strings.HasPrefix(val, "0X") { + return val + } + // Expression defaults: (expr) — not quoted, shown as-is. + if len(val) >= 2 && val[0] == '(' && val[len(val)-1] == ')' { + return val + } + // Already single-quoted string — return as-is. + if len(val) >= 2 && val[0] == '\'' && val[len(val)-1] == '\'' { + return val + } + // MySQL 8.0 quotes all literal defaults (including numerics). + return "'" + val + "'" +} + +// formatOnUpdate normalizes ON UPDATE values to MySQL 8.0 format. +func formatOnUpdate(val string) string { + upper := strings.ToUpper(val) + if upper == "CURRENT_TIMESTAMP" || upper == "CURRENT_TIMESTAMP()" { + return "CURRENT_TIMESTAMP" + } + if strings.HasPrefix(upper, "CURRENT_TIMESTAMP(") { + return upper + } + if upper == "NOW()" { + return "CURRENT_TIMESTAMP" + } + return val +} + +// isTimestampType returns true for TIMESTAMP/DATETIME types. +func isTimestampType(dt string) bool { + switch strings.ToLower(dt) { + case "timestamp": + return true + } + return false +} + +// isTextBlobType returns true for types where MySQL doesn't show DEFAULT NULL. +func isTextBlobType(dt string) bool { + switch strings.ToLower(dt) { + case "text", "tinytext", "mediumtext", "longtext", + "blob", "tinyblob", "mediumblob", "longblob": + return true + } + return false +} + +func showIndex(idx *Index) string { + var b strings.Builder + + if idx.Primary { + b.WriteString("PRIMARY KEY (") + } else if idx.Unique { + b.WriteString(fmt.Sprintf("UNIQUE KEY `%s` (", idx.Name)) + } else if idx.Fulltext { + b.WriteString(fmt.Sprintf("FULLTEXT KEY `%s` (", idx.Name)) + } else if idx.Spatial { + b.WriteString(fmt.Sprintf("SPATIAL KEY `%s` (", idx.Name)) + } else { + b.WriteString(fmt.Sprintf("KEY `%s` (", idx.Name)) + } + + cols := make([]string, 0, len(idx.Columns)) + for _, ic := range idx.Columns { + cols = append(cols, showIndexColumn(ic)) + } + b.WriteString(strings.Join(cols, ",")) + b.WriteString(")") + + // USING clause: shown when explicitly specified, not for PRIMARY/FULLTEXT/SPATIAL. + if !idx.Primary && !idx.Fulltext && !idx.Spatial && idx.IndexType != "" { + b.WriteString(fmt.Sprintf(" USING %s", strings.ToUpper(idx.IndexType))) + } + + // KEY_BLOCK_SIZE is intentionally NOT rendered here: MySQL 8.0 parses + // the index-level KEY_BLOCK_SIZE option but does not include it in + // SHOW CREATE TABLE output (only the table-level KEY_BLOCK_SIZE is + // shown). The catalog still preserves the value for programmatic + // access. + + // Comment. + if idx.Comment != "" { + b.WriteString(fmt.Sprintf(" COMMENT '%s'", escapeComment(idx.Comment))) + } + + // Invisible. + if !idx.Visible { + b.WriteString(" /*!80000 INVISIBLE */") + } + + return b.String() +} + +func showIndexColumn(ic *IndexColumn) string { + var b strings.Builder + if ic.Expr != "" { + b.WriteString(fmt.Sprintf("(%s)", ic.Expr)) + } else { + b.WriteString(fmt.Sprintf("`%s`", ic.Name)) + if ic.Length > 0 { + b.WriteString(fmt.Sprintf("(%d)", ic.Length)) + } + } + if ic.Descending { + b.WriteString(" DESC") + } + return b.String() +} + +func showConstraint(con *Constraint) string { + var b strings.Builder + + switch con.Type { + case ConForeignKey: + b.WriteString(fmt.Sprintf("CONSTRAINT `%s` FOREIGN KEY (", con.Name)) + cols := make([]string, 0, len(con.Columns)) + for _, c := range con.Columns { + cols = append(cols, fmt.Sprintf("`%s`", c)) + } + b.WriteString(strings.Join(cols, ", ")) + if con.RefDatabase != "" { + b.WriteString(fmt.Sprintf(") REFERENCES `%s`.`%s` (", con.RefDatabase, con.RefTable)) + } else { + b.WriteString(fmt.Sprintf(") REFERENCES `%s` (", con.RefTable)) + } + refCols := make([]string, 0, len(con.RefColumns)) + for _, c := range con.RefColumns { + refCols = append(refCols, fmt.Sprintf("`%s`", c)) + } + b.WriteString(strings.Join(refCols, ", ")) + b.WriteString(")") + + // ON DELETE — omit if RESTRICT or NO ACTION (MySQL defaults). + if con.OnDelete != "" && !isFKDefault(con.OnDelete) { + b.WriteString(fmt.Sprintf(" ON DELETE %s", strings.ToUpper(con.OnDelete))) + } + // ON UPDATE — omit if RESTRICT or NO ACTION (MySQL defaults). + if con.OnUpdate != "" && !isFKDefault(con.OnUpdate) { + b.WriteString(fmt.Sprintf(" ON UPDATE %s", strings.ToUpper(con.OnUpdate))) + } + + case ConCheck: + b.WriteString(fmt.Sprintf("CONSTRAINT `%s` CHECK (%s)", con.Name, con.CheckExpr)) + if con.NotEnforced { + b.WriteString(" /*!80016 NOT ENFORCED */") + } + } + + return b.String() +} + +func showTableOptions(tbl *Table) string { + var parts []string + + if tbl.Engine != "" { + parts = append(parts, fmt.Sprintf("ENGINE=%s", tbl.Engine)) + } + + // AUTO_INCREMENT — shown only when > 1. + if tbl.AutoIncrement > 1 { + parts = append(parts, fmt.Sprintf("AUTO_INCREMENT=%d", tbl.AutoIncrement)) + } + + if tbl.Charset != "" { + parts = append(parts, fmt.Sprintf("DEFAULT CHARSET=%s", tbl.Charset)) + } + + // MySQL 8.0 shows COLLATE when: + // - The collation differs from the charset's default, OR + // - The collation was explicitly specified, OR + // - The charset is utf8mb4 (MySQL 8.0 always shows collation for utf8mb4) + if tbl.Charset != "" { + effectiveCollation := tbl.Collation + if effectiveCollation == "" { + effectiveCollation = defaultCollationForCharset[toLower(tbl.Charset)] + } + defColl := defaultCollationForCharset[toLower(tbl.Charset)] + isNonDefaultCollation := effectiveCollation != "" && !eqFoldStr(effectiveCollation, defColl) + isUtf8mb4 := eqFoldStr(tbl.Charset, "utf8mb4") + if isNonDefaultCollation || isUtf8mb4 { + if effectiveCollation != "" { + parts = append(parts, fmt.Sprintf("COLLATE=%s", effectiveCollation)) + } + } + } + + // KEY_BLOCK_SIZE — shown when non-zero. + if tbl.KeyBlockSize > 0 { + parts = append(parts, fmt.Sprintf("KEY_BLOCK_SIZE=%d", tbl.KeyBlockSize)) + } + + // ROW_FORMAT — shown when explicitly set. + if tbl.RowFormat != "" { + parts = append(parts, fmt.Sprintf("ROW_FORMAT=%s", strings.ToUpper(tbl.RowFormat))) + } + + if tbl.Comment != "" { + parts = append(parts, fmt.Sprintf("COMMENT='%s'", escapeComment(tbl.Comment))) + } + + return strings.Join(parts, " ") +} + +// isFKDefault returns true if the action is a MySQL FK default that should not be shown. +// MySQL 8.0 hides NO ACTION (the implicit default) but shows RESTRICT when explicitly specified. +func isFKDefault(action string) bool { + upper := strings.ToUpper(action) + return upper == "NO ACTION" +} + +func isEnumSetType(dt string) bool { + switch strings.ToLower(dt) { + case "enum", "set": + return true + } + return false +} + +// isExpressionIndex returns true if any column in the index is an expression. +func isExpressionIndex(idx *Index) bool { + for _, ic := range idx.Columns { + if ic.Expr != "" { + return true + } + } + return false +} + +func eqFoldStr(a, b string) bool { + return strings.EqualFold(a, b) +} + +func escapeComment(s string) string { + s = strings.ReplaceAll(s, "\\", "\\\\") + s = strings.ReplaceAll(s, "'", "''") + return s +} + +// showPartitioning renders the partition clause for SHOW CREATE TABLE. +// MySQL 8.0 outputs partitioning after table options with specific formatting. +func showPartitioning(pi *PartitionInfo) string { + var b strings.Builder + + // MySQL uses /*!50500 for RANGE COLUMNS and LIST COLUMNS, /*!50100 for others. + versionComment := "50100" + if pi.Type == "RANGE COLUMNS" || pi.Type == "LIST COLUMNS" { + versionComment = "50500" + } + b.WriteString(fmt.Sprintf("/*!%s PARTITION BY ", versionComment)) + if pi.Linear { + b.WriteString("LINEAR ") + } + switch pi.Type { + case "RANGE": + b.WriteString(fmt.Sprintf("RANGE (%s)", pi.Expr)) + case "RANGE COLUMNS": + // MySQL uses double space before COLUMNS and no backticks on column names. + b.WriteString(fmt.Sprintf("RANGE COLUMNS(%s)", formatPartitionColumnsPlain(pi.Columns))) + case "LIST": + b.WriteString(fmt.Sprintf("LIST (%s)", pi.Expr)) + case "LIST COLUMNS": + b.WriteString(fmt.Sprintf("LIST COLUMNS(%s)", formatPartitionColumnsPlain(pi.Columns))) + case "HASH": + b.WriteString(fmt.Sprintf("HASH (%s)", pi.Expr)) + case "KEY": + // MySQL does not backtick-quote KEY column names. + if pi.Algorithm > 0 { + b.WriteString(fmt.Sprintf("KEY ALGORITHM = %d (%s)", pi.Algorithm, formatPartitionColumnsPlain(pi.Columns))) + } else { + b.WriteString(fmt.Sprintf("KEY (%s)", formatPartitionColumnsPlain(pi.Columns))) + } + } + + // Subpartition clause. + if pi.SubType != "" { + b.WriteString("\n") + b.WriteString("SUBPARTITION BY ") + if pi.SubLinear { + b.WriteString("LINEAR ") + } + switch pi.SubType { + case "HASH": + b.WriteString(fmt.Sprintf("HASH (%s)", pi.SubExpr)) + case "KEY": + if pi.SubAlgo > 0 { + b.WriteString(fmt.Sprintf("KEY ALGORITHM = %d (%s)", pi.SubAlgo, formatPartitionColumns(pi.SubColumns))) + } else { + b.WriteString(fmt.Sprintf("KEY (%s)", formatPartitionColumns(pi.SubColumns))) + } + } + if pi.NumSubParts > 0 { + b.WriteString(fmt.Sprintf("\nSUBPARTITIONS %d", pi.NumSubParts)) + } + } + + // Partition definitions. + // For HASH/KEY partitions with NumParts > 0, auto-generated partition defs + // are rendered as "PARTITIONS N" (matching MySQL 8.0's SHOW CREATE TABLE). + hashKeyAutoGen := pi.NumParts > 0 && (pi.Type == "HASH" || pi.Type == "KEY") + if hashKeyAutoGen { + b.WriteString(fmt.Sprintf("\nPARTITIONS %d", pi.NumParts)) + } else if len(pi.Partitions) > 0 { + b.WriteString("\n(") + for i, pd := range pi.Partitions { + if i > 0 { + b.WriteString(",\n ") + } + b.WriteString("PARTITION ") + b.WriteString(pd.Name) + if pd.ValueExpr != "" { + switch { + case strings.HasPrefix(pi.Type, "RANGE"): + if pd.ValueExpr == "MAXVALUE" { + if pi.Type == "RANGE COLUMNS" { + // RANGE COLUMNS uses parenthesized MAXVALUE + b.WriteString(" VALUES LESS THAN (MAXVALUE)") + } else { + b.WriteString(" VALUES LESS THAN MAXVALUE") + } + } else { + b.WriteString(fmt.Sprintf(" VALUES LESS THAN (%s)", pd.ValueExpr)) + } + case strings.HasPrefix(pi.Type, "LIST"): + b.WriteString(fmt.Sprintf(" VALUES IN (%s)", pd.ValueExpr)) + } + } + b.WriteString(fmt.Sprintf(" ENGINE = %s", partitionEngine(pd.Engine))) + if pd.Comment != "" { + b.WriteString(fmt.Sprintf(" COMMENT = '%s'", escapeComment(pd.Comment))) + } + // Subpartition definitions — skip auto-generated ones (NumSubParts > 0). + if len(pd.SubPartitions) > 0 && pi.NumSubParts == 0 { + b.WriteString("\n (") + for j, spd := range pd.SubPartitions { + if j > 0 { + b.WriteString(",\n ") + } + b.WriteString("SUBPARTITION ") + b.WriteString(spd.Name) + b.WriteString(fmt.Sprintf(" ENGINE = %s", partitionEngine(spd.Engine))) + if spd.Comment != "" { + b.WriteString(fmt.Sprintf(" COMMENT = '%s'", escapeComment(spd.Comment))) + } + } + b.WriteString(")") + } + } + b.WriteString(")") + } + + b.WriteString(" */") + return b.String() +} + +// formatPartitionColumns formats column names for partition clauses with backticks. +func formatPartitionColumns(cols []string) string { + parts := make([]string, len(cols)) + for i, c := range cols { + parts[i] = "`" + c + "`" + } + return strings.Join(parts, ",") +} + +// formatPartitionColumnsPlain formats column names without backticks (MySQL style for RANGE COLUMNS, LIST COLUMNS, KEY). +func formatPartitionColumnsPlain(cols []string) string { + return strings.Join(cols, ",") +} + +// partitionEngine returns the engine for a partition, defaulting to InnoDB. +func partitionEngine(engine string) string { + if engine == "" { + return "InnoDB" + } + return engine +} diff --git a/tidb/catalog/show_test.go b/tidb/catalog/show_test.go new file mode 100644 index 00000000..0d712a3a --- /dev/null +++ b/tidb/catalog/show_test.go @@ -0,0 +1,195 @@ +package catalog + +import ( + "strings" + "testing" +) + +func TestShowCreateTableBasic(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, `CREATE TABLE t1 ( + id INT NOT NULL AUTO_INCREMENT, + name VARCHAR(100) DEFAULT 'test', + PRIMARY KEY (id), + KEY idx_name (name) + )`) + + got := c.ShowCreateTable("testdb", "t1") + assertContains(t, got, "CREATE TABLE `t1`") + assertContains(t, got, "`id` int NOT NULL AUTO_INCREMENT") + assertContains(t, got, "`name` varchar(100) DEFAULT 'test'") + assertContains(t, got, "PRIMARY KEY (`id`)") + assertContains(t, got, "KEY `idx_name` (`name`)") + assertContains(t, got, "ENGINE=InnoDB") + assertContains(t, got, "DEFAULT CHARSET=utf8mb4") + // MySQL 8.0 always shows COLLATE in SHOW CREATE TABLE. + assertContains(t, got, "COLLATE=utf8mb4_0900_ai_ci") +} + +func TestShowCreateTableUniqueKey(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, `CREATE TABLE t2 ( + id INT NOT NULL AUTO_INCREMENT, + email VARCHAR(255) NOT NULL, + PRIMARY KEY (id), + UNIQUE KEY uk_email (email) + )`) + + got := c.ShowCreateTable("testdb", "t2") + assertContains(t, got, "UNIQUE KEY `uk_email` (`email`)") +} + +func TestShowCreateTableDefaults(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, `CREATE TABLE t3 ( + id INT NOT NULL AUTO_INCREMENT, + nullable_col VARCHAR(50), + str_default VARCHAR(50) DEFAULT 'hello', + num_default INT DEFAULT 42, + PRIMARY KEY (id) + )`) + + got := c.ShowCreateTable("testdb", "t3") + // Nullable column without explicit default should show DEFAULT NULL. + assertContains(t, got, "`nullable_col` varchar(50) DEFAULT NULL") + // String default. + assertContains(t, got, "`str_default` varchar(50) DEFAULT 'hello'") + // Numeric default — MySQL 8.0 quotes it. + assertContains(t, got, "`num_default` int DEFAULT '42'") +} + +func TestShowCreateTableMultipleIndexes(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, `CREATE TABLE t4 ( + id INT NOT NULL AUTO_INCREMENT, + a INT, + b VARCHAR(100), + c INT, + PRIMARY KEY (id), + KEY idx_a (a), + KEY idx_b (b), + UNIQUE KEY uk_c (c) + )`) + + got := c.ShowCreateTable("testdb", "t4") + assertContains(t, got, "PRIMARY KEY (`id`)") + assertContains(t, got, "KEY `idx_a` (`a`)") + assertContains(t, got, "KEY `idx_b` (`b`)") + assertContains(t, got, "UNIQUE KEY `uk_c` (`c`)") +} + +func TestShowCreateTableForeignKey(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, `CREATE TABLE parent ( + id INT NOT NULL AUTO_INCREMENT, + PRIMARY KEY (id) + )`) + mustExec(t, c, `CREATE TABLE child ( + id INT NOT NULL AUTO_INCREMENT, + parent_id INT NOT NULL, + PRIMARY KEY (id), + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent (id) ON DELETE CASCADE + )`) + + got := c.ShowCreateTable("testdb", "child") + assertContains(t, got, "CONSTRAINT `fk_parent` FOREIGN KEY (`parent_id`) REFERENCES `parent` (`id`)") + assertContains(t, got, "ON DELETE CASCADE") + // ON UPDATE RESTRICT is the default and should NOT appear. + assertNotContains(t, got, "ON UPDATE") +} + +func TestShowCreateTableForeignKeyRestrict(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, `CREATE TABLE ref_tbl ( + id INT NOT NULL AUTO_INCREMENT, + PRIMARY KEY (id) + )`) + mustExec(t, c, `CREATE TABLE fk_tbl ( + id INT NOT NULL AUTO_INCREMENT, + ref_id INT NOT NULL, + PRIMARY KEY (id), + CONSTRAINT fk_ref FOREIGN KEY (ref_id) REFERENCES ref_tbl (id) ON DELETE RESTRICT ON UPDATE RESTRICT + )`) + + got := c.ShowCreateTable("testdb", "fk_tbl") + // MySQL 8.0 shows RESTRICT when explicitly specified (unlike NO ACTION which is hidden). + assertContains(t, got, "ON DELETE RESTRICT") + assertContains(t, got, "ON UPDATE RESTRICT") +} + +func TestShowCreateTableComment(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, `CREATE TABLE t5 ( + id INT NOT NULL AUTO_INCREMENT, + name VARCHAR(100) COMMENT 'user name', + PRIMARY KEY (id) + ) COMMENT='main table'`) + + got := c.ShowCreateTable("testdb", "t5") + assertContains(t, got, "COMMENT 'user name'") + assertContains(t, got, "COMMENT='main table'") +} + +func TestShowCreateTableUnknownDatabaseOrTable(t *testing.T) { + c := setupWithDB(t) + + if got := c.ShowCreateTable("nonexistent", "t1"); got != "" { + t.Errorf("expected empty string for unknown database, got %q", got) + } + + if got := c.ShowCreateTable("testdb", "nonexistent"); got != "" { + t.Errorf("expected empty string for unknown table, got %q", got) + } +} + +func TestShowCreateTableNotNullNoDefault(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, `CREATE TABLE t6 ( + id INT NOT NULL, + name VARCHAR(100) NOT NULL, + PRIMARY KEY (id) + )`) + + got := c.ShowCreateTable("testdb", "t6") + // NOT NULL columns without a default should NOT show DEFAULT NULL. + assertContains(t, got, "`id` int NOT NULL") + assertContains(t, got, "`name` varchar(100) NOT NULL") + assertNotContains(t, got, "DEFAULT NULL") +} + +func TestShowCreateTableNonDefaultCollation(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, `CREATE TABLE t7 ( + id INT NOT NULL, + PRIMARY KEY (id) + ) DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci`) + + got := c.ShowCreateTable("testdb", "t7") + assertContains(t, got, "DEFAULT CHARSET=utf8mb4") + assertContains(t, got, "COLLATE=utf8mb4_unicode_ci") +} + +func TestShowCreateTableAutoIncrementColumn(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, `CREATE TABLE t8 ( + id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + PRIMARY KEY (id) + )`) + + got := c.ShowCreateTable("testdb", "t8") + assertContains(t, got, "`id` bigint unsigned NOT NULL AUTO_INCREMENT") +} + +func assertContains(t *testing.T, s, substr string) { + t.Helper() + if !strings.Contains(s, substr) { + t.Errorf("expected output to contain %q\ngot:\n%s", substr, s) + } +} + +func assertNotContains(t *testing.T, s, substr string) { + t.Helper() + if strings.Contains(s, substr) { + t.Errorf("expected output NOT to contain %q\ngot:\n%s", substr, s) + } +} diff --git a/tidb/catalog/table.go b/tidb/catalog/table.go new file mode 100644 index 00000000..7b4c8996 --- /dev/null +++ b/tidb/catalog/table.go @@ -0,0 +1,224 @@ +package catalog + +type Table struct { + Name string + Database *Database + Columns []*Column + colByName map[string]int // lowered name -> index + Indexes []*Index + Constraints []*Constraint + Engine string + Charset string + Collation string + Comment string + AutoIncrement int64 + Temporary bool + RowFormat string + KeyBlockSize int + Partitioning *PartitionInfo + + // droppedByCleanup tracks indexes auto-removed by DROP COLUMN cleanup + // during multi-command ALTER TABLE. This allows a subsequent explicit + // DROP INDEX in the same ALTER to succeed (matching MySQL 8.0 behavior). + droppedByCleanup map[string]bool +} + +// PartitionInfo holds partition metadata for a table. +type PartitionInfo struct { + Type string // RANGE, LIST, HASH, KEY + Linear bool // LINEAR HASH or LINEAR KEY + Expr string // partition expression (for RANGE/LIST/HASH) + Columns []string // partition columns (for RANGE COLUMNS/LIST COLUMNS/KEY) + Algorithm int // ALGORITHM={1|2} for KEY partitioning + NumParts int // PARTITIONS num + Partitions []*PartitionDefInfo + SubType string // subpartition type (HASH or KEY, "" if none) + SubLinear bool // LINEAR for subpartition + SubExpr string // subpartition expression + SubColumns []string // subpartition columns + SubAlgo int // subpartition ALGORITHM + NumSubParts int // SUBPARTITIONS num +} + +// PartitionDefInfo holds a single partition definition. +type PartitionDefInfo struct { + Name string + ValueExpr string // "LESS THAN (...)" or "IN (...)" or "" + Engine string // ENGINE option for this partition + Comment string // COMMENT option for this partition + SubPartitions []*SubPartitionDefInfo +} + +// SubPartitionDefInfo holds a single subpartition definition. +type SubPartitionDefInfo struct { + Name string + Engine string + Comment string +} + +type Column struct { + Position int + Name string + DataType string // normalized (int, varchar, etc.) + ColumnType string // full type string (varchar(100), int unsigned) + Nullable bool + Default *string + DefaultDropped bool // true when ALTER COLUMN DROP DEFAULT was used + AutoIncrement bool + Charset string + Collation string + Comment string + OnUpdate string + Generated *GeneratedColumnInfo + Invisible bool + SRID int // Spatial Reference ID (0 = not set) + DefaultAnalyzed AnalyzedExpr // Phase 3: analyzed DEFAULT expression + GeneratedAnalyzed AnalyzedExpr // Phase 3: analyzed GENERATED ALWAYS AS expression +} + +type GeneratedColumnInfo struct { + Expr string + Stored bool +} + +type View struct { + Name string + Database *Database + Definition string + Algorithm string + Definer string + SqlSecurity string + CheckOption string + Columns []string // All column names (explicit or derived from SELECT) + ExplicitColumns bool // true if the user specified a column list in CREATE VIEW + AnalyzedQuery *Query // analyzed view body (populated on CREATE VIEW); nil if analysis failed +} + +// Routine represents a stored function or procedure in the catalog. +type Routine struct { + Name string + Database *Database + IsProcedure bool + Definer string + Params []*RoutineParam + Returns string // return type string for functions (empty for procedures) + Body string + Characteristics map[string]string // name -> value (DETERMINISTIC, COMMENT, etc.) +} + +// RoutineParam represents a parameter of a stored routine. +type RoutineParam struct { + Direction string // IN, OUT, INOUT (empty for functions) + Name string + TypeName string // full type string +} + +// Trigger represents a trigger in the catalog. +type Trigger struct { + Name string + Database *Database + Table string // table name the trigger is on + Timing string // BEFORE, AFTER + Event string // INSERT, UPDATE, DELETE + Definer string + Body string + Order *TriggerOrderInfo +} + +// TriggerOrderInfo represents FOLLOWS/PRECEDES ordering. +type TriggerOrderInfo struct { + Follows bool + TriggerName string +} + +// Event represents a scheduled event in the catalog. +type Event struct { + Name string + Database *Database + Definer string + Schedule string // raw schedule text (e.g. "EVERY 1 HOUR", "AT '2024-01-01 00:00:00'") + OnCompletion string // PRESERVE, NOT PRESERVE, or "" (default NOT PRESERVE) + Enable string // ENABLE, DISABLE, DISABLE ON SLAVE, or "" (default ENABLE) + Comment string + Body string +} + +// cloneTable returns a deep copy of the table's mutable state. +// The returned Table shares the same Name, Database pointer, and scalar fields, +// but has independent slices and maps so that mutations do not affect the original. +func cloneTable(src *Table) Table { + dst := *src // shallow copy of all scalar fields + + // Deep copy columns. + dst.Columns = make([]*Column, len(src.Columns)) + for i, sc := range src.Columns { + col := *sc + if sc.Default != nil { + def := *sc.Default + col.Default = &def + } + if sc.Generated != nil { + gen := *sc.Generated + col.Generated = &gen + } + dst.Columns[i] = &col + } + + // Deep copy colByName. + dst.colByName = make(map[string]int, len(src.colByName)) + for k, v := range src.colByName { + dst.colByName[k] = v + } + + // Deep copy indexes. + dst.Indexes = make([]*Index, len(src.Indexes)) + for i, si := range src.Indexes { + idx := *si + idx.Table = src // keep pointing to the original table pointer + cols := make([]*IndexColumn, len(si.Columns)) + for j, sc := range si.Columns { + ic := *sc + cols[j] = &ic + } + idx.Columns = cols + dst.Indexes[i] = &idx + } + + // Deep copy constraints. + dst.Constraints = make([]*Constraint, len(src.Constraints)) + for i, sc := range src.Constraints { + con := *sc + con.Table = src + con.Columns = append([]string{}, sc.Columns...) + con.RefColumns = append([]string{}, sc.RefColumns...) + dst.Constraints[i] = &con + } + + // Deep copy partitioning. + if src.Partitioning != nil { + pi := *src.Partitioning + pi.Columns = append([]string{}, src.Partitioning.Columns...) + pi.SubColumns = append([]string{}, src.Partitioning.SubColumns...) + pi.Partitions = make([]*PartitionDefInfo, len(src.Partitioning.Partitions)) + for i, sp := range src.Partitioning.Partitions { + pd := *sp + pd.SubPartitions = make([]*SubPartitionDefInfo, len(sp.SubPartitions)) + for j, ss := range sp.SubPartitions { + sd := *ss + pd.SubPartitions[j] = &sd + } + pi.Partitions[i] = &pd + } + dst.Partitioning = &pi + } + + return dst +} + +func (t *Table) GetColumn(name string) *Column { + idx, ok := t.colByName[toLower(name)] + if !ok { + return nil + } + return t.Columns[idx] +} diff --git a/tidb/catalog/tablecmds.go b/tidb/catalog/tablecmds.go new file mode 100644 index 00000000..1ec0a191 --- /dev/null +++ b/tidb/catalog/tablecmds.go @@ -0,0 +1,1529 @@ +package catalog + +import ( + "fmt" + "strings" + + nodes "github.com/bytebase/omni/tidb/ast" + "github.com/bytebase/omni/tidb/deparse" +) + +func (c *Catalog) createTable(stmt *nodes.CreateTableStmt) error { + // Resolve database. + dbName := "" + if stmt.Table != nil { + dbName = stmt.Table.Schema + } + db, err := c.resolveDatabase(dbName) + if err != nil { + return err + } + + tableName := stmt.Table.Name + key := toLower(tableName) + + // Check for duplicate table or view with the same name. + if db.Tables[key] != nil { + if stmt.IfNotExists { + return nil + } + return errDupTable(tableName) + } + if db.Views[key] != nil { + return errDupTable(tableName) + } + + // CREATE TABLE ... LIKE + if stmt.Like != nil { + return c.createTableLike(db, tableName, key, stmt) + } + + // CREATE TABLE ... AS SELECT (CTAS) — not supported yet, skip silently + if stmt.Select != nil && len(stmt.Columns) == 0 { + return nil + } + + tbl := &Table{ + Name: tableName, + Database: db, + Columns: make([]*Column, 0, len(stmt.Columns)), + colByName: make(map[string]int), + Indexes: make([]*Index, 0), + Constraints: make([]*Constraint, 0), + Charset: db.Charset, + Collation: db.Collation, + Engine: "InnoDB", + Temporary: stmt.Temporary, + } + + // Apply table options. + tblCharsetExplicit := false + tblCollationExplicit := false + for _, opt := range stmt.Options { + switch toLower(opt.Name) { + case "engine": + tbl.Engine = opt.Value + case "charset", "character set", "default charset", "default character set": + tbl.Charset = opt.Value + tblCharsetExplicit = true + case "collate", "default collate": + tbl.Collation = opt.Value + tblCollationExplicit = true + case "comment": + tbl.Comment = opt.Value + case "auto_increment": + fmt.Sscanf(opt.Value, "%d", &tbl.AutoIncrement) + case "row_format": + tbl.RowFormat = opt.Value + case "key_block_size": + fmt.Sscanf(opt.Value, "%d", &tbl.KeyBlockSize) + } + } + // When charset is specified without explicit collation, derive the default collation. + if tblCharsetExplicit && !tblCollationExplicit { + if dc, ok := defaultCollationForCharset[toLower(tbl.Charset)]; ok { + tbl.Collation = dc + } + } + // When collation is specified without explicit charset, derive the charset from collation. + if tblCollationExplicit && !tblCharsetExplicit { + if cs := charsetForCollation(tbl.Collation); cs != "" { + tbl.Charset = cs + } + } + // Track whether we have a primary key (to detect multiple PKs). + hasPK := false + + // Defer FK backing index creation until after all explicit indexes are added, + // so that explicit indexes can satisfy FK requirements without creating duplicates. + type pendingFK struct { + conName string + cols []string + idxCols []*IndexColumn + } + var pendingFKs []pendingFK + + // unnamedFKCount counts FKs that received an auto-generated name in this + // CREATE TABLE statement. Matches MySQL 8.0's generate_fk_name() counter, + // which is initialized to 0 for CREATE TABLE and incremented per unnamed FK, + // IGNORING user-named FKs. See sql/sql_table.cc:9252 (create_table_impl + // is called with fk_max_generated_name_number = 0) and sql/sql_table.cc:5912 + // (generate_fk_name uses ++counter). + // + // Example: CREATE TABLE t (a INT, CONSTRAINT t_ibfk_5 FK, b INT, FK) + // → first unnamed FK gets t_ibfk_1 (not t_ibfk_2 or t_ibfk_6). + // This differs from ALTER TABLE ADD FK, where the counter starts at + // max(existing) — see altercmds.go. + var unnamedFKCount int + + // unnamedCheckCount counts CHECK constraints that received an auto-generated + // name in this CREATE TABLE statement. Matches MySQL 8.0's CHECK counter + // at sql/sql_table.cc:19073 (cc_max_generated_number starts at 0, used + // via ++cc_max_generated_number). Like the FK counter, this IGNORES + // user-named CHECK constraints during CREATE TABLE. + // + // Example: CREATE TABLE t (a INT, CONSTRAINT t_chk_1 CHECK(a>0), b INT, CHECK(b<100)) + // → unnamed CHECK gets t_chk_1, but t_chk_1 is already taken by user + // → real MySQL errors with ER_CHECK_CONSTRAINT_DUP_NAME + // (see sql/sql_table.cc:19595 check_constraint_dup_name check). + // For ALTER TABLE ADD CHECK, the counter is loaded from existing max — + // see altercmds.go which uses nextCheckNumber (gap-scan helper). + var unnamedCheckCount int + + // Process columns. + for i, colDef := range stmt.Columns { + colKey := toLower(colDef.Name) + if _, exists := tbl.colByName[colKey]; exists { + return errDupColumn(colDef.Name) + } + + col := &Column{ + Position: i + 1, + Name: colDef.Name, + Nullable: true, // default nullable + } + + // Type info. + isSerial := false + if colDef.TypeName != nil { + typeName := toLower(colDef.TypeName.Name) + // Handle SERIAL: expands to BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE + if typeName == "serial" { + isSerial = true + col.DataType = "bigint" + col.ColumnType = "bigint unsigned" + col.AutoIncrement = true + col.Nullable = false + } else if typeName == "boolean" { + col.DataType = "tinyint" + col.ColumnType = formatColumnType(colDef.TypeName) + } else if typeName == "numeric" { + col.DataType = "decimal" + col.ColumnType = formatColumnType(colDef.TypeName) + } else { + col.DataType = typeName + col.ColumnType = formatColumnType(colDef.TypeName) + } + if colDef.TypeName.Charset != "" { + col.Charset = colDef.TypeName.Charset + } + if colDef.TypeName.Collate != "" { + col.Collation = colDef.TypeName.Collate + } + + // MySQL converts string types with CHARACTER SET binary to binary types. + // ENUM and SET are not converted — they keep CHARACTER SET binary annotation. + if strings.EqualFold(col.Charset, "binary") && isStringType(col.DataType) && !isEnumSetType(col.DataType) { + col = convertToBinaryType(col, colDef.TypeName) + } + } + + // Default charset/collation for string types. + if isStringType(col.DataType) { + if col.Charset == "" { + col.Charset = tbl.Charset + } + if col.Collation == "" { + // If column charset differs from table charset, use the default + // collation for the column's charset, not the table's collation. + if !strings.EqualFold(col.Charset, tbl.Charset) { + if dc, ok := defaultCollationForCharset[toLower(col.Charset)]; ok { + col.Collation = dc + } + } else { + col.Collation = tbl.Collation + } + } + } + + // Top-level column properties. + if colDef.TypeName != nil && colDef.TypeName.SRID != 0 { + col.SRID = colDef.TypeName.SRID + } + if colDef.AutoIncrement { + col.AutoIncrement = true + col.Nullable = false + } + if colDef.Comment != "" { + col.Comment = colDef.Comment + } + if colDef.DefaultValue != nil { + s := nodeToSQL(colDef.DefaultValue) + col.Default = &s + } + if colDef.OnUpdate != nil { + col.OnUpdate = nodeToSQL(colDef.OnUpdate) + } + if colDef.Generated != nil { + col.Generated = &GeneratedColumnInfo{ + Expr: nodeToSQLGenerated(colDef.Generated.Expr, tbl.Charset), + Stored: colDef.Generated.Stored, + } + } + + // Process column-level constraints. + for _, cc := range colDef.Constraints { + switch cc.Type { + case nodes.ColConstrNotNull: + col.Nullable = false + case nodes.ColConstrNull: + col.Nullable = true + case nodes.ColConstrDefault: + if cc.Expr != nil { + s := nodeToSQL(cc.Expr) + col.Default = &s + } + case nodes.ColConstrPrimaryKey: + if hasPK { + return errMultiplePriKey() + } + hasPK = true + col.Nullable = false + // Add PK index and constraint after all columns are processed. + // We'll defer this—record it for now. + case nodes.ColConstrUnique: + // Handled after columns are added. + case nodes.ColConstrAutoIncrement: + col.AutoIncrement = true + col.Nullable = false + case nodes.ColConstrCheck: + // Add check constraint. + conName := cc.Name + if conName == "" { + unnamedCheckCount++ + conName = fmt.Sprintf("%s_chk_%d", tableName, unnamedCheckCount) + } + tbl.Constraints = append(tbl.Constraints, &Constraint{ + Name: conName, + Type: ConCheck, + Table: tbl, + CheckExpr: nodeToSQL(cc.Expr), + NotEnforced: cc.NotEnforced, + }) + case nodes.ColConstrReferences: + // Column-level FK. + refDB := "" + refTable := "" + if cc.RefTable != nil { + refDB = cc.RefTable.Schema + refTable = cc.RefTable.Name + } + conName := cc.Name + if conName == "" { + unnamedFKCount++ + conName = fmt.Sprintf("%s_ibfk_%d", tableName, unnamedFKCount) + } + tbl.Constraints = append(tbl.Constraints, &Constraint{ + Name: conName, + Type: ConForeignKey, + Table: tbl, + Columns: []string{colDef.Name}, + RefDatabase: refDB, + RefTable: refTable, + RefColumns: cc.RefColumns, + OnDelete: refActionToString(cc.OnDelete), + OnUpdate: refActionToString(cc.OnUpdate), + }) + // Defer implicit backing index for FK until after all explicit indexes are added. + pendingFKs = append(pendingFKs, pendingFK{conName: cc.Name, cols: []string{colDef.Name}, idxCols: []*IndexColumn{{Name: colDef.Name}}}) + case nodes.ColConstrVisible: + col.Invisible = false + case nodes.ColConstrInvisible: + col.Invisible = true + case nodes.ColConstrCollate: + // Collation specified via constraint. + if cc.Expr != nil { + if s, ok := cc.Expr.(*nodes.StringLit); ok { + col.Collation = s.Value + } + } + } + } + + // SERIAL implies UNIQUE KEY — add after the column is fully configured. + if isSerial { + idxName := allocIndexName(tbl, colDef.Name) + tbl.Indexes = append(tbl.Indexes, &Index{ + Name: idxName, + Table: tbl, + Columns: []*IndexColumn{{Name: colDef.Name}}, + Unique: true, + IndexType: "", + Visible: true, + }) + tbl.Constraints = append(tbl.Constraints, &Constraint{ + Name: idxName, + Type: ConUniqueKey, + Table: tbl, + Columns: []string{colDef.Name}, + IndexName: idxName, + }) + } + + tbl.Columns = append(tbl.Columns, col) + tbl.colByName[colKey] = i + } + + // Second pass: add column-level PK and UNIQUE indexes/constraints. + for _, colDef := range stmt.Columns { + for _, cc := range colDef.Constraints { + switch cc.Type { + case nodes.ColConstrPrimaryKey: + tbl.Indexes = append(tbl.Indexes, &Index{ + Name: "PRIMARY", + Table: tbl, + Columns: []*IndexColumn{{Name: colDef.Name}}, + Unique: true, + Primary: true, + IndexType: "", + Visible: true, + }) + tbl.Constraints = append(tbl.Constraints, &Constraint{ + Name: "PRIMARY", + Type: ConPrimaryKey, + Table: tbl, + Columns: []string{colDef.Name}, + IndexName: "PRIMARY", + }) + case nodes.ColConstrUnique: + idxName := allocIndexName(tbl, colDef.Name) + tbl.Indexes = append(tbl.Indexes, &Index{ + Name: idxName, + Table: tbl, + Columns: []*IndexColumn{{Name: colDef.Name}}, + Unique: true, + IndexType: "", + Visible: true, + }) + tbl.Constraints = append(tbl.Constraints, &Constraint{ + Name: idxName, + Type: ConUniqueKey, + Table: tbl, + Columns: []string{colDef.Name}, + IndexName: idxName, + }) + } + } + } + + // Process table-level constraints. + for _, con := range stmt.Constraints { + cols := extractColumnNames(con) + + switch con.Type { + case nodes.ConstrPrimaryKey: + if hasPK { + return errMultiplePriKey() + } + hasPK = true + // Mark PK columns as NOT NULL. + for _, colName := range cols { + c := tbl.GetColumn(colName) + if c != nil { + c.Nullable = false + } + } + idxCols := buildIndexColumns(con) + pkIdx := &Index{ + Name: "PRIMARY", + Table: tbl, + Columns: idxCols, + Unique: true, + Primary: true, + IndexType: "", + Visible: true, + } + applyIndexOptions(pkIdx, con.IndexOptions) + tbl.Indexes = append(tbl.Indexes, pkIdx) + tbl.Constraints = append(tbl.Constraints, &Constraint{ + Name: "PRIMARY", + Type: ConPrimaryKey, + Table: tbl, + Columns: cols, + IndexName: "PRIMARY", + }) + + case nodes.ConstrUnique: + idxName := con.Name + if idxName == "" && len(cols) > 0 { + idxName = allocIndexName(tbl, cols[0]) + } + idxCols := buildIndexColumns(con) + uqIdx := &Index{ + Name: idxName, + Table: tbl, + Columns: idxCols, + Unique: true, + IndexType: resolveConstraintIndexType(con), + Visible: true, + } + applyIndexOptions(uqIdx, con.IndexOptions) + tbl.Indexes = append(tbl.Indexes, uqIdx) + tbl.Constraints = append(tbl.Constraints, &Constraint{ + Name: idxName, + Type: ConUniqueKey, + Table: tbl, + Columns: cols, + IndexName: idxName, + }) + + case nodes.ConstrForeignKey: + conName := con.Name + if conName == "" { + unnamedFKCount++ + conName = fmt.Sprintf("%s_ibfk_%d", tableName, unnamedFKCount) + } + refDB := "" + refTable := "" + if con.RefTable != nil { + refDB = con.RefTable.Schema + refTable = con.RefTable.Name + } + tbl.Constraints = append(tbl.Constraints, &Constraint{ + Name: conName, + Type: ConForeignKey, + Table: tbl, + Columns: cols, + RefDatabase: refDB, + RefTable: refTable, + RefColumns: con.RefColumns, + OnDelete: refActionToString(con.OnDelete), + OnUpdate: refActionToString(con.OnUpdate), + }) + // Defer implicit backing index for FK until after all explicit indexes are added. + pendingFKs = append(pendingFKs, pendingFK{conName: con.Name, cols: cols, idxCols: buildIndexColumns(con)}) + + case nodes.ConstrCheck: + conName := con.Name + if conName == "" { + unnamedCheckCount++ + conName = fmt.Sprintf("%s_chk_%d", tableName, unnamedCheckCount) + } + tbl.Constraints = append(tbl.Constraints, &Constraint{ + Name: conName, + Type: ConCheck, + Table: tbl, + CheckExpr: nodeToSQL(con.Expr), + NotEnforced: con.NotEnforced, + }) + + case nodes.ConstrIndex: + idxName := con.Name + if idxName == "" && len(cols) > 0 { + idxName = allocIndexName(tbl, cols[0]) + } + idxCols := buildIndexColumns(con) + keyIdx := &Index{ + Name: idxName, + Table: tbl, + Columns: idxCols, + IndexType: resolveConstraintIndexType(con), + Visible: true, + } + applyIndexOptions(keyIdx, con.IndexOptions) + tbl.Indexes = append(tbl.Indexes, keyIdx) + + case nodes.ConstrFulltextIndex: + idxName := con.Name + if idxName == "" && len(cols) > 0 { + idxName = allocIndexName(tbl, cols[0]) + } + idxCols := buildIndexColumns(con) + ftIdx := &Index{ + Name: idxName, + Table: tbl, + Columns: idxCols, + Fulltext: true, + IndexType: "FULLTEXT", + Visible: true, + } + applyIndexOptions(ftIdx, con.IndexOptions) + tbl.Indexes = append(tbl.Indexes, ftIdx) + + case nodes.ConstrSpatialIndex: + idxName := con.Name + if idxName == "" && len(cols) > 0 { + idxName = allocIndexName(tbl, cols[0]) + } + idxCols := buildIndexColumns(con) + spIdx := &Index{ + Name: idxName, + Table: tbl, + Columns: idxCols, + Spatial: true, + IndexType: "SPATIAL", + Visible: true, + } + applyIndexOptions(spIdx, con.IndexOptions) + tbl.Indexes = append(tbl.Indexes, spIdx) + } + } + + // Process deferred FK backing indexes now that all explicit indexes are in place. + for _, fk := range pendingFKs { + ensureFKBackingIndex(tbl, fk.conName, fk.cols, fk.idxCols) + } + + // Validate foreign key constraints (unless foreign_key_checks=0). + if c.foreignKeyChecks { + if err := c.validateForeignKeys(db, tbl); err != nil { + return err + } + } + + // Process partition clause. + if stmt.Partitions != nil { + tbl.Partitioning = buildPartitionInfo(stmt.Partitions) + } + + // Phase 3: analyze DEFAULT, GENERATED, and CHECK expressions now that all + // columns are present in the table. + c.analyzeTableExpressions(tbl, stmt) + + db.Tables[key] = tbl + return nil +} + +// analyzeTableExpressions performs best-effort semantic analysis on DEFAULT, +// GENERATED, and CHECK expressions after all columns have been added to the table. +func (c *Catalog) analyzeTableExpressions(tbl *Table, stmt *nodes.CreateTableStmt) { + // Analyze DEFAULT and GENERATED expressions from column definitions. + for i, colDef := range stmt.Columns { + if i >= len(tbl.Columns) { + break + } + col := tbl.Columns[i] + + // Top-level DEFAULT. + if colDef.DefaultValue != nil { + if analyzed, err := c.AnalyzeStandaloneExpr(colDef.DefaultValue, tbl); err == nil { + col.DefaultAnalyzed = analyzed + } + } + + // Column-constraint DEFAULT (may override top-level). + for _, cc := range colDef.Constraints { + if cc.Type == nodes.ColConstrDefault && cc.Expr != nil { + if analyzed, err := c.AnalyzeStandaloneExpr(cc.Expr, tbl); err == nil { + col.DefaultAnalyzed = analyzed + } + } + } + + // GENERATED ALWAYS AS. + if colDef.Generated != nil { + if analyzed, err := c.AnalyzeStandaloneExpr(colDef.Generated.Expr, tbl); err == nil { + col.GeneratedAnalyzed = analyzed + } + } + } + + // Analyze CHECK expressions on constraints. + // We iterate constraints and match CHECK ones; the AST sources are both + // column-level and table-level constraint nodes. + checkIdx := 0 + for _, colDef := range stmt.Columns { + for _, cc := range colDef.Constraints { + if cc.Type == nodes.ColConstrCheck && cc.Expr != nil { + // Find matching CHECK constraint by index. + for checkIdx < len(tbl.Constraints) { + if tbl.Constraints[checkIdx].Type == ConCheck { + if analyzed, err := c.AnalyzeStandaloneExpr(cc.Expr, tbl); err == nil { + tbl.Constraints[checkIdx].CheckAnalyzed = analyzed + } + checkIdx++ + break + } + checkIdx++ + } + } + } + } + for _, con := range stmt.Constraints { + if con.Type == nodes.ConstrCheck && con.Expr != nil { + for checkIdx < len(tbl.Constraints) { + if tbl.Constraints[checkIdx].Type == ConCheck { + if analyzed, err := c.AnalyzeStandaloneExpr(con.Expr, tbl); err == nil { + tbl.Constraints[checkIdx].CheckAnalyzed = analyzed + } + checkIdx++ + break + } + checkIdx++ + } + } + } +} + +// buildPartitionInfo converts an AST PartitionClause to a catalog PartitionInfo. +func buildPartitionInfo(pc *nodes.PartitionClause) *PartitionInfo { + pi := &PartitionInfo{ + Linear: pc.Linear, + NumParts: pc.NumParts, + } + + switch pc.Type { + case nodes.PartitionRange: + if len(pc.Columns) > 0 { + pi.Type = "RANGE COLUMNS" + pi.Columns = pc.Columns + } else { + pi.Type = "RANGE" + pi.Expr = nodeToSQL(pc.Expr) + } + case nodes.PartitionList: + if len(pc.Columns) > 0 { + pi.Type = "LIST COLUMNS" + pi.Columns = pc.Columns + } else { + pi.Type = "LIST" + pi.Expr = nodeToSQL(pc.Expr) + } + case nodes.PartitionHash: + pi.Type = "HASH" + pi.Expr = nodeToSQL(pc.Expr) + case nodes.PartitionKey: + pi.Type = "KEY" + pi.Columns = pc.Columns + pi.Algorithm = pc.Algorithm + } + + // Subpartition info. + if pc.SubPartType != 0 || pc.SubPartExpr != nil || len(pc.SubPartColumns) > 0 { + switch pc.SubPartType { + case nodes.PartitionHash: + pi.SubType = "HASH" + pi.SubExpr = nodeToSQL(pc.SubPartExpr) + case nodes.PartitionKey: + pi.SubType = "KEY" + pi.SubColumns = pc.SubPartColumns + pi.SubAlgo = pc.SubPartAlgo + } + pi.SubLinear = false // TODO: track linear for subpartitions if parser supports it + pi.NumSubParts = pc.NumSubParts + } + + // Partition definitions. + for _, pd := range pc.Partitions { + pdi := &PartitionDefInfo{ + Name: pd.Name, + } + // Values. + if pd.Values != nil { + pdi.ValueExpr = partitionValueToString(pd.Values, pc.Type) + } + // Options. + for _, opt := range pd.Options { + switch toLower(opt.Name) { + case "engine": + pdi.Engine = opt.Value + case "comment": + pdi.Comment = opt.Value + } + } + // Subpartitions. + for _, spd := range pd.SubPartitions { + spdi := &SubPartitionDefInfo{ + Name: spd.Name, + } + for _, opt := range spd.Options { + switch toLower(opt.Name) { + case "engine": + spdi.Engine = opt.Value + case "comment": + spdi.Comment = opt.Value + } + } + pdi.SubPartitions = append(pdi.SubPartitions, spdi) + } + pi.Partitions = append(pi.Partitions, pdi) + } + + // Auto-generate partition definitions for HASH/KEY/LINEAR HASH/LINEAR KEY + // when PARTITIONS N is specified without explicit partition definitions. + // MySQL naming convention: p0, p1, p2, ... + if len(pi.Partitions) == 0 && pi.NumParts > 0 { + for i := 0; i < pi.NumParts; i++ { + pi.Partitions = append(pi.Partitions, &PartitionDefInfo{ + Name: fmt.Sprintf("p%d", i), + }) + } + } + + // Auto-generate subpartition definitions when SUBPARTITIONS N is specified + // without explicit subpartition definitions. + // MySQL naming convention: sp0, sp1, ... + if pi.NumSubParts > 0 { + for _, part := range pi.Partitions { + if len(part.SubPartitions) == 0 { + for j := 0; j < pi.NumSubParts; j++ { + part.SubPartitions = append(part.SubPartitions, &SubPartitionDefInfo{ + Name: fmt.Sprintf("%ssp%d", part.Name, j), + }) + } + } + } + } + + return pi +} + +// partitionValueToString converts a partition value node to SQL string. +func partitionValueToString(v nodes.Node, ptype nodes.PartitionType) string { + switch n := v.(type) { + case *nodes.String: + if n.Str == "MAXVALUE" { + return "MAXVALUE" + } + return n.Str + case *nodes.List: + parts := make([]string, len(n.Items)) + for i, item := range n.Items { + if subList, ok := item.(*nodes.List); ok { + // Tuple: (val1, val2) for multi-column LIST COLUMNS + subParts := make([]string, len(subList.Items)) + for j, sub := range subList.Items { + subParts[j] = nodeToSQL(sub.(nodes.ExprNode)) + } + parts[i] = "(" + strings.Join(subParts, ",") + ")" + } else { + parts[i] = nodeToSQL(item.(nodes.ExprNode)) + } + } + return strings.Join(parts, ",") + case nodes.ExprNode: + return nodeToSQL(n) + default: + return "" + } +} + +// validateForeignKeys checks all FK constraints on a table against the referenced tables. +// It validates: (1) referenced table exists, (2) referenced columns have an index, +// (3) column types are compatible. +func (c *Catalog) validateForeignKeys(db *Database, tbl *Table) error { + for _, con := range tbl.Constraints { + if con.Type != ConForeignKey { + continue + } + if err := c.validateSingleFK(db, tbl, con); err != nil { + return err + } + } + return nil +} + +// validateSingleFK validates a single FK constraint against its referenced table. +func (c *Catalog) validateSingleFK(db *Database, tbl *Table, con *Constraint) error { + // Resolve the referenced table. + refDBName := con.RefDatabase + var refDB *Database + if refDBName != "" { + refDB = c.GetDatabase(refDBName) + } else { + refDB = db + } + + var refTbl *Table + if refDB != nil { + // Self-referencing FK: the table being created references itself. + if toLower(con.RefTable) == toLower(tbl.Name) && refDB == db { + refTbl = tbl + } else { + refTbl = refDB.GetTable(con.RefTable) + } + } + + if refTbl == nil { + return errFKNoRefTable(con.RefTable) + } + + // Check that referenced columns have an index (PK or UNIQUE or KEY) + // that starts with the referenced columns in order. + if !hasIndexOnColumns(refTbl, con.RefColumns) { + return errFKMissingIndex(con.Name, con.RefTable) + } + + // Check column type compatibility. + for i, colName := range con.Columns { + if i >= len(con.RefColumns) { + break + } + col := tbl.GetColumn(colName) + refCol := refTbl.GetColumn(con.RefColumns[i]) + if col == nil || refCol == nil { + continue + } + if !fkTypesCompatible(col, refCol) { + return errFKIncompatibleColumns(colName, con.RefColumns[i], con.Name) + } + } + + return nil +} + +// hasIndexOnColumns checks whether a table has an index (PK, UNIQUE, or regular KEY) +// whose leading columns match the given columns. +func hasIndexOnColumns(tbl *Table, cols []string) bool { + for _, idx := range tbl.Indexes { + if len(idx.Columns) < len(cols) { + continue + } + match := true + for i, col := range cols { + if toLower(idx.Columns[i].Name) != toLower(col) { + match = false + break + } + } + if match { + return true + } + } + return false +} + +// fkTypesCompatible checks whether two columns have compatible types for FK relationships. +// MySQL requires that FK and referenced columns have the same storage type. +func fkTypesCompatible(col, refCol *Column) bool { + // Compare base data types. + if col.DataType != refCol.DataType { + return false + } + + // For integer types, check signedness (unsigned must match). + colUnsigned := strings.Contains(strings.ToLower(col.ColumnType), "unsigned") + refUnsigned := strings.Contains(strings.ToLower(refCol.ColumnType), "unsigned") + if colUnsigned != refUnsigned { + return false + } + + // For string types, check charset compatibility. + if isStringType(col.DataType) { + colCharset := col.Charset + refCharset := refCol.Charset + if colCharset != "" && refCharset != "" && toLower(colCharset) != toLower(refCharset) { + return false + } + } + + return true +} + +// extractColumnNames returns column names from an AST constraint. +func extractColumnNames(con *nodes.Constraint) []string { + if len(con.IndexColumns) > 0 { + names := make([]string, 0, len(con.IndexColumns)) + for _, ic := range con.IndexColumns { + if cr, ok := ic.Expr.(*nodes.ColumnRef); ok { + names = append(names, cr.Column) + } + } + return names + } + return con.Columns +} + +// buildIndexColumns converts AST IndexColumn list to catalog IndexColumn list. +func buildIndexColumns(con *nodes.Constraint) []*IndexColumn { + if len(con.IndexColumns) > 0 { + result := make([]*IndexColumn, 0, len(con.IndexColumns)) + for _, ic := range con.IndexColumns { + idxCol := &IndexColumn{ + Length: ic.Length, + Descending: ic.Desc, + } + if cr, ok := ic.Expr.(*nodes.ColumnRef); ok { + idxCol.Name = cr.Column + } else { + idxCol.Expr = nodeToSQL(ic.Expr) + } + result = append(result, idxCol) + } + return result + } + // Fallback to simple column names. + result := make([]*IndexColumn, 0, len(con.Columns)) + for _, name := range con.Columns { + result = append(result, &IndexColumn{Name: name}) + } + return result +} + +// allocIndexName generates a unique index name based on the first column, +// appending _2, _3, etc. on collision. +func allocIndexName(tbl *Table, baseName string) string { + candidate := baseName + suffix := 2 + for indexNameExists(tbl, candidate) { + candidate = fmt.Sprintf("%s_%d", baseName, suffix) + suffix++ + } + return candidate +} + +// hasIndexCoveringColumns returns true if the table already has an index whose +// leading columns match the given FK columns (left-prefix match). MySQL 8.0 +// reuses such an index instead of creating an implicit backing index for the FK. +func hasIndexCoveringColumns(tbl *Table, fkCols []string) bool { + for _, idx := range tbl.Indexes { + if len(idx.Columns) < len(fkCols) { + continue + } + match := true + for i, col := range fkCols { + if !strings.EqualFold(idx.Columns[i].Name, col) { + match = false + break + } + } + if match { + return true + } + } + return false +} + +// ensureFKBackingIndex creates an implicit backing index for FK columns +// if no existing index already covers them (MySQL 8.0 behavior). +// MySQL uses the constraint name as the index name when provided; +// otherwise falls back to the first column name via allocIndexName. +func ensureFKBackingIndex(tbl *Table, conName string, cols []string, idxCols []*IndexColumn) { + if hasIndexCoveringColumns(tbl, cols) { + return + } + idxName := conName + if idxName == "" { + idxName = allocIndexName(tbl, cols[0]) + } + tbl.Indexes = append(tbl.Indexes, &Index{ + Name: idxName, + Table: tbl, + Columns: idxCols, + Visible: true, + }) +} + +func indexNameExists(tbl *Table, name string) bool { + key := toLower(name) + for _, idx := range tbl.Indexes { + if toLower(idx.Name) == key { + return true + } + } + return false +} + +func indexTypeOrDefault(indexType, defaultType string) string { + if indexType != "" { + return indexType + } + return defaultType +} + +// resolveConstraintIndexType returns the index type from a constraint, +// checking both IndexType (USING before key parts) and IndexOptions (USING after key parts). +func resolveConstraintIndexType(con *nodes.Constraint) string { + if con.IndexType != "" { + return strings.ToUpper(con.IndexType) + } + for _, opt := range con.IndexOptions { + if strings.EqualFold(opt.Name, "USING") { + if s, ok := opt.Value.(*nodes.StringLit); ok { + return strings.ToUpper(s.Value) + } + } + } + return "" +} + +// applyIndexOptions extracts COMMENT, VISIBLE/INVISIBLE, and KEY_BLOCK_SIZE +// from AST IndexOptions and applies them to the given Index. +func applyIndexOptions(idx *Index, opts []*nodes.IndexOption) { + for _, opt := range opts { + switch strings.ToUpper(opt.Name) { + case "COMMENT": + if s, ok := opt.Value.(*nodes.StringLit); ok { + idx.Comment = s.Value + } + case "VISIBLE": + idx.Visible = true + case "INVISIBLE": + idx.Visible = false + case "KEY_BLOCK_SIZE": + switch n := opt.Value.(type) { + case *nodes.Integer: + idx.KeyBlockSize = int(n.Ival) + case *nodes.IntLit: + idx.KeyBlockSize = int(n.Value) + } + } + } +} + +// nextFKGeneratedNumber returns the next available counter for an auto-generated +// InnoDB FK constraint name of the form "_ibfk_". +// +// This matches MySQL 8.0's behavior in sql/sql_table.cc:5843 +// (get_fk_max_generated_name_number): it scans existing FK constraints on the +// table, parses any name that looks like "_ibfk_" as a +// generated name, and returns max(N)+1 (or 1 if no such names exist). +// +// Bytebase omni catalog is case-insensitive on table names (we lowercase the +// prefix before comparison). MySQL's own comparison is case-sensitive on +// already-lowered names, so this is equivalent for typical use. +// +// As in MySQL, pre-4.0.18-style names ("_ibfk_0") are +// ignored — we skip anything whose counter substring starts with '0'. +func nextFKGeneratedNumber(tbl *Table, tableName string) int { + prefix := toLower(tableName) + "_ibfk_" + max := 0 + for _, con := range tbl.Constraints { + if con.Type != ConForeignKey { + continue + } + name := toLower(con.Name) + if !strings.HasPrefix(name, prefix) { + continue + } + rest := name[len(prefix):] + if rest == "" || rest[0] == '0' { + continue + } + n := 0 + ok := true + for _, ch := range rest { + if ch < '0' || ch > '9' { + ok = false + break + } + n = n*10 + int(ch-'0') + } + if !ok { + continue + } + if n > max { + max = n + } + } + return max + 1 +} + +// nextCheckNumber returns the next available check constraint number for auto-naming. +// MySQL uses tableName_chk_N where N starts at 1 and increments, skipping existing names. +func nextCheckNumber(tbl *Table) int { + n := 1 + for { + name := fmt.Sprintf("%s_chk_%d", tbl.Name, n) + exists := false + for _, c := range tbl.Constraints { + if toLower(c.Name) == toLower(name) { + exists = true + break + } + } + if !exists { + return n + } + n++ + } +} + +func isStringType(dt string) bool { + switch dt { + case "char", "varchar", "tinytext", "text", "mediumtext", "longtext", + "enum", "set": + return true + } + return false +} + +// convertToBinaryType converts a string-type column with CHARACTER SET binary +// to the equivalent binary type (char->binary, varchar->varbinary, text->blob, etc.). +func convertToBinaryType(col *Column, dt *nodes.DataType) *Column { + switch col.DataType { + case "char": + col.DataType = "binary" + length := dt.Length + if length == 0 { + length = 1 + } + col.ColumnType = fmt.Sprintf("binary(%d)", length) + case "varchar": + col.DataType = "varbinary" + col.ColumnType = fmt.Sprintf("varbinary(%d)", dt.Length) + case "tinytext": + col.DataType = "tinyblob" + col.ColumnType = "tinyblob" + case "text": + col.DataType = "blob" + col.ColumnType = "blob" + case "mediumtext": + col.DataType = "mediumblob" + col.ColumnType = "mediumblob" + case "longtext": + col.DataType = "longblob" + col.ColumnType = "longblob" + } + // Binary types don't have charset/collation in SHOW CREATE TABLE. + col.Charset = "" + col.Collation = "" + return col +} + +// nodeToSQLGenerated converts an AST expression to SQL for use in a generated +// column definition. MySQL prefixes string literals with a charset introducer +// (e.g., _utf8mb4'value') in generated column expressions. +func nodeToSQLGenerated(node nodes.ExprNode, charset string) string { + if node == nil { + return "" + } + switch n := node.(type) { + case *nodes.ColumnRef: + if n.Table != "" { + return "`" + n.Table + "`.`" + n.Column + "`" + } + return "`" + n.Column + "`" + case *nodes.IntLit: + return fmt.Sprintf("%d", n.Value) + case *nodes.StringLit: + // MySQL adds charset introducer for string literals in generated columns. + if charset != "" { + return "_" + charset + "'" + n.Value + "'" + } + return "'" + n.Value + "'" + case *nodes.FuncCallExpr: + funcName := strings.ToLower(n.Name) + if n.Star { + return funcName + "(*)" + } + var args []string + for _, a := range n.Args { + args = append(args, nodeToSQLGenerated(a, charset)) + } + return funcName + "(" + strings.Join(args, ",") + ")" + case *nodes.NullLit: + return "NULL" + case *nodes.BoolLit: + if n.Value { + return "1" + } + return "0" + case *nodes.FloatLit: + return n.Value + case *nodes.BitLit: + val := strings.TrimLeft(n.Value, "0") + if val == "" { + val = "0" + } + return "b'" + val + "'" + case *nodes.ParenExpr: + return "(" + nodeToSQLGenerated(n.Expr, charset) + ")" + case *nodes.BinaryExpr: + left := nodeToSQLGenerated(n.Left, charset) + right := nodeToSQLGenerated(n.Right, charset) + // MySQL rewrites JSON operators to function calls in generated column expressions. + switch n.Op { + case nodes.BinOpJsonExtract: + return "json_extract(" + left + "," + right + ")" + case nodes.BinOpJsonUnquote: + return "json_unquote(json_extract(" + left + "," + right + "))" + } + op := binaryOpToString(n.Op) + return "(" + left + " " + op + " " + right + ")" + case *nodes.UnaryExpr: + operand := nodeToSQLGenerated(n.Operand, charset) + switch n.Op { + case nodes.UnaryMinus: + return "-" + operand + case nodes.UnaryNot: + return "NOT " + operand + case nodes.UnaryBitNot: + return "~" + operand + default: + return operand + } + default: + return "(?)" + } +} + +func nodeToSQL(node nodes.ExprNode) string { + return deparse.Deparse(node) +} + +func binaryOpToString(op nodes.BinaryOp) string { + switch op { + case nodes.BinOpAdd: + return "+" + case nodes.BinOpSub: + return "-" + case nodes.BinOpMul: + return "*" + case nodes.BinOpDiv: + return "/" + case nodes.BinOpMod: + return "%" + case nodes.BinOpEq: + return "=" + case nodes.BinOpNe: + return "!=" + case nodes.BinOpLt: + return "<" + case nodes.BinOpGt: + return ">" + case nodes.BinOpLe: + return "<=" + case nodes.BinOpGe: + return ">=" + case nodes.BinOpAnd: + return "and" + case nodes.BinOpOr: + return "or" + case nodes.BinOpBitAnd: + return "&" + case nodes.BinOpBitOr: + return "|" + case nodes.BinOpBitXor: + return "^" + case nodes.BinOpShiftLeft: + return "<<" + case nodes.BinOpShiftRight: + return ">>" + case nodes.BinOpDivInt: + return "DIV" + case nodes.BinOpXor: + return "XOR" + case nodes.BinOpRegexp: + return "REGEXP" + case nodes.BinOpLikeEscape: + return "LIKE" + case nodes.BinOpNullSafeEq: + return "<=>" + case nodes.BinOpJsonExtract: + return "->" + case nodes.BinOpJsonUnquote: + return "->>" + case nodes.BinOpSoundsLike: + return "SOUNDS LIKE" + default: + return "?" + } +} + +func formatColumnType(dt *nodes.DataType) string { + name := strings.ToLower(dt.Name) + + // MySQL type aliases: BOOLEAN/BOOL → tinyint(1), NUMERIC → decimal, SERIAL → bigint unsigned + // GEOMETRYCOLLECTION → geomcollection (MySQL 8.0 normalized form) + switch name { + case "boolean": + return "tinyint(1)" + case "numeric": + name = "decimal" + case "serial": + return "bigint unsigned" + case "geometrycollection": + name = "geomcollection" + } + + var buf strings.Builder + buf.WriteString(name) + + // Integer display width handling for MySQL 8.0: + // - Display width is deprecated and NOT shown by default + // - EXCEPTION: When ZEROFILL is used, MySQL 8.0 still shows the display width + // with default widths per type: tinyint(3), smallint(5), mediumint(8), int(10), bigint(20) + isIntType := isIntegerType(name) + if isIntType { + if dt.Zerofill { + width := dt.Length + if width == 0 { + width = defaultIntDisplayWidth(name, dt.Unsigned) + } + fmt.Fprintf(&buf, "(%d)", width) + } + // Non-zerofill integer types: strip display width (MySQL 8.0 deprecated) + } else if name == "decimal" && dt.Length == 0 && dt.Scale == 0 { + // DECIMAL with no precision → MySQL shows decimal(10,0) + buf.WriteString("(10,0)") + } else if isTextBlobLengthStripped(name) { + // TEXT(n) and BLOB(n) — MySQL stores the length internally but + // SHOW CREATE TABLE displays just TEXT / BLOB without the length. + // Do not emit length. + } else if name == "year" { + // YEAR(4) is deprecated in MySQL 8.0 — SHOW CREATE TABLE shows just `year`. + } else if (name == "char" || name == "binary") && dt.Length == 0 { + // CHAR/BINARY with no length → MySQL shows char(1)/binary(1) + buf.WriteString("(1)") + } else if dt.Length > 0 && dt.Scale > 0 { + fmt.Fprintf(&buf, "(%d,%d)", dt.Length, dt.Scale) + } else if dt.Length > 0 { + fmt.Fprintf(&buf, "(%d)", dt.Length) + } + + if len(dt.EnumValues) > 0 { + buf.WriteString("(") + for i, v := range dt.EnumValues { + if i > 0 { + buf.WriteString(",") + } + buf.WriteString("'" + escapeEnumValue(v) + "'") + } + buf.WriteString(")") + } + if dt.Unsigned { + buf.WriteString(" unsigned") + } + if dt.Zerofill { + buf.WriteString(" zerofill") + } + return buf.String() +} + +// isIntegerType returns true for MySQL integer types. +func isIntegerType(dt string) bool { + switch dt { + case "tinyint", "smallint", "mediumint", "int", "integer", "bigint": + return true + } + return false +} + +// defaultIntDisplayWidth returns the default display width for integer types +// when ZEROFILL is used. These are the MySQL defaults. +func defaultIntDisplayWidth(typeName string, unsigned bool) int { + switch typeName { + case "tinyint": + if unsigned { + return 3 + } + return 4 + case "smallint": + if unsigned { + return 5 + } + return 6 + case "mediumint": + if unsigned { + return 8 + } + return 9 + case "int", "integer": + if unsigned { + return 10 + } + return 11 + case "bigint": + if unsigned { + return 20 + } + return 20 + } + return 11 +} + +// isTextBlobLengthStripped returns true for types where MySQL strips the length +// in SHOW CREATE TABLE output (TEXT(n) → text, BLOB(n) → blob). +func isTextBlobLengthStripped(dt string) bool { + switch dt { + case "text", "blob": + return true + } + return false +} + +// escapeEnumValue escapes single quotes in ENUM/SET values for SHOW CREATE TABLE. +// MySQL uses '' (two single quotes) to escape a single quote in enum values. +func escapeEnumValue(s string) string { + return strings.ReplaceAll(s, "'", "''") +} + +// createTableLike implements CREATE TABLE t2 LIKE t1. +// It copies the structure (columns, indexes, constraints) from the source table. +func (c *Catalog) createTableLike(db *Database, tableName, key string, stmt *nodes.CreateTableStmt) error { + // Resolve source table. + srcDBName := stmt.Like.Schema + srcDB, err := c.resolveDatabase(srcDBName) + if err != nil { + return err + } + srcTbl := srcDB.GetTable(stmt.Like.Name) + if srcTbl == nil { + return errNoSuchTable(srcDB.Name, stmt.Like.Name) + } + + tbl := &Table{ + Name: tableName, + Database: db, + Columns: make([]*Column, 0, len(srcTbl.Columns)), + colByName: make(map[string]int), + Indexes: make([]*Index, 0, len(srcTbl.Indexes)), + Constraints: make([]*Constraint, 0, len(srcTbl.Constraints)), + Engine: srcTbl.Engine, + Charset: srcTbl.Charset, + Collation: srcTbl.Collation, + Comment: srcTbl.Comment, + RowFormat: srcTbl.RowFormat, + KeyBlockSize: srcTbl.KeyBlockSize, + Temporary: stmt.Temporary, + } + + // Copy columns. + for i, srcCol := range srcTbl.Columns { + col := &Column{ + Position: srcCol.Position, + Name: srcCol.Name, + DataType: srcCol.DataType, + ColumnType: srcCol.ColumnType, + Nullable: srcCol.Nullable, + AutoIncrement: srcCol.AutoIncrement, + Charset: srcCol.Charset, + Collation: srcCol.Collation, + Comment: srcCol.Comment, + OnUpdate: srcCol.OnUpdate, + Invisible: srcCol.Invisible, + } + if srcCol.Default != nil { + def := *srcCol.Default + col.Default = &def + } + if srcCol.Generated != nil { + col.Generated = &GeneratedColumnInfo{ + Expr: srcCol.Generated.Expr, + Stored: srcCol.Generated.Stored, + } + } + tbl.Columns = append(tbl.Columns, col) + tbl.colByName[toLower(col.Name)] = i + } + + // Copy indexes. + for _, srcIdx := range srcTbl.Indexes { + idx := &Index{ + Name: srcIdx.Name, + Table: tbl, + Unique: srcIdx.Unique, + Primary: srcIdx.Primary, + Fulltext: srcIdx.Fulltext, + Spatial: srcIdx.Spatial, + IndexType: srcIdx.IndexType, + Visible: srcIdx.Visible, + Comment: srcIdx.Comment, + KeyBlockSize: srcIdx.KeyBlockSize, + } + cols := make([]*IndexColumn, len(srcIdx.Columns)) + for i, sc := range srcIdx.Columns { + cols[i] = &IndexColumn{ + Name: sc.Name, + Length: sc.Length, + Descending: sc.Descending, + Expr: sc.Expr, + } + } + idx.Columns = cols + tbl.Indexes = append(tbl.Indexes, idx) + } + + // Copy constraints (skip FK — MySQL 8.0 does not copy FKs with LIKE). + for _, srcCon := range srcTbl.Constraints { + if srcCon.Type == ConForeignKey { + continue + } + con := &Constraint{ + Name: srcCon.Name, + Type: srcCon.Type, + Table: tbl, + Columns: append([]string{}, srcCon.Columns...), + IndexName: srcCon.IndexName, + CheckExpr: srcCon.CheckExpr, + NotEnforced: srcCon.NotEnforced, + RefDatabase: srcCon.RefDatabase, + RefTable: srcCon.RefTable, + RefColumns: append([]string{}, srcCon.RefColumns...), + OnDelete: srcCon.OnDelete, + OnUpdate: srcCon.OnUpdate, + } + tbl.Constraints = append(tbl.Constraints, con) + } + + db.Tables[key] = tbl + return nil +} + +func refActionToString(action nodes.ReferenceAction) string { + switch action { + case nodes.RefActRestrict: + return "RESTRICT" + case nodes.RefActCascade: + return "CASCADE" + case nodes.RefActSetNull: + return "SET NULL" + case nodes.RefActSetDefault: + return "SET DEFAULT" + case nodes.RefActNoAction: + return "NO ACTION" + default: + return "NO ACTION" + } +} diff --git a/tidb/catalog/tablecmds_test.go b/tidb/catalog/tablecmds_test.go new file mode 100644 index 00000000..2dc90ca6 --- /dev/null +++ b/tidb/catalog/tablecmds_test.go @@ -0,0 +1,306 @@ +package catalog + +import "testing" + +func mustExec(t *testing.T, c *Catalog, sql string) { + t.Helper() + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Fatalf("exec error: %v", r.Error) + } + } +} + +func setupWithDB(t *testing.T) *Catalog { + t.Helper() + c := New() + mustExec(t, c, "CREATE DATABASE testdb") + c.SetCurrentDatabase("testdb") + return c +} + +func TestCreateTableBasic(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, `CREATE TABLE users ( + id INT NOT NULL AUTO_INCREMENT, + name VARCHAR(100) NOT NULL, + email VARCHAR(255), + age INT UNSIGNED DEFAULT 0, + score DECIMAL(10,2), + PRIMARY KEY (id) + )`) + + db := c.GetDatabase("testdb") + tbl := db.GetTable("users") + if tbl == nil { + t.Fatal("table users not found") + } + if len(tbl.Columns) != 5 { + t.Fatalf("expected 5 columns, got %d", len(tbl.Columns)) + } + + // Check id column. + id := tbl.GetColumn("id") + if id == nil { + t.Fatal("column id not found") + } + if id.Nullable { + t.Error("id should not be nullable") + } + if !id.AutoIncrement { + t.Error("id should be auto_increment") + } + if id.DataType != "int" { + t.Errorf("expected data type 'int', got %q", id.DataType) + } + if id.Position != 1 { + t.Errorf("expected position 1, got %d", id.Position) + } + + // Check name column. + name := tbl.GetColumn("name") + if name == nil { + t.Fatal("column name not found") + } + if name.Nullable { + t.Error("name should not be nullable") + } + if name.ColumnType != "varchar(100)" { + t.Errorf("expected column type 'varchar(100)', got %q", name.ColumnType) + } + + // Check email column (nullable by default). + email := tbl.GetColumn("email") + if email == nil { + t.Fatal("column email not found") + } + if !email.Nullable { + t.Error("email should be nullable by default") + } + + // Check age column (unsigned, default). + age := tbl.GetColumn("age") + if age == nil { + t.Fatal("column age not found") + } + if age.ColumnType != "int unsigned" { + t.Errorf("expected column type 'int unsigned', got %q", age.ColumnType) + } + if age.Default == nil || *age.Default != "0" { + t.Errorf("expected default '0', got %v", age.Default) + } + + // Check score column. + score := tbl.GetColumn("score") + if score == nil { + t.Fatal("column score not found") + } + if score.ColumnType != "decimal(10,2)" { + t.Errorf("expected column type 'decimal(10,2)', got %q", score.ColumnType) + } + + // Check PK index. + if len(tbl.Indexes) < 1 { + t.Fatal("expected at least 1 index") + } + pkIdx := tbl.Indexes[0] + if pkIdx.Name != "PRIMARY" { + t.Errorf("expected PK index name 'PRIMARY', got %q", pkIdx.Name) + } + if !pkIdx.Primary { + t.Error("expected primary flag on PK index") + } + if !pkIdx.Unique { + t.Error("expected unique flag on PK index") + } +} + +func TestCreateTableDuplicate(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, "CREATE TABLE t1 (id INT)") + results, _ := c.Exec("CREATE TABLE t1 (id INT)", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("expected duplicate table error") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("expected *Error, got %T", results[0].Error) + } + if catErr.Code != ErrDupTable { + t.Errorf("expected error code %d, got %d", ErrDupTable, catErr.Code) + } +} + +func TestCreateTableIfNotExists(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, "CREATE TABLE t1 (id INT)") + results, _ := c.Exec("CREATE TABLE IF NOT EXISTS t1 (id INT)", nil) + if results[0].Error != nil { + t.Errorf("IF NOT EXISTS should not error: %v", results[0].Error) + } +} + +func TestCreateTableDupColumn(t *testing.T) { + c := setupWithDB(t) + results, _ := c.Exec("CREATE TABLE t1 (id INT, id VARCHAR(10))", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("expected duplicate column error") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("expected *Error, got %T", results[0].Error) + } + if catErr.Code != ErrDupColumn { + t.Errorf("expected error code %d, got %d", ErrDupColumn, catErr.Code) + } +} + +func TestCreateTableMultiplePK(t *testing.T) { + c := setupWithDB(t) + results, _ := c.Exec(`CREATE TABLE t1 ( + id INT PRIMARY KEY, + name VARCHAR(100), + PRIMARY KEY (name) + )`, &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("expected multiple primary key error") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("expected *Error, got %T", results[0].Error) + } + if catErr.Code != ErrMultiplePriKey { + t.Errorf("expected error code %d, got %d", ErrMultiplePriKey, catErr.Code) + } +} + +func TestCreateTableNoDatabaseSelected(t *testing.T) { + c := New() + results, _ := c.Exec("CREATE TABLE t1 (id INT)", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("expected no database selected error") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("expected *Error, got %T", results[0].Error) + } + if catErr.Code != ErrNoDatabaseSelected { + t.Errorf("expected error code %d, got %d", ErrNoDatabaseSelected, catErr.Code) + } +} + +func TestCreateTableWithIndexes(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, `CREATE TABLE t1 ( + id INT NOT NULL AUTO_INCREMENT, + email VARCHAR(255) NOT NULL, + name VARCHAR(100), + PRIMARY KEY (id), + UNIQUE KEY idx_email (email), + INDEX idx_name (name) + )`) + + db := c.GetDatabase("testdb") + tbl := db.GetTable("t1") + if tbl == nil { + t.Fatal("table t1 not found") + } + + // Should have 3 indexes: PRIMARY, idx_email, idx_name. + if len(tbl.Indexes) != 3 { + t.Fatalf("expected 3 indexes, got %d", len(tbl.Indexes)) + } + + // Check UNIQUE KEY. + var uniqueIdx *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_email" { + uniqueIdx = idx + break + } + } + if uniqueIdx == nil { + t.Fatal("unique index idx_email not found") + } + if !uniqueIdx.Unique { + t.Error("idx_email should be unique") + } + if len(uniqueIdx.Columns) != 1 || uniqueIdx.Columns[0].Name != "email" { + t.Errorf("idx_email should have column 'email'") + } + + // Check regular INDEX. + var regIdx *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_name" { + regIdx = idx + break + } + } + if regIdx == nil { + t.Fatal("index idx_name not found") + } + if regIdx.Unique { + t.Error("idx_name should not be unique") + } +} + +func TestCreateTableWithFK(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, "CREATE TABLE departments (id INT NOT NULL, PRIMARY KEY (id))") + mustExec(t, c, `CREATE TABLE employees ( + id INT NOT NULL AUTO_INCREMENT, + dept_id INT NOT NULL, + PRIMARY KEY (id), + CONSTRAINT fk_dept FOREIGN KEY (dept_id) REFERENCES departments(id) ON DELETE CASCADE + )`) + + db := c.GetDatabase("testdb") + tbl := db.GetTable("employees") + if tbl == nil { + t.Fatal("table employees not found") + } + + // Check FK constraint. + var fk *Constraint + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey { + fk = con + break + } + } + if fk == nil { + t.Fatal("FK constraint not found") + } + if fk.Name != "fk_dept" { + t.Errorf("expected FK name 'fk_dept', got %q", fk.Name) + } + if fk.RefTable != "departments" { + t.Errorf("expected ref table 'departments', got %q", fk.RefTable) + } + if len(fk.RefColumns) != 1 || fk.RefColumns[0] != "id" { + t.Errorf("expected ref column 'id', got %v", fk.RefColumns) + } + if len(fk.Columns) != 1 || fk.Columns[0] != "dept_id" { + t.Errorf("expected column 'dept_id', got %v", fk.Columns) + } + if fk.OnDelete != "CASCADE" { + t.Errorf("expected ON DELETE CASCADE, got %q", fk.OnDelete) + } + + // Check that FK has a backing index (MySQL uses constraint name when provided). + var fkIdx *Index + for _, idx := range tbl.Indexes { + if idx.Name == "fk_dept" { + fkIdx = idx + break + } + } + if fkIdx == nil { + t.Fatal("FK backing index not found") + } +} diff --git a/tidb/catalog/triggercmds.go b/tidb/catalog/triggercmds.go new file mode 100644 index 00000000..7d975e5c --- /dev/null +++ b/tidb/catalog/triggercmds.go @@ -0,0 +1,138 @@ +package catalog + +import ( + "fmt" + "strings" + + nodes "github.com/bytebase/omni/tidb/ast" +) + +func (c *Catalog) createTrigger(stmt *nodes.CreateTriggerStmt) error { + // Resolve database from the table reference. + schema := "" + if stmt.Table != nil { + schema = stmt.Table.Schema + } + db, err := c.resolveDatabase(schema) + if err != nil { + return err + } + + // Verify the table exists. + tableName := "" + if stmt.Table != nil { + tableName = stmt.Table.Name + } + if tableName != "" { + tbl := db.GetTable(tableName) + if tbl == nil { + return errNoSuchTable(db.Name, tableName) + } + } + + name := stmt.Name + key := toLower(name) + + if _, exists := db.Triggers[key]; exists { + if !stmt.IfNotExists { + return errDupTrigger(name) + } + return nil + } + + // MySQL always sets a definer. Default to `root`@`%` when not specified. + definer := stmt.Definer + if definer == "" { + definer = "`root`@`%`" + } + + trigger := &Trigger{ + Name: name, + Database: db, + Table: tableName, + Timing: stmt.Timing, + Event: stmt.Event, + Definer: definer, + Body: strings.TrimSpace(stmt.Body), + } + + if stmt.Order != nil { + trigger.Order = &TriggerOrderInfo{ + Follows: stmt.Order.Follows, + TriggerName: stmt.Order.TriggerName, + } + } + + db.Triggers[key] = trigger + return nil +} + +func (c *Catalog) dropTrigger(stmt *nodes.DropTriggerStmt) error { + schema := "" + if stmt.Name != nil { + schema = stmt.Name.Schema + } + db, err := c.resolveDatabase(schema) + if err != nil { + if stmt.IfExists { + return nil + } + return err + } + + name := "" + if stmt.Name != nil { + name = stmt.Name.Name + } + key := toLower(name) + + if _, exists := db.Triggers[key]; !exists { + if stmt.IfExists { + return nil + } + return errNoSuchTrigger(db.Name, name) + } + + delete(db.Triggers, key) + return nil +} + +// ShowCreateTrigger produces MySQL 8.0-compatible SHOW CREATE TRIGGER output. +// +// MySQL 8.0 SHOW CREATE TRIGGER format: +// +// CREATE DEFINER=`root`@`%` TRIGGER `trigger_name` BEFORE INSERT ON `table_name` FOR EACH ROW trigger_body +func (c *Catalog) ShowCreateTrigger(database, name string) string { + db := c.GetDatabase(database) + if db == nil { + return "" + } + trigger := db.Triggers[toLower(name)] + if trigger == nil { + return "" + } + return showCreateTrigger(trigger) +} + +func showCreateTrigger(tr *Trigger) string { + var b strings.Builder + + b.WriteString("CREATE") + + // DEFINER + if tr.Definer != "" { + b.WriteString(fmt.Sprintf(" DEFINER=%s", tr.Definer)) + } + + b.WriteString(fmt.Sprintf(" TRIGGER `%s` %s %s ON `%s` FOR EACH ROW", + tr.Name, tr.Timing, tr.Event, tr.Table)) + + // Note: MySQL 8.0 SHOW CREATE TRIGGER does NOT include FOLLOWS/PRECEDES. + + // Body + if tr.Body != "" { + b.WriteString(fmt.Sprintf(" %s", tr.Body)) + } + + return b.String() +} diff --git a/tidb/catalog/viewcmds.go b/tidb/catalog/viewcmds.go new file mode 100644 index 00000000..6c0e4f4c --- /dev/null +++ b/tidb/catalog/viewcmds.go @@ -0,0 +1,323 @@ +package catalog + +import ( + "fmt" + "strings" + + nodes "github.com/bytebase/omni/tidb/ast" + "github.com/bytebase/omni/tidb/deparse" +) + +func (c *Catalog) createView(stmt *nodes.CreateViewStmt) error { + db, err := c.resolveDatabase(stmt.Name.Schema) + if err != nil { + return err + } + key := toLower(stmt.Name.Name) + // Tables and views share the same namespace in MySQL. + if _, exists := db.Tables[key]; exists { + return errDupTable(stmt.Name.Name) + } + if _, exists := db.Views[key]; exists { + if !stmt.OrReplace { + return errDupTable(stmt.Name.Name) + } + } + + // MySQL always sets a definer. Default to `root`@`%` when not specified. + definer := stmt.Definer + if definer == "" { + definer = "`root`@`%`" + } + + // Analyze the view body for semantic IR (used by lineage, SDL diff). + // Done before deparseViewSelect because deparse mutates the AST. + var analyzedQuery *Query + if stmt.Select != nil { + q, err := c.AnalyzeSelectStmt(stmt.Select) + if err == nil { + analyzedQuery = q + } + // Swallow error: view analysis may fail for complex views not yet + // supported by the analyzer. The view still gets created with + // Definition text but without AnalyzedQuery. + } + + // Resolve, rewrite, and deparse the SELECT to produce canonical definition. + definition, derivedCols := c.deparseViewSelect(stmt.Select, stmt.SelectText, db) + + // Use explicit column list if provided, otherwise derive from SELECT target list. + hasExplicit := len(stmt.Columns) > 0 + viewCols := stmt.Columns + if !hasExplicit { + viewCols = derivedCols + } + + db.Views[key] = &View{ + Name: stmt.Name.Name, + Database: db, + Definition: definition, + Algorithm: stmt.Algorithm, + Definer: definer, + SqlSecurity: stmt.SqlSecurity, + CheckOption: stmt.CheckOption, + Columns: viewCols, + ExplicitColumns: hasExplicit, + AnalyzedQuery: analyzedQuery, + } + return nil +} + +func (c *Catalog) alterView(stmt *nodes.AlterViewStmt) error { + db, err := c.resolveDatabase(stmt.Name.Schema) + if err != nil { + return err + } + key := toLower(stmt.Name.Name) + // ALTER VIEW requires the view to exist. + if _, exists := db.Views[key]; !exists { + return errUnknownTable(db.Name, stmt.Name.Name) + } + + // MySQL always sets a definer. Default to `root`@`%` when not specified. + definer := stmt.Definer + if definer == "" { + definer = "`root`@`%`" + } + + // Analyze the view body for semantic IR (used by lineage, SDL diff). + // Done before deparseViewSelect because deparse mutates the AST. + var analyzedQuery *Query + if stmt.Select != nil { + q, err := c.AnalyzeSelectStmt(stmt.Select) + if err == nil { + analyzedQuery = q + } + } + + // Resolve, rewrite, and deparse the SELECT to produce canonical definition. + definition, derivedCols := c.deparseViewSelect(stmt.Select, stmt.SelectText, db) + + // Use explicit column list if provided, otherwise derive from SELECT target list. + hasExplicit := len(stmt.Columns) > 0 + viewCols := stmt.Columns + if !hasExplicit { + viewCols = derivedCols + } + + db.Views[key] = &View{ + Name: stmt.Name.Name, + Database: db, + Definition: definition, + Algorithm: stmt.Algorithm, + Definer: definer, + SqlSecurity: stmt.SqlSecurity, + CheckOption: stmt.CheckOption, + Columns: viewCols, + ExplicitColumns: hasExplicit, + AnalyzedQuery: analyzedQuery, + } + return nil +} + +func (c *Catalog) dropView(stmt *nodes.DropViewStmt) error { + for _, ref := range stmt.Views { + db, err := c.resolveDatabase(ref.Schema) + if err != nil { + if stmt.IfExists { + continue + } + return err + } + key := toLower(ref.Name) + if _, exists := db.Views[key]; !exists { + if stmt.IfExists { + continue + } + return errUnknownTable(db.Name, ref.Name) + } + delete(db.Views, key) + } + return nil +} + +// deparseViewSelect resolves, rewrites, and deparses the SELECT AST for a view. +// If the AST is nil (parser didn't produce one), falls back to the raw SelectText. +// Returns the deparsed definition and the derived column names from the resolved +// SELECT target list (used when no explicit column list is specified). +func (c *Catalog) deparseViewSelect(sel *nodes.SelectStmt, rawText string, db *Database) (string, []string) { + if sel == nil { + return rawText, nil + } + + // Build a TableLookup that resolves table names from this database. + lookup := tableLookupForDB(db) + + // Determine the database charset for CAST resolution. + charset := db.Charset + if charset == "" { + charset = c.defaultCharset + } + + // Resolve: qualify columns, expand *, normalize JOINs. + resolver := &deparse.Resolver{ + Lookup: lookup, + DefaultCharset: charset, + } + resolver.Resolve(sel) + + // Extract column names from the resolved target list. + derivedCols := extractViewColumns(sel) + + // Rewrite: NOT folding, boolean context wrapping. + deparse.RewriteSelectStmt(sel) + + // Deparse: AST → canonical SQL text. + return deparse.DeparseSelect(sel), derivedCols +} + +// extractViewColumns extracts column names from a resolved SELECT target list. +// This produces the column list that MySQL would derive for a view. +func extractViewColumns(sel *nodes.SelectStmt) []string { + if sel == nil { + return nil + } + var cols []string + for _, target := range sel.TargetList { + rt, ok := target.(*nodes.ResTarget) + if !ok { + continue + } + if rt.Name != "" { + cols = append(cols, rt.Name) + } else if cr, ok := rt.Val.(*nodes.ColumnRef); ok { + cols = append(cols, cr.Column) + } + } + return cols +} + +// tableLookupForDB returns a deparse.TableLookup function that resolves table +// and view names from the given database's Tables and Views maps. +func tableLookupForDB(db *Database) deparse.TableLookup { + return func(tableName string) *deparse.ResolverTable { + key := toLower(tableName) + // Try tables first. + tbl := db.Tables[key] + if tbl != nil { + cols := make([]deparse.ResolverColumn, len(tbl.Columns)) + for i, c := range tbl.Columns { + cols[i] = deparse.ResolverColumn{ + Name: c.Name, + Position: c.Position, + } + } + return &deparse.ResolverTable{ + Name: tbl.Name, + Columns: cols, + } + } + // Fall back to views. + v := db.Views[key] + if v != nil { + cols := make([]deparse.ResolverColumn, len(v.Columns)) + for i, colName := range v.Columns { + cols[i] = deparse.ResolverColumn{ + Name: colName, + Position: i + 1, + } + } + return &deparse.ResolverTable{ + Name: v.Name, + Columns: cols, + } + } + return nil + } +} + +// ShowCreateView produces MySQL 8.0-compatible SHOW CREATE VIEW output. +// Returns "" if the database or view does not exist. +func (c *Catalog) ShowCreateView(database, name string) string { + db := c.GetDatabase(database) + if db == nil { + return "" + } + v := db.Views[toLower(name)] + if v == nil { + return "" + } + return showCreateView(v) +} + +// formatDefiner ensures the definer string is backtick-quoted per MySQL 8.0 format. +// Input can be: `root`@`%`, root@%, 'root'@'%', etc. +// Output: `root`@`%` +func formatDefiner(definer string) string { + // If already formatted with backticks, return as-is. + if strings.HasPrefix(definer, "`") && strings.Contains(definer, "@") { + return definer + } + // Split on @ + parts := strings.SplitN(definer, "@", 2) + if len(parts) == 1 { + // No @ — just backtick-quote the whole thing. + return "`" + strings.Trim(parts[0], "`'") + "`" + } + user := strings.Trim(parts[0], "`'") + host := strings.Trim(parts[1], "`'") + return fmt.Sprintf("`%s`@`%s`", user, host) +} + +// showCreateView produces the SHOW CREATE VIEW output for a view. +// MySQL 8.0 format: +// +// CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`%` SQL SECURITY DEFINER VIEW `view_name` AS select_statement +// WITH CASCADED CHECK OPTION +func showCreateView(v *View) string { + var b strings.Builder + + b.WriteString("CREATE") + + // ALGORITHM — MySQL 8.0 always shows ALGORITHM, defaults to UNDEFINED. + algorithm := v.Algorithm + if algorithm == "" { + algorithm = "UNDEFINED" + } + b.WriteString(fmt.Sprintf(" ALGORITHM=%s", strings.ToUpper(algorithm))) + + // DEFINER — MySQL 8.0 always shows DEFINER with backtick-quoted user@host. + if v.Definer != "" { + b.WriteString(fmt.Sprintf(" DEFINER=%s", formatDefiner(v.Definer))) + } + + // SQL SECURITY — MySQL 8.0 always shows SQL SECURITY, defaults to DEFINER. + sqlSecurity := v.SqlSecurity + if sqlSecurity == "" { + sqlSecurity = "DEFINER" + } + b.WriteString(fmt.Sprintf(" SQL SECURITY %s", strings.ToUpper(sqlSecurity))) + + // VIEW name + b.WriteString(fmt.Sprintf(" VIEW `%s`", v.Name)) + + // Column list (only if explicitly specified by user in CREATE VIEW). + if v.ExplicitColumns && len(v.Columns) > 0 { + cols := make([]string, len(v.Columns)) + for i, c := range v.Columns { + cols[i] = fmt.Sprintf("`%s`", c) + } + b.WriteString(fmt.Sprintf(" (%s)", strings.Join(cols, ","))) + } + + // AS select_statement + b.WriteString(" AS ") + b.WriteString(v.Definition) + + // WITH CHECK OPTION + if v.CheckOption != "" { + b.WriteString(fmt.Sprintf(" WITH %s CHECK OPTION", strings.ToUpper(v.CheckOption))) + } + + return b.String() +} diff --git a/tidb/catalog/viewcmds_test.go b/tidb/catalog/viewcmds_test.go new file mode 100644 index 00000000..88be505c --- /dev/null +++ b/tidb/catalog/viewcmds_test.go @@ -0,0 +1,153 @@ +package catalog + +import ( + "strings" + "testing" +) + +func TestCreateView(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + _, err := c.Exec("CREATE VIEW v1 AS SELECT 1", nil) + if err != nil { + t.Fatal(err) + } + db := c.GetDatabase("test") + if db.Views[toLower("v1")] == nil { + t.Fatal("view should exist") + } +} + +func TestCreateViewOrReplace(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE VIEW v1 AS SELECT 1", nil) + results, _ := c.Exec("CREATE OR REPLACE VIEW v1 AS SELECT 2", nil) + if results[0].Error != nil { + t.Fatalf("OR REPLACE should not error: %v", results[0].Error) + } + if c.GetDatabase("test").Views[toLower("v1")] == nil { + t.Fatal("view should still exist after replace") + } +} + +func TestCreateViewDuplicate(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE VIEW v1 AS SELECT 1", nil) + results, _ := c.Exec("CREATE VIEW v1 AS SELECT 2", &ExecOptions{ContinueOnError: true}) + if results[0].Error == nil { + t.Fatal("expected duplicate error") + } +} + +func TestDropView(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE VIEW v1 AS SELECT 1", nil) + _, err := c.Exec("DROP VIEW v1", nil) + if err != nil { + t.Fatal(err) + } + if c.GetDatabase("test").Views[toLower("v1")] != nil { + t.Fatal("view should be dropped") + } +} + +func TestDropViewIfExists(t *testing.T) { + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + results, _ := c.Exec("DROP VIEW IF EXISTS noexist", nil) + if results[0].Error != nil { + t.Errorf("IF EXISTS should not error: %v", results[0].Error) + } +} + +// TestSection_7_1_ViewCreationPipeline verifies that createView() calls +// resolver + deparser instead of storing raw SelectText. +func TestSection_7_1_ViewCreationPipeline(t *testing.T) { + t.Run("createView_uses_deparser", func(t *testing.T) { + // Create a view with a schema-aware SELECT; the stored Definition + // should be the deparsed output, not the raw input. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t (a INT, b INT)", nil) + + results, _ := c.Exec("CREATE VIEW v1 AS SELECT a, b FROM t", nil) + if results[0].Error != nil { + t.Fatalf("CREATE VIEW error: %v", results[0].Error) + } + + db := c.GetDatabase("test") + v := db.Views[toLower("v1")] + if v == nil { + t.Fatal("view v1 should exist") + } + + // The definition should contain qualified columns (from resolver) + // and proper formatting (from deparser). + if !strings.Contains(v.Definition, "`t`.`a`") { + t.Errorf("Definition should contain qualified column `t`.`a`, got: %s", v.Definition) + } + if !strings.Contains(v.Definition, "AS `a`") { + t.Errorf("Definition should contain alias AS `a`, got: %s", v.Definition) + } + + // The raw input "SELECT a, b FROM t" should NOT be stored verbatim. + if v.Definition == "SELECT a, b FROM t" { + t.Error("Definition should be deparsed, not raw input") + } + + expected := "select `t`.`a` AS `a`,`t`.`b` AS `b` from `t`" + if v.Definition != expected { + t.Errorf("Definition mismatch:\n got: %q\n want: %q", v.Definition, expected) + } + }) + + t.Run("definition_contains_deparsed_sql", func(t *testing.T) { + // Verify that View.Definition has deparsed SQL with column qualification + // and function rewrites, not the raw input text. + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t (a INT)", nil) + + c.Exec("CREATE VIEW v1 AS SELECT * FROM t WHERE a > 0", nil) + db := c.GetDatabase("test") + v := db.Views[toLower("v1")] + + // * should be expanded to named columns + expected := "select `t`.`a` AS `a` from `t` where (`t`.`a` > 0)" + if v.Definition != expected { + t.Errorf("Definition mismatch:\n got: %q\n want: %q", v.Definition, expected) + } + }) + + t.Run("preamble_format", func(t *testing.T) { + // Verify SHOW CREATE VIEW preamble matches MySQL 8.0 format: + // CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`%` SQL SECURITY DEFINER VIEW `v` AS ... + c := New() + c.Exec("CREATE DATABASE test", nil) + c.SetCurrentDatabase("test") + c.Exec("CREATE TABLE t (a INT)", nil) + c.Exec("CREATE VIEW v1 AS SELECT a FROM t", nil) + + ddl := c.ShowCreateView("test", "v1") + expectedPrefix := "CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`%` SQL SECURITY DEFINER VIEW `v1` AS " + if !strings.HasPrefix(ddl, expectedPrefix) { + t.Errorf("SHOW CREATE VIEW preamble mismatch:\n got: %q\n want prefix: %q", ddl, expectedPrefix) + } + + // Full output check + expectedFull := expectedPrefix + "select `t`.`a` AS `a` from `t`" + if ddl != expectedFull { + t.Errorf("SHOW CREATE VIEW full mismatch:\n got: %q\n want: %q", ddl, expectedFull) + } + }) +} diff --git a/tidb/catalog/wt_10_1_test.go b/tidb/catalog/wt_10_1_test.go new file mode 100644 index 00000000..57d9f400 --- /dev/null +++ b/tidb/catalog/wt_10_1_test.go @@ -0,0 +1,363 @@ +package catalog + +import ( + "strings" + "testing" +) + +// --- Section 6.1 (Phase 6): Basic LIKE Completeness (9 scenarios) --- +// File target: wt_10_1_test.go +// Proof: go test ./mysql/catalog/ -short -count=1 -run "TestWalkThrough_10_1" + +func TestWalkThrough_10_1_BasicLIKECompleteness(t *testing.T) { + // Scenario 1: LIKE copies all column definitions (name, type, nullability, default, comment) + t.Run("like_copies_column_definitions", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE src ( + id INT NOT NULL, + name VARCHAR(100) NOT NULL DEFAULT 'unknown' COMMENT 'user name', + email VARCHAR(255) DEFAULT NULL, + score DECIMAL(10,2) NOT NULL DEFAULT 0.00, + PRIMARY KEY (id) + )`) + wtExec(t, c, `CREATE TABLE dst LIKE src`) + + srcTbl := c.GetDatabase("testdb").GetTable("src") + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table dst not found") + } + if len(dstTbl.Columns) != len(srcTbl.Columns) { + t.Fatalf("expected %d columns, got %d", len(srcTbl.Columns), len(dstTbl.Columns)) + } + + // Verify each column attribute + for i, srcCol := range srcTbl.Columns { + dstCol := dstTbl.Columns[i] + if dstCol.Name != srcCol.Name { + t.Errorf("col %d: name mismatch: %q vs %q", i, srcCol.Name, dstCol.Name) + } + if dstCol.ColumnType != srcCol.ColumnType { + t.Errorf("col %q: type mismatch: %q vs %q", srcCol.Name, srcCol.ColumnType, dstCol.ColumnType) + } + if dstCol.Nullable != srcCol.Nullable { + t.Errorf("col %q: nullable mismatch: %v vs %v", srcCol.Name, srcCol.Nullable, dstCol.Nullable) + } + if dstCol.Comment != srcCol.Comment { + t.Errorf("col %q: comment mismatch: %q vs %q", srcCol.Name, srcCol.Comment, dstCol.Comment) + } + // Compare defaults + if (srcCol.Default == nil) != (dstCol.Default == nil) { + t.Errorf("col %q: default nil mismatch", srcCol.Name) + } else if srcCol.Default != nil && *srcCol.Default != *dstCol.Default { + t.Errorf("col %q: default value mismatch: %q vs %q", srcCol.Name, *srcCol.Default, *dstCol.Default) + } + } + + // Verify SHOW CREATE TABLE DDLs are structurally similar + srcDDL := c.ShowCreateTable("testdb", "src") + dstDDL := c.ShowCreateTable("testdb", "dst") + // Replace table names for comparison + srcNorm := strings.Replace(srcDDL, "`src`", "`dst`", 1) + if srcNorm != dstDDL { + t.Errorf("SHOW CREATE TABLE mismatch:\nsrc:\n%s\ndst:\n%s", srcDDL, dstDDL) + } + }) + + // Scenario 2: LIKE copies PRIMARY KEY + t.Run("like_copies_primary_key", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE src ( + id INT NOT NULL, + name VARCHAR(50) NOT NULL, + PRIMARY KEY (id) + )`) + wtExec(t, c, `CREATE TABLE dst LIKE src`) + + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table dst not found") + } + + // Check PRIMARY index exists + var hasPK bool + for _, idx := range dstTbl.Indexes { + if idx.Primary { + hasPK = true + if len(idx.Columns) != 1 || idx.Columns[0].Name != "id" { + t.Errorf("expected PK on (id), got %v", idx.Columns) + } + break + } + } + if !hasPK { + t.Error("expected PRIMARY KEY on dst table") + } + + ddl := c.ShowCreateTable("testdb", "dst") + if !strings.Contains(ddl, "PRIMARY KEY") { + t.Errorf("expected PRIMARY KEY in DDL:\n%s", ddl) + } + }) + + // Scenario 3: LIKE copies UNIQUE KEYs + t.Run("like_copies_unique_keys", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE src ( + id INT NOT NULL, + email VARCHAR(255) NOT NULL, + code VARCHAR(20) NOT NULL, + PRIMARY KEY (id), + UNIQUE KEY uk_email (email), + UNIQUE KEY uk_code (code) + )`) + wtExec(t, c, `CREATE TABLE dst LIKE src`) + + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table dst not found") + } + + uniqueCount := 0 + for _, idx := range dstTbl.Indexes { + if idx.Unique && !idx.Primary { + uniqueCount++ + } + } + if uniqueCount != 2 { + t.Errorf("expected 2 unique keys, got %d", uniqueCount) + } + + ddl := c.ShowCreateTable("testdb", "dst") + if !strings.Contains(ddl, "UNIQUE KEY `uk_email`") { + t.Errorf("expected UNIQUE KEY uk_email in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "UNIQUE KEY `uk_code`") { + t.Errorf("expected UNIQUE KEY uk_code in DDL:\n%s", ddl) + } + }) + + // Scenario 4: LIKE copies regular indexes + t.Run("like_copies_regular_indexes", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE src ( + id INT NOT NULL, + name VARCHAR(100), + age INT, + PRIMARY KEY (id), + KEY idx_name (name), + KEY idx_age (age) + )`) + wtExec(t, c, `CREATE TABLE dst LIKE src`) + + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table dst not found") + } + + regularCount := 0 + for _, idx := range dstTbl.Indexes { + if !idx.Primary && !idx.Unique && !idx.Fulltext && !idx.Spatial { + regularCount++ + } + } + if regularCount != 2 { + t.Errorf("expected 2 regular indexes, got %d", regularCount) + } + + ddl := c.ShowCreateTable("testdb", "dst") + if !strings.Contains(ddl, "KEY `idx_name`") { + t.Errorf("expected KEY idx_name in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "KEY `idx_age`") { + t.Errorf("expected KEY idx_age in DDL:\n%s", ddl) + } + }) + + // Scenario 5: LIKE copies FULLTEXT indexes + t.Run("like_copies_fulltext_indexes", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE src ( + id INT NOT NULL, + content TEXT, + PRIMARY KEY (id), + FULLTEXT KEY ft_content (content) + )`) + wtExec(t, c, `CREATE TABLE dst LIKE src`) + + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table dst not found") + } + + var hasFT bool + for _, idx := range dstTbl.Indexes { + if idx.Fulltext { + hasFT = true + if idx.Name != "ft_content" { + t.Errorf("expected fulltext index named ft_content, got %q", idx.Name) + } + break + } + } + if !hasFT { + t.Error("expected FULLTEXT index on dst table") + } + + ddl := c.ShowCreateTable("testdb", "dst") + if !strings.Contains(ddl, "FULLTEXT KEY `ft_content`") { + t.Errorf("expected FULLTEXT KEY ft_content in DDL:\n%s", ddl) + } + }) + + // Scenario 6: LIKE copies CHECK constraints + t.Run("like_copies_check_constraints", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE src ( + id INT NOT NULL, + age INT, + PRIMARY KEY (id), + CONSTRAINT chk_age CHECK (age >= 0) + )`) + wtExec(t, c, `CREATE TABLE dst LIKE src`) + + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table dst not found") + } + + var hasCheck bool + for _, con := range dstTbl.Constraints { + if con.Type == ConCheck { + hasCheck = true + if con.Name != "chk_age" { + t.Errorf("expected check constraint named chk_age, got %q", con.Name) + } + break + } + } + if !hasCheck { + t.Error("expected CHECK constraint on dst table") + } + + ddl := c.ShowCreateTable("testdb", "dst") + if !strings.Contains(ddl, "CONSTRAINT `chk_age` CHECK") { + t.Errorf("expected CHECK constraint in DDL:\n%s", ddl) + } + }) + + // Scenario 7: LIKE does NOT copy FOREIGN KEY constraints — MySQL 8.0 behavior + t.Run("like_does_not_copy_foreign_keys", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE parent ( + id INT NOT NULL, + PRIMARY KEY (id) + )`) + wtExec(t, c, `CREATE TABLE src ( + id INT NOT NULL, + parent_id INT NOT NULL, + PRIMARY KEY (id), + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent(id) + )`) + wtExec(t, c, `CREATE TABLE dst LIKE src`) + + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table dst not found") + } + + // No FK constraints should be copied + for _, con := range dstTbl.Constraints { + if con.Type == ConForeignKey { + t.Errorf("LIKE should NOT copy FK constraints, but found FK %q", con.Name) + } + } + + // The FK-backing index IS still copied (it's a regular KEY on the source) + ddl := c.ShowCreateTable("testdb", "dst") + if strings.Contains(ddl, "FOREIGN KEY") { + t.Errorf("expected no FOREIGN KEY in dst DDL:\n%s", ddl) + } + // The backing index for the FK should still be present + if !strings.Contains(ddl, "KEY `fk_parent`") { + t.Errorf("expected FK-backing index to be copied:\n%s", ddl) + } + }) + + // Scenario 8: LIKE copies AUTO_INCREMENT column attribute — but counter resets to 0 + t.Run("like_copies_auto_increment_attribute_resets_counter", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE src ( + id INT NOT NULL AUTO_INCREMENT, + name VARCHAR(100), + PRIMARY KEY (id) + )`) + // Insert to advance the counter on src + wtExec(t, c, `INSERT INTO src (name) VALUES ('a'), ('b'), ('c')`) + + wtExec(t, c, `CREATE TABLE dst LIKE src`) + + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table dst not found") + } + + // AUTO_INCREMENT attribute should be copied + idCol := dstTbl.GetColumn("id") + if idCol == nil { + t.Fatal("column id not found on dst") + } + if !idCol.AutoIncrement { + t.Error("expected AUTO_INCREMENT attribute on dst.id") + } + + // Counter should reset — dst DDL should NOT show AUTO_INCREMENT=N (or show AUTO_INCREMENT=1 at most) + dstDDL := c.ShowCreateTable("testdb", "dst") + if strings.Contains(dstDDL, "AUTO_INCREMENT=") { + // MySQL 8.0 does not show AUTO_INCREMENT=N on empty tables + t.Errorf("expected no AUTO_INCREMENT=N on empty dst table:\n%s", dstDDL) + } + + // Verify the column-level auto_increment keyword is present + if !strings.Contains(dstDDL, "AUTO_INCREMENT") { + t.Errorf("expected AUTO_INCREMENT column attribute in DDL:\n%s", dstDDL) + } + }) + + // Scenario 9: LIKE copies ENGINE, CHARSET, COLLATION, COMMENT + t.Run("like_copies_table_options", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE src ( + id INT NOT NULL, + PRIMARY KEY (id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='source table'`) + wtExec(t, c, `CREATE TABLE dst LIKE src`) + + srcTbl := c.GetDatabase("testdb").GetTable("src") + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table dst not found") + } + + if dstTbl.Engine != srcTbl.Engine { + t.Errorf("ENGINE mismatch: %q vs %q", srcTbl.Engine, dstTbl.Engine) + } + if dstTbl.Charset != srcTbl.Charset { + t.Errorf("CHARSET mismatch: %q vs %q", srcTbl.Charset, dstTbl.Charset) + } + if dstTbl.Collation != srcTbl.Collation { + t.Errorf("COLLATION mismatch: %q vs %q", srcTbl.Collation, dstTbl.Collation) + } + if dstTbl.Comment != srcTbl.Comment { + t.Errorf("COMMENT mismatch: %q vs %q", srcTbl.Comment, dstTbl.Comment) + } + + ddl := c.ShowCreateTable("testdb", "dst") + if !strings.Contains(ddl, "ENGINE=InnoDB") { + t.Errorf("expected ENGINE=InnoDB in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "COMMENT='source table'") { + t.Errorf("expected COMMENT='source table' in DDL:\n%s", ddl) + } + }) +} diff --git a/tidb/catalog/wt_10_2_test.go b/tidb/catalog/wt_10_2_test.go new file mode 100644 index 00000000..8efec721 --- /dev/null +++ b/tidb/catalog/wt_10_2_test.go @@ -0,0 +1,328 @@ +package catalog + +import ( + "strings" + "testing" +) + +// --- Section 6.2 (Phase 6): LIKE Edge Cases (7 scenarios) --- +// File target: wt_10_2_test.go +// Proof: go test ./mysql/catalog/ -short -count=1 -run "TestWalkThrough_10_2" + +func TestWalkThrough_10_2_LIKEEdgeCases(t *testing.T) { + // Scenario 1: LIKE copies generated columns — expression and VIRTUAL/STORED preserved + t.Run("like_copies_generated_columns", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE src ( + id INT NOT NULL, + price DECIMAL(10,2) NOT NULL, + qty INT NOT NULL, + total DECIMAL(10,2) AS (price * qty) STORED, + label VARCHAR(100) AS (CONCAT('item-', id)) VIRTUAL, + PRIMARY KEY (id) + )`) + wtExec(t, c, `CREATE TABLE dst LIKE src`) + + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table dst not found") + } + + // Check STORED generated column + totalCol := dstTbl.GetColumn("total") + if totalCol == nil { + t.Fatal("column total not found in dst") + } + if totalCol.Generated == nil { + t.Fatal("total should be a generated column") + } + if !totalCol.Generated.Stored { + t.Error("total should be STORED") + } + if totalCol.Generated.Expr == "" { + t.Error("total generated expression should not be empty") + } + + // Check VIRTUAL generated column + labelCol := dstTbl.GetColumn("label") + if labelCol == nil { + t.Fatal("column label not found in dst") + } + if labelCol.Generated == nil { + t.Fatal("label should be a generated column") + } + if labelCol.Generated.Stored { + t.Error("label should be VIRTUAL (not stored)") + } + if labelCol.Generated.Expr == "" { + t.Error("label generated expression should not be empty") + } + + // Verify expressions match source + srcTbl := c.GetDatabase("testdb").GetTable("src") + srcTotal := srcTbl.GetColumn("total") + srcLabel := srcTbl.GetColumn("label") + if totalCol.Generated.Expr != srcTotal.Generated.Expr { + t.Errorf("total expression mismatch: src=%q dst=%q", srcTotal.Generated.Expr, totalCol.Generated.Expr) + } + if labelCol.Generated.Expr != srcLabel.Generated.Expr { + t.Errorf("label expression mismatch: src=%q dst=%q", srcLabel.Generated.Expr, labelCol.Generated.Expr) + } + }) + + // Scenario 2: LIKE copies INVISIBLE columns + t.Run("like_copies_invisible_columns", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE src ( + id INT NOT NULL, + visible_col VARCHAR(100), + hidden_col INT INVISIBLE, + PRIMARY KEY (id) + )`) + wtExec(t, c, `CREATE TABLE dst LIKE src`) + + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table dst not found") + } + + hiddenCol := dstTbl.GetColumn("hidden_col") + if hiddenCol == nil { + t.Fatal("column hidden_col not found in dst") + } + if !hiddenCol.Invisible { + t.Error("hidden_col should be INVISIBLE in dst") + } + + visibleCol := dstTbl.GetColumn("visible_col") + if visibleCol == nil { + t.Fatal("column visible_col not found in dst") + } + if visibleCol.Invisible { + t.Error("visible_col should NOT be invisible in dst") + } + + // Verify SHOW CREATE TABLE renders INVISIBLE + sct := c.ShowCreateTable("testdb", "dst") + if !strings.Contains(sct, "INVISIBLE") { + t.Errorf("SHOW CREATE TABLE should contain INVISIBLE keyword, got:\n%s", sct) + } + }) + + // Scenario 3: LIKE does NOT copy partitioning — target table is unpartitioned + t.Run("like_does_not_copy_partitioning", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE src ( + id INT NOT NULL, + created_at DATE NOT NULL, + PRIMARY KEY (id, created_at) + ) PARTITION BY RANGE (YEAR(created_at)) ( + PARTITION p2020 VALUES LESS THAN (2021), + PARTITION p2021 VALUES LESS THAN (2022), + PARTITION pmax VALUES LESS THAN MAXVALUE + )`) + wtExec(t, c, `CREATE TABLE dst LIKE src`) + + // Source should have partitioning + srcTbl := c.GetDatabase("testdb").GetTable("src") + if srcTbl.Partitioning == nil { + t.Fatal("source table should have partitioning") + } + + // Destination should NOT have partitioning (MySQL behavior) + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table dst not found") + } + if dstTbl.Partitioning != nil { + t.Error("LIKE should NOT copy partitioning — target table should be unpartitioned") + } + + // SHOW CREATE TABLE for dst should not mention PARTITION + sct := c.ShowCreateTable("testdb", "dst") + if strings.Contains(sct, "PARTITION") { + t.Errorf("SHOW CREATE TABLE for dst should not contain PARTITION, got:\n%s", sct) + } + }) + + // Scenario 4: LIKE from table with prefix index — prefix length preserved + t.Run("like_copies_prefix_index", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE src ( + id INT NOT NULL, + name VARCHAR(255), + bio TEXT, + PRIMARY KEY (id), + INDEX idx_name (name(50)), + INDEX idx_bio (bio(100)) + )`) + wtExec(t, c, `CREATE TABLE dst LIKE src`) + + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table dst not found") + } + + // Find idx_name and verify prefix length + var idxName, idxBio *Index + for _, idx := range dstTbl.Indexes { + switch idx.Name { + case "idx_name": + idxName = idx + case "idx_bio": + idxBio = idx + } + } + + if idxName == nil { + t.Fatal("index idx_name not found in dst") + } + if len(idxName.Columns) != 1 || idxName.Columns[0].Length != 50 { + t.Errorf("idx_name prefix length: expected 50, got %d", idxName.Columns[0].Length) + } + + if idxBio == nil { + t.Fatal("index idx_bio not found in dst") + } + if len(idxBio.Columns) != 1 || idxBio.Columns[0].Length != 100 { + t.Errorf("idx_bio prefix length: expected 100, got %d", idxBio.Columns[0].Length) + } + + // Verify SHOW CREATE TABLE renders prefix lengths + sct := c.ShowCreateTable("testdb", "dst") + if !strings.Contains(sct, "`name`(50)") { + t.Errorf("SHOW CREATE TABLE should contain name(50), got:\n%s", sct) + } + if !strings.Contains(sct, "`bio`(100)") { + t.Errorf("SHOW CREATE TABLE should contain bio(100), got:\n%s", sct) + } + }) + + // Scenario 5: LIKE into TEMPORARY TABLE + t.Run("like_into_temporary_table", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE src ( + id INT NOT NULL, + name VARCHAR(100), + PRIMARY KEY (id) + )`) + wtExec(t, c, `CREATE TEMPORARY TABLE dst LIKE src`) + + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table dst not found") + } + if !dstTbl.Temporary { + t.Error("dst should be a TEMPORARY table") + } + + // Columns should still be copied + if len(dstTbl.Columns) != 2 { + t.Errorf("expected 2 columns, got %d", len(dstTbl.Columns)) + } + + // Source should NOT be temporary + srcTbl := c.GetDatabase("testdb").GetTable("src") + if srcTbl.Temporary { + t.Error("src should not be temporary") + } + }) + + // Scenario 6: LIKE cross-database — source in different database + t.Run("like_cross_database", func(t *testing.T) { + c := wtSetup(t) + // Create source table in a different database + wtExec(t, c, "CREATE DATABASE other_db") + wtExec(t, c, `CREATE TABLE other_db.src ( + id INT NOT NULL, + name VARCHAR(100) NOT NULL, + score INT DEFAULT 0, + PRIMARY KEY (id), + INDEX idx_name (name) + )`) + + // Create LIKE table in testdb referencing other_db.src + wtExec(t, c, `CREATE TABLE dst LIKE other_db.src`) + + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table dst not found in testdb") + } + + // Verify columns were copied + if len(dstTbl.Columns) != 3 { + t.Fatalf("expected 3 columns, got %d", len(dstTbl.Columns)) + } + nameCol := dstTbl.GetColumn("name") + if nameCol == nil { + t.Fatal("column name not found in dst") + } + if nameCol.Nullable { + t.Error("name should be NOT NULL") + } + + scoreCol := dstTbl.GetColumn("score") + if scoreCol == nil { + t.Fatal("column score not found in dst") + } + if scoreCol.Default == nil || *scoreCol.Default != "0" { + def := "" + if scoreCol.Default != nil { + def = *scoreCol.Default + } + t.Errorf("score default: expected '0', got %q", def) + } + + // Verify indexes were copied + var idxName *Index + for _, idx := range dstTbl.Indexes { + if idx.Name == "idx_name" { + idxName = idx + break + } + } + if idxName == nil { + t.Fatal("index idx_name not found in dst") + } + + // dst should belong to testdb, not other_db + if dstTbl.Database.Name != "testdb" { + t.Errorf("dst should belong to testdb, got %q", dstTbl.Database.Name) + } + }) + + // Scenario 7: LIKE then ALTER TABLE ADD COLUMN — verify table is independently modifiable + t.Run("like_then_alter_add_column", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE src ( + id INT NOT NULL, + name VARCHAR(100), + PRIMARY KEY (id) + )`) + wtExec(t, c, `CREATE TABLE dst LIKE src`) + + // Add a column to dst — should not affect src + wtExec(t, c, `ALTER TABLE dst ADD COLUMN email VARCHAR(255) NOT NULL`) + + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table dst not found") + } + if len(dstTbl.Columns) != 3 { + t.Fatalf("dst: expected 3 columns after ADD COLUMN, got %d", len(dstTbl.Columns)) + } + emailCol := dstTbl.GetColumn("email") + if emailCol == nil { + t.Fatal("column email not found in dst") + } + + // Source should be unaffected + srcTbl := c.GetDatabase("testdb").GetTable("src") + if len(srcTbl.Columns) != 2 { + t.Fatalf("src: expected 2 columns (unchanged), got %d", len(srcTbl.Columns)) + } + if srcTbl.GetColumn("email") != nil { + t.Error("src should NOT have email column — tables should be independent") + } + }) +} diff --git a/tidb/catalog/wt_11_1_test.go b/tidb/catalog/wt_11_1_test.go new file mode 100644 index 00000000..b0e91412 --- /dev/null +++ b/tidb/catalog/wt_11_1_test.go @@ -0,0 +1,252 @@ +package catalog + +import "testing" + +// --- Section 7.1 (Phase 7): Catalog State Isolation (8 scenarios) --- +// File target: wt_11_1_test.go +// Proof: go test ./mysql/catalog/ -short -count=1 -run "TestWalkThrough_11_1" + +func TestWalkThrough_11_1_CatalogStateIsolation(t *testing.T) { + t.Run("separate_catalogs_independent", func(t *testing.T) { + // Scenario 1: Execute DDL on catalog A, verify catalog B (separate New()) is unaffected. + catA := wtSetup(t) + catB := wtSetup(t) + + wtExec(t, catA, "CREATE TABLE t1 (id INT NOT NULL)") + + tblA := catA.GetDatabase("testdb").GetTable("t1") + if tblA == nil { + t.Fatal("catalog A should have table t1") + } + + tblB := catB.GetDatabase("testdb").GetTable("t1") + if tblB != nil { + t.Fatal("catalog B should NOT have table t1") + } + }) + + t.Run("dropped_reference_not_reusable", func(t *testing.T) { + // Scenario 2: Create table, get reference, execute DROP — reference should not + // be reusable on new table with same name. + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT NOT NULL)") + + oldRef := c.GetDatabase("testdb").GetTable("t1") + if oldRef == nil { + t.Fatal("table t1 should exist") + } + + wtExec(t, c, "DROP TABLE t1") + if c.GetDatabase("testdb").GetTable("t1") != nil { + t.Fatal("t1 should be dropped") + } + + // Re-create with different schema. + wtExec(t, c, "CREATE TABLE t1 (id INT NOT NULL, name VARCHAR(50))") + + newRef := c.GetDatabase("testdb").GetTable("t1") + if newRef == nil { + t.Fatal("new t1 should exist after re-creation") + } + + // The old reference should differ from the new one. + if oldRef == newRef { + t.Fatal("old reference and new reference should be different pointers") + } + if len(oldRef.Columns) == len(newRef.Columns) { + t.Fatal("old table had 1 column, new table has 2 — they should differ") + } + }) + + t.Run("incremental_state_accumulation", func(t *testing.T) { + // Scenario 3: Two Exec() calls building up state incrementally — state accumulates. + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT NOT NULL)") + wtExec(t, c, "CREATE TABLE t2 (id INT NOT NULL, val VARCHAR(100))") + + db := c.GetDatabase("testdb") + if db.GetTable("t1") == nil { + t.Fatal("t1 should exist after first Exec") + } + if db.GetTable("t2") == nil { + t.Fatal("t2 should exist after second Exec") + } + + // Further accumulation via ALTER. + wtExec(t, c, "ALTER TABLE t1 ADD COLUMN name VARCHAR(50)") + tbl := db.GetTable("t1") + if len(tbl.Columns) != 2 { + t.Fatalf("expected 2 columns after ALTER, got %d", len(tbl.Columns)) + } + }) + + t.Run("continue_on_error_partial_apply", func(t *testing.T) { + // Scenario 4: Exec with ContinueOnError — successful statements applied, failed ones not. + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT NOT NULL)") + + sql := `CREATE TABLE t2 (id INT NOT NULL); +CREATE TABLE t1 (id INT NOT NULL); +CREATE TABLE t3 (id INT NOT NULL);` + + results, err := c.Exec(sql, &ExecOptions{ContinueOnError: true}) + if err != nil { + t.Fatalf("parse error: %v", err) + } + + // t2 should succeed (index 0). + if results[0].Error != nil { + t.Fatalf("stmt 0 (CREATE t2) should succeed, got: %v", results[0].Error) + } + // t1 should fail — duplicate (index 1). + if results[1].Error == nil { + t.Fatal("stmt 1 (CREATE t1 duplicate) should fail") + } + // t3 should succeed (index 2). + if results[2].Error != nil { + t.Fatalf("stmt 2 (CREATE t3) should succeed, got: %v", results[2].Error) + } + + db := c.GetDatabase("testdb") + if db.GetTable("t2") == nil { + t.Fatal("t2 should exist (successful stmt before error)") + } + if db.GetTable("t3") == nil { + t.Fatal("t3 should exist (successful stmt after error with ContinueOnError)") + } + }) + + t.Run("exec_empty_string", func(t *testing.T) { + // Scenario 5: Exec empty string — no state change, no error. + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT NOT NULL)") + + results, err := c.Exec("", nil) + if err != nil { + t.Fatalf("expected no parse error, got: %v", err) + } + if results != nil { + t.Fatalf("expected nil results for empty string, got %d results", len(results)) + } + + // State unchanged. + if c.GetDatabase("testdb").GetTable("t1") == nil { + t.Fatal("t1 should still exist after empty Exec") + } + }) + + t.Run("exec_dml_only_skipped", func(t *testing.T) { + // Scenario 6: Exec with only DML — no state change, statements skipped. + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT NOT NULL)") + + results, err := c.Exec("SELECT 1; INSERT INTO t1 VALUES (1);", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + + for i, r := range results { + if r.Error != nil { + t.Fatalf("stmt %d should not error, got: %v", i, r.Error) + } + if !r.Skipped { + t.Fatalf("stmt %d should be skipped (DML)", i) + } + } + + // State unchanged — table still has same structure. + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Fatal("t1 should still exist") + } + if len(tbl.Columns) != 1 { + t.Fatalf("expected 1 column, got %d", len(tbl.Columns)) + } + }) + + t.Run("drop_database_cascade", func(t *testing.T) { + // Scenario 7: DROP DATABASE cascade — all tables, views, routines, triggers, events removed. + c := wtSetup(t) + + // Create various objects. + wtExec(t, c, "CREATE TABLE t1 (id INT NOT NULL, val INT)") + wtExec(t, c, "CREATE VIEW v1 AS SELECT id FROM t1") + wtExec(t, c, "CREATE FUNCTION f1() RETURNS INT DETERMINISTIC RETURN 1") + wtExec(t, c, "CREATE TRIGGER tr1 BEFORE INSERT ON t1 FOR EACH ROW SET NEW.val = 0") + wtExec(t, c, "CREATE EVENT ev1 ON SCHEDULE EVERY 1 HOUR DO SELECT 1") + + // Verify objects exist. + db := c.GetDatabase("testdb") + if db == nil { + t.Fatal("testdb should exist") + } + if db.GetTable("t1") == nil { + t.Fatal("t1 should exist before DROP DATABASE") + } + if db.Views[toLower("v1")] == nil { + t.Fatal("v1 should exist before DROP DATABASE") + } + if db.Functions[toLower("f1")] == nil { + t.Fatal("f1 should exist before DROP DATABASE") + } + if db.Triggers[toLower("tr1")] == nil { + t.Fatal("tr1 should exist before DROP DATABASE") + } + if db.Events[toLower("ev1")] == nil { + t.Fatal("ev1 should exist before DROP DATABASE") + } + + // DROP DATABASE removes everything. + wtExec(t, c, "DROP DATABASE testdb") + + if c.GetDatabase("testdb") != nil { + t.Fatal("testdb should not exist after DROP DATABASE") + } + // Current database should be cleared. + if c.CurrentDatabase() != "" { + t.Fatalf("current database should be empty after DROP, got %q", c.CurrentDatabase()) + } + }) + + t.Run("operations_across_two_databases", func(t *testing.T) { + // Scenario 8: Operations across two databases — CREATE TABLE in db1 and db2 independently. + c := New() + wtExec(t, c, "CREATE DATABASE db1; CREATE DATABASE db2;") + + c.SetCurrentDatabase("db1") + wtExec(t, c, "CREATE TABLE t1 (id INT NOT NULL, a VARCHAR(50))") + + c.SetCurrentDatabase("db2") + wtExec(t, c, "CREATE TABLE t1 (id INT NOT NULL, b INT, c INT)") + + // Verify db1.t1 has 2 columns. + db1 := c.GetDatabase("db1") + tbl1 := db1.GetTable("t1") + if tbl1 == nil { + t.Fatal("db1.t1 should exist") + } + if len(tbl1.Columns) != 2 { + t.Fatalf("db1.t1 expected 2 columns, got %d", len(tbl1.Columns)) + } + + // Verify db2.t1 has 3 columns. + db2 := c.GetDatabase("db2") + tbl2 := db2.GetTable("t1") + if tbl2 == nil { + t.Fatal("db2.t1 should exist") + } + if len(tbl2.Columns) != 3 { + t.Fatalf("db2.t1 expected 3 columns, got %d", len(tbl2.Columns)) + } + + // Dropping table in db1 should not affect db2. + c.SetCurrentDatabase("db1") + wtExec(t, c, "DROP TABLE t1") + if db1.GetTable("t1") != nil { + t.Fatal("db1.t1 should be dropped") + } + if db2.GetTable("t1") == nil { + t.Fatal("db2.t1 should still exist after dropping db1.t1") + } + }) +} diff --git a/tidb/catalog/wt_11_2_test.go b/tidb/catalog/wt_11_2_test.go new file mode 100644 index 00000000..48cdf2ed --- /dev/null +++ b/tidb/catalog/wt_11_2_test.go @@ -0,0 +1,334 @@ +package catalog + +import ( + "testing" +) + +// --- Section 7.2 (starmap): Table State Consistency (6 scenarios) --- +// These tests verify internal catalog consistency after ALTER TABLE operations. + +// checkPositionsSequential verifies that all column positions in the table are +// 1-based, sequential, and gap-free. +func checkPositionsSequential(t *testing.T, tbl *Table) { + t.Helper() + for i, col := range tbl.Columns { + expected := i + 1 + if col.Position != expected { + t.Errorf("column %q: expected position %d, got %d", col.Name, expected, col.Position) + } + } +} + +// checkColByNameConsistent verifies that the colByName map is consistent with +// the actual Columns slice: every column is findable via GetColumn and the +// returned column matches the one in the slice. +func checkColByNameConsistent(t *testing.T, tbl *Table) { + t.Helper() + for i, col := range tbl.Columns { + got := tbl.GetColumn(col.Name) + if got == nil { + t.Errorf("GetColumn(%q) returned nil, but column exists at index %d", col.Name, i) + continue + } + if got != col { + t.Errorf("GetColumn(%q) returned different pointer than Columns[%d]", col.Name, i) + } + } +} + +// Scenario 1: After ADD COLUMN, all column positions are sequential (1-based, no gaps) +func TestWalkThrough_11_2_AddColumnPositions(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT NOT NULL, name VARCHAR(100), age INT)") + + // Add column at end. + wtExec(t, c, "ALTER TABLE t1 ADD COLUMN email VARCHAR(255)") + tbl := c.GetDatabase("testdb").GetTable("t1") + checkPositionsSequential(t, tbl) + + if len(tbl.Columns) != 4 { + t.Fatalf("expected 4 columns, got %d", len(tbl.Columns)) + } + + // Add column FIRST. + wtExec(t, c, "ALTER TABLE t1 ADD COLUMN flag TINYINT FIRST") + tbl = c.GetDatabase("testdb").GetTable("t1") + checkPositionsSequential(t, tbl) + + if len(tbl.Columns) != 5 { + t.Fatalf("expected 5 columns, got %d", len(tbl.Columns)) + } + if tbl.Columns[0].Name != "flag" { + t.Errorf("expected first column 'flag', got %q", tbl.Columns[0].Name) + } + + // Add column AFTER a specific column. + wtExec(t, c, "ALTER TABLE t1 ADD COLUMN score INT AFTER name") + tbl = c.GetDatabase("testdb").GetTable("t1") + checkPositionsSequential(t, tbl) + + if len(tbl.Columns) != 6 { + t.Fatalf("expected 6 columns, got %d", len(tbl.Columns)) + } +} + +// Scenario 2: After DROP COLUMN, remaining column positions are resequenced +func TestWalkThrough_11_2_DropColumnResequence(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (a INT, b INT, c INT, d INT, e INT)") + + // Drop from middle. + wtExec(t, c, "ALTER TABLE t1 DROP COLUMN c") + tbl := c.GetDatabase("testdb").GetTable("t1") + checkPositionsSequential(t, tbl) + + if len(tbl.Columns) != 4 { + t.Fatalf("expected 4 columns, got %d", len(tbl.Columns)) + } + // Expected order: a, b, d, e + expectedNames := []string{"a", "b", "d", "e"} + for i, name := range expectedNames { + if tbl.Columns[i].Name != name { + t.Errorf("column %d: expected %q, got %q", i, name, tbl.Columns[i].Name) + } + } + + // Drop from first. + wtExec(t, c, "ALTER TABLE t1 DROP COLUMN a") + tbl = c.GetDatabase("testdb").GetTable("t1") + checkPositionsSequential(t, tbl) + + if len(tbl.Columns) != 3 { + t.Fatalf("expected 3 columns, got %d", len(tbl.Columns)) + } + + // Drop from last. + wtExec(t, c, "ALTER TABLE t1 DROP COLUMN e") + tbl = c.GetDatabase("testdb").GetTable("t1") + checkPositionsSequential(t, tbl) + + if len(tbl.Columns) != 2 { + t.Fatalf("expected 2 columns, got %d", len(tbl.Columns)) + } +} + +// Scenario 3: After MODIFY COLUMN FIRST, positions reflect new order +func TestWalkThrough_11_2_ModifyColumnFirst(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (a INT, b INT, c INT)") + + // Move last column to first. + wtExec(t, c, "ALTER TABLE t1 MODIFY COLUMN c INT FIRST") + tbl := c.GetDatabase("testdb").GetTable("t1") + checkPositionsSequential(t, tbl) + + expectedNames := []string{"c", "a", "b"} + for i, name := range expectedNames { + if tbl.Columns[i].Name != name { + t.Errorf("column %d: expected %q, got %q", i, name, tbl.Columns[i].Name) + } + } + + // Move middle column to first. + wtExec(t, c, "ALTER TABLE t1 MODIFY COLUMN a INT FIRST") + tbl = c.GetDatabase("testdb").GetTable("t1") + checkPositionsSequential(t, tbl) + + expectedNames = []string{"a", "c", "b"} + for i, name := range expectedNames { + if tbl.Columns[i].Name != name { + t.Errorf("column %d: expected %q, got %q", i, name, tbl.Columns[i].Name) + } + } +} + +// Scenario 4: After RENAME COLUMN, index column references updated +func TestWalkThrough_11_2_RenameColumnIndexRefs(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT NOT NULL, name VARCHAR(100), age INT, INDEX idx_name (name), INDEX idx_combo (name, age))") + + // Rename the column used in indexes. + wtExec(t, c, "ALTER TABLE t1 RENAME COLUMN name TO full_name") + tbl := c.GetDatabase("testdb").GetTable("t1") + + // Old name should not be findable. + if tbl.GetColumn("name") != nil { + t.Error("old column 'name' should no longer exist") + } + // New name should be findable. + if tbl.GetColumn("full_name") == nil { + t.Fatal("column 'full_name' not found after rename") + } + + // Verify index column references reflect the new name. + for _, idx := range tbl.Indexes { + for _, ic := range idx.Columns { + if ic.Name == "name" { + t.Errorf("index %q still references old column name 'name'", idx.Name) + } + } + } + + // Specifically check idx_name. + var idxName *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_name" { + idxName = idx + break + } + } + if idxName == nil { + t.Fatal("index idx_name not found") + } + if len(idxName.Columns) != 1 || idxName.Columns[0].Name != "full_name" { + t.Errorf("idx_name: expected column 'full_name', got %v", idxName.Columns) + } + + // Check composite index idx_combo. + var idxCombo *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_combo" { + idxCombo = idx + break + } + } + if idxCombo == nil { + t.Fatal("index idx_combo not found") + } + if len(idxCombo.Columns) != 2 { + t.Fatalf("idx_combo: expected 2 columns, got %d", len(idxCombo.Columns)) + } + if idxCombo.Columns[0].Name != "full_name" { + t.Errorf("idx_combo column 0: expected 'full_name', got %q", idxCombo.Columns[0].Name) + } + if idxCombo.Columns[1].Name != "age" { + t.Errorf("idx_combo column 1: expected 'age', got %q", idxCombo.Columns[1].Name) + } + + // Positions should still be consistent. + checkPositionsSequential(t, tbl) + checkColByNameConsistent(t, tbl) +} + +// Scenario 5: After DROP INDEX, remaining indexes unaffected +func TestWalkThrough_11_2_DropIndexRemainingUnaffected(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT NOT NULL, name VARCHAR(100), age INT, INDEX idx_name (name), INDEX idx_age (age), INDEX idx_combo (name, age))") + + // Drop the middle index. + wtExec(t, c, "ALTER TABLE t1 DROP INDEX idx_age") + tbl := c.GetDatabase("testdb").GetTable("t1") + + // idx_age should be gone. + for _, idx := range tbl.Indexes { + if idx.Name == "idx_age" { + t.Fatal("index idx_age should have been dropped") + } + } + + // idx_name should still exist and be intact. + var idxName *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_name" { + idxName = idx + break + } + } + if idxName == nil { + t.Fatal("index idx_name should still exist") + } + if len(idxName.Columns) != 1 || idxName.Columns[0].Name != "name" { + t.Errorf("idx_name columns changed unexpectedly") + } + + // idx_combo should still exist and be intact. + var idxCombo *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_combo" { + idxCombo = idx + break + } + } + if idxCombo == nil { + t.Fatal("index idx_combo should still exist") + } + if len(idxCombo.Columns) != 2 { + t.Fatalf("idx_combo: expected 2 columns, got %d", len(idxCombo.Columns)) + } + if idxCombo.Columns[0].Name != "name" || idxCombo.Columns[1].Name != "age" { + t.Errorf("idx_combo columns changed unexpectedly") + } + + // Columns should be unaffected. + checkPositionsSequential(t, tbl) + if len(tbl.Columns) != 3 { + t.Errorf("expected 3 columns, got %d", len(tbl.Columns)) + } +} + +// Scenario 6: colByName index is consistent after every ALTER TABLE operation +func TestWalkThrough_11_2_ColByNameConsistency(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT NOT NULL, name VARCHAR(100), age INT)") + tbl := c.GetDatabase("testdb").GetTable("t1") + checkColByNameConsistent(t, tbl) + checkPositionsSequential(t, tbl) + + // ADD COLUMN + wtExec(t, c, "ALTER TABLE t1 ADD COLUMN email VARCHAR(255)") + tbl = c.GetDatabase("testdb").GetTable("t1") + checkColByNameConsistent(t, tbl) + checkPositionsSequential(t, tbl) + + // ADD COLUMN FIRST + wtExec(t, c, "ALTER TABLE t1 ADD COLUMN flag TINYINT FIRST") + tbl = c.GetDatabase("testdb").GetTable("t1") + checkColByNameConsistent(t, tbl) + checkPositionsSequential(t, tbl) + + // DROP COLUMN + wtExec(t, c, "ALTER TABLE t1 DROP COLUMN age") + tbl = c.GetDatabase("testdb").GetTable("t1") + checkColByNameConsistent(t, tbl) + checkPositionsSequential(t, tbl) + // Dropped column must not be findable. + if tbl.GetColumn("age") != nil { + t.Error("dropped column 'age' should not be findable via GetColumn") + } + + // MODIFY COLUMN (change type, move FIRST) + wtExec(t, c, "ALTER TABLE t1 MODIFY COLUMN email VARCHAR(500) FIRST") + tbl = c.GetDatabase("testdb").GetTable("t1") + checkColByNameConsistent(t, tbl) + checkPositionsSequential(t, tbl) + + // CHANGE COLUMN (rename) + wtExec(t, c, "ALTER TABLE t1 CHANGE COLUMN name full_name VARCHAR(200)") + tbl = c.GetDatabase("testdb").GetTable("t1") + checkColByNameConsistent(t, tbl) + checkPositionsSequential(t, tbl) + if tbl.GetColumn("name") != nil { + t.Error("old column 'name' should not be findable after CHANGE") + } + if tbl.GetColumn("full_name") == nil { + t.Error("new column 'full_name' should be findable after CHANGE") + } + + // RENAME COLUMN + wtExec(t, c, "ALTER TABLE t1 RENAME COLUMN full_name TO display_name") + tbl = c.GetDatabase("testdb").GetTable("t1") + checkColByNameConsistent(t, tbl) + checkPositionsSequential(t, tbl) + if tbl.GetColumn("full_name") != nil { + t.Error("old column 'full_name' should not be findable after RENAME") + } + if tbl.GetColumn("display_name") == nil { + t.Error("new column 'display_name' should be findable after RENAME") + } + + // ADD COLUMN AFTER + wtExec(t, c, "ALTER TABLE t1 ADD COLUMN score INT AFTER email") + tbl = c.GetDatabase("testdb").GetTable("t1") + checkColByNameConsistent(t, tbl) + checkPositionsSequential(t, tbl) +} diff --git a/tidb/catalog/wt_12_1_test.go b/tidb/catalog/wt_12_1_test.go new file mode 100644 index 00000000..d1328106 --- /dev/null +++ b/tidb/catalog/wt_12_1_test.go @@ -0,0 +1,176 @@ +package catalog + +import ( + "strings" + "testing" +) + +// --- Section 8.1 (Phase 8): Prefix and Expression Index Rendering (8 scenarios) --- +// File target: wt_12_1_test.go +// Proof: go test ./mysql/catalog/ -short -count=1 -run "TestWalkThrough_12_1" + +func TestWalkThrough_12_1_PrefixAndExpressionIndexRendering(t *testing.T) { + // Scenario 1: KEY idx (col(10)) — prefix length rendered in SHOW CREATE + t.Run("prefix_length_rendered", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT NOT NULL, + name VARCHAR(100), + KEY idx_name (name(10)), + PRIMARY KEY (id) + )`) + + ddl := c.ShowCreateTable("testdb", "t1") + if !strings.Contains(ddl, "`name`(10)") { + t.Errorf("expected prefix index col(10) in SHOW CREATE TABLE, got:\n%s", ddl) + } + if !strings.Contains(ddl, "KEY `idx_name`") { + t.Errorf("expected KEY idx_name in SHOW CREATE TABLE, got:\n%s", ddl) + } + }) + + // Scenario 2: KEY idx (col1(10), col2(20)) — multi-column prefix index + t.Run("multi_column_prefix", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT NOT NULL, + col1 VARCHAR(100), + col2 VARCHAR(100), + KEY idx_multi (col1(10), col2(20)), + PRIMARY KEY (id) + )`) + + ddl := c.ShowCreateTable("testdb", "t1") + if !strings.Contains(ddl, "`col1`(10)") { + t.Errorf("expected col1(10) in SHOW CREATE TABLE, got:\n%s", ddl) + } + if !strings.Contains(ddl, "`col2`(20)") { + t.Errorf("expected col2(20) in SHOW CREATE TABLE, got:\n%s", ddl) + } + }) + + // Scenario 3: KEY idx (col(10), col2) — mixed prefix and full column + t.Run("mixed_prefix_and_full_column", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT NOT NULL, + col1 VARCHAR(100), + col2 INT, + KEY idx_mixed (col1(10), col2), + PRIMARY KEY (id) + )`) + + ddl := c.ShowCreateTable("testdb", "t1") + if !strings.Contains(ddl, "`col1`(10)") { + t.Errorf("expected col1(10) in SHOW CREATE TABLE, got:\n%s", ddl) + } + // col2 should appear without prefix length + if !strings.Contains(ddl, "`col2`") { + t.Errorf("expected col2 in SHOW CREATE TABLE, got:\n%s", ddl) + } + // Ensure col2 does NOT have a prefix length + if strings.Contains(ddl, "`col2`(") { + t.Errorf("col2 should not have prefix length, got:\n%s", ddl) + } + }) + + // Scenario 4: KEY idx ((UPPER(col))) — expression index with function + t.Run("expression_index_function", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT NOT NULL, + name VARCHAR(100), + KEY idx_expr ((UPPER(name))), + PRIMARY KEY (id) + )`) + + ddl := c.ShowCreateTable("testdb", "t1") + // MySQL renders expression indexes with double parens: ((expr)) + upperIdx := strings.Contains(ddl, "(UPPER(") || strings.Contains(ddl, "(upper(") + if !upperIdx { + t.Errorf("expected expression index with UPPER in SHOW CREATE TABLE, got:\n%s", ddl) + } + }) + + // Scenario 5: KEY idx ((col1 + col2)) — expression index with arithmetic + t.Run("expression_index_arithmetic", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT NOT NULL, + col1 INT, + col2 INT, + KEY idx_arith ((col1 + col2)), + PRIMARY KEY (id) + )`) + + ddl := c.ShowCreateTable("testdb", "t1") + // The expression should be rendered inside parens + if !strings.Contains(ddl, "col1") || !strings.Contains(ddl, "col2") { + t.Errorf("expected expression index with col1 and col2 in SHOW CREATE TABLE, got:\n%s", ddl) + } + // Should have double-paren wrapping for expression index + if !strings.Contains(ddl, "((") { + t.Errorf("expected double parens for expression index in SHOW CREATE TABLE, got:\n%s", ddl) + } + }) + + // Scenario 6: UNIQUE KEY idx ((UPPER(col))) — unique expression index + t.Run("unique_expression_index", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT NOT NULL, + name VARCHAR(100), + UNIQUE KEY idx_uexpr ((UPPER(name))), + PRIMARY KEY (id) + )`) + + ddl := c.ShowCreateTable("testdb", "t1") + if !strings.Contains(ddl, "UNIQUE KEY `idx_uexpr`") { + t.Errorf("expected UNIQUE KEY idx_uexpr in SHOW CREATE TABLE, got:\n%s", ddl) + } + upperIdx := strings.Contains(ddl, "(UPPER(") || strings.Contains(ddl, "(upper(") + if !upperIdx { + t.Errorf("expected expression with UPPER in SHOW CREATE TABLE, got:\n%s", ddl) + } + }) + + // Scenario 7: KEY idx (col1, (UPPER(col2))) — mixed regular and expression columns + t.Run("mixed_regular_and_expression", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT NOT NULL, + col1 INT, + col2 VARCHAR(100), + KEY idx_mix (col1, (UPPER(col2))), + PRIMARY KEY (id) + )`) + + ddl := c.ShowCreateTable("testdb", "t1") + if !strings.Contains(ddl, "`col1`") { + t.Errorf("expected regular column col1 in index, got:\n%s", ddl) + } + upperIdx := strings.Contains(ddl, "(UPPER(") || strings.Contains(ddl, "(upper(") + if !upperIdx { + t.Errorf("expected expression with UPPER in index, got:\n%s", ddl) + } + }) + + // Scenario 8: KEY idx (col DESC) — descending index column + t.Run("descending_index_column", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT NOT NULL, + name VARCHAR(100), + KEY idx_desc (name DESC), + PRIMARY KEY (id) + )`) + + ddl := c.ShowCreateTable("testdb", "t1") + if !strings.Contains(ddl, "DESC") { + t.Errorf("expected DESC in index column rendering, got:\n%s", ddl) + } + if !strings.Contains(ddl, "KEY `idx_desc`") { + t.Errorf("expected KEY idx_desc in SHOW CREATE TABLE, got:\n%s", ddl) + } + }) +} diff --git a/tidb/catalog/wt_12_2_test.go b/tidb/catalog/wt_12_2_test.go new file mode 100644 index 00000000..d90aecea --- /dev/null +++ b/tidb/catalog/wt_12_2_test.go @@ -0,0 +1,162 @@ +package catalog + +import ( + "strings" + "testing" +) + +// --- Section 8.2 (starmap): Index Rendering in SHOW CREATE TABLE (7 scenarios) --- + +func TestWalkThrough_12_2(t *testing.T) { + t.Run("index_ordering_pk_unique_key_fulltext_spatial", func(t *testing.T) { + // Scenario 1: Table with PK + UNIQUE + KEY + FULLTEXT + SPATIAL — verify ordering + // MySQL 8.0 orders: PRIMARY → UNIQUE → regular+SPATIAL → expression → FULLTEXT + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT NOT NULL, + name VARCHAR(100), + body TEXT, + geo GEOMETRY NOT NULL SRID 0, + PRIMARY KEY (id), + KEY idx_name (name), + UNIQUE KEY uk_name (name), + FULLTEXT KEY ft_body (body), + SPATIAL KEY sp_geo (geo) + )`) + + got := c.ShowCreateTable("testdb", "t") + + // Verify all index types are present. + assertContains(t, got, "PRIMARY KEY (`id`)") + assertContains(t, got, "UNIQUE KEY `uk_name` (`name`)") + assertContains(t, got, "KEY `idx_name` (`name`)") + assertContains(t, got, "FULLTEXT KEY `ft_body` (`body`)") + assertContains(t, got, "SPATIAL KEY `sp_geo` (`geo`)") + + // Verify ordering: PRIMARY < UNIQUE < KEY/SPATIAL < FULLTEXT + posPK := strings.Index(got, "PRIMARY KEY") + posUK := strings.Index(got, "UNIQUE KEY") + posKey := strings.Index(got, "KEY `idx_name`") + posSP := strings.Index(got, "SPATIAL KEY") + posFT := strings.Index(got, "FULLTEXT KEY") + + if posPK >= posUK { + t.Error("PRIMARY KEY should appear before UNIQUE KEY") + } + if posUK >= posKey { + t.Error("UNIQUE KEY should appear before regular KEY") + } + // SPATIAL is in the same group as regular keys, so it should come after or alongside idx_name + if posSP <= posUK { + t.Error("SPATIAL KEY should appear after UNIQUE KEY") + } + if posFT <= posKey { + t.Error("FULLTEXT KEY should appear after regular KEY") + } + if posFT <= posSP { + t.Error("FULLTEXT KEY should appear after SPATIAL KEY") + } + }) + + t.Run("regular_and_expression_index_ordering", func(t *testing.T) { + // Scenario 2: Table with regular index + expression index — verify relative ordering + // MySQL 8.0: regular keys come before expression-based keys + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT NOT NULL, + name VARCHAR(100), + age INT, + PRIMARY KEY (id), + KEY idx_name (name), + KEY idx_expr ((age * 2)) + )`) + + got := c.ShowCreateTable("testdb", "t") + + assertContains(t, got, "KEY `idx_name` (`name`)") + assertContains(t, got, "KEY `idx_expr` ((") + + posRegular := strings.Index(got, "KEY `idx_name`") + posExpr := strings.Index(got, "KEY `idx_expr`") + + if posRegular >= posExpr { + t.Errorf("regular KEY should appear before expression KEY\ngot:\n%s", got) + } + }) + + t.Run("index_with_comment", func(t *testing.T) { + // Scenario 3: Index with COMMENT + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT NOT NULL, + name VARCHAR(100), + PRIMARY KEY (id), + KEY idx_name (name) COMMENT 'name lookup' + )`) + + got := c.ShowCreateTable("testdb", "t") + assertContains(t, got, "KEY `idx_name` (`name`) COMMENT 'name lookup'") + }) + + t.Run("index_with_invisible", func(t *testing.T) { + // Scenario 4: Index with INVISIBLE + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT NOT NULL, + name VARCHAR(100), + PRIMARY KEY (id), + KEY idx_name (name) INVISIBLE + )`) + + got := c.ShowCreateTable("testdb", "t") + assertContains(t, got, "KEY `idx_name` (`name`) /*!80000 INVISIBLE */") + }) + + t.Run("index_with_key_block_size", func(t *testing.T) { + // Scenario 5: Index with KEY_BLOCK_SIZE + // MySQL 8.0 parses index-level KEY_BLOCK_SIZE but does not render + // it in SHOW CREATE TABLE output. Verify the catalog matches this + // behavior by NOT including KEY_BLOCK_SIZE on the index line. + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT NOT NULL, + name VARCHAR(100), + PRIMARY KEY (id), + KEY idx_name (name) KEY_BLOCK_SIZE=4 + )`) + + got := c.ShowCreateTable("testdb", "t") + assertContains(t, got, "KEY `idx_name` (`name`)") + if strings.Contains(got, "KEY_BLOCK_SIZE") { + t.Errorf("expected SHOW CREATE TABLE to omit index-level KEY_BLOCK_SIZE (MySQL 8.0 behavior), got:\n%s", got) + } + }) + + t.Run("index_with_using_btree", func(t *testing.T) { + // Scenario 6: Index with USING BTREE + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT NOT NULL, + name VARCHAR(100), + PRIMARY KEY (id), + KEY idx_name (name) USING BTREE + )`) + + got := c.ShowCreateTable("testdb", "t") + assertContains(t, got, "KEY `idx_name` (`name`) USING BTREE") + }) + + t.Run("index_with_using_hash", func(t *testing.T) { + // Scenario 7: Index with USING HASH + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT NOT NULL, + name VARCHAR(100), + PRIMARY KEY (id), + KEY idx_name (name) USING HASH + ) ENGINE=MEMORY`) + + got := c.ShowCreateTable("testdb", "t") + assertContains(t, got, "KEY `idx_name` (`name`) USING HASH") + }) +} diff --git a/tidb/catalog/wt_13_1_test.go b/tidb/catalog/wt_13_1_test.go new file mode 100644 index 00000000..d9883d98 --- /dev/null +++ b/tidb/catalog/wt_13_1_test.go @@ -0,0 +1,110 @@ +package catalog + +import "testing" + +// --- Section 9.1 (Phase 9): SET Variable Effects (7 scenarios) --- +// File target: wt_13_1_test.go +// Proof: go test ./mysql/catalog/ -short -count=1 -run "TestWalkThrough_13_1" + +func TestWalkThrough_13_1_SetVariableEffects(t *testing.T) { + // Scenario 1: SET foreign_key_checks = 0 then CREATE TABLE with invalid FK — succeeds + t.Run("fk_checks_0_allows_invalid_fk", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "SET foreign_key_checks = 0") + // Create a table with FK referencing a non-existent table — should succeed. + wtExec(t, c, `CREATE TABLE child ( + id INT NOT NULL, + parent_id INT, + PRIMARY KEY (id), + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES nonexistent(id) + )`) + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child should exist") + } + }) + + // Scenario 2: SET foreign_key_checks = 1 then CREATE TABLE with invalid FK — fails + t.Run("fk_checks_1_rejects_invalid_fk", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "SET foreign_key_checks = 1") + results, _ := c.Exec(`CREATE TABLE child ( + id INT NOT NULL, + parent_id INT, + PRIMARY KEY (id), + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES nonexistent(id) + )`, nil) + if len(results) == 0 { + t.Fatal("expected result from CREATE TABLE") + } + assertError(t, results[0].Error, ErrFKNoRefTable) + }) + + // Scenario 3: SET foreign_key_checks = OFF — accepts OFF as 0 + t.Run("fk_checks_off_as_0", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "SET foreign_key_checks = OFF") + if c.ForeignKeyChecks() { + t.Error("foreign_key_checks should be false after SET OFF") + } + // Verify it actually allows invalid FK. + wtExec(t, c, `CREATE TABLE child ( + id INT NOT NULL, + parent_id INT, + PRIMARY KEY (id), + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES nonexistent(id) + )`) + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child should exist") + } + }) + + // Scenario 4: SET NAMES utf8mb4 — silently accepted + t.Run("set_names_silently_accepted", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "SET NAMES utf8mb4") + // No state change — catalog still works normally. + wtExec(t, c, "CREATE TABLE t (id INT NOT NULL, PRIMARY KEY (id))") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table t should exist after SET NAMES") + } + }) + + // Scenario 5: SET CHARACTER SET latin1 — silently accepted + t.Run("set_character_set_silently_accepted", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "SET CHARACTER SET latin1") + // No state change — catalog still works normally. + wtExec(t, c, "CREATE TABLE t (id INT NOT NULL, PRIMARY KEY (id))") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table t should exist after SET CHARACTER SET") + } + }) + + // Scenario 6: SET sql_mode = 'STRICT_TRANS_TABLES' — silently accepted + t.Run("set_sql_mode_silently_accepted", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "SET sql_mode = 'STRICT_TRANS_TABLES'") + // No state change — catalog still works normally. + wtExec(t, c, "CREATE TABLE t (id INT NOT NULL, PRIMARY KEY (id))") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table t should exist after SET sql_mode") + } + }) + + // Scenario 7: SET unknown_variable = 'value' — silently accepted + t.Run("set_unknown_variable_silently_accepted", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "SET unknown_variable = 'value'") + // No state change — catalog still works normally. + wtExec(t, c, "CREATE TABLE t (id INT NOT NULL, PRIMARY KEY (id))") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table t should exist after SET unknown_variable") + } + }) +} diff --git a/tidb/catalog/wt_13_2_test.go b/tidb/catalog/wt_13_2_test.go new file mode 100644 index 00000000..f40a2b16 --- /dev/null +++ b/tidb/catalog/wt_13_2_test.go @@ -0,0 +1,173 @@ +package catalog + +import ( + "strings" + "testing" +) + +// --- Section 9.2: SHOW CREATE TABLE Fidelity (9 scenarios) --- +// File target: wt_13_2_test.go +// Proof: go test ./mysql/catalog/ -short -count=1 -run "TestWalkThrough_13_2" + +func TestWalkThrough_13_2_ShowCreateTableFidelity(t *testing.T) { + // Scenario 1: Table with no explicit options — ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 rendered + t.Run("no_explicit_options", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT NOT NULL, + name VARCHAR(100), + PRIMARY KEY (id) + )`) + + got := c.ShowCreateTable("testdb", "t") + assertContains(t, got, "ENGINE=InnoDB") + assertContains(t, got, "DEFAULT CHARSET=utf8mb4") + assertContains(t, got, "COLLATE=utf8mb4_0900_ai_ci") + }) + + // Scenario 2: Table with ROW_FORMAT=DYNAMIC + t.Run("row_format_dynamic", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT NOT NULL, + PRIMARY KEY (id) + ) ROW_FORMAT=DYNAMIC`) + + got := c.ShowCreateTable("testdb", "t") + assertContains(t, got, "ROW_FORMAT=DYNAMIC") + assertContains(t, got, "ENGINE=InnoDB") + }) + + // Scenario 3: Table with KEY_BLOCK_SIZE=8 + t.Run("key_block_size", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT NOT NULL, + PRIMARY KEY (id) + ) KEY_BLOCK_SIZE=8`) + + got := c.ShowCreateTable("testdb", "t") + assertContains(t, got, "KEY_BLOCK_SIZE=8") + assertContains(t, got, "ENGINE=InnoDB") + }) + + // Scenario 4: Table with COMMENT='description' + t.Run("table_comment", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT NOT NULL, + PRIMARY KEY (id) + ) COMMENT='this is a description'`) + + got := c.ShowCreateTable("testdb", "t") + assertContains(t, got, "COMMENT='this is a description'") + }) + + // Scenario 5: Table with AUTO_INCREMENT=1000 + t.Run("auto_increment_1000", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT NOT NULL AUTO_INCREMENT, + PRIMARY KEY (id) + ) AUTO_INCREMENT=1000`) + + got := c.ShowCreateTable("testdb", "t") + assertContains(t, got, "AUTO_INCREMENT=1000") + }) + + // Scenario 6: TEMPORARY TABLE — SHOW CREATE TABLE works + t.Run("temporary_table", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TEMPORARY TABLE t ( + id INT NOT NULL, + name VARCHAR(50), + PRIMARY KEY (id) + )`) + + got := c.ShowCreateTable("testdb", "t") + assertContains(t, got, "CREATE TEMPORARY TABLE `t`") + assertContains(t, got, "`id` int NOT NULL") + assertContains(t, got, "PRIMARY KEY (`id`)") + }) + + // Scenario 7: Table with all column types + t.Run("all_column_types", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT NOT NULL, + col_bigint BIGINT, + col_decimal DECIMAL(10,2), + col_varchar VARCHAR(255), + col_text TEXT, + col_blob BLOB, + col_json JSON, + col_enum ENUM('a','b','c'), + col_set SET('x','y','z'), + col_date DATE, + col_datetime DATETIME, + col_timestamp TIMESTAMP NULL, + PRIMARY KEY (id) + )`) + + got := c.ShowCreateTable("testdb", "t") + assertContains(t, got, "`col_bigint` bigint") + assertContains(t, got, "`col_decimal` decimal(10,2)") + assertContains(t, got, "`col_varchar` varchar(255)") + assertContains(t, got, "`col_text` text") + assertContains(t, got, "`col_blob` blob") + assertContains(t, got, "`col_json` json") + assertContains(t, got, "`col_enum` enum('a','b','c')") + assertContains(t, got, "`col_set` set('x','y','z')") + assertContains(t, got, "`col_date` date") + assertContains(t, got, "`col_datetime` datetime") + assertContains(t, got, "`col_timestamp` timestamp") + + // INT rendered + assertContains(t, got, "`id` int NOT NULL") + + // TEXT and BLOB should NOT show DEFAULT NULL + if strings.Contains(got, "`col_text` text DEFAULT NULL") { + t.Error("TEXT column should not show DEFAULT NULL") + } + if strings.Contains(got, "`col_blob` blob DEFAULT NULL") { + t.Error("BLOB column should not show DEFAULT NULL") + } + + // TIMESTAMP with explicit NULL should show NULL + assertContains(t, got, "`col_timestamp` timestamp NULL") + }) + + // Scenario 8: Column with ON UPDATE CURRENT_TIMESTAMP + t.Run("on_update_current_timestamp", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT NOT NULL, + updated_at TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (id) + )`) + + got := c.ShowCreateTable("testdb", "t") + // ON UPDATE should appear after DEFAULT + assertContains(t, got, "DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + }) + + // Scenario 9: Column DEFAULT expression (CURRENT_TIMESTAMP, literal, NULL) + t.Run("default_expressions", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT NOT NULL, + created_at TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP, + status INT DEFAULT 0, + note VARCHAR(100) DEFAULT NULL, + PRIMARY KEY (id) + )`) + + got := c.ShowCreateTable("testdb", "t") + // CURRENT_TIMESTAMP default + assertContains(t, got, "DEFAULT CURRENT_TIMESTAMP") + // Literal default — MySQL 8.0 quotes numeric defaults + assertContains(t, got, "DEFAULT '0'") + // NULL default + assertContains(t, got, "`note` varchar(100) DEFAULT NULL") + }) +} diff --git a/tidb/catalog/wt_1_1_test.go b/tidb/catalog/wt_1_1_test.go new file mode 100644 index 00000000..723f9354 --- /dev/null +++ b/tidb/catalog/wt_1_1_test.go @@ -0,0 +1,156 @@ +package catalog + +import "testing" + +// Section 1.1: Exec Result Basics + +func TestWalkThrough_1_1_EmptySQL(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("", nil) + if err != nil { + t.Fatalf("expected nil error, got: %v", err) + } + if results != nil { + t.Fatalf("expected nil results, got %d results", len(results)) + } +} + +func TestWalkThrough_1_1_WhitespaceOnlySQL(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec(" \n\t \n ", nil) + if err != nil { + t.Fatalf("expected nil error, got: %v", err) + } + if results != nil { + t.Fatalf("expected nil results, got %d results", len(results)) + } +} + +func TestWalkThrough_1_1_CommentOnlySQL(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("-- this is a comment\n/* block comment */", nil) + if err != nil { + t.Fatalf("expected nil error, got: %v", err) + } + if results != nil { + t.Fatalf("expected nil results, got %d results", len(results)) + } +} + +func TestWalkThrough_1_1_SingleDDL(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("CREATE TABLE t1 (id INT)", nil) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + if len(results) != 1 { + t.Fatalf("expected 1 result, got %d", len(results)) + } + assertNoError(t, results[0].Error) +} + +func TestWalkThrough_1_1_MultipleDDL(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("CREATE TABLE t1 (id INT); CREATE TABLE t2 (id INT); CREATE TABLE t3 (id INT)", nil) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + if len(results) != 3 { + t.Fatalf("expected 3 results, got %d", len(results)) + } + for i, r := range results { + assertNoError(t, r.Error) + if r.Error != nil { + t.Errorf("result[%d] unexpected error: %v", i, r.Error) + } + } +} + +func TestWalkThrough_1_1_ResultIndexMatchesPosition(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("CREATE TABLE t1 (id INT); CREATE TABLE t2 (id INT); CREATE TABLE t3 (id INT)", nil) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + for i, r := range results { + if r.Index != i { + t.Errorf("result[%d].Index = %d, want %d", i, r.Index, i) + } + } +} + +func TestWalkThrough_1_1_DMLSkipped(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT, name VARCHAR(100))") + + dmlStatements := []string{ + "SELECT * FROM t1", + "INSERT INTO t1 (id, name) VALUES (1, 'test')", + "UPDATE t1 SET name = 'updated' WHERE id = 1", + "DELETE FROM t1 WHERE id = 1", + } + + for _, sql := range dmlStatements { + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("unexpected parse error for %q: %v", sql, err) + } + if len(results) != 1 { + t.Fatalf("expected 1 result for %q, got %d", sql, len(results)) + } + if !results[0].Skipped { + t.Errorf("expected Skipped=true for %q", sql) + } + } +} + +func TestWalkThrough_1_1_DMLDoesNotModifyState(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT)") + + db := c.GetDatabase("testdb") + tableCountBefore := len(db.Tables) + + // Execute DML — should not change anything + _, err := c.Exec("INSERT INTO t1 (id) VALUES (1); SELECT * FROM t1; UPDATE t1 SET id = 2; DELETE FROM t1", nil) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + + tableCountAfter := len(db.Tables) + if tableCountBefore != tableCountAfter { + t.Errorf("table count changed from %d to %d after DML", tableCountBefore, tableCountAfter) + } + + // Verify the original table is still there unchanged + tbl := db.GetTable("t1") + if tbl == nil { + t.Fatal("table t1 should still exist after DML") + } +} + +func TestWalkThrough_1_1_UnknownStatementsIgnored(t *testing.T) { + c := wtSetup(t) + + // FLUSH and ANALYZE are parsed but not handled in processUtility — they return nil + unsupportedStatements := []string{ + "FLUSH TABLES", + "ANALYZE TABLE t1", + } + + // Create a table so ANALYZE TABLE has something to reference + wtExec(t, c, "CREATE TABLE t1 (id INT)") + + for _, sql := range unsupportedStatements { + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("unexpected parse error for %q: %v", sql, err) + } + if len(results) != 1 { + t.Fatalf("expected 1 result for %q, got %d", sql, len(results)) + } + if results[0].Error != nil { + t.Errorf("expected nil error for %q, got: %v", sql, results[0].Error) + } + } +} diff --git a/tidb/catalog/wt_1_2_test.go b/tidb/catalog/wt_1_2_test.go new file mode 100644 index 00000000..9061c2d0 --- /dev/null +++ b/tidb/catalog/wt_1_2_test.go @@ -0,0 +1,179 @@ +package catalog + +import ( + "testing" +) + +// --- Default mode: stop at first error --- + +func TestWalkThrough_1_2_DefaultStopsAtFirstError(t *testing.T) { + c := wtSetup(t) + // Three statements: first succeeds, second fails (dup table), third should not run. + sql := "CREATE TABLE t1 (id INT); CREATE TABLE t1 (id INT); CREATE TABLE t2 (id INT);" + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + // Should have exactly 2 results: success + error (stops there). + if len(results) != 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } + assertNoError(t, results[0].Error) + assertError(t, results[1].Error, ErrDupTable) +} + +func TestWalkThrough_1_2_DefaultStatementsAfterErrorNotExecuted(t *testing.T) { + c := wtSetup(t) + sql := "CREATE TABLE t1 (id INT); CREATE TABLE t1 (id INT); CREATE TABLE t2 (id INT);" + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + // Only 2 results should be returned — third statement not reached. + if len(results) != 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } + // Verify: no result with Index==2. + for _, r := range results { + if r.Index == 2 { + t.Error("third statement should not have been executed") + } + } +} + +func TestWalkThrough_1_2_DefaultCatalogReflectsPreError(t *testing.T) { + c := wtSetup(t) + sql := "CREATE TABLE t1 (id INT); CREATE TABLE t1 (id INT); CREATE TABLE t2 (id INT);" + _, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + db := c.GetDatabase("testdb") + if db == nil { + t.Fatal("testdb not found") + } + // t1 should exist (created before error). + if db.GetTable("t1") == nil { + t.Error("t1 should exist — it was created before the error") + } + // t2 should NOT exist (after error). + if db.GetTable("t2") != nil { + t.Error("t2 should not exist — execution stopped at error") + } +} + +// --- ContinueOnError mode --- + +func TestWalkThrough_1_2_ContinueOnErrorAllAttempted(t *testing.T) { + c := wtSetup(t) + sql := "CREATE TABLE t1 (id INT); CREATE TABLE t1 (id INT); CREATE TABLE t2 (id INT);" + opts := &ExecOptions{ContinueOnError: true} + results, err := c.Exec(sql, opts) + if err != nil { + t.Fatalf("parse error: %v", err) + } + // All 3 statements should be attempted. + if len(results) != 3 { + t.Fatalf("expected 3 results, got %d", len(results)) + } +} + +func TestWalkThrough_1_2_ContinueOnErrorSuccessAfterFailure(t *testing.T) { + c := wtSetup(t) + sql := "CREATE TABLE t1 (id INT); CREATE TABLE t1 (id INT); CREATE TABLE t2 (id INT);" + opts := &ExecOptions{ContinueOnError: true} + _, err := c.Exec(sql, opts) + if err != nil { + t.Fatalf("parse error: %v", err) + } + db := c.GetDatabase("testdb") + if db == nil { + t.Fatal("testdb not found") + } + // t2 should exist even though t1 dup failed in the middle. + if db.GetTable("t2") == nil { + t.Error("t2 should exist — ContinueOnError should continue past failures") + } +} + +func TestWalkThrough_1_2_ContinueOnErrorPerStatementErrors(t *testing.T) { + c := wtSetup(t) + sql := "CREATE TABLE t1 (id INT); CREATE TABLE t1 (id INT); CREATE TABLE t2 (id INT);" + opts := &ExecOptions{ContinueOnError: true} + results, err := c.Exec(sql, opts) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if len(results) != 3 { + t.Fatalf("expected 3 results, got %d", len(results)) + } + // First: success + assertNoError(t, results[0].Error) + // Second: dup table error + assertError(t, results[1].Error, ErrDupTable) + // Third: success + assertNoError(t, results[2].Error) +} + +func TestWalkThrough_1_2_ContinueOnErrorMultipleErrors(t *testing.T) { + c := wtSetup(t) + // Two errors: dup table and alter on unknown table. + sql := `CREATE TABLE t1 (id INT); +CREATE TABLE t1 (id INT); +ALTER TABLE nosuch ADD COLUMN x INT; +CREATE TABLE t2 (id INT);` + opts := &ExecOptions{ContinueOnError: true} + results, err := c.Exec(sql, opts) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if len(results) != 4 { + t.Fatalf("expected 4 results, got %d", len(results)) + } + assertNoError(t, results[0].Error) + assertError(t, results[1].Error, ErrDupTable) + assertError(t, results[2].Error, ErrNoSuchTable) + assertNoError(t, results[3].Error) +} + +// --- Parse error --- + +func TestWalkThrough_1_2_ParseErrorReturnsTopLevelError(t *testing.T) { + c := wtSetup(t) + // Intentionally bad SQL that the parser cannot parse. + results, err := c.Exec("CREATE TABLE ???", nil) + if err == nil { + t.Fatal("expected parse error, got nil") + } + if results != nil { + t.Errorf("expected nil results on parse error, got %d results", len(results)) + } +} + +// --- DELIMITER-containing SQL --- + +func TestWalkThrough_1_2_DelimiterSplitting(t *testing.T) { + c := wtSetup(t) + sql := `DELIMITER ;; +CREATE TABLE t1 (id INT);; +CREATE TABLE t2 (id INT);; +DELIMITER ;` + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + // Should have 2 results (one per CREATE TABLE). + if len(results) != 2 { + t.Fatalf("expected 2 results for DELIMITER SQL, got %d", len(results)) + } + db := c.GetDatabase("testdb") + if db == nil { + t.Fatal("testdb not found") + } + if db.GetTable("t1") == nil { + t.Error("t1 should exist after DELIMITER SQL") + } + if db.GetTable("t2") == nil { + t.Error("t2 should exist after DELIMITER SQL") + } +} diff --git a/tidb/catalog/wt_1_3_test.go b/tidb/catalog/wt_1_3_test.go new file mode 100644 index 00000000..99339307 --- /dev/null +++ b/tidb/catalog/wt_1_3_test.go @@ -0,0 +1,127 @@ +package catalog + +import "testing" + +// Section 1.3: Result Metadata — Line, SQL + +func TestWalkThrough_1_3_SingleLineMultiStatement(t *testing.T) { + // Single-line multi-statement: each Result.Line is 1 + c := wtSetup(t) + sql := "CREATE TABLE t1 (id INT); CREATE TABLE t2 (id INT); CREATE TABLE t3 (id INT);" + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if len(results) != 3 { + t.Fatalf("got %d results, want 3", len(results)) + } + for i, r := range results { + if r.Line != 1 { + t.Errorf("result[%d].Line = %d, want 1", i, r.Line) + } + } +} + +func TestWalkThrough_1_3_MultiLineStatements(t *testing.T) { + // Multi-line statements: Result.Line matches first line of each statement + c := wtSetup(t) + sql := "CREATE TABLE t1 (\n id INT\n);\nCREATE TABLE t2 (\n id INT\n);\nCREATE TABLE t3 (\n id INT\n);" + // Line 1: CREATE TABLE t1 ( + // Line 2: id INT + // Line 3: ); + // Line 4: CREATE TABLE t2 ( + // Line 5: id INT + // Line 6: ); + // Line 7: CREATE TABLE t3 ( + // Line 8: id INT + // Line 9: ); + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if len(results) != 3 { + t.Fatalf("got %d results, want 3", len(results)) + } + wantLines := []int{1, 4, 7} + for i, r := range results { + if r.Line != wantLines[i] { + t.Errorf("result[%d].Line = %d, want %d", i, r.Line, wantLines[i]) + } + } +} + +func TestWalkThrough_1_3_DelimiterMode(t *testing.T) { + // DELIMITER mode: Result.Line points to correct line in original SQL + c := wtSetup(t) + sql := "DELIMITER ;;\nCREATE TABLE t1 (id INT);;\nDELIMITER ;\nCREATE TABLE t2 (id INT);" + // Line 1: DELIMITER ;; + // Line 2: CREATE TABLE t1 (id INT);; + // Line 3: DELIMITER ; + // Line 4: CREATE TABLE t2 (id INT); + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if len(results) != 2 { + t.Fatalf("got %d results, want 2", len(results)) + } + wantLines := []int{2, 4} + for i, r := range results { + if r.Line != wantLines[i] { + t.Errorf("result[%d].Line = %d, want %d", i, r.Line, wantLines[i]) + } + } +} + +func TestWalkThrough_1_3_BlankLines(t *testing.T) { + // Statements after blank lines: Line numbers account for blank lines + c := wtSetup(t) + sql := "CREATE TABLE t1 (id INT);\n\n\nCREATE TABLE t2 (id INT);\n\nCREATE TABLE t3 (id INT);" + // Line 1: CREATE TABLE t1 (id INT); + // Line 2: (blank) + // Line 3: (blank) + // Line 4: CREATE TABLE t2 (id INT); + // Line 5: (blank) + // Line 6: CREATE TABLE t3 (id INT); + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if len(results) != 3 { + t.Fatalf("got %d results, want 3", len(results)) + } + wantLines := []int{1, 4, 6} + for i, r := range results { + if r.Line != wantLines[i] { + t.Errorf("result[%d].Line = %d, want %d", i, r.Line, wantLines[i]) + } + } +} + +func TestWalkThrough_1_3_DMLLineNumbers(t *testing.T) { + // Result.Line for DML (skipped) statements is still correct + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT)") + sql := "SELECT * FROM t1;\nINSERT INTO t1 VALUES (1);\nCREATE TABLE t2 (id INT);" + // Line 1: SELECT * FROM t1; + // Line 2: INSERT INTO t1 VALUES (1); + // Line 3: CREATE TABLE t2 (id INT); + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if len(results) != 3 { + t.Fatalf("got %d results, want 3", len(results)) + } + // DML statements should have correct line numbers even though skipped + wantLines := []int{1, 2, 3} + wantSkipped := []bool{true, true, false} + for i, r := range results { + if r.Line != wantLines[i] { + t.Errorf("result[%d].Line = %d, want %d", i, r.Line, wantLines[i]) + } + if r.Skipped != wantSkipped[i] { + t.Errorf("result[%d].Skipped = %v, want %v", i, r.Skipped, wantSkipped[i]) + } + } +} diff --git a/tidb/catalog/wt_2_1_test.go b/tidb/catalog/wt_2_1_test.go new file mode 100644 index 00000000..f820f5e1 --- /dev/null +++ b/tidb/catalog/wt_2_1_test.go @@ -0,0 +1,88 @@ +package catalog + +import "testing" + +// Section 2.1: Database Errors + +func TestWalkThrough_2_1_CreateDatabaseDuplicate(t *testing.T) { + c := wtSetup(t) + // "testdb" already exists from wtSetup + results, err := c.Exec("CREATE DATABASE testdb", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrDupDatabase) +} + +func TestWalkThrough_2_1_CreateDatabaseIfNotExistsOnExisting(t *testing.T) { + c := wtSetup(t) + // "testdb" already exists from wtSetup + results, err := c.Exec("CREATE DATABASE IF NOT EXISTS testdb", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertNoError(t, results[0].Error) +} + +func TestWalkThrough_2_1_DropDatabaseUnknown(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("DROP DATABASE nope", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrUnknownDatabase) +} + +func TestWalkThrough_2_1_DropDatabaseIfExistsOnUnknown(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("DROP DATABASE IF EXISTS nope", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertNoError(t, results[0].Error) +} + +func TestWalkThrough_2_1_UseUnknownDatabase(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("USE nope", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrUnknownDatabase) +} + +func TestWalkThrough_2_1_CreateTableWithoutUse(t *testing.T) { + c := New() // no database selected + results, err := c.Exec("CREATE TABLE t (id INT)", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrNoDatabaseSelected) +} + +func TestWalkThrough_2_1_AlterTableWithoutUse(t *testing.T) { + c := New() // no database selected + results, err := c.Exec("ALTER TABLE t ADD COLUMN x INT", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrNoDatabaseSelected) +} + +func TestWalkThrough_2_1_AlterDatabaseUnknown(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("ALTER DATABASE nope CHARACTER SET utf8", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrUnknownDatabase) +} + +func TestWalkThrough_2_1_TruncateUnknownTable(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("TRUNCATE TABLE nope", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrNoSuchTable) +} diff --git a/tidb/catalog/wt_2_2_test.go b/tidb/catalog/wt_2_2_test.go new file mode 100644 index 00000000..04020fe8 --- /dev/null +++ b/tidb/catalog/wt_2_2_test.go @@ -0,0 +1,185 @@ +package catalog + +import "testing" + +// Section 2.2: Table and Column Errors + +func TestWalkThrough_2_2_CreateTableDuplicate(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT)") + results, err := c.Exec("CREATE TABLE t (id INT)", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrDupTable) +} + +func TestWalkThrough_2_2_CreateTableIfNotExistsOnExisting(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT)") + results, err := c.Exec("CREATE TABLE IF NOT EXISTS t (id INT)", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertNoError(t, results[0].Error) +} + +func TestWalkThrough_2_2_CreateTableDuplicateColumn(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("CREATE TABLE t (id INT, id INT)", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrDupColumn) +} + +func TestWalkThrough_2_2_CreateTableMultiplePrimaryKeys(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(50) PRIMARY KEY)", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrMultiplePriKey) +} + +func TestWalkThrough_2_2_DropTableUnknown(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("DROP TABLE unknown_tbl", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + // DROP TABLE on non-existent table returns ErrUnknownTable (1051). + assertError(t, results[0].Error, ErrUnknownTable) +} + +func TestWalkThrough_2_2_DropTableIfExistsUnknown(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("DROP TABLE IF EXISTS unknown_tbl", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertNoError(t, results[0].Error) +} + +func TestWalkThrough_2_2_AlterTableUnknown(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("ALTER TABLE unknown_tbl ADD COLUMN x INT", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrNoSuchTable) +} + +func TestWalkThrough_2_2_AlterTableAddColumnDuplicate(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT)") + results, err := c.Exec("ALTER TABLE t ADD COLUMN id INT", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrDupColumn) +} + +func TestWalkThrough_2_2_AlterTableDropColumnUnknown(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT)") + results, err := c.Exec("ALTER TABLE t DROP COLUMN unknown_col", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + // MySQL 8.0 returns error 1091 (ErrCantDropKey) for DROP COLUMN on nonexistent column. + // The catalog matches this behavior. The scenario listed ErrNoSuchColumn (1054) + // but the actual MySQL behavior is 1091. + assertError(t, results[0].Error, ErrCantDropKey) +} + +func TestWalkThrough_2_2_AlterTableModifyColumnUnknown(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT)") + results, err := c.Exec("ALTER TABLE t MODIFY COLUMN unknown_col VARCHAR(50)", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrNoSuchColumn) +} + +func TestWalkThrough_2_2_AlterTableChangeColumnUnknownSource(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT)") + results, err := c.Exec("ALTER TABLE t CHANGE COLUMN unknown_col new_col INT", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrNoSuchColumn) +} + +func TestWalkThrough_2_2_AlterTableChangeColumnDuplicateTarget(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT, name VARCHAR(50))") + results, err := c.Exec("ALTER TABLE t CHANGE COLUMN id name INT", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrDupColumn) +} + +func TestWalkThrough_2_2_AlterTableAddPKWhenPKExists(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(50))") + results, err := c.Exec("ALTER TABLE t ADD PRIMARY KEY (name)", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrMultiplePriKey) +} + +func TestWalkThrough_2_2_AlterTableRenameColumnUnknown(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT)") + results, err := c.Exec("ALTER TABLE t RENAME COLUMN unknown_col TO new_col", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrNoSuchColumn) +} + +func TestWalkThrough_2_2_RenameTableUnknownSource(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("RENAME TABLE unknown_tbl TO new_tbl", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrNoSuchTable) +} + +func TestWalkThrough_2_2_RenameTableToExistingTarget(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT)") + wtExec(t, c, "CREATE TABLE t2 (id INT)") + results, err := c.Exec("RENAME TABLE t1 TO t2", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrDupTable) +} + +func TestWalkThrough_2_2_DropTableMultiPartialFailure(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT)") + // DROP TABLE t1, t2 — t1 exists, t2 does not. + // t1 should be dropped, then error on t2. + results, err := c.Exec("DROP TABLE t1, t2", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrUnknownTable) + + // Verify t1 was actually dropped before the error on t2. + db := c.GetDatabase("testdb") + if db == nil { + t.Fatal("database testdb not found") + } + if db.GetTable("t1") != nil { + t.Error("t1 should have been dropped before error on t2") + } +} diff --git a/tidb/catalog/wt_2_3_test.go b/tidb/catalog/wt_2_3_test.go new file mode 100644 index 00000000..6633c025 --- /dev/null +++ b/tidb/catalog/wt_2_3_test.go @@ -0,0 +1,71 @@ +package catalog + +import "testing" + +// TestWalkThrough_2_3_CreateIndexDupName tests that CREATE INDEX with a duplicate +// index name returns ErrDupKeyName (1061). +func TestWalkThrough_2_3_CreateIndexDupName(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT, name VARCHAR(100))") + wtExec(t, c, "CREATE INDEX idx_name ON t (name)") + + results, err := c.Exec("CREATE INDEX idx_name ON t (id)", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrDupKeyName) +} + +// TestWalkThrough_2_3_CreateIndexIfNotExists tests that CREATE INDEX IF NOT EXISTS +// on an existing index returns no error. +func TestWalkThrough_2_3_CreateIndexIfNotExists(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT, name VARCHAR(100))") + wtExec(t, c, "CREATE INDEX idx_name ON t (name)") + + results, err := c.Exec("CREATE INDEX IF NOT EXISTS idx_name ON t (id)", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertNoError(t, results[0].Error) +} + +// TestWalkThrough_2_3_DropIndexUnknown tests that DROP INDEX on a nonexistent +// index returns ErrCantDropKey (1091). +func TestWalkThrough_2_3_DropIndexUnknown(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT)") + + results, err := c.Exec("DROP INDEX no_such_idx ON t", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrCantDropKey) +} + +// TestWalkThrough_2_3_AlterTableDropIndexUnknown tests that ALTER TABLE DROP INDEX +// on a nonexistent index returns ErrCantDropKey (1091). +func TestWalkThrough_2_3_AlterTableDropIndexUnknown(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT)") + + results, err := c.Exec("ALTER TABLE t DROP INDEX no_such_idx", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrCantDropKey) +} + +// TestWalkThrough_2_3_AlterTableAddUniqueIndexDupName tests that ALTER TABLE ADD +// UNIQUE INDEX with a duplicate index name returns ErrDupKeyName (1061). +func TestWalkThrough_2_3_AlterTableAddUniqueIndexDupName(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT, name VARCHAR(100))") + wtExec(t, c, "CREATE INDEX my_idx ON t (name)") + + results, err := c.Exec("ALTER TABLE t ADD UNIQUE INDEX my_idx (id)", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrDupKeyName) +} diff --git a/tidb/catalog/wt_2_4_test.go b/tidb/catalog/wt_2_4_test.go new file mode 100644 index 00000000..93eeb0f1 --- /dev/null +++ b/tidb/catalog/wt_2_4_test.go @@ -0,0 +1,155 @@ +package catalog + +import "testing" + +// TestWalkThrough_2_4_CreateTableFKUnknownTable tests that CREATE TABLE with a FK +// referencing an unknown table returns ErrFKNoRefTable (1824) when fk_checks=1. +func TestWalkThrough_2_4_CreateTableFKUnknownTable(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("CREATE TABLE child (id INT, pid INT, FOREIGN KEY (pid) REFERENCES parent(id))", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrFKNoRefTable) +} + +// TestWalkThrough_2_4_CreateTableFKUnknownColumn tests that CREATE TABLE with a FK +// referencing an unknown column returns an error when fk_checks=1. +// MySQL returns ErrFKMissingIndex (1822) when the referenced column has no matching +// index (which is the case when the column doesn't exist). +func TestWalkThrough_2_4_CreateTableFKUnknownColumn(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + results, err := c.Exec("CREATE TABLE child (id INT, pid INT, FOREIGN KEY (pid) REFERENCES parent(nonexistent))", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + // The referenced column "nonexistent" doesn't exist, so there's no matching index. + assertError(t, results[0].Error, ErrFKMissingIndex) +} + +// TestWalkThrough_2_4_DropTableReferencedByFK tests that DROP TABLE on a table +// referenced by a FK returns ErrFKCannotDropParent (3730) when fk_checks=1. +func TestWalkThrough_2_4_DropTableReferencedByFK(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, "CREATE TABLE child (id INT, pid INT, FOREIGN KEY (pid) REFERENCES parent(id))") + + results, err := c.Exec("DROP TABLE parent", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrFKCannotDropParent) +} + +// TestWalkThrough_2_4_DropTableReferencedByFKWithChecksOff tests that DROP TABLE +// on a table referenced by a FK succeeds when fk_checks=0. +func TestWalkThrough_2_4_DropTableReferencedByFKWithChecksOff(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, "CREATE TABLE child (id INT, pid INT, FOREIGN KEY (pid) REFERENCES parent(id))") + wtExec(t, c, "SET foreign_key_checks = 0") + + results, err := c.Exec("DROP TABLE parent", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertNoError(t, results[0].Error) +} + +// TestWalkThrough_2_4_AlterTableAddFKUnknownTable tests that ALTER TABLE ADD FK +// referencing an unknown table returns ErrFKNoRefTable (1824). +func TestWalkThrough_2_4_AlterTableAddFKUnknownTable(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE child (id INT, pid INT)") + + results, err := c.Exec("ALTER TABLE child ADD CONSTRAINT fk_parent FOREIGN KEY (pid) REFERENCES parent(id)", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrFKNoRefTable) +} + +// TestWalkThrough_2_4_AlterTableDropColumnUsedInFK tests that ALTER TABLE DROP COLUMN +// on a column used in a FK constraint returns an error. +func TestWalkThrough_2_4_AlterTableDropColumnUsedInFK(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, "CREATE TABLE child (id INT, pid INT, FOREIGN KEY (pid) REFERENCES parent(id))") + + results, err := c.Exec("ALTER TABLE child DROP COLUMN pid", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + // Error code 1828: Cannot drop column needed in FK constraint. + if results[0].Error == nil { + t.Fatal("expected error when dropping column used in FK, got nil") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("expected *catalog.Error, got %T: %v", results[0].Error, results[0].Error) + } + if catErr.Code != 1828 { + t.Errorf("expected error code 1828, got %d: %s", catErr.Code, catErr.Message) + } +} + +// TestWalkThrough_2_4_AlterTableAddFKMissingIndex tests that ALTER TABLE ADD FK +// where the referenced table lacks a matching index returns ErrFKMissingIndex (1822). +func TestWalkThrough_2_4_AlterTableAddFKMissingIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT, val INT)") + wtExec(t, c, "CREATE TABLE child (id INT, pid INT)") + + // parent.val has no index, so FK should fail. + results, err := c.Exec("ALTER TABLE child ADD CONSTRAINT fk_val FOREIGN KEY (pid) REFERENCES parent(val)", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrFKMissingIndex) +} + +// TestWalkThrough_2_4_AlterTableAddFKIncompatibleColumns tests that ALTER TABLE ADD FK +// where column types are incompatible returns ErrFKIncompatibleColumns (3780). +func TestWalkThrough_2_4_AlterTableAddFKIncompatibleColumns(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, "CREATE TABLE child (id INT, pid VARCHAR(100))") + + results, err := c.Exec("ALTER TABLE child ADD CONSTRAINT fk_parent FOREIGN KEY (pid) REFERENCES parent(id)", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrFKIncompatibleColumns) +} + +// TestWalkThrough_2_4_SetForeignKeyChecksOff tests that SET foreign_key_checks=0 +// disables FK validation during CREATE TABLE. +func TestWalkThrough_2_4_SetForeignKeyChecksOff(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "SET foreign_key_checks = 0") + + // Should succeed even though "parent" doesn't exist. + results, err := c.Exec("CREATE TABLE child (id INT, pid INT, FOREIGN KEY (pid) REFERENCES parent(id))", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertNoError(t, results[0].Error) +} + +// TestWalkThrough_2_4_SetForeignKeyChecksOn tests that SET foreign_key_checks=1 +// re-enables FK validation. +func TestWalkThrough_2_4_SetForeignKeyChecksOn(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "SET foreign_key_checks = 0") + // Create child with FK to nonexistent parent — should succeed. + wtExec(t, c, "CREATE TABLE child (id INT, pid INT, FOREIGN KEY (pid) REFERENCES parent(id))") + wtExec(t, c, "SET foreign_key_checks = 1") + + // Now creating another table referencing nonexistent parent should fail. + results, err := c.Exec("CREATE TABLE child2 (id INT, pid INT, FOREIGN KEY (pid) REFERENCES parent(id))", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrFKNoRefTable) +} diff --git a/tidb/catalog/wt_2_5_test.go b/tidb/catalog/wt_2_5_test.go new file mode 100644 index 00000000..8348ad01 --- /dev/null +++ b/tidb/catalog/wt_2_5_test.go @@ -0,0 +1,50 @@ +package catalog + +import "testing" + +// Section 2.5: View Errors + +func TestWalkThrough_2_5_CreateViewOnExistingName(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT)") + wtExec(t, c, "CREATE VIEW v AS SELECT id FROM t") + // CREATE VIEW on existing view name without OR REPLACE → error. + // MySQL treats views and tables in the same namespace: ErrDupTable (1050). + results, err := c.Exec("CREATE VIEW v AS SELECT id FROM t", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrDupTable) +} + +func TestWalkThrough_2_5_CreateOrReplaceViewOnExisting(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT)") + wtExec(t, c, "CREATE VIEW v AS SELECT id FROM t") + // CREATE OR REPLACE VIEW on existing → no error, view is replaced. + results, err := c.Exec("CREATE OR REPLACE VIEW v AS SELECT id FROM t", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertNoError(t, results[0].Error) +} + +func TestWalkThrough_2_5_DropViewUnknown(t *testing.T) { + c := wtSetup(t) + // DROP VIEW on non-existent view → ErrUnknownTable (1051). + results, err := c.Exec("DROP VIEW unknown_view", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrUnknownTable) +} + +func TestWalkThrough_2_5_DropViewIfExistsUnknown(t *testing.T) { + c := wtSetup(t) + // DROP VIEW IF EXISTS on non-existent view → no error. + results, err := c.Exec("DROP VIEW IF EXISTS unknown_view", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertNoError(t, results[0].Error) +} diff --git a/tidb/catalog/wt_2_6_test.go b/tidb/catalog/wt_2_6_test.go new file mode 100644 index 00000000..67a9c163 --- /dev/null +++ b/tidb/catalog/wt_2_6_test.go @@ -0,0 +1,113 @@ +package catalog + +import "testing" + +// --- Procedure errors --- + +func TestWalkThrough_2_6_CreateProcedureDuplicate(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE PROCEDURE myproc() BEGIN END") + results, err := c.Exec("CREATE PROCEDURE myproc() BEGIN END", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrDupProcedure) +} + +func TestWalkThrough_2_6_CreateFunctionDuplicate(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE FUNCTION myfunc() RETURNS INT RETURN 1") + results, err := c.Exec("CREATE FUNCTION myfunc() RETURNS INT RETURN 1", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrDupFunction) +} + +func TestWalkThrough_2_6_DropProcedureUnknown(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("DROP PROCEDURE no_such_proc", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrNoSuchProcedure) +} + +func TestWalkThrough_2_6_DropFunctionUnknown(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("DROP FUNCTION no_such_func", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrNoSuchFunction) +} + +func TestWalkThrough_2_6_DropProcedureIfExistsUnknown(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("DROP PROCEDURE IF EXISTS no_such_proc", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertNoError(t, results[0].Error) +} + +// --- Trigger errors --- + +func TestWalkThrough_2_6_CreateTriggerDuplicate(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT)") + wtExec(t, c, "CREATE TRIGGER trg1 BEFORE INSERT ON t1 FOR EACH ROW SET @x = 1") + results, err := c.Exec("CREATE TRIGGER trg1 BEFORE INSERT ON t1 FOR EACH ROW SET @x = 1", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrDupTrigger) +} + +func TestWalkThrough_2_6_DropTriggerUnknown(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("DROP TRIGGER no_such_trigger", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrNoSuchTrigger) +} + +func TestWalkThrough_2_6_DropTriggerIfExistsUnknown(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("DROP TRIGGER IF EXISTS no_such_trigger", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertNoError(t, results[0].Error) +} + +// --- Event errors --- + +func TestWalkThrough_2_6_CreateEventDuplicate(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE EVENT evt1 ON SCHEDULE EVERY 1 DAY DO SELECT 1") + results, err := c.Exec("CREATE EVENT evt1 ON SCHEDULE EVERY 1 DAY DO SELECT 1", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrDupEvent) +} + +func TestWalkThrough_2_6_DropEventUnknown(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("DROP EVENT no_such_event", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrNoSuchEvent) +} + +func TestWalkThrough_2_6_DropEventIfExistsUnknown(t *testing.T) { + c := wtSetup(t) + results, err := c.Exec("DROP EVENT IF EXISTS no_such_event", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertNoError(t, results[0].Error) +} diff --git a/tidb/catalog/wt_3_1_test.go b/tidb/catalog/wt_3_1_test.go new file mode 100644 index 00000000..b1ee90ef --- /dev/null +++ b/tidb/catalog/wt_3_1_test.go @@ -0,0 +1,372 @@ +package catalog + +import "testing" + +// --- 3.1 CREATE TABLE State --- + +func TestWalkThrough_3_1_TableExists(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE users (id INT)") + tbl := c.GetDatabase("testdb").GetTable("users") + if tbl == nil { + t.Fatal("table 'users' not found after CREATE TABLE") + } +} + +func TestWalkThrough_3_1_ColumnCount(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b VARCHAR(100), c TEXT)") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + if len(tbl.Columns) != 3 { + t.Fatalf("expected 3 columns, got %d", len(tbl.Columns)) + } +} + +func TestWalkThrough_3_1_ColumnNamesInOrder(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (alpha INT, beta VARCHAR(50), gamma TEXT)") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + expected := []string{"alpha", "beta", "gamma"} + for i, col := range tbl.Columns { + if col.Name != expected[i] { + t.Errorf("column %d: expected name %q, got %q", i, expected[i], col.Name) + } + } +} + +func TestWalkThrough_3_1_ColumnPositionsSequential(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b VARCHAR(50), c TEXT, d BLOB)") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + for i, col := range tbl.Columns { + expected := i + 1 + if col.Position != expected { + t.Errorf("column %q: expected position %d, got %d", col.Name, expected, col.Position) + } + } +} + +func TestWalkThrough_3_1_ColumnTypes(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + a INT, + b VARCHAR(100), + c DECIMAL(10,2), + d DATETIME, + e TEXT, + f BLOB, + g JSON + )`) + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + tests := []struct { + name string + dataType string + }{ + {"a", "int"}, + {"b", "varchar"}, + {"c", "decimal"}, + {"d", "datetime"}, + {"e", "text"}, + {"f", "blob"}, + {"g", "json"}, + } + for _, tt := range tests { + col := tbl.GetColumn(tt.name) + if col == nil { + t.Errorf("column %q not found", tt.name) + continue + } + if col.DataType != tt.dataType { + t.Errorf("column %q: expected DataType %q, got %q", tt.name, tt.dataType, col.DataType) + } + } + + // Check ColumnType includes params. + colB := tbl.GetColumn("b") + if colB.ColumnType != "varchar(100)" { + t.Errorf("column b: expected ColumnType 'varchar(100)', got %q", colB.ColumnType) + } + colC := tbl.GetColumn("c") + if colC.ColumnType != "decimal(10,2)" { + t.Errorf("column c: expected ColumnType 'decimal(10,2)', got %q", colC.ColumnType) + } +} + +func TestWalkThrough_3_1_NotNull(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT NOT NULL, b INT)") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + colA := tbl.GetColumn("a") + if colA.Nullable { + t.Error("column a should be NOT NULL (Nullable=false)") + } + colB := tbl.GetColumn("b") + if !colB.Nullable { + t.Error("column b should be nullable by default") + } +} + +func TestWalkThrough_3_1_DefaultValue(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT DEFAULT 42, b VARCHAR(50) DEFAULT 'hello', c INT)") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + colA := tbl.GetColumn("a") + if colA.Default == nil { + t.Fatal("column a: expected default value, got nil") + } + if *colA.Default != "42" { + t.Errorf("column a: expected default '42', got %q", *colA.Default) + } + + colB := tbl.GetColumn("b") + if colB.Default == nil { + t.Fatal("column b: expected default value, got nil") + } + // The default may be stored with or without quotes. + if *colB.Default != "'hello'" && *colB.Default != "hello" { + t.Errorf("column b: expected default 'hello' or \"'hello'\", got %q", *colB.Default) + } + + colC := tbl.GetColumn("c") + if colC.Default != nil { + t.Errorf("column c: expected nil default, got %q", *colC.Default) + } +} + +func TestWalkThrough_3_1_AutoIncrement(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(50))") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + colID := tbl.GetColumn("id") + if !colID.AutoIncrement { + t.Error("column id should have AutoIncrement=true") + } + colName := tbl.GetColumn("name") + if colName.AutoIncrement { + t.Error("column name should not have AutoIncrement") + } +} + +func TestWalkThrough_3_1_ColumnComment(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT COMMENT 'primary identifier', name VARCHAR(50))") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + colID := tbl.GetColumn("id") + if colID.Comment != "primary identifier" { + t.Errorf("column id: expected comment 'primary identifier', got %q", colID.Comment) + } + colName := tbl.GetColumn("name") + if colName.Comment != "" { + t.Errorf("column name: expected empty comment, got %q", colName.Comment) + } +} + +func TestWalkThrough_3_1_TableEngine(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT) ENGINE=MyISAM") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + if tbl.Engine != "MyISAM" { + t.Errorf("expected engine 'MyISAM', got %q", tbl.Engine) + } +} + +func TestWalkThrough_3_1_TableCharsetAndCollation(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT) DEFAULT CHARSET=latin1 COLLATE=latin1_swedish_ci") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + if tbl.Charset != "latin1" { + t.Errorf("expected charset 'latin1', got %q", tbl.Charset) + } + if tbl.Collation != "latin1_swedish_ci" { + t.Errorf("expected collation 'latin1_swedish_ci', got %q", tbl.Collation) + } +} + +func TestWalkThrough_3_1_TableComment(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT) COMMENT='user accounts'") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + if tbl.Comment != "user accounts" { + t.Errorf("expected table comment 'user accounts', got %q", tbl.Comment) + } +} + +func TestWalkThrough_3_1_TableAutoIncrementStart(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT AUTO_INCREMENT PRIMARY KEY) AUTO_INCREMENT=1000") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + if tbl.AutoIncrement != 1000 { + t.Errorf("expected table AUTO_INCREMENT=1000, got %d", tbl.AutoIncrement) + } +} + +func TestWalkThrough_3_1_UnsignedModifier(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT UNSIGNED, b BIGINT UNSIGNED)") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + colA := tbl.GetColumn("a") + if colA.ColumnType != "int unsigned" { + t.Errorf("column a: expected ColumnType 'int unsigned', got %q", colA.ColumnType) + } + colB := tbl.GetColumn("b") + if colB.ColumnType != "bigint unsigned" { + t.Errorf("column b: expected ColumnType 'bigint unsigned', got %q", colB.ColumnType) + } +} + +func TestWalkThrough_3_1_GeneratedColumnVirtual(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT AS (a * 2) VIRTUAL)") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + colB := tbl.GetColumn("b") + if colB == nil { + t.Fatal("column b not found") + } + if colB.Generated == nil { + t.Fatal("column b: expected Generated info, got nil") + } + if colB.Generated.Stored { + t.Error("column b: expected Stored=false for VIRTUAL") + } + if colB.Generated.Expr == "" { + t.Error("column b: expected non-empty generated expression") + } +} + +func TestWalkThrough_3_1_GeneratedColumnStored(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT AS (a + 1) STORED)") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + colB := tbl.GetColumn("b") + if colB == nil { + t.Fatal("column b not found") + } + if colB.Generated == nil { + t.Fatal("column b: expected Generated info, got nil") + } + if !colB.Generated.Stored { + t.Error("column b: expected Stored=true for STORED") + } + if colB.Generated.Expr == "" { + t.Error("column b: expected non-empty generated expression") + } +} + +func TestWalkThrough_3_1_ColumnInvisible(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT INVISIBLE)") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + colA := tbl.GetColumn("a") + if colA.Invisible { + t.Error("column a should be visible (Invisible=false)") + } + colB := tbl.GetColumn("b") + if !colB.Invisible { + t.Error("column b should be invisible (Invisible=true)") + } +} + +func TestWalkThrough_3_1_CreateTableLike(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, ` + CREATE TABLE src ( + id INT NOT NULL AUTO_INCREMENT, + name VARCHAR(100) DEFAULT 'unnamed', + PRIMARY KEY (id), + INDEX idx_name (name) + ) + `) + wtExec(t, c, "CREATE TABLE dst LIKE src") + + srcTbl := c.GetDatabase("testdb").GetTable("src") + dstTbl := c.GetDatabase("testdb").GetTable("dst") + if dstTbl == nil { + t.Fatal("table 'dst' not found after CREATE TABLE ... LIKE") + } + + // Columns should match. + if len(dstTbl.Columns) != len(srcTbl.Columns) { + t.Fatalf("expected %d columns, got %d", len(srcTbl.Columns), len(dstTbl.Columns)) + } + for i, srcCol := range srcTbl.Columns { + dstCol := dstTbl.Columns[i] + if dstCol.Name != srcCol.Name { + t.Errorf("column %d: expected name %q, got %q", i, srcCol.Name, dstCol.Name) + } + if dstCol.ColumnType != srcCol.ColumnType { + t.Errorf("column %q: expected type %q, got %q", srcCol.Name, srcCol.ColumnType, dstCol.ColumnType) + } + if dstCol.Nullable != srcCol.Nullable { + t.Errorf("column %q: expected Nullable=%v, got %v", srcCol.Name, srcCol.Nullable, dstCol.Nullable) + } + } + + // Indexes should match. + if len(dstTbl.Indexes) != len(srcTbl.Indexes) { + t.Fatalf("expected %d indexes, got %d", len(srcTbl.Indexes), len(dstTbl.Indexes)) + } + for i, srcIdx := range srcTbl.Indexes { + dstIdx := dstTbl.Indexes[i] + if dstIdx.Name != srcIdx.Name { + t.Errorf("index %d: expected name %q, got %q", i, srcIdx.Name, dstIdx.Name) + } + if dstIdx.Primary != srcIdx.Primary { + t.Errorf("index %q: expected Primary=%v, got %v", srcIdx.Name, srcIdx.Primary, dstIdx.Primary) + } + } + + // Constraints should match. + if len(dstTbl.Constraints) != len(srcTbl.Constraints) { + t.Fatalf("expected %d constraints, got %d", len(srcTbl.Constraints), len(dstTbl.Constraints)) + } +} diff --git a/tidb/catalog/wt_3_2_test.go b/tidb/catalog/wt_3_2_test.go new file mode 100644 index 00000000..4ecf50ac --- /dev/null +++ b/tidb/catalog/wt_3_2_test.go @@ -0,0 +1,153 @@ +package catalog + +import "testing" + +// TestWalkThrough_3_2_AlterDatabaseCharset verifies that ALTER DATABASE +// changes the database charset and derives the default collation. +func TestWalkThrough_3_2_AlterDatabaseCharset(t *testing.T) { + c := wtSetup(t) + + // Default charset is utf8mb4 + db := c.GetDatabase("testdb") + if db == nil { + t.Fatal("testdb not found") + } + if db.Charset != "utf8mb4" { + t.Fatalf("expected default charset utf8mb4, got %s", db.Charset) + } + + wtExec(t, c, "ALTER DATABASE testdb CHARACTER SET latin1") + + db = c.GetDatabase("testdb") + if db.Charset != "latin1" { + t.Errorf("expected charset latin1, got %s", db.Charset) + } + // When charset is changed without explicit collation, default collation is derived. + if db.Collation != "latin1_swedish_ci" { + t.Errorf("expected collation latin1_swedish_ci, got %s", db.Collation) + } +} + +// TestWalkThrough_3_2_AlterDatabaseCollation verifies that ALTER DATABASE +// changes the database collation. +func TestWalkThrough_3_2_AlterDatabaseCollation(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, "ALTER DATABASE testdb COLLATE utf8mb4_unicode_ci") + + db := c.GetDatabase("testdb") + if db == nil { + t.Fatal("testdb not found") + } + if db.Collation != "utf8mb4_unicode_ci" { + t.Errorf("expected collation utf8mb4_unicode_ci, got %s", db.Collation) + } + // Charset should remain unchanged when only collation is set. + if db.Charset != "utf8mb4" { + t.Errorf("expected charset utf8mb4, got %s", db.Charset) + } +} + +// TestWalkThrough_3_2_RenameTable verifies that RENAME TABLE removes the old +// name and adds the new name with the same columns and indexes. +func TestWalkThrough_3_2_RenameTable(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, "CREATE TABLE t1 (id INT PRIMARY KEY, name VARCHAR(100))") + wtExec(t, c, "CREATE INDEX idx_name ON t1 (name)") + + // Capture column and index info before rename. + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Fatal("t1 not found before rename") + } + origColCount := len(tbl.Columns) + origIdxCount := len(tbl.Indexes) + + wtExec(t, c, "RENAME TABLE t1 TO t2") + + // Old name must be gone. + if c.GetDatabase("testdb").GetTable("t1") != nil { + t.Error("t1 should not exist after rename") + } + + // New name must be present. + tbl2 := c.GetDatabase("testdb").GetTable("t2") + if tbl2 == nil { + t.Fatal("t2 not found after rename") + } + + // Same columns. + if len(tbl2.Columns) != origColCount { + t.Errorf("expected %d columns, got %d", origColCount, len(tbl2.Columns)) + } + if tbl2.GetColumn("id") == nil { + t.Error("column 'id' missing after rename") + } + if tbl2.GetColumn("name") == nil { + t.Error("column 'name' missing after rename") + } + + // Same indexes. + if len(tbl2.Indexes) != origIdxCount { + t.Errorf("expected %d indexes, got %d", origIdxCount, len(tbl2.Indexes)) + } +} + +// TestWalkThrough_3_2_RenameTableCrossDatabase verifies that RENAME TABLE +// moves a table from one database to another. +func TestWalkThrough_3_2_RenameTableCrossDatabase(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, "CREATE DATABASE otherdb") + wtExec(t, c, "CREATE TABLE t1 (id INT PRIMARY KEY, val TEXT)") + + wtExec(t, c, "RENAME TABLE testdb.t1 TO otherdb.t1") + + // Old location must be empty. + if c.GetDatabase("testdb").GetTable("t1") != nil { + t.Error("t1 should not exist in testdb after cross-db rename") + } + + // New location must have the table. + tbl := c.GetDatabase("otherdb").GetTable("t1") + if tbl == nil { + t.Fatal("t1 not found in otherdb after cross-db rename") + } + if tbl.GetColumn("id") == nil { + t.Error("column 'id' missing after cross-db rename") + } + if tbl.GetColumn("val") == nil { + t.Error("column 'val' missing after cross-db rename") + } +} + +// TestWalkThrough_3_2_TruncateTable verifies that TRUNCATE TABLE keeps the +// table but resets AUTO_INCREMENT to 0. +func TestWalkThrough_3_2_TruncateTable(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, "CREATE TABLE t1 (id INT AUTO_INCREMENT PRIMARY KEY) AUTO_INCREMENT=100") + + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Fatal("t1 not found") + } + if tbl.AutoIncrement != 100 { + t.Fatalf("expected AUTO_INCREMENT=100 before truncate, got %d", tbl.AutoIncrement) + } + + wtExec(t, c, "TRUNCATE TABLE t1") + + tbl = c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Fatal("t1 should still exist after truncate") + } + if tbl.AutoIncrement != 0 { + t.Errorf("expected AUTO_INCREMENT=0 after truncate, got %d", tbl.AutoIncrement) + } + // Columns should still be present. + if tbl.GetColumn("id") == nil { + t.Error("column 'id' missing after truncate") + } +} diff --git a/tidb/catalog/wt_3_3_test.go b/tidb/catalog/wt_3_3_test.go new file mode 100644 index 00000000..482909a0 --- /dev/null +++ b/tidb/catalog/wt_3_3_test.go @@ -0,0 +1,608 @@ +package catalog + +import "testing" + +// --- PRIMARY KEY --- + +func TestWalkThrough_3_3_PrimaryKeyCreatesIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT, name VARCHAR(50), PRIMARY KEY (id))") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + found := false + for _, idx := range tbl.Indexes { + if idx.Primary { + found = true + if idx.Name != "PRIMARY" { + t.Errorf("expected PK index name 'PRIMARY', got %q", idx.Name) + } + if !idx.Unique { + t.Error("PK index should be Unique=true") + } + if len(idx.Columns) != 1 || idx.Columns[0].Name != "id" { + t.Errorf("PK index columns mismatch: %+v", idx.Columns) + } + } + } + if !found { + t.Error("no index with Primary=true found") + } +} + +func TestWalkThrough_3_3_PrimaryKeyColumnsNotNull(t *testing.T) { + c := wtSetup(t) + // Do NOT specify NOT NULL — PK should auto-mark columns NOT NULL. + wtExec(t, c, "CREATE TABLE t (id INT, PRIMARY KEY (id))") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + col := tbl.GetColumn("id") + if col == nil { + t.Fatal("column id not found") + } + if col.Nullable { + t.Error("PK column should be NOT NULL automatically") + } +} + +// --- UNIQUE KEY --- + +func TestWalkThrough_3_3_UniqueKeyCreatesIndexAndConstraint(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, email VARCHAR(100), UNIQUE KEY uk_email (email))") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + // Check index. + var uqIdx *Index + for _, idx := range tbl.Indexes { + if idx.Name == "uk_email" { + uqIdx = idx + break + } + } + if uqIdx == nil { + t.Fatal("unique index uk_email not found") + } + if !uqIdx.Unique { + t.Error("expected Unique=true") + } + if len(uqIdx.Columns) != 1 || uqIdx.Columns[0].Name != "email" { + t.Errorf("unique index columns mismatch: %+v", uqIdx.Columns) + } + + // Check constraint. + var uqCon *Constraint + for _, con := range tbl.Constraints { + if con.Name == "uk_email" && con.Type == ConUniqueKey { + uqCon = con + break + } + } + if uqCon == nil { + t.Fatal("unique constraint uk_email not found") + } + if len(uqCon.Columns) != 1 || uqCon.Columns[0] != "email" { + t.Errorf("unique constraint columns mismatch: %v", uqCon.Columns) + } +} + +// --- Regular INDEX --- + +func TestWalkThrough_3_3_RegularIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(100), INDEX idx_name (name))") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + var idx *Index + for _, i := range tbl.Indexes { + if i.Name == "idx_name" { + idx = i + break + } + } + if idx == nil { + t.Fatal("index idx_name not found") + } + if idx.Unique { + t.Error("regular index should not be unique") + } + if idx.Primary { + t.Error("regular index should not be primary") + } + if len(idx.Columns) != 1 || idx.Columns[0].Name != "name" { + t.Errorf("index columns mismatch: %+v", idx.Columns) + } +} + +// --- Multi-column index --- + +func TestWalkThrough_3_3_MultiColumnIndexOrder(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT, INDEX idx_abc (a, b, c))") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + var idx *Index + for _, i := range tbl.Indexes { + if i.Name == "idx_abc" { + idx = i + break + } + } + if idx == nil { + t.Fatal("index idx_abc not found") + } + if len(idx.Columns) != 3 { + t.Fatalf("expected 3 columns, got %d", len(idx.Columns)) + } + expected := []string{"a", "b", "c"} + for i, exp := range expected { + if idx.Columns[i].Name != exp { + t.Errorf("column %d: expected %q, got %q", i, exp, idx.Columns[i].Name) + } + } +} + +// --- FULLTEXT index --- + +func TestWalkThrough_3_3_FulltextIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, body TEXT, FULLTEXT INDEX ft_body (body))") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + var idx *Index + for _, i := range tbl.Indexes { + if i.Name == "ft_body" { + idx = i + break + } + } + if idx == nil { + t.Fatal("fulltext index ft_body not found") + } + if !idx.Fulltext { + t.Error("expected Fulltext=true") + } +} + +// --- SPATIAL index --- + +func TestWalkThrough_3_3_SpatialIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, geo GEOMETRY NOT NULL SRID 0, SPATIAL INDEX sp_geo (geo))") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + var idx *Index + for _, i := range tbl.Indexes { + if i.Name == "sp_geo" { + idx = i + break + } + } + if idx == nil { + t.Fatal("spatial index sp_geo not found") + } + if !idx.Spatial { + t.Error("expected Spatial=true") + } +} + +// --- Index COMMENT --- + +func TestWalkThrough_3_3_IndexComment(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(100), INDEX idx_name (name) COMMENT 'name lookup')") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + var idx *Index + for _, i := range tbl.Indexes { + if i.Name == "idx_name" { + idx = i + break + } + } + if idx == nil { + t.Fatal("index idx_name not found") + } + if idx.Comment != "name lookup" { + t.Errorf("expected comment 'name lookup', got %q", idx.Comment) + } +} + +// --- Index INVISIBLE --- + +func TestWalkThrough_3_3_IndexInvisible(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(100), INDEX idx_name (name) INVISIBLE)") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + var idx *Index + for _, i := range tbl.Indexes { + if i.Name == "idx_name" { + idx = i + break + } + } + if idx == nil { + t.Fatal("index idx_name not found") + } + if idx.Visible { + t.Error("expected Visible=false for INVISIBLE index") + } +} + +// --- FOREIGN KEY --- + +func TestWalkThrough_3_3_ForeignKeyConstraint(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, "CREATE TABLE child (id INT PRIMARY KEY, pid INT, CONSTRAINT fk_pid FOREIGN KEY (pid) REFERENCES parent(id) ON DELETE CASCADE ON UPDATE SET NULL)") + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + var fk *Constraint + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey && con.Name == "fk_pid" { + fk = con + break + } + } + if fk == nil { + t.Fatal("FK constraint fk_pid not found") + } + if fk.RefTable != "parent" { + t.Errorf("expected RefTable 'parent', got %q", fk.RefTable) + } + if len(fk.RefColumns) != 1 || fk.RefColumns[0] != "id" { + t.Errorf("expected RefColumns [id], got %v", fk.RefColumns) + } + if fk.OnDelete != "CASCADE" { + t.Errorf("expected OnDelete 'CASCADE', got %q", fk.OnDelete) + } + if fk.OnUpdate != "SET NULL" { + t.Errorf("expected OnUpdate 'SET NULL', got %q", fk.OnUpdate) + } + if len(fk.Columns) != 1 || fk.Columns[0] != "pid" { + t.Errorf("expected Columns [pid], got %v", fk.Columns) + } +} + +// --- FOREIGN KEY implicit backing index --- + +func TestWalkThrough_3_3_ForeignKeyBackingIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, "CREATE TABLE child (id INT PRIMARY KEY, pid INT, CONSTRAINT fk_pid FOREIGN KEY (pid) REFERENCES parent(id))") + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + // FK should create an implicit backing index. + found := false + for _, idx := range tbl.Indexes { + if idx.Primary { + continue + } + for _, col := range idx.Columns { + if col.Name == "pid" { + found = true + break + } + } + if found { + break + } + } + if !found { + t.Error("expected implicit backing index for FK on column pid") + } +} + +// --- CHECK constraint --- + +func TestWalkThrough_3_3_CheckConstraint(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, age INT, CONSTRAINT chk_age CHECK (age >= 0))") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + var chk *Constraint + for _, con := range tbl.Constraints { + if con.Type == ConCheck && con.Name == "chk_age" { + chk = con + break + } + } + if chk == nil { + t.Fatal("CHECK constraint chk_age not found") + } + if chk.CheckExpr == "" { + t.Error("CHECK expression should not be empty") + } + if chk.NotEnforced { + t.Error("CHECK should be enforced by default") + } +} + +func TestWalkThrough_3_3_CheckConstraintNotEnforced(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, age INT, CONSTRAINT chk_age CHECK (age >= 0) NOT ENFORCED)") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + var chk *Constraint + for _, con := range tbl.Constraints { + if con.Type == ConCheck && con.Name == "chk_age" { + chk = con + break + } + } + if chk == nil { + t.Fatal("CHECK constraint chk_age not found") + } + if !chk.NotEnforced { + t.Error("CHECK should be NOT ENFORCED") + } +} + +// --- Named constraints --- + +func TestWalkThrough_3_3_NamedConstraints(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT, CONSTRAINT my_pk PRIMARY KEY (id))") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + // PK constraint name in MySQL is always "PRIMARY", regardless of user-specified name. + // But non-PK constraints should preserve names. Test with UNIQUE. + c2 := wtSetup(t) + wtExec(t, c2, "CREATE TABLE t (id INT PRIMARY KEY, email VARCHAR(100), CONSTRAINT my_unique UNIQUE KEY (email))") + tbl2 := c2.GetDatabase("testdb").GetTable("t") + if tbl2 == nil { + t.Fatal("table not found") + } + found := false + for _, con := range tbl2.Constraints { + if con.Type == ConUniqueKey && con.Name == "my_unique" { + found = true + break + } + } + if !found { + t.Error("named unique constraint 'my_unique' not found") + } +} + +// --- Unnamed constraints auto-generated names --- + +func TestWalkThrough_3_3_UnnamedConstraintAutoName(t *testing.T) { + c := wtSetup(t) + // FK without explicit name should get auto-generated name like t_ibfk_1. + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, "CREATE TABLE child (id INT PRIMARY KEY, pid INT, FOREIGN KEY (pid) REFERENCES parent(id))") + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + var fk *Constraint + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey { + fk = con + break + } + } + if fk == nil { + t.Fatal("FK constraint not found") + } + if fk.Name == "" { + t.Error("unnamed FK constraint should get auto-generated name") + } + // MySQL auto-names FKs as tableName_ibfk_N. + expected := "child_ibfk_1" + if fk.Name != expected { + t.Errorf("expected auto-generated FK name %q, got %q", expected, fk.Name) + } + + // Also test unnamed CHECK: should get tableName_chk_N. + c2 := wtSetup(t) + wtExec(t, c2, "CREATE TABLE t (id INT PRIMARY KEY, age INT, CHECK (age >= 0))") + tbl2 := c2.GetDatabase("testdb").GetTable("t") + if tbl2 == nil { + t.Fatal("table not found") + } + var chk *Constraint + for _, con := range tbl2.Constraints { + if con.Type == ConCheck { + chk = con + break + } + } + if chk == nil { + t.Fatal("CHECK constraint not found") + } + if chk.Name == "" { + t.Error("unnamed CHECK constraint should get auto-generated name") + } + expectedChk := "t_chk_1" + if chk.Name != expectedChk { + t.Errorf("expected auto-generated CHECK name %q, got %q", expectedChk, chk.Name) + } +} + +// Bug A (CREATE path): auto-generated FK name counter increments per unnamed +// FK, starting from 0, ignoring user-named FKs. +// +// MySQL reference: sql/sql_table.cc:9252 initializes the counter to 0 for +// create_table_impl; sql/sql_table.cc:5912 generate_fk_name uses ++counter. +// This means user-named FKs do NOT seed the counter during CREATE TABLE. +// +// Example: CREATE TABLE child (a INT, CONSTRAINT child_ibfk_5 FK, b INT, FK) +// Real MySQL: unnamed FK gets "child_ibfk_1" (first auto-named, counter 0 → 1). +// Spot-check confirmed with real MySQL 8.0 container. +// +// NOTE: ALTER TABLE ADD FK uses a different rule (max+1) — see the test below. +func TestBugFix_FKCounterCreateTable(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, `CREATE TABLE child ( + a INT, + CONSTRAINT child_ibfk_5 FOREIGN KEY (a) REFERENCES parent(id), + b INT, + FOREIGN KEY (b) REFERENCES parent(id) + )`) + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + var autoGenName string + for _, con := range tbl.Constraints { + if con.Type != ConForeignKey { + continue + } + if con.Name == "child_ibfk_5" { + continue + } + autoGenName = con.Name + break + } + if autoGenName == "" { + t.Fatal("expected a second (auto-named) FK constraint, found none") + } + // Real MySQL 8.0 produces "child_ibfk_1" here (verified by spot-check). + // The user-named "child_ibfk_5" does NOT seed the counter during CREATE. + if autoGenName != "child_ibfk_1" { + t.Errorf("expected child_ibfk_1 (first unnamed FK, ignoring user-named _5), got %s", autoGenName) + } +} + +// Bug A (ALTER path): ALTER TABLE ADD FOREIGN KEY uses max(existing)+1 logic. +// MySQL reference: sql/sql_table.cc:14345 (ALTER TABLE) initializes the +// counter via get_fk_max_generated_name_number(), which scans the existing +// table definition for the max generated-name counter. +func TestBugFix_FKCounterAlterTable(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, `CREATE TABLE child ( + a INT, + b INT, + CONSTRAINT child_ibfk_20 FOREIGN KEY (a) REFERENCES parent(id) + )`) + wtExec(t, c, "ALTER TABLE child ADD FOREIGN KEY (b) REFERENCES parent(id)") + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + var autoGenName string + for _, con := range tbl.Constraints { + if con.Type != ConForeignKey { + continue + } + if con.Name == "child_ibfk_20" { + continue + } + autoGenName = con.Name + break + } + if autoGenName != "child_ibfk_21" { + t.Errorf("expected child_ibfk_21 (max 20 + 1), got %s", autoGenName) + } +} + +// Bug B: TIMESTAMP first column must NOT be auto-promoted under MySQL 8.0 +// defaults. In 8.0, explicit_defaults_for_timestamp = ON by default, which +// disables promote_first_timestamp_column() (sql/sql_table.cc:10148). omni +// catalog matches this default — it never promotes. This test locks in the +// absence of a stale TIMESTAMP-promotion bug. +func TestBugFix_TimestampNoAutoPromotion(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (ts TIMESTAMP NOT NULL)") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table t not found") + } + col := tbl.GetColumn("ts") + if col == nil { + t.Fatal("column ts not found") + } + if col.Default != nil { + t.Errorf("expected no auto-promoted DEFAULT (8.0 default behavior), got Default=%q", *col.Default) + } + if col.OnUpdate != "" { + t.Errorf("expected no auto-promoted ON UPDATE (8.0 default behavior), got OnUpdate=%q", col.OnUpdate) + } +} + +// PS1: CHECK constraint counter (CREATE path) follows the same rule as FK +// counter: it's a local counter starting at 0, incrementing per unnamed +// CHECK, IGNORING user-named _chk_N constraints. +// +// MySQL source: sql/sql_table.cc:19073 declares `uint cc_max_generated_number = 0` +// as a fresh local counter. Uses ++cc_max_generated_number per unnamed CHECK. +// If the generated name collides with a user-named one, MySQL errors with +// ER_CHECK_CONSTRAINT_DUP_NAME at sql/sql_table.cc:19595. +// +// Example: CREATE TABLE t (a INT, CONSTRAINT t_chk_1 CHECK(a>0), b INT, CHECK(b<100)) +// Real MySQL: the second unnamed CHECK gets t_chk_1 (counter 0 → 1), which +// collides with user-named t_chk_1 → ER_CHECK_CONSTRAINT_DUP_NAME. +// omni currently does not error on collision (PS7 tracking), but at minimum +// it should assign the correct counter sequence ignoring user-named entries. +func TestBugFix_CheckCounterCreateTable(t *testing.T) { + c := wtSetup(t) + // Use user-named t_chk_5 so omni's unnamed-CHECK counter (starting 1) + // doesn't collide. + wtExec(t, c, `CREATE TABLE t ( + a INT, + CONSTRAINT t_chk_5 CHECK (a > 0), + b INT, + CHECK (b < 100) + )`) + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table t not found") + } + var autoGenName string + for _, con := range tbl.Constraints { + if con.Type != ConCheck { + continue + } + if con.Name == "t_chk_5" { + continue + } + autoGenName = con.Name + break + } + if autoGenName == "" { + t.Fatal("expected a second (auto-named) CHECK constraint, found none") + } + // Real MySQL 8.0 produces t_chk_1 here — the user-named _5 is NOT seeded + // into the counter during CREATE (verified by source code analysis; + // sql/sql_table.cc:19073 starts cc_max_generated_number at 0). + if autoGenName != "t_chk_1" { + t.Errorf("expected t_chk_1 (first unnamed CHECK, ignoring user-named _5), got %s", autoGenName) + } +} diff --git a/tidb/catalog/wt_3_4_test.go b/tidb/catalog/wt_3_4_test.go new file mode 100644 index 00000000..342cd241 --- /dev/null +++ b/tidb/catalog/wt_3_4_test.go @@ -0,0 +1,257 @@ +package catalog + +import "testing" + +// --- 3.4 ALTER TABLE State — Column Operations --- + +func TestWalkThrough_3_4_AddColumnEnd(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT)") + wtExec(t, c, "ALTER TABLE t ADD COLUMN c VARCHAR(50)") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + if len(tbl.Columns) != 3 { + t.Fatalf("expected 3 columns, got %d", len(tbl.Columns)) + } + col := tbl.GetColumn("c") + if col == nil { + t.Fatal("column 'c' not found") + } + if col.Position != 3 { + t.Errorf("expected column c at position 3, got %d", col.Position) + } +} + +func TestWalkThrough_3_4_AddColumnFirst(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT)") + wtExec(t, c, "ALTER TABLE t ADD COLUMN z VARCHAR(50) FIRST") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + if len(tbl.Columns) != 3 { + t.Fatalf("expected 3 columns, got %d", len(tbl.Columns)) + } + // z should be at position 1 + col := tbl.GetColumn("z") + if col == nil { + t.Fatal("column 'z' not found") + } + if col.Position != 1 { + t.Errorf("expected column z at position 1, got %d", col.Position) + } + // a should shift to position 2 + if tbl.GetColumn("a").Position != 2 { + t.Errorf("expected column a at position 2, got %d", tbl.GetColumn("a").Position) + } + // b should shift to position 3 + if tbl.GetColumn("b").Position != 3 { + t.Errorf("expected column b at position 3, got %d", tbl.GetColumn("b").Position) + } +} + +func TestWalkThrough_3_4_AddColumnAfter(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT)") + wtExec(t, c, "ALTER TABLE t ADD COLUMN x VARCHAR(50) AFTER a") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + if len(tbl.Columns) != 4 { + t.Fatalf("expected 4 columns, got %d", len(tbl.Columns)) + } + // Order should be: a(1), x(2), b(3), c(4) + expected := []struct { + name string + pos int + }{{"a", 1}, {"x", 2}, {"b", 3}, {"c", 4}} + for _, e := range expected { + col := tbl.GetColumn(e.name) + if col == nil { + t.Fatalf("column %q not found", e.name) + } + if col.Position != e.pos { + t.Errorf("column %q: expected position %d, got %d", e.name, e.pos, col.Position) + } + } +} + +func TestWalkThrough_3_4_DropColumn(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT)") + wtExec(t, c, "ALTER TABLE t DROP COLUMN b") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + if len(tbl.Columns) != 2 { + t.Fatalf("expected 2 columns, got %d", len(tbl.Columns)) + } + if tbl.GetColumn("b") != nil { + t.Error("column 'b' should have been dropped") + } + // Positions should be resequenced: a=1, c=2 + if tbl.GetColumn("a").Position != 1 { + t.Errorf("expected column a at position 1, got %d", tbl.GetColumn("a").Position) + } + if tbl.GetColumn("c").Position != 2 { + t.Errorf("expected column c at position 2, got %d", tbl.GetColumn("c").Position) + } +} + +func TestWalkThrough_3_4_ModifyColumnType(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b VARCHAR(50), c INT)") + wtExec(t, c, "ALTER TABLE t MODIFY COLUMN b TEXT") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + col := tbl.GetColumn("b") + if col == nil { + t.Fatal("column 'b' not found") + } + if col.DataType != "text" { + t.Errorf("expected DataType 'text', got %q", col.DataType) + } + // Position should remain unchanged + if col.Position != 2 { + t.Errorf("expected position 2, got %d", col.Position) + } +} + +func TestWalkThrough_3_4_ModifyColumnNullability(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT NOT NULL)") + tbl := c.GetDatabase("testdb").GetTable("t") + col := tbl.GetColumn("b") + if col.Nullable { + t.Error("column 'b' should initially be NOT NULL") + } + wtExec(t, c, "ALTER TABLE t MODIFY COLUMN b INT NULL") + tbl = c.GetDatabase("testdb").GetTable("t") + col = tbl.GetColumn("b") + if !col.Nullable { + t.Error("column 'b' should be nullable after MODIFY") + } +} + +func TestWalkThrough_3_4_ChangeColumnName(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, old_name VARCHAR(50), c INT)") + wtExec(t, c, "ALTER TABLE t CHANGE COLUMN old_name new_name VARCHAR(50)") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + // Old name should be gone + if tbl.GetColumn("old_name") != nil { + t.Error("old column name 'old_name' should not exist") + } + // New name should be present + col := tbl.GetColumn("new_name") + if col == nil { + t.Fatal("column 'new_name' not found") + } +} + +func TestWalkThrough_3_4_ChangeColumnTypeAndAttrs(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b VARCHAR(50) NOT NULL)") + wtExec(t, c, "ALTER TABLE t CHANGE COLUMN b b_new TEXT NULL DEFAULT 'hello'") + tbl := c.GetDatabase("testdb").GetTable("t") + col := tbl.GetColumn("b_new") + if col == nil { + t.Fatal("column 'b_new' not found") + } + if col.DataType != "text" { + t.Errorf("expected DataType 'text', got %q", col.DataType) + } + if !col.Nullable { + t.Error("column should be nullable after CHANGE") + } + if col.Default == nil || *col.Default != "'hello'" { + def := "" + if col.Default != nil { + def = *col.Default + } + t.Errorf("expected default 'hello', got %s", def) + } +} + +func TestWalkThrough_3_4_RenameColumn(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT)") + wtExec(t, c, "ALTER TABLE t RENAME COLUMN b TO b_renamed") + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl.GetColumn("b") != nil { + t.Error("old column name 'b' should not exist") + } + col := tbl.GetColumn("b_renamed") + if col == nil { + t.Fatal("column 'b_renamed' not found") + } + // Position should be unchanged (2) + if col.Position != 2 { + t.Errorf("expected position 2, got %d", col.Position) + } +} + +func TestWalkThrough_3_4_AlterColumnSetDefault(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT)") + wtExec(t, c, "ALTER TABLE t ALTER COLUMN b SET DEFAULT 42") + tbl := c.GetDatabase("testdb").GetTable("t") + col := tbl.GetColumn("b") + if col.Default == nil { + t.Fatal("expected default value, got nil") + } + if *col.Default != "42" { + t.Errorf("expected default '42', got %q", *col.Default) + } +} + +func TestWalkThrough_3_4_AlterColumnDropDefault(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT DEFAULT 10)") + tbl := c.GetDatabase("testdb").GetTable("t") + col := tbl.GetColumn("b") + if col.Default == nil { + t.Fatal("expected initial default value") + } + wtExec(t, c, "ALTER TABLE t ALTER COLUMN b DROP DEFAULT") + tbl = c.GetDatabase("testdb").GetTable("t") + col = tbl.GetColumn("b") + if col.Default != nil { + t.Errorf("expected nil default after DROP DEFAULT, got %q", *col.Default) + } + if !col.DefaultDropped { + t.Error("expected DefaultDropped to be true") + } +} + +func TestWalkThrough_3_4_AlterColumnVisibility(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT)") + tbl := c.GetDatabase("testdb").GetTable("t") + col := tbl.GetColumn("b") + if col.Invisible { + t.Error("column should start visible") + } + wtExec(t, c, "ALTER TABLE t ALTER COLUMN b SET INVISIBLE") + tbl = c.GetDatabase("testdb").GetTable("t") + col = tbl.GetColumn("b") + if !col.Invisible { + t.Error("column should be invisible after SET INVISIBLE") + } + wtExec(t, c, "ALTER TABLE t ALTER COLUMN b SET VISIBLE") + tbl = c.GetDatabase("testdb").GetTable("t") + col = tbl.GetColumn("b") + if col.Invisible { + t.Error("column should be visible after SET VISIBLE") + } +} diff --git a/tidb/catalog/wt_3_5_test.go b/tidb/catalog/wt_3_5_test.go new file mode 100644 index 00000000..092013bc --- /dev/null +++ b/tidb/catalog/wt_3_5_test.go @@ -0,0 +1,415 @@ +package catalog + +import "testing" + +// --- ADD INDEX --- + +func TestWalkThrough_3_5_AddIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(100))") + wtExec(t, c, "ALTER TABLE t ADD INDEX idx_name (name)") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + var found *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_name" { + found = idx + break + } + } + if found == nil { + t.Fatal("index idx_name not found in table.Indexes") + } + if found.Unique { + t.Error("regular index should not be Unique") + } + if found.Primary { + t.Error("regular index should not be Primary") + } + if len(found.Columns) != 1 || found.Columns[0].Name != "name" { + t.Errorf("index columns mismatch: %+v", found.Columns) + } +} + +// --- ADD UNIQUE INDEX --- + +func TestWalkThrough_3_5_AddUniqueIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, email VARCHAR(200))") + wtExec(t, c, "ALTER TABLE t ADD UNIQUE INDEX idx_email (email)") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + // Check index exists and is unique. + var uqIdx *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_email" { + uqIdx = idx + break + } + } + if uqIdx == nil { + t.Fatal("unique index idx_email not found") + } + if !uqIdx.Unique { + t.Error("expected Unique=true on unique index") + } + if len(uqIdx.Columns) != 1 || uqIdx.Columns[0].Name != "email" { + t.Errorf("unique index columns mismatch: %+v", uqIdx.Columns) + } + + // Check constraint created. + var uqCon *Constraint + for _, con := range tbl.Constraints { + if con.Name == "idx_email" && con.Type == ConUniqueKey { + uqCon = con + break + } + } + if uqCon == nil { + t.Fatal("unique constraint idx_email not found") + } + if len(uqCon.Columns) != 1 || uqCon.Columns[0] != "email" { + t.Errorf("unique constraint columns mismatch: %v", uqCon.Columns) + } +} + +// --- ADD PRIMARY KEY --- + +func TestWalkThrough_3_5_AddPrimaryKey(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT, name VARCHAR(100))") + wtExec(t, c, "ALTER TABLE t ADD PRIMARY KEY (id)") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + var pkIdx *Index + for _, idx := range tbl.Indexes { + if idx.Primary { + pkIdx = idx + break + } + } + if pkIdx == nil { + t.Fatal("no primary key index found") + } + if pkIdx.Name != "PRIMARY" { + t.Errorf("expected PK index name 'PRIMARY', got %q", pkIdx.Name) + } + if !pkIdx.Unique { + t.Error("PK index should be Unique=true") + } + if len(pkIdx.Columns) != 1 || pkIdx.Columns[0].Name != "id" { + t.Errorf("PK index columns mismatch: %+v", pkIdx.Columns) + } + + // PK column should be marked NOT NULL. + col := tbl.GetColumn("id") + if col == nil { + t.Fatal("column id not found") + } + if col.Nullable { + t.Error("PK column should be NOT NULL") + } +} + +// --- ADD FOREIGN KEY --- + +func TestWalkThrough_3_5_AddForeignKey(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, "CREATE TABLE child (id INT PRIMARY KEY, parent_id INT)") + wtExec(t, c, "ALTER TABLE child ADD CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent(id) ON DELETE CASCADE ON UPDATE SET NULL") + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + var fkCon *Constraint + for _, con := range tbl.Constraints { + if con.Name == "fk_parent" && con.Type == ConForeignKey { + fkCon = con + break + } + } + if fkCon == nil { + t.Fatal("FK constraint fk_parent not found") + } + if len(fkCon.Columns) != 1 || fkCon.Columns[0] != "parent_id" { + t.Errorf("FK columns mismatch: %v", fkCon.Columns) + } + if fkCon.RefTable != "parent" { + t.Errorf("expected RefTable 'parent', got %q", fkCon.RefTable) + } + if len(fkCon.RefColumns) != 1 || fkCon.RefColumns[0] != "id" { + t.Errorf("FK RefColumns mismatch: %v", fkCon.RefColumns) + } + if fkCon.OnDelete != "CASCADE" { + t.Errorf("expected OnDelete 'CASCADE', got %q", fkCon.OnDelete) + } + if fkCon.OnUpdate != "SET NULL" { + t.Errorf("expected OnUpdate 'SET NULL', got %q", fkCon.OnUpdate) + } +} + +// --- ADD CHECK --- + +func TestWalkThrough_3_5_AddCheck(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, age INT)") + wtExec(t, c, "ALTER TABLE t ADD CONSTRAINT chk_age CHECK (age >= 0)") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + var chkCon *Constraint + for _, con := range tbl.Constraints { + if con.Name == "chk_age" && con.Type == ConCheck { + chkCon = con + break + } + } + if chkCon == nil { + t.Fatal("CHECK constraint chk_age not found") + } + if chkCon.CheckExpr == "" { + t.Error("CHECK constraint should have a non-empty expression") + } + if chkCon.NotEnforced { + t.Error("CHECK constraint should be enforced by default") + } +} + +// --- DROP INDEX --- + +func TestWalkThrough_3_5_DropIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(100), INDEX idx_name (name))") + // Verify index exists first. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + foundBefore := false + for _, idx := range tbl.Indexes { + if idx.Name == "idx_name" { + foundBefore = true + break + } + } + if !foundBefore { + t.Fatal("index idx_name should exist before drop") + } + + wtExec(t, c, "ALTER TABLE t DROP INDEX idx_name") + + tbl = c.GetDatabase("testdb").GetTable("t") + for _, idx := range tbl.Indexes { + if idx.Name == "idx_name" { + t.Error("index idx_name should have been removed after DROP INDEX") + } + } +} + +// --- DROP PRIMARY KEY --- + +func TestWalkThrough_3_5_DropPrimaryKey(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(100))") + wtExec(t, c, "ALTER TABLE t DROP PRIMARY KEY") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + for _, idx := range tbl.Indexes { + if idx.Primary { + t.Error("PK index should have been removed after DROP PRIMARY KEY") + } + } +} + +// --- RENAME INDEX --- + +func TestWalkThrough_3_5_RenameIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(100), INDEX idx_name (name))") + wtExec(t, c, "ALTER TABLE t RENAME INDEX idx_name TO idx_name_new") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + foundOld := false + foundNew := false + for _, idx := range tbl.Indexes { + if idx.Name == "idx_name" { + foundOld = true + } + if idx.Name == "idx_name_new" { + foundNew = true + } + } + if foundOld { + t.Error("old index name idx_name should no longer exist") + } + if !foundNew { + t.Error("new index name idx_name_new should exist") + } +} + +// --- ALTER INDEX VISIBILITY --- + +func TestWalkThrough_3_5_AlterIndexVisible(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(100), INDEX idx_name (name))") + + // Default should be visible. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + var idx *Index + for _, i := range tbl.Indexes { + if i.Name == "idx_name" { + idx = i + break + } + } + if idx == nil { + t.Fatal("index idx_name not found") + } + if !idx.Visible { + t.Error("index should be visible by default") + } + + // Set invisible. + wtExec(t, c, "ALTER TABLE t ALTER INDEX idx_name INVISIBLE") + tbl = c.GetDatabase("testdb").GetTable("t") + for _, i := range tbl.Indexes { + if i.Name == "idx_name" { + if i.Visible { + t.Error("index should be invisible after ALTER INDEX INVISIBLE") + } + } + } + + // Set visible again. + wtExec(t, c, "ALTER TABLE t ALTER INDEX idx_name VISIBLE") + tbl = c.GetDatabase("testdb").GetTable("t") + for _, i := range tbl.Indexes { + if i.Name == "idx_name" { + if !i.Visible { + t.Error("index should be visible after ALTER INDEX VISIBLE") + } + } + } +} + +// --- ALTER CHECK ENFORCED / NOT ENFORCED --- + +func TestWalkThrough_3_5_AlterCheckEnforced(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, age INT, CONSTRAINT chk_age CHECK (age >= 0))") + + // Default should be enforced. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + var chk *Constraint + for _, con := range tbl.Constraints { + if con.Name == "chk_age" && con.Type == ConCheck { + chk = con + break + } + } + if chk == nil { + t.Fatal("CHECK constraint chk_age not found") + } + if chk.NotEnforced { + t.Error("CHECK should be enforced by default") + } + + // Set NOT ENFORCED. + wtExec(t, c, "ALTER TABLE t ALTER CHECK chk_age NOT ENFORCED") + tbl = c.GetDatabase("testdb").GetTable("t") + for _, con := range tbl.Constraints { + if con.Name == "chk_age" && con.Type == ConCheck { + if !con.NotEnforced { + t.Error("CHECK should be NOT ENFORCED after ALTER CHECK NOT ENFORCED") + } + } + } + + // Set ENFORCED. + wtExec(t, c, "ALTER TABLE t ALTER CHECK chk_age ENFORCED") + tbl = c.GetDatabase("testdb").GetTable("t") + for _, con := range tbl.Constraints { + if con.Name == "chk_age" && con.Type == ConCheck { + if con.NotEnforced { + t.Error("CHECK should be ENFORCED after ALTER CHECK ENFORCED") + } + } + } +} + +// --- CONVERT TO CHARACTER SET --- + +func TestWalkThrough_3_5_ConvertToCharset(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(100), bio TEXT, age INT) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4") + + // Convert to latin1. + wtExec(t, c, "ALTER TABLE t CONVERT TO CHARACTER SET latin1") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + // Table charset should be updated. + if tbl.Charset != "latin1" { + t.Errorf("expected table charset 'latin1', got %q", tbl.Charset) + } + + // String columns should be updated. + nameCol := tbl.GetColumn("name") + if nameCol == nil { + t.Fatal("column name not found") + } + if nameCol.Charset != "latin1" { + t.Errorf("expected name column charset 'latin1', got %q", nameCol.Charset) + } + + bioCol := tbl.GetColumn("bio") + if bioCol == nil { + t.Fatal("column bio not found") + } + if bioCol.Charset != "latin1" { + t.Errorf("expected bio column charset 'latin1', got %q", bioCol.Charset) + } + + // Non-string column should NOT be affected. + ageCol := tbl.GetColumn("age") + if ageCol == nil { + t.Fatal("column age not found") + } + if ageCol.Charset != "" { + t.Errorf("expected age column charset empty, got %q", ageCol.Charset) + } +} diff --git a/tidb/catalog/wt_3_6_test.go b/tidb/catalog/wt_3_6_test.go new file mode 100644 index 00000000..b9dcb32d --- /dev/null +++ b/tidb/catalog/wt_3_6_test.go @@ -0,0 +1,223 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestWalkThrough_3_6_CreateViewExists verifies that CREATE VIEW adds the view +// to database.Views. +func TestWalkThrough_3_6_CreateViewExists(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT, name VARCHAR(100))") + wtExec(t, c, "CREATE VIEW v1 AS SELECT id, name FROM t1") + + db := c.GetDatabase("testdb") + v := db.Views[toLower("v1")] + if v == nil { + t.Fatal("view v1 not found in database.Views") + } + if v.Name != "v1" { + t.Errorf("expected view name 'v1', got %q", v.Name) + } +} + +// TestWalkThrough_3_6_CreateViewDefinition verifies that the Definition field +// stores the deparsed SQL (not the raw input). +func TestWalkThrough_3_6_CreateViewDefinition(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT, name VARCHAR(100))") + wtExec(t, c, "CREATE VIEW v1 AS SELECT id, name FROM t1") + + db := c.GetDatabase("testdb") + v := db.Views[toLower("v1")] + if v == nil { + t.Fatal("view v1 not found") + } + // Definition should be the deparsed SELECT, not empty. + if v.Definition == "" { + t.Fatal("view definition should not be empty") + } + // The deparsed SQL should reference the columns and table. + def := strings.ToLower(v.Definition) + if !strings.Contains(def, "select") { + t.Errorf("definition should contain SELECT, got %q", v.Definition) + } + if !strings.Contains(def, "t1") { + t.Errorf("definition should reference t1, got %q", v.Definition) + } +} + +// TestWalkThrough_3_6_CreateViewAttributes verifies Algorithm, Definer, and +// SqlSecurity are preserved. +func TestWalkThrough_3_6_CreateViewAttributes(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT)") + wtExec(t, c, "CREATE ALGORITHM=MERGE DEFINER=`admin`@`localhost` SQL SECURITY INVOKER VIEW v1 AS SELECT id FROM t1") + + db := c.GetDatabase("testdb") + v := db.Views[toLower("v1")] + if v == nil { + t.Fatal("view v1 not found") + } + if !strings.EqualFold(v.Algorithm, "MERGE") { + t.Errorf("expected Algorithm 'MERGE', got %q", v.Algorithm) + } + // The parser stores the definer as-is (without backtick-quoting). + // Backtick formatting is applied only in showCreateView via formatDefiner. + if !strings.Contains(v.Definer, "admin") || !strings.Contains(v.Definer, "localhost") { + t.Errorf("expected Definer to contain 'admin' and 'localhost', got %q", v.Definer) + } + if !strings.EqualFold(v.SqlSecurity, "INVOKER") { + t.Errorf("expected SqlSecurity 'INVOKER', got %q", v.SqlSecurity) + } +} + +// TestWalkThrough_3_6_CreateViewDefaultAttributes verifies defaults when +// Algorithm/Definer/SqlSecurity are not specified. +func TestWalkThrough_3_6_CreateViewDefaultAttributes(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT)") + wtExec(t, c, "CREATE VIEW v1 AS SELECT id FROM t1") + + db := c.GetDatabase("testdb") + v := db.Views[toLower("v1")] + if v == nil { + t.Fatal("view v1 not found") + } + // Definer defaults to `root`@`%` per viewcmds.go. + if v.Definer != "`root`@`%`" { + t.Errorf("expected default Definer '`root`@`%%`', got %q", v.Definer) + } +} + +// TestWalkThrough_3_6_CreateViewColumns verifies that the column list is +// derived from the SELECT target list. +func TestWalkThrough_3_6_CreateViewColumns(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT, name VARCHAR(100), age INT)") + wtExec(t, c, "CREATE VIEW v1 AS SELECT id, name FROM t1") + + db := c.GetDatabase("testdb") + v := db.Views[toLower("v1")] + if v == nil { + t.Fatal("view v1 not found") + } + if len(v.Columns) != 2 { + t.Fatalf("expected 2 columns, got %d: %v", len(v.Columns), v.Columns) + } + if v.Columns[0] != "id" { + t.Errorf("expected first column 'id', got %q", v.Columns[0]) + } + if v.Columns[1] != "name" { + t.Errorf("expected second column 'name', got %q", v.Columns[1]) + } +} + +// TestWalkThrough_3_6_CreateOrReplaceView verifies that CREATE OR REPLACE VIEW +// updates an existing view. +func TestWalkThrough_3_6_CreateOrReplaceView(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT, name VARCHAR(100), age INT)") + wtExec(t, c, "CREATE VIEW v1 AS SELECT id FROM t1") + + // Verify initial state. + db := c.GetDatabase("testdb") + v := db.Views[toLower("v1")] + if v == nil { + t.Fatal("view v1 not found after CREATE") + } + if len(v.Columns) != 1 { + t.Fatalf("expected 1 column initially, got %d", len(v.Columns)) + } + + // Replace with a different SELECT. + wtExec(t, c, "CREATE OR REPLACE VIEW v1 AS SELECT id, name, age FROM t1") + + v = db.Views[toLower("v1")] + if v == nil { + t.Fatal("view v1 not found after CREATE OR REPLACE") + } + if len(v.Columns) != 3 { + t.Fatalf("expected 3 columns after replace, got %d: %v", len(v.Columns), v.Columns) + } +} + +// TestWalkThrough_3_6_AlterView verifies that ALTER VIEW updates definition +// and attributes. +func TestWalkThrough_3_6_AlterView(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT, name VARCHAR(100))") + wtExec(t, c, "CREATE VIEW v1 AS SELECT id FROM t1") + + // ALTER VIEW changes definition and attributes. + wtExec(t, c, "ALTER VIEW v1 AS SELECT id, name FROM t1") + + db := c.GetDatabase("testdb") + v := db.Views[toLower("v1")] + if v == nil { + t.Fatal("view v1 not found after ALTER") + } + if len(v.Columns) != 2 { + t.Fatalf("expected 2 columns after ALTER, got %d", len(v.Columns)) + } + // Definition should be updated. + def := strings.ToLower(v.Definition) + if !strings.Contains(def, "name") { + t.Errorf("definition after ALTER should reference 'name', got %q", v.Definition) + } +} + +// TestWalkThrough_3_6_DropView verifies that DROP VIEW removes the view from +// database.Views. +func TestWalkThrough_3_6_DropView(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT)") + wtExec(t, c, "CREATE VIEW v1 AS SELECT id FROM t1") + + db := c.GetDatabase("testdb") + if db.Views[toLower("v1")] == nil { + t.Fatal("view v1 should exist before DROP") + } + + wtExec(t, c, "DROP VIEW v1") + + if db.Views[toLower("v1")] != nil { + t.Fatal("view v1 should not exist after DROP") + } +} + +// TestWalkThrough_3_6_ViewReferencingTableColumns verifies that a view +// referencing table columns resolves correctly and the column names appear +// in the view's Columns list. +func TestWalkThrough_3_6_ViewReferencingTableColumns(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE employees ( + id INT NOT NULL, + first_name VARCHAR(50), + last_name VARCHAR(50), + salary DECIMAL(10,2) + )`) + wtExec(t, c, "CREATE VIEW emp_names AS SELECT id, first_name, last_name FROM employees") + + db := c.GetDatabase("testdb") + v := db.Views[toLower("emp_names")] + if v == nil { + t.Fatal("view emp_names not found") + } + expectedCols := []string{"id", "first_name", "last_name"} + if len(v.Columns) != len(expectedCols) { + t.Fatalf("expected %d columns, got %d: %v", len(expectedCols), len(v.Columns), v.Columns) + } + for i, want := range expectedCols { + if v.Columns[i] != want { + t.Errorf("column %d: expected %q, got %q", i, want, v.Columns[i]) + } + } + + // Verify the definition references the employees table. + def := strings.ToLower(v.Definition) + if !strings.Contains(def, "employees") { + t.Errorf("definition should reference employees table, got %q", v.Definition) + } +} diff --git a/tidb/catalog/wt_3_7_test.go b/tidb/catalog/wt_3_7_test.go new file mode 100644 index 00000000..a902d911 --- /dev/null +++ b/tidb/catalog/wt_3_7_test.go @@ -0,0 +1,274 @@ +package catalog + +import "testing" + +// --- 3.7 Routine, Trigger, and Event State --- + +func TestWalkThrough_3_7_CreateProcedure(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `DELIMITER ;; +CREATE PROCEDURE my_proc(IN a INT, OUT b VARCHAR(100)) +BEGIN + SET b = CONCAT('hello', a); +END;; +DELIMITER ;`) + + db := c.GetDatabase("testdb") + if db == nil { + t.Fatal("testdb not found") + } + proc := db.Procedures[toLower("my_proc")] + if proc == nil { + t.Fatal("procedure my_proc not found in database.Procedures") + } + if proc.Name != "my_proc" { + t.Errorf("expected name my_proc, got %s", proc.Name) + } + if !proc.IsProcedure { + t.Error("expected IsProcedure=true") + } + if len(proc.Params) != 2 { + t.Fatalf("expected 2 params, got %d", len(proc.Params)) + } + // Check param a + if proc.Params[0].Direction != "IN" { + t.Errorf("param 0 direction: expected IN, got %s", proc.Params[0].Direction) + } + if proc.Params[0].Name != "a" { + t.Errorf("param 0 name: expected a, got %s", proc.Params[0].Name) + } + if proc.Params[0].TypeName != "INT" { + t.Errorf("param 0 type: expected INT, got %s", proc.Params[0].TypeName) + } + // Check param b + if proc.Params[1].Direction != "OUT" { + t.Errorf("param 1 direction: expected OUT, got %s", proc.Params[1].Direction) + } + if proc.Params[1].Name != "b" { + t.Errorf("param 1 name: expected b, got %s", proc.Params[1].Name) + } + if proc.Params[1].TypeName != "VARCHAR(100)" { + t.Errorf("param 1 type: expected VARCHAR(100), got %s", proc.Params[1].TypeName) + } +} + +func TestWalkThrough_3_7_CreateFunction(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `DELIMITER ;; +CREATE FUNCTION my_func(a INT, b INT) RETURNS INT +DETERMINISTIC +RETURN a + b;; +DELIMITER ;`) + + db := c.GetDatabase("testdb") + if db == nil { + t.Fatal("testdb not found") + } + fn := db.Functions[toLower("my_func")] + if fn == nil { + t.Fatal("function my_func not found in database.Functions") + } + if fn.Name != "my_func" { + t.Errorf("expected name my_func, got %s", fn.Name) + } + if fn.IsProcedure { + t.Error("expected IsProcedure=false for function") + } + // Return type should contain "int" + if fn.Returns == "" { + t.Error("expected non-empty Returns for function") + } + if len(fn.Params) != 2 { + t.Fatalf("expected 2 params, got %d", len(fn.Params)) + } + if fn.Params[0].Name != "a" { + t.Errorf("param 0 name: expected a, got %s", fn.Params[0].Name) + } + if fn.Params[1].Name != "b" { + t.Errorf("param 1 name: expected b, got %s", fn.Params[1].Name) + } +} + +func TestWalkThrough_3_7_AlterProcedure(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `DELIMITER ;; +CREATE PROCEDURE my_proc() +BEGIN + SELECT 1; +END;; +DELIMITER ;`) + + wtExec(t, c, "ALTER PROCEDURE my_proc COMMENT 'updated comment'") + + db := c.GetDatabase("testdb") + proc := db.Procedures[toLower("my_proc")] + if proc == nil { + t.Fatal("procedure my_proc not found") + } + if proc.Characteristics["COMMENT"] != "updated comment" { + t.Errorf("expected COMMENT 'updated comment', got %q", proc.Characteristics["COMMENT"]) + } +} + +func TestWalkThrough_3_7_DropProcedure(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `DELIMITER ;; +CREATE PROCEDURE my_proc() +BEGIN + SELECT 1; +END;; +DELIMITER ;`) + + // Verify it exists first. + db := c.GetDatabase("testdb") + if db.Procedures[toLower("my_proc")] == nil { + t.Fatal("procedure should exist before drop") + } + + wtExec(t, c, "DROP PROCEDURE my_proc") + + if db.Procedures[toLower("my_proc")] != nil { + t.Error("procedure my_proc should be removed after DROP PROCEDURE") + } +} + +func TestWalkThrough_3_7_CreateTrigger(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT, val INT)") + wtExec(t, c, `DELIMITER ;; +CREATE TRIGGER trg_before_insert BEFORE INSERT ON t1 FOR EACH ROW +BEGIN + SET NEW.val = NEW.val + 1; +END;; +DELIMITER ;`) + + db := c.GetDatabase("testdb") + if db == nil { + t.Fatal("testdb not found") + } + trg := db.Triggers[toLower("trg_before_insert")] + if trg == nil { + t.Fatal("trigger trg_before_insert not found in database.Triggers") + } + if trg.Name != "trg_before_insert" { + t.Errorf("expected name trg_before_insert, got %s", trg.Name) + } + if trg.Timing != "BEFORE" { + t.Errorf("expected timing BEFORE, got %s", trg.Timing) + } + if trg.Event != "INSERT" { + t.Errorf("expected event INSERT, got %s", trg.Event) + } + if trg.Table != "t1" { + t.Errorf("expected table t1, got %s", trg.Table) + } +} + +func TestWalkThrough_3_7_DropTrigger(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT)") + wtExec(t, c, `DELIMITER ;; +CREATE TRIGGER trg1 AFTER DELETE ON t1 FOR EACH ROW +BEGIN + SELECT 1; +END;; +DELIMITER ;`) + + db := c.GetDatabase("testdb") + if db.Triggers[toLower("trg1")] == nil { + t.Fatal("trigger should exist before drop") + } + + wtExec(t, c, "DROP TRIGGER trg1") + + if db.Triggers[toLower("trg1")] != nil { + t.Error("trigger trg1 should be removed after DROP TRIGGER") + } +} + +func TestWalkThrough_3_7_CreateEvent(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE EVENT my_event +ON SCHEDULE EVERY 1 HOUR +ON COMPLETION PRESERVE +ENABLE +COMMENT 'hourly cleanup' +DO DELETE FROM t1 WHERE created < NOW() - INTERVAL 1 DAY`) + + db := c.GetDatabase("testdb") + if db == nil { + t.Fatal("testdb not found") + } + ev := db.Events[toLower("my_event")] + if ev == nil { + t.Fatal("event my_event not found in database.Events") + } + if ev.Name != "my_event" { + t.Errorf("expected name my_event, got %s", ev.Name) + } + if ev.Schedule == "" { + t.Error("expected non-empty schedule") + } + // Enable should be ENABLE or default + if ev.Enable != "" && ev.Enable != "ENABLE" { + t.Errorf("expected Enable ENABLE or empty, got %s", ev.Enable) + } + if ev.OnCompletion != "PRESERVE" { + t.Errorf("expected OnCompletion PRESERVE, got %s", ev.OnCompletion) + } +} + +func TestWalkThrough_3_7_AlterEvent(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE EVENT my_event ON SCHEDULE EVERY 1 HOUR DO SELECT 1") + + wtExec(t, c, "ALTER EVENT my_event ON SCHEDULE EVERY 2 HOUR DISABLE") + + db := c.GetDatabase("testdb") + ev := db.Events[toLower("my_event")] + if ev == nil { + t.Fatal("event my_event not found") + } + if ev.Enable != "DISABLE" { + t.Errorf("expected Enable DISABLE after alter, got %s", ev.Enable) + } + // Schedule should be updated + if ev.Schedule == "" { + t.Error("expected non-empty schedule after alter") + } +} + +func TestWalkThrough_3_7_AlterEventRename(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE EVENT old_event ON SCHEDULE EVERY 1 HOUR DO SELECT 1") + + wtExec(t, c, "ALTER EVENT old_event RENAME TO new_event") + + db := c.GetDatabase("testdb") + if db.Events[toLower("old_event")] != nil { + t.Error("old_event should not exist after rename") + } + ev := db.Events[toLower("new_event")] + if ev == nil { + t.Fatal("new_event should exist after rename") + } + if ev.Name != "new_event" { + t.Errorf("expected name new_event, got %s", ev.Name) + } +} + +func TestWalkThrough_3_7_DropEvent(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE EVENT my_event ON SCHEDULE EVERY 1 HOUR DO SELECT 1") + + db := c.GetDatabase("testdb") + if db.Events[toLower("my_event")] == nil { + t.Fatal("event should exist before drop") + } + + wtExec(t, c, "DROP EVENT my_event") + + if db.Events[toLower("my_event")] != nil { + t.Error("event my_event should be removed after DROP EVENT") + } +} diff --git a/tidb/catalog/wt_4_1_test.go b/tidb/catalog/wt_4_1_test.go new file mode 100644 index 00000000..9a99cf88 --- /dev/null +++ b/tidb/catalog/wt_4_1_test.go @@ -0,0 +1,452 @@ +package catalog + +import "testing" + +// --- 4.1 Schema Setup Migrations --- + +// TestWalkThrough_4_1_CreateDBTablesIndexes tests creating a database with multiple +// tables and indexes in a single Exec call, then verifies all objects are present. +func TestWalkThrough_4_1_CreateDBTablesIndexes(t *testing.T) { + c := New() + sql := ` +CREATE DATABASE myapp; +USE myapp; + +CREATE TABLE users ( + id INT NOT NULL AUTO_INCREMENT, + email VARCHAR(255) NOT NULL, + name VARCHAR(100), + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id), + UNIQUE KEY idx_email (email) +); + +CREATE TABLE posts ( + id INT NOT NULL AUTO_INCREMENT, + user_id INT NOT NULL, + title VARCHAR(200) NOT NULL, + body TEXT, + status ENUM('draft','published','archived') DEFAULT 'draft', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id), + INDEX idx_user_id (user_id), + INDEX idx_status_created (status, created_at) +); + +CREATE TABLE comments ( + id INT NOT NULL AUTO_INCREMENT, + post_id INT NOT NULL, + user_id INT NOT NULL, + body TEXT NOT NULL, + PRIMARY KEY (id), + INDEX idx_post_id (post_id), + INDEX idx_user_id (user_id) +); +` + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Fatalf("exec error on stmt %d: %v", r.Index, r.Error) + } + } + + db := c.GetDatabase("myapp") + if db == nil { + t.Fatal("database 'myapp' not found") + } + + // Verify tables exist. + for _, name := range []string{"users", "posts", "comments"} { + if db.GetTable(name) == nil { + t.Errorf("table %q not found", name) + } + } + + // Verify users table structure. + users := db.GetTable("users") + if users == nil { + t.Fatal("users table not found") + } + if len(users.Columns) != 4 { + t.Errorf("users: expected 4 columns, got %d", len(users.Columns)) + } + // Check PK index. + var hasPK bool + for _, idx := range users.Indexes { + if idx.Primary { + hasPK = true + if len(idx.Columns) != 1 || idx.Columns[0].Name != "id" { + t.Errorf("users PK: expected column 'id', got %v", idx.Columns) + } + } + } + if !hasPK { + t.Error("users: no PRIMARY KEY index found") + } + // Check unique index on email. + var hasEmailIdx bool + for _, idx := range users.Indexes { + if idx.Name == "idx_email" { + hasEmailIdx = true + if !idx.Unique { + t.Error("idx_email should be unique") + } + } + } + if !hasEmailIdx { + t.Error("users: idx_email index not found") + } + + // Verify posts table indexes. + posts := db.GetTable("posts") + if posts == nil { + t.Fatal("posts table not found") + } + if len(posts.Columns) != 6 { + t.Errorf("posts: expected 6 columns, got %d", len(posts.Columns)) + } + idxNames := make(map[string]bool) + for _, idx := range posts.Indexes { + idxNames[idx.Name] = true + } + for _, expected := range []string{"idx_user_id", "idx_status_created"} { + if !idxNames[expected] { + t.Errorf("posts: index %q not found", expected) + } + } + + // Verify multi-column index column order. + for _, idx := range posts.Indexes { + if idx.Name == "idx_status_created" { + if len(idx.Columns) != 2 { + t.Fatalf("idx_status_created: expected 2 columns, got %d", len(idx.Columns)) + } + if idx.Columns[0].Name != "status" || idx.Columns[1].Name != "created_at" { + t.Errorf("idx_status_created: expected [status, created_at], got [%s, %s]", + idx.Columns[0].Name, idx.Columns[1].Name) + } + } + } + + // Verify comments table. + comments := db.GetTable("comments") + if comments == nil { + t.Fatal("comments table not found") + } + if len(comments.Columns) != 4 { + t.Errorf("comments: expected 4 columns, got %d", len(comments.Columns)) + } +} + +// TestWalkThrough_4_1_CreateTableThenAddFK tests creating a table and then adding +// a foreign key that references it. +func TestWalkThrough_4_1_CreateTableThenAddFK(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, ` +CREATE TABLE parents ( + id INT NOT NULL AUTO_INCREMENT, + name VARCHAR(100), + PRIMARY KEY (id) +); +`) + + wtExec(t, c, ` +CREATE TABLE children ( + id INT NOT NULL AUTO_INCREMENT, + parent_id INT NOT NULL, + name VARCHAR(100), + PRIMARY KEY (id), + INDEX idx_parent_id (parent_id) +); +`) + + wtExec(t, c, ` +ALTER TABLE children ADD CONSTRAINT fk_parent + FOREIGN KEY (parent_id) REFERENCES parents (id) + ON DELETE CASCADE ON UPDATE CASCADE; +`) + + db := c.GetDatabase("testdb") + children := db.GetTable("children") + if children == nil { + t.Fatal("children table not found") + } + + // Find the FK constraint. + var fk *Constraint + for _, con := range children.Constraints { + if con.Type == ConForeignKey && con.Name == "fk_parent" { + fk = con + break + } + } + if fk == nil { + t.Fatal("FK constraint 'fk_parent' not found") + } + + if fk.RefTable != "parents" { + t.Errorf("FK RefTable: expected 'parents', got %q", fk.RefTable) + } + if len(fk.Columns) != 1 || fk.Columns[0] != "parent_id" { + t.Errorf("FK Columns: expected [parent_id], got %v", fk.Columns) + } + if len(fk.RefColumns) != 1 || fk.RefColumns[0] != "id" { + t.Errorf("FK RefColumns: expected [id], got %v", fk.RefColumns) + } + if fk.OnDelete != "CASCADE" { + t.Errorf("FK OnDelete: expected CASCADE, got %q", fk.OnDelete) + } + if fk.OnUpdate != "CASCADE" { + t.Errorf("FK OnUpdate: expected CASCADE, got %q", fk.OnUpdate) + } +} + +// TestWalkThrough_4_1_CreateTableThenView tests creating a table and then a view +// on it, verifying the view resolves columns correctly. +func TestWalkThrough_4_1_CreateTableThenView(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, ` +CREATE TABLE products ( + id INT NOT NULL AUTO_INCREMENT, + name VARCHAR(200) NOT NULL, + price DECIMAL(10,2) NOT NULL, + active TINYINT(1) DEFAULT 1, + PRIMARY KEY (id) +); +`) + + wtExec(t, c, ` +CREATE VIEW active_products AS + SELECT id, name, price FROM products WHERE active = 1; +`) + + db := c.GetDatabase("testdb") + v := db.Views[toLower("active_products")] + if v == nil { + t.Fatal("view 'active_products' not found") + } + + // The view should have derived columns from the SELECT. + if len(v.Columns) < 3 { + t.Fatalf("view columns: expected at least 3, got %d: %v", len(v.Columns), v.Columns) + } + + expected := []string{"id", "name", "price"} + for i, want := range expected { + if i >= len(v.Columns) { + t.Errorf("missing column %d: expected %q", i, want) + continue + } + if v.Columns[i] != want { + t.Errorf("column %d: expected %q, got %q", i, want, v.Columns[i]) + } + } +} + +// TestWalkThrough_4_1_CreateDBSetCharsetTables tests that tables inherit charset +// from the database when created after a database with explicit charset. +func TestWalkThrough_4_1_CreateDBSetCharsetTables(t *testing.T) { + c := New() + + sql := ` +CREATE DATABASE latin_db DEFAULT CHARACTER SET latin1; +USE latin_db; + +CREATE TABLE t1 ( + id INT NOT NULL AUTO_INCREMENT, + name VARCHAR(100), + PRIMARY KEY (id) +); + +CREATE TABLE t2 ( + id INT NOT NULL AUTO_INCREMENT, + description TEXT, + PRIMARY KEY (id) +) DEFAULT CHARSET=utf8mb4; +` + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Fatalf("exec error on stmt %d: %v", r.Index, r.Error) + } + } + + db := c.GetDatabase("latin_db") + if db == nil { + t.Fatal("database 'latin_db' not found") + } + if db.Charset != "latin1" { + t.Errorf("database charset: expected 'latin1', got %q", db.Charset) + } + + // t1 should inherit latin1 from the database. + t1 := db.GetTable("t1") + if t1 == nil { + t.Fatal("table t1 not found") + } + if t1.Charset != "latin1" { + t.Errorf("t1 charset: expected 'latin1', got %q", t1.Charset) + } + + // t1's VARCHAR column should inherit latin1. + nameCol := t1.GetColumn("name") + if nameCol == nil { + t.Fatal("column 'name' not found in t1") + } + if nameCol.Charset != "latin1" { + t.Errorf("t1.name charset: expected 'latin1', got %q", nameCol.Charset) + } + + // t2 has explicit utf8mb4 override. + t2 := db.GetTable("t2") + if t2 == nil { + t.Fatal("table t2 not found") + } + if t2.Charset != "utf8mb4" { + t.Errorf("t2 charset: expected 'utf8mb4', got %q", t2.Charset) + } + + // t2's TEXT column should inherit utf8mb4 from the table. + descCol := t2.GetColumn("description") + if descCol == nil { + t.Fatal("column 'description' not found in t2") + } + if descCol.Charset != "utf8mb4" { + t.Errorf("t2.description charset: expected 'utf8mb4', got %q", descCol.Charset) + } +} + +// TestWalkThrough_4_1_MysqldumpStyle tests a mysqldump-style output with SET vars, +// DELIMITER, procedures, triggers, and tables. +func TestWalkThrough_4_1_MysqldumpStyle(t *testing.T) { + c := New() + + sql := ` +SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT; +SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS; +SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION; +SET NAMES utf8mb4; +SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS; +SET FOREIGN_KEY_CHECKS=0; + +CREATE DATABASE IF NOT EXISTS dumpdb DEFAULT CHARACTER SET utf8mb4; +USE dumpdb; + +CREATE TABLE users ( + id INT NOT NULL AUTO_INCREMENT, + username VARCHAR(50) NOT NULL, + email VARCHAR(100) NOT NULL, + PRIMARY KEY (id), + UNIQUE KEY idx_username (username), + UNIQUE KEY idx_email (email) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE orders ( + id INT NOT NULL AUTO_INCREMENT, + user_id INT NOT NULL, + total DECIMAL(10,2) NOT NULL DEFAULT 0.00, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id), + KEY idx_user_id (user_id), + CONSTRAINT fk_orders_user FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +DELIMITER ;; + +CREATE PROCEDURE get_user_orders(IN p_user_id INT) +BEGIN + SELECT * FROM orders WHERE user_id = p_user_id; +END;; + +CREATE TRIGGER trg_order_after_insert AFTER INSERT ON orders +FOR EACH ROW +BEGIN + UPDATE users SET email = email WHERE id = NEW.user_id; +END;; + +DELIMITER ; + +SET FOREIGN_KEY_CHECKS=@OLD_FOREIGN_KEY_CHECKS; +SET CHARACTER_SET_CLIENT=@OLD_CHARACTER_SET_CLIENT; +SET CHARACTER_SET_RESULTS=@OLD_CHARACTER_SET_RESULTS; +SET COLLATION_CONNECTION=@OLD_COLLATION_CONNECTION; +` + + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Fatalf("exec error on stmt %d (line %d): %v", r.Index, r.Line, r.Error) + } + } + + db := c.GetDatabase("dumpdb") + if db == nil { + t.Fatal("database 'dumpdb' not found") + } + + // Verify tables. + users := db.GetTable("users") + if users == nil { + t.Fatal("table 'users' not found") + } + if users.Engine != "InnoDB" { + t.Errorf("users engine: expected 'InnoDB', got %q", users.Engine) + } + + orders := db.GetTable("orders") + if orders == nil { + t.Fatal("table 'orders' not found") + } + + // Verify FK on orders. + var fk *Constraint + for _, con := range orders.Constraints { + if con.Type == ConForeignKey && con.Name == "fk_orders_user" { + fk = con + break + } + } + if fk == nil { + t.Fatal("FK constraint 'fk_orders_user' not found on orders table") + } + if fk.RefTable != "users" { + t.Errorf("FK RefTable: expected 'users', got %q", fk.RefTable) + } + if fk.OnDelete != "CASCADE" { + t.Errorf("FK OnDelete: expected 'CASCADE', got %q", fk.OnDelete) + } + + // Verify procedure. + proc := db.Procedures[toLower("get_user_orders")] + if proc == nil { + t.Fatal("procedure 'get_user_orders' not found") + } + + // Verify trigger. + trg := db.Triggers[toLower("trg_order_after_insert")] + if trg == nil { + t.Fatal("trigger 'trg_order_after_insert' not found") + } + if trg.Timing != "AFTER" { + t.Errorf("trigger timing: expected 'AFTER', got %q", trg.Timing) + } + if trg.Event != "INSERT" { + t.Errorf("trigger event: expected 'INSERT', got %q", trg.Event) + } + + // Verify FK checks were re-enabled at the end. + // The last SET FOREIGN_KEY_CHECKS=@OLD_FOREIGN_KEY_CHECKS uses a variable reference, + // which may not be interpreted. We just verify catalog is in a valid state. + // The key point is the script executed without error. +} diff --git a/tidb/catalog/wt_4_2_test.go b/tidb/catalog/wt_4_2_test.go new file mode 100644 index 00000000..22c4e73f --- /dev/null +++ b/tidb/catalog/wt_4_2_test.go @@ -0,0 +1,315 @@ +package catalog + +import ( + "testing" +) + +// TestWalkThrough_4_2_AddColumnThenIndex tests that after adding a column, +// an index can be created on it and references the new column. +func TestWalkThrough_4_2_AddColumnThenIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT PRIMARY KEY)") + wtExec(t, c, "ALTER TABLE t1 ADD COLUMN email VARCHAR(255)") + wtExec(t, c, "CREATE INDEX idx_email ON t1 (email)") + + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Fatal("table t1 not found") + } + + // Verify column exists. + col := tbl.GetColumn("email") + if col == nil { + t.Fatal("column email not found") + } + + // Verify index exists and references new column. + var found *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_email" { + found = idx + break + } + } + if found == nil { + t.Fatal("index idx_email not found") + } + if len(found.Columns) != 1 || found.Columns[0].Name != "email" { + t.Fatalf("expected index column [email], got %v", found.Columns) + } +} + +// TestWalkThrough_4_2_RenameColumnIndexRef tests that after renaming a column, +// existing indexes still reference the correct (new) column name. +func TestWalkThrough_4_2_RenameColumnIndexRef(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT PRIMARY KEY, name VARCHAR(100))") + wtExec(t, c, "CREATE INDEX idx_name ON t1 (name)") + wtExec(t, c, "ALTER TABLE t1 RENAME COLUMN name TO full_name") + + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Fatal("table t1 not found") + } + + // Column should be renamed. + if tbl.GetColumn("name") != nil { + t.Error("old column name 'name' should not exist") + } + if tbl.GetColumn("full_name") == nil { + t.Fatal("renamed column 'full_name' not found") + } + + // Index should reference the new name. + var found *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_name" { + found = idx + break + } + } + if found == nil { + t.Fatal("index idx_name not found") + } + if len(found.Columns) != 1 || found.Columns[0].Name != "full_name" { + t.Fatalf("expected index column [full_name], got %v", indexColNames(found)) + } +} + +// TestWalkThrough_4_2_AddColumnThenFK tests that after adding a column, +// a foreign key using that column is correctly created. +func TestWalkThrough_4_2_AddColumnThenFK(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, "CREATE TABLE child (id INT PRIMARY KEY)") + wtExec(t, c, "ALTER TABLE child ADD COLUMN parent_id INT") + wtExec(t, c, "ALTER TABLE child ADD CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent(id)") + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + // Find FK constraint. + var fk *Constraint + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey && con.Name == "fk_parent" { + fk = con + break + } + } + if fk == nil { + t.Fatal("FK constraint fk_parent not found") + } + if len(fk.Columns) != 1 || fk.Columns[0] != "parent_id" { + t.Fatalf("expected FK columns [parent_id], got %v", fk.Columns) + } + if fk.RefTable != "parent" { + t.Fatalf("expected RefTable=parent, got %s", fk.RefTable) + } + if len(fk.RefColumns) != 1 || fk.RefColumns[0] != "id" { + t.Fatalf("expected FK RefColumns [id], got %v", fk.RefColumns) + } +} + +// TestWalkThrough_4_2_DropIndexRecreate tests dropping an index and re-creating +// it with different columns: the old is gone, the new is present. +func TestWalkThrough_4_2_DropIndexRecreate(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT PRIMARY KEY, a INT, b INT, c INT)") + wtExec(t, c, "CREATE INDEX idx_ab ON t1 (a, b)") + + // Verify initial index. + tbl := c.GetDatabase("testdb").GetTable("t1") + if findIndex(tbl, "idx_ab") == nil { + t.Fatal("initial idx_ab not found") + } + + // Drop and re-create with different columns. + wtExec(t, c, "DROP INDEX idx_ab ON t1") + if findIndex(tbl, "idx_ab") != nil { + t.Fatal("idx_ab should be gone after DROP") + } + + wtExec(t, c, "CREATE INDEX idx_ab ON t1 (b, c)") + idx := findIndex(tbl, "idx_ab") + if idx == nil { + t.Fatal("re-created idx_ab not found") + } + names := indexColNames(idx) + if len(names) != 2 || names[0] != "b" || names[1] != "c" { + t.Fatalf("expected index columns [b, c], got %v", names) + } +} + +// TestWalkThrough_4_2_AlterTableMultipleCommands tests that multiple ALTER TABLE +// sub-commands applied in sequence produce the correct cumulative effect. +func TestWalkThrough_4_2_AlterTableMultipleCommands(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t1 (id INT PRIMARY KEY, a INT, b INT)") + + // Multiple ALTERs in sequence (separate statements, not multi-command). + wtExec(t, c, "ALTER TABLE t1 ADD COLUMN c VARCHAR(50)") + wtExec(t, c, "ALTER TABLE t1 DROP COLUMN b") + wtExec(t, c, "ALTER TABLE t1 MODIFY COLUMN a BIGINT") + + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Fatal("table t1 not found") + } + + // Should have: id, a, c + if len(tbl.Columns) != 3 { + t.Fatalf("expected 3 columns, got %d", len(tbl.Columns)) + } + if tbl.Columns[0].Name != "id" { + t.Errorf("col 0: expected id, got %s", tbl.Columns[0].Name) + } + if tbl.Columns[1].Name != "a" { + t.Errorf("col 1: expected a, got %s", tbl.Columns[1].Name) + } + if tbl.Columns[2].Name != "c" { + t.Errorf("col 2: expected c, got %s", tbl.Columns[2].Name) + } + + // a should be BIGINT now. + colA := tbl.GetColumn("a") + if colA.DataType != "bigint" { + t.Errorf("expected column a type=bigint, got %s", colA.DataType) + } + + // b should be gone. + if tbl.GetColumn("b") != nil { + t.Error("column b should have been dropped") + } +} + +// TestWalkThrough_4_2_RenameTableThenView tests renaming a table then creating +// a view on the new name: the view should resolve. +func TestWalkThrough_4_2_RenameTableThenView(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE old_t (id INT PRIMARY KEY, val TEXT)") + wtExec(t, c, "RENAME TABLE old_t TO new_t") + + db := c.GetDatabase("testdb") + if db.GetTable("old_t") != nil { + t.Error("old_t should no longer exist") + } + if db.GetTable("new_t") == nil { + t.Fatal("new_t should exist after rename") + } + + wtExec(t, c, "CREATE VIEW v1 AS SELECT id, val FROM new_t") + + v := db.Views[toLower("v1")] + if v == nil { + t.Fatal("view v1 not found") + } + // The view should exist in the database and have a definition. + if v.Definition == "" { + t.Error("view v1 should have a non-empty definition") + } +} + +// TestWalkThrough_4_2_ChangeColumnTypeGeneratedColumn tests that after changing +// a column type, a dependent generated column is still recorded. +func TestWalkThrough_4_2_ChangeColumnTypeGeneratedColumn(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + a INT, + b INT GENERATED ALWAYS AS (a * 2) STORED + )`) + + tbl := c.GetDatabase("testdb").GetTable("t1") + colB := tbl.GetColumn("b") + if colB == nil { + t.Fatal("column b not found") + } + if colB.Generated == nil { + t.Fatal("column b should be generated") + } + if !colB.Generated.Stored { + t.Error("column b should be STORED") + } + + // Change column a type to BIGINT. + wtExec(t, c, "ALTER TABLE t1 MODIFY COLUMN a BIGINT") + + // Verify generated column b is still recorded. + colB = tbl.GetColumn("b") + if colB == nil { + t.Fatal("column b not found after modify") + } + if colB.Generated == nil { + t.Fatal("column b should still be generated after modifying a") + } + if !colB.Generated.Stored { + t.Error("column b should still be STORED") + } +} + +// TestWalkThrough_4_2_ConvertCharset tests CONVERT TO CHARACTER SET updates +// all string columns. +func TestWalkThrough_4_2_ConvertCharset(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT, + name VARCHAR(100), + bio TEXT, + age INT, + tag ENUM('a','b') + )`) + + // Convert to latin1. + wtExec(t, c, "ALTER TABLE t1 CONVERT TO CHARACTER SET latin1") + + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Fatal("table t1 not found") + } + + // Table-level charset/collation should be updated. + if tbl.Charset != "latin1" { + t.Errorf("expected table charset=latin1, got %s", tbl.Charset) + } + + // String columns should have latin1. + for _, colName := range []string{"name", "bio", "tag"} { + col := tbl.GetColumn(colName) + if col == nil { + t.Fatalf("column %s not found", colName) + } + if col.Charset != "latin1" { + t.Errorf("column %s: expected charset=latin1, got %s", colName, col.Charset) + } + } + + // Non-string columns should not have charset changed. + colID := tbl.GetColumn("id") + if colID.Charset != "" { + t.Errorf("INT column id should not have charset, got %s", colID.Charset) + } + colAge := tbl.GetColumn("age") + if colAge.Charset != "" { + t.Errorf("INT column age should not have charset, got %s", colAge.Charset) + } +} + +// --- helpers --- + +func findIndex(tbl *Table, name string) *Index { + for _, idx := range tbl.Indexes { + if toLower(idx.Name) == toLower(name) { + return idx + } + } + return nil +} + +func indexColNames(idx *Index) []string { + names := make([]string, len(idx.Columns)) + for i, ic := range idx.Columns { + names[i] = ic.Name + } + return names +} diff --git a/tidb/catalog/wt_4_3_test.go b/tidb/catalog/wt_4_3_test.go new file mode 100644 index 00000000..c7c6cd6d --- /dev/null +++ b/tidb/catalog/wt_4_3_test.go @@ -0,0 +1,327 @@ +package catalog + +import "testing" + +// --- 4.3 Error Detection in Migrations --- + +// TestWalkThrough_4_3_ContinueOnErrorMigration tests that with ContinueOnError, +// a first CREATE succeeds, a duplicate CREATE fails, and a third ALTER on the +// first table succeeds. +func TestWalkThrough_4_3_ContinueOnErrorMigration(t *testing.T) { + c := wtSetup(t) + + sql := `CREATE TABLE t1 (id INT PRIMARY KEY, name VARCHAR(100)); +CREATE TABLE t1 (id INT PRIMARY KEY); +ALTER TABLE t1 ADD COLUMN email VARCHAR(255);` + + results, err := c.Exec(sql, &ExecOptions{ContinueOnError: true}) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if len(results) != 3 { + t.Fatalf("expected 3 results, got %d", len(results)) + } + + // First CREATE succeeds. + assertNoError(t, results[0].Error) + + // Duplicate CREATE fails. + assertError(t, results[1].Error, ErrDupTable) + + // ALTER on first table succeeds (ContinueOnError keeps going). + assertNoError(t, results[2].Error) + + // Verify final state: t1 has id, name, email. + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Fatal("table t1 not found") + } + if len(tbl.Columns) != 3 { + t.Fatalf("expected 3 columns, got %d", len(tbl.Columns)) + } + if tbl.GetColumn("email") == nil { + t.Error("column email not found after ContinueOnError ALTER") + } +} + +// TestWalkThrough_4_3_ContinueOnErrorCodes tests that ContinueOnError produces +// the correct error code for each failing statement. +func TestWalkThrough_4_3_ContinueOnErrorCodes(t *testing.T) { + c := wtSetup(t) + + sql := `CREATE TABLE t1 (id INT PRIMARY KEY); +CREATE TABLE t1 (id INT); +ALTER TABLE no_such_table ADD COLUMN x INT; +CREATE TABLE t2 (id INT PRIMARY KEY);` + + results, err := c.Exec(sql, &ExecOptions{ContinueOnError: true}) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if len(results) != 4 { + t.Fatalf("expected 4 results, got %d", len(results)) + } + + assertNoError(t, results[0].Error) + assertError(t, results[1].Error, ErrDupTable) + assertError(t, results[2].Error, ErrNoSuchTable) + assertNoError(t, results[3].Error) + + // Both t1 and t2 should exist. + db := c.GetDatabase("testdb") + if db.GetTable("t1") == nil { + t.Error("table t1 not found") + } + if db.GetTable("t2") == nil { + t.Error("table t2 not found") + } +} + +// TestWalkThrough_4_3_FKCycle tests creating a FK cycle (A refs B, B refs A) +// with fk_checks=0, then verifying state after fk_checks=1. +func TestWalkThrough_4_3_FKCycle(t *testing.T) { + c := wtSetup(t) + + sql := `SET foreign_key_checks = 0; +CREATE TABLE a ( + id INT PRIMARY KEY, + b_id INT, + CONSTRAINT fk_a_b FOREIGN KEY (b_id) REFERENCES b (id) +); +CREATE TABLE b ( + id INT PRIMARY KEY, + a_id INT, + CONSTRAINT fk_b_a FOREIGN KEY (a_id) REFERENCES a (id) +); +SET foreign_key_checks = 1;` + + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Fatalf("exec error on stmt %d: %v", r.Index, r.Error) + } + } + + db := c.GetDatabase("testdb") + + // Both tables should exist. + tblA := db.GetTable("a") + if tblA == nil { + t.Fatal("table a not found") + } + tblB := db.GetTable("b") + if tblB == nil { + t.Fatal("table b not found") + } + + // Verify FK constraints. + var fkAB *Constraint + for _, con := range tblA.Constraints { + if con.Type == ConForeignKey && con.Name == "fk_a_b" { + fkAB = con + break + } + } + if fkAB == nil { + t.Fatal("FK fk_a_b not found on table a") + } + if fkAB.RefTable != "b" { + t.Errorf("fk_a_b RefTable: expected 'b', got %q", fkAB.RefTable) + } + + var fkBA *Constraint + for _, con := range tblB.Constraints { + if con.Type == ConForeignKey && con.Name == "fk_b_a" { + fkBA = con + break + } + } + if fkBA == nil { + t.Fatal("FK fk_b_a not found on table b") + } + if fkBA.RefTable != "a" { + t.Errorf("fk_b_a RefTable: expected 'a', got %q", fkBA.RefTable) + } +} + +// TestWalkThrough_4_3_DropTableCascadeError tests that dropping a parent table +// referenced by a FK child produces the correct error. +func TestWalkThrough_4_3_DropTableCascadeError(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE parent (id INT PRIMARY KEY); +CREATE TABLE child ( + id INT PRIMARY KEY, + parent_id INT, + CONSTRAINT fk_child_parent FOREIGN KEY (parent_id) REFERENCES parent (id) +);`) + + // Attempt to drop parent with FK child referencing it. + results, err := c.Exec("DROP TABLE parent", nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrFKCannotDropParent) + + // Parent should still exist. + if c.GetDatabase("testdb").GetTable("parent") == nil { + t.Error("parent table should still exist after failed DROP") + } +} + +// TestWalkThrough_4_3_AlterMissingTableLine tests that ALTER on a missing table +// produces the correct error at the correct line. +func TestWalkThrough_4_3_AlterMissingTableLine(t *testing.T) { + c := wtSetup(t) + + sql := "CREATE TABLE t1 (id INT PRIMARY KEY);\n" + + "CREATE TABLE t2 (id INT PRIMARY KEY);\n" + + "ALTER TABLE no_such_table ADD COLUMN x INT;\n" + + "CREATE TABLE t3 (id INT PRIMARY KEY);" + + results, err := c.Exec(sql, &ExecOptions{ContinueOnError: true}) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if len(results) != 4 { + t.Fatalf("expected 4 results, got %d", len(results)) + } + + // First two succeed. + assertNoError(t, results[0].Error) + assertNoError(t, results[1].Error) + + // Third statement fails with correct error. + assertError(t, results[2].Error, ErrNoSuchTable) + + // Verify the line number is correct (line 3). + if results[2].Line != 3 { + t.Errorf("expected error on line 3, got line %d", results[2].Line) + } + + // Fourth succeeds (ContinueOnError). + assertNoError(t, results[3].Error) +} + +// TestWalkThrough_4_3_MultipleErrorsAllCodes tests that multiple errors in one +// migration all have correct error codes and correct line numbers. +func TestWalkThrough_4_3_MultipleErrorsAllCodes(t *testing.T) { + c := wtSetup(t) + + sql := "CREATE TABLE t1 (id INT PRIMARY KEY);\n" + + "CREATE TABLE t1 (id INT);\n" + + "ALTER TABLE missing ADD COLUMN x INT;\n" + + "ALTER TABLE t1 DROP COLUMN no_such;\n" + + "ALTER TABLE t1 ADD COLUMN val TEXT;" + + results, err := c.Exec(sql, &ExecOptions{ContinueOnError: true}) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if len(results) != 5 { + t.Fatalf("expected 5 results, got %d", len(results)) + } + + // Statement 0: CREATE TABLE t1 succeeds. + assertNoError(t, results[0].Error) + if results[0].Line != 1 { + t.Errorf("stmt 0: expected line 1, got %d", results[0].Line) + } + + // Statement 1: duplicate table error. + assertError(t, results[1].Error, ErrDupTable) + if results[1].Line != 2 { + t.Errorf("stmt 1: expected line 2, got %d", results[1].Line) + } + + // Statement 2: no such table error. + assertError(t, results[2].Error, ErrNoSuchTable) + if results[2].Line != 3 { + t.Errorf("stmt 2: expected line 3, got %d", results[2].Line) + } + + // Statement 3: DROP COLUMN on nonexistent column returns 1091 (same as DROP INDEX in MySQL 8.0). + assertError(t, results[3].Error, ErrCantDropKey) + if results[3].Line != 4 { + t.Errorf("stmt 3: expected line 4, got %d", results[3].Line) + } + + // Statement 4: succeeds. + assertNoError(t, results[4].Error) + if results[4].Line != 5 { + t.Errorf("stmt 4: expected line 5, got %d", results[4].Line) + } + + // Final state: t1 has id, val. + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Fatal("table t1 not found") + } + if len(tbl.Columns) != 2 { + t.Fatalf("expected 2 columns (id, val), got %d", len(tbl.Columns)) + } + if tbl.GetColumn("val") == nil { + t.Error("column val not found") + } +} + +// TestWalkThrough_4_3_MixedDMLAndDDL tests that DML statements are skipped, +// DDL statements are executed, and the final state is correct. +func TestWalkThrough_4_3_MixedDMLAndDDL(t *testing.T) { + c := wtSetup(t) + + sql := `CREATE TABLE t1 (id INT PRIMARY KEY, name VARCHAR(100)); +INSERT INTO t1 VALUES (1, 'Alice'); +ALTER TABLE t1 ADD COLUMN email VARCHAR(255); +SELECT * FROM t1; +UPDATE t1 SET name = 'Bob' WHERE id = 1; +CREATE INDEX idx_name ON t1 (name); +DELETE FROM t1 WHERE id = 1;` + + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if len(results) != 7 { + t.Fatalf("expected 7 results, got %d", len(results)) + } + + // Verify DML statements are skipped. + if !results[1].Skipped { + t.Error("INSERT should be skipped") + } + if !results[3].Skipped { + t.Error("SELECT should be skipped") + } + if !results[4].Skipped { + t.Error("UPDATE should be skipped") + } + if !results[6].Skipped { + t.Error("DELETE should be skipped") + } + + // Verify DDL statements succeeded. + assertNoError(t, results[0].Error) // CREATE TABLE + assertNoError(t, results[2].Error) // ALTER TABLE ADD COLUMN + assertNoError(t, results[5].Error) // CREATE INDEX + + // Verify final state: t1 has id, name, email and index idx_name. + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Fatal("table t1 not found") + } + if len(tbl.Columns) != 3 { + t.Fatalf("expected 3 columns, got %d", len(tbl.Columns)) + } + if tbl.GetColumn("email") == nil { + t.Error("column email not found") + } + + // Verify index. + if findIndex(tbl, "idx_name") == nil { + t.Error("index idx_name not found") + } +} diff --git a/tidb/catalog/wt_5_1_test.go b/tidb/catalog/wt_5_1_test.go new file mode 100644 index 00000000..b353f6f4 --- /dev/null +++ b/tidb/catalog/wt_5_1_test.go @@ -0,0 +1,211 @@ +package catalog + +import "testing" + +// --- 5.1 (mapped from starmap 1.1) Column Repositioning Interactions --- + +func TestWalkThrough_5_1_AddAfterJustAdded(t *testing.T) { + // ADD COLUMN x AFTER a, ADD COLUMN y AFTER x — second command references column added by first + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT)") + wtExec(t, c, "ALTER TABLE t ADD COLUMN x INT AFTER a, ADD COLUMN y INT AFTER x") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + // Expected order: a, x, y, b, c + expected := []string{"a", "x", "y", "b", "c"} + if len(tbl.Columns) != len(expected) { + t.Fatalf("expected %d columns, got %d", len(expected), len(tbl.Columns)) + } + for i, name := range expected { + if tbl.Columns[i].Name != name { + t.Errorf("column %d: expected %q, got %q", i, name, tbl.Columns[i].Name) + } + if tbl.Columns[i].Position != i+1 { + t.Errorf("column %q: expected position %d, got %d", name, i+1, tbl.Columns[i].Position) + } + } +} + +func TestWalkThrough_5_1_AddFirstTwice(t *testing.T) { + // ADD COLUMN x FIRST, ADD COLUMN y FIRST — both FIRST, y should end up before x + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT)") + wtExec(t, c, "ALTER TABLE t ADD COLUMN x INT FIRST, ADD COLUMN y INT FIRST") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + // Expected order: y, x, a, b, c + expected := []string{"y", "x", "a", "b", "c"} + if len(tbl.Columns) != len(expected) { + t.Fatalf("expected %d columns, got %d", len(expected), len(tbl.Columns)) + } + for i, name := range expected { + if tbl.Columns[i].Name != name { + t.Errorf("column %d: expected %q, got %q", i, name, tbl.Columns[i].Name) + } + if tbl.Columns[i].Position != i+1 { + t.Errorf("column %q: expected position %d, got %d", name, i+1, tbl.Columns[i].Position) + } + } +} + +func TestWalkThrough_5_1_AddAfterThenDropThat(t *testing.T) { + // ADD COLUMN x AFTER a, DROP COLUMN a — add after a column that is then dropped in same statement + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT)") + wtExec(t, c, "ALTER TABLE t ADD COLUMN x INT AFTER a, DROP COLUMN a") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + // Expected order: x, b, c (a was dropped) + expected := []string{"x", "b", "c"} + if len(tbl.Columns) != len(expected) { + t.Fatalf("expected %d columns, got %d", len(expected), len(tbl.Columns)) + } + for i, name := range expected { + if tbl.Columns[i].Name != name { + t.Errorf("column %d: expected %q, got %q", i, name, tbl.Columns[i].Name) + } + if tbl.Columns[i].Position != i+1 { + t.Errorf("column %q: expected position %d, got %d", name, i+1, tbl.Columns[i].Position) + } + } +} + +func TestWalkThrough_5_1_ModifyAfterChain(t *testing.T) { + // MODIFY COLUMN a INT AFTER c, MODIFY COLUMN b INT AFTER a — chain of AFTER references + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT)") + wtExec(t, c, "ALTER TABLE t MODIFY COLUMN a INT AFTER c, MODIFY COLUMN b INT AFTER a") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + // Start: a, b, c + // After MODIFY a AFTER c: b, c, a + // After MODIFY b AFTER a: c, a, b + expected := []string{"c", "a", "b"} + if len(tbl.Columns) != len(expected) { + t.Fatalf("expected %d columns, got %d", len(expected), len(tbl.Columns)) + } + for i, name := range expected { + if tbl.Columns[i].Name != name { + t.Errorf("column %d: expected %q, got %q", i, name, tbl.Columns[i].Name) + } + if tbl.Columns[i].Position != i+1 { + t.Errorf("column %q: expected position %d, got %d", name, i+1, tbl.Columns[i].Position) + } + } +} + +func TestWalkThrough_5_1_ModifyFirstThenAddAfter(t *testing.T) { + // MODIFY COLUMN a INT FIRST, ADD COLUMN x INT AFTER a — FIRST then AFTER the moved column + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT)") + // a is already first, but this tests referencing after MODIFY FIRST + wtExec(t, c, "ALTER TABLE t MODIFY COLUMN c INT FIRST, ADD COLUMN x INT AFTER c") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + // Start: a, b, c + // After MODIFY c FIRST: c, a, b + // After ADD x AFTER c: c, x, a, b + expected := []string{"c", "x", "a", "b"} + if len(tbl.Columns) != len(expected) { + t.Fatalf("expected %d columns, got %d", len(expected), len(tbl.Columns)) + } + for i, name := range expected { + if tbl.Columns[i].Name != name { + t.Errorf("column %d: expected %q, got %q", i, name, tbl.Columns[i].Name) + } + if tbl.Columns[i].Position != i+1 { + t.Errorf("column %q: expected position %d, got %d", name, i+1, tbl.Columns[i].Position) + } + } +} + +func TestWalkThrough_5_1_DropThenReAdd(t *testing.T) { + // DROP COLUMN a, ADD COLUMN a INT — drop then re-add same column name + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT)") + wtExec(t, c, "ALTER TABLE t DROP COLUMN a, ADD COLUMN a INT") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + // Expected order: b, c, a (dropped from front, re-added at end) + expected := []string{"b", "c", "a"} + if len(tbl.Columns) != len(expected) { + t.Fatalf("expected %d columns, got %d", len(expected), len(tbl.Columns)) + } + for i, name := range expected { + if tbl.Columns[i].Name != name { + t.Errorf("column %d: expected %q, got %q", i, name, tbl.Columns[i].Name) + } + if tbl.Columns[i].Position != i+1 { + t.Errorf("column %q: expected position %d, got %d", name, i+1, tbl.Columns[i].Position) + } + } +} + +func TestWalkThrough_5_1_ChangeThenAddAfterNewName(t *testing.T) { + // CHANGE COLUMN a b_new INT, ADD COLUMN d INT AFTER b_new — rename then reference new name + // Note: using b_new instead of b to avoid conflict with existing column b + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT)") + wtExec(t, c, "ALTER TABLE t CHANGE COLUMN a a_new INT, ADD COLUMN d INT AFTER a_new") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + // Expected order: a_new, d, b, c + expected := []string{"a_new", "d", "b", "c"} + if len(tbl.Columns) != len(expected) { + t.Fatalf("expected %d columns, got %d", len(expected), len(tbl.Columns)) + } + for i, name := range expected { + if tbl.Columns[i].Name != name { + t.Errorf("column %d: expected %q, got %q", i, name, tbl.Columns[i].Name) + } + if tbl.Columns[i].Position != i+1 { + t.Errorf("column %q: expected position %d, got %d", name, i+1, tbl.Columns[i].Position) + } + } +} + +func TestWalkThrough_5_1_ThreeAppends(t *testing.T) { + // ADD COLUMN x INT, ADD COLUMN y INT, ADD COLUMN z INT — three appends, verify final order + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT)") + wtExec(t, c, "ALTER TABLE t ADD COLUMN x INT, ADD COLUMN y INT, ADD COLUMN z INT") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + // Expected order: a, b, c, x, y, z + expected := []string{"a", "b", "c", "x", "y", "z"} + if len(tbl.Columns) != len(expected) { + t.Fatalf("expected %d columns, got %d", len(expected), len(tbl.Columns)) + } + for i, name := range expected { + if tbl.Columns[i].Name != name { + t.Errorf("column %d: expected %q, got %q", i, name, tbl.Columns[i].Name) + } + if tbl.Columns[i].Position != i+1 { + t.Errorf("column %q: expected position %d, got %d", name, i+1, tbl.Columns[i].Position) + } + } +} diff --git a/tidb/catalog/wt_5_2_test.go b/tidb/catalog/wt_5_2_test.go new file mode 100644 index 00000000..e8b8edc5 --- /dev/null +++ b/tidb/catalog/wt_5_2_test.go @@ -0,0 +1,262 @@ +package catalog + +import "testing" + +// Section 1.2 — Column + Index Interactions +// Tests multi-command ALTER TABLE scenarios where column and index operations interact. + +// Scenario: ADD COLUMN x INT, ADD INDEX idx_x (x) — add column then index on it in same ALTER +func TestWalkThrough_5_2_AddColumnThenIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(100))") + wtExec(t, c, "ALTER TABLE t ADD COLUMN x INT, ADD INDEX idx_x (x)") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + // Verify column was added. + col := tbl.GetColumn("x") + if col == nil { + t.Fatal("column x not found") + } + + // Verify index was created. + var found *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_x" { + found = idx + break + } + } + if found == nil { + t.Fatal("index idx_x not found") + } + if found.Unique { + t.Error("idx_x should not be unique") + } + if len(found.Columns) != 1 || found.Columns[0].Name != "x" { + t.Errorf("index columns mismatch: %+v", found.Columns) + } +} + +// Scenario: ADD COLUMN x INT, ADD UNIQUE INDEX ux (x) — add column then unique index +func TestWalkThrough_5_2_AddColumnThenUniqueIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(100))") + wtExec(t, c, "ALTER TABLE t ADD COLUMN x INT, ADD UNIQUE INDEX ux (x)") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + // Verify column was added. + col := tbl.GetColumn("x") + if col == nil { + t.Fatal("column x not found") + } + + // Verify unique index was created. + var found *Index + for _, idx := range tbl.Indexes { + if idx.Name == "ux" { + found = idx + break + } + } + if found == nil { + t.Fatal("unique index ux not found") + } + if !found.Unique { + t.Error("ux should be unique") + } + if len(found.Columns) != 1 || found.Columns[0].Name != "x" { + t.Errorf("unique index columns mismatch: %+v", found.Columns) + } +} + +// Scenario: DROP COLUMN x, DROP INDEX idx_x — drop column and its index simultaneously +func TestWalkThrough_5_2_DropColumnAndIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(100), x INT, INDEX idx_x (x))") + + // Verify setup: column and index exist. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl.GetColumn("x") == nil { + t.Fatal("column x should exist before drop") + } + + wtExec(t, c, "ALTER TABLE t DROP COLUMN x, DROP INDEX idx_x") + + tbl = c.GetDatabase("testdb").GetTable("t") + + // Verify column was dropped. + if tbl.GetColumn("x") != nil { + t.Error("column x should have been dropped") + } + + // Verify index was dropped. + for _, idx := range tbl.Indexes { + if idx.Name == "idx_x" { + t.Error("index idx_x should have been dropped") + } + } + + // Verify remaining columns are correct. + if len(tbl.Columns) != 2 { + t.Errorf("expected 2 columns after drop, got %d", len(tbl.Columns)) + } +} + +// Scenario: MODIFY COLUMN x VARCHAR(200), DROP INDEX idx_x, ADD INDEX idx_x (x) — rebuild index after type change +func TestWalkThrough_5_2_ModifyColumnRebuildIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, x VARCHAR(100), INDEX idx_x (x))") + + wtExec(t, c, "ALTER TABLE t MODIFY COLUMN x VARCHAR(200), DROP INDEX idx_x, ADD INDEX idx_x (x)") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + // Verify column type was modified. + col := tbl.GetColumn("x") + if col == nil { + t.Fatal("column x not found") + } + if col.ColumnType != "varchar(200)" { + t.Errorf("expected column type 'varchar(200)', got %q", col.ColumnType) + } + + // Verify index was recreated. + var found *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_x" { + found = idx + break + } + } + if found == nil { + t.Fatal("index idx_x not found after rebuild") + } + if len(found.Columns) != 1 || found.Columns[0].Name != "x" { + t.Errorf("index columns mismatch: %+v", found.Columns) + } +} + +// Scenario: CHANGE COLUMN x y INT, ADD INDEX idx_y (y) — rename column then index with new name +func TestWalkThrough_5_2_ChangeColumnThenIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, x INT)") + + wtExec(t, c, "ALTER TABLE t CHANGE COLUMN x y INT, ADD INDEX idx_y (y)") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + // Verify old column is gone and new column exists. + if tbl.GetColumn("x") != nil { + t.Error("column x should no longer exist after CHANGE") + } + col := tbl.GetColumn("y") + if col == nil { + t.Fatal("column y not found after CHANGE") + } + + // Verify index on new column name. + var found *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_y" { + found = idx + break + } + } + if found == nil { + t.Fatal("index idx_y not found") + } + if len(found.Columns) != 1 || found.Columns[0].Name != "y" { + t.Errorf("index columns mismatch: %+v", found.Columns) + } +} + +// Scenario: ADD COLUMN x INT, ADD PRIMARY KEY (id, x) — add column then include in new PK +func TestWalkThrough_5_2_AddColumnThenPrimaryKey(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT NOT NULL, name VARCHAR(100))") + + wtExec(t, c, "ALTER TABLE t ADD COLUMN x INT NOT NULL, ADD PRIMARY KEY (id, x)") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + // Verify column was added. + col := tbl.GetColumn("x") + if col == nil { + t.Fatal("column x not found") + } + if col.Nullable { + t.Error("column x should be NOT NULL") + } + + // Verify primary key exists with both columns. + var pkIdx *Index + for _, idx := range tbl.Indexes { + if idx.Primary { + pkIdx = idx + break + } + } + if pkIdx == nil { + t.Fatal("primary key not found") + } + if len(pkIdx.Columns) != 2 { + t.Fatalf("expected 2 PK columns, got %d", len(pkIdx.Columns)) + } + if pkIdx.Columns[0].Name != "id" { + t.Errorf("expected first PK column 'id', got %q", pkIdx.Columns[0].Name) + } + if pkIdx.Columns[1].Name != "x" { + t.Errorf("expected second PK column 'x', got %q", pkIdx.Columns[1].Name) + } +} + +// Scenario: DROP INDEX idx_x, ADD INDEX idx_x (x, y) — drop and recreate index with extra column +func TestWalkThrough_5_2_DropAndRecreateIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT PRIMARY KEY, x INT, y INT, INDEX idx_x (x))") + + wtExec(t, c, "ALTER TABLE t DROP INDEX idx_x, ADD INDEX idx_x (x, y)") + + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl == nil { + t.Fatal("table not found") + } + + // Verify index was recreated with two columns. + var found *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_x" { + found = idx + break + } + } + if found == nil { + t.Fatal("index idx_x not found after recreate") + } + if len(found.Columns) != 2 { + t.Fatalf("expected 2 index columns, got %d", len(found.Columns)) + } + if found.Columns[0].Name != "x" { + t.Errorf("expected first index column 'x', got %q", found.Columns[0].Name) + } + if found.Columns[1].Name != "y" { + t.Errorf("expected second index column 'y', got %q", found.Columns[1].Name) + } +} diff --git a/tidb/catalog/wt_5_3_test.go b/tidb/catalog/wt_5_3_test.go new file mode 100644 index 00000000..9e89e3da --- /dev/null +++ b/tidb/catalog/wt_5_3_test.go @@ -0,0 +1,249 @@ +package catalog + +import "testing" + +// Section 1.3: Column + FK Interactions (multi-command ALTER TABLE) + +// Scenario 1: ADD COLUMN parent_id INT, ADD CONSTRAINT fk FOREIGN KEY (parent_id) REFERENCES parent(id) +func TestWalkThrough_5_3_AddColumnThenFK(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT NOT NULL, PRIMARY KEY (id))") + wtExec(t, c, "CREATE TABLE child (id INT NOT NULL)") + + mustExec(t, c, "ALTER TABLE child ADD COLUMN parent_id INT, ADD CONSTRAINT fk FOREIGN KEY (parent_id) REFERENCES parent(id)") + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + // Verify column was added. + col := tbl.GetColumn("parent_id") + if col == nil { + t.Fatal("column parent_id not found") + } + + // Verify FK constraint exists. + var fkCon *Constraint + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey && con.Name == "fk" { + fkCon = con + break + } + } + if fkCon == nil { + t.Fatal("FK constraint 'fk' not found") + } + if len(fkCon.Columns) != 1 || fkCon.Columns[0] != "parent_id" { + t.Errorf("FK columns = %v, want [parent_id]", fkCon.Columns) + } + if fkCon.RefTable != "parent" { + t.Errorf("FK ref table = %q, want 'parent'", fkCon.RefTable) + } + + // Verify backing index exists (named "fk" since constraint is named). + var fkIdx *Index + for _, idx := range tbl.Indexes { + if idx.Name == "fk" { + fkIdx = idx + break + } + } + if fkIdx == nil { + t.Fatal("backing index 'fk' not found") + } + if len(fkIdx.Columns) != 1 || fkIdx.Columns[0].Name != "parent_id" { + t.Errorf("backing index columns = %v, want [parent_id]", fkIdx.Columns) + } +} + +// Scenario 2: ADD COLUMN parent_id INT, ADD INDEX idx (parent_id), ADD CONSTRAINT fk FOREIGN KEY (parent_id) REFERENCES parent(id) +func TestWalkThrough_5_3_AddColumnExplicitIndexThenFK(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT NOT NULL, PRIMARY KEY (id))") + wtExec(t, c, "CREATE TABLE child (id INT NOT NULL)") + + mustExec(t, c, "ALTER TABLE child ADD COLUMN parent_id INT, ADD INDEX idx (parent_id), ADD CONSTRAINT fk FOREIGN KEY (parent_id) REFERENCES parent(id)") + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + // Verify column was added. + col := tbl.GetColumn("parent_id") + if col == nil { + t.Fatal("column parent_id not found") + } + + // Verify FK constraint exists. + var fkCon *Constraint + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey && con.Name == "fk" { + fkCon = con + break + } + } + if fkCon == nil { + t.Fatal("FK constraint 'fk' not found") + } + + // Verify explicit index "idx" exists. + var explicitIdx *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx" { + explicitIdx = idx + break + } + } + if explicitIdx == nil { + t.Fatal("explicit index 'idx' not found") + } + + // Verify NO duplicate backing index was created — the explicit index should cover the FK. + idxCount := 0 + for _, idx := range tbl.Indexes { + for _, ic := range idx.Columns { + if ic.Name == "parent_id" { + idxCount++ + break + } + } + } + if idxCount != 1 { + t.Errorf("expected exactly 1 index covering parent_id, got %d (indexes: %v)", idxCount, indexNames(tbl)) + } +} + +// Scenario 3: DROP FOREIGN KEY fk, DROP COLUMN parent_id +func TestWalkThrough_5_3_DropFKThenDropColumn(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT NOT NULL, PRIMARY KEY (id))") + wtExec(t, c, "CREATE TABLE child (id INT NOT NULL, parent_id INT, CONSTRAINT fk FOREIGN KEY (parent_id) REFERENCES parent(id))") + + mustExec(t, c, "ALTER TABLE child DROP FOREIGN KEY fk, DROP COLUMN parent_id") + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + // Verify FK constraint is gone. + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey { + t.Errorf("FK constraint should have been dropped, found: %s", con.Name) + } + } + + // Verify column is gone. + if tbl.GetColumn("parent_id") != nil { + t.Error("column parent_id should have been dropped") + } + + // Verify backing index is also gone (column was dropped, so index columns are empty). + for _, idx := range tbl.Indexes { + if idx.Name == "fk" { + t.Error("backing index 'fk' should have been removed after column drop") + } + } +} + +// Scenario 4: DROP FOREIGN KEY fk, DROP INDEX fk +func TestWalkThrough_5_3_DropFKThenDropBackingIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT NOT NULL, PRIMARY KEY (id))") + wtExec(t, c, "CREATE TABLE child (id INT NOT NULL, parent_id INT, CONSTRAINT fk FOREIGN KEY (parent_id) REFERENCES parent(id))") + + mustExec(t, c, "ALTER TABLE child DROP FOREIGN KEY fk, DROP INDEX fk") + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + // Verify FK constraint is gone. + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey { + t.Errorf("FK constraint should have been dropped, found: %s", con.Name) + } + } + + // Verify backing index is gone. + for _, idx := range tbl.Indexes { + if idx.Name == "fk" { + t.Errorf("backing index 'fk' should have been dropped") + } + } + + // Verify column still exists. + if tbl.GetColumn("parent_id") == nil { + t.Error("column parent_id should still exist") + } +} + +// Scenario 5: ADD FOREIGN KEY fk1 (...), ADD INDEX idx (...) on same column in same ALTER +func TestWalkThrough_5_3_AddFKAndIndexSameColumn(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT NOT NULL, PRIMARY KEY (id))") + wtExec(t, c, "CREATE TABLE child (id INT NOT NULL, parent_id INT)") + + mustExec(t, c, "ALTER TABLE child ADD CONSTRAINT fk1 FOREIGN KEY (parent_id) REFERENCES parent(id), ADD INDEX idx (parent_id)") + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + // Verify FK constraint exists. + var fkCon *Constraint + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey && con.Name == "fk1" { + fkCon = con + break + } + } + if fkCon == nil { + t.Fatal("FK constraint 'fk1' not found") + } + + // Verify explicit index "idx" exists. + var explicitIdx *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx" { + explicitIdx = idx + break + } + } + if explicitIdx == nil { + t.Fatal("explicit index 'idx' not found") + } + + // The FK backing index should NOT create a duplicate — the explicit idx covers parent_id. + // But since FK is processed before INDEX in the ALTER commands, the FK may have created + // its own backing index "fk1" before "idx" was added. MySQL processes commands sequentially. + // So we should have: backing index "fk1" created by FK, then "idx" added as explicit. + // Actually in MySQL, when FK is added first and no index exists, the backing index IS created. + // Then the explicit ADD INDEX creates a second one. Both should exist. + // Let's just verify: FK constraint exists + at least one index covers parent_id. + idxCount := 0 + for _, idx := range tbl.Indexes { + for _, ic := range idx.Columns { + if ic.Name == "parent_id" { + idxCount++ + break + } + } + } + if idxCount < 1 { + t.Errorf("expected at least 1 index covering parent_id, got %d", idxCount) + } +} + +// indexNames returns a list of index names for debugging. +func indexNames(tbl *Table) []string { + names := make([]string, len(tbl.Indexes)) + for i, idx := range tbl.Indexes { + names[i] = idx.Name + } + return names +} diff --git a/tidb/catalog/wt_5_4_test.go b/tidb/catalog/wt_5_4_test.go new file mode 100644 index 00000000..7e328116 --- /dev/null +++ b/tidb/catalog/wt_5_4_test.go @@ -0,0 +1,116 @@ +package catalog + +import "testing" + +// TestWalkThrough_5_4 covers section 1.4 — Error Semantics in Multi-Command ALTER. +// Multi-command ALTER TABLE is atomic in MySQL: if any sub-command fails, the +// entire ALTER is rolled back and the table state is unchanged. +func TestWalkThrough_5_4(t *testing.T) { + t.Run("dup_column_in_same_alter", func(t *testing.T) { + // ADD COLUMN x INT, ADD COLUMN x INT — duplicate column error, verify rollback + c := setupTestTable(t) // t1: id INT NOT NULL, name VARCHAR(100), age INT + + results, _ := c.Exec( + "ALTER TABLE t1 ADD COLUMN x INT, ADD COLUMN x INT", + &ExecOptions{ContinueOnError: true}, + ) + assertError(t, results[0].Error, ErrDupColumn) + + // Verify rollback: table should be unchanged. + tbl := c.GetDatabase("test").GetTable("t1") + if len(tbl.Columns) != 3 { + t.Errorf("expected 3 columns after rollback, got %d", len(tbl.Columns)) + } + if tbl.GetColumn("x") != nil { + t.Error("column 'x' should not exist after rollback") + } + }) + + t.Run("add_then_drop_nonexistent", func(t *testing.T) { + // ADD COLUMN x INT, DROP COLUMN nonexistent — first succeeds, second errors; + // verify x was NOT added (MySQL rolls back entire ALTER) + c := setupTestTable(t) + + results, _ := c.Exec( + "ALTER TABLE t1 ADD COLUMN x INT, DROP COLUMN nonexistent", + &ExecOptions{ContinueOnError: true}, + ) + if results[0].Error == nil { + t.Fatal("expected error from DROP COLUMN nonexistent") + } + + tbl := c.GetDatabase("test").GetTable("t1") + if len(tbl.Columns) != 3 { + t.Errorf("expected 3 columns after rollback, got %d", len(tbl.Columns)) + } + if tbl.GetColumn("x") != nil { + t.Error("column 'x' should not exist after rollback") + } + }) + + t.Run("modify_nonexistent_then_add", func(t *testing.T) { + // MODIFY COLUMN nonexistent INT, ADD COLUMN y INT — first errors, second never runs + c := setupTestTable(t) + + results, _ := c.Exec( + "ALTER TABLE t1 MODIFY COLUMN nonexistent INT, ADD COLUMN y INT", + &ExecOptions{ContinueOnError: true}, + ) + assertError(t, results[0].Error, ErrNoSuchColumn) + + tbl := c.GetDatabase("test").GetTable("t1") + if len(tbl.Columns) != 3 { + t.Errorf("expected 3 columns after rollback, got %d", len(tbl.Columns)) + } + if tbl.GetColumn("y") != nil { + t.Error("column 'y' should not exist — entire ALTER rolled back") + } + }) + + t.Run("drop_nonexistent_index_then_add_column", func(t *testing.T) { + // DROP INDEX nonexistent, ADD COLUMN y INT — error on first, verify y not added + c := setupTestTable(t) + + results, _ := c.Exec( + "ALTER TABLE t1 DROP INDEX nonexistent, ADD COLUMN y INT", + &ExecOptions{ContinueOnError: true}, + ) + assertError(t, results[0].Error, ErrCantDropKey) + + tbl := c.GetDatabase("test").GetTable("t1") + if len(tbl.Columns) != 3 { + t.Errorf("expected 3 columns after rollback, got %d", len(tbl.Columns)) + } + if tbl.GetColumn("y") != nil { + t.Error("column 'y' should not exist — entire ALTER rolled back") + } + }) + + t.Run("add_index_then_drop_same_column", func(t *testing.T) { + // ADD COLUMN x INT, ADD INDEX idx_x (x), DROP COLUMN x — + // MySQL processes sequentially: add x, create index, then drop x + // which also cleans up the index. Net result: table unchanged except + // droppedByCleanup tracking. The ALTER succeeds in MySQL. + c := setupTestTable(t) + + mustExec(t, c, + "ALTER TABLE t1 ADD COLUMN x INT, ADD INDEX idx_x (x), DROP COLUMN x", + ) + + tbl := c.GetDatabase("test").GetTable("t1") + // x was added then dropped — should not be present. + if tbl.GetColumn("x") != nil { + t.Error("column 'x' should not exist — it was added then dropped") + } + if len(tbl.Columns) != 3 { + t.Errorf("expected 3 columns, got %d", len(tbl.Columns)) + } + + // Index should also have been cleaned up when x was dropped. + for _, idx := range tbl.Indexes { + if idx.Name == "idx_x" { + t.Error("index 'idx_x' should not exist — its column was dropped") + } + } + }) +} diff --git a/tidb/catalog/wt_6_1_test.go b/tidb/catalog/wt_6_1_test.go new file mode 100644 index 00000000..208286d7 --- /dev/null +++ b/tidb/catalog/wt_6_1_test.go @@ -0,0 +1,377 @@ +package catalog + +import ( + "strings" + "testing" +) + +func TestWalkThrough_6_1_FKBackingIndexManagement(t *testing.T) { + // Scenario 1: CREATE TABLE with named FK, no explicit index — implicit index uses constraint name + t.Run("named_fk_implicit_index_uses_constraint_name", func(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, "CREATE TABLE parent (id INT NOT NULL, PRIMARY KEY (id))") + mustExec(t, c, `CREATE TABLE child ( + id INT NOT NULL, + parent_id INT NOT NULL, + PRIMARY KEY (id), + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent(id) + )`) + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + // Should have 2 indexes: PRIMARY and fk_parent + if len(tbl.Indexes) != 2 { + t.Fatalf("expected 2 indexes, got %d", len(tbl.Indexes)) + } + + var fkIdx *Index + for _, idx := range tbl.Indexes { + if idx.Name == "fk_parent" { + fkIdx = idx + break + } + } + if fkIdx == nil { + t.Fatal("expected implicit index named 'fk_parent', not found") + } + if len(fkIdx.Columns) != 1 || fkIdx.Columns[0].Name != "parent_id" { + t.Errorf("expected index on (parent_id), got %v", fkIdx.Columns) + } + }) + + // Scenario 2: CREATE TABLE with unnamed FK — implicit index uses first column name + t.Run("unnamed_fk_implicit_index_uses_column_name", func(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, "CREATE TABLE parent (id INT NOT NULL, PRIMARY KEY (id))") + mustExec(t, c, `CREATE TABLE child ( + id INT NOT NULL, + parent_id INT NOT NULL, + PRIMARY KEY (id), + FOREIGN KEY (parent_id) REFERENCES parent(id) + )`) + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + // Should have 2 indexes: PRIMARY and parent_id + if len(tbl.Indexes) != 2 { + t.Fatalf("expected 2 indexes, got %d", len(tbl.Indexes)) + } + + var fkIdx *Index + for _, idx := range tbl.Indexes { + if idx.Name == "parent_id" { + fkIdx = idx + break + } + } + if fkIdx == nil { + names := make([]string, len(tbl.Indexes)) + for i, idx := range tbl.Indexes { + names[i] = idx.Name + } + t.Fatalf("expected implicit index named 'parent_id', not found; indexes: %v", names) + } + }) + + // Scenario 3: CREATE TABLE with explicit index on FK columns, then named FK — no duplicate index + t.Run("explicit_index_before_named_fk_no_duplicate", func(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, "CREATE TABLE parent (id INT NOT NULL, PRIMARY KEY (id))") + mustExec(t, c, `CREATE TABLE child ( + id INT NOT NULL, + parent_id INT NOT NULL, + PRIMARY KEY (id), + INDEX idx_parent (parent_id), + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent(id) + )`) + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + // Should have 2 indexes: PRIMARY and idx_parent (no duplicate fk_parent) + if len(tbl.Indexes) != 2 { + names := make([]string, len(tbl.Indexes)) + for i, idx := range tbl.Indexes { + names[i] = idx.Name + } + t.Fatalf("expected 2 indexes (no duplicate), got %d: %v", len(tbl.Indexes), names) + } + + // The existing index should be idx_parent, not fk_parent + found := false + for _, idx := range tbl.Indexes { + if idx.Name == "idx_parent" { + found = true + } + } + if !found { + t.Error("expected idx_parent index to exist") + } + }) + + // Scenario 4: CREATE TABLE with FK on column already in UNIQUE KEY — no duplicate index + t.Run("fk_on_unique_key_column_no_duplicate", func(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, "CREATE TABLE parent (id INT NOT NULL, PRIMARY KEY (id))") + mustExec(t, c, `CREATE TABLE child ( + id INT NOT NULL, + parent_id INT NOT NULL, + PRIMARY KEY (id), + UNIQUE KEY uk_parent (parent_id), + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent(id) + )`) + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + // Should have 2 indexes: PRIMARY and uk_parent (no duplicate fk_parent) + if len(tbl.Indexes) != 2 { + names := make([]string, len(tbl.Indexes)) + for i, idx := range tbl.Indexes { + names[i] = idx.Name + } + t.Fatalf("expected 2 indexes (no duplicate), got %d: %v", len(tbl.Indexes), names) + } + }) + + // Scenario 5: CREATE TABLE with FK on column already in PRIMARY KEY — no duplicate index + t.Run("fk_on_primary_key_column_no_duplicate", func(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, "CREATE TABLE parent (id INT NOT NULL, PRIMARY KEY (id))") + mustExec(t, c, `CREATE TABLE child ( + parent_id INT NOT NULL, + PRIMARY KEY (parent_id), + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent(id) + )`) + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + // Should have only 1 index: PRIMARY (no duplicate for FK) + if len(tbl.Indexes) != 1 { + names := make([]string, len(tbl.Indexes)) + for i, idx := range tbl.Indexes { + names[i] = idx.Name + } + t.Fatalf("expected 1 index (PRIMARY only), got %d: %v", len(tbl.Indexes), names) + } + if !tbl.Indexes[0].Primary { + t.Error("expected the only index to be PRIMARY") + } + }) + + // Scenario 6: CREATE TABLE with multi-column FK, partial index exists — implicit index still created + t.Run("multi_column_fk_partial_index_still_creates_implicit", func(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, `CREATE TABLE parent ( + a INT NOT NULL, + b INT NOT NULL, + PRIMARY KEY (a, b) + )`) + mustExec(t, c, `CREATE TABLE child ( + id INT NOT NULL, + a INT NOT NULL, + b INT NOT NULL, + PRIMARY KEY (id), + INDEX idx_a (a), + CONSTRAINT fk_ab FOREIGN KEY (a, b) REFERENCES parent(a, b) + )`) + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + // idx_a only covers (a), not (a, b) — so an implicit index fk_ab should be created + // Expected: PRIMARY, idx_a, fk_ab = 3 indexes + if len(tbl.Indexes) != 3 { + names := make([]string, len(tbl.Indexes)) + for i, idx := range tbl.Indexes { + names[i] = idx.Name + } + t.Fatalf("expected 3 indexes, got %d: %v", len(tbl.Indexes), names) + } + + var fkIdx *Index + for _, idx := range tbl.Indexes { + if idx.Name == "fk_ab" { + fkIdx = idx + break + } + } + if fkIdx == nil { + t.Fatal("expected implicit index named 'fk_ab' for multi-column FK") + } + if len(fkIdx.Columns) != 2 { + t.Errorf("expected 2 columns in fk_ab index, got %d", len(fkIdx.Columns)) + } + }) + + // Scenario 7: ALTER TABLE ADD FK when column already has index — no duplicate index + t.Run("alter_add_fk_column_has_index_no_duplicate", func(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, "CREATE TABLE parent (id INT NOT NULL, PRIMARY KEY (id))") + mustExec(t, c, `CREATE TABLE child ( + id INT NOT NULL, + parent_id INT NOT NULL, + PRIMARY KEY (id), + INDEX idx_parent (parent_id) + )`) + mustExec(t, c, "ALTER TABLE child ADD CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent(id)") + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + // Should have 2 indexes: PRIMARY and idx_parent (no duplicate fk_parent) + if len(tbl.Indexes) != 2 { + names := make([]string, len(tbl.Indexes)) + for i, idx := range tbl.Indexes { + names[i] = idx.Name + } + t.Fatalf("expected 2 indexes (no duplicate), got %d: %v", len(tbl.Indexes), names) + } + + // FK constraint should exist + var fk *Constraint + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey && con.Name == "fk_parent" { + fk = con + break + } + } + if fk == nil { + t.Fatal("FK constraint fk_parent not found") + } + }) + + // Scenario 8: ALTER TABLE ADD FK when column has no index — implicit index created + t.Run("alter_add_fk_no_index_creates_implicit", func(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, "CREATE TABLE parent (id INT NOT NULL, PRIMARY KEY (id))") + mustExec(t, c, `CREATE TABLE child ( + id INT NOT NULL, + parent_id INT NOT NULL, + PRIMARY KEY (id) + )`) + mustExec(t, c, "ALTER TABLE child ADD CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent(id)") + + tbl := c.GetDatabase("testdb").GetTable("child") + if tbl == nil { + t.Fatal("table child not found") + } + + // Should have 2 indexes: PRIMARY and fk_parent (implicit) + if len(tbl.Indexes) != 2 { + names := make([]string, len(tbl.Indexes)) + for i, idx := range tbl.Indexes { + names[i] = idx.Name + } + t.Fatalf("expected 2 indexes, got %d: %v", len(tbl.Indexes), names) + } + + var fkIdx *Index + for _, idx := range tbl.Indexes { + if idx.Name == "fk_parent" { + fkIdx = idx + break + } + } + if fkIdx == nil { + t.Fatal("expected implicit index named 'fk_parent'") + } + }) + + // Scenario 9: ALTER TABLE DROP FOREIGN KEY — FK removed but backing index remains + t.Run("drop_fk_keeps_backing_index", func(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, "CREATE TABLE parent (id INT NOT NULL, PRIMARY KEY (id))") + mustExec(t, c, `CREATE TABLE child ( + id INT NOT NULL, + parent_id INT NOT NULL, + PRIMARY KEY (id), + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent(id) + )`) + + // Verify FK and backing index exist before drop + tbl := c.GetDatabase("testdb").GetTable("child") + if len(tbl.Indexes) != 2 { + t.Fatalf("expected 2 indexes before drop, got %d", len(tbl.Indexes)) + } + + mustExec(t, c, "ALTER TABLE child DROP FOREIGN KEY fk_parent") + + tbl = c.GetDatabase("testdb").GetTable("child") + + // FK constraint should be gone + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey && strings.EqualFold(con.Name, "fk_parent") { + t.Error("FK constraint fk_parent should have been removed") + } + } + + // Backing index should remain (MySQL behavior) + if len(tbl.Indexes) != 2 { + t.Fatalf("expected 2 indexes after FK drop (backing index remains), got %d", len(tbl.Indexes)) + } + + var fkIdx *Index + for _, idx := range tbl.Indexes { + if idx.Name == "fk_parent" { + fkIdx = idx + break + } + } + if fkIdx == nil { + t.Fatal("backing index fk_parent should remain after DROP FOREIGN KEY") + } + }) + + // Scenario 10: ALTER TABLE DROP FOREIGN KEY, DROP INDEX fk_name — explicit index cleanup after FK drop + t.Run("drop_fk_then_drop_index_explicit_cleanup", func(t *testing.T) { + c := setupWithDB(t) + mustExec(t, c, "CREATE TABLE parent (id INT NOT NULL, PRIMARY KEY (id))") + mustExec(t, c, `CREATE TABLE child ( + id INT NOT NULL, + parent_id INT NOT NULL, + PRIMARY KEY (id), + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent(id) + )`) + + mustExec(t, c, "ALTER TABLE child DROP FOREIGN KEY fk_parent, DROP INDEX fk_parent") + + tbl := c.GetDatabase("testdb").GetTable("child") + + // FK constraint should be gone + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey && strings.EqualFold(con.Name, "fk_parent") { + t.Error("FK constraint fk_parent should have been removed") + } + } + + // Backing index should also be gone now + if len(tbl.Indexes) != 1 { + names := make([]string, len(tbl.Indexes)) + for i, idx := range tbl.Indexes { + names[i] = idx.Name + } + t.Fatalf("expected 1 index (PRIMARY only) after DROP FK + DROP INDEX, got %d: %v", len(tbl.Indexes), names) + } + if !tbl.Indexes[0].Primary { + t.Error("expected the only remaining index to be PRIMARY") + } + }) +} diff --git a/tidb/catalog/wt_6_2_test.go b/tidb/catalog/wt_6_2_test.go new file mode 100644 index 00000000..d97ae98f --- /dev/null +++ b/tidb/catalog/wt_6_2_test.go @@ -0,0 +1,317 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestWalkThrough_6_2 tests FK Validation Matrix scenarios from section 2.2 +// of the walkthrough. Scenarios already fully covered in wt_2_4_test.go are +// omitted; this file covers new or extended aspects only. + +// Scenario 1: DROP TABLE parent when child FK exists, foreign_key_checks=1 — error 3730 +// (Already covered by TestWalkThrough_2_4_DropTableReferencedByFK; included here +// under a 6_2 name for completeness of the section proof.) +func TestWalkThrough_6_2_DropParentFKChecksOn(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, "CREATE TABLE child (id INT, pid INT, FOREIGN KEY (pid) REFERENCES parent(id))") + + results, err := c.Exec("DROP TABLE parent", &ExecOptions{ContinueOnError: true}) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrFKCannotDropParent) +} + +// Scenario 2: DROP TABLE parent when child FK exists, foreign_key_checks=0 — succeeds, +// child FK becomes orphan (references dropped parent). +func TestWalkThrough_6_2_DropParentFKChecksOff_OrphanState(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, "CREATE TABLE child (id INT, pid INT, FOREIGN KEY (pid) REFERENCES parent(id))") + wtExec(t, c, "SET foreign_key_checks = 0") + wtExec(t, c, "DROP TABLE parent") + + // Parent should be gone. + if c.GetDatabase("testdb").GetTable("parent") != nil { + t.Fatal("parent table should have been dropped") + } + + // Child should still exist. + child := c.GetDatabase("testdb").GetTable("child") + if child == nil { + t.Fatal("child table should still exist") + } + + // Child FK constraint should still reference the dropped parent. + var fk *Constraint + for _, con := range child.Constraints { + if con.Type == ConForeignKey { + fk = con + break + } + } + if fk == nil { + t.Fatal("child should still have FK constraint") + } + if !strings.EqualFold(fk.RefTable, "parent") { + t.Errorf("FK should still reference 'parent', got %q", fk.RefTable) + } +} + +// Scenario 3: DROP TABLE child then parent, foreign_key_checks=1 — succeeds. +func TestWalkThrough_6_2_DropChildThenParent(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, "CREATE TABLE child (id INT, pid INT, FOREIGN KEY (pid) REFERENCES parent(id))") + + // Drop child first — removes FK dependency. + wtExec(t, c, "DROP TABLE child") + // Now parent can be dropped. + wtExec(t, c, "DROP TABLE parent") + + if c.GetDatabase("testdb").GetTable("child") != nil { + t.Error("child should be dropped") + } + if c.GetDatabase("testdb").GetTable("parent") != nil { + t.Error("parent should be dropped") + } +} + +// Scenario 4: DROP COLUMN used in FK on same table, foreign_key_checks=1 — error 1828 +// (Already covered by TestWalkThrough_2_4_AlterTableDropColumnUsedInFK; included +// here for section completeness.) +func TestWalkThrough_6_2_DropColumnUsedInFK(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, "CREATE TABLE child (id INT, pid INT, FOREIGN KEY (pid) REFERENCES parent(id))") + + results, err := c.Exec("ALTER TABLE child DROP COLUMN pid", &ExecOptions{ContinueOnError: true}) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if results[0].Error == nil { + t.Fatal("expected error when dropping FK column") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("expected *Error, got %T", results[0].Error) + } + if catErr.Code != 1828 { + t.Errorf("expected error code 1828, got %d", catErr.Code) + } +} + +// Scenario 5: CREATE TABLE with FK referencing nonexistent table, foreign_key_checks=0 — succeeds. +// (Already covered by TestWalkThrough_2_4_SetForeignKeyChecksOff; included for completeness.) +func TestWalkThrough_6_2_CreateFKNonexistentTableFKOff(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "SET foreign_key_checks = 0") + wtExec(t, c, "CREATE TABLE child (id INT, pid INT, FOREIGN KEY (pid) REFERENCES nonexistent(id))") + + child := c.GetDatabase("testdb").GetTable("child") + if child == nil { + t.Fatal("child table should exist") + } +} + +// Scenario 6: CREATE TABLE with FK referencing nonexistent column, foreign_key_checks=0 — succeeds. +func TestWalkThrough_6_2_CreateFKNonexistentColumnFKOff(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, "SET foreign_key_checks = 0") + wtExec(t, c, "CREATE TABLE child (id INT, pid INT, FOREIGN KEY (pid) REFERENCES parent(nonexistent_col))") + + child := c.GetDatabase("testdb").GetTable("child") + if child == nil { + t.Fatal("child table should exist") + } + // Verify FK constraint is stored. + var fk *Constraint + for _, con := range child.Constraints { + if con.Type == ConForeignKey { + fk = con + break + } + } + if fk == nil { + t.Fatal("FK constraint should be stored") + } + if len(fk.RefColumns) == 0 || fk.RefColumns[0] != "nonexistent_col" { + t.Errorf("FK should reference 'nonexistent_col', got %v", fk.RefColumns) + } +} + +// Scenario 7: ALTER TABLE ADD FK with type mismatch (INT vs VARCHAR), foreign_key_checks=1 — error. +// (Already covered by TestWalkThrough_2_4_AlterTableAddFKIncompatibleColumns; included for completeness.) +func TestWalkThrough_6_2_AddFKTypeMismatch(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT PRIMARY KEY)") + wtExec(t, c, "CREATE TABLE child (id INT, pid VARCHAR(100))") + + results, err := c.Exec("ALTER TABLE child ADD CONSTRAINT fk_p FOREIGN KEY (pid) REFERENCES parent(id)", &ExecOptions{ContinueOnError: true}) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrFKIncompatibleColumns) +} + +// Scenario 8: ALTER TABLE ADD FK where referenced table has no index on referenced columns — error 1822. +// (Already covered by TestWalkThrough_2_4_AlterTableAddFKMissingIndex; included for completeness.) +func TestWalkThrough_6_2_AddFKMissingIndex(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id INT, val INT)") // val has no index + wtExec(t, c, "CREATE TABLE child (id INT, pid INT)") + + results, err := c.Exec("ALTER TABLE child ADD CONSTRAINT fk_v FOREIGN KEY (pid) REFERENCES parent(val)", &ExecOptions{ContinueOnError: true}) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrFKMissingIndex) +} + +// Scenario 9: SET foreign_key_checks=0 then CREATE circular FKs then SET foreign_key_checks=1 — both tables valid. +func TestWalkThrough_6_2_CircularFKs(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "SET foreign_key_checks = 0") + wtExec(t, c, "CREATE TABLE a (id INT PRIMARY KEY, b_id INT, FOREIGN KEY (b_id) REFERENCES b(id))") + wtExec(t, c, "CREATE TABLE b (id INT PRIMARY KEY, a_id INT, FOREIGN KEY (a_id) REFERENCES a(id))") + wtExec(t, c, "SET foreign_key_checks = 1") + + tblA := c.GetDatabase("testdb").GetTable("a") + tblB := c.GetDatabase("testdb").GetTable("b") + if tblA == nil { + t.Fatal("table a should exist") + } + if tblB == nil { + t.Fatal("table b should exist") + } + + // Verify table a has FK referencing b. + var fkA *Constraint + for _, con := range tblA.Constraints { + if con.Type == ConForeignKey { + fkA = con + break + } + } + if fkA == nil { + t.Fatal("table a should have FK constraint") + } + if !strings.EqualFold(fkA.RefTable, "b") { + t.Errorf("table a FK should reference 'b', got %q", fkA.RefTable) + } + + // Verify table b has FK referencing a. + var fkB *Constraint + for _, con := range tblB.Constraints { + if con.Type == ConForeignKey { + fkB = con + break + } + } + if fkB == nil { + t.Fatal("table b should have FK constraint") + } + if !strings.EqualFold(fkB.RefTable, "a") { + t.Errorf("table b FK should reference 'a', got %q", fkB.RefTable) + } +} + +// Scenario 10: Self-referencing FK (table references itself) — column references own table PK. +func TestWalkThrough_6_2_SelfReferencingFK(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE tree (id INT PRIMARY KEY, parent_id INT, FOREIGN KEY (parent_id) REFERENCES tree(id))") + + tbl := c.GetDatabase("testdb").GetTable("tree") + if tbl == nil { + t.Fatal("table tree should exist") + } + + var fk *Constraint + for _, con := range tbl.Constraints { + if con.Type == ConForeignKey { + fk = con + break + } + } + if fk == nil { + t.Fatal("self-referencing FK constraint should exist") + } + if !strings.EqualFold(fk.RefTable, "tree") { + t.Errorf("FK should reference 'tree', got %q", fk.RefTable) + } + if len(fk.Columns) != 1 || fk.Columns[0] != "parent_id" { + t.Errorf("FK columns should be [parent_id], got %v", fk.Columns) + } + if len(fk.RefColumns) != 1 || fk.RefColumns[0] != "id" { + t.Errorf("FK ref columns should be [id], got %v", fk.RefColumns) + } + + // Verify SHOW CREATE TABLE renders the self-ref FK. + sct := c.ShowCreateTable("testdb", "tree") + if !strings.Contains(sct, "REFERENCES `tree`") { + t.Errorf("SHOW CREATE TABLE should contain self-reference, got:\n%s", sct) + } +} + +// Scenario 11: FK column count mismatch (single-column FK referencing composite PK) — error. +// MySQL returns error 1822 (missing index) because the single-column FK cannot match +// the composite PK index (which requires both columns as leading prefix). +func TestWalkThrough_6_2_FKColumnCountMismatch(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE parent (id1 INT, id2 INT, PRIMARY KEY (id1, id2))") + wtExec(t, c, "CREATE TABLE child (id INT, pid INT)") + + // Single-column FK referencing one column of a composite PK. + // The PK index has (id1, id2), so a reference to just (id1) DOES match + // the leading prefix. This should SUCCEED in MySQL. + // But a reference to (id2) alone would NOT match the prefix and would fail. + // Let's test the failing case: reference id2 (second column of composite PK). + results, err := c.Exec("ALTER TABLE child ADD CONSTRAINT fk_mismatch FOREIGN KEY (pid) REFERENCES parent(id2)", &ExecOptions{ContinueOnError: true}) + if err != nil { + t.Fatalf("parse error: %v", err) + } + assertError(t, results[0].Error, ErrFKMissingIndex) +} + +// Scenario 12: Cross-database FK: FOREIGN KEY (col) REFERENCES other_db.parent(id) — stored and rendered correctly. +func TestWalkThrough_6_2_CrossDatabaseFK(t *testing.T) { + c := wtSetup(t) + // Create the other database and parent table. + wtExec(t, c, "CREATE DATABASE other_db") + wtExec(t, c, "CREATE TABLE other_db.parent (id INT PRIMARY KEY)") + + // Create child in testdb with cross-database FK. + wtExec(t, c, "CREATE TABLE child (id INT, pid INT, CONSTRAINT fk_cross FOREIGN KEY (pid) REFERENCES other_db.parent(id))") + + child := c.GetDatabase("testdb").GetTable("child") + if child == nil { + t.Fatal("child table should exist") + } + + // Verify FK stores the cross-database reference. + var fk *Constraint + for _, con := range child.Constraints { + if con.Type == ConForeignKey { + fk = con + break + } + } + if fk == nil { + t.Fatal("FK constraint should exist") + } + if !strings.EqualFold(fk.RefDatabase, "other_db") { + t.Errorf("FK RefDatabase should be 'other_db', got %q", fk.RefDatabase) + } + if !strings.EqualFold(fk.RefTable, "parent") { + t.Errorf("FK RefTable should be 'parent', got %q", fk.RefTable) + } + + // Verify SHOW CREATE TABLE renders the cross-database reference. + sct := c.ShowCreateTable("testdb", "child") + if !strings.Contains(sct, "`other_db`.`parent`") { + t.Errorf("SHOW CREATE TABLE should contain cross-database reference `other_db`.`parent`, got:\n%s", sct) + } +} diff --git a/tidb/catalog/wt_6_3_test.go b/tidb/catalog/wt_6_3_test.go new file mode 100644 index 00000000..f9426ccd --- /dev/null +++ b/tidb/catalog/wt_6_3_test.go @@ -0,0 +1,100 @@ +package catalog + +import ( + "strings" + "testing" +) + +func TestWalkThrough_6_3_FKActionsRendering(t *testing.T) { + // Scenario 1: FK with ON DELETE CASCADE ON UPDATE SET NULL — both actions rendered in SHOW CREATE + t.Run("fk_with_cascade_and_set_null", func(t *testing.T) { + c := New() + mustExec(t, c, "CREATE DATABASE test") + c.SetCurrentDatabase("test") + mustExec(t, c, "CREATE TABLE parent (id INT NOT NULL, PRIMARY KEY (id))") + mustExec(t, c, `CREATE TABLE child ( + id INT NOT NULL, + parent_id INT, + PRIMARY KEY (id), + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent (id) + ON DELETE CASCADE ON UPDATE SET NULL + )`) + + got := c.ShowCreateTable("test", "child") + if !strings.Contains(got, "ON DELETE CASCADE") { + t.Errorf("expected ON DELETE CASCADE in output:\n%s", got) + } + if !strings.Contains(got, "ON UPDATE SET NULL") { + t.Errorf("expected ON UPDATE SET NULL in output:\n%s", got) + } + }) + + // Scenario 2: FK with no action specified — defaults not rendered in SHOW CREATE + t.Run("fk_with_no_action_defaults_omitted", func(t *testing.T) { + c := New() + mustExec(t, c, "CREATE DATABASE test") + c.SetCurrentDatabase("test") + mustExec(t, c, "CREATE TABLE parent (id INT NOT NULL, PRIMARY KEY (id))") + mustExec(t, c, `CREATE TABLE child ( + id INT NOT NULL, + parent_id INT, + PRIMARY KEY (id), + CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES parent (id) + )`) + + got := c.ShowCreateTable("test", "child") + // MySQL 8.0 does not show ON DELETE or ON UPDATE when using default (NO ACTION). + if strings.Contains(got, "ON DELETE") { + t.Errorf("default FK action should not show ON DELETE:\n%s", got) + } + if strings.Contains(got, "ON UPDATE") { + t.Errorf("default FK action should not show ON UPDATE:\n%s", got) + } + // Verify FK is still rendered. + if !strings.Contains(got, "CONSTRAINT `fk_parent` FOREIGN KEY (`parent_id`) REFERENCES `parent` (`id`)") { + t.Errorf("FK constraint not rendered correctly:\n%s", got) + } + }) + + // Scenario 3: Multi-column FK with actions — actions on composite FK rendered correctly + t.Run("multi_column_fk_with_actions", func(t *testing.T) { + c := New() + mustExec(t, c, "CREATE DATABASE test") + c.SetCurrentDatabase("test") + mustExec(t, c, `CREATE TABLE parent ( + a INT NOT NULL, + b INT NOT NULL, + PRIMARY KEY (a, b) + )`) + mustExec(t, c, `CREATE TABLE child ( + id INT NOT NULL, + pa INT, + pb INT, + PRIMARY KEY (id), + CONSTRAINT fk_composite FOREIGN KEY (pa, pb) REFERENCES parent (a, b) + ON DELETE CASCADE ON UPDATE SET NULL + )`) + + got := c.ShowCreateTable("test", "child") + // Verify multi-column FK is rendered with both columns. + if !strings.Contains(got, "FOREIGN KEY (`pa`, `pb`) REFERENCES `parent` (`a`, `b`)") { + t.Errorf("multi-column FK not rendered correctly:\n%s", got) + } + // Actions rendered once for the whole FK. + if !strings.Contains(got, "ON DELETE CASCADE") { + t.Errorf("expected ON DELETE CASCADE on composite FK:\n%s", got) + } + if !strings.Contains(got, "ON UPDATE SET NULL") { + t.Errorf("expected ON UPDATE SET NULL on composite FK:\n%s", got) + } + // Verify only one occurrence of each action (not per-column). + if strings.Count(got, "ON DELETE CASCADE") != 1 { + t.Errorf("expected exactly one ON DELETE CASCADE, got %d:\n%s", + strings.Count(got, "ON DELETE CASCADE"), got) + } + if strings.Count(got, "ON UPDATE SET NULL") != 1 { + t.Errorf("expected exactly one ON UPDATE SET NULL, got %d:\n%s", + strings.Count(got, "ON UPDATE SET NULL"), got) + } + }) +} diff --git a/tidb/catalog/wt_7_1_test.go b/tidb/catalog/wt_7_1_test.go new file mode 100644 index 00000000..a09d8d1c --- /dev/null +++ b/tidb/catalog/wt_7_1_test.go @@ -0,0 +1,336 @@ +package catalog + +import ( + "strings" + "testing" +) + +// --- Section 3.1 (Phase 3): RANGE Partitioning (8 scenarios) --- +// File target: wt_7_1_test.go +// Proof: go test ./mysql/catalog/ -short -count=1 -run "TestWalkThrough_7_1" + +func TestWalkThrough_7_1_RangePartitioning(t *testing.T) { + t.Run("range_expr_3_partitions_maxvalue", func(t *testing.T) { + // Scenario 1: CREATE TABLE PARTITION BY RANGE (expr) with 3 partitions + MAXVALUE + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT NOT NULL, + store_id INT NOT NULL + ) + PARTITION BY RANGE (store_id) ( + PARTITION p0 VALUES LESS THAN (10), + PARTITION p1 VALUES LESS THAN (20), + PARTITION p2 VALUES LESS THAN MAXVALUE + )`) + + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Fatal("table not found") + } + if tbl.Partitioning == nil { + t.Fatal("expected partitioning info, got nil") + } + if tbl.Partitioning.Type != "RANGE" { + t.Errorf("expected type RANGE, got %q", tbl.Partitioning.Type) + } + if len(tbl.Partitioning.Partitions) != 3 { + t.Fatalf("expected 3 partitions, got %d", len(tbl.Partitioning.Partitions)) + } + if tbl.Partitioning.Partitions[0].Name != "p0" { + t.Errorf("expected partition name p0, got %q", tbl.Partitioning.Partitions[0].Name) + } + + ddl := c.ShowCreateTable("testdb", "t1") + // Verify version comment /*!50100 + if !strings.Contains(ddl, "/*!50100") { + t.Errorf("expected /*!50100 version comment in DDL:\n%s", ddl) + } + // Verify MAXVALUE without parens for plain RANGE + if !strings.Contains(ddl, "VALUES LESS THAN MAXVALUE") { + t.Errorf("expected 'VALUES LESS THAN MAXVALUE' (no parens) in DDL:\n%s", ddl) + } + // Verify VALUES LESS THAN (10) + if !strings.Contains(ddl, "VALUES LESS THAN (10)") { + t.Errorf("expected 'VALUES LESS THAN (10)' in DDL:\n%s", ddl) + } + }) + + t.Run("range_columns_single", func(t *testing.T) { + // Scenario 2: CREATE TABLE PARTITION BY RANGE COLUMNS (col) — single column + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t2 ( + id INT NOT NULL, + name VARCHAR(50) NOT NULL + ) + PARTITION BY RANGE COLUMNS (name) ( + PARTITION p0 VALUES LESS THAN ('g'), + PARTITION p1 VALUES LESS THAN ('n'), + PARTITION p2 VALUES LESS THAN (MAXVALUE) + )`) + + tbl := c.GetDatabase("testdb").GetTable("t2") + if tbl == nil { + t.Fatal("table not found") + } + if tbl.Partitioning == nil { + t.Fatal("expected partitioning info, got nil") + } + if tbl.Partitioning.Type != "RANGE COLUMNS" { + t.Errorf("expected type RANGE COLUMNS, got %q", tbl.Partitioning.Type) + } + if len(tbl.Partitioning.Columns) != 1 { + t.Fatalf("expected 1 column, got %d", len(tbl.Partitioning.Columns)) + } + + ddl := c.ShowCreateTable("testdb", "t2") + // RANGE COLUMNS uses /*!50500 + if !strings.Contains(ddl, "/*!50500") { + t.Errorf("expected /*!50500 version comment in DDL:\n%s", ddl) + } + // RANGE COLUMNS MAXVALUE is parenthesized + if !strings.Contains(ddl, "VALUES LESS THAN (MAXVALUE)") { + t.Errorf("expected 'VALUES LESS THAN (MAXVALUE)' in DDL:\n%s", ddl) + } + // MySQL uses double space before COLUMNS + if !strings.Contains(ddl, "RANGE COLUMNS") { + t.Errorf("expected 'RANGE COLUMNS' (double space) in DDL:\n%s", ddl) + } + }) + + t.Run("range_columns_multi", func(t *testing.T) { + // Scenario 3: CREATE TABLE PARTITION BY RANGE COLUMNS (col1, col2) — multi-column + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t3 ( + a INT NOT NULL, + b INT NOT NULL, + c INT NOT NULL + ) + PARTITION BY RANGE COLUMNS (a, b) ( + PARTITION p0 VALUES LESS THAN (10, 20), + PARTITION p1 VALUES LESS THAN (20, 30), + PARTITION p2 VALUES LESS THAN (MAXVALUE, MAXVALUE) + )`) + + tbl := c.GetDatabase("testdb").GetTable("t3") + if tbl == nil { + t.Fatal("table not found") + } + if tbl.Partitioning == nil { + t.Fatal("expected partitioning info, got nil") + } + if tbl.Partitioning.Type != "RANGE COLUMNS" { + t.Errorf("expected type RANGE COLUMNS, got %q", tbl.Partitioning.Type) + } + if len(tbl.Partitioning.Columns) != 2 { + t.Fatalf("expected 2 columns, got %d", len(tbl.Partitioning.Columns)) + } + if len(tbl.Partitioning.Partitions) != 3 { + t.Fatalf("expected 3 partitions, got %d", len(tbl.Partitioning.Partitions)) + } + + ddl := c.ShowCreateTable("testdb", "t3") + if !strings.Contains(ddl, "RANGE COLUMNS(a,b)") { + t.Errorf("expected 'RANGE COLUMNS(a,b)' in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "VALUES LESS THAN (10,20)") { + t.Errorf("expected 'VALUES LESS THAN (10,20)' in DDL:\n%s", ddl) + } + }) + + t.Run("alter_add_partition", func(t *testing.T) { + // Scenario 4: ALTER TABLE ADD PARTITION to RANGE table + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t4 ( + id INT NOT NULL, + val INT NOT NULL + ) + PARTITION BY RANGE (val) ( + PARTITION p0 VALUES LESS THAN (10), + PARTITION p1 VALUES LESS THAN (20) + )`) + + tbl := c.GetDatabase("testdb").GetTable("t4") + if len(tbl.Partitioning.Partitions) != 2 { + t.Fatalf("expected 2 partitions before ADD, got %d", len(tbl.Partitioning.Partitions)) + } + + wtExec(t, c, `ALTER TABLE t4 ADD PARTITION (PARTITION p2 VALUES LESS THAN (30))`) + + tbl = c.GetDatabase("testdb").GetTable("t4") + if len(tbl.Partitioning.Partitions) != 3 { + t.Fatalf("expected 3 partitions after ADD, got %d", len(tbl.Partitioning.Partitions)) + } + if tbl.Partitioning.Partitions[2].Name != "p2" { + t.Errorf("expected new partition name p2, got %q", tbl.Partitioning.Partitions[2].Name) + } + + ddl := c.ShowCreateTable("testdb", "t4") + if !strings.Contains(ddl, "PARTITION p2 VALUES LESS THAN (30)") { + t.Errorf("expected p2 partition in DDL:\n%s", ddl) + } + }) + + t.Run("alter_drop_partition", func(t *testing.T) { + // Scenario 5: ALTER TABLE DROP PARTITION from RANGE table + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t5 ( + id INT NOT NULL, + val INT NOT NULL + ) + PARTITION BY RANGE (val) ( + PARTITION p0 VALUES LESS THAN (10), + PARTITION p1 VALUES LESS THAN (20), + PARTITION p2 VALUES LESS THAN (30) + )`) + + tbl := c.GetDatabase("testdb").GetTable("t5") + if len(tbl.Partitioning.Partitions) != 3 { + t.Fatalf("expected 3 partitions before DROP, got %d", len(tbl.Partitioning.Partitions)) + } + + wtExec(t, c, `ALTER TABLE t5 DROP PARTITION p1`) + + tbl = c.GetDatabase("testdb").GetTable("t5") + if len(tbl.Partitioning.Partitions) != 2 { + t.Fatalf("expected 2 partitions after DROP, got %d", len(tbl.Partitioning.Partitions)) + } + // Remaining should be p0 and p2 + if tbl.Partitioning.Partitions[0].Name != "p0" { + t.Errorf("expected first partition p0, got %q", tbl.Partitioning.Partitions[0].Name) + } + if tbl.Partitioning.Partitions[1].Name != "p2" { + t.Errorf("expected second partition p2, got %q", tbl.Partitioning.Partitions[1].Name) + } + + ddl := c.ShowCreateTable("testdb", "t5") + if strings.Contains(ddl, "p1") { + t.Errorf("dropped partition p1 should not appear in DDL:\n%s", ddl) + } + }) + + t.Run("alter_reorganize_split", func(t *testing.T) { + // Scenario 6: ALTER TABLE REORGANIZE PARTITION split (1->2) + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t6 ( + id INT NOT NULL, + val INT NOT NULL + ) + PARTITION BY RANGE (val) ( + PARTITION p0 VALUES LESS THAN (10), + PARTITION p1 VALUES LESS THAN (30), + PARTITION p2 VALUES LESS THAN MAXVALUE + )`) + + wtExec(t, c, `ALTER TABLE t6 REORGANIZE PARTITION p1 INTO ( + PARTITION p1a VALUES LESS THAN (20), + PARTITION p1b VALUES LESS THAN (30) + )`) + + tbl := c.GetDatabase("testdb").GetTable("t6") + if len(tbl.Partitioning.Partitions) != 4 { + t.Fatalf("expected 4 partitions after split, got %d", len(tbl.Partitioning.Partitions)) + } + names := make([]string, len(tbl.Partitioning.Partitions)) + for i, p := range tbl.Partitioning.Partitions { + names[i] = p.Name + } + expected := []string{"p0", "p1a", "p1b", "p2"} + for i, exp := range expected { + if names[i] != exp { + t.Errorf("partition %d: expected %q, got %q", i, exp, names[i]) + } + } + + ddl := c.ShowCreateTable("testdb", "t6") + if !strings.Contains(ddl, "PARTITION p1a VALUES LESS THAN (20)") { + t.Errorf("expected p1a partition in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "PARTITION p1b VALUES LESS THAN (30)") { + t.Errorf("expected p1b partition in DDL:\n%s", ddl) + } + }) + + t.Run("alter_reorganize_merge", func(t *testing.T) { + // Scenario 7: ALTER TABLE REORGANIZE PARTITION merge (2->1) + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t7 ( + id INT NOT NULL, + val INT NOT NULL + ) + PARTITION BY RANGE (val) ( + PARTITION p0 VALUES LESS THAN (10), + PARTITION p1 VALUES LESS THAN (20), + PARTITION p2 VALUES LESS THAN (30), + PARTITION p3 VALUES LESS THAN MAXVALUE + )`) + + wtExec(t, c, `ALTER TABLE t7 REORGANIZE PARTITION p1, p2 INTO ( + PARTITION p_merged VALUES LESS THAN (30) + )`) + + tbl := c.GetDatabase("testdb").GetTable("t7") + if len(tbl.Partitioning.Partitions) != 3 { + t.Fatalf("expected 3 partitions after merge, got %d", len(tbl.Partitioning.Partitions)) + } + names := make([]string, len(tbl.Partitioning.Partitions)) + for i, p := range tbl.Partitioning.Partitions { + names[i] = p.Name + } + expected := []string{"p0", "p_merged", "p3"} + for i, exp := range expected { + if names[i] != exp { + t.Errorf("partition %d: expected %q, got %q", i, exp, names[i]) + } + } + + ddl := c.ShowCreateTable("testdb", "t7") + if !strings.Contains(ddl, "PARTITION p_merged VALUES LESS THAN (30)") { + t.Errorf("expected p_merged partition in DDL:\n%s", ddl) + } + if strings.Contains(ddl, "PARTITION p1 ") || strings.Contains(ddl, "PARTITION p2 ") { + t.Errorf("old partitions p1/p2 should not appear in DDL:\n%s", ddl) + } + }) + + t.Run("range_date_expression", func(t *testing.T) { + // Scenario 8: RANGE partition with date expression (YEAR(col)) + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t8 ( + id INT NOT NULL, + created_at DATE NOT NULL + ) + PARTITION BY RANGE (YEAR(created_at)) ( + PARTITION p2020 VALUES LESS THAN (2021), + PARTITION p2021 VALUES LESS THAN (2022), + PARTITION p2022 VALUES LESS THAN (2023), + PARTITION pfuture VALUES LESS THAN MAXVALUE + )`) + + tbl := c.GetDatabase("testdb").GetTable("t8") + if tbl == nil { + t.Fatal("table not found") + } + if tbl.Partitioning == nil { + t.Fatal("expected partitioning info, got nil") + } + if tbl.Partitioning.Type != "RANGE" { + t.Errorf("expected type RANGE, got %q", tbl.Partitioning.Type) + } + if len(tbl.Partitioning.Partitions) != 4 { + t.Fatalf("expected 4 partitions, got %d", len(tbl.Partitioning.Partitions)) + } + + ddl := c.ShowCreateTable("testdb", "t8") + // Verify the expression rendering includes YEAR(...) (MySQL renders as lowercase year()) + upperDDL := strings.ToUpper(ddl) + if !strings.Contains(upperDDL, "YEAR(") { + t.Errorf("expected YEAR() expression in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "VALUES LESS THAN (2021)") { + t.Errorf("expected 'VALUES LESS THAN (2021)' in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "VALUES LESS THAN MAXVALUE") { + t.Errorf("expected 'VALUES LESS THAN MAXVALUE' (no parens) in DDL:\n%s", ddl) + } + }) +} diff --git a/tidb/catalog/wt_7_2_test.go b/tidb/catalog/wt_7_2_test.go new file mode 100644 index 00000000..407d262a --- /dev/null +++ b/tidb/catalog/wt_7_2_test.go @@ -0,0 +1,272 @@ +package catalog + +import ( + "strings" + "testing" +) + +// --- Section 3.2: LIST Partitioning (6 scenarios) --- + +// Scenario 1: CREATE TABLE PARTITION BY LIST (expr) with VALUES IN +func TestWalkThrough_7_2_ListByExpr(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t1 ( + id INT NOT NULL, + region INT NOT NULL + ) PARTITION BY LIST (region) ( + PARTITION p_east VALUES IN (1,2,3), + PARTITION p_west VALUES IN (4,5,6), + PARTITION p_central VALUES IN (7,8,9) + )`) + + ddl := c.ShowCreateTable("testdb", "t1") + if ddl == "" { + t.Fatal("ShowCreateTable returned empty string") + } + + // Verify partition type marker + if !strings.Contains(ddl, "/*!50100 PARTITION BY LIST") { + t.Errorf("expected /*!50100 PARTITION BY LIST, got:\n%s", ddl) + } + // MySQL 8.0 backtick-quotes identifiers in partition expressions + if !strings.Contains(ddl, "LIST (`region`)") { + t.Errorf("expected LIST (`region`), got:\n%s", ddl) + } + + // Verify partition definitions + for _, name := range []string{"p_east", "p_west", "p_central"} { + if !strings.Contains(ddl, "PARTITION "+name) { + t.Errorf("expected PARTITION %s in DDL:\n%s", name, ddl) + } + } + if !strings.Contains(ddl, "VALUES IN (1,2,3)") { + t.Errorf("expected VALUES IN (1,2,3), got:\n%s", ddl) + } + if !strings.Contains(ddl, "VALUES IN (4,5,6)") { + t.Errorf("expected VALUES IN (4,5,6), got:\n%s", ddl) + } + if !strings.Contains(ddl, "VALUES IN (7,8,9)") { + t.Errorf("expected VALUES IN (7,8,9), got:\n%s", ddl) + } + + // Verify catalog state + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Fatal("table t1 not found") + } + if tbl.Partitioning == nil { + t.Fatal("partitioning is nil") + } + if tbl.Partitioning.Type != "LIST" { + t.Errorf("expected partition type LIST, got %q", tbl.Partitioning.Type) + } + if len(tbl.Partitioning.Partitions) != 3 { + t.Errorf("expected 3 partitions, got %d", len(tbl.Partitioning.Partitions)) + } +} + +// Scenario 2: CREATE TABLE PARTITION BY LIST COLUMNS (col) — single column +func TestWalkThrough_7_2_ListColumnsSingle(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t2 ( + id INT NOT NULL, + status VARCHAR(20) NOT NULL + ) PARTITION BY LIST COLUMNS (status) ( + PARTITION p_active VALUES IN ('active','pending'), + PARTITION p_inactive VALUES IN ('inactive','archived') + )`) + + ddl := c.ShowCreateTable("testdb", "t2") + if ddl == "" { + t.Fatal("ShowCreateTable returned empty string") + } + + // LIST COLUMNS uses /*!50500 + if !strings.Contains(ddl, "/*!50500 PARTITION BY LIST") { + t.Errorf("expected /*!50500 PARTITION BY LIST, got:\n%s", ddl) + } + // MySQL uses double space before COLUMNS + if !strings.Contains(ddl, "LIST COLUMNS(status)") { + t.Errorf("expected LIST COLUMNS(status), got:\n%s", ddl) + } + + // Verify VALUES IN with string values + if !strings.Contains(ddl, "VALUES IN ('active','pending')") { + t.Errorf("expected VALUES IN ('active','pending'), got:\n%s", ddl) + } + + tbl := c.GetDatabase("testdb").GetTable("t2") + if tbl.Partitioning.Type != "LIST COLUMNS" { + t.Errorf("expected partition type LIST COLUMNS, got %q", tbl.Partitioning.Type) + } + if len(tbl.Partitioning.Columns) != 1 { + t.Errorf("expected 1 partition column, got %d", len(tbl.Partitioning.Columns)) + } +} + +// Scenario 3: CREATE TABLE PARTITION BY LIST COLUMNS (col1, col2) — multi-column +func TestWalkThrough_7_2_ListColumnsMulti(t *testing.T) { + c := wtSetup(t) + + // Multi-column LIST COLUMNS with tuple syntax: VALUES IN ((c1,c2),(c1,c2)) + // If the parser doesn't support tuple syntax, this will fail at parse time. + results, err := c.Exec(`CREATE TABLE t3 ( + id INT NOT NULL, + a INT NOT NULL, + b INT NOT NULL + ) PARTITION BY LIST COLUMNS (a, b) ( + PARTITION p0 VALUES IN ((0,0),(0,1)), + PARTITION p1 VALUES IN ((1,0),(1,1)) + )`, nil) + if err != nil { + t.Skipf("[~] partial: parser does not support multi-column LIST COLUMNS tuple syntax: %v", err) + return + } + for _, r := range results { + if r.Error != nil { + t.Skipf("[~] partial: catalog error for multi-column LIST COLUMNS: %v", r.Error) + return + } + } + + ddl := c.ShowCreateTable("testdb", "t3") + if ddl == "" { + t.Fatal("ShowCreateTable returned empty string") + } + + if !strings.Contains(ddl, "LIST COLUMNS(a,b)") { + t.Errorf("expected LIST COLUMNS(a,b), got:\n%s", ddl) + } + + tbl := c.GetDatabase("testdb").GetTable("t3") + if tbl.Partitioning.Type != "LIST COLUMNS" { + t.Errorf("expected partition type LIST COLUMNS, got %q", tbl.Partitioning.Type) + } + if len(tbl.Partitioning.Columns) != 2 { + t.Errorf("expected 2 partition columns, got %d", len(tbl.Partitioning.Columns)) + } + if len(tbl.Partitioning.Partitions) != 2 { + t.Errorf("expected 2 partitions, got %d", len(tbl.Partitioning.Partitions)) + } +} + +// Scenario 4: ALTER TABLE ADD PARTITION with new VALUES IN +func TestWalkThrough_7_2_AlterAddPartition(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t4 ( + id INT NOT NULL, + region INT NOT NULL + ) PARTITION BY LIST (region) ( + PARTITION p_east VALUES IN (1,2,3), + PARTITION p_west VALUES IN (4,5,6) + )`) + + tbl := c.GetDatabase("testdb").GetTable("t4") + if len(tbl.Partitioning.Partitions) != 2 { + t.Fatalf("expected 2 partitions before ADD, got %d", len(tbl.Partitioning.Partitions)) + } + + wtExec(t, c, `ALTER TABLE t4 ADD PARTITION (PARTITION p_central VALUES IN (7,8,9))`) + + tbl = c.GetDatabase("testdb").GetTable("t4") + if len(tbl.Partitioning.Partitions) != 3 { + t.Fatalf("expected 3 partitions after ADD, got %d", len(tbl.Partitioning.Partitions)) + } + + // Verify the new partition name + lastPart := tbl.Partitioning.Partitions[2] + if lastPart.Name != "p_central" { + t.Errorf("expected partition name p_central, got %q", lastPart.Name) + } + + ddl := c.ShowCreateTable("testdb", "t4") + if !strings.Contains(ddl, "VALUES IN (7,8,9)") { + t.Errorf("expected VALUES IN (7,8,9) in DDL:\n%s", ddl) + } +} + +// Scenario 5: ALTER TABLE DROP PARTITION from LIST table +func TestWalkThrough_7_2_AlterDropPartition(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t5 ( + id INT NOT NULL, + region INT NOT NULL + ) PARTITION BY LIST (region) ( + PARTITION p_east VALUES IN (1,2,3), + PARTITION p_west VALUES IN (4,5,6), + PARTITION p_central VALUES IN (7,8,9) + )`) + + tbl := c.GetDatabase("testdb").GetTable("t5") + if len(tbl.Partitioning.Partitions) != 3 { + t.Fatalf("expected 3 partitions before DROP, got %d", len(tbl.Partitioning.Partitions)) + } + + wtExec(t, c, `ALTER TABLE t5 DROP PARTITION p_west`) + + tbl = c.GetDatabase("testdb").GetTable("t5") + if len(tbl.Partitioning.Partitions) != 2 { + t.Fatalf("expected 2 partitions after DROP, got %d", len(tbl.Partitioning.Partitions)) + } + + // Verify remaining partition names + names := make([]string, len(tbl.Partitioning.Partitions)) + for i, p := range tbl.Partitioning.Partitions { + names[i] = p.Name + } + if names[0] != "p_east" || names[1] != "p_central" { + t.Errorf("expected partitions [p_east, p_central], got %v", names) + } + + ddl := c.ShowCreateTable("testdb", "t5") + if strings.Contains(ddl, "p_west") { + t.Errorf("dropped partition p_west still appears in DDL:\n%s", ddl) + } +} + +// Scenario 6: ALTER TABLE REORGANIZE PARTITION in LIST table +func TestWalkThrough_7_2_AlterReorganizePartition(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t6 ( + id INT NOT NULL, + region INT NOT NULL + ) PARTITION BY LIST (region) ( + PARTITION p_east VALUES IN (1,2,3), + PARTITION p_west VALUES IN (4,5,6) + )`) + + // Reorganize p_east into two new partitions + wtExec(t, c, `ALTER TABLE t6 REORGANIZE PARTITION p_east INTO ( + PARTITION p_northeast VALUES IN (1,2), + PARTITION p_southeast VALUES IN (3) + )`) + + tbl := c.GetDatabase("testdb").GetTable("t6") + if len(tbl.Partitioning.Partitions) != 3 { + t.Fatalf("expected 3 partitions after REORGANIZE, got %d", len(tbl.Partitioning.Partitions)) + } + + // Verify partition names and ordering + expected := []string{"p_northeast", "p_southeast", "p_west"} + for i, p := range tbl.Partitioning.Partitions { + if p.Name != expected[i] { + t.Errorf("partition %d: expected name %q, got %q", i, expected[i], p.Name) + } + } + + ddl := c.ShowCreateTable("testdb", "t6") + if strings.Contains(ddl, "p_east") { + t.Errorf("reorganized partition p_east still appears in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "VALUES IN (1,2)") { + t.Errorf("expected VALUES IN (1,2) in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "VALUES IN (3)") { + t.Errorf("expected VALUES IN (3) in DDL:\n%s", ddl) + } +} diff --git a/tidb/catalog/wt_7_3_test.go b/tidb/catalog/wt_7_3_test.go new file mode 100644 index 00000000..44233add --- /dev/null +++ b/tidb/catalog/wt_7_3_test.go @@ -0,0 +1,327 @@ +package catalog + +import ( + "strings" + "testing" +) + +// --- Section 3.3: HASH and KEY Partitioning (9 scenarios) --- + +// Scenario 1: CREATE TABLE PARTITION BY HASH (expr) PARTITIONS 4 +func TestWalkThrough_7_3_HashByExpr(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t1 ( + id INT NOT NULL, + val INT NOT NULL + ) PARTITION BY HASH (id) PARTITIONS 4`) + + ddl := c.ShowCreateTable("testdb", "t1") + if ddl == "" { + t.Fatal("ShowCreateTable returned empty string") + } + + // Verify partition type marker + if !strings.Contains(ddl, "/*!50100 PARTITION BY HASH") { + t.Errorf("expected /*!50100 PARTITION BY HASH, got:\n%s", ddl) + } + // MySQL 8.0 backtick-quotes identifiers in HASH expression + if !strings.Contains(ddl, "HASH (`id`)") { + t.Errorf("expected HASH (`id`), got:\n%s", ddl) + } + // HASH with no explicit partition defs renders PARTITIONS N + if !strings.Contains(ddl, "PARTITIONS 4") { + t.Errorf("expected PARTITIONS 4 in DDL:\n%s", ddl) + } + + // Verify catalog state + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Fatal("table t1 not found") + } + if tbl.Partitioning == nil { + t.Fatal("partitioning is nil") + } + if tbl.Partitioning.Type != "HASH" { + t.Errorf("expected partition type HASH, got %q", tbl.Partitioning.Type) + } + if tbl.Partitioning.NumParts != 4 { + t.Errorf("expected NumParts 4, got %d", tbl.Partitioning.NumParts) + } + if tbl.Partitioning.Linear { + t.Error("expected Linear=false for non-linear HASH") + } +} + +// Scenario 2: CREATE TABLE PARTITION BY LINEAR HASH (expr) PARTITIONS 4 +func TestWalkThrough_7_3_LinearHash(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t2 ( + id INT NOT NULL, + val INT NOT NULL + ) PARTITION BY LINEAR HASH (id) PARTITIONS 4`) + + ddl := c.ShowCreateTable("testdb", "t2") + if ddl == "" { + t.Fatal("ShowCreateTable returned empty string") + } + + // LINEAR keyword should be rendered + if !strings.Contains(ddl, "LINEAR HASH") { + t.Errorf("expected LINEAR HASH in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "PARTITIONS 4") { + t.Errorf("expected PARTITIONS 4 in DDL:\n%s", ddl) + } + + // Verify catalog state + tbl := c.GetDatabase("testdb").GetTable("t2") + if tbl.Partitioning == nil { + t.Fatal("partitioning is nil") + } + if !tbl.Partitioning.Linear { + t.Error("expected Linear=true for LINEAR HASH") + } + if tbl.Partitioning.Type != "HASH" { + t.Errorf("expected partition type HASH, got %q", tbl.Partitioning.Type) + } +} + +// Scenario 3: CREATE TABLE PARTITION BY KEY (col) PARTITIONS 4 +func TestWalkThrough_7_3_KeyPartition(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t3 ( + id INT NOT NULL PRIMARY KEY, + val INT NOT NULL + ) PARTITION BY KEY (id) PARTITIONS 4`) + + ddl := c.ShowCreateTable("testdb", "t3") + if ddl == "" { + t.Fatal("ShowCreateTable returned empty string") + } + + if !strings.Contains(ddl, "/*!50100 PARTITION BY KEY") { + t.Errorf("expected /*!50100 PARTITION BY KEY, got:\n%s", ddl) + } + // KEY columns are rendered without backticks (plain) + if !strings.Contains(ddl, "KEY (id)") { + t.Errorf("expected KEY (id) in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "PARTITIONS 4") { + t.Errorf("expected PARTITIONS 4 in DDL:\n%s", ddl) + } + + // Verify catalog state + tbl := c.GetDatabase("testdb").GetTable("t3") + if tbl.Partitioning == nil { + t.Fatal("partitioning is nil") + } + if tbl.Partitioning.Type != "KEY" { + t.Errorf("expected partition type KEY, got %q", tbl.Partitioning.Type) + } + if len(tbl.Partitioning.Columns) != 1 || tbl.Partitioning.Columns[0] != "id" { + t.Errorf("expected Columns [id], got %v", tbl.Partitioning.Columns) + } + if tbl.Partitioning.NumParts != 4 { + t.Errorf("expected NumParts 4, got %d", tbl.Partitioning.NumParts) + } +} + +// Scenario 4: CREATE TABLE PARTITION BY KEY () PARTITIONS 4 — uses PK +func TestWalkThrough_7_3_KeyEmptyColumns(t *testing.T) { + c := wtSetup(t) + + results, err := c.Exec(`CREATE TABLE t4 ( + id INT NOT NULL PRIMARY KEY, + val INT NOT NULL + ) PARTITION BY KEY () PARTITIONS 4`, nil) + if err != nil { + t.Skipf("[~] partial: parser does not support KEY () with empty column list: %v", err) + return + } + for _, r := range results { + if r.Error != nil { + t.Skipf("[~] partial: catalog error for KEY (): %v", r.Error) + return + } + } + + ddl := c.ShowCreateTable("testdb", "t4") + if ddl == "" { + t.Fatal("ShowCreateTable returned empty string") + } + + // KEY () uses PK columns — MySQL renders as KEY () + if !strings.Contains(ddl, "KEY ()") { + t.Errorf("expected KEY () in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "PARTITIONS 4") { + t.Errorf("expected PARTITIONS 4 in DDL:\n%s", ddl) + } + + // Verify catalog state + tbl := c.GetDatabase("testdb").GetTable("t4") + if tbl.Partitioning == nil { + t.Fatal("partitioning is nil") + } + if tbl.Partitioning.Type != "KEY" { + t.Errorf("expected partition type KEY, got %q", tbl.Partitioning.Type) + } + if tbl.Partitioning.NumParts != 4 { + t.Errorf("expected NumParts 4, got %d", tbl.Partitioning.NumParts) + } +} + +// Scenario 5: CREATE TABLE PARTITION BY KEY (col) ALGORITHM=2 PARTITIONS 4 +func TestWalkThrough_7_3_KeyAlgorithm(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t5 ( + id INT NOT NULL PRIMARY KEY, + val INT NOT NULL + ) PARTITION BY KEY ALGORITHM=2 (id) PARTITIONS 4`) + + ddl := c.ShowCreateTable("testdb", "t5") + if ddl == "" { + t.Fatal("ShowCreateTable returned empty string") + } + + // ALGORITHM=2 should be rendered as ALGORITHM = 2 + if !strings.Contains(ddl, "KEY ALGORITHM = 2 (id)") { + t.Errorf("expected KEY ALGORITHM = 2 (id) in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "PARTITIONS 4") { + t.Errorf("expected PARTITIONS 4 in DDL:\n%s", ddl) + } + + // Verify catalog state + tbl := c.GetDatabase("testdb").GetTable("t5") + if tbl.Partitioning == nil { + t.Fatal("partitioning is nil") + } + if tbl.Partitioning.Algorithm != 2 { + t.Errorf("expected Algorithm 2, got %d", tbl.Partitioning.Algorithm) + } +} + +// Scenario 6: ALTER TABLE COALESCE PARTITION 2 on HASH table (4→2) +func TestWalkThrough_7_3_CoalesceHash(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t6 ( + id INT NOT NULL, + val INT NOT NULL + ) PARTITION BY HASH (id) PARTITIONS 4`) + + tbl := c.GetDatabase("testdb").GetTable("t6") + if tbl.Partitioning.NumParts != 4 { + t.Fatalf("expected NumParts 4 before COALESCE, got %d", tbl.Partitioning.NumParts) + } + + wtExec(t, c, `ALTER TABLE t6 COALESCE PARTITION 2`) + + tbl = c.GetDatabase("testdb").GetTable("t6") + if tbl.Partitioning.NumParts != 2 { + t.Errorf("expected NumParts 2 after COALESCE, got %d", tbl.Partitioning.NumParts) + } + + ddl := c.ShowCreateTable("testdb", "t6") + if !strings.Contains(ddl, "PARTITIONS 2") { + t.Errorf("expected PARTITIONS 2 in DDL:\n%s", ddl) + } +} + +// Scenario 7: ALTER TABLE COALESCE PARTITION on KEY table +func TestWalkThrough_7_3_CoalesceKey(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t7 ( + id INT NOT NULL PRIMARY KEY, + val INT NOT NULL + ) PARTITION BY KEY (id) PARTITIONS 4`) + + tbl := c.GetDatabase("testdb").GetTable("t7") + if tbl.Partitioning.NumParts != 4 { + t.Fatalf("expected NumParts 4 before COALESCE, got %d", tbl.Partitioning.NumParts) + } + + wtExec(t, c, `ALTER TABLE t7 COALESCE PARTITION 2`) + + tbl = c.GetDatabase("testdb").GetTable("t7") + if tbl.Partitioning.NumParts != 2 { + t.Errorf("expected NumParts 2 after COALESCE, got %d", tbl.Partitioning.NumParts) + } + + ddl := c.ShowCreateTable("testdb", "t7") + if !strings.Contains(ddl, "PARTITIONS 2") { + t.Errorf("expected PARTITIONS 2 in DDL:\n%s", ddl) + } +} + +// Scenario 8: ALTER TABLE ADD PARTITION on HASH table — error in MySQL +// MySQL rejects ADD PARTITION on HASH/KEY tables. Document omni behavior. +func TestWalkThrough_7_3_AddPartitionHashError(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t8 ( + id INT NOT NULL, + val INT NOT NULL + ) PARTITION BY HASH (id) PARTITIONS 4`) + + // In MySQL, this would error. Test what omni does. + results, err := c.Exec(`ALTER TABLE t8 ADD PARTITION (PARTITION p_extra ENGINE=InnoDB)`, nil) + if err != nil { + // Parse error — document and skip + t.Logf("note: ADD PARTITION on HASH table caused parse error (MySQL would reject this too): %v", err) + return + } + for _, r := range results { + if r.Error != nil { + // Catalog error — this is actually correct behavior (MySQL rejects it) + t.Logf("note: ADD PARTITION on HASH table correctly errored: %v", r.Error) + return + } + } + + // If we reach here, omni allowed ADD PARTITION on HASH (differs from MySQL) + t.Log("note: omni allows ADD PARTITION on HASH table — MySQL would reject this. Documenting behavior difference.") + + // Verify the partition was added (since omni allowed it) + tbl := c.GetDatabase("testdb").GetTable("t8") + if tbl.Partitioning == nil { + t.Fatal("partitioning is nil") + } + // The original had NumParts=4 with no explicit Partitions slice. + // ADD PARTITION appends to the Partitions slice. + t.Logf("after ADD PARTITION: NumParts=%d, len(Partitions)=%d", + tbl.Partitioning.NumParts, len(tbl.Partitioning.Partitions)) +} + +// Scenario 9: ALTER TABLE REMOVE PARTITIONING on partitioned table +func TestWalkThrough_7_3_RemovePartitioning(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t9 ( + id INT NOT NULL, + val INT NOT NULL + ) PARTITION BY HASH (id) PARTITIONS 4`) + + tbl := c.GetDatabase("testdb").GetTable("t9") + if tbl.Partitioning == nil { + t.Fatal("partitioning should be set before REMOVE") + } + + wtExec(t, c, `ALTER TABLE t9 REMOVE PARTITIONING`) + + tbl = c.GetDatabase("testdb").GetTable("t9") + if tbl.Partitioning != nil { + t.Error("partitioning should be nil after REMOVE PARTITIONING") + } + + ddl := c.ShowCreateTable("testdb", "t9") + if strings.Contains(ddl, "PARTITION") { + t.Errorf("SHOW CREATE TABLE should not contain PARTITION after REMOVE:\n%s", ddl) + } +} diff --git a/tidb/catalog/wt_7_4_test.go b/tidb/catalog/wt_7_4_test.go new file mode 100644 index 00000000..a9aa74a0 --- /dev/null +++ b/tidb/catalog/wt_7_4_test.go @@ -0,0 +1,392 @@ +package catalog + +import ( + "strings" + "testing" +) + +// --- Section 3.4: Subpartitions and Partition Options (8 scenarios) --- + +// Scenario 1: RANGE partitions with SUBPARTITION BY HASH +func TestWalkThrough_7_4_RangeSubpartByHash(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t1 ( + id INT NOT NULL, + purchased DATE NOT NULL + ) PARTITION BY RANGE (YEAR(purchased)) + SUBPARTITION BY HASH (TO_DAYS(purchased)) + SUBPARTITIONS 2 + ( + PARTITION p0 VALUES LESS THAN (2000), + PARTITION p1 VALUES LESS THAN (2010), + PARTITION p2 VALUES LESS THAN MAXVALUE + )`) + + ddl := c.ShowCreateTable("testdb", "t1") + if ddl == "" { + t.Fatal("ShowCreateTable returned empty string") + } + + // Verify SUBPARTITION BY HASH rendering + if !strings.Contains(ddl, "SUBPARTITION BY HASH") { + t.Errorf("expected SUBPARTITION BY HASH in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "SUBPARTITIONS 2") { + t.Errorf("expected SUBPARTITIONS 2 in DDL:\n%s", ddl) + } + // Verify RANGE partitions still rendered + if !strings.Contains(ddl, "PARTITION p0 VALUES LESS THAN (2000)") { + t.Errorf("expected PARTITION p0 in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "PARTITION p2 VALUES LESS THAN MAXVALUE") { + t.Errorf("expected PARTITION p2 with MAXVALUE in DDL:\n%s", ddl) + } + + // Verify catalog state + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl == nil { + t.Fatal("table t1 not found") + } + if tbl.Partitioning == nil { + t.Fatal("partitioning is nil") + } + if tbl.Partitioning.SubType != "HASH" { + t.Errorf("expected SubType HASH, got %q", tbl.Partitioning.SubType) + } + if tbl.Partitioning.NumSubParts != 2 { + t.Errorf("expected NumSubParts 2, got %d", tbl.Partitioning.NumSubParts) + } + if len(tbl.Partitioning.Partitions) != 3 { + t.Errorf("expected 3 partitions, got %d", len(tbl.Partitioning.Partitions)) + } +} + +// Scenario 2: RANGE partitions with SUBPARTITION BY KEY +func TestWalkThrough_7_4_RangeSubpartByKey(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t2 ( + id INT NOT NULL, + purchased DATE NOT NULL + ) PARTITION BY RANGE (YEAR(purchased)) + SUBPARTITION BY KEY (id) + SUBPARTITIONS 2 + ( + PARTITION p0 VALUES LESS THAN (2000), + PARTITION p1 VALUES LESS THAN MAXVALUE + )`) + + ddl := c.ShowCreateTable("testdb", "t2") + if ddl == "" { + t.Fatal("ShowCreateTable returned empty string") + } + + // Verify SUBPARTITION BY KEY rendering + if !strings.Contains(ddl, "SUBPARTITION BY KEY") { + t.Errorf("expected SUBPARTITION BY KEY in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "KEY (`id`)") { + t.Errorf("expected KEY (`id`) in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "SUBPARTITIONS 2") { + t.Errorf("expected SUBPARTITIONS 2 in DDL:\n%s", ddl) + } + + // Verify catalog state + tbl := c.GetDatabase("testdb").GetTable("t2") + if tbl.Partitioning == nil { + t.Fatal("partitioning is nil") + } + if tbl.Partitioning.SubType != "KEY" { + t.Errorf("expected SubType KEY, got %q", tbl.Partitioning.SubType) + } + if len(tbl.Partitioning.SubColumns) != 1 || tbl.Partitioning.SubColumns[0] != "id" { + t.Errorf("expected SubColumns [id], got %v", tbl.Partitioning.SubColumns) + } + if tbl.Partitioning.NumSubParts != 2 { + t.Errorf("expected NumSubParts 2, got %d", tbl.Partitioning.NumSubParts) + } +} + +// Scenario 3: Explicit subpartition definitions with names +func TestWalkThrough_7_4_ExplicitSubpartNames(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t3 ( + id INT NOT NULL, + purchased DATE NOT NULL + ) PARTITION BY RANGE (YEAR(purchased)) + SUBPARTITION BY HASH (TO_DAYS(purchased)) + ( + PARTITION p0 VALUES LESS THAN (2000) ( + SUBPARTITION s0, + SUBPARTITION s1 + ), + PARTITION p1 VALUES LESS THAN MAXVALUE ( + SUBPARTITION s2, + SUBPARTITION s3 + ) + )`) + + ddl := c.ShowCreateTable("testdb", "t3") + if ddl == "" { + t.Fatal("ShowCreateTable returned empty string") + } + + // Verify explicit subpartition names rendered + if !strings.Contains(ddl, "SUBPARTITION s0") { + t.Errorf("expected SUBPARTITION s0 in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "SUBPARTITION s1") { + t.Errorf("expected SUBPARTITION s1 in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "SUBPARTITION s2") { + t.Errorf("expected SUBPARTITION s2 in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "SUBPARTITION s3") { + t.Errorf("expected SUBPARTITION s3 in DDL:\n%s", ddl) + } + + // Verify catalog state + tbl := c.GetDatabase("testdb").GetTable("t3") + if tbl.Partitioning == nil { + t.Fatal("partitioning is nil") + } + if len(tbl.Partitioning.Partitions) != 2 { + t.Fatalf("expected 2 partitions, got %d", len(tbl.Partitioning.Partitions)) + } + p0 := tbl.Partitioning.Partitions[0] + if len(p0.SubPartitions) != 2 { + t.Fatalf("expected 2 subpartitions for p0, got %d", len(p0.SubPartitions)) + } + if p0.SubPartitions[0].Name != "s0" { + t.Errorf("expected subpartition name s0, got %q", p0.SubPartitions[0].Name) + } + if p0.SubPartitions[1].Name != "s1" { + t.Errorf("expected subpartition name s1, got %q", p0.SubPartitions[1].Name) + } + p1 := tbl.Partitioning.Partitions[1] + if len(p1.SubPartitions) != 2 { + t.Fatalf("expected 2 subpartitions for p1, got %d", len(p1.SubPartitions)) + } + if p1.SubPartitions[0].Name != "s2" { + t.Errorf("expected subpartition name s2, got %q", p1.SubPartitions[0].Name) + } + if p1.SubPartitions[1].Name != "s3" { + t.Errorf("expected subpartition name s3, got %q", p1.SubPartitions[1].Name) + } +} + +// Scenario 4: Partition with ENGINE option +func TestWalkThrough_7_4_PartitionEngineOption(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t4 ( + id INT NOT NULL, + val INT NOT NULL + ) PARTITION BY RANGE (id) ( + PARTITION p0 VALUES LESS THAN (100) ENGINE=InnoDB, + PARTITION p1 VALUES LESS THAN (200) ENGINE=InnoDB, + PARTITION p2 VALUES LESS THAN MAXVALUE ENGINE=InnoDB + )`) + + ddl := c.ShowCreateTable("testdb", "t4") + if ddl == "" { + t.Fatal("ShowCreateTable returned empty string") + } + + // Verify ENGINE = InnoDB rendered per partition + if !strings.Contains(ddl, "ENGINE = InnoDB") { + t.Errorf("expected ENGINE = InnoDB per partition in DDL:\n%s", ddl) + } + + // Count occurrences of ENGINE = InnoDB — should be 3 (one per partition) + count := strings.Count(ddl, "ENGINE = InnoDB") + if count < 3 { + t.Errorf("expected at least 3 occurrences of ENGINE = InnoDB, got %d in DDL:\n%s", count, ddl) + } + + // Verify catalog state: engine defaults to InnoDB even when not specified + tbl := c.GetDatabase("testdb").GetTable("t4") + if tbl.Partitioning == nil { + t.Fatal("partitioning is nil") + } + if len(tbl.Partitioning.Partitions) != 3 { + t.Errorf("expected 3 partitions, got %d", len(tbl.Partitioning.Partitions)) + } +} + +// Scenario 5: Partition with COMMENT option +func TestWalkThrough_7_4_PartitionCommentOption(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t5 ( + id INT NOT NULL, + val INT NOT NULL + ) PARTITION BY RANGE (id) ( + PARTITION p0 VALUES LESS THAN (100) COMMENT='first partition', + PARTITION p1 VALUES LESS THAN MAXVALUE COMMENT='last partition' + )`) + + ddl := c.ShowCreateTable("testdb", "t5") + if ddl == "" { + t.Fatal("ShowCreateTable returned empty string") + } + + // Verify COMMENT rendered per partition + if !strings.Contains(ddl, "COMMENT = 'first partition'") { + t.Errorf("expected COMMENT = 'first partition' in DDL:\n%s", ddl) + } + if !strings.Contains(ddl, "COMMENT = 'last partition'") { + t.Errorf("expected COMMENT = 'last partition' in DDL:\n%s", ddl) + } + + // Verify catalog state + tbl := c.GetDatabase("testdb").GetTable("t5") + if tbl.Partitioning == nil { + t.Fatal("partitioning is nil") + } + p0 := tbl.Partitioning.Partitions[0] + if p0.Comment != "first partition" { + t.Errorf("expected comment 'first partition', got %q", p0.Comment) + } + p1 := tbl.Partitioning.Partitions[1] + if p1.Comment != "last partition" { + t.Errorf("expected comment 'last partition', got %q", p1.Comment) + } +} + +// Scenario 6: SUBPARTITIONS N count without explicit defs +func TestWalkThrough_7_4_SubpartitionsNCount(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t6 ( + id INT NOT NULL, + purchased DATE NOT NULL + ) PARTITION BY RANGE (YEAR(purchased)) + SUBPARTITION BY HASH (TO_DAYS(purchased)) + SUBPARTITIONS 3 + ( + PARTITION p0 VALUES LESS THAN (2000), + PARTITION p1 VALUES LESS THAN MAXVALUE + )`) + + ddl := c.ShowCreateTable("testdb", "t6") + if ddl == "" { + t.Fatal("ShowCreateTable returned empty string") + } + + // Verify SUBPARTITIONS 3 rendered (no explicit subpartition names) + if !strings.Contains(ddl, "SUBPARTITIONS 3") { + t.Errorf("expected SUBPARTITIONS 3 in DDL:\n%s", ddl) + } + + // Verify catalog state + tbl := c.GetDatabase("testdb").GetTable("t6") + if tbl.Partitioning == nil { + t.Fatal("partitioning is nil") + } + if tbl.Partitioning.NumSubParts != 3 { + t.Errorf("expected NumSubParts 3, got %d", tbl.Partitioning.NumSubParts) + } + // Auto-generated subpartitions should be present on each partition def. + for _, p := range tbl.Partitioning.Partitions { + if len(p.SubPartitions) != 3 { + t.Errorf("expected 3 auto-generated subpartitions for partition %s, got %d", p.Name, len(p.SubPartitions)) + } + } +} + +// Scenario 7: ALTER TABLE TRUNCATE PARTITION — no structural change +func TestWalkThrough_7_4_TruncatePartition(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t7 ( + id INT NOT NULL, + val INT NOT NULL + ) PARTITION BY RANGE (id) ( + PARTITION p0 VALUES LESS THAN (100), + PARTITION p1 VALUES LESS THAN (200), + PARTITION p2 VALUES LESS THAN MAXVALUE + )`) + + ddlBefore := c.ShowCreateTable("testdb", "t7") + if ddlBefore == "" { + t.Fatal("ShowCreateTable returned empty string before TRUNCATE") + } + + wtExec(t, c, `ALTER TABLE t7 TRUNCATE PARTITION p1`) + + ddlAfter := c.ShowCreateTable("testdb", "t7") + + // TRUNCATE PARTITION is a data-only operation; structure unchanged + if ddlBefore != ddlAfter { + t.Errorf("SHOW CREATE TABLE changed after TRUNCATE PARTITION:\n--- before ---\n%s\n--- after ---\n%s", + ddlBefore, ddlAfter) + } + + // Verify catalog state unchanged + tbl := c.GetDatabase("testdb").GetTable("t7") + if tbl.Partitioning == nil { + t.Fatal("partitioning is nil after TRUNCATE") + } + if len(tbl.Partitioning.Partitions) != 3 { + t.Errorf("expected 3 partitions after TRUNCATE, got %d", len(tbl.Partitioning.Partitions)) + } +} + +// Scenario 8: ALTER TABLE EXCHANGE PARTITION — validation only, structure unchanged +func TestWalkThrough_7_4_ExchangePartition(t *testing.T) { + c := wtSetup(t) + + wtExec(t, c, `CREATE TABLE t8 ( + id INT NOT NULL, + val INT NOT NULL + ) PARTITION BY RANGE (id) ( + PARTITION p0 VALUES LESS THAN (100), + PARTITION p1 VALUES LESS THAN (200), + PARTITION p2 VALUES LESS THAN MAXVALUE + )`) + + // Create the exchange target table (non-partitioned) + wtExec(t, c, `CREATE TABLE t8_swap ( + id INT NOT NULL, + val INT NOT NULL + )`) + + ddlBefore := c.ShowCreateTable("testdb", "t8") + ddlSwapBefore := c.ShowCreateTable("testdb", "t8_swap") + + wtExec(t, c, `ALTER TABLE t8 EXCHANGE PARTITION p1 WITH TABLE t8_swap`) + + ddlAfter := c.ShowCreateTable("testdb", "t8") + ddlSwapAfter := c.ShowCreateTable("testdb", "t8_swap") + + // EXCHANGE PARTITION is a data operation; structure unchanged for both tables + if ddlBefore != ddlAfter { + t.Errorf("partitioned table structure changed after EXCHANGE:\n--- before ---\n%s\n--- after ---\n%s", + ddlBefore, ddlAfter) + } + if ddlSwapBefore != ddlSwapAfter { + t.Errorf("swap table structure changed after EXCHANGE:\n--- before ---\n%s\n--- after ---\n%s", + ddlSwapBefore, ddlSwapAfter) + } + + // Verify both tables still exist + if c.GetDatabase("testdb").GetTable("t8") == nil { + t.Error("partitioned table t8 should still exist") + } + if c.GetDatabase("testdb").GetTable("t8_swap") == nil { + t.Error("swap table t8_swap should still exist") + } + + // Verify exchange with nonexistent table errors + results, err := c.Exec(`ALTER TABLE t8 EXCHANGE PARTITION p1 WITH TABLE nonexistent`, nil) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if len(results) == 0 || results[0].Error == nil { + t.Error("expected error when exchanging with nonexistent table") + } +} diff --git a/tidb/catalog/wt_8_1_test.go b/tidb/catalog/wt_8_1_test.go new file mode 100644 index 00000000..29fcba26 --- /dev/null +++ b/tidb/catalog/wt_8_1_test.go @@ -0,0 +1,148 @@ +package catalog + +import ( + "strings" + "testing" +) + +// --- Section 4.1 (starmap): Charset Inheritance Chain (7 scenarios) --- + +func TestWalkThrough_8_1(t *testing.T) { + t.Run("db_charset_inherited_by_table_and_column", func(t *testing.T) { + // Scenario 1: CREATE DATABASE CHARSET utf8mb4 → CREATE TABLE (no charset) → column inherits utf8mb4 + c := New() + mustExec(t, c, "CREATE DATABASE db1 DEFAULT CHARACTER SET utf8mb4") + c.SetCurrentDatabase("db1") + mustExec(t, c, "CREATE TABLE t1 (id INT, name VARCHAR(100))") + + db := c.GetDatabase("db1") + tbl := db.GetTable("t1") + if tbl.Charset != "utf8mb4" { + t.Errorf("table charset: expected utf8mb4, got %q", tbl.Charset) + } + col := tbl.GetColumn("name") + if col.Charset != "utf8mb4" { + t.Errorf("column charset: expected utf8mb4, got %q", col.Charset) + } + if col.Collation != "utf8mb4_0900_ai_ci" { + t.Errorf("column collation: expected utf8mb4_0900_ai_ci, got %q", col.Collation) + } + }) + + t.Run("table_charset_overrides_db", func(t *testing.T) { + // Scenario 2: CREATE DATABASE CHARSET latin1 → CREATE TABLE CHARSET utf8mb4 → column inherits utf8mb4 + c := New() + mustExec(t, c, "CREATE DATABASE db2 DEFAULT CHARACTER SET latin1") + c.SetCurrentDatabase("db2") + mustExec(t, c, "CREATE TABLE t1 (id INT, name VARCHAR(100)) DEFAULT CHARSET=utf8mb4") + + db := c.GetDatabase("db2") + tbl := db.GetTable("t1") + if tbl.Charset != "utf8mb4" { + t.Errorf("table charset: expected utf8mb4, got %q", tbl.Charset) + } + col := tbl.GetColumn("name") + if col.Charset != "utf8mb4" { + t.Errorf("column charset: expected utf8mb4, got %q", col.Charset) + } + }) + + t.Run("add_column_inherits_table_charset", func(t *testing.T) { + // Scenario 3: Table CHARSET utf8mb4 → ADD COLUMN VARCHAR (no charset) → column inherits table charset + c := New() + mustExec(t, c, "CREATE DATABASE db3 DEFAULT CHARACTER SET utf8mb4") + c.SetCurrentDatabase("db3") + mustExec(t, c, "CREATE TABLE t1 (id INT) DEFAULT CHARSET=utf8mb4") + mustExec(t, c, "ALTER TABLE t1 ADD COLUMN name VARCHAR(100)") + + tbl := c.GetDatabase("db3").GetTable("t1") + col := tbl.GetColumn("name") + if col.Charset != "utf8mb4" { + t.Errorf("column charset: expected utf8mb4, got %q", col.Charset) + } + if col.Collation != "utf8mb4_0900_ai_ci" { + t.Errorf("column collation: expected utf8mb4_0900_ai_ci, got %q", col.Collation) + } + }) + + t.Run("add_column_charset_overrides_table", func(t *testing.T) { + // Scenario 4: Table CHARSET utf8mb4 → ADD COLUMN VARCHAR CHARSET latin1 → column overrides + c := New() + mustExec(t, c, "CREATE DATABASE db4 DEFAULT CHARACTER SET utf8mb4") + c.SetCurrentDatabase("db4") + mustExec(t, c, "CREATE TABLE t1 (id INT) DEFAULT CHARSET=utf8mb4") + mustExec(t, c, "ALTER TABLE t1 ADD COLUMN name VARCHAR(100) CHARACTER SET latin1") + + tbl := c.GetDatabase("db4").GetTable("t1") + col := tbl.GetColumn("name") + if col.Charset != "latin1" { + t.Errorf("column charset: expected latin1, got %q", col.Charset) + } + if col.Collation != "latin1_swedish_ci" { + t.Errorf("column collation: expected latin1_swedish_ci, got %q", col.Collation) + } + }) + + t.Run("show_create_table_shows_inherited_charset", func(t *testing.T) { + // Scenario 5: CREATE DATABASE CHARSET utf8mb4 → table inherits → SHOW CREATE TABLE shows DEFAULT CHARSET=utf8mb4 + c := New() + mustExec(t, c, "CREATE DATABASE db5 DEFAULT CHARACTER SET utf8mb4") + c.SetCurrentDatabase("db5") + mustExec(t, c, "CREATE TABLE t1 (id INT, name VARCHAR(100))") + + ddl := c.ShowCreateTable("db5", "t1") + if !strings.Contains(ddl, "DEFAULT CHARSET=utf8mb4") { + t.Errorf("SHOW CREATE TABLE should contain DEFAULT CHARSET=utf8mb4, got:\n%s", ddl) + } + // MySQL 8.0 always shows COLLATE for utf8mb4 + if !strings.Contains(ddl, "COLLATE=utf8mb4_0900_ai_ci") { + t.Errorf("SHOW CREATE TABLE should contain COLLATE=utf8mb4_0900_ai_ci for utf8mb4, got:\n%s", ddl) + } + }) + + t.Run("charset_only_derives_default_collation", func(t *testing.T) { + // Scenario 6: CREATE TABLE with CHARSET only (no COLLATE) — default collation derived + c := New() + mustExec(t, c, "CREATE DATABASE db6") + c.SetCurrentDatabase("db6") + mustExec(t, c, "CREATE TABLE t1 (id INT, name VARCHAR(100)) DEFAULT CHARSET=latin1") + + tbl := c.GetDatabase("db6").GetTable("t1") + if tbl.Charset != "latin1" { + t.Errorf("table charset: expected latin1, got %q", tbl.Charset) + } + if tbl.Collation != "latin1_swedish_ci" { + t.Errorf("table collation: expected latin1_swedish_ci, got %q", tbl.Collation) + } + + // Column should also inherit + col := tbl.GetColumn("name") + if col.Charset != "latin1" { + t.Errorf("column charset: expected latin1, got %q", col.Charset) + } + if col.Collation != "latin1_swedish_ci" { + t.Errorf("column collation: expected latin1_swedish_ci, got %q", col.Collation) + } + }) + + t.Run("collate_only_derives_charset", func(t *testing.T) { + // Scenario 7: CREATE TABLE with COLLATE only (no CHARSET) — charset derived from collation + c := New() + mustExec(t, c, "CREATE DATABASE db7") + c.SetCurrentDatabase("db7") + mustExec(t, c, "CREATE TABLE t1 (id INT, name VARCHAR(100)) DEFAULT COLLATE=latin1_swedish_ci") + + tbl := c.GetDatabase("db7").GetTable("t1") + if tbl.Charset != "latin1" { + t.Errorf("table charset: expected latin1, got %q", tbl.Charset) + } + if tbl.Collation != "latin1_swedish_ci" { + t.Errorf("table collation: expected latin1_swedish_ci, got %q", tbl.Collation) + } + + ddl := c.ShowCreateTable("db7", "t1") + if !strings.Contains(ddl, "DEFAULT CHARSET=latin1") { + t.Errorf("SHOW CREATE TABLE should contain DEFAULT CHARSET=latin1, got:\n%s", ddl) + } + }) +} diff --git a/tidb/catalog/wt_8_2_test.go b/tidb/catalog/wt_8_2_test.go new file mode 100644 index 00000000..425159f8 --- /dev/null +++ b/tidb/catalog/wt_8_2_test.go @@ -0,0 +1,263 @@ +package catalog + +import ( + "strings" + "testing" +) + +// TestWalkThrough_8_2_AlterDefaultCharset tests that ALTER TABLE DEFAULT CHARACTER SET +// changes the table default but leaves existing column charsets unchanged. +func TestWalkThrough_8_2_AlterDefaultCharset(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT, + name VARCHAR(100), + bio TEXT + ) DEFAULT CHARSET=utf8mb4`) + + // Verify initial state: columns inherit utf8mb4. + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl.Charset != "utf8mb4" { + t.Fatalf("expected table charset=utf8mb4, got %s", tbl.Charset) + } + nameCol := tbl.GetColumn("name") + if nameCol.Charset != "utf8mb4" { + t.Fatalf("expected name charset=utf8mb4, got %s", nameCol.Charset) + } + + // ALTER TABLE DEFAULT CHARACTER SET latin1 — table default changes but columns unchanged. + wtExec(t, c, "ALTER TABLE t1 DEFAULT CHARACTER SET latin1") + + tbl = c.GetDatabase("testdb").GetTable("t1") + if tbl.Charset != "latin1" { + t.Errorf("expected table charset=latin1, got %s", tbl.Charset) + } + + // Existing columns should still have utf8mb4. + nameCol = tbl.GetColumn("name") + if nameCol.Charset != "utf8mb4" { + t.Errorf("expected name charset=utf8mb4 (unchanged), got %s", nameCol.Charset) + } + bioCol := tbl.GetColumn("bio") + if bioCol.Charset != "utf8mb4" { + t.Errorf("expected bio charset=utf8mb4 (unchanged), got %s", bioCol.Charset) + } +} + +// TestWalkThrough_8_2_ConvertToCharset tests CONVERT TO CHARACTER SET utf8mb4 +// updates table + all VARCHAR/TEXT/ENUM columns. +func TestWalkThrough_8_2_ConvertToCharset(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT, + name VARCHAR(100), + bio TEXT, + tag ENUM('a','b') + ) DEFAULT CHARSET=latin1`) + + wtExec(t, c, "ALTER TABLE t1 CONVERT TO CHARACTER SET utf8mb4") + + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl.Charset != "utf8mb4" { + t.Errorf("expected table charset=utf8mb4, got %s", tbl.Charset) + } + if tbl.Collation != "utf8mb4_0900_ai_ci" { + t.Errorf("expected table collation=utf8mb4_0900_ai_ci, got %s", tbl.Collation) + } + + // All string columns should be updated. + for _, colName := range []string{"name", "bio", "tag"} { + col := tbl.GetColumn(colName) + if col == nil { + t.Fatalf("column %s not found", colName) + } + if col.Charset != "utf8mb4" { + t.Errorf("column %s: expected charset=utf8mb4, got %s", colName, col.Charset) + } + if col.Collation != "utf8mb4_0900_ai_ci" { + t.Errorf("column %s: expected collation=utf8mb4_0900_ai_ci, got %s", colName, col.Collation) + } + } +} + +// TestWalkThrough_8_2_ConvertWithCollation tests CONVERT TO CHARACTER SET utf8mb4 +// COLLATE utf8mb4_unicode_ci — non-default collation. +func TestWalkThrough_8_2_ConvertWithCollation(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT, + name VARCHAR(100), + bio TEXT + ) DEFAULT CHARSET=latin1`) + + wtExec(t, c, "ALTER TABLE t1 CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci") + + tbl := c.GetDatabase("testdb").GetTable("t1") + if tbl.Charset != "utf8mb4" { + t.Errorf("expected table charset=utf8mb4, got %s", tbl.Charset) + } + if tbl.Collation != "utf8mb4_unicode_ci" { + t.Errorf("expected table collation=utf8mb4_unicode_ci, got %s", tbl.Collation) + } + + for _, colName := range []string{"name", "bio"} { + col := tbl.GetColumn(colName) + if col.Charset != "utf8mb4" { + t.Errorf("column %s: expected charset=utf8mb4, got %s", colName, col.Charset) + } + if col.Collation != "utf8mb4_unicode_ci" { + t.Errorf("column %s: expected collation=utf8mb4_unicode_ci, got %s", colName, col.Collation) + } + } +} + +// TestWalkThrough_8_2_ConvertIntColumnsUnchanged tests that CONVERT TO CHARACTER SET +// does not affect INT columns. +func TestWalkThrough_8_2_ConvertIntColumnsUnchanged(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT, + count BIGINT, + name VARCHAR(100) + ) DEFAULT CHARSET=latin1`) + + wtExec(t, c, "ALTER TABLE t1 CONVERT TO CHARACTER SET utf8mb4") + + tbl := c.GetDatabase("testdb").GetTable("t1") + + // INT columns should have no charset. + colID := tbl.GetColumn("id") + if colID.Charset != "" { + t.Errorf("INT column id should not have charset, got %s", colID.Charset) + } + colCount := tbl.GetColumn("count") + if colCount.Charset != "" { + t.Errorf("BIGINT column count should not have charset, got %s", colCount.Charset) + } + + // String column should be updated. + colName := tbl.GetColumn("name") + if colName.Charset != "utf8mb4" { + t.Errorf("VARCHAR column name: expected charset=utf8mb4, got %s", colName.Charset) + } +} + +// TestWalkThrough_8_2_ConvertMixedColumnTypes tests CONVERT TO CHARACTER SET on +// a table with mixed column types — only string types are updated. +func TestWalkThrough_8_2_ConvertMixedColumnTypes(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT, + name VARCHAR(100), + score DECIMAL(10,2), + bio TEXT, + active TINYINT, + tag ENUM('a','b'), + created DATE, + data MEDIUMTEXT + ) DEFAULT CHARSET=latin1`) + + wtExec(t, c, "ALTER TABLE t1 CONVERT TO CHARACTER SET utf8mb4") + + tbl := c.GetDatabase("testdb").GetTable("t1") + + // String types should be updated. + stringCols := []string{"name", "bio", "tag", "data"} + for _, colName := range stringCols { + col := tbl.GetColumn(colName) + if col == nil { + t.Fatalf("column %s not found", colName) + } + if col.Charset != "utf8mb4" { + t.Errorf("column %s: expected charset=utf8mb4, got %s", colName, col.Charset) + } + } + + // Non-string types should not have charset. + nonStringCols := []string{"id", "score", "active", "created"} + for _, colName := range nonStringCols { + col := tbl.GetColumn(colName) + if col == nil { + t.Fatalf("column %s not found", colName) + } + if col.Charset != "" { + t.Errorf("non-string column %s should not have charset, got %s", colName, col.Charset) + } + } +} + +// TestWalkThrough_8_2_ConvertOverwritesExplicitCharset tests that CONVERT TO CHARACTER SET +// overwrites a column that already has an explicit charset. +func TestWalkThrough_8_2_ConvertOverwritesExplicitCharset(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT, + name VARCHAR(100) CHARACTER SET latin1 + ) DEFAULT CHARSET=utf8mb4`) + + // Verify initial: name has explicit latin1. + tbl := c.GetDatabase("testdb").GetTable("t1") + colName := tbl.GetColumn("name") + if colName.Charset != "latin1" { + t.Fatalf("expected name charset=latin1 initially, got %s", colName.Charset) + } + + // CONVERT overwrites the explicit charset. + wtExec(t, c, "ALTER TABLE t1 CONVERT TO CHARACTER SET utf8mb4") + + tbl = c.GetDatabase("testdb").GetTable("t1") + colName = tbl.GetColumn("name") + if colName.Charset != "utf8mb4" { + t.Errorf("expected name charset=utf8mb4 after CONVERT, got %s", colName.Charset) + } + if colName.Collation != "utf8mb4_0900_ai_ci" { + t.Errorf("expected name collation=utf8mb4_0900_ai_ci after CONVERT, got %s", colName.Collation) + } +} + +// TestWalkThrough_8_2_ConvertThenShowCreate tests that after CONVERT TO CHARACTER SET, +// column charsets matching the table default are NOT shown in SHOW CREATE TABLE. +func TestWalkThrough_8_2_ConvertThenShowCreate(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t1 ( + id INT, + name VARCHAR(100), + bio TEXT, + age INT + ) DEFAULT CHARSET=latin1`) + + wtExec(t, c, "ALTER TABLE t1 CONVERT TO CHARACTER SET utf8mb4") + + ddl := c.ShowCreateTable("testdb", "t1") + + // Column charsets should NOT appear in the output since they match the table default. + // The table-level DEFAULT CHARSET=utf8mb4 should appear. + if !strings.Contains(ddl, "DEFAULT CHARSET=utf8mb4") { + t.Errorf("expected DEFAULT CHARSET=utf8mb4 in output, got:\n%s", ddl) + } + + // Column definitions should NOT contain "CHARACTER SET" since they match the table default. + lines := strings.Split(ddl, "\n") + for _, line := range lines { + trimmed := strings.TrimSpace(line) + // Skip table options line. + if strings.HasPrefix(trimmed, ")") { + continue + } + // Check column lines for CHARACTER SET — should not appear. + if strings.HasPrefix(trimmed, "`name`") || strings.HasPrefix(trimmed, "`bio`") { + if strings.Contains(line, "CHARACTER SET") { + t.Errorf("column charset should not be shown when matching table default:\n%s", line) + } + } + } + + // INT columns should definitely not show CHARACTER SET. + for _, line := range lines { + if strings.Contains(line, "`id`") || strings.Contains(line, "`age`") { + if strings.Contains(line, "CHARACTER SET") { + t.Errorf("INT column should not show CHARACTER SET:\n%s", line) + } + } + } +} diff --git a/tidb/catalog/wt_8_3_test.go b/tidb/catalog/wt_8_3_test.go new file mode 100644 index 00000000..75e94c4c --- /dev/null +++ b/tidb/catalog/wt_8_3_test.go @@ -0,0 +1,126 @@ +package catalog + +import ( + "strings" + "testing" +) + +// --- 4.3 SHOW CREATE TABLE Charset Rendering --- + +func TestWalkThrough_8_3_CharsetSameAsTable(t *testing.T) { + // Column charset same as table — CHARACTER SET not shown in column def + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT, name VARCHAR(100)) DEFAULT CHARSET=utf8mb4") + got := c.ShowCreateTable("testdb", "t") + // Column should NOT show CHARACTER SET since it matches table charset + if strings.Contains(got, "CHARACTER SET") { + t.Errorf("expected no CHARACTER SET in column def when charset matches table\ngot:\n%s", got) + } +} + +func TestWalkThrough_8_3_CharsetDiffersFromTable(t *testing.T) { + // Column charset differs from table — CHARACTER SET shown + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT, name VARCHAR(100) CHARACTER SET latin1) DEFAULT CHARSET=utf8mb4") + got := c.ShowCreateTable("testdb", "t") + if !strings.Contains(got, "CHARACTER SET latin1") { + t.Errorf("expected CHARACTER SET latin1 in column def when charset differs from table\ngot:\n%s", got) + } +} + +func TestWalkThrough_8_3_CollationNonDefaultSameAsTable(t *testing.T) { + // Column collation is non-default for its charset but same as table — COLLATE shown + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT, name VARCHAR(100)) DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci") + got := c.ShowCreateTable("testdb", "t") + // The column inherits utf8mb4_unicode_ci from table, which is non-default for utf8mb4. + // MySQL shows COLLATE in this case (collation inherited but non-default for charset). + if !strings.Contains(got, "COLLATE utf8mb4_unicode_ci") { + t.Errorf("expected COLLATE utf8mb4_unicode_ci in column def\ngot:\n%s", got) + } + // CHARACTER SET should NOT be shown since charset matches table + lines := strings.Split(got, "\n") + for _, line := range lines { + if strings.Contains(line, "`name`") && strings.Contains(line, "CHARACTER SET") { + t.Errorf("expected no CHARACTER SET in column def when charset matches table\nline: %s", line) + } + } +} + +func TestWalkThrough_8_3_CollationDiffersFromTable(t *testing.T) { + // Column collation differs from table — both CHARACTER SET and COLLATE shown + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT, name VARCHAR(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin) DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci") + got := c.ShowCreateTable("testdb", "t") + // Column has utf8mb4_bin which differs from table's utf8mb4_unicode_ci. + // Both are non-default for utf8mb4, but column differs from table. + // MySQL shows both CHARACTER SET and COLLATE. + lines := strings.Split(got, "\n") + foundCharset := false + foundCollate := false + for _, line := range lines { + if strings.Contains(line, "`name`") { + if strings.Contains(line, "CHARACTER SET utf8mb4") { + foundCharset = true + } + if strings.Contains(line, "COLLATE utf8mb4_bin") { + foundCollate = true + } + } + } + if !foundCharset { + t.Errorf("expected CHARACTER SET utf8mb4 in column def when collation differs from table\ngot:\n%s", got) + } + if !foundCollate { + t.Errorf("expected COLLATE utf8mb4_bin in column def when collation differs from table\ngot:\n%s", got) + } +} + +func TestWalkThrough_8_3_Utf8mb4DefaultCollation(t *testing.T) { + // Table with utf8mb4 default collation — COLLATE always shown for utf8mb4 (MySQL 8.0 behavior) + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT) DEFAULT CHARSET=utf8mb4") + got := c.ShowCreateTable("testdb", "t") + // MySQL 8.0 always shows COLLATE for utf8mb4 tables + if !strings.Contains(got, "COLLATE=utf8mb4_0900_ai_ci") { + t.Errorf("expected COLLATE=utf8mb4_0900_ai_ci in table options for utf8mb4\ngot:\n%s", got) + } +} + +func TestWalkThrough_8_3_Latin1DefaultCollation(t *testing.T) { + // Table with latin1 and default collation — COLLATE not shown (latin1_swedish_ci is default) + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT) DEFAULT CHARSET=latin1") + got := c.ShowCreateTable("testdb", "t") + if strings.Contains(got, "COLLATE=") { + t.Errorf("expected no COLLATE in table options for latin1 with default collation\ngot:\n%s", got) + } +} + +func TestWalkThrough_8_3_BinaryCharset(t *testing.T) { + // BINARY charset on column — rendered as CHARACTER SET binary + // MySQL 8.0 converts CHAR(N) CHARACTER SET binary → binary(N), + // and VARCHAR(N) CHARACTER SET binary → varbinary(N). + // But ENUM/SET with CHARACTER SET binary retains the charset annotation. + + // Sub-test 1: CHAR CHARACTER SET binary → binary(N) in MySQL 8.0 + t.Run("char_binary_converts", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT, data CHAR(100) CHARACTER SET binary) DEFAULT CHARSET=utf8mb4") + got := c.ShowCreateTable("testdb", "t") + // MySQL converts to binary(100) type — no CHARACTER SET annotation + if !strings.Contains(got, "`data` binary(100)") { + t.Errorf("expected CHAR(100) CHARACTER SET binary to render as binary(100)\ngot:\n%s", got) + } + }) + + // Sub-test 2: ENUM with CHARACTER SET binary — shows CHARACTER SET binary + t.Run("enum_charset_binary", func(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (id INT, status ENUM('a','b','c') CHARACTER SET binary) DEFAULT CHARSET=utf8mb4") + got := c.ShowCreateTable("testdb", "t") + if !strings.Contains(got, "CHARACTER SET binary") { + t.Errorf("expected CHARACTER SET binary in ENUM column def\ngot:\n%s", got) + } + }) +} diff --git a/tidb/catalog/wt_9_1_test.go b/tidb/catalog/wt_9_1_test.go new file mode 100644 index 00000000..a335d6c1 --- /dev/null +++ b/tidb/catalog/wt_9_1_test.go @@ -0,0 +1,268 @@ +package catalog + +import ( + "strings" + "testing" +) + +// --- 5.1 Generated Column CRUD (11 scenarios) --- + +func TestWalkThrough_9_1_VirtualArithmetic(t *testing.T) { + // Scenario 1: CREATE TABLE with VIRTUAL generated column (arithmetic) + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT GENERATED ALWAYS AS (a + b) VIRTUAL)") + + tbl := c.GetDatabase("testdb").GetTable("t") + col := tbl.GetColumn("c") + if col == nil { + t.Fatal("column c not found") + } + if col.Generated == nil { + t.Fatal("column c should be generated") + } + if col.Generated.Stored { + t.Error("column c should be VIRTUAL, not STORED") + } + + got := c.ShowCreateTable("testdb", "t") + if !strings.Contains(got, "GENERATED ALWAYS AS") { + t.Errorf("expected GENERATED ALWAYS AS in SHOW CREATE TABLE\ngot:\n%s", got) + } + if !strings.Contains(got, "VIRTUAL") { + t.Errorf("expected VIRTUAL in SHOW CREATE TABLE\ngot:\n%s", got) + } + // Expression should contain backtick-quoted column refs and arithmetic. + if !strings.Contains(got, "(`a` + `b`)") { + t.Errorf("expected expression (`a` + `b`) in SHOW CREATE TABLE\ngot:\n%s", got) + } +} + +func TestWalkThrough_9_1_StoredArithmetic(t *testing.T) { + // Scenario 2: CREATE TABLE with STORED generated column (arithmetic) + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT GENERATED ALWAYS AS (a * b) STORED)") + + tbl := c.GetDatabase("testdb").GetTable("t") + col := tbl.GetColumn("c") + if col == nil { + t.Fatal("column c not found") + } + if col.Generated == nil { + t.Fatal("column c should be generated") + } + if !col.Generated.Stored { + t.Error("column c should be STORED") + } + + got := c.ShowCreateTable("testdb", "t") + if !strings.Contains(got, "STORED") { + t.Errorf("expected STORED in SHOW CREATE TABLE\ngot:\n%s", got) + } + if !strings.Contains(got, "(`a` * `b`)") { + t.Errorf("expected expression (`a` * `b`) in SHOW CREATE TABLE\ngot:\n%s", got) + } +} + +func TestWalkThrough_9_1_VirtualConcat(t *testing.T) { + // Scenario 3: CREATE TABLE with VIRTUAL generated column (CONCAT function) + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (first_name VARCHAR(50), last_name VARCHAR(50), full_name VARCHAR(101) GENERATED ALWAYS AS (CONCAT(first_name, ' ', last_name)) VIRTUAL)") + + tbl := c.GetDatabase("testdb").GetTable("t") + col := tbl.GetColumn("full_name") + if col == nil { + t.Fatal("column full_name not found") + } + if col.Generated == nil { + t.Fatal("column full_name should be generated") + } + + got := c.ShowCreateTable("testdb", "t") + // MySQL renders CONCAT with charset introducer for string literals in generated columns. + if !strings.Contains(got, "concat(") { + t.Errorf("expected concat function in SHOW CREATE TABLE\ngot:\n%s", got) + } + if !strings.Contains(got, "VIRTUAL") { + t.Errorf("expected VIRTUAL in SHOW CREATE TABLE\ngot:\n%s", got) + } +} + +func TestWalkThrough_9_1_StoredNotNull(t *testing.T) { + // Scenario 4: CREATE TABLE with STORED generated column + NOT NULL + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT GENERATED ALWAYS AS (a + b) STORED NOT NULL)") + + tbl := c.GetDatabase("testdb").GetTable("t") + col := tbl.GetColumn("c") + if col == nil { + t.Fatal("column c not found") + } + if col.Generated == nil { + t.Fatal("column c should be generated") + } + if !col.Generated.Stored { + t.Error("column c should be STORED") + } + if col.Nullable { + t.Error("column c should be NOT NULL") + } + + got := c.ShowCreateTable("testdb", "t") + // Rendering order: type, GENERATED ALWAYS AS (expr) STORED, NOT NULL + if !strings.Contains(got, "STORED NOT NULL") { + t.Errorf("expected 'STORED NOT NULL' in SHOW CREATE TABLE\ngot:\n%s", got) + } +} + +func TestWalkThrough_9_1_Comment(t *testing.T) { + // Scenario 5: CREATE TABLE with generated column + COMMENT + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT GENERATED ALWAYS AS (a + b) VIRTUAL COMMENT 'sum of a and b')") + + tbl := c.GetDatabase("testdb").GetTable("t") + col := tbl.GetColumn("c") + if col == nil { + t.Fatal("column c not found") + } + if col.Comment != "sum of a and b" { + t.Errorf("expected comment 'sum of a and b', got %q", col.Comment) + } + + got := c.ShowCreateTable("testdb", "t") + // COMMENT should come after VIRTUAL + if !strings.Contains(got, "VIRTUAL COMMENT 'sum of a and b'") { + t.Errorf("expected 'VIRTUAL COMMENT ...' in SHOW CREATE TABLE\ngot:\n%s", got) + } +} + +func TestWalkThrough_9_1_Invisible(t *testing.T) { + // Scenario 6: CREATE TABLE with generated column + INVISIBLE + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT GENERATED ALWAYS AS (a + b) VIRTUAL INVISIBLE)") + + tbl := c.GetDatabase("testdb").GetTable("t") + col := tbl.GetColumn("c") + if col == nil { + t.Fatal("column c not found") + } + if !col.Invisible { + t.Error("column c should be INVISIBLE") + } + + got := c.ShowCreateTable("testdb", "t") + // INVISIBLE rendered after VIRTUAL with version comment + if !strings.Contains(got, "INVISIBLE") { + t.Errorf("expected INVISIBLE in SHOW CREATE TABLE\ngot:\n%s", got) + } +} + +func TestWalkThrough_9_1_JsonExtract(t *testing.T) { + // Scenario 7: CREATE TABLE with generated column using JSON_EXTRACT + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (doc JSON, name VARCHAR(100) GENERATED ALWAYS AS (JSON_EXTRACT(doc, '$.name')) VIRTUAL)") + + tbl := c.GetDatabase("testdb").GetTable("t") + col := tbl.GetColumn("name") + if col == nil { + t.Fatal("column name not found") + } + if col.Generated == nil { + t.Fatal("column name should be generated") + } + + got := c.ShowCreateTable("testdb", "t") + // MySQL renders json_extract in lowercase + if !strings.Contains(got, "json_extract(") { + t.Errorf("expected json_extract function in SHOW CREATE TABLE\ngot:\n%s", got) + } +} + +func TestWalkThrough_9_1_AlterAddGenerated(t *testing.T) { + // Scenario 8: ALTER TABLE ADD COLUMN with GENERATED ALWAYS AS + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT)") + wtExec(t, c, "ALTER TABLE t ADD COLUMN c INT GENERATED ALWAYS AS (a + b) VIRTUAL") + + tbl := c.GetDatabase("testdb").GetTable("t") + col := tbl.GetColumn("c") + if col == nil { + t.Fatal("column c not found") + } + if col.Generated == nil { + t.Fatal("column c should be generated") + } + if col.Generated.Stored { + t.Error("column c should be VIRTUAL") + } + + got := c.ShowCreateTable("testdb", "t") + if !strings.Contains(got, "GENERATED ALWAYS AS") { + t.Errorf("expected GENERATED ALWAYS AS in SHOW CREATE TABLE\ngot:\n%s", got) + } +} + +func TestWalkThrough_9_1_ModifyChangeExpression(t *testing.T) { + // Scenario 9: MODIFY COLUMN to change generated expression + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT GENERATED ALWAYS AS (a + b) VIRTUAL)") + wtExec(t, c, "ALTER TABLE t MODIFY COLUMN c INT GENERATED ALWAYS AS (a * b) VIRTUAL") + + tbl := c.GetDatabase("testdb").GetTable("t") + col := tbl.GetColumn("c") + if col == nil { + t.Fatal("column c not found") + } + if col.Generated == nil { + t.Fatal("column c should still be generated") + } + + got := c.ShowCreateTable("testdb", "t") + // Expression should now be (a * b), not (a + b) + if !strings.Contains(got, "(`a` * `b`)") { + t.Errorf("expected updated expression (`a` * `b`) in SHOW CREATE TABLE\ngot:\n%s", got) + } +} + +func TestWalkThrough_9_1_ModifyGeneratedToRegular(t *testing.T) { + // Scenario 10: ALTER TABLE MODIFY generated column to regular column + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT GENERATED ALWAYS AS (a + b) VIRTUAL)") + wtExec(t, c, "ALTER TABLE t MODIFY COLUMN c INT") + + tbl := c.GetDatabase("testdb").GetTable("t") + col := tbl.GetColumn("c") + if col == nil { + t.Fatal("column c not found") + } + if col.Generated != nil { + t.Error("column c should no longer be generated after MODIFY to regular column") + } + + got := c.ShowCreateTable("testdb", "t") + if strings.Contains(got, "GENERATED ALWAYS AS") { + t.Errorf("expected no GENERATED ALWAYS AS after MODIFY to regular column\ngot:\n%s", got) + } +} + +func TestWalkThrough_9_1_ModifyVirtualToStored(t *testing.T) { + // Scenario 11: ALTER TABLE MODIFY VIRTUAL to STORED — MySQL 8.0 error + // MySQL 8.0 does not allow changing VIRTUAL to STORED in-place. + // Error 3106 (HY000): 'Changing the STORED status' is not supported for generated columns. + c := wtSetup(t) + wtExec(t, c, "CREATE TABLE t (a INT, b INT, c INT GENERATED ALWAYS AS (a + b) VIRTUAL)") + + results, _ := c.Exec("ALTER TABLE t MODIFY COLUMN c INT GENERATED ALWAYS AS (a + b) STORED", nil) + if len(results) == 0 { + t.Fatal("expected result from ALTER TABLE") + } + if results[0].Error == nil { + t.Fatal("expected error when changing VIRTUAL to STORED") + } + catErr, ok := results[0].Error.(*Error) + if !ok { + t.Fatalf("expected *catalog.Error, got %T", results[0].Error) + } + if catErr.Code != ErrUnsupportedGeneratedStorageChange { + t.Errorf("expected error code %d, got %d: %s", ErrUnsupportedGeneratedStorageChange, catErr.Code, catErr.Message) + } +} diff --git a/tidb/catalog/wt_9_2_test.go b/tidb/catalog/wt_9_2_test.go new file mode 100644 index 00000000..05b146ac --- /dev/null +++ b/tidb/catalog/wt_9_2_test.go @@ -0,0 +1,239 @@ +package catalog + +import ( + "strings" + "testing" +) + +// --- Section 5.2: Generated Column Dependencies (6 scenarios) --- + +// Scenario 1: DROP COLUMN referenced by VIRTUAL generated column — MySQL 8.0 error 3108 +func TestWalkThrough_9_2_DropColumnReferencedByVirtual(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT, + a INT, + b INT, + v INT GENERATED ALWAYS AS (a + b) VIRTUAL + )`) + + // Dropping column 'a' should fail because generated column 'v' references it. + results, _ := c.Exec("ALTER TABLE t DROP COLUMN a", &ExecOptions{ContinueOnError: true}) + assertError(t, results[0].Error, ErrDependentByGenCol) + + // Verify column was NOT dropped (error should prevent it). + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl.GetColumn("a") == nil { + t.Error("column 'a' should still exist after failed DROP") + } + if len(tbl.Columns) != 4 { + t.Errorf("expected 4 columns, got %d", len(tbl.Columns)) + } +} + +// Scenario 2: DROP COLUMN referenced by STORED generated column — MySQL 8.0 error 3108 +func TestWalkThrough_9_2_DropColumnReferencedByStored(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT, + a INT, + b INT, + s INT GENERATED ALWAYS AS (a + b) STORED + )`) + + // Dropping column 'b' should fail because stored generated column 's' references it. + results, _ := c.Exec("ALTER TABLE t DROP COLUMN b", &ExecOptions{ContinueOnError: true}) + assertError(t, results[0].Error, ErrDependentByGenCol) + + // Verify column was NOT dropped. + tbl := c.GetDatabase("testdb").GetTable("t") + if tbl.GetColumn("b") == nil { + t.Error("column 'b' should still exist after failed DROP") + } +} + +// Scenario 3: MODIFY base column type when generated column uses it — expression preserved, no error +func TestWalkThrough_9_2_ModifyBaseColumnType(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT, + a INT, + v INT GENERATED ALWAYS AS (a * 2) VIRTUAL + )`) + + // Modify base column type — should succeed, generated expression preserved. + wtExec(t, c, "ALTER TABLE t MODIFY COLUMN a BIGINT") + + tbl := c.GetDatabase("testdb").GetTable("t") + + // Verify column type was changed. + col := tbl.GetColumn("a") + if col == nil { + t.Fatal("column 'a' not found after MODIFY") + } + if col.ColumnType != "bigint" { + t.Errorf("expected column type 'bigint', got %q", col.ColumnType) + } + + // Verify generated column expression is preserved. + vCol := tbl.GetColumn("v") + if vCol == nil { + t.Fatal("generated column 'v' not found") + } + if vCol.Generated == nil { + t.Fatal("column 'v' should be a generated column") + } + + // Verify SHOW CREATE TABLE renders correctly. + show := c.ShowCreateTable("testdb", "t") + if !strings.Contains(show, "GENERATED ALWAYS AS") { + t.Error("SHOW CREATE TABLE should contain generated column expression") + } + if !strings.Contains(show, "`a`") { + t.Error("generated expression should still reference column 'a'") + } +} + +// Scenario 4: Generated column referencing another generated column — verify creation and SHOW CREATE +func TestWalkThrough_9_2_GeneratedReferencingGenerated(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT, + a INT, + v1 INT GENERATED ALWAYS AS (a * 2) VIRTUAL, + v2 INT GENERATED ALWAYS AS (v1 + 10) VIRTUAL + )`) + + tbl := c.GetDatabase("testdb").GetTable("t") + + // Verify both generated columns exist. + v1 := tbl.GetColumn("v1") + if v1 == nil { + t.Fatal("generated column 'v1' not found") + } + if v1.Generated == nil { + t.Fatal("v1 should be a generated column") + } + if v1.Generated.Stored { + t.Error("v1 should be VIRTUAL, not STORED") + } + + v2 := tbl.GetColumn("v2") + if v2 == nil { + t.Fatal("generated column 'v2' not found") + } + if v2.Generated == nil { + t.Fatal("v2 should be a generated column") + } + if v2.Generated.Stored { + t.Error("v2 should be VIRTUAL, not STORED") + } + + // Verify SHOW CREATE TABLE renders both generated columns correctly. + show := c.ShowCreateTable("testdb", "t") + if !strings.Contains(show, "`v1`") { + t.Error("SHOW CREATE TABLE should contain v1") + } + if !strings.Contains(show, "`v2`") { + t.Error("SHOW CREATE TABLE should contain v2") + } + + // v2 references v1, so dropping v1 should fail. + results, _ := c.Exec("ALTER TABLE t DROP COLUMN v1", &ExecOptions{ContinueOnError: true}) + assertError(t, results[0].Error, ErrDependentByGenCol) +} + +// Scenario 5: Index on generated column — index created, rendered correctly +func TestWalkThrough_9_2_IndexOnGeneratedColumn(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT, + a INT, + v INT GENERATED ALWAYS AS (a * 2) VIRTUAL, + INDEX idx_v (v) + )`) + + tbl := c.GetDatabase("testdb").GetTable("t") + + // Verify generated column exists. + vCol := tbl.GetColumn("v") + if vCol == nil { + t.Fatal("generated column 'v' not found") + } + if vCol.Generated == nil { + t.Fatal("v should be a generated column") + } + + // Verify index exists on generated column. + var found *Index + for _, idx := range tbl.Indexes { + if idx.Name == "idx_v" { + found = idx + break + } + } + if found == nil { + t.Fatal("index idx_v not found") + } + if found.Unique { + t.Error("idx_v should not be unique") + } + if len(found.Columns) != 1 || found.Columns[0].Name != "v" { + t.Errorf("index columns mismatch: %+v", found.Columns) + } + + // Verify SHOW CREATE TABLE renders both the generated column and index. + show := c.ShowCreateTable("testdb", "t") + if !strings.Contains(show, "GENERATED ALWAYS AS") { + t.Error("SHOW CREATE TABLE should contain generated column expression") + } + if !strings.Contains(show, "KEY `idx_v`") { + t.Error("SHOW CREATE TABLE should contain index idx_v") + } +} + +// Scenario 6: UNIQUE index on generated column — unique constraint on generated column +func TestWalkThrough_9_2_UniqueIndexOnGeneratedColumn(t *testing.T) { + c := wtSetup(t) + wtExec(t, c, `CREATE TABLE t ( + id INT, + a INT, + v INT GENERATED ALWAYS AS (a * 2) VIRTUAL, + UNIQUE INDEX ux_v (v) + )`) + + tbl := c.GetDatabase("testdb").GetTable("t") + + // Verify generated column exists. + vCol := tbl.GetColumn("v") + if vCol == nil { + t.Fatal("generated column 'v' not found") + } + if vCol.Generated == nil { + t.Fatal("v should be a generated column") + } + + // Verify unique index exists on generated column. + var found *Index + for _, idx := range tbl.Indexes { + if idx.Name == "ux_v" { + found = idx + break + } + } + if found == nil { + t.Fatal("unique index ux_v not found") + } + if !found.Unique { + t.Error("ux_v should be unique") + } + if len(found.Columns) != 1 || found.Columns[0].Name != "v" { + t.Errorf("unique index columns mismatch: %+v", found.Columns) + } + + // Verify SHOW CREATE TABLE renders the unique index. + show := c.ShowCreateTable("testdb", "t") + if !strings.Contains(show, "UNIQUE KEY `ux_v`") { + t.Error("SHOW CREATE TABLE should contain UNIQUE KEY ux_v") + } +} diff --git a/tidb/catalog/wt_helpers_test.go b/tidb/catalog/wt_helpers_test.go new file mode 100644 index 00000000..9ef74103 --- /dev/null +++ b/tidb/catalog/wt_helpers_test.go @@ -0,0 +1,56 @@ +package catalog + +import "testing" + +// wtSetup creates a Catalog with database "testdb" selected, ready for walk-through tests. +func wtSetup(t *testing.T) *Catalog { + t.Helper() + c := New() + results, err := c.Exec("CREATE DATABASE testdb; USE testdb;", nil) + if err != nil { + t.Fatalf("wtSetup parse error: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Fatalf("wtSetup exec error: %v", r.Error) + } + } + return c +} + +// wtExec executes SQL on the catalog and fatals on any error. +func wtExec(t *testing.T, c *Catalog, sql string) { + t.Helper() + results, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("wtExec parse error: %v", err) + } + for _, r := range results { + if r.Error != nil { + t.Fatalf("wtExec exec error on stmt %d: %v", r.Index, r.Error) + } + } +} + +// assertError asserts that err is a *catalog.Error with the given MySQL error code. +func assertError(t *testing.T, err error, code int) { + t.Helper() + if err == nil { + t.Fatalf("expected error with code %d, got nil", code) + } + catErr, ok := err.(*Error) + if !ok { + t.Fatalf("expected *catalog.Error, got %T: %v", err, err) + } + if catErr.Code != code { + t.Errorf("expected error code %d, got %d: %s", code, catErr.Code, catErr.Message) + } +} + +// assertNoError asserts that err is nil. +func assertNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } +} diff --git a/tidb/completion/SCENARIOS-completion.md b/tidb/completion/SCENARIOS-completion.md new file mode 100644 index 00000000..17e96fd8 --- /dev/null +++ b/tidb/completion/SCENARIOS-completion.md @@ -0,0 +1,442 @@ +# MySQL SQL Autocompletion Scenarios + +> Goal: Implement parser-native SQL autocompletion for MySQL, matching PG's architecture — parser instrumentation with collectMode, completion module with Complete(sql, cursor, catalog) API +> Verification: Table-driven Go tests — Complete(sql, cursorPos, catalog) returns expected candidates with correct types +> Reference: PG completion system (pg/completion/, pg/parser/complete.go) +> Out of scope (v1): XA transactions, SHUTDOWN, RESTART, CLONE, INSTALL/UNINSTALL PLUGIN, IMPORT TABLE, BINLOG, CACHE INDEX, PURGE BINARY LOGS, CHANGE REPLICATION, START/STOP REPLICA + +Status legend: `[ ]` pending, `[x]` passing, `[~]` partial + +--- + +## Phase 1: Parser Completion Infrastructure + +### 1.1 Completion Mode Fields & Entry Point + +``` +[x] Parser struct has completion fields: completing bool, cursorOff int, candidates *CandidateSet, collecting bool, maxCollect int +[x] collectMode() returns true when completing && collecting +[x] checkCursor() sets collecting=true when p.cur.Loc >= cursorOff +[x] CandidateSet struct with Tokens []int, Rules []RuleCandidate, seen/seenR dedup maps +[x] RuleCandidate struct with Rule string (e.g., "columnref", "table_ref") +[x] addTokenCandidate(tokType int) adds to CandidateSet.Tokens with dedup +[x] addRuleCandidate(rule string) adds to CandidateSet.Rules with dedup +[x] Collect(sql string, cursorOffset int) *CandidateSet — entry point with panic recovery +[x] Collect returns non-nil CandidateSet even for empty input +[x] Collect returns keyword candidates at statement start position: SELECT, INSERT, UPDATE, DELETE, CREATE, ALTER, DROP, etc. +[x] All existing parser tests pass with completion fields present (completing=false) — regression gate +``` + +### 1.2 Basic Keyword Collection + +``` +[x] Empty input → keyword candidates for all top-level statements +[x] After semicolon → keyword candidates for new statement +[x] `SELECT |` → keyword candidates for select expressions (DISTINCT, ALL) + rule candidates (columnref, func_name) +[x] `CREATE |` → keyword candidates for object types (TABLE, INDEX, VIEW, DATABASE, FUNCTION, PROCEDURE, TRIGGER, EVENT) +[x] `ALTER |` → keyword candidates (TABLE, DATABASE, VIEW, FUNCTION, PROCEDURE, EVENT) +[x] `DROP |` → keyword candidates (TABLE, INDEX, VIEW, DATABASE, FUNCTION, PROCEDURE, TRIGGER, EVENT, IF) +``` + +## Phase 2: Completion Module (API & Resolution) + +Build the public API and candidate resolution before instrumenting grammar rules, +so that Phase 3+ instrumentation can be tested end-to-end. + +### 2.1 Public API & Core Logic + +``` +[x] Complete(sql, cursorOffset, catalog) returns []Candidate +[x] Candidate struct has Text, Type, Definition, Comment fields +[x] CandidateType enum: Keyword, Database, Table, View, Column, Function, Procedure, Index, Trigger, Event, Variable, Charset, Engine, Type +[x] Complete with nil catalog returns keyword-only candidates +[x] Complete with empty sql returns top-level statement keywords +[x] Prefix filtering: `SEL|` matches SELECT keyword +[x] Prefix filtering is case-insensitive +[x] Deduplication: same candidate not returned twice +``` + +### 2.2 Candidate Resolution + +``` +[x] Token candidates → keyword strings (from token type mapping) +[x] "table_ref" rule → catalog tables + views +[x] "columnref" rule → columns from tables in scope +[x] "database_ref" rule → catalog databases +[x] "function_ref" / "func_name" rule → catalog functions + built-in function names +[x] "procedure_ref" rule → catalog procedures +[x] "index_ref" rule → indexes from relevant table +[x] "trigger_ref" rule → catalog triggers +[x] "event_ref" rule → catalog events +[x] "view_ref" rule → catalog views +[x] "charset" rule → known charset names (utf8mb4, latin1, utf8, ascii, binary) +[x] "engine" rule → known engine names (InnoDB, MyISAM, MEMORY, CSV, ARCHIVE) +[x] "type_name" rule → MySQL type keywords (INT, VARCHAR, TEXT, BLOB, DATE, etc.) +``` + +### 2.3 Table Reference Extraction + +``` +[x] Extract table refs from simple SELECT: `SELECT * FROM t` → [{Table: "t"}] +[x] Extract table refs with alias: `SELECT * FROM t AS x` → [{Table: "t", Alias: "x"}] +[x] Extract table refs from JOIN: `FROM t1 JOIN t2 ON ...` → [{Table: "t1"}, {Table: "t2"}] +[x] Extract table refs with database: `FROM db.t` → [{Database: "db", Table: "t"}] +[x] Extract table refs from subquery: inner tables don't leak to outer scope +[x] Extract table refs from UPDATE: `UPDATE t SET ...` → [{Table: "t"}] +[x] Extract table refs from INSERT: `INSERT INTO t ...` → [{Table: "t"}] +[x] Extract table refs from DELETE: `DELETE FROM t ...` → [{Table: "t"}] +[x] Fallback to lexer-based extraction when AST parsing fails (incomplete SQL) +``` + +### 2.4 Tricky Completion (Fallback) + +``` +[x] Incomplete SQL: `SELECT * FROM ` (trailing space) → insert placeholder, re-collect +[x] Truncated mid-keyword: `SELE` → prefix-filter against keywords +[x] Truncated after comma: `SELECT a,` → insert placeholder column +[x] Truncated after operator: `WHERE a >` → insert placeholder expression +[x] Multiple placeholder strategies tried in order +[x] Fallback returns best-effort results when no strategy succeeds +[x] Placeholder insertion does not corrupt the candidate set from the initial pass +``` + +## Phase 3: SELECT Statement Instrumentation + +### 3.1 SELECT Target List + +``` +[x] `SELECT |` → columnref, func_name, literal keywords (DISTINCT, ALL, *) +[x] `SELECT a, |` → columnref, func_name after comma +[x] `SELECT a, b, |` → columnref after multiple commas +[x] `SELECT * FROM t WHERE a > (SELECT |)` → columnref in subquery +[x] `SELECT DISTINCT |` → columnref after DISTINCT +``` + +### 3.2 FROM Clause + +``` +[x] `SELECT * FROM |` → table_ref (tables, views, databases) +[x] `SELECT * FROM db.|` → table_ref qualified with database +[x] `SELECT * FROM t1, |` → table_ref after comma (multi-table) +[x] `SELECT * FROM (SELECT * FROM |)` → table_ref in derived table +[x] `SELECT * FROM t |` → keyword candidates (WHERE, JOIN, LEFT, RIGHT, CROSS, NATURAL, STRAIGHT_JOIN, ORDER, GROUP, HAVING, LIMIT, UNION, FOR) +[x] `SELECT * FROM t AS |` → no specific candidates (alias context) +``` + +### 3.3 JOIN Clauses + +``` +[x] `SELECT * FROM t1 JOIN |` → table_ref after JOIN +[x] `SELECT * FROM t1 LEFT JOIN |` → table_ref after LEFT JOIN +[x] `SELECT * FROM t1 RIGHT JOIN |` → table_ref after RIGHT JOIN +[x] `SELECT * FROM t1 CROSS JOIN |` → table_ref after CROSS JOIN +[x] `SELECT * FROM t1 NATURAL JOIN |` → table_ref after NATURAL JOIN +[x] `SELECT * FROM t1 STRAIGHT_JOIN |` → table_ref after STRAIGHT_JOIN +[x] `SELECT * FROM t1 JOIN t2 ON |` → columnref after ON +[x] `SELECT * FROM t1 JOIN t2 USING (|)` → columnref after USING ( +[x] `SELECT * FROM t1 |` → JOIN keywords (JOIN, LEFT, RIGHT, INNER, CROSS, NATURAL, STRAIGHT_JOIN) +``` + +### 3.4 WHERE, GROUP BY, HAVING + +``` +[x] `SELECT * FROM t WHERE |` → columnref after WHERE +[x] `SELECT * FROM t WHERE a = 1 AND |` → columnref after AND +[x] `SELECT * FROM t WHERE a = 1 OR |` → columnref after OR +[x] `SELECT * FROM t GROUP BY |` → columnref after GROUP BY +[x] `SELECT * FROM t GROUP BY a, |` → columnref after comma +[x] `SELECT * FROM t GROUP BY a |` → keyword candidates (HAVING, ORDER, LIMIT, WITH ROLLUP) +[x] `SELECT * FROM t HAVING |` → columnref after HAVING +``` + +### 3.5 ORDER BY, LIMIT, DISTINCT + +``` +[x] `SELECT * FROM t ORDER BY |` → columnref after ORDER BY +[x] `SELECT * FROM t ORDER BY a, |` → columnref after comma +[x] `SELECT * FROM t ORDER BY a |` → keyword candidates (ASC, DESC, LIMIT, comma) +[x] `SELECT * FROM t LIMIT |` → no specific candidates (numeric context) +[x] `SELECT * FROM t LIMIT 10 OFFSET |` → no specific candidates +``` + +### 3.6 Set Operations & FOR UPDATE + +``` +[x] `SELECT a FROM t UNION |` → keyword candidates (ALL, SELECT) +[x] `SELECT a FROM t UNION ALL |` → keyword candidate (SELECT) +[x] `SELECT a FROM t INTERSECT |` → keyword candidates (ALL, SELECT) +[x] `SELECT a FROM t EXCEPT |` → keyword candidates (ALL, SELECT) +[x] `SELECT * FROM t FOR |` → keyword candidates (UPDATE, SHARE) +[x] `SELECT * FROM t FOR UPDATE |` → keyword candidates (OF, NOWAIT, SKIP) +``` + +### 3.7 CTE (WITH Clause) + +``` +[x] `WITH |` → keyword candidate (RECURSIVE) + identifier context for CTE name +[x] `WITH cte AS (|)` → keyword candidate (SELECT) +[x] `WITH cte AS (SELECT * FROM t) SELECT |` → columnref (CTE columns available) +[x] `WITH cte AS (SELECT * FROM t) SELECT * FROM |` → table_ref (CTE name available) +[x] `WITH RECURSIVE cte(|)` → identifier context for column names +``` + +### 3.8 Window Functions & Index Hints + +``` +[x] `SELECT a, ROW_NUMBER() OVER (|)` → keyword candidates (PARTITION, ORDER) +[x] `SELECT a, SUM(b) OVER (PARTITION BY |)` → columnref +[x] `SELECT a, SUM(b) OVER (ORDER BY |)` → columnref +[x] `SELECT a, SUM(b) OVER (ORDER BY a ROWS |)` → keyword candidates (BETWEEN, UNBOUNDED, CURRENT) +[x] `SELECT * FROM t USE INDEX (|)` → index_ref +[x] `SELECT * FROM t FORCE INDEX (|)` → index_ref +[x] `SELECT * FROM t IGNORE INDEX (|)` → index_ref +``` + +## Phase 4: DML Instrumentation + +### 4.1 INSERT + +``` +[x] `INSERT INTO |` → table_ref after INTO +[x] `INSERT INTO t (|)` → columnref for table t +[x] `INSERT INTO t (a, |)` → columnref after comma +[x] `INSERT INTO t VALUES (|)` → no specific candidates (value context) +[x] `INSERT INTO t |` → keyword candidates (VALUES, SET, SELECT, PARTITION) +[x] `INSERT INTO t VALUES (1) ON DUPLICATE KEY UPDATE |` → columnref +[x] `INSERT INTO t SET |` → columnref (assignment context) +[x] `INSERT INTO t SELECT |` → columnref (INSERT SELECT) +[x] `REPLACE INTO |` → table_ref +``` + +### 4.2 UPDATE + +``` +[x] `UPDATE |` → table_ref +[x] `UPDATE t SET |` → columnref for table t +[x] `UPDATE t SET a = 1, |` → columnref after comma +[x] `UPDATE t SET a = 1 WHERE |` → columnref +[x] `UPDATE t SET a = 1 ORDER BY |` → columnref +[x] `UPDATE t1 JOIN t2 ON t1.a = t2.a SET |` → columnref from both tables +``` + +### 4.3 DELETE & LOAD DATA & CALL + +``` +[x] `DELETE FROM |` → table_ref +[x] `DELETE FROM t WHERE |` → columnref for table t +[x] `DELETE FROM t ORDER BY |` → columnref +[x] `DELETE t1 FROM t1 JOIN t2 ON t1.a = t2.a WHERE |` → columnref from both tables +[x] `LOAD DATA INFILE 'f' INTO TABLE |` → table_ref +[x] `CALL |` → procedure_ref +``` + +## Phase 5: DDL Instrumentation + +### 5.1 CREATE TABLE + +``` +[x] `CREATE TABLE |` → identifier context (no specific candidates) +[x] `CREATE TABLE t (a INT, |)` → keyword candidates for column/constraint start (PRIMARY, UNIQUE, INDEX, KEY, FOREIGN, CHECK, CONSTRAINT) +[x] `CREATE TABLE t (a INT |)` → keyword candidates for column options (NOT, NULL, DEFAULT, AUTO_INCREMENT, PRIMARY, UNIQUE, COMMENT, COLLATE, REFERENCES, CHECK, GENERATED) +[x] `CREATE TABLE t (a INT) |` → keyword candidates for table options (ENGINE, DEFAULT, CHARSET, COLLATE, COMMENT, AUTO_INCREMENT, ROW_FORMAT, PARTITION) +[x] `CREATE TABLE t (a INT) ENGINE=|` → keyword candidates for engines (InnoDB, MyISAM, MEMORY, etc.) +[x] `CREATE TABLE t (a |)` → type candidates (INT, VARCHAR, TEXT, BLOB, DATE, DATETIME, DECIMAL, etc.) +[x] `CREATE TABLE t (FOREIGN KEY (a) REFERENCES |)` → table_ref +[x] `CREATE TABLE t LIKE |` → table_ref +[x] `CREATE TABLE t (a INT GENERATED ALWAYS AS (|))` → expression context (columnref, func_name) +``` + +### 5.2 ALTER TABLE + +``` +[x] `ALTER TABLE |` → table_ref +[x] `ALTER TABLE t |` → keyword candidates (ADD, DROP, MODIFY, CHANGE, RENAME, ALTER, CONVERT, ENGINE, DEFAULT, ORDER, ALGORITHM, LOCK, FORCE, ADD PARTITION, DROP PARTITION) +[x] `ALTER TABLE t ADD |` → keyword candidates (COLUMN, INDEX, KEY, UNIQUE, PRIMARY, FOREIGN, CONSTRAINT, CHECK, PARTITION, SPATIAL, FULLTEXT) +[x] `ALTER TABLE t ADD COLUMN |` → identifier context +[x] `ALTER TABLE t DROP |` → keyword candidates (COLUMN, INDEX, KEY, FOREIGN, PRIMARY, CHECK, CONSTRAINT, PARTITION) +[x] `ALTER TABLE t DROP COLUMN |` → columnref for table t +[x] `ALTER TABLE t DROP INDEX |` → index_ref for table t +[x] `ALTER TABLE t DROP FOREIGN KEY |` → constraint_ref +[x] `ALTER TABLE t DROP CONSTRAINT |` → constraint_ref (generic, MySQL 8.0.16+) +[x] `ALTER TABLE t MODIFY |` → columnref +[x] `ALTER TABLE t MODIFY COLUMN |` → columnref +[x] `ALTER TABLE t CHANGE |` → columnref (old name) +[x] `ALTER TABLE t RENAME TO |` → identifier context +[x] `ALTER TABLE t RENAME COLUMN |` → columnref +[x] `ALTER TABLE t RENAME INDEX |` → index_ref +[x] `ALTER TABLE t ADD INDEX idx (|)` → columnref +[x] `ALTER TABLE t CONVERT TO CHARACTER SET |` → charset candidates +[x] `ALTER TABLE t ALGORITHM=|` → keyword candidates (DEFAULT, INPLACE, COPY, INSTANT) +``` + +### 5.3 CREATE/DROP Index, View, Database + +``` +[x] `CREATE INDEX idx ON |` → table_ref +[x] `CREATE INDEX idx ON t (|)` → columnref for table t +[x] `CREATE UNIQUE INDEX idx ON |` → table_ref +[x] `DROP INDEX |` → index_ref +[x] `DROP INDEX idx ON |` → table_ref +[x] `CREATE VIEW |` → identifier context +[x] `CREATE VIEW v AS |` → keyword candidate (SELECT) +[x] `CREATE DEFINER=|` → keyword candidate (CURRENT_USER) + user context +[x] `ALTER VIEW v AS |` → keyword candidate (SELECT) +[x] `DROP VIEW |` → view_ref +[x] `CREATE DATABASE |` → identifier context +[x] `DROP DATABASE |` → database_ref +[x] `DROP TABLE |` → table_ref +[x] `DROP TABLE IF EXISTS |` → table_ref +[x] `TRUNCATE TABLE |` → table_ref +[x] `RENAME TABLE |` → table_ref +[x] `RENAME TABLE t TO |` → identifier context +[x] `DESCRIBE |` → table_ref +[x] `DESC |` → table_ref +``` + +## Phase 6: Routine/Trigger/Event Instrumentation + +### 6.1 Functions & Procedures + +``` +[x] `CREATE FUNCTION |` → identifier context +[x] `CREATE FUNCTION f(|)` → keyword candidates for param direction (IN, OUT, INOUT) + type context +[x] `CREATE FUNCTION f() RETURNS |` → type candidates +[x] `CREATE FUNCTION f() |` → keyword candidates for characteristics (DETERMINISTIC, NO SQL, READS SQL DATA, MODIFIES SQL DATA, COMMENT, LANGUAGE, SQL SECURITY) +[x] `DROP FUNCTION |` → function_ref +[x] `DROP FUNCTION IF EXISTS |` → function_ref +[x] `CREATE PROCEDURE |` → identifier context +[x] `DROP PROCEDURE |` → procedure_ref +[x] `ALTER FUNCTION |` → function_ref +[x] `ALTER PROCEDURE |` → procedure_ref +``` + +### 6.2 Triggers & Events + +``` +[x] `CREATE TRIGGER |` → identifier context +[x] `CREATE TRIGGER trg |` → keyword candidates (BEFORE, AFTER) +[x] `CREATE TRIGGER trg BEFORE |` → keyword candidates (INSERT, UPDATE, DELETE) +[x] `CREATE TRIGGER trg BEFORE INSERT ON |` → table_ref +[x] `DROP TRIGGER |` → trigger_ref +[x] `CREATE EVENT |` → identifier context +[x] `CREATE EVENT ev ON SCHEDULE |` → keyword candidates (AT, EVERY) +[x] `DROP EVENT |` → event_ref +[x] `ALTER EVENT |` → event_ref +``` + +### 6.3 Transaction, LOCK & Table Maintenance + +``` +[x] `BEGIN |` → keyword candidates (WORK) +[x] `START TRANSACTION |` → keyword candidates (WITH CONSISTENT SNAPSHOT, READ ONLY, READ WRITE) +[x] `COMMIT |` → keyword candidates (AND, WORK) +[x] `ROLLBACK |` → keyword candidates (TO, WORK) +[x] `ROLLBACK TO |` → keyword candidate (SAVEPOINT) +[x] `SAVEPOINT |` → identifier context +[x] `RELEASE SAVEPOINT |` → identifier context +[x] `LOCK TABLES |` → table_ref +[x] `LOCK TABLES t |` → keyword candidates (READ, WRITE) +[x] `ANALYZE TABLE |` → table_ref +[x] `OPTIMIZE TABLE |` → table_ref +[x] `CHECK TABLE |` → table_ref +[x] `REPAIR TABLE |` → table_ref +[x] `FLUSH |` → keyword candidates (PRIVILEGES, TABLES, LOGS, STATUS, HOSTS) +``` + +## Phase 7: Session/Utility Instrumentation + +### 7.1 SET & SHOW + +``` +[x] `SET |` → variable candidates (@@, @, GLOBAL, SESSION, NAMES, CHARACTER, PASSWORD) + keyword candidates +[x] `SET GLOBAL |` → variable candidates (system variables) +[x] `SET SESSION |` → variable candidates +[x] `SET NAMES |` → charset candidates +[x] `SET CHARACTER SET |` → charset candidates +[x] `SHOW |` → keyword candidates (TABLES, COLUMNS, INDEX, DATABASES, CREATE, STATUS, VARIABLES, PROCESSLIST, GRANTS, WARNINGS, ERRORS, ENGINE, etc.) +[x] `SHOW CREATE TABLE |` → table_ref +[x] `SHOW CREATE VIEW |` → view_ref +[x] `SHOW CREATE FUNCTION |` → function_ref +[x] `SHOW CREATE PROCEDURE |` → procedure_ref +[x] `SHOW COLUMNS FROM |` → table_ref +[x] `SHOW INDEX FROM |` → table_ref +[x] `SHOW TABLES FROM |` → database_ref +``` + +### 7.2 USE, GRANT, EXPLAIN + +``` +[x] `USE |` → database_ref +[x] `GRANT |` → keyword candidates (ALL, SELECT, INSERT, UPDATE, DELETE, CREATE, ALTER, DROP, INDEX, EXECUTE, etc.) +[x] `GRANT SELECT ON |` → table_ref (or database.*) +[x] `GRANT SELECT ON t TO |` → user context +[x] `REVOKE SELECT ON |` → table_ref +[x] `EXPLAIN |` → keyword candidates (SELECT, INSERT, UPDATE, DELETE, FORMAT) +[x] `EXPLAIN SELECT |` → same as SELECT instrumentation +[x] `PREPARE stmt FROM |` → string context +[x] `EXECUTE |` → prepared statement name +[x] `DEALLOCATE PREPARE |` → prepared statement name +[x] `DO |` → expression context (columnref, func_name) +``` + +## Phase 8: Expression Instrumentation + +### 8.1 Function & Type Names + +``` +[x] `SELECT |()` context → func_name candidates (built-in functions: COUNT, SUM, AVG, MAX, MIN, CONCAT, SUBSTRING, TRIM, NOW, IF, IFNULL, COALESCE, CAST, CONVERT, etc.) +[x] `SELECT CAST(a AS |)` → type candidates (CHAR, SIGNED, UNSIGNED, DECIMAL, DATE, DATETIME, TIME, JSON, BINARY) +[x] `SELECT CONVERT(a, |)` → type candidates +[x] `SELECT CONVERT(a USING |)` → charset candidates +``` + +### 8.2 Expression Contexts + +``` +[x] `SELECT a + |` → columnref, func_name (expression continuation) +[x] `SELECT CASE WHEN |` → columnref (CASE WHEN condition) +[x] `SELECT CASE WHEN a THEN |` → columnref, literal (CASE THEN result) +[x] `SELECT CASE a WHEN |` → literal context (CASE WHEN value) +[x] `SELECT * FROM t WHERE a IN (|)` → columnref, literal (IN list or subquery) +[x] `SELECT * FROM t WHERE a BETWEEN |` → columnref, literal (BETWEEN lower bound) +[x] `SELECT * FROM t WHERE a BETWEEN 1 AND |` → columnref, literal (BETWEEN upper bound) +[x] `SELECT * FROM t WHERE EXISTS (|)` → keyword candidate (SELECT) +[x] `SELECT * FROM t WHERE a IS |` → keyword candidates (NULL, NOT, TRUE, FALSE, UNKNOWN) +``` + +## Phase 9: Integration Tests + +### 9.1 Multi-Table Schema Tests + +``` +[x] Column completion scoped to correct table in JOIN: `SELECT t1.| FROM t1 JOIN t2 ON ...` → only t1 columns +[x] Column completion from all tables when unqualified: `SELECT | FROM t1 JOIN t2 ON ...` → columns from both tables +[x] Table alias completion: `SELECT x.| FROM t AS x` → columns of t via alias x +[x] View column completion: `SELECT | FROM v` → columns of view v +[x] CTE column completion: `WITH cte AS (SELECT a FROM t) SELECT | FROM cte` → column a +[x] Database-qualified table: `SELECT * FROM db.| ` → tables in database db +``` + +### 9.2 Edge Cases + +``` +[x] Cursor at beginning of SQL: `|SELECT * FROM t` → top-level keywords +[x] Cursor in middle of identifier: `SELECT us|ers FROM t` → prefix "us" filters candidates +[x] Cursor after semicolon (multi-statement): `SELECT 1; SELECT |` → new statement context +[x] Empty SQL: `|` → top-level keywords +[x] Whitespace only: ` |` → top-level keywords +[x] Very long SQL with cursor in middle +[x] SQL with syntax errors before cursor: completion still works for valid prefix +[x] Backtick-quoted identifiers: `SELECT \`| FROM t` → column candidates +``` + +### 9.3 Complex SQL Patterns + +``` +[x] Nested subquery column completion: `SELECT * FROM t WHERE a IN (SELECT | FROM t2)` → t2 columns +[x] Correlated subquery: `SELECT *, (SELECT | FROM t2 WHERE t2.a = t1.a) FROM t1` → t2 columns +[x] UNION: `SELECT a FROM t1 UNION SELECT | FROM t2` → t2 columns +[x] Multiple JOINs: `SELECT | FROM t1 JOIN t2 ON ... JOIN t3 ON ...` → columns from all 3 tables +[x] INSERT ... SELECT: `INSERT INTO t1 SELECT | FROM t2` → t2 columns +[x] Complex ALTER: `ALTER TABLE t ADD CONSTRAINT fk FOREIGN KEY (|) REFERENCES ...` → t columns +``` diff --git a/tidb/completion/completion.go b/tidb/completion/completion.go new file mode 100644 index 00000000..7b15c9ac --- /dev/null +++ b/tidb/completion/completion.go @@ -0,0 +1,163 @@ +// Package completion provides parser-native C3-style SQL completion for MySQL. +package completion + +import ( + "strings" + + "github.com/bytebase/omni/tidb/catalog" + "github.com/bytebase/omni/tidb/parser" +) + +// CandidateType classifies a completion candidate. +type CandidateType int + +const ( + CandidateKeyword CandidateType = iota // SQL keyword + CandidateDatabase // database name + CandidateTable // table name + CandidateView // view name + CandidateColumn // column name + CandidateFunction // function name + CandidateProcedure // procedure name + CandidateIndex // index name + CandidateTrigger // trigger name + CandidateEvent // event name + CandidateVariable // variable name + CandidateCharset // charset name + CandidateEngine // engine name + CandidateType_ // SQL type name +) + +// Candidate is a single completion suggestion. +type Candidate struct { + Text string // the completion text + Type CandidateType // what kind of object this is + Definition string // optional definition/signature + Comment string // optional doc comment +} + +// Complete returns completion candidates for the given SQL at the cursor offset. +// cat may be nil if no catalog context is available. +func Complete(sql string, cursorOffset int, cat *catalog.Catalog) []Candidate { + prefix := extractPrefix(sql, cursorOffset) + + // When the cursor is mid-token, back up to the start of the token + // so the parser sees the position before the partial text. + collectOffset := cursorOffset - len(prefix) + + result := standardComplete(sql, collectOffset, cat) + if len(result) == 0 { + result = trickyComplete(sql, collectOffset, cat) + } + + return filterByPrefix(result, prefix) +} + +// extractPrefix returns the partial token the user is typing at cursorOffset. +func extractPrefix(sql string, cursorOffset int) string { + if cursorOffset > len(sql) { + cursorOffset = len(sql) + } + i := cursorOffset + for i > 0 { + c := sql[i-1] + if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || c == '_' { + i-- + } else { + break + } + } + return sql[i:cursorOffset] +} + +// filterByPrefix filters candidates whose text starts with prefix (case-insensitive). +func filterByPrefix(candidates []Candidate, prefix string) []Candidate { + if prefix == "" { + return candidates + } + upper := strings.ToUpper(prefix) + var result []Candidate + for _, c := range candidates { + if strings.HasPrefix(strings.ToUpper(c.Text), upper) { + result = append(result, c) + } + } + return result +} + +// standardComplete collects parser-level candidates using Collect, then +// resolves them against the catalog. +func standardComplete(sql string, cursorOffset int, cat *catalog.Catalog) []Candidate { + cs := parser.Collect(sql, cursorOffset) + return resolve(cs, cat, sql, cursorOffset) +} + +// trickyComplete handles edge cases that the standard C3 approach cannot +// resolve (e.g., partially typed identifiers in ambiguous positions). +func trickyComplete(sql string, cursorOffset int, cat *catalog.Catalog) []Candidate { + if cursorOffset > len(sql) { + cursorOffset = len(sql) + } + prefix := sql[:cursorOffset] + suffix := "" + if cursorOffset < len(sql) { + suffix = sql[cursorOffset:] + } + + strategies := []string{ + prefix + " __placeholder__" + suffix, + prefix + " __placeholder__ " + suffix, + prefix + " 1" + suffix, + } + + for _, patched := range strategies { + cs := parser.Collect(patched, cursorOffset) + if cs != nil && (len(cs.Tokens) > 0 || len(cs.Rules) > 0) { + return resolve(cs, cat, sql, cursorOffset) + } + } + return nil +} + +// resolve converts parser CandidateSet into typed Candidate values. +// Token candidates become keywords; rule candidates are resolved against the catalog. +func resolve(cs *parser.CandidateSet, cat *catalog.Catalog, sql string, cursorOffset int) []Candidate { + if cs == nil { + return nil + } + var result []Candidate + + // Token candidates -> keywords + for _, tok := range cs.Tokens { + name := parser.TokenName(tok) + if name == "" { + continue + } + result = append(result, Candidate{Text: name, Type: CandidateKeyword}) + } + + // Rule candidates -> catalog objects + result = append(result, resolveRules(cs, cat, sql, cursorOffset)...) + + return dedup(result) +} + +// dedup removes duplicate candidates (same text+type, case-insensitive). +func dedup(cs []Candidate) []Candidate { + type key struct { + text string + typ CandidateType + } + seen := make(map[key]bool, len(cs)) + result := make([]Candidate, 0, len(cs)) + for _, c := range cs { + k := key{strings.ToLower(c.Text), c.Type} + if seen[k] { + continue + } + seen[k] = true + result = append(result, c) + } + return result +} diff --git a/tidb/completion/completion_test.go b/tidb/completion/completion_test.go new file mode 100644 index 00000000..342b2bce --- /dev/null +++ b/tidb/completion/completion_test.go @@ -0,0 +1,2835 @@ +package completion + +import ( + "testing" + + "github.com/bytebase/omni/tidb/catalog" +) + +// containsCandidate returns true if candidates contains one with the given text and type. +func containsCandidate(candidates []Candidate, text string, typ CandidateType) bool { + for _, c := range candidates { + if c.Text == text && c.Type == typ { + return true + } + } + return false +} + +// containsText returns true if any candidate has the given text. +func containsText(candidates []Candidate, text string) bool { + for _, c := range candidates { + if c.Text == text { + return true + } + } + return false +} + +// hasDuplicates returns true if there are duplicate (text, type) pairs (case-insensitive). +func hasDuplicates(candidates []Candidate) bool { + type key struct { + text string + typ CandidateType + } + seen := make(map[key]bool) + for _, c := range candidates { + k := key{text: c.Text, typ: c.Type} + if seen[k] { + return true + } + seen[k] = true + } + return false +} + +func TestComplete_2_1_CompleteReturnsSlice(t *testing.T) { + // Scenario: Complete(sql, cursorOffset, catalog) returns []Candidate + cat := catalog.New() + candidates := Complete("SELECT ", 7, cat) + if candidates == nil { + // nil is acceptable (no candidates), but the function should not panic + candidates = []Candidate{} + } + // Just verify it returns a slice (type is enforced by compiler). + _ = candidates +} + +func TestComplete_2_1_CandidateFields(t *testing.T) { + // Scenario: Candidate struct has Text, Type, Definition, Comment fields + c := Candidate{ + Text: "SELECT", + Type: CandidateKeyword, + Definition: "SQL SELECT statement", + Comment: "Retrieves data", + } + if c.Text != "SELECT" { + t.Errorf("Text = %q, want SELECT", c.Text) + } + if c.Type != CandidateKeyword { + t.Errorf("Type = %d, want CandidateKeyword", c.Type) + } + if c.Definition != "SQL SELECT statement" { + t.Errorf("Definition = %q", c.Definition) + } + if c.Comment != "Retrieves data" { + t.Errorf("Comment = %q", c.Comment) + } +} + +func TestComplete_2_1_CandidateTypeEnum(t *testing.T) { + // Scenario: CandidateType enum with all types + types := []CandidateType{ + CandidateKeyword, + CandidateDatabase, + CandidateTable, + CandidateView, + CandidateColumn, + CandidateFunction, + CandidateProcedure, + CandidateIndex, + CandidateTrigger, + CandidateEvent, + CandidateVariable, + CandidateCharset, + CandidateEngine, + CandidateType_, + } + // All types should be distinct. + seen := make(map[CandidateType]bool) + for _, ct := range types { + if seen[ct] { + t.Errorf("duplicate CandidateType value %d", ct) + } + seen[ct] = true + } + if len(types) != 14 { + t.Errorf("expected 14 CandidateType values, got %d", len(types)) + } +} + +func TestComplete_2_1_NilCatalog(t *testing.T) { + // Scenario: Complete with nil catalog returns keyword-only candidates + // (plus built-in function names, which are always available regardless of catalog). + candidates := Complete("SELECT ", 7, nil) + for _, c := range candidates { + if c.Type != CandidateKeyword && c.Type != CandidateFunction { + t.Errorf("with nil catalog, got unexpected candidate type: %+v", c) + } + } + // Should still return some keywords (e.g., DISTINCT, ALL from SELECT context). + if len(candidates) == 0 { + t.Error("expected some keyword candidates with nil catalog") + } + // No catalog-dependent types should appear. + for _, c := range candidates { + switch c.Type { + case CandidateTable, CandidateView, CandidateColumn, CandidateDatabase, + CandidateProcedure, CandidateIndex, CandidateTrigger, CandidateEvent: + t.Errorf("with nil catalog, got catalog-dependent candidate: %+v", c) + } + } +} + +func TestComplete_2_1_EmptySQL(t *testing.T) { + // Scenario: Complete with empty sql returns top-level statement keywords + candidates := Complete("", 0, nil) + if len(candidates) == 0 { + t.Fatal("expected top-level keywords for empty SQL") + } + // Should contain core statement keywords. + for _, kw := range []string{"SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP"} { + if !containsCandidate(candidates, kw, CandidateKeyword) { + t.Errorf("missing expected keyword %s", kw) + } + } + // All should be keywords. + for _, c := range candidates { + if c.Type != CandidateKeyword { + t.Errorf("non-keyword candidate in empty SQL: %+v", c) + } + } +} + +func TestComplete_2_1_PrefixFiltering(t *testing.T) { + // Scenario: Prefix filtering: `SEL|` matches SELECT keyword + candidates := Complete("SEL", 3, nil) + if !containsCandidate(candidates, "SELECT", CandidateKeyword) { + t.Error("expected SELECT in candidates for prefix SEL") + } + // Should not contain non-matching keywords. + if containsCandidate(candidates, "INSERT", CandidateKeyword) { + t.Error("INSERT should not match prefix SEL") + } +} + +func TestComplete_2_1_PrefixCaseInsensitive(t *testing.T) { + // Scenario: Prefix filtering is case-insensitive + candidates := Complete("sel", 3, nil) + if !containsCandidate(candidates, "SELECT", CandidateKeyword) { + t.Error("expected SELECT in candidates for lowercase prefix sel") + } + // Mixed case + candidates2 := Complete("Sel", 3, nil) + if !containsCandidate(candidates2, "SELECT", CandidateKeyword) { + t.Error("expected SELECT in candidates for mixed-case prefix Sel") + } +} + +func TestComplete_2_1_Deduplication(t *testing.T) { + // Scenario: Deduplication: same candidate not returned twice + // Use a context that might produce duplicate token candidates. + candidates := Complete("", 0, nil) + if hasDuplicates(candidates) { + t.Error("found duplicate candidates in results") + } + + // Also test with a prefix context. + candidates2 := Complete("SELECT ", 7, nil) + if hasDuplicates(candidates2) { + t.Error("found duplicate candidates in SELECT context") + } +} + +// --- Section 2.2: Candidate Resolution --- + +// setupCatalog creates a catalog with a test database for resolution tests. +func setupCatalog(t *testing.T) *catalog.Catalog { + t.Helper() + cat := catalog.New() + mustExec(t, cat, "CREATE DATABASE testdb") + cat.SetCurrentDatabase("testdb") + mustExec(t, cat, "CREATE TABLE users (id INT, name VARCHAR(100), email VARCHAR(200))") + mustExec(t, cat, "CREATE TABLE orders (id INT, user_id INT, total DECIMAL(10,2))") + mustExec(t, cat, "CREATE INDEX idx_name ON users (name)") + mustExec(t, cat, "CREATE INDEX idx_user_id ON orders (user_id)") + mustExec(t, cat, "CREATE VIEW active_users AS SELECT * FROM users WHERE id > 0") + mustExec(t, cat, "CREATE FUNCTION my_func() RETURNS INT DETERMINISTIC RETURN 1") + mustExec(t, cat, "CREATE PROCEDURE my_proc() BEGIN SELECT 1; END") + mustExec(t, cat, "CREATE TRIGGER my_trig BEFORE INSERT ON users FOR EACH ROW SET NEW.name = UPPER(NEW.name)") + // Event creation requires schedule — use Exec directly. + mustExec(t, cat, "CREATE EVENT my_event ON SCHEDULE EVERY 1 HOUR DO SELECT 1") + return cat +} + +// mustExec executes SQL on the catalog, failing the test on error. +func mustExec(t *testing.T, cat *catalog.Catalog, sql string) { + t.Helper() + if _, err := cat.Exec(sql, nil); err != nil { + t.Fatalf("Exec(%q) failed: %v", sql, err) + } +} + +func TestResolve_2_2_TokenCandidatesKeywords(t *testing.T) { + // Scenario: Token candidates -> keyword strings (from token type mapping) + // Tested via Complete — empty SQL yields token-only candidates resolved as keywords. + candidates := Complete("", 0, nil) + if len(candidates) == 0 { + t.Fatal("expected keyword candidates") + } + for _, c := range candidates { + if c.Type != CandidateKeyword { + t.Errorf("expected keyword type, got %d for %q", c.Type, c.Text) + } + } +} + +func TestResolve_2_2_TableRef(t *testing.T) { + // Scenario: "table_ref" rule -> catalog tables + views + cat := setupCatalog(t) + candidates := resolveRule("table_ref", cat, "", 0) + if !containsCandidate(candidates, "users", CandidateTable) { + t.Error("missing table 'users'") + } + if !containsCandidate(candidates, "orders", CandidateTable) { + t.Error("missing table 'orders'") + } + if !containsCandidate(candidates, "active_users", CandidateView) { + t.Error("missing view 'active_users'") + } +} + +func TestResolve_2_2_ColumnRef(t *testing.T) { + // Scenario: "columnref" rule -> columns from tables in scope + // For now, returns all columns from all tables in current database. + cat := setupCatalog(t) + candidates := resolveRule("columnref", cat, "", 0) + // users: id, name, email + if !containsCandidate(candidates, "id", CandidateColumn) { + t.Error("missing column 'id'") + } + if !containsCandidate(candidates, "name", CandidateColumn) { + t.Error("missing column 'name'") + } + if !containsCandidate(candidates, "email", CandidateColumn) { + t.Error("missing column 'email'") + } + // orders: user_id, total (id is deduped) + if !containsCandidate(candidates, "user_id", CandidateColumn) { + t.Error("missing column 'user_id'") + } + if !containsCandidate(candidates, "total", CandidateColumn) { + t.Error("missing column 'total'") + } +} + +func TestResolve_2_2_DatabaseRef(t *testing.T) { + // Scenario: "database_ref" rule -> catalog databases + cat := setupCatalog(t) + // Add another database. + mustExec(t, cat, "CREATE DATABASE otherdb") + candidates := resolveRule("database_ref", cat, "", 0) + if !containsCandidate(candidates, "testdb", CandidateDatabase) { + t.Error("missing database 'testdb'") + } + if !containsCandidate(candidates, "otherdb", CandidateDatabase) { + t.Error("missing database 'otherdb'") + } +} + +func TestResolve_2_2_FunctionRef(t *testing.T) { + // Scenario: "function_ref" / "func_name" rule -> catalog functions + built-in names + cat := setupCatalog(t) + for _, rule := range []string{"function_ref", "func_name"} { + candidates := resolveRule(rule, cat, "", 0) + // Should include built-in functions. + if !containsCandidate(candidates, "COUNT", CandidateFunction) { + t.Errorf("[%s] missing built-in function COUNT", rule) + } + if !containsCandidate(candidates, "CONCAT", CandidateFunction) { + t.Errorf("[%s] missing built-in function CONCAT", rule) + } + if !containsCandidate(candidates, "NOW", CandidateFunction) { + t.Errorf("[%s] missing built-in function NOW", rule) + } + // Should include catalog function. + if !containsCandidate(candidates, "my_func", CandidateFunction) { + t.Errorf("[%s] missing catalog function 'my_func'", rule) + } + } +} + +func TestResolve_2_2_ProcedureRef(t *testing.T) { + // Scenario: "procedure_ref" rule -> catalog procedures + cat := setupCatalog(t) + candidates := resolveRule("procedure_ref", cat, "", 0) + if !containsCandidate(candidates, "my_proc", CandidateProcedure) { + t.Error("missing procedure 'my_proc'") + } +} + +func TestResolve_2_2_IndexRef(t *testing.T) { + // Scenario: "index_ref" rule -> indexes from relevant table + cat := setupCatalog(t) + candidates := resolveRule("index_ref", cat, "", 0) + if !containsCandidate(candidates, "idx_name", CandidateIndex) { + t.Error("missing index 'idx_name'") + } + if !containsCandidate(candidates, "idx_user_id", CandidateIndex) { + t.Error("missing index 'idx_user_id'") + } +} + +func TestResolve_2_2_TriggerRef(t *testing.T) { + // Scenario: "trigger_ref" rule -> catalog triggers + cat := setupCatalog(t) + candidates := resolveRule("trigger_ref", cat, "", 0) + if !containsCandidate(candidates, "my_trig", CandidateTrigger) { + t.Error("missing trigger 'my_trig'") + } +} + +func TestResolve_2_2_EventRef(t *testing.T) { + // Scenario: "event_ref" rule -> catalog events + cat := setupCatalog(t) + candidates := resolveRule("event_ref", cat, "", 0) + if !containsCandidate(candidates, "my_event", CandidateEvent) { + t.Error("missing event 'my_event'") + } +} + +func TestResolve_2_2_ViewRef(t *testing.T) { + // Scenario: "view_ref" rule -> catalog views + cat := setupCatalog(t) + candidates := resolveRule("view_ref", cat, "", 0) + if !containsCandidate(candidates, "active_users", CandidateView) { + t.Error("missing view 'active_users'") + } +} + +func TestResolve_2_2_Charset(t *testing.T) { + // Scenario: "charset" rule -> known charset names + candidates := resolveRule("charset", nil, "", 0) + for _, cs := range []string{"utf8mb4", "latin1", "utf8", "ascii", "binary"} { + if !containsCandidate(candidates, cs, CandidateCharset) { + t.Errorf("missing charset %q", cs) + } + } +} + +func TestResolve_2_2_Engine(t *testing.T) { + // Scenario: "engine" rule -> known engine names + candidates := resolveRule("engine", nil, "", 0) + for _, eng := range []string{"InnoDB", "MyISAM", "MEMORY", "CSV", "ARCHIVE"} { + if !containsCandidate(candidates, eng, CandidateEngine) { + t.Errorf("missing engine %q", eng) + } + } +} + +func TestResolve_2_2_TypeName(t *testing.T) { + // Scenario: "type_name" rule -> MySQL type keywords + candidates := resolveRule("type_name", nil, "", 0) + for _, typ := range []string{"INT", "VARCHAR", "TEXT", "BLOB", "DATE", "DATETIME", "DECIMAL", "JSON", "ENUM"} { + if !containsCandidate(candidates, typ, CandidateType_) { + t.Errorf("missing type %q", typ) + } + } +} + +// --- Section 2.4: Tricky Completion (Fallback) --- + +func TestComplete_2_4_IncompleteTrailingSpace(t *testing.T) { + // Scenario: Incomplete SQL with trailing space → insert placeholder, re-collect. + // The trickyComplete function patches SQL with placeholder tokens to make it + // parseable, then re-runs Collect. When standard Collect returns nothing, + // trickyComplete should return whatever the patched version produces. + // + // Use a context where standard returns empty but placeholder strategy succeeds: + // `SELECT ` at offset 7 gets standard candidates via the SELECT expr + // instrumentation. So instead we test that trickyComplete is called when + // standardComplete returns empty results. + cat := setupCatalog(t) + + // Test that trailing space after FROM gets candidates via tricky path. + // The numeric placeholder "1" makes "SELECT * FROM 1" parseable, yielding + // keyword tokens for the follow set (WHERE, JOIN, etc.). + candidates := Complete("SELECT * FROM ", 14, cat) + if len(candidates) == 0 { + t.Skip("FROM clause not yet instrumented (Phase 3); tricky mechanism works but parser lacks rule candidates here") + } +} + +func TestComplete_2_4_TruncatedMidKeyword(t *testing.T) { + // Scenario: Truncated mid-keyword: `SELE` → prefix-filter against keywords. + // The prefix "SELE" is extracted, Collect runs at offset 0 (start of statement), + // producing top-level keywords, then filterByPrefix keeps only SELECT. + candidates := Complete("SELE", 4, nil) + if !containsCandidate(candidates, "SELECT", CandidateKeyword) { + t.Error("expected SELECT keyword from prefix filter for 'SELE'") + } + if containsCandidate(candidates, "INSERT", CandidateKeyword) { + t.Error("INSERT should not match prefix SELE") + } +} + +func TestComplete_2_4_TruncatedAfterComma(t *testing.T) { + // Scenario: Truncated after comma: `SELECT a,` → insert placeholder column. + // After the comma, standardComplete runs at offset 9. If it returns nothing, + // trickyComplete patches to "SELECT a, __placeholder__" or "SELECT a, 1" + // which may parse differently. + // + // With current parser instrumentation, the SELECT expr list checkpoint is at + // the start of parseSelectExprs. The comma case needs additional instrumentation + // (Phase 3, scenario 3.1). But we verify the tricky mechanism doesn't panic + // and returns whatever the parser can provide. + cat := setupCatalog(t) + candidates := Complete("SELECT a,", 9, cat) + // The mechanism must not panic; results depend on parser instrumentation. + _ = candidates +} + +func TestComplete_2_4_TruncatedAfterOperator(t *testing.T) { + // Scenario: Truncated after operator: `WHERE a >` → insert placeholder expression. + // trickyComplete patches to "... WHERE id > __placeholder__" or "... WHERE id > 1". + // The numeric placeholder "1" makes valid SQL, so Collect can run on it. + cat := setupCatalog(t) + candidates := Complete("SELECT * FROM users WHERE id >", 30, cat) + // Must not panic. Results depend on expression instrumentation (Phase 8). + _ = candidates +} + +func TestComplete_2_4_MultiplePlaceholderStrategies(t *testing.T) { + // Scenario: Multiple placeholder strategies tried in order. + // trickyComplete tries three strategies: + // 1. prefix + " __placeholder__" + suffix + // 2. prefix + " __placeholder__ " + suffix + // 3. prefix + " 1" + suffix + // We verify the function exists, tries them in order, and returns the first + // strategy that yields candidates. + + // Use trickyComplete directly to verify it returns results when a strategy works. + // "SELECT " at offset 7 — standard would return results, but we call tricky + // directly to verify the placeholder mechanism. + candidates := trickyComplete("", 0, nil) + // For empty SQL, even the patched versions should produce keyword candidates + // because " __placeholder__" at offset 0 triggers statement-start keywords. + if len(candidates) == 0 { + t.Error("expected trickyComplete to produce candidates for empty SQL via placeholder strategy") + } + + // Verify keywords are present (the placeholder text itself should not appear). + hasKeyword := false + for _, c := range candidates { + if c.Type == CandidateKeyword { + hasKeyword = true + break + } + } + if !hasKeyword { + t.Error("expected keyword candidates from placeholder strategy") + } +} + +func TestComplete_2_4_FallbackBestEffort(t *testing.T) { + // Scenario: Fallback returns best-effort results when no strategy succeeds. + // Completely nonsensical SQL should not panic. trickyComplete returns nil + // when no strategy produces candidates. + candidates := Complete("XYZZY PLUGH ", 12, nil) + // Must not panic. Result may be empty or nil. + _ = candidates + + // Also test with more realistic but still broken SQL. + candidates2 := Complete(")))((( ", 7, nil) + _ = candidates2 + + // Verify trickyComplete returns nil for truly unparseable input. + tricky := trickyComplete("XYZZY PLUGH ", 12, nil) + // nil is acceptable — it means no strategy succeeded. + _ = tricky +} + +func TestComplete_2_4_PlaceholderNoCorruption(t *testing.T) { + // Scenario: Placeholder insertion does not corrupt the initial candidate set. + // Running Complete multiple times on the same input must produce consistent results. + // The placeholder text (__placeholder__) must never leak into returned candidates. + cat := setupCatalog(t) + + // Use a SQL that produces candidates via standard path. + validSQL := "SELECT " + validCandidates := Complete(validSQL, len(validSQL), cat) + validCandidates2 := Complete(validSQL, len(validSQL), cat) + + // Both runs should return the same number of candidates. + if len(validCandidates) != len(validCandidates2) { + t.Errorf("candidate count mismatch: first=%d, second=%d", len(validCandidates), len(validCandidates2)) + } + + // Placeholder text must not leak into any candidate set. + for _, c := range validCandidates { + if c.Text == "__placeholder__" { + t.Error("placeholder text leaked into candidate set") + } + } + for _, c := range validCandidates2 { + if c.Text == "__placeholder__" { + t.Error("placeholder text leaked into candidate set on second run") + } + } + + // Also test via trickyComplete directly: placeholder must not appear in results. + trickyCandidates := trickyComplete("SELECT * FROM ", 14, cat) + for _, c := range trickyCandidates { + if c.Text == "__placeholder__" { + t.Error("placeholder text leaked into tricky candidate set") + } + } +} + +func TestResolve_2_2_NilCatalogSafety(t *testing.T) { + // All catalog-dependent rules should handle nil catalog gracefully. + for _, rule := range []string{"table_ref", "columnref", "database_ref", "procedure_ref", "index_ref", "trigger_ref", "event_ref", "view_ref"} { + candidates := resolveRule(rule, nil, "", 0) + if candidates != nil && len(candidates) > 0 { + t.Errorf("[%s] expected no candidates with nil catalog, got %d", rule, len(candidates)) + } + } + // function_ref/func_name still return built-ins with nil catalog. + candidates := resolveRule("func_name", nil, "", 0) + if len(candidates) == 0 { + t.Error("func_name should return built-in functions even with nil catalog") + } +} + +// --- Section 3.1: SELECT Target List --- + +func TestComplete_3_1_SelectTargetList(t *testing.T) { + cat := catalog.New() + mustExec(t, cat, "CREATE DATABASE test") + cat.SetCurrentDatabase("test") + mustExec(t, cat, "CREATE TABLE t (a INT, b INT, c INT)") + mustExec(t, cat, "CREATE TABLE t1 (x INT, y INT)") + + cases := []struct { + name string + sql string + cursor int + wantCol string // expected column candidate (or "" to skip) + wantFunc bool // expect function candidates + wantKW string // expected keyword candidate (or "" to skip) + absentType CandidateType + absentText string + }{ + { + name: "select_pipe_columnref", + sql: "SELECT ", + cursor: 7, + wantCol: "a", + wantFunc: true, + wantKW: "DISTINCT", + }, + { + name: "select_after_comma", + sql: "SELECT a, ", + cursor: 10, + wantCol: "a", + wantFunc: true, + }, + { + name: "select_after_two_commas", + sql: "SELECT a, b, ", + cursor: 13, + wantCol: "c", + wantFunc: true, + }, + { + name: "select_subquery", + sql: "SELECT * FROM t WHERE a > (SELECT ", + cursor: 34, + wantCol: "a", + wantFunc: true, + }, + { + name: "select_distinct_pipe", + sql: "SELECT DISTINCT ", + cursor: 16, + wantCol: "a", + wantFunc: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + candidates := Complete(tc.sql, tc.cursor, cat) + if len(candidates) == 0 { + t.Fatal("expected candidates, got none") + } + + if tc.wantCol != "" { + if !containsCandidate(candidates, tc.wantCol, CandidateColumn) { + t.Errorf("missing column candidate %q; got %v", tc.wantCol, candidates) + } + } + + if tc.wantFunc { + if !containsCandidate(candidates, "COUNT", CandidateFunction) { + t.Errorf("missing function candidate COUNT; got %v", candidates) + } + } + + if tc.wantKW != "" { + if !containsCandidate(candidates, tc.wantKW, CandidateKeyword) { + t.Errorf("missing keyword candidate %q; got %v", tc.wantKW, candidates) + } + } + + if tc.absentText != "" { + if containsCandidate(candidates, tc.absentText, tc.absentType) { + t.Errorf("unexpected candidate %q of type %d", tc.absentText, tc.absentType) + } + } + }) + } +} + +// --- Section 3.2: FROM Clause --- + +func TestComplete_3_2_FromClause(t *testing.T) { + cat := catalog.New() + mustExec(t, cat, "CREATE DATABASE test") + mustExec(t, cat, "CREATE DATABASE testdb2") + cat.SetCurrentDatabase("test") + mustExec(t, cat, "CREATE TABLE t (a INT, b INT, c INT)") + mustExec(t, cat, "CREATE TABLE t1 (x INT, y INT)") + mustExec(t, cat, "CREATE TABLE t2 (p INT, q INT)") + mustExec(t, cat, "CREATE VIEW v1 AS SELECT * FROM t") + + cases := []struct { + name string + sql string + cursor int + wantType CandidateType + wantText string // expected candidate text (or "" to skip check) + wantAbsent string // text that should NOT appear (or "" to skip) + absentType CandidateType + }{ + { + // Scenario 1: SELECT * FROM | → table_ref (tables, views, databases) + name: "from_table_ref", + sql: "SELECT * FROM ", + cursor: 14, + wantType: CandidateTable, + wantText: "t", + }, + { + // Scenario 1 continued: views should also appear + name: "from_view_ref", + sql: "SELECT * FROM ", + cursor: 14, + wantType: CandidateView, + wantText: "v1", + }, + { + // Scenario 2: SELECT * FROM db.| → table_ref qualified with database + name: "from_qualified_table_ref", + sql: "SELECT * FROM test.", + cursor: 19, + wantType: CandidateTable, + wantText: "t", + }, + { + // Scenario 3: SELECT * FROM t1, | → table_ref after comma + name: "from_comma_table_ref", + sql: "SELECT * FROM t1, ", + cursor: 18, + wantType: CandidateTable, + wantText: "t2", + }, + { + // Scenario 4: SELECT * FROM (SELECT * FROM |) → table_ref in derived table + name: "from_derived_table_ref", + sql: "SELECT * FROM (SELECT * FROM ", + cursor: 29, + wantType: CandidateTable, + wantText: "t", + }, + { + // Scenario 5: SELECT * FROM t | → keyword candidates + name: "from_after_table_keywords", + sql: "SELECT * FROM t ", + cursor: 16, + wantType: CandidateKeyword, + wantText: "WHERE", + }, + { + // Scenario 5 continued: JOIN keyword should appear + name: "from_after_table_join", + sql: "SELECT * FROM t ", + cursor: 16, + wantType: CandidateKeyword, + wantText: "JOIN", + }, + { + // Scenario 5 continued: LEFT keyword should appear + name: "from_after_table_left", + sql: "SELECT * FROM t ", + cursor: 16, + wantType: CandidateKeyword, + wantText: "LEFT", + }, + { + // Scenario 5 continued: RIGHT keyword should appear + name: "from_after_table_right", + sql: "SELECT * FROM t ", + cursor: 16, + wantType: CandidateKeyword, + wantText: "RIGHT", + }, + { + // Scenario 6: SELECT * FROM t AS | → no specific candidates (alias context) + // Tables/views should NOT appear since we're in an alias context. + name: "from_alias_no_table", + sql: "SELECT * FROM t AS ", + cursor: 19, + wantType: CandidateTable, + wantText: "", + wantAbsent: "t", + absentType: CandidateTable, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + candidates := Complete(tc.sql, tc.cursor, cat) + + if tc.wantText != "" { + if !containsCandidate(candidates, tc.wantText, tc.wantType) { + t.Errorf("missing candidate %q of type %d; got %v", tc.wantText, tc.wantType, candidates) + } + } + + if tc.wantAbsent != "" { + if containsCandidate(candidates, tc.wantAbsent, tc.absentType) { + t.Errorf("unexpected candidate %q of type %d should not appear", tc.wantAbsent, tc.absentType) + } + } + }) + } +} + +// --- Section 3.3: JOIN Clauses --- + +func TestComplete_3_3_JoinClauses(t *testing.T) { + cat := catalog.New() + mustExec(t, cat, "CREATE DATABASE test") + cat.SetCurrentDatabase("test") + mustExec(t, cat, "CREATE TABLE t1 (a INT, b INT)") + mustExec(t, cat, "CREATE TABLE t2 (a INT, c INT)") + mustExec(t, cat, "CREATE TABLE t3 (x INT, y INT)") + + cases := []struct { + name string + sql string + cursor int + wantType CandidateType + wantText string + wantAbsent string + absentType CandidateType + }{ + { + // Scenario 1: SELECT * FROM t1 JOIN | → table_ref after JOIN + name: "join_table_ref", + sql: "SELECT * FROM t1 JOIN ", + cursor: 22, + wantType: CandidateTable, + wantText: "t2", + }, + { + // Scenario 2: SELECT * FROM t1 LEFT JOIN | → table_ref after LEFT JOIN + name: "left_join_table_ref", + sql: "SELECT * FROM t1 LEFT JOIN ", + cursor: 27, + wantType: CandidateTable, + wantText: "t2", + }, + { + // Scenario 3: SELECT * FROM t1 RIGHT JOIN | → table_ref after RIGHT JOIN + name: "right_join_table_ref", + sql: "SELECT * FROM t1 RIGHT JOIN ", + cursor: 28, + wantType: CandidateTable, + wantText: "t3", + }, + { + // Scenario 4: SELECT * FROM t1 CROSS JOIN | → table_ref after CROSS JOIN + name: "cross_join_table_ref", + sql: "SELECT * FROM t1 CROSS JOIN ", + cursor: 28, + wantType: CandidateTable, + wantText: "t2", + }, + { + // Scenario 5: SELECT * FROM t1 NATURAL JOIN | → table_ref after NATURAL JOIN + name: "natural_join_table_ref", + sql: "SELECT * FROM t1 NATURAL JOIN ", + cursor: 30, + wantType: CandidateTable, + wantText: "t2", + }, + { + // Scenario 6: SELECT * FROM t1 STRAIGHT_JOIN | → table_ref after STRAIGHT_JOIN + name: "straight_join_table_ref", + sql: "SELECT * FROM t1 STRAIGHT_JOIN ", + cursor: 31, + wantType: CandidateTable, + wantText: "t2", + }, + { + // Scenario 7: SELECT * FROM t1 JOIN t2 ON | → columnref after ON + name: "join_on_columnref", + sql: "SELECT * FROM t1 JOIN t2 ON ", + cursor: 28, + wantType: CandidateColumn, + wantText: "a", + }, + { + // Scenario 8: SELECT * FROM t1 JOIN t2 USING (| → columnref after USING ( + name: "join_using_columnref", + sql: "SELECT * FROM t1 JOIN t2 USING (", + cursor: 32, + wantType: CandidateColumn, + wantText: "a", + }, + { + // Scenario 9: SELECT * FROM t1 | → JOIN keywords + name: "after_table_join_keywords", + sql: "SELECT * FROM t1 ", + cursor: 17, + wantType: CandidateKeyword, + wantText: "JOIN", + }, + { + // Scenario 9 continued: LEFT keyword + name: "after_table_left_keyword", + sql: "SELECT * FROM t1 ", + cursor: 17, + wantType: CandidateKeyword, + wantText: "LEFT", + }, + { + // Scenario 9 continued: RIGHT keyword + name: "after_table_right_keyword", + sql: "SELECT * FROM t1 ", + cursor: 17, + wantType: CandidateKeyword, + wantText: "RIGHT", + }, + { + // Scenario 9 continued: INNER keyword + name: "after_table_inner_keyword", + sql: "SELECT * FROM t1 ", + cursor: 17, + wantType: CandidateKeyword, + wantText: "INNER", + }, + { + // Scenario 9 continued: CROSS keyword + name: "after_table_cross_keyword", + sql: "SELECT * FROM t1 ", + cursor: 17, + wantType: CandidateKeyword, + wantText: "CROSS", + }, + { + // Scenario 9 continued: NATURAL keyword + name: "after_table_natural_keyword", + sql: "SELECT * FROM t1 ", + cursor: 17, + wantType: CandidateKeyword, + wantText: "NATURAL", + }, + { + // Scenario 9 continued: STRAIGHT_JOIN keyword + name: "after_table_straight_join_keyword", + sql: "SELECT * FROM t1 ", + cursor: 17, + wantType: CandidateKeyword, + wantText: "STRAIGHT_JOIN", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + candidates := Complete(tc.sql, tc.cursor, cat) + if len(candidates) == 0 { + t.Fatalf("expected candidates, got none") + } + + if tc.wantText != "" { + if !containsCandidate(candidates, tc.wantText, tc.wantType) { + t.Errorf("missing candidate %q of type %d; got %v", tc.wantText, tc.wantType, candidates) + } + } + + if tc.wantAbsent != "" { + if containsCandidate(candidates, tc.wantAbsent, tc.absentType) { + t.Errorf("unexpected candidate %q of type %d should not appear", tc.wantAbsent, tc.absentType) + } + } + }) + } +} + +// --- Section 3.4: WHERE, GROUP BY, HAVING --- + +func TestComplete_3_4_WhereGroupByHaving(t *testing.T) { + cat := catalog.New() + mustExec(t, cat, "CREATE DATABASE test") + cat.SetCurrentDatabase("test") + mustExec(t, cat, "CREATE TABLE t (a INT, b INT, c INT)") + mustExec(t, cat, "CREATE TABLE t1 (x INT, y INT)") + + cases := []struct { + name string + sql string + cursor int + wantType CandidateType + wantText string + wantAbsent string + absentType CandidateType + }{ + { + // Scenario 1: SELECT * FROM t WHERE | → columnref after WHERE + name: "where_columnref", + sql: "SELECT * FROM t WHERE ", + cursor: 22, + wantType: CandidateColumn, + wantText: "a", + }, + { + // Scenario 2: SELECT * FROM t WHERE a = 1 AND | → columnref after AND + name: "where_and_columnref", + sql: "SELECT * FROM t WHERE a = 1 AND ", + cursor: 32, + wantType: CandidateColumn, + wantText: "b", + }, + { + // Scenario 3: SELECT * FROM t WHERE a = 1 OR | → columnref after OR + name: "where_or_columnref", + sql: "SELECT * FROM t WHERE a = 1 OR ", + cursor: 31, + wantType: CandidateColumn, + wantText: "c", + }, + { + // Scenario 4: SELECT * FROM t GROUP BY | → columnref after GROUP BY + name: "group_by_columnref", + sql: "SELECT * FROM t GROUP BY ", + cursor: 25, + wantType: CandidateColumn, + wantText: "a", + }, + { + // Scenario 5: SELECT * FROM t GROUP BY a, | → columnref after comma + name: "group_by_comma_columnref", + sql: "SELECT * FROM t GROUP BY a, ", + cursor: 28, + wantType: CandidateColumn, + wantText: "b", + }, + { + // Scenario 6: SELECT * FROM t GROUP BY a | → keyword candidates + name: "group_by_follow_having", + sql: "SELECT * FROM t GROUP BY a ", + cursor: 27, + wantType: CandidateKeyword, + wantText: "HAVING", + }, + { + // Scenario 7: SELECT * FROM t HAVING | → columnref after HAVING + name: "having_columnref", + sql: "SELECT * FROM t HAVING ", + cursor: 23, + wantType: CandidateColumn, + wantText: "a", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + candidates := Complete(tc.sql, tc.cursor, cat) + if len(candidates) == 0 { + t.Fatalf("expected candidates, got none") + } + + if tc.wantText != "" { + if !containsCandidate(candidates, tc.wantText, tc.wantType) { + t.Errorf("missing candidate %q of type %d; got %v", tc.wantText, tc.wantType, candidates) + } + } + + if tc.wantAbsent != "" { + if containsCandidate(candidates, tc.wantAbsent, tc.absentType) { + t.Errorf("unexpected candidate %q of type %d should not appear", tc.wantAbsent, tc.absentType) + } + } + }) + } + + // Additional checks for GROUP BY follow-set keywords (ORDER, LIMIT, WITH). + t.Run("group_by_follow_order", func(t *testing.T) { + candidates := Complete("SELECT * FROM t GROUP BY a ", 27, cat) + if !containsCandidate(candidates, "ORDER", CandidateKeyword) { + t.Errorf("missing keyword ORDER after GROUP BY list; got %v", candidates) + } + }) + t.Run("group_by_follow_limit", func(t *testing.T) { + candidates := Complete("SELECT * FROM t GROUP BY a ", 27, cat) + if !containsCandidate(candidates, "LIMIT", CandidateKeyword) { + t.Errorf("missing keyword LIMIT after GROUP BY list; got %v", candidates) + } + }) + t.Run("group_by_follow_with", func(t *testing.T) { + candidates := Complete("SELECT * FROM t GROUP BY a ", 27, cat) + if !containsCandidate(candidates, "WITH", CandidateKeyword) { + t.Errorf("missing keyword WITH (for WITH ROLLUP) after GROUP BY list; got %v", candidates) + } + }) +} + +// --- Section 3.5: ORDER BY, LIMIT, DISTINCT --- + +func TestComplete_3_5_OrderByLimitDistinct(t *testing.T) { + cat := catalog.New() + mustExec(t, cat, "CREATE DATABASE test") + cat.SetCurrentDatabase("test") + mustExec(t, cat, "CREATE TABLE t (a INT, b INT, c INT)") + + cases := []struct { + name string + sql string + cursor int + wantType CandidateType + wantText string + }{ + { + // Scenario 1: SELECT * FROM t ORDER BY | → columnref after ORDER BY + name: "order_by_columnref", + sql: "SELECT * FROM t ORDER BY ", + cursor: 25, + wantType: CandidateColumn, + wantText: "a", + }, + { + // Scenario 2: SELECT * FROM t ORDER BY a, | → columnref after comma + name: "order_by_comma_columnref", + sql: "SELECT * FROM t ORDER BY a, ", + cursor: 28, + wantType: CandidateColumn, + wantText: "b", + }, + { + // Scenario 3: SELECT * FROM t ORDER BY a | → keyword candidates (ASC, DESC, LIMIT) + name: "order_by_follow_asc", + sql: "SELECT * FROM t ORDER BY a ", + cursor: 27, + wantType: CandidateKeyword, + wantText: "ASC", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + candidates := Complete(tc.sql, tc.cursor, cat) + if len(candidates) == 0 { + t.Fatalf("expected candidates, got none") + } + if !containsCandidate(candidates, tc.wantText, tc.wantType) { + t.Errorf("missing candidate %q of type %d; got %v", tc.wantText, tc.wantType, candidates) + } + }) + } + + // Additional checks for ORDER BY follow-set keywords. + t.Run("order_by_follow_desc", func(t *testing.T) { + candidates := Complete("SELECT * FROM t ORDER BY a ", 27, cat) + if !containsCandidate(candidates, "DESC", CandidateKeyword) { + t.Errorf("missing keyword DESC after ORDER BY item; got %v", candidates) + } + }) + t.Run("order_by_follow_limit", func(t *testing.T) { + candidates := Complete("SELECT * FROM t ORDER BY a ", 27, cat) + if !containsCandidate(candidates, "LIMIT", CandidateKeyword) { + t.Errorf("missing keyword LIMIT after ORDER BY item; got %v", candidates) + } + }) + + // Scenario 4: SELECT * FROM t LIMIT | → no specific candidates (numeric context) + // The expression parser will offer columnref/func_name which is acceptable for LIMIT expressions. + t.Run("limit_numeric_context", func(t *testing.T) { + candidates := Complete("SELECT * FROM t LIMIT ", 22, cat) + // Should have some candidates (expression context), not panic. + _ = candidates + }) + + // Scenario 5: SELECT * FROM t LIMIT 10 OFFSET | → no specific candidates + t.Run("limit_offset_numeric_context", func(t *testing.T) { + candidates := Complete("SELECT * FROM t LIMIT 10 OFFSET ", 32, cat) + _ = candidates + }) +} + +// --- Section 3.6: Set Operations & FOR UPDATE --- + +func TestComplete_3_6_SetOperationsForUpdate(t *testing.T) { + cat := catalog.New() + mustExec(t, cat, "CREATE DATABASE test") + cat.SetCurrentDatabase("test") + mustExec(t, cat, "CREATE TABLE t (a INT, b INT)") + + cases := []struct { + name string + sql string + cursor int + wantType CandidateType + wantText string + }{ + { + // Scenario 1: SELECT a FROM t UNION | → keyword candidates (ALL, SELECT) + name: "union_all_or_select", + sql: "SELECT a FROM t UNION ", + cursor: 22, + wantType: CandidateKeyword, + wantText: "ALL", + }, + { + // Scenario 2: SELECT a FROM t UNION ALL | → keyword candidate (SELECT) + name: "union_all_select", + sql: "SELECT a FROM t UNION ALL ", + cursor: 26, + wantType: CandidateKeyword, + wantText: "SELECT", + }, + { + // Scenario 3: SELECT a FROM t INTERSECT | → keyword candidates (ALL, SELECT) + name: "intersect_all_or_select", + sql: "SELECT a FROM t INTERSECT ", + cursor: 26, + wantType: CandidateKeyword, + wantText: "ALL", + }, + { + // Scenario 4: SELECT a FROM t EXCEPT | → keyword candidates (ALL, SELECT) + name: "except_all_or_select", + sql: "SELECT a FROM t EXCEPT ", + cursor: 23, + wantType: CandidateKeyword, + wantText: "ALL", + }, + { + // Scenario 5: SELECT * FROM t FOR | → keyword candidates (UPDATE, SHARE) + name: "for_update_share", + sql: "SELECT * FROM t FOR ", + cursor: 20, + wantType: CandidateKeyword, + wantText: "UPDATE", + }, + { + // Scenario 6: SELECT * FROM t FOR UPDATE | → keyword candidates (OF, NOWAIT, SKIP) + name: "for_update_options", + sql: "SELECT * FROM t FOR UPDATE ", + cursor: 27, + wantType: CandidateKeyword, + wantText: "NOWAIT", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + candidates := Complete(tc.sql, tc.cursor, cat) + if len(candidates) == 0 { + t.Fatalf("expected candidates, got none") + } + if !containsCandidate(candidates, tc.wantText, tc.wantType) { + t.Errorf("missing candidate %q of type %d; got %v", tc.wantText, tc.wantType, candidates) + } + }) + } + + // Additional checks for set operation keywords. + t.Run("union_select_keyword", func(t *testing.T) { + candidates := Complete("SELECT a FROM t UNION ", 22, cat) + if !containsCandidate(candidates, "SELECT", CandidateKeyword) { + t.Errorf("missing keyword SELECT after UNION; got %v", candidates) + } + }) + t.Run("intersect_select_keyword", func(t *testing.T) { + candidates := Complete("SELECT a FROM t INTERSECT ", 26, cat) + if !containsCandidate(candidates, "SELECT", CandidateKeyword) { + t.Errorf("missing keyword SELECT after INTERSECT; got %v", candidates) + } + }) + t.Run("except_select_keyword", func(t *testing.T) { + candidates := Complete("SELECT a FROM t EXCEPT ", 23, cat) + if !containsCandidate(candidates, "SELECT", CandidateKeyword) { + t.Errorf("missing keyword SELECT after EXCEPT; got %v", candidates) + } + }) + t.Run("for_share_keyword", func(t *testing.T) { + candidates := Complete("SELECT * FROM t FOR ", 20, cat) + if !containsCandidate(candidates, "SHARE", CandidateKeyword) { + t.Errorf("missing keyword SHARE after FOR; got %v", candidates) + } + }) + t.Run("for_update_skip_keyword", func(t *testing.T) { + candidates := Complete("SELECT * FROM t FOR UPDATE ", 27, cat) + if !containsCandidate(candidates, "SKIP", CandidateKeyword) { + t.Errorf("missing keyword SKIP after FOR UPDATE; got %v", candidates) + } + }) +} + +// --- Section 3.7: CTE (WITH Clause) --- + +func TestComplete_3_7_CTE(t *testing.T) { + cat := catalog.New() + mustExec(t, cat, "CREATE DATABASE test") + cat.SetCurrentDatabase("test") + mustExec(t, cat, "CREATE TABLE t (a INT, b INT, c INT)") + + // Scenario 1: WITH | → keyword candidate (RECURSIVE) + identifier context + t.Run("with_recursive_keyword", func(t *testing.T) { + candidates := Complete("WITH ", 5, cat) + if !containsCandidate(candidates, "RECURSIVE", CandidateKeyword) { + t.Errorf("missing keyword RECURSIVE after WITH; got %v", candidates) + } + }) + + // Scenario 2: WITH cte AS (|) → keyword candidate (SELECT) + t.Run("with_cte_as_select", func(t *testing.T) { + candidates := Complete("WITH cte AS (", 13, cat) + if !containsCandidate(candidates, "SELECT", CandidateKeyword) { + t.Errorf("missing keyword SELECT inside CTE AS (); got %v", candidates) + } + }) + + // Scenario 3: WITH cte AS (SELECT * FROM t) SELECT | → columnref (CTE columns available) + t.Run("with_cte_select_columnref", func(t *testing.T) { + candidates := Complete("WITH cte AS (SELECT * FROM t) SELECT ", 37, cat) + if len(candidates) == 0 { + t.Fatal("expected candidates, got none") + } + // Should get column/function candidates in SELECT expression context. + hasColOrFunc := false + for _, c := range candidates { + if c.Type == CandidateColumn || c.Type == CandidateFunction { + hasColOrFunc = true + break + } + } + if !hasColOrFunc { + t.Errorf("expected column or function candidates after CTE SELECT; got %v", candidates) + } + }) + + // Scenario 4: WITH cte AS (SELECT * FROM t) SELECT * FROM | → table_ref (CTE name available) + t.Run("with_cte_from_table_ref", func(t *testing.T) { + candidates := Complete("WITH cte AS (SELECT * FROM t) SELECT * FROM ", 45, cat) + if len(candidates) == 0 { + t.Fatal("expected candidates, got none") + } + // Should get table_ref candidates. + hasTable := false + for _, c := range candidates { + if c.Type == CandidateTable || c.Type == CandidateView { + hasTable = true + break + } + } + if !hasTable { + t.Errorf("expected table candidates after CTE FROM; got %v", candidates) + } + }) + + // Scenario 5: WITH RECURSIVE cte(|) → identifier context for column names + // This is an identifier context — just verify no panic and something reasonable. + t.Run("with_recursive_cte_columns", func(t *testing.T) { + candidates := Complete("WITH RECURSIVE cte(", 19, cat) + // Must not panic. Results may be empty or have column candidates. + _ = candidates + }) +} + +// --- Section 3.8: Window Functions & Index Hints --- + +func TestComplete_3_8_WindowFunctionsIndexHints(t *testing.T) { + cat := catalog.New() + mustExec(t, cat, "CREATE DATABASE test") + cat.SetCurrentDatabase("test") + mustExec(t, cat, "CREATE TABLE t (a INT, b INT, c INT)") + mustExec(t, cat, "CREATE INDEX idx_a ON t (a)") + mustExec(t, cat, "CREATE INDEX idx_b ON t (b)") + + // Scenario 1: SELECT a, ROW_NUMBER() OVER (|) → keyword candidates (PARTITION, ORDER) + t.Run("over_partition_order", func(t *testing.T) { + candidates := Complete("SELECT a, ROW_NUMBER() OVER (", 29, cat) + if !containsCandidate(candidates, "PARTITION", CandidateKeyword) { + t.Errorf("missing keyword PARTITION inside OVER (); got %v", candidates) + } + if !containsCandidate(candidates, "ORDER", CandidateKeyword) { + t.Errorf("missing keyword ORDER inside OVER (); got %v", candidates) + } + }) + + // Scenario 2: SELECT a, SUM(b) OVER (PARTITION BY |) → columnref + t.Run("over_partition_by_columnref", func(t *testing.T) { + candidates := Complete("SELECT a, SUM(b) OVER (PARTITION BY ", 36, cat) + if !containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("missing column 'a' after PARTITION BY; got %v", candidates) + } + }) + + // Scenario 3: SELECT a, SUM(b) OVER (ORDER BY |) → columnref + t.Run("over_order_by_columnref", func(t *testing.T) { + candidates := Complete("SELECT a, SUM(b) OVER (ORDER BY ", 32, cat) + if !containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("missing column 'a' after ORDER BY in window; got %v", candidates) + } + }) + + // Scenario 4: SELECT a, SUM(b) OVER (ORDER BY a ROWS |) → keyword candidates (BETWEEN, UNBOUNDED, CURRENT) + t.Run("over_rows_frame_keywords", func(t *testing.T) { + candidates := Complete("SELECT a, SUM(b) OVER (ORDER BY a ROWS ", 39, cat) + if !containsCandidate(candidates, "BETWEEN", CandidateKeyword) { + t.Errorf("missing keyword BETWEEN after ROWS; got %v", candidates) + } + if !containsCandidate(candidates, "UNBOUNDED", CandidateKeyword) { + t.Errorf("missing keyword UNBOUNDED after ROWS; got %v", candidates) + } + if !containsCandidate(candidates, "CURRENT", CandidateKeyword) { + t.Errorf("missing keyword CURRENT after ROWS; got %v", candidates) + } + }) + + // Scenario 5: SELECT * FROM t USE INDEX (|) → index_ref + t.Run("use_index_ref", func(t *testing.T) { + candidates := Complete("SELECT * FROM t USE INDEX (", 27, cat) + if !containsCandidate(candidates, "idx_a", CandidateIndex) { + t.Errorf("missing index 'idx_a' in USE INDEX (); got %v", candidates) + } + if !containsCandidate(candidates, "idx_b", CandidateIndex) { + t.Errorf("missing index 'idx_b' in USE INDEX (); got %v", candidates) + } + }) + + // Scenario 6: SELECT * FROM t FORCE INDEX (|) → index_ref + t.Run("force_index_ref", func(t *testing.T) { + candidates := Complete("SELECT * FROM t FORCE INDEX (", 29, cat) + if !containsCandidate(candidates, "idx_a", CandidateIndex) { + t.Errorf("missing index 'idx_a' in FORCE INDEX (); got %v", candidates) + } + }) + + // Scenario 7: SELECT * FROM t IGNORE INDEX (|) → index_ref + t.Run("ignore_index_ref", func(t *testing.T) { + candidates := Complete("SELECT * FROM t IGNORE INDEX (", 30, cat) + if !containsCandidate(candidates, "idx_a", CandidateIndex) { + t.Errorf("missing index 'idx_a' in IGNORE INDEX (); got %v", candidates) + } + }) +} + +// --- Section 4.1: INSERT --- + +func TestComplete_4_1_Insert(t *testing.T) { + cat := catalog.New() + mustExec(t, cat, "CREATE DATABASE testdb") + cat.SetCurrentDatabase("testdb") + mustExec(t, cat, "CREATE TABLE users (id INT, name VARCHAR(100), email VARCHAR(200))") + mustExec(t, cat, "CREATE TABLE orders (id INT, user_id INT, total DECIMAL(10,2))") + + // Scenario 1: INSERT INTO | → table_ref + t.Run("insert_into_table_ref", func(t *testing.T) { + candidates := Complete("INSERT INTO ", 12, cat) + if !containsCandidate(candidates, "users", CandidateTable) { + t.Errorf("missing table 'users'; got %v", candidates) + } + if !containsCandidate(candidates, "orders", CandidateTable) { + t.Errorf("missing table 'orders'; got %v", candidates) + } + }) + + // Scenario 2: INSERT INTO t (|) → columnref for table t + t.Run("insert_column_list", func(t *testing.T) { + candidates := Complete("INSERT INTO users (", 19, cat) + if !containsCandidate(candidates, "id", CandidateColumn) { + t.Errorf("missing column 'id'; got %v", candidates) + } + if !containsCandidate(candidates, "name", CandidateColumn) { + t.Errorf("missing column 'name'; got %v", candidates) + } + }) + + // Scenario 3: INSERT INTO t (a, |) → columnref after comma + t.Run("insert_column_list_after_comma", func(t *testing.T) { + candidates := Complete("INSERT INTO users (id, ", 23, cat) + if !containsCandidate(candidates, "name", CandidateColumn) { + t.Errorf("missing column 'name'; got %v", candidates) + } + }) + + // Scenario 4: INSERT INTO t VALUES (|) → no specific candidates (value context) + t.Run("insert_values_context", func(t *testing.T) { + candidates := Complete("INSERT INTO users VALUES (", 26, cat) + // Should not offer table or column candidates in value context. + // Values context offers expression candidates (columnref, func_name) from parseExpr. + // This is acceptable — the key is that it doesn't crash. + _ = candidates + }) + + // Scenario 5: INSERT INTO t | → keyword candidates (VALUES, SET, SELECT, PARTITION) + t.Run("insert_after_table_keywords", func(t *testing.T) { + candidates := Complete("INSERT INTO users ", 18, cat) + if !containsCandidate(candidates, "VALUES", CandidateKeyword) { + t.Errorf("missing keyword VALUES; got %v", candidates) + } + if !containsCandidate(candidates, "SET", CandidateKeyword) { + t.Errorf("missing keyword SET; got %v", candidates) + } + if !containsCandidate(candidates, "SELECT", CandidateKeyword) { + t.Errorf("missing keyword SELECT; got %v", candidates) + } + }) + + // Scenario 6: INSERT INTO t VALUES (1) ON DUPLICATE KEY UPDATE | → columnref + t.Run("insert_on_duplicate_key_update", func(t *testing.T) { + candidates := Complete("INSERT INTO users VALUES (1) ON DUPLICATE KEY UPDATE ", 53, cat) + if !containsCandidate(candidates, "id", CandidateColumn) { + t.Errorf("missing column 'id'; got %v", candidates) + } + if !containsCandidate(candidates, "name", CandidateColumn) { + t.Errorf("missing column 'name'; got %v", candidates) + } + }) + + // Scenario 7: INSERT INTO t SET | → columnref (assignment context) + t.Run("insert_set_columnref", func(t *testing.T) { + candidates := Complete("INSERT INTO users SET ", 22, cat) + if !containsCandidate(candidates, "id", CandidateColumn) { + t.Errorf("missing column 'id'; got %v", candidates) + } + if !containsCandidate(candidates, "name", CandidateColumn) { + t.Errorf("missing column 'name'; got %v", candidates) + } + }) + + // Scenario 8: INSERT INTO t SELECT | → columnref (INSERT SELECT) + t.Run("insert_select", func(t *testing.T) { + candidates := Complete("INSERT INTO users SELECT ", 25, cat) + // After SELECT, should offer columnref and func_name. + hasColumnOrKeyword := false + for _, c := range candidates { + if c.Type == CandidateColumn || c.Type == CandidateKeyword || c.Type == CandidateFunction { + hasColumnOrKeyword = true + break + } + } + if !hasColumnOrKeyword { + t.Errorf("expected column/keyword/function candidates after INSERT ... SELECT; got %v", candidates) + } + }) + + // Scenario 9: REPLACE INTO | → table_ref + t.Run("replace_into_table_ref", func(t *testing.T) { + candidates := Complete("REPLACE INTO ", 13, cat) + if !containsCandidate(candidates, "users", CandidateTable) { + t.Errorf("missing table 'users'; got %v", candidates) + } + if !containsCandidate(candidates, "orders", CandidateTable) { + t.Errorf("missing table 'orders'; got %v", candidates) + } + }) +} + +// --- Section 4.2: UPDATE --- + +func TestComplete_4_2_Update(t *testing.T) { + cat := catalog.New() + mustExec(t, cat, "CREATE DATABASE testdb") + cat.SetCurrentDatabase("testdb") + mustExec(t, cat, "CREATE TABLE t (a INT, b INT, c INT)") + mustExec(t, cat, "CREATE TABLE t1 (a INT, b INT)") + mustExec(t, cat, "CREATE TABLE t2 (a INT, b INT)") + + // Scenario 1: UPDATE | → table_ref + t.Run("update_table_ref", func(t *testing.T) { + candidates := Complete("UPDATE ", 7, cat) + if !containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("missing table 't'; got %v", candidates) + } + if !containsCandidate(candidates, "t1", CandidateTable) { + t.Errorf("missing table 't1'; got %v", candidates) + } + }) + + // Scenario 2: UPDATE t SET | → columnref for table t + t.Run("update_set_columnref", func(t *testing.T) { + candidates := Complete("UPDATE t SET ", 13, cat) + if !containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("missing column 'a'; got %v", candidates) + } + if !containsCandidate(candidates, "b", CandidateColumn) { + t.Errorf("missing column 'b'; got %v", candidates) + } + }) + + // Scenario 3: UPDATE t SET a = 1, | → columnref after comma + t.Run("update_set_after_comma", func(t *testing.T) { + candidates := Complete("UPDATE t SET a = 1, ", 20, cat) + if !containsCandidate(candidates, "b", CandidateColumn) { + t.Errorf("missing column 'b'; got %v", candidates) + } + }) + + // Scenario 4: UPDATE t SET a = 1 WHERE | → columnref + t.Run("update_where_columnref", func(t *testing.T) { + candidates := Complete("UPDATE t SET a = 1 WHERE ", 25, cat) + if !containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("missing column 'a'; got %v", candidates) + } + if !containsCandidate(candidates, "b", CandidateColumn) { + t.Errorf("missing column 'b'; got %v", candidates) + } + }) + + // Scenario 5: UPDATE t SET a = 1 ORDER BY | → columnref + t.Run("update_order_by_columnref", func(t *testing.T) { + candidates := Complete("UPDATE t SET a = 1 ORDER BY ", 28, cat) + if !containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("missing column 'a'; got %v", candidates) + } + }) + + // Scenario 6: UPDATE t1 JOIN t2 ON t1.a = t2.a SET | → columnref from both tables + t.Run("update_multi_table_set", func(t *testing.T) { + candidates := Complete("UPDATE t1 JOIN t2 ON t1.a = t2.a SET ", 37, cat) + if !containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("missing column 'a'; got %v", candidates) + } + if !containsCandidate(candidates, "b", CandidateColumn) { + t.Errorf("missing column 'b'; got %v", candidates) + } + }) +} + +// --- Section 4.3: DELETE & LOAD DATA & CALL --- + +func TestComplete_4_3_DeleteLoadCall(t *testing.T) { + cat := catalog.New() + mustExec(t, cat, "CREATE DATABASE testdb") + cat.SetCurrentDatabase("testdb") + mustExec(t, cat, "CREATE TABLE t (a INT, b INT, c INT)") + mustExec(t, cat, "CREATE TABLE t1 (a INT, b INT)") + mustExec(t, cat, "CREATE TABLE t2 (a INT, b INT)") + mustExec(t, cat, "CREATE PROCEDURE my_proc() BEGIN SELECT 1; END") + + // Scenario 1: DELETE FROM | → table_ref + t.Run("delete_from_table_ref", func(t *testing.T) { + candidates := Complete("DELETE FROM ", 12, cat) + if !containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("missing table 't'; got %v", candidates) + } + if !containsCandidate(candidates, "t1", CandidateTable) { + t.Errorf("missing table 't1'; got %v", candidates) + } + }) + + // Scenario 2: DELETE FROM t WHERE | → columnref for table t + t.Run("delete_where_columnref", func(t *testing.T) { + candidates := Complete("DELETE FROM t WHERE ", 20, cat) + if !containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("missing column 'a'; got %v", candidates) + } + if !containsCandidate(candidates, "b", CandidateColumn) { + t.Errorf("missing column 'b'; got %v", candidates) + } + }) + + // Scenario 3: DELETE FROM t ORDER BY | → columnref + t.Run("delete_order_by_columnref", func(t *testing.T) { + candidates := Complete("DELETE FROM t ORDER BY ", 23, cat) + if !containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("missing column 'a'; got %v", candidates) + } + }) + + // Scenario 4: DELETE t1 FROM t1 JOIN t2 ON t1.a = t2.a WHERE | → columnref from both tables + t.Run("delete_multi_table_where", func(t *testing.T) { + candidates := Complete("DELETE t1 FROM t1 JOIN t2 ON t1.a = t2.a WHERE ", 48, cat) + if !containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("missing column 'a'; got %v", candidates) + } + if !containsCandidate(candidates, "b", CandidateColumn) { + t.Errorf("missing column 'b'; got %v", candidates) + } + }) + + // Scenario 5: LOAD DATA INFILE 'f' INTO TABLE | → table_ref + t.Run("load_data_into_table_ref", func(t *testing.T) { + candidates := Complete("LOAD DATA INFILE 'f' INTO TABLE ", 32, cat) + if !containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("missing table 't'; got %v", candidates) + } + if !containsCandidate(candidates, "t1", CandidateTable) { + t.Errorf("missing table 't1'; got %v", candidates) + } + }) + + // Scenario 6: CALL | → procedure_ref + t.Run("call_procedure_ref", func(t *testing.T) { + candidates := Complete("CALL ", 5, cat) + if !containsCandidate(candidates, "my_proc", CandidateProcedure) { + t.Errorf("missing procedure 'my_proc'; got %v", candidates) + } + }) +} + +// ────────────────────────────────────────────────────────────── +// Section 5.1: CREATE TABLE +// ────────────────────────────────────────────────────────────── + +func TestComplete_5_1_CreateTable(t *testing.T) { + cat := catalog.New() + mustExec(t, cat, "CREATE DATABASE testdb") + cat.SetCurrentDatabase("testdb") + mustExec(t, cat, "CREATE TABLE t (a INT, b INT, c INT)") + mustExec(t, cat, "CREATE TABLE t1 (a INT, b INT)") + mustExec(t, cat, "CREATE TABLE t2 (x INT, y INT)") + + // Scenario 1: CREATE TABLE | → identifier context (no specific candidates from catalog) + t.Run("create_table_identifier_context", func(t *testing.T) { + candidates := Complete("CREATE TABLE ", 13, cat) + // Should not contain table names (user is defining a new name) + if containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("should not suggest existing table 't' for new table name; got %v", candidates) + } + }) + + // Scenario 2: CREATE TABLE t (a INT, |) → constraint/column keywords + t.Run("create_table_column_constraint_keywords", func(t *testing.T) { + candidates := Complete("CREATE TABLE t (a INT, ", 22, cat) + for _, kw := range []string{"PRIMARY", "UNIQUE", "INDEX", "KEY", "FOREIGN", "CHECK", "CONSTRAINT"} { + if !containsCandidate(candidates, kw, CandidateKeyword) { + t.Errorf("missing keyword %q; got %v", kw, candidates) + } + } + }) + + // Scenario 3: CREATE TABLE t (a INT |) → column option keywords + t.Run("create_table_column_options", func(t *testing.T) { + sql := "CREATE TABLE t (a INT " + candidates := Complete(sql, len(sql), cat) + for _, kw := range []string{"NOT", "NULL", "DEFAULT", "AUTO_INCREMENT", "PRIMARY", "UNIQUE", "COMMENT", "COLLATE", "REFERENCES", "CHECK", "GENERATED"} { + if !containsCandidate(candidates, kw, CandidateKeyword) { + t.Errorf("missing keyword %q; got %v", kw, candidates) + } + } + }) + + // Scenario 4: CREATE TABLE t (a INT) | → table option keywords + t.Run("create_table_table_options", func(t *testing.T) { + candidates := Complete("CREATE TABLE t (a INT) ", 23, cat) + for _, kw := range []string{"ENGINE", "DEFAULT", "CHARSET", "COLLATE", "COMMENT", "AUTO_INCREMENT", "ROW_FORMAT", "PARTITION"} { + if !containsCandidate(candidates, kw, CandidateKeyword) { + t.Errorf("missing keyword %q; got %v", kw, candidates) + } + } + }) + + // Scenario 5: CREATE TABLE t (a INT) ENGINE=| → engine candidates + t.Run("create_table_engine_candidates", func(t *testing.T) { + candidates := Complete("CREATE TABLE t (a INT) ENGINE=", 30, cat) + for _, eng := range []string{"InnoDB", "MyISAM", "MEMORY"} { + if !containsCandidate(candidates, eng, CandidateEngine) { + t.Errorf("missing engine %q; got %v", eng, candidates) + } + } + }) + + // Scenario 6: CREATE TABLE t (a |) → type candidates + t.Run("create_table_type_candidates", func(t *testing.T) { + candidates := Complete("CREATE TABLE t (a ", 18, cat) + for _, typ := range []string{"INT", "VARCHAR", "TEXT", "BLOB", "DATE", "DATETIME", "DECIMAL"} { + if !containsCandidate(candidates, typ, CandidateType_) { + t.Errorf("missing type %q; got %v", typ, candidates) + } + } + }) + + // Scenario 7: CREATE TABLE t (FOREIGN KEY (a) REFERENCES |) → table_ref + t.Run("create_table_references_table_ref", func(t *testing.T) { + sql := "CREATE TABLE t3 (FOREIGN KEY (a) REFERENCES " + candidates := Complete(sql, len(sql), cat) + if !containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("missing table 't'; got %v", candidates) + } + if !containsCandidate(candidates, "t1", CandidateTable) { + t.Errorf("missing table 't1'; got %v", candidates) + } + }) + + // Scenario 8: CREATE TABLE t LIKE | → table_ref + t.Run("create_table_like_table_ref", func(t *testing.T) { + candidates := Complete("CREATE TABLE t3 LIKE ", 21, cat) + if !containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("missing table 't'; got %v", candidates) + } + }) + + // Scenario 9: CREATE TABLE t (a INT GENERATED ALWAYS AS (|)) → expression context + t.Run("create_table_generated_expr_context", func(t *testing.T) { + candidates := Complete("CREATE TABLE t3 (a INT GENERATED ALWAYS AS (", 44, cat) + hasCol := containsCandidate(candidates, "a", CandidateColumn) || + containsCandidate(candidates, "b", CandidateColumn) + hasFunc := false + for _, c := range candidates { + if c.Type == CandidateFunction { + hasFunc = true + break + } + } + if !hasCol && !hasFunc { + t.Errorf("expected columnref or func_name candidates; got %v", candidates) + } + }) +} + +// ────────────────────────────────────────────────────────────── +// Section 5.2: ALTER TABLE +// ────────────────────────────────────────────────────────────── + +func TestComplete_5_2_AlterTable(t *testing.T) { + cat := catalog.New() + mustExec(t, cat, "CREATE DATABASE testdb") + cat.SetCurrentDatabase("testdb") + mustExec(t, cat, "CREATE TABLE t (a INT, b INT, c INT, INDEX idx_a (a), CONSTRAINT fk_b FOREIGN KEY (b) REFERENCES t (a))") + mustExec(t, cat, "CREATE TABLE t1 (a INT, b INT)") + mustExec(t, cat, "CREATE TABLE t2 (x INT, y INT)") + + // Scenario 1: ALTER TABLE | → table_ref + t.Run("alter_table_table_ref", func(t *testing.T) { + candidates := Complete("ALTER TABLE ", 12, cat) + if !containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("missing table 't'; got %v", candidates) + } + if !containsCandidate(candidates, "t1", CandidateTable) { + t.Errorf("missing table 't1'; got %v", candidates) + } + }) + + // Scenario 2: ALTER TABLE t | → operation keywords + t.Run("alter_table_operation_keywords", func(t *testing.T) { + candidates := Complete("ALTER TABLE t ", 14, cat) + for _, kw := range []string{"ADD", "DROP", "MODIFY", "CHANGE", "RENAME", "ALTER", "CONVERT", "ENGINE", "DEFAULT", "ORDER", "ALGORITHM", "LOCK", "FORCE"} { + if !containsCandidate(candidates, kw, CandidateKeyword) { + t.Errorf("missing keyword %q; got %v", kw, candidates) + } + } + }) + + // Scenario 3: ALTER TABLE t ADD | → ADD keywords + t.Run("alter_table_add_keywords", func(t *testing.T) { + candidates := Complete("ALTER TABLE t ADD ", 18, cat) + for _, kw := range []string{"COLUMN", "INDEX", "KEY", "UNIQUE", "PRIMARY", "FOREIGN", "CONSTRAINT", "CHECK", "PARTITION", "SPATIAL", "FULLTEXT"} { + if !containsCandidate(candidates, kw, CandidateKeyword) { + t.Errorf("missing keyword %q; got %v", kw, candidates) + } + } + }) + + // Scenario 4: ALTER TABLE t ADD COLUMN | → identifier context + t.Run("alter_table_add_column_identifier", func(t *testing.T) { + candidates := Complete("ALTER TABLE t ADD COLUMN ", 24, cat) + // Should not suggest column names — user defines new column name + if containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("should not suggest existing column 'a' for new column name; got %v", candidates) + } + }) + + // Scenario 5: ALTER TABLE t DROP | → DROP keywords + t.Run("alter_table_drop_keywords", func(t *testing.T) { + candidates := Complete("ALTER TABLE t DROP ", 19, cat) + for _, kw := range []string{"COLUMN", "INDEX", "KEY", "FOREIGN", "PRIMARY", "CHECK", "CONSTRAINT", "PARTITION"} { + if !containsCandidate(candidates, kw, CandidateKeyword) { + t.Errorf("missing keyword %q; got %v", kw, candidates) + } + } + }) + + // Scenario 6: ALTER TABLE t DROP COLUMN | → columnref + t.Run("alter_table_drop_column_columnref", func(t *testing.T) { + sql := "ALTER TABLE t DROP COLUMN " + candidates := Complete(sql, len(sql), cat) + if !containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("missing column 'a'; got %v", candidates) + } + if !containsCandidate(candidates, "b", CandidateColumn) { + t.Errorf("missing column 'b'; got %v", candidates) + } + }) + + // Scenario 7: ALTER TABLE t DROP INDEX | → index_ref + t.Run("alter_table_drop_index_ref", func(t *testing.T) { + sql := "ALTER TABLE t DROP INDEX " + candidates := Complete(sql, len(sql), cat) + if !containsCandidate(candidates, "idx_a", CandidateIndex) { + t.Errorf("missing index 'idx_a'; got %v", candidates) + } + }) + + // Scenario 8: ALTER TABLE t DROP FOREIGN KEY | → constraint_ref + t.Run("alter_table_drop_fk_constraint_ref", func(t *testing.T) { + sql := "ALTER TABLE t DROP FOREIGN KEY " + candidates := Complete(sql, len(sql), cat) + if !containsCandidate(candidates, "fk_b", CandidateIndex) { + t.Errorf("missing constraint 'fk_b'; got %v", candidates) + } + }) + + // Scenario 9: ALTER TABLE t DROP CONSTRAINT | → constraint_ref + t.Run("alter_table_drop_constraint_ref", func(t *testing.T) { + sql := "ALTER TABLE t DROP CONSTRAINT " + candidates := Complete(sql, len(sql), cat) + if !containsCandidate(candidates, "fk_b", CandidateIndex) { + t.Errorf("missing constraint 'fk_b'; got %v", candidates) + } + }) + + // Scenario 10: ALTER TABLE t MODIFY | → columnref + t.Run("alter_table_modify_columnref", func(t *testing.T) { + candidates := Complete("ALTER TABLE t MODIFY ", 21, cat) + if !containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("missing column 'a'; got %v", candidates) + } + }) + + // Scenario 11: ALTER TABLE t MODIFY COLUMN | → columnref + t.Run("alter_table_modify_column_columnref", func(t *testing.T) { + candidates := Complete("ALTER TABLE t MODIFY COLUMN ", 28, cat) + if !containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("missing column 'a'; got %v", candidates) + } + }) + + // Scenario 12: ALTER TABLE t CHANGE | → columnref + t.Run("alter_table_change_columnref", func(t *testing.T) { + candidates := Complete("ALTER TABLE t CHANGE ", 21, cat) + if !containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("missing column 'a'; got %v", candidates) + } + }) + + // Scenario 13: ALTER TABLE t RENAME TO | → identifier context + t.Run("alter_table_rename_to_identifier", func(t *testing.T) { + candidates := Complete("ALTER TABLE t RENAME TO ", 24, cat) + // Should not suggest table names — user defines a new name + if containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("should not suggest existing table 't' for rename target; got %v", candidates) + } + }) + + // Scenario 14: ALTER TABLE t RENAME COLUMN | → columnref + t.Run("alter_table_rename_column_columnref", func(t *testing.T) { + candidates := Complete("ALTER TABLE t RENAME COLUMN ", 28, cat) + if !containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("missing column 'a'; got %v", candidates) + } + }) + + // Scenario 15: ALTER TABLE t RENAME INDEX | → index_ref + t.Run("alter_table_rename_index_ref", func(t *testing.T) { + candidates := Complete("ALTER TABLE t RENAME INDEX ", 27, cat) + if !containsCandidate(candidates, "idx_a", CandidateIndex) { + t.Errorf("missing index 'idx_a'; got %v", candidates) + } + }) + + // Scenario 16: ALTER TABLE t ADD INDEX idx (|) → columnref + t.Run("alter_table_add_index_columnref", func(t *testing.T) { + candidates := Complete("ALTER TABLE t ADD INDEX idx (", 28, cat) + if !containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("missing column 'a'; got %v", candidates) + } + }) + + // Scenario 17: ALTER TABLE t CONVERT TO CHARACTER SET | → charset + t.Run("alter_table_convert_charset", func(t *testing.T) { + candidates := Complete("ALTER TABLE t CONVERT TO CHARACTER SET ", 39, cat) + if !containsCandidate(candidates, "utf8mb4", CandidateCharset) { + t.Errorf("missing charset 'utf8mb4'; got %v", candidates) + } + if !containsCandidate(candidates, "latin1", CandidateCharset) { + t.Errorf("missing charset 'latin1'; got %v", candidates) + } + }) + + // Scenario 18: ALTER TABLE t ALGORITHM=| → algorithm keywords + t.Run("alter_table_algorithm_keywords", func(t *testing.T) { + candidates := Complete("ALTER TABLE t ALGORITHM=", 24, cat) + if !containsCandidate(candidates, "DEFAULT", CandidateKeyword) { + t.Errorf("missing keyword 'DEFAULT'; got %v", candidates) + } + }) +} + +// ────────────────────────────────────────────────────────────── +// Section 5.3: CREATE/DROP Index, View, Database & misc +// ────────────────────────────────────────────────────────────── + +func TestComplete_5_3_CreateDropMisc(t *testing.T) { + cat := catalog.New() + mustExec(t, cat, "CREATE DATABASE testdb") + cat.SetCurrentDatabase("testdb") + mustExec(t, cat, "CREATE TABLE t (a INT, b INT, c INT, INDEX idx_a (a))") + mustExec(t, cat, "CREATE TABLE t1 (a INT, b INT)") + mustExec(t, cat, "CREATE VIEW v AS SELECT a, b FROM t") + + // Scenario 1: CREATE INDEX idx ON | → table_ref + t.Run("create_index_on_table_ref", func(t *testing.T) { + candidates := Complete("CREATE INDEX idx ON ", 20, cat) + if !containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("missing table 't'; got %v", candidates) + } + }) + + // Scenario 2: CREATE INDEX idx ON t (|) → columnref + t.Run("create_index_columns_columnref", func(t *testing.T) { + candidates := Complete("CREATE INDEX idx ON t (", 22, cat) + if !containsCandidate(candidates, "a", CandidateColumn) { + t.Errorf("missing column 'a'; got %v", candidates) + } + if !containsCandidate(candidates, "b", CandidateColumn) { + t.Errorf("missing column 'b'; got %v", candidates) + } + }) + + // Scenario 3: CREATE UNIQUE INDEX idx ON | → table_ref + t.Run("create_unique_index_on_table_ref", func(t *testing.T) { + candidates := Complete("CREATE UNIQUE INDEX idx ON ", 27, cat) + if !containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("missing table 't'; got %v", candidates) + } + }) + + // Scenario 4: DROP INDEX | → index_ref + t.Run("drop_index_ref", func(t *testing.T) { + candidates := Complete("DROP INDEX ", 11, cat) + if !containsCandidate(candidates, "idx_a", CandidateIndex) { + t.Errorf("missing index 'idx_a'; got %v", candidates) + } + }) + + // Scenario 5: DROP INDEX idx ON | → table_ref + t.Run("drop_index_on_table_ref", func(t *testing.T) { + candidates := Complete("DROP INDEX idx_a ON ", 20, cat) + if !containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("missing table 't'; got %v", candidates) + } + }) + + // Scenario 6: CREATE VIEW | → identifier context + t.Run("create_view_identifier", func(t *testing.T) { + candidates := Complete("CREATE VIEW ", 12, cat) + // Should not suggest view names — user defines a new name + if containsCandidate(candidates, "v", CandidateView) { + t.Errorf("should not suggest existing view 'v' for new view name; got %v", candidates) + } + }) + + // Scenario 7: CREATE VIEW v AS | → SELECT keyword + t.Run("create_view_as_select", func(t *testing.T) { + candidates := Complete("CREATE VIEW v2 AS ", 18, cat) + if !containsCandidate(candidates, "SELECT", CandidateKeyword) { + t.Errorf("missing keyword 'SELECT'; got %v", candidates) + } + }) + + // Scenario 8: CREATE DEFINER=| → CURRENT_USER keyword + t.Run("create_definer_current_user", func(t *testing.T) { + candidates := Complete("CREATE DEFINER=", 15, cat) + if !containsCandidate(candidates, "CURRENT_USER", CandidateKeyword) { + t.Errorf("missing keyword 'CURRENT_USER'; got %v", candidates) + } + }) + + // Scenario 9: ALTER VIEW v AS | → SELECT keyword + t.Run("alter_view_as_select", func(t *testing.T) { + candidates := Complete("ALTER VIEW v AS ", 16, cat) + if !containsCandidate(candidates, "SELECT", CandidateKeyword) { + t.Errorf("missing keyword 'SELECT'; got %v", candidates) + } + }) + + // Scenario 10: DROP VIEW | → view_ref + t.Run("drop_view_ref", func(t *testing.T) { + candidates := Complete("DROP VIEW ", 10, cat) + if !containsCandidate(candidates, "v", CandidateView) { + t.Errorf("missing view 'v'; got %v", candidates) + } + }) + + // Scenario 11: CREATE DATABASE | → identifier context + t.Run("create_database_identifier", func(t *testing.T) { + candidates := Complete("CREATE DATABASE ", 16, cat) + // Should not suggest database names — user defines a new name + if containsCandidate(candidates, "testdb", CandidateDatabase) { + t.Errorf("should not suggest existing database for new database name; got %v", candidates) + } + }) + + // Scenario 12: DROP DATABASE | → database_ref + t.Run("drop_database_ref", func(t *testing.T) { + candidates := Complete("DROP DATABASE ", 14, cat) + if !containsCandidate(candidates, "testdb", CandidateDatabase) { + t.Errorf("missing database 'testdb'; got %v", candidates) + } + }) + + // Scenario 13: DROP TABLE | → table_ref + t.Run("drop_table_ref", func(t *testing.T) { + candidates := Complete("DROP TABLE ", 11, cat) + if !containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("missing table 't'; got %v", candidates) + } + }) + + // Scenario 14: DROP TABLE IF EXISTS | → table_ref + t.Run("drop_table_if_exists_ref", func(t *testing.T) { + candidates := Complete("DROP TABLE IF EXISTS ", 21, cat) + if !containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("missing table 't'; got %v", candidates) + } + }) + + // Scenario 15: TRUNCATE TABLE | → table_ref + t.Run("truncate_table_ref", func(t *testing.T) { + candidates := Complete("TRUNCATE TABLE ", 15, cat) + if !containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("missing table 't'; got %v", candidates) + } + }) + + // Scenario 16: RENAME TABLE | → table_ref + t.Run("rename_table_ref", func(t *testing.T) { + candidates := Complete("RENAME TABLE ", 13, cat) + if !containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("missing table 't'; got %v", candidates) + } + }) + + // Scenario 17: RENAME TABLE t TO | → identifier context + t.Run("rename_table_to_identifier", func(t *testing.T) { + candidates := Complete("RENAME TABLE t TO ", 18, cat) + // Should not suggest table names — user defines a new name + if containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("should not suggest existing table 't' for rename target; got %v", candidates) + } + }) + + // Scenario 18: DESCRIBE | → table_ref + t.Run("describe_table_ref", func(t *testing.T) { + candidates := Complete("DESCRIBE ", 9, cat) + if !containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("missing table 't'; got %v", candidates) + } + }) + + // Scenario 18b: DESC | → table_ref + t.Run("desc_table_ref", func(t *testing.T) { + candidates := Complete("DESC ", 5, cat) + if !containsCandidate(candidates, "t", CandidateTable) { + t.Errorf("missing table 't'; got %v", candidates) + } + }) +} + +// ============================================================================= +// Phase 6: Routine/Trigger/Event Instrumentation +// ============================================================================= + +func TestComplete_6_1_FunctionsAndProcedures(t *testing.T) { + cat := setupCatalog(t) + + // Scenario 1: CREATE FUNCTION | → identifier context (no specific candidates) + t.Run("create_function_identifier", func(t *testing.T) { + candidates := Complete("CREATE FUNCTION ", 16, cat) + // Should not suggest existing functions — user defines a new name + if containsCandidate(candidates, "my_func", CandidateFunction) { + t.Errorf("should not suggest existing function for new function name; got %v", candidates) + } + }) + + // Scenario 2: CREATE FUNCTION f(|) → param direction keywords + type context + t.Run("create_function_params", func(t *testing.T) { + candidates := Complete("CREATE FUNCTION f(", 18, cat) + if !containsCandidate(candidates, "IN", CandidateKeyword) { + t.Errorf("missing keyword 'IN'; got %v", candidates) + } + if !containsCandidate(candidates, "OUT", CandidateKeyword) { + t.Errorf("missing keyword 'OUT'; got %v", candidates) + } + if !containsCandidate(candidates, "INOUT", CandidateKeyword) { + t.Errorf("missing keyword 'INOUT'; got %v", candidates) + } + }) + + // Scenario 3: CREATE FUNCTION f() RETURNS | → type candidates + t.Run("create_function_returns_type", func(t *testing.T) { + candidates := Complete("CREATE FUNCTION f() RETURNS ", 28, cat) + if !containsCandidate(candidates, "INT", CandidateType_) { + t.Errorf("missing type 'INT'; got %v", candidates) + } + if !containsCandidate(candidates, "VARCHAR", CandidateType_) { + t.Errorf("missing type 'VARCHAR'; got %v", candidates) + } + }) + + // Scenario 4: CREATE FUNCTION f() | → characteristic keywords + t.Run("create_function_characteristics", func(t *testing.T) { + candidates := Complete("CREATE FUNCTION f() ", 20, cat) + if !containsCandidate(candidates, "DETERMINISTIC", CandidateKeyword) { + t.Errorf("missing keyword 'DETERMINISTIC'; got %v", candidates) + } + if !containsCandidate(candidates, "COMMENT", CandidateKeyword) { + t.Errorf("missing keyword 'COMMENT'; got %v", candidates) + } + if !containsCandidate(candidates, "LANGUAGE", CandidateKeyword) { + t.Errorf("missing keyword 'LANGUAGE'; got %v", candidates) + } + }) + + // Scenario 5: DROP FUNCTION | → function_ref + t.Run("drop_function_ref", func(t *testing.T) { + candidates := Complete("DROP FUNCTION ", 14, cat) + if !containsCandidate(candidates, "my_func", CandidateFunction) { + t.Errorf("missing function 'my_func'; got %v", candidates) + } + }) + + // Scenario 6: DROP FUNCTION IF EXISTS | → function_ref + t.Run("drop_function_if_exists_ref", func(t *testing.T) { + candidates := Complete("DROP FUNCTION IF EXISTS ", 24, cat) + if !containsCandidate(candidates, "my_func", CandidateFunction) { + t.Errorf("missing function 'my_func'; got %v", candidates) + } + }) + + // Scenario 7: CREATE PROCEDURE | → identifier context + t.Run("create_procedure_identifier", func(t *testing.T) { + candidates := Complete("CREATE PROCEDURE ", 17, cat) + if containsCandidate(candidates, "my_proc", CandidateProcedure) { + t.Errorf("should not suggest existing procedure for new procedure name; got %v", candidates) + } + }) + + // Scenario 8: DROP PROCEDURE | → procedure_ref + t.Run("drop_procedure_ref", func(t *testing.T) { + candidates := Complete("DROP PROCEDURE ", 15, cat) + if !containsCandidate(candidates, "my_proc", CandidateProcedure) { + t.Errorf("missing procedure 'my_proc'; got %v", candidates) + } + }) + + // Scenario 9: ALTER FUNCTION | → function_ref + t.Run("alter_function_ref", func(t *testing.T) { + candidates := Complete("ALTER FUNCTION ", 15, cat) + if !containsCandidate(candidates, "my_func", CandidateFunction) { + t.Errorf("missing function 'my_func'; got %v", candidates) + } + }) + + // Scenario 10: ALTER PROCEDURE | → procedure_ref + t.Run("alter_procedure_ref", func(t *testing.T) { + candidates := Complete("ALTER PROCEDURE ", 16, cat) + if !containsCandidate(candidates, "my_proc", CandidateProcedure) { + t.Errorf("missing procedure 'my_proc'; got %v", candidates) + } + }) +} + +func TestComplete_6_2_TriggersAndEvents(t *testing.T) { + cat := setupCatalog(t) + + // Scenario 1: CREATE TRIGGER | → identifier context + t.Run("create_trigger_identifier", func(t *testing.T) { + candidates := Complete("CREATE TRIGGER ", 15, cat) + if containsCandidate(candidates, "my_trig", CandidateTrigger) { + t.Errorf("should not suggest existing trigger for new trigger name; got %v", candidates) + } + }) + + // Scenario 2: CREATE TRIGGER trg | → BEFORE/AFTER keywords + t.Run("create_trigger_timing", func(t *testing.T) { + candidates := Complete("CREATE TRIGGER trg ", 19, cat) + if !containsCandidate(candidates, "BEFORE", CandidateKeyword) { + t.Errorf("missing keyword 'BEFORE'; got %v", candidates) + } + if !containsCandidate(candidates, "AFTER", CandidateKeyword) { + t.Errorf("missing keyword 'AFTER'; got %v", candidates) + } + }) + + // Scenario 3: CREATE TRIGGER trg BEFORE | → INSERT/UPDATE/DELETE + t.Run("create_trigger_event", func(t *testing.T) { + candidates := Complete("CREATE TRIGGER trg BEFORE ", 26, cat) + if !containsCandidate(candidates, "INSERT", CandidateKeyword) { + t.Errorf("missing keyword 'INSERT'; got %v", candidates) + } + if !containsCandidate(candidates, "UPDATE", CandidateKeyword) { + t.Errorf("missing keyword 'UPDATE'; got %v", candidates) + } + if !containsCandidate(candidates, "DELETE", CandidateKeyword) { + t.Errorf("missing keyword 'DELETE'; got %v", candidates) + } + }) + + // Scenario 4: CREATE TRIGGER trg BEFORE INSERT ON | → table_ref + t.Run("create_trigger_on_table", func(t *testing.T) { + candidates := Complete("CREATE TRIGGER trg BEFORE INSERT ON ", 36, cat) + if !containsCandidate(candidates, "users", CandidateTable) { + t.Errorf("missing table 'users'; got %v", candidates) + } + }) + + // Scenario 5: DROP TRIGGER | → trigger_ref + t.Run("drop_trigger_ref", func(t *testing.T) { + candidates := Complete("DROP TRIGGER ", 13, cat) + if !containsCandidate(candidates, "my_trig", CandidateTrigger) { + t.Errorf("missing trigger 'my_trig'; got %v", candidates) + } + }) + + // Scenario 6: CREATE EVENT | → identifier context + t.Run("create_event_identifier", func(t *testing.T) { + candidates := Complete("CREATE EVENT ", 13, cat) + if containsCandidate(candidates, "my_event", CandidateEvent) { + t.Errorf("should not suggest existing event for new event name; got %v", candidates) + } + }) + + // Scenario 7: CREATE EVENT ev ON SCHEDULE | → AT/EVERY keywords + t.Run("create_event_on_schedule", func(t *testing.T) { + candidates := Complete("CREATE EVENT ev ON SCHEDULE ", 28, cat) + if !containsCandidate(candidates, "AT", CandidateKeyword) { + t.Errorf("missing keyword 'AT'; got %v", candidates) + } + }) + + // Scenario 8: DROP EVENT | → event_ref + t.Run("drop_event_ref", func(t *testing.T) { + candidates := Complete("DROP EVENT ", 11, cat) + if !containsCandidate(candidates, "my_event", CandidateEvent) { + t.Errorf("missing event 'my_event'; got %v", candidates) + } + }) + + // Scenario 9: ALTER EVENT | → event_ref + t.Run("alter_event_ref", func(t *testing.T) { + candidates := Complete("ALTER EVENT ", 12, cat) + if !containsCandidate(candidates, "my_event", CandidateEvent) { + t.Errorf("missing event 'my_event'; got %v", candidates) + } + }) +} + +func TestComplete_6_3_TransactionLockMaintenance(t *testing.T) { + cat := setupCatalog(t) + + // Scenario 1: BEGIN | → WORK keyword + t.Run("begin_work", func(t *testing.T) { + candidates := Complete("BEGIN ", 6, cat) + // BEGIN should reach the completion point; we verify no crash and keywords are available + _ = candidates + }) + + // Scenario 2: START TRANSACTION | → WITH/READ keywords + t.Run("start_transaction_keywords", func(t *testing.T) { + candidates := Complete("START TRANSACTION ", 18, cat) + if !containsCandidate(candidates, "WITH", CandidateKeyword) { + t.Errorf("missing keyword 'WITH'; got %v", candidates) + } + if !containsCandidate(candidates, "READ", CandidateKeyword) { + t.Errorf("missing keyword 'READ'; got %v", candidates) + } + }) + + // Scenario 3: COMMIT | → AND keyword + t.Run("commit_keywords", func(t *testing.T) { + candidates := Complete("COMMIT ", 7, cat) + if !containsCandidate(candidates, "AND", CandidateKeyword) { + t.Errorf("missing keyword 'AND'; got %v", candidates) + } + }) + + // Scenario 4: ROLLBACK | → TO keyword + t.Run("rollback_keywords", func(t *testing.T) { + candidates := Complete("ROLLBACK ", 9, cat) + if !containsCandidate(candidates, "TO", CandidateKeyword) { + t.Errorf("missing keyword 'TO'; got %v", candidates) + } + }) + + // Scenario 5: ROLLBACK TO | → SAVEPOINT keyword + t.Run("rollback_to_savepoint", func(t *testing.T) { + candidates := Complete("ROLLBACK TO ", 12, cat) + if !containsCandidate(candidates, "SAVEPOINT", CandidateKeyword) { + t.Errorf("missing keyword 'SAVEPOINT'; got %v", candidates) + } + }) + + // Scenario 6: SAVEPOINT | → identifier context + t.Run("savepoint_identifier", func(t *testing.T) { + candidates := Complete("SAVEPOINT ", 10, cat) + // Should not crash; identifier context means no specific object suggestions + _ = candidates + }) + + // Scenario 7: RELEASE SAVEPOINT | → identifier context + t.Run("release_savepoint_identifier", func(t *testing.T) { + candidates := Complete("RELEASE SAVEPOINT ", 18, cat) + // Should not crash; identifier context + _ = candidates + }) + + // Scenario 8: LOCK TABLES | → table_ref + t.Run("lock_tables_ref", func(t *testing.T) { + candidates := Complete("LOCK TABLES ", 12, cat) + if !containsCandidate(candidates, "users", CandidateTable) { + t.Errorf("missing table 'users'; got %v", candidates) + } + }) + + // Scenario 9: LOCK TABLES t | → READ/WRITE keywords + t.Run("lock_tables_read_write", func(t *testing.T) { + candidates := Complete("LOCK TABLES users ", 18, cat) + if !containsCandidate(candidates, "READ", CandidateKeyword) { + t.Errorf("missing keyword 'READ'; got %v", candidates) + } + if !containsCandidate(candidates, "WRITE", CandidateKeyword) { + t.Errorf("missing keyword 'WRITE'; got %v", candidates) + } + }) + + // Scenario 10: ANALYZE TABLE | → table_ref + t.Run("analyze_table_ref", func(t *testing.T) { + candidates := Complete("ANALYZE TABLE ", 14, cat) + if !containsCandidate(candidates, "users", CandidateTable) { + t.Errorf("missing table 'users'; got %v", candidates) + } + }) + + // Scenario 11: OPTIMIZE TABLE | → table_ref + t.Run("optimize_table_ref", func(t *testing.T) { + candidates := Complete("OPTIMIZE TABLE ", 15, cat) + if !containsCandidate(candidates, "users", CandidateTable) { + t.Errorf("missing table 'users'; got %v", candidates) + } + }) + + // Scenario 12: CHECK TABLE | → table_ref + t.Run("check_table_ref", func(t *testing.T) { + candidates := Complete("CHECK TABLE ", 12, cat) + if !containsCandidate(candidates, "users", CandidateTable) { + t.Errorf("missing table 'users'; got %v", candidates) + } + }) + + // Scenario 13: REPAIR TABLE | → table_ref + t.Run("repair_table_ref", func(t *testing.T) { + candidates := Complete("REPAIR TABLE ", 13, cat) + if !containsCandidate(candidates, "users", CandidateTable) { + t.Errorf("missing table 'users'; got %v", candidates) + } + }) + + // Scenario 14: FLUSH | → PRIVILEGES/TABLES/LOGS/STATUS keywords + t.Run("flush_keywords", func(t *testing.T) { + candidates := Complete("FLUSH ", 6, cat) + if !containsCandidate(candidates, "PRIVILEGES", CandidateKeyword) { + t.Errorf("missing keyword 'PRIVILEGES'; got %v", candidates) + } + if !containsCandidate(candidates, "TABLES", CandidateKeyword) { + t.Errorf("missing keyword 'TABLES'; got %v", candidates) + } + if !containsCandidate(candidates, "LOGS", CandidateKeyword) { + t.Errorf("missing keyword 'LOGS'; got %v", candidates) + } + if !containsCandidate(candidates, "STATUS", CandidateKeyword) { + t.Errorf("missing keyword 'STATUS'; got %v", candidates) + } + }) +} + +// ============================================================================= +// Phase 7: Session/Utility Instrumentation +// ============================================================================= + +func TestComplete_7_1_SetAndShow(t *testing.T) { + cat := setupCatalog(t) + + // Scenario 1: SET | → variable/keyword candidates + t.Run("set_candidates", func(t *testing.T) { + candidates := Complete("SET ", 4, cat) + if !containsCandidate(candidates, "GLOBAL", CandidateKeyword) { + t.Errorf("missing keyword 'GLOBAL'; got %v", candidates) + } + if !containsCandidate(candidates, "SESSION", CandidateKeyword) { + t.Errorf("missing keyword 'SESSION'; got %v", candidates) + } + if !containsCandidate(candidates, "CHARACTER", CandidateKeyword) { + t.Errorf("missing keyword 'CHARACTER'; got %v", candidates) + } + if !containsCandidate(candidates, "NAMES", CandidateVariable) { + t.Errorf("missing variable 'NAMES'; got %v", candidates) + } + }) + + // Scenario 2: SET GLOBAL | → variable candidates + t.Run("set_global_variable", func(t *testing.T) { + candidates := Complete("SET GLOBAL ", 11, cat) + if !containsCandidate(candidates, "@@max_connections", CandidateVariable) { + t.Errorf("missing variable '@@max_connections'; got %v", candidates) + } + }) + + // Scenario 3: SET SESSION | → variable candidates + t.Run("set_session_variable", func(t *testing.T) { + candidates := Complete("SET SESSION ", 12, cat) + if !containsCandidate(candidates, "@@sql_mode", CandidateVariable) { + t.Errorf("missing variable '@@sql_mode'; got %v", candidates) + } + }) + + // Scenario 4: SET NAMES | → charset candidates + t.Run("set_names_charset", func(t *testing.T) { + candidates := Complete("SET NAMES ", 10, cat) + if !containsCandidate(candidates, "utf8mb4", CandidateCharset) { + t.Errorf("missing charset 'utf8mb4'; got %v", candidates) + } + if !containsCandidate(candidates, "latin1", CandidateCharset) { + t.Errorf("missing charset 'latin1'; got %v", candidates) + } + }) + + // Scenario 5: SET CHARACTER SET | → charset candidates + t.Run("set_character_set_charset", func(t *testing.T) { + candidates := Complete("SET CHARACTER SET ", 18, cat) + if !containsCandidate(candidates, "utf8mb4", CandidateCharset) { + t.Errorf("missing charset 'utf8mb4'; got %v", candidates) + } + }) + + // Scenario 6: SHOW | → keyword candidates + t.Run("show_keywords", func(t *testing.T) { + candidates := Complete("SHOW ", 5, cat) + if !containsCandidate(candidates, "TABLES", CandidateKeyword) { + t.Errorf("missing keyword 'TABLES'; got %v", candidates) + } + if !containsCandidate(candidates, "COLUMNS", CandidateKeyword) { + t.Errorf("missing keyword 'COLUMNS'; got %v", candidates) + } + if !containsCandidate(candidates, "INDEX", CandidateKeyword) { + t.Errorf("missing keyword 'INDEX'; got %v", candidates) + } + if !containsCandidate(candidates, "DATABASES", CandidateKeyword) { + t.Errorf("missing keyword 'DATABASES'; got %v", candidates) + } + if !containsCandidate(candidates, "CREATE", CandidateKeyword) { + t.Errorf("missing keyword 'CREATE'; got %v", candidates) + } + if !containsCandidate(candidates, "STATUS", CandidateKeyword) { + t.Errorf("missing keyword 'STATUS'; got %v", candidates) + } + if !containsCandidate(candidates, "VARIABLES", CandidateKeyword) { + t.Errorf("missing keyword 'VARIABLES'; got %v", candidates) + } + }) + + // Scenario 7: SHOW CREATE TABLE | → table_ref + t.Run("show_create_table_ref", func(t *testing.T) { + candidates := Complete("SHOW CREATE TABLE ", 18, cat) + if !containsCandidate(candidates, "users", CandidateTable) { + t.Errorf("missing table 'users'; got %v", candidates) + } + }) + + // Scenario 8: SHOW CREATE VIEW | → view_ref + t.Run("show_create_view_ref", func(t *testing.T) { + candidates := Complete("SHOW CREATE VIEW ", 17, cat) + if !containsCandidate(candidates, "active_users", CandidateView) { + t.Errorf("missing view 'active_users'; got %v", candidates) + } + }) + + // Scenario 9: SHOW CREATE FUNCTION | → function_ref + t.Run("show_create_function_ref", func(t *testing.T) { + candidates := Complete("SHOW CREATE FUNCTION ", 21, cat) + if !containsCandidate(candidates, "my_func", CandidateFunction) { + t.Errorf("missing function 'my_func'; got %v", candidates) + } + }) + + // Scenario 10: SHOW CREATE PROCEDURE | → procedure_ref + t.Run("show_create_procedure_ref", func(t *testing.T) { + candidates := Complete("SHOW CREATE PROCEDURE ", 22, cat) + if !containsCandidate(candidates, "my_proc", CandidateProcedure) { + t.Errorf("missing procedure 'my_proc'; got %v", candidates) + } + }) + + // Scenario 11: SHOW COLUMNS FROM | → table_ref + t.Run("show_columns_from_table_ref", func(t *testing.T) { + candidates := Complete("SHOW COLUMNS FROM ", 18, cat) + if !containsCandidate(candidates, "users", CandidateTable) { + t.Errorf("missing table 'users'; got %v", candidates) + } + }) + + // Scenario 12: SHOW INDEX FROM | → table_ref + t.Run("show_index_from_table_ref", func(t *testing.T) { + candidates := Complete("SHOW INDEX FROM ", 16, cat) + if !containsCandidate(candidates, "users", CandidateTable) { + t.Errorf("missing table 'users'; got %v", candidates) + } + }) + + // Scenario 13: SHOW TABLES FROM | → database_ref + t.Run("show_tables_from_database_ref", func(t *testing.T) { + candidates := Complete("SHOW TABLES FROM ", 17, cat) + if !containsCandidate(candidates, "testdb", CandidateDatabase) { + t.Errorf("missing database 'testdb'; got %v", candidates) + } + }) +} + +func TestComplete_7_2_UseGrantExplain(t *testing.T) { + cat := setupCatalog(t) + + // Scenario 1: USE | → database_ref + t.Run("use_database_ref", func(t *testing.T) { + candidates := Complete("USE ", 4, cat) + if !containsCandidate(candidates, "testdb", CandidateDatabase) { + t.Errorf("missing database 'testdb'; got %v", candidates) + } + }) + + // Scenario 2: GRANT | → privilege keywords + t.Run("grant_privilege_keywords", func(t *testing.T) { + candidates := Complete("GRANT ", 6, cat) + if !containsCandidate(candidates, "ALL", CandidateKeyword) { + t.Errorf("missing keyword 'ALL'; got %v", candidates) + } + if !containsCandidate(candidates, "SELECT", CandidateKeyword) { + t.Errorf("missing keyword 'SELECT'; got %v", candidates) + } + if !containsCandidate(candidates, "INSERT", CandidateKeyword) { + t.Errorf("missing keyword 'INSERT'; got %v", candidates) + } + if !containsCandidate(candidates, "UPDATE", CandidateKeyword) { + t.Errorf("missing keyword 'UPDATE'; got %v", candidates) + } + if !containsCandidate(candidates, "DELETE", CandidateKeyword) { + t.Errorf("missing keyword 'DELETE'; got %v", candidates) + } + if !containsCandidate(candidates, "CREATE", CandidateKeyword) { + t.Errorf("missing keyword 'CREATE'; got %v", candidates) + } + if !containsCandidate(candidates, "ALTER", CandidateKeyword) { + t.Errorf("missing keyword 'ALTER'; got %v", candidates) + } + if !containsCandidate(candidates, "DROP", CandidateKeyword) { + t.Errorf("missing keyword 'DROP'; got %v", candidates) + } + if !containsCandidate(candidates, "EXECUTE", CandidateKeyword) { + t.Errorf("missing keyword 'EXECUTE'; got %v", candidates) + } + }) + + // Scenario 3: GRANT SELECT ON | → table_ref + t.Run("grant_on_table_ref", func(t *testing.T) { + candidates := Complete("GRANT SELECT ON ", 16, cat) + if !containsCandidate(candidates, "users", CandidateTable) { + t.Errorf("missing table 'users'; got %v", candidates) + } + }) + + // Scenario 4: GRANT SELECT ON t TO | → user context (no specific resolution needed) + t.Run("grant_to_user", func(t *testing.T) { + candidates := Complete("GRANT SELECT ON users TO ", 25, cat) + // Just verify no panic; user context is not resolved via catalog + _ = candidates + }) + + // Scenario 5: REVOKE SELECT ON | → table_ref + t.Run("revoke_on_table_ref", func(t *testing.T) { + candidates := Complete("REVOKE SELECT ON ", 17, cat) + if !containsCandidate(candidates, "users", CandidateTable) { + t.Errorf("missing table 'users'; got %v", candidates) + } + }) + + // Scenario 6: EXPLAIN | → keyword candidates + t.Run("explain_keywords", func(t *testing.T) { + candidates := Complete("EXPLAIN ", 8, cat) + if !containsCandidate(candidates, "SELECT", CandidateKeyword) { + t.Errorf("missing keyword 'SELECT'; got %v", candidates) + } + }) + + // Scenario 7: EXPLAIN SELECT | → same as SELECT instrumentation + t.Run("explain_select", func(t *testing.T) { + candidates := Complete("EXPLAIN SELECT ", 15, cat) + if !containsCandidate(candidates, "id", CandidateColumn) { + // Columns may or may not be offered depending on context, + // but func_name should be offered + if !containsCandidate(candidates, "COUNT", CandidateFunction) { + t.Errorf("missing function 'COUNT' or column 'id'; got %v", candidates) + } + } + }) + + // Scenario 8: PREPARE stmt FROM | → string context (no specific candidates) + t.Run("prepare_from", func(t *testing.T) { + candidates := Complete("PREPARE stmt FROM ", 18, cat) + // Just verify no panic; string context is not resolved via catalog + _ = candidates + }) + + // Scenario 9: EXECUTE | → prepared statement name (no specific candidates — identifier) + t.Run("execute_stmt", func(t *testing.T) { + candidates := Complete("EXECUTE ", 8, cat) + // Just verify no panic; prepared statement names are not in catalog + _ = candidates + }) + + // Scenario 10: DEALLOCATE PREPARE | → prepared statement name + t.Run("deallocate_prepare", func(t *testing.T) { + candidates := Complete("DEALLOCATE PREPARE ", 19, cat) + // Just verify no panic + _ = candidates + }) + + // Scenario 11: DO | → expression context (columnref, func_name) + t.Run("do_expression", func(t *testing.T) { + candidates := Complete("DO ", 3, cat) + if !containsCandidate(candidates, "COUNT", CandidateFunction) { + t.Errorf("missing function 'COUNT'; got %v", candidates) + } + }) +} + +// ============================================================================= +// Phase 8: Expression Instrumentation +// ============================================================================= + +func TestComplete_8_1_FunctionAndTypeNames(t *testing.T) { + cat := setupCatalog(t) + + // Scenario 1: SELECT |() context → func_name candidates + t.Run("func_name_candidates", func(t *testing.T) { + candidates := Complete("SELECT ", 7, cat) + if !containsCandidate(candidates, "COUNT", CandidateFunction) { + t.Errorf("missing function 'COUNT'; got %v", candidates) + } + if !containsCandidate(candidates, "SUM", CandidateFunction) { + t.Errorf("missing function 'SUM'; got %v", candidates) + } + if !containsCandidate(candidates, "CONCAT", CandidateFunction) { + t.Errorf("missing function 'CONCAT'; got %v", candidates) + } + if !containsCandidate(candidates, "NOW", CandidateFunction) { + t.Errorf("missing function 'NOW'; got %v", candidates) + } + }) + + // Scenario 2: SELECT CAST(a AS |) → type candidates + t.Run("cast_as_type", func(t *testing.T) { + candidates := Complete("SELECT CAST(a AS ", 17, cat) + if !containsCandidate(candidates, "CHAR", CandidateType_) { + t.Errorf("missing type 'CHAR'; got %v", candidates) + } + if !containsCandidate(candidates, "DATE", CandidateType_) { + t.Errorf("missing type 'DATE'; got %v", candidates) + } + if !containsCandidate(candidates, "DECIMAL", CandidateType_) { + t.Errorf("missing type 'DECIMAL'; got %v", candidates) + } + }) + + // Scenario 3: SELECT CONVERT(a, |) → type candidates + t.Run("convert_type", func(t *testing.T) { + candidates := Complete("SELECT CONVERT(a, ", 18, cat) + if !containsCandidate(candidates, "CHAR", CandidateType_) { + t.Errorf("missing type 'CHAR'; got %v", candidates) + } + if !containsCandidate(candidates, "DATE", CandidateType_) { + t.Errorf("missing type 'DATE'; got %v", candidates) + } + }) + + // Scenario 4: SELECT CONVERT(a USING |) → charset candidates + t.Run("convert_using_charset", func(t *testing.T) { + candidates := Complete("SELECT CONVERT(a USING ", 23, cat) + if !containsCandidate(candidates, "utf8mb4", CandidateCharset) { + t.Errorf("missing charset 'utf8mb4'; got %v", candidates) + } + if !containsCandidate(candidates, "latin1", CandidateCharset) { + t.Errorf("missing charset 'latin1'; got %v", candidates) + } + }) +} + +func TestComplete_8_2_ExpressionContexts(t *testing.T) { + cat := setupCatalog(t) + + // Scenario 1: SELECT a + | → columnref, func_name + t.Run("expr_continuation", func(t *testing.T) { + candidates := Complete("SELECT a + ", 11, cat) + if !containsCandidate(candidates, "COUNT", CandidateFunction) { + t.Errorf("missing function 'COUNT'; got %v", candidates) + } + }) + + // Scenario 2: SELECT CASE WHEN | → columnref + t.Run("case_when_condition", func(t *testing.T) { + candidates := Complete("SELECT CASE WHEN ", 17, cat) + if !containsCandidate(candidates, "COUNT", CandidateFunction) { + t.Errorf("missing function 'COUNT'; got %v", candidates) + } + }) + + // Scenario 3: SELECT CASE WHEN a THEN | → columnref, literal + t.Run("case_then_result", func(t *testing.T) { + candidates := Complete("SELECT CASE WHEN a THEN ", 24, cat) + if !containsCandidate(candidates, "COUNT", CandidateFunction) { + t.Errorf("missing function 'COUNT'; got %v", candidates) + } + }) + + // Scenario 4: SELECT CASE a WHEN | → literal context + t.Run("case_when_value", func(t *testing.T) { + candidates := Complete("SELECT CASE a WHEN ", 19, cat) + // In a simple CASE, WHEN values go through parseExpr, so func_name/columnref are available + if !containsCandidate(candidates, "COUNT", CandidateFunction) { + t.Errorf("missing function 'COUNT'; got %v", candidates) + } + }) + + // Scenario 5: SELECT * FROM t WHERE a IN (|) → columnref, literal + t.Run("in_list", func(t *testing.T) { + candidates := Complete("SELECT * FROM users WHERE id IN (", 33, cat) + // After (, the parser may see SELECT or try parseExpr + hasFunc := containsCandidate(candidates, "COUNT", CandidateFunction) + hasCol := containsCandidate(candidates, "id", CandidateColumn) + hasSelect := containsCandidate(candidates, "SELECT", CandidateKeyword) + if !hasFunc && !hasCol && !hasSelect { + t.Errorf("missing completion candidates in IN list; got %v", candidates) + } + }) + + // Scenario 6: SELECT * FROM t WHERE a BETWEEN | → columnref, literal + t.Run("between_lower", func(t *testing.T) { + candidates := Complete("SELECT * FROM users WHERE id BETWEEN ", 37, cat) + hasFunc := containsCandidate(candidates, "COUNT", CandidateFunction) + hasCol := containsCandidate(candidates, "id", CandidateColumn) + if !hasFunc && !hasCol { + t.Errorf("missing completion candidates in BETWEEN; got %v", candidates) + } + }) + + // Scenario 7: SELECT * FROM t WHERE a BETWEEN 1 AND | → columnref, literal + t.Run("between_upper", func(t *testing.T) { + candidates := Complete("SELECT * FROM users WHERE id BETWEEN 1 AND ", 44, cat) + hasFunc := containsCandidate(candidates, "COUNT", CandidateFunction) + hasCol := containsCandidate(candidates, "id", CandidateColumn) + if !hasFunc && !hasCol { + t.Errorf("missing completion candidates in BETWEEN AND; got %v", candidates) + } + }) + + // Scenario 8: SELECT * FROM t WHERE EXISTS (|) → keyword candidate (SELECT) + t.Run("exists_select", func(t *testing.T) { + candidates := Complete("SELECT * FROM users WHERE EXISTS (", 34, cat) + if !containsCandidate(candidates, "SELECT", CandidateKeyword) { + t.Errorf("missing keyword 'SELECT'; got %v", candidates) + } + }) + + // Scenario 9: SELECT * FROM t WHERE a IS | → keyword candidates (NULL, NOT, TRUE, FALSE) + t.Run("is_keywords", func(t *testing.T) { + candidates := Complete("SELECT * FROM users WHERE id IS ", 32, cat) + if !containsCandidate(candidates, "NULL", CandidateKeyword) { + t.Errorf("missing keyword 'NULL'; got %v", candidates) + } + if !containsCandidate(candidates, "NOT", CandidateKeyword) { + t.Errorf("missing keyword 'NOT'; got %v", candidates) + } + if !containsCandidate(candidates, "TRUE", CandidateKeyword) { + t.Errorf("missing keyword 'TRUE'; got %v", candidates) + } + if !containsCandidate(candidates, "FALSE", CandidateKeyword) { + t.Errorf("missing keyword 'FALSE'; got %v", candidates) + } + }) +} diff --git a/tidb/completion/integration_test.go b/tidb/completion/integration_test.go new file mode 100644 index 00000000..e4949d3a --- /dev/null +++ b/tidb/completion/integration_test.go @@ -0,0 +1,353 @@ +package completion + +import ( + "strings" + "testing" + + "github.com/bytebase/omni/tidb/catalog" +) + +// setupIntegrationCatalog creates a realistic test catalog for integration tests. +func setupIntegrationCatalog(t *testing.T) *catalog.Catalog { + t.Helper() + cat := catalog.New() + mustExec(t, cat, "CREATE DATABASE test") + cat.SetCurrentDatabase("test") + mustExec(t, cat, "CREATE TABLE users (id INT, name VARCHAR(100), email VARCHAR(200))") + mustExec(t, cat, "CREATE TABLE orders (id INT, user_id INT, amount DECIMAL(10,2), status VARCHAR(20))") + mustExec(t, cat, "CREATE TABLE products (id INT, name VARCHAR(100), price DECIMAL(10,2))") + mustExec(t, cat, "CREATE VIEW active_users AS SELECT id, name FROM users WHERE id > 0") + mustExec(t, cat, "CREATE INDEX idx_user_id ON orders (user_id)") + return cat +} + +// assertContains checks that at least one candidate has the given text and type. +func assertContains(t *testing.T, candidates []Candidate, text string, typ CandidateType) { + t.Helper() + if !containsCandidate(candidates, text, typ) { + t.Errorf("expected candidate %q (type %d) not found in %d candidates", text, typ, len(candidates)) + } +} + +// assertNotContains checks that no candidate has the given text and type. +func assertNotContains(t *testing.T, candidates []Candidate, text string, typ CandidateType) { + t.Helper() + if containsCandidate(candidates, text, typ) { + t.Errorf("unexpected candidate %q (type %d) found", text, typ) + } +} + +// assertHasType checks that at least one candidate has the given type. +func assertHasType(t *testing.T, candidates []Candidate, typ CandidateType) { + t.Helper() + for _, c := range candidates { + if c.Type == typ { + return + } + } + t.Errorf("expected at least one candidate of type %d, found none in %d candidates", typ, len(candidates)) +} + +// candidatesOfType returns candidates matching the given type. +func candidatesOfType(candidates []Candidate, typ CandidateType) []Candidate { + var result []Candidate + for _, c := range candidates { + if c.Type == typ { + result = append(result, c) + } + } + return result +} + +// --- 9.1 Multi-Table Schema Tests --- + +func TestIntegration_9_1_MultiTableSchema(t *testing.T) { + cat := setupIntegrationCatalog(t) + + t.Run("column_scoped_to_correct_table_in_JOIN", func(t *testing.T) { + // When selecting columns from a joined query without qualifier, + // columns from both tables should be present. + sql := "SELECT FROM users JOIN orders ON users.id = orders.user_id" + cursor := len("SELECT ") + candidates := Complete(sql, cursor, cat) + + // Columns from users. + assertContains(t, candidates, "name", CandidateColumn) + assertContains(t, candidates, "email", CandidateColumn) + + // Columns from orders. + assertContains(t, candidates, "user_id", CandidateColumn) + assertContains(t, candidates, "amount", CandidateColumn) + assertContains(t, candidates, "status", CandidateColumn) + + // Shared column (id) should appear once (dedup). + assertContains(t, candidates, "id", CandidateColumn) + }) + + t.Run("unqualified_columns_from_all_tables", func(t *testing.T) { + // SELECT | FROM users JOIN orders ON users.id = orders.user_id + // Unqualified column completion should include columns from both tables. + sql := "SELECT FROM users JOIN orders ON users.id = orders.user_id" + cursor := len("SELECT ") + candidates := Complete(sql, cursor, cat) + + cols := candidatesOfType(candidates, CandidateColumn) + // Should have columns from users: id, name, email + assertContains(t, candidates, "name", CandidateColumn) + assertContains(t, candidates, "email", CandidateColumn) + // Should have columns from orders: user_id, amount, status + assertContains(t, candidates, "user_id", CandidateColumn) + assertContains(t, candidates, "amount", CandidateColumn) + assertContains(t, candidates, "status", CandidateColumn) + _ = cols + }) + + t.Run("table_alias_completion", func(t *testing.T) { + // SELECT | FROM users AS x + // Alias x refers to users table. Unqualified column completion in + // this context should include users columns (via alias resolution). + sql := "SELECT FROM users AS x" + cursor := len("SELECT ") + candidates := Complete(sql, cursor, cat) + + // The ref extractor sees alias x for users, so users columns available. + assertContains(t, candidates, "id", CandidateColumn) + assertContains(t, candidates, "name", CandidateColumn) + assertContains(t, candidates, "email", CandidateColumn) + }) + + t.Run("view_column_completion", func(t *testing.T) { + // SELECT | FROM active_users + // active_users is a view with columns: id, name + sql := "SELECT FROM active_users" + cursor := len("SELECT ") + candidates := Complete(sql, cursor, cat) + + assertContains(t, candidates, "id", CandidateColumn) + assertContains(t, candidates, "name", CandidateColumn) + }) + + t.Run("cte_column_completion", func(t *testing.T) { + // WITH cte AS (SELECT id, name FROM users) SELECT | FROM cte + sql := "WITH cte AS (SELECT id, name FROM users) SELECT FROM cte" + cursor := len("WITH cte AS (SELECT id, name FROM users) SELECT ") + candidates := Complete(sql, cursor, cat) + + // CTE resolves to the users table, so should have some columns. + // At minimum, columns from the referenced table should be available. + assertHasType(t, candidates, CandidateColumn) + }) + + t.Run("database_qualified_table", func(t *testing.T) { + // SELECT * FROM test.| → tables in database test + sql := "SELECT * FROM test. " + cursor := len("SELECT * FROM test.") + candidates := Complete(sql, cursor, cat) + + // Should have tables from the "test" database. + assertContains(t, candidates, "users", CandidateTable) + assertContains(t, candidates, "orders", CandidateTable) + assertContains(t, candidates, "products", CandidateTable) + }) +} + +// --- 9.2 Edge Cases --- + +func TestIntegration_9_2_EdgeCases(t *testing.T) { + cat := setupIntegrationCatalog(t) + + t.Run("cursor_at_beginning", func(t *testing.T) { + // |SELECT * FROM users → top-level keywords + sql := "SELECT * FROM users" + cursor := 0 + candidates := Complete(sql, cursor, cat) + + // At the beginning, prefix is empty so top-level keywords returned. + if len(candidates) == 0 { + t.Fatal("expected top-level keywords at cursor position 0") + } + assertContains(t, candidates, "SELECT", CandidateKeyword) + }) + + t.Run("cursor_in_middle_of_identifier", func(t *testing.T) { + // SELECT us|ers FROM t → prefix "us" filters candidates + sql := "SELECT users FROM users" + cursor := len("SELECT us") + candidates := Complete(sql, cursor, cat) + + // The prefix "us" should filter candidates. "users" column should match. + for _, c := range candidates { + if !strings.HasPrefix(strings.ToUpper(c.Text), "US") { + t.Errorf("candidate %q does not match prefix 'us'", c.Text) + } + } + }) + + t.Run("cursor_after_semicolon", func(t *testing.T) { + // SELECT 1; SELECT | → new statement context + sql := "SELECT 1; SELECT " + cursor := len(sql) + candidates := Complete(sql, cursor, cat) + + // Should get candidates for SELECT context (columns, keywords, etc.) + if len(candidates) == 0 { + t.Fatal("expected candidates after semicolon in new statement") + } + assertHasType(t, candidates, CandidateKeyword) + }) + + t.Run("empty_sql", func(t *testing.T) { + // | → top-level keywords + candidates := Complete("", 0, cat) + + if len(candidates) == 0 { + t.Fatal("expected top-level keywords for empty SQL") + } + assertContains(t, candidates, "SELECT", CandidateKeyword) + assertContains(t, candidates, "INSERT", CandidateKeyword) + assertContains(t, candidates, "UPDATE", CandidateKeyword) + assertContains(t, candidates, "DELETE", CandidateKeyword) + assertContains(t, candidates, "CREATE", CandidateKeyword) + }) + + t.Run("whitespace_only", func(t *testing.T) { + // " |" → top-level keywords + sql := " " + cursor := len(sql) + candidates := Complete(sql, cursor, cat) + + if len(candidates) == 0 { + t.Fatal("expected top-level keywords for whitespace-only SQL") + } + assertContains(t, candidates, "SELECT", CandidateKeyword) + }) + + t.Run("very_long_sql", func(t *testing.T) { + // Very long SQL with cursor in middle - should not panic or hang. + var b strings.Builder + b.WriteString("SELECT ") + for i := 0; i < 100; i++ { + if i > 0 { + b.WriteString(", ") + } + b.WriteString("id") + } + b.WriteString(" FROM users WHERE ") + cursor := b.Len() + b.WriteString("id > 0") + sql := b.String() + + candidates := Complete(sql, cursor, cat) + // Should return something (at least columns/keywords). + if len(candidates) == 0 { + t.Fatal("expected candidates for cursor in middle of long SQL") + } + }) + + t.Run("syntax_errors_before_cursor", func(t *testing.T) { + // SQL with syntax errors before cursor: completion still works. + sql := "SELCT * FORM users WHERE " + cursor := len(sql) + candidates := Complete(sql, cursor, cat) + + // Even with typos, the system should return some candidates (fallback). + // It's acceptable if it returns top-level keywords or columns. + if len(candidates) == 0 { + t.Log("no candidates for SQL with errors - acceptable fallback behavior") + } + // Main point: should not panic. + }) + + t.Run("backtick_quoted_identifiers", func(t *testing.T) { + // SELECT `| FROM users → should still produce candidates + // Note: backtick handling may be limited, but should not panic. + sql := "SELECT ` FROM users" + cursor := len("SELECT `") + candidates := Complete(sql, cursor, cat) + // Should not panic. Candidates may or may not be returned depending + // on backtick handling, but the system must be robust. + _ = candidates + }) +} + +// --- 9.3 Complex SQL Patterns --- + +func TestIntegration_9_3_ComplexSQLPatterns(t *testing.T) { + cat := setupIntegrationCatalog(t) + + t.Run("nested_subquery_column_completion", func(t *testing.T) { + // SELECT * FROM users WHERE id IN (SELECT | FROM orders) + // The ref extractor finds outer-scope table refs (users). + // Column candidates are returned (from users or fallback to all). + sql := "SELECT * FROM users WHERE id IN (SELECT FROM orders)" + cursor := len("SELECT * FROM users WHERE id IN (SELECT ") + candidates := Complete(sql, cursor, cat) + + // Should return some column candidates. + assertHasType(t, candidates, CandidateColumn) + // Should also have keyword/function candidates for SELECT context. + assertHasType(t, candidates, CandidateKeyword) + }) + + t.Run("correlated_subquery", func(t *testing.T) { + // SELECT *, (SELECT | FROM orders WHERE orders.user_id = users.id) FROM users + sql := "SELECT *, (SELECT FROM orders WHERE orders.user_id = users.id) FROM users" + cursor := len("SELECT *, (SELECT ") + candidates := Complete(sql, cursor, cat) + + // Should return column and keyword/function candidates. + assertHasType(t, candidates, CandidateColumn) + }) + + t.Run("union_select", func(t *testing.T) { + // SELECT name FROM users UNION SELECT | FROM products + // The ref extractor walks both sides of the UNION, finding users and products. + sql := "SELECT name FROM users UNION SELECT FROM products" + cursor := len("SELECT name FROM users UNION SELECT ") + candidates := Complete(sql, cursor, cat) + + // Should return column candidates (from the combined table refs). + assertHasType(t, candidates, CandidateColumn) + // Both users and products columns should be available via UNION ref extraction. + assertContains(t, candidates, "name", CandidateColumn) + }) + + t.Run("multiple_joins", func(t *testing.T) { + // SELECT | FROM users JOIN orders ON users.id = orders.user_id JOIN products ON ... + sql := "SELECT FROM users JOIN orders ON users.id = orders.user_id JOIN products ON orders.id = products.id" + cursor := len("SELECT ") + candidates := Complete(sql, cursor, cat) + + // Should have columns from all 3 tables. + assertContains(t, candidates, "email", CandidateColumn) // users + assertContains(t, candidates, "user_id", CandidateColumn) // orders + assertContains(t, candidates, "amount", CandidateColumn) // orders + assertContains(t, candidates, "price", CandidateColumn) // products + }) + + t.Run("insert_select", func(t *testing.T) { + // INSERT INTO users SELECT | FROM products + // The ref extractor finds 'users' (INSERT target) and 'products' (SELECT FROM). + sql := "INSERT INTO users SELECT FROM products" + cursor := len("INSERT INTO users SELECT ") + candidates := Complete(sql, cursor, cat) + + // Should have column candidates (from users and/or products). + assertHasType(t, candidates, CandidateColumn) + // The INSERT target (users) columns should be available. + assertContains(t, candidates, "name", CandidateColumn) + }) + + t.Run("complex_alter_table", func(t *testing.T) { + // ALTER TABLE users ADD COLUMN | → should get type candidates or column options + // Testing the simpler ALTER TABLE path that works end-to-end. + sql := "ALTER TABLE users ADD INDEX idx (" + cursor := len(sql) + candidates := Complete(sql, cursor, cat) + + // ADD INDEX (|) should produce columnref candidates for users. + assertContains(t, candidates, "id", CandidateColumn) + assertContains(t, candidates, "name", CandidateColumn) + assertContains(t, candidates, "email", CandidateColumn) + }) +} diff --git a/tidb/completion/refs.go b/tidb/completion/refs.go new file mode 100644 index 00000000..d7ddd63a --- /dev/null +++ b/tidb/completion/refs.go @@ -0,0 +1,264 @@ +package completion + +import ( + "github.com/bytebase/omni/tidb/ast" + "github.com/bytebase/omni/tidb/parser" +) + +// TableRef is a table reference found in a SQL statement. +type TableRef struct { + Database string // database/schema qualifier + Table string // table name + Alias string // AS alias +} + +// extractTableRefs parses the SQL and returns table references visible at cursor. +func extractTableRefs(sql string, cursorOffset int) (refs []TableRef) { + defer func() { + if r := recover(); r != nil { + refs = extractTableRefsLexer(sql, cursorOffset) + } + }() + return extractTableRefsInner(sql, cursorOffset) +} + +func extractTableRefsInner(sql string, cursorOffset int) []TableRef { + list, err := parser.Parse(sql) + if err != nil || list == nil || len(list.Items) == 0 { + // Fallback: try with placeholder appended at cursor. + if cursorOffset <= len(sql) { + patched := sql[:cursorOffset] + " _x" + if cursorOffset < len(sql) { + patched += sql[cursorOffset:] + } + list, err = parser.Parse(patched) + if err != nil || list == nil { + return extractTableRefsLexer(sql, cursorOffset) + } + } else { + return nil + } + } + + var refs []TableRef + for _, item := range list.Items { + refs = append(refs, extractRefsFromStmt(item)...) + } + if len(refs) == 0 { + return extractTableRefsLexer(sql, cursorOffset) + } + return refs +} + +// extractRefsFromStmt extracts table references from a single statement node. +func extractRefsFromStmt(n ast.Node) []TableRef { + if n == nil { + return nil + } + switch v := n.(type) { + case *ast.SelectStmt: + return extractRefsFromSelect(v) + case *ast.InsertStmt: + return extractRefsFromInsert(v) + case *ast.UpdateStmt: + return extractRefsFromUpdate(v) + case *ast.DeleteStmt: + return extractRefsFromDelete(v) + } + return nil +} + +// extractRefsFromSelect extracts table references from a SELECT statement. +// Does not recurse into subqueries (their tables don't leak to the outer scope). +func extractRefsFromSelect(s *ast.SelectStmt) []TableRef { + if s == nil { + return nil + } + var refs []TableRef + + // Set operations: walk both sides. + if s.SetOp != ast.SetOpNone { + refs = append(refs, extractRefsFromSelect(s.Left)...) + refs = append(refs, extractRefsFromSelect(s.Right)...) + return refs + } + + // CTEs. + for _, cte := range s.CTEs { + if cte != nil && cte.Name != "" { + refs = append(refs, TableRef{Table: cte.Name}) + } + } + + // FROM clause. + for _, te := range s.From { + refs = append(refs, extractRefsFromTableExpr(te)...) + } + return refs +} + +// extractRefsFromTableExpr extracts table references from a TableExpr node. +func extractRefsFromTableExpr(te ast.TableExpr) []TableRef { + if te == nil { + return nil + } + switch v := te.(type) { + case *ast.TableRef: + if v.Name != "" { + return []TableRef{{Database: v.Schema, Table: v.Name, Alias: v.Alias}} + } + case *ast.JoinClause: + var refs []TableRef + refs = append(refs, extractRefsFromTableExpr(v.Left)...) + refs = append(refs, extractRefsFromTableExpr(v.Right)...) + return refs + case *ast.SubqueryExpr: + // Subquery tables don't leak to the outer scope. + return nil + } + return nil +} + +// extractRefsFromInsert extracts the target table from an INSERT statement. +func extractRefsFromInsert(s *ast.InsertStmt) []TableRef { + if s == nil || s.Table == nil { + return nil + } + return []TableRef{{Database: s.Table.Schema, Table: s.Table.Name, Alias: s.Table.Alias}} +} + +// extractRefsFromUpdate extracts table references from an UPDATE statement. +func extractRefsFromUpdate(s *ast.UpdateStmt) []TableRef { + if s == nil { + return nil + } + var refs []TableRef + for _, te := range s.Tables { + refs = append(refs, extractRefsFromTableExpr(te)...) + } + return refs +} + +// extractRefsFromDelete extracts table references from a DELETE statement. +func extractRefsFromDelete(s *ast.DeleteStmt) []TableRef { + if s == nil { + return nil + } + var refs []TableRef + for _, te := range s.Tables { + refs = append(refs, extractRefsFromTableExpr(te)...) + } + for _, te := range s.Using { + refs = append(refs, extractRefsFromTableExpr(te)...) + } + return refs +} + +// extractTableRefsLexer is a fallback using lexer-based pattern matching +// when the SQL doesn't parse (e.g., incomplete SQL being edited). +func extractTableRefsLexer(sql string, cursorOffset int) []TableRef { + lex := parser.NewLexer(sql) + var tokens []parser.Token + for { + tok := lex.NextToken() + if tok.Type == 0 || tok.Loc >= cursorOffset { + break + } + tokens = append(tokens, tok) + } + + var refs []TableRef + + for i := 0; i < len(tokens); i++ { + typ := tokens[i].Type + + // FROM table, JOIN table + if typ == parser.FROM || typ == parser.JOIN { + ref, skip := lexerExtractTableAfter(tokens, i+1) + if ref != nil { + refs = append(refs, *ref) + } + i += skip + continue + } + + // UPDATE table + if typ == parser.UPDATE { + ref, skip := lexerExtractTableAfter(tokens, i+1) + if ref != nil { + refs = append(refs, *ref) + } + i += skip + continue + } + + // INSERT [INTO] table / REPLACE [INTO] table + if typ == parser.INSERT || typ == parser.REPLACE { + j := i + 1 + if j < len(tokens) && tokens[j].Type == parser.INTO { + j++ + } + ref, skip := lexerExtractTableAfter(tokens, j) + if ref != nil { + refs = append(refs, *ref) + } + i = j + skip + continue + } + + // DELETE [FROM] table + if typ == parser.DELETE { + j := i + 1 + if j < len(tokens) && tokens[j].Type == parser.FROM { + j++ + } + ref, skip := lexerExtractTableAfter(tokens, j) + if ref != nil { + refs = append(refs, *ref) + } + i = j + skip + continue + } + } + return refs +} + +// lexerExtractTableAfter extracts a [db.]table [AS alias] reference starting at tokens[j]. +// Returns the TableRef (or nil) and the number of tokens consumed. +func lexerExtractTableAfter(tokens []parser.Token, j int) (*TableRef, int) { + if j >= len(tokens) || !parser.IsIdentTokenType(tokens[j].Type) { + return nil, 0 + } + ref := TableRef{Table: tokens[j].Str} + consumed := j + 1 + // Check for db.table + if j+2 < len(tokens) && tokens[j+1].Type == '.' && parser.IsIdentTokenType(tokens[j+2].Type) { + ref.Database = ref.Table + ref.Table = tokens[j+2].Str + consumed = j + 3 + } + // Check for alias (AS alias or bare ident) + k := consumed + if k < len(tokens) { + if tokens[k].Type == parser.AS && k+1 < len(tokens) && parser.IsIdentTokenType(tokens[k+1].Type) { + ref.Alias = tokens[k+1].Str + } else if parser.IsIdentTokenType(tokens[k].Type) && !isClauseKeyword(tokens[k].Type) { + ref.Alias = tokens[k].Str + } + } + return &ref, consumed - j +} + +// isClauseKeyword returns true for keywords that typically follow a table name +// and should not be treated as aliases. +func isClauseKeyword(typ int) bool { + switch typ { + case parser.SET, parser.WHERE, parser.ON, parser.VALUES, parser.JOIN, + parser.INNER, parser.LEFT, parser.RIGHT, parser.CROSS, parser.NATURAL, + parser.ORDER, parser.GROUP, parser.HAVING, parser.LIMIT, parser.UNION, + parser.FOR, parser.USING, parser.FROM, parser.INTO, + parser.SELECT, parser.INSERT, parser.UPDATE, parser.DELETE: + return true + } + return false +} diff --git a/tidb/completion/refs_test.go b/tidb/completion/refs_test.go new file mode 100644 index 00000000..f66aff6d --- /dev/null +++ b/tidb/completion/refs_test.go @@ -0,0 +1,174 @@ +package completion + +import ( + "testing" + + "github.com/bytebase/omni/tidb/catalog" +) + +func TestExtractTableRefs(t *testing.T) { + tests := []struct { + name string + sql string + offset int + wantTables []string // expected table names + wantAlias map[string]string // table -> alias (optional) + wantDB map[string]string // table -> database (optional) + wantAbsent []string // tables that should NOT appear + }{ + { + name: "simple_select", + sql: "SELECT * FROM t WHERE ", + offset: 22, + wantTables: []string{"t"}, + }, + { + name: "with_alias", + sql: "SELECT * FROM t AS x WHERE ", + offset: 27, + wantTables: []string{"t"}, + wantAlias: map[string]string{"t": "x"}, + }, + { + name: "join", + sql: "SELECT * FROM t1 JOIN t2 ON t1.a = t2.a WHERE ", + offset: 46, + wantTables: []string{"t1", "t2"}, + }, + { + name: "database_qualified", + sql: "SELECT * FROM db.t WHERE ", + offset: 25, + wantTables: []string{"t"}, + wantDB: map[string]string{"t": "db"}, + }, + { + name: "subquery_no_leak", + sql: "SELECT * FROM outer_t WHERE a IN (SELECT b FROM inner_t) AND ", + offset: 61, + wantTables: []string{"outer_t"}, + wantAbsent: []string{"inner_t"}, + }, + { + name: "update", + sql: "UPDATE t SET a = 1 WHERE ", + offset: 25, + wantTables: []string{"t"}, + }, + { + name: "insert_into", + sql: "INSERT INTO t (a, b) VALUES ", + offset: 28, + wantTables: []string{"t"}, + }, + { + name: "delete_from", + sql: "DELETE FROM t WHERE ", + offset: 20, + wantTables: []string{"t"}, + }, + { + name: "lexer_fallback_incomplete", + sql: "SELECT * FROM t WHERE a = ", + offset: 26, + wantTables: []string{"t"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + refs := extractTableRefs(tt.sql, tt.offset) + got := make(map[string]bool) + refMap := make(map[string]TableRef) + for _, r := range refs { + got[r.Table] = true + refMap[r.Table] = r + } + for _, w := range tt.wantTables { + if !got[w] { + t.Errorf("extractTableRefs(%q, %d): missing table %q, got refs=%+v", tt.sql, tt.offset, w, refs) + } + } + for _, a := range tt.wantAbsent { + if got[a] { + t.Errorf("extractTableRefs(%q, %d): should not contain %q, got refs=%+v", tt.sql, tt.offset, a, refs) + } + } + if tt.wantAlias != nil { + for tbl, alias := range tt.wantAlias { + if r, ok := refMap[tbl]; ok { + if r.Alias != alias { + t.Errorf("extractTableRefs(%q, %d): table %q alias want %q got %q", tt.sql, tt.offset, tbl, alias, r.Alias) + } + } + } + } + if tt.wantDB != nil { + for tbl, db := range tt.wantDB { + if r, ok := refMap[tbl]; ok { + if r.Database != db { + t.Errorf("extractTableRefs(%q, %d): table %q database want %q got %q", tt.sql, tt.offset, tbl, db, r.Database) + } + } + } + } + }) + } +} + +// TestResolveColumnRefScoped tests that resolveColumnRefScoped returns +// columns only from tables referenced in the SQL. +func TestResolveColumnRefScoped(t *testing.T) { + cat := catalog.New() + cat.Exec("CREATE DATABASE test", nil) + cat.SetCurrentDatabase("test") + cat.Exec("CREATE TABLE t1 (a INT, b INT)", nil) + cat.Exec("CREATE TABLE t2 (c INT, d INT)", nil) + cat.Exec("CREATE TABLE t3 (e INT, f INT)", nil) + + tests := []struct { + name string + sql string + cursor int + wantCols []string // columns that should appear + wantAbsent []string // columns that should NOT appear + }{ + { + name: "scoped_single_table", + sql: "SELECT * FROM t1 WHERE ", + cursor: 23, + wantCols: []string{"a", "b"}, + wantAbsent: []string{"c", "d", "e", "f"}, + }, + { + name: "scoped_join", + sql: "SELECT * FROM t1 JOIN t2 ON t1.a = t2.c WHERE ", + cursor: 46, + wantCols: []string{"a", "b", "c", "d"}, + wantAbsent: []string{"e", "f"}, + }, + { + name: "scoped_update", + sql: "UPDATE t2 SET ", + cursor: 14, + wantCols: []string{"c", "d"}, + wantAbsent: []string{"a", "b", "e", "f"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + candidates := resolveColumnRefScoped(cat, tt.sql, tt.cursor) + for _, col := range tt.wantCols { + if !containsCandidate(candidates, col, CandidateColumn) { + t.Errorf("resolveColumnRefScoped(%q, %d): missing column %q", tt.sql, tt.cursor, col) + } + } + for _, col := range tt.wantAbsent { + if containsCandidate(candidates, col, CandidateColumn) { + t.Errorf("resolveColumnRefScoped(%q, %d): should not contain column %q", tt.sql, tt.cursor, col) + } + } + }) + } +} diff --git a/tidb/completion/resolve.go b/tidb/completion/resolve.go new file mode 100644 index 00000000..17181ddf --- /dev/null +++ b/tidb/completion/resolve.go @@ -0,0 +1,453 @@ +package completion + +import ( + "github.com/bytebase/omni/tidb/catalog" + "github.com/bytebase/omni/tidb/parser" +) + +// Built-in MySQL function names for func_name / function_ref resolution. +var builtinFunctions = []string{ + // Aggregate + "COUNT", "SUM", "AVG", "MAX", "MIN", "GROUP_CONCAT", "JSON_ARRAYAGG", "JSON_OBJECTAGG", + "BIT_AND", "BIT_OR", "BIT_XOR", "STD", "STDDEV", "STDDEV_POP", "STDDEV_SAMP", + "VAR_POP", "VAR_SAMP", "VARIANCE", + // Window + "ROW_NUMBER", "RANK", "DENSE_RANK", "CUME_DIST", "PERCENT_RANK", + "NTILE", "LAG", "LEAD", "FIRST_VALUE", "LAST_VALUE", "NTH_VALUE", + // String + "CONCAT", "CONCAT_WS", "SUBSTRING", "SUBSTR", "LEFT", "RIGHT", "LENGTH", "CHAR_LENGTH", + "CHARACTER_LENGTH", "UPPER", "LOWER", "LCASE", "UCASE", "TRIM", "LTRIM", "RTRIM", + "REPLACE", "REVERSE", "INSERT", "LPAD", "RPAD", "REPEAT", "SPACE", "FORMAT", + "LOCATE", "INSTR", "POSITION", "FIELD", "FIND_IN_SET", "ELT", "MAKE_SET", + "QUOTE", "SOUNDEX", "HEX", "UNHEX", "ORD", "ASCII", "BIN", "OCT", + // Numeric + "ABS", "CEIL", "CEILING", "FLOOR", "ROUND", "TRUNCATE", "MOD", "POW", "POWER", + "SQRT", "EXP", "LOG", "LOG2", "LOG10", "LN", "SIGN", "PI", "RAND", + "CRC32", "CONV", "RADIANS", "DEGREES", "SIN", "COS", "TAN", "ASIN", "ACOS", "ATAN", "COT", + // Date/Time + "NOW", "CURDATE", "CURTIME", "CURRENT_DATE", "CURRENT_TIME", "CURRENT_TIMESTAMP", + "SYSDATE", "UTC_DATE", "UTC_TIME", "UTC_TIMESTAMP", "LOCALTIME", "LOCALTIMESTAMP", + "DATE", "TIME", "YEAR", "MONTH", "DAY", "HOUR", "MINUTE", "SECOND", "MICROSECOND", + "DAYNAME", "DAYOFMONTH", "DAYOFWEEK", "DAYOFYEAR", "WEEK", "WEEKDAY", "WEEKOFYEAR", + "QUARTER", "YEARWEEK", "LAST_DAY", "MAKEDATE", "MAKETIME", + "DATE_ADD", "DATE_SUB", "ADDDATE", "SUBDATE", "ADDTIME", "SUBTIME", + "DATEDIFF", "TIMEDIFF", "TIMESTAMPDIFF", "TIMESTAMPADD", + "DATE_FORMAT", "TIME_FORMAT", "STR_TO_DATE", "FROM_UNIXTIME", "UNIX_TIMESTAMP", + "EXTRACT", "GET_FORMAT", "PERIOD_ADD", "PERIOD_DIFF", "SEC_TO_TIME", "TIME_TO_SEC", + "FROM_DAYS", "TO_DAYS", "TO_SECONDS", + // Control flow + "IF", "IFNULL", "NULLIF", "COALESCE", "GREATEST", "LEAST", "INTERVAL", + // Cast + "CAST", "CONVERT", "BINARY", + // JSON + "JSON_ARRAY", "JSON_OBJECT", "JSON_QUOTE", "JSON_EXTRACT", "JSON_UNQUOTE", + "JSON_CONTAINS", "JSON_CONTAINS_PATH", "JSON_KEYS", "JSON_SEARCH", + "JSON_SET", "JSON_INSERT", "JSON_REPLACE", "JSON_REMOVE", "JSON_MERGE_PRESERVE", + "JSON_MERGE_PATCH", "JSON_DEPTH", "JSON_LENGTH", "JSON_TYPE", "JSON_VALID", + "JSON_ARRAYAGG", "JSON_OBJECTAGG", "JSON_PRETTY", "JSON_STORAGE_FREE", "JSON_STORAGE_SIZE", + "JSON_TABLE", "JSON_VALUE", + // Info + "DATABASE", "SCHEMA", "USER", "CURRENT_USER", "SESSION_USER", "SYSTEM_USER", + "VERSION", "CONNECTION_ID", "LAST_INSERT_ID", "ROW_COUNT", "FOUND_ROWS", + "BENCHMARK", "CHARSET", "COLLATION", "COERCIBILITY", + // Encryption + "MD5", "SHA1", "SHA2", "AES_ENCRYPT", "AES_DECRYPT", + "RANDOM_BYTES", + // Misc + "UUID", "UUID_SHORT", "UUID_TO_BIN", "BIN_TO_UUID", + "SLEEP", "VALUES", "DEFAULT", "INET_ATON", "INET_NTOA", "INET6_ATON", "INET6_NTOA", + "IS_IPV4", "IS_IPV6", "IS_UUID", + "ANY_VALUE", "GROUPING", +} + +// Known charsets for "charset" rule. +var knownCharsets = []string{ + "utf8mb4", "utf8mb3", "utf8", "latin1", "ascii", "binary", + "big5", "cp1250", "cp1251", "cp1256", "cp1257", "cp850", "cp852", "cp866", + "dec8", "eucjpms", "euckr", "gb2312", "gb18030", "gbk", "geostd8", + "greek", "hebrew", "hp8", "keybcs2", "koi8r", "koi8u", + "latin2", "latin5", "latin7", "macce", "macroman", + "sjis", "swe7", "tis620", "ucs2", "ujis", "utf16", "utf16le", "utf32", +} + +// Known engines for "engine" rule. +var knownEngines = []string{ + "InnoDB", "MyISAM", "MEMORY", "CSV", "ARCHIVE", + "BLACKHOLE", "MERGE", "FEDERATED", "NDB", "NDBCLUSTER", +} + +// MySQL type keywords for "type_name" rule. +var typeKeywords = []string{ + // Integer types + "TINYINT", "SMALLINT", "MEDIUMINT", "INT", "INTEGER", "BIGINT", + // Fixed-point + "DECIMAL", "NUMERIC", "DEC", "FIXED", + // Floating-point + "FLOAT", "DOUBLE", "REAL", + // Bit + "BIT", "BOOL", "BOOLEAN", + // String types + "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT", + "BINARY", "VARBINARY", "TINYBLOB", "BLOB", "MEDIUMBLOB", "LONGBLOB", + // Date/time + "DATE", "DATETIME", "TIMESTAMP", "TIME", "YEAR", + // JSON + "JSON", + // Spatial + "GEOMETRY", "POINT", "LINESTRING", "POLYGON", + "MULTIPOINT", "MULTILINESTRING", "MULTIPOLYGON", "GEOMETRYCOLLECTION", + // Enum/Set + "ENUM", "SET", + // Serial + "SERIAL", +} + +// resolveRules converts parser rule candidates into typed Candidate values +// using the catalog for name resolution. +func resolveRules(cs *parser.CandidateSet, cat *catalog.Catalog, sql string, cursorOffset int) []Candidate { + if cs == nil { + return nil + } + var result []Candidate + for _, rc := range cs.Rules { + result = append(result, resolveRule(rc.Rule, cat, sql, cursorOffset)...) + } + return result +} + +// resolveRule resolves a single grammar rule name into completion candidates. +func resolveRule(rule string, cat *catalog.Catalog, sql string, cursorOffset int) []Candidate { + switch rule { + case "table_ref": + return resolveTableRef(cat) + case "columnref": + return resolveColumnRefScoped(cat, sql, cursorOffset) + case "database_ref": + return resolveDatabaseRef(cat) + case "function_ref", "func_name": + return resolveFunctionRef(cat) + case "procedure_ref": + return resolveProcedureRef(cat) + case "index_ref": + return resolveIndexRef(cat) + case "trigger_ref": + return resolveTriggerRef(cat) + case "event_ref": + return resolveEventRef(cat) + case "view_ref": + return resolveViewRef(cat) + case "constraint_ref": + return resolveConstraintRef(cat) + case "charset": + return resolveCharset() + case "engine": + return resolveEngine() + case "type_name": + return resolveTypeName() + case "variable": + return resolveVariable() + } + return nil +} + +// resolveTableRef returns tables and views from the current database. +func resolveTableRef(cat *catalog.Catalog) []Candidate { + if cat == nil { + return nil + } + db := currentDB(cat) + if db == nil { + return nil + } + var result []Candidate + for _, t := range db.Tables { + result = append(result, Candidate{Text: t.Name, Type: CandidateTable}) + } + for _, v := range db.Views { + result = append(result, Candidate{Text: v.Name, Type: CandidateView}) + } + return result +} + +// resolveColumnRefScoped returns columns scoped to the tables referenced in +// the SQL statement. If no table refs are found, falls back to all columns +// in the current database. +func resolveColumnRefScoped(cat *catalog.Catalog, sql string, cursorOffset int) []Candidate { + if cat == nil { + return nil + } + db := currentDB(cat) + if db == nil { + return nil + } + + refs := extractTableRefs(sql, cursorOffset) + if len(refs) == 0 { + return resolveColumnRef(cat) + } + + seen := make(map[string]bool) + var result []Candidate + for _, ref := range refs { + // Resolve table in the appropriate database. + targetDB := db + if ref.Database != "" { + targetDB = cat.GetDatabase(ref.Database) + if targetDB == nil { + continue + } + } + // Look up the table. + for _, t := range targetDB.Tables { + if t.Name == ref.Table { + for _, col := range t.Columns { + if !seen[col.Name] { + seen[col.Name] = true + result = append(result, Candidate{Text: col.Name, Type: CandidateColumn}) + } + } + break + } + } + // Also check views. + for _, v := range targetDB.Views { + if v.Name == ref.Table { + for _, colName := range v.Columns { + if !seen[colName] { + seen[colName] = true + result = append(result, Candidate{Text: colName, Type: CandidateColumn}) + } + } + break + } + } + } + // If we found refs but couldn't resolve any columns (e.g. CTE name not + // matching any catalog table), fall back to all columns. + if len(result) == 0 { + return resolveColumnRef(cat) + } + return result +} + +// resolveColumnRef returns columns from all tables in the current database. +func resolveColumnRef(cat *catalog.Catalog) []Candidate { + if cat == nil { + return nil + } + db := currentDB(cat) + if db == nil { + return nil + } + seen := make(map[string]bool) + var result []Candidate + for _, t := range db.Tables { + for _, col := range t.Columns { + if !seen[col.Name] { + seen[col.Name] = true + result = append(result, Candidate{Text: col.Name, Type: CandidateColumn}) + } + } + } + for _, v := range db.Views { + for _, colName := range v.Columns { + if !seen[colName] { + seen[colName] = true + result = append(result, Candidate{Text: colName, Type: CandidateColumn}) + } + } + } + return result +} + +// resolveDatabaseRef returns all databases from the catalog. +func resolveDatabaseRef(cat *catalog.Catalog) []Candidate { + if cat == nil { + return nil + } + var result []Candidate + for _, db := range cat.Databases() { + result = append(result, Candidate{Text: db.Name, Type: CandidateDatabase}) + } + return result +} + +// resolveFunctionRef returns catalog functions + built-in function names. +func resolveFunctionRef(cat *catalog.Catalog) []Candidate { + var result []Candidate + // Built-in functions always available. + for _, name := range builtinFunctions { + result = append(result, Candidate{Text: name, Type: CandidateFunction}) + } + // Catalog functions from current database. + if cat != nil { + if db := currentDB(cat); db != nil { + for _, fn := range db.Functions { + result = append(result, Candidate{Text: fn.Name, Type: CandidateFunction}) + } + } + } + return result +} + +// resolveProcedureRef returns catalog procedures from the current database. +func resolveProcedureRef(cat *catalog.Catalog) []Candidate { + if cat == nil { + return nil + } + db := currentDB(cat) + if db == nil { + return nil + } + var result []Candidate + for _, p := range db.Procedures { + result = append(result, Candidate{Text: p.Name, Type: CandidateProcedure}) + } + return result +} + +// resolveIndexRef returns indexes from all tables in the current database. +// Table-scoped resolution will be refined in later phases. +func resolveIndexRef(cat *catalog.Catalog) []Candidate { + if cat == nil { + return nil + } + db := currentDB(cat) + if db == nil { + return nil + } + seen := make(map[string]bool) + var result []Candidate + for _, t := range db.Tables { + for _, idx := range t.Indexes { + if idx.Name != "" && !seen[idx.Name] { + seen[idx.Name] = true + result = append(result, Candidate{Text: idx.Name, Type: CandidateIndex}) + } + } + } + return result +} + +// resolveTriggerRef returns triggers from the current database. +func resolveTriggerRef(cat *catalog.Catalog) []Candidate { + if cat == nil { + return nil + } + db := currentDB(cat) + if db == nil { + return nil + } + var result []Candidate + for _, tr := range db.Triggers { + result = append(result, Candidate{Text: tr.Name, Type: CandidateTrigger}) + } + return result +} + +// resolveEventRef returns events from the current database. +func resolveEventRef(cat *catalog.Catalog) []Candidate { + if cat == nil { + return nil + } + db := currentDB(cat) + if db == nil { + return nil + } + var result []Candidate + for _, ev := range db.Events { + result = append(result, Candidate{Text: ev.Name, Type: CandidateEvent}) + } + return result +} + +// resolveViewRef returns views from the current database. +func resolveViewRef(cat *catalog.Catalog) []Candidate { + if cat == nil { + return nil + } + db := currentDB(cat) + if db == nil { + return nil + } + var result []Candidate + for _, v := range db.Views { + result = append(result, Candidate{Text: v.Name, Type: CandidateView}) + } + return result +} + +// resolveConstraintRef returns constraint names from all tables in the current database. +func resolveConstraintRef(cat *catalog.Catalog) []Candidate { + if cat == nil { + return nil + } + db := currentDB(cat) + if db == nil { + return nil + } + seen := make(map[string]bool) + var result []Candidate + for _, t := range db.Tables { + for _, c := range t.Constraints { + if c.Name != "" && !seen[c.Name] { + seen[c.Name] = true + result = append(result, Candidate{Text: c.Name, Type: CandidateIndex}) + } + } + } + return result +} + +// resolveCharset returns known MySQL charset names. +func resolveCharset() []Candidate { + result := make([]Candidate, len(knownCharsets)) + for i, name := range knownCharsets { + result[i] = Candidate{Text: name, Type: CandidateCharset} + } + return result +} + +// resolveEngine returns known MySQL engine names. +func resolveEngine() []Candidate { + result := make([]Candidate, len(knownEngines)) + for i, name := range knownEngines { + result[i] = Candidate{Text: name, Type: CandidateEngine} + } + return result +} + +// resolveTypeName returns MySQL type keywords. +func resolveTypeName() []Candidate { + result := make([]Candidate, len(typeKeywords)) + for i, name := range typeKeywords { + result[i] = Candidate{Text: name, Type: CandidateType_} + } + return result +} + +// resolveVariable returns common MySQL system variable names. +func resolveVariable() []Candidate { + vars := []string{ + "@@autocommit", "@@character_set_client", "@@character_set_connection", + "@@character_set_results", "@@collation_connection", "@@max_connections", + "@@wait_timeout", "@@interactive_timeout", "@@sql_mode", + "@@time_zone", "@@tx_isolation", "@@innodb_buffer_pool_size", + "@@global.max_connections", "@@session.sql_mode", + "NAMES", "CHARACTER SET", "PASSWORD", + } + result := make([]Candidate, len(vars)) + for i, name := range vars { + result[i] = Candidate{Text: name, Type: CandidateVariable} + } + return result +} + +// currentDB returns the current database from the catalog, or nil. +func currentDB(cat *catalog.Catalog) *catalog.Database { + name := cat.CurrentDatabase() + if name == "" { + return nil + } + return cat.GetDatabase(name) +} diff --git a/tidb/deparse/deparse.go b/tidb/deparse/deparse.go new file mode 100644 index 00000000..9470aa67 --- /dev/null +++ b/tidb/deparse/deparse.go @@ -0,0 +1,1757 @@ +// Package deparse converts MySQL AST nodes back to SQL text, +// matching MySQL 8.0's SHOW CREATE VIEW formatting. +package deparse + +import ( + "fmt" + "math/big" + "strings" + + ast "github.com/bytebase/omni/tidb/ast" +) + +// Deparse converts an expression AST node to its SQL text representation, +// matching MySQL 8.0's canonical formatting (as seen in SHOW CREATE VIEW). +func Deparse(node ast.ExprNode) string { + if node == nil { + return "" + } + return deparseExpr(node) +} + +// DeparseSelect converts a SelectStmt AST node to its SQL text representation, +// matching MySQL 8.0's SHOW CREATE VIEW formatting. +func DeparseSelect(stmt *ast.SelectStmt) string { + if stmt == nil { + return "" + } + return deparseSelectStmt(stmt) +} + +func deparseSelectStmt(stmt *ast.SelectStmt) string { + return deparseSelectStmtCtx(stmt, false) +} + +// deparseSelectStmtNoAlias formats a SELECT without target list aliases. +// Used for subquery contexts (IN, EXISTS) where MySQL omits AS alias. +func deparseSelectStmtNoAlias(stmt *ast.SelectStmt) string { + return deparseSelectStmtCtx(stmt, true) +} + +func deparseSelectStmtCtx(stmt *ast.SelectStmt, suppressAlias bool) string { + // Handle set operations: UNION / UNION ALL / INTERSECT / EXCEPT + if stmt.SetOp != ast.SetOpNone { + return deparseSetOperation(stmt) + } + + var b strings.Builder + + // CTE (WITH clause) — emit before the SELECT keyword + if len(stmt.CTEs) > 0 { + b.WriteString(deparseCTEs(stmt.CTEs)) + b.WriteString(" ") + } + + b.WriteString("select ") + + // DISTINCT + if stmt.DistinctKind == ast.DistinctOn { + b.WriteString("distinct ") + } + + // Target list + for i, target := range stmt.TargetList { + if i > 0 { + b.WriteString(",") + } + if suppressAlias { + b.WriteString(deparseResTargetNoAlias(target)) + } else { + b.WriteString(deparseResTarget(target, i+1)) + } + } + + // FROM clause + if len(stmt.From) > 0 { + b.WriteString(" from ") + if len(stmt.From) == 1 { + b.WriteString(deparseTableExpr(stmt.From[0])) + } else { + // Multiple tables (implicit cross join) → normalized to explicit join with parens + // e.g., FROM t1, t2 → from (`t1` join `t2`) + // For 3+ tables: FROM t1, t2, t3 → from ((`t1` join `t2`) join `t3`) + b.WriteString(deparseImplicitCrossJoin(stmt.From)) + } + } + + // WHERE clause + if stmt.Where != nil { + b.WriteString(" where ") + b.WriteString(deparseExpr(stmt.Where)) + } + + // GROUP BY clause + if len(stmt.GroupBy) > 0 { + b.WriteString(" group by ") + for i, expr := range stmt.GroupBy { + if i > 0 { + b.WriteString(",") + } + b.WriteString(deparseExpr(expr)) + } + if stmt.WithRollup { + b.WriteString(" with rollup") + } + } + + // HAVING clause + if stmt.Having != nil { + b.WriteString(" having ") + b.WriteString(deparseExpr(stmt.Having)) + } + + // WINDOW clause: WINDOW `w` AS (window_spec) [, ...] + // MySQL 8.0 appends a trailing space after the window clause. + if len(stmt.WindowClause) > 0 { + b.WriteString(" window ") + for i, wd := range stmt.WindowClause { + if i > 0 { + b.WriteString(",") + } + b.WriteString("`") + b.WriteString(wd.Name) + b.WriteString("` AS ") + b.WriteString(deparseWindowBody(wd)) + } + b.WriteString(" ") + } + + // ORDER BY clause + if len(stmt.OrderBy) > 0 { + b.WriteString(" order by ") + for i, item := range stmt.OrderBy { + if i > 0 { + b.WriteString(",") + } + b.WriteString(deparseExpr(item.Expr)) + if item.Desc { + b.WriteString(" desc") + } + } + } + + // LIMIT clause + if stmt.Limit != nil { + b.WriteString(" limit ") + if stmt.Limit.Offset != nil { + // MySQL comma syntax: LIMIT offset,count + b.WriteString(deparseExpr(stmt.Limit.Offset)) + b.WriteString(",") + } + b.WriteString(deparseExpr(stmt.Limit.Count)) + } + + // FOR UPDATE / FOR SHARE / LOCK IN SHARE MODE + if stmt.ForUpdate != nil { + b.WriteString(" ") + b.WriteString(deparseForUpdate(stmt.ForUpdate)) + } + + return b.String() +} + +// deparseForUpdate formats a FOR UPDATE / FOR SHARE / LOCK IN SHARE MODE clause. +// MySQL 8.0 format: +// - for update +// - for share +// - lock in share mode (legacy syntax) +// - for update of `t` +// - for update nowait +// - for update skip locked +func deparseForUpdate(fu *ast.ForUpdate) string { + if fu.LockInShareMode { + return "lock in share mode" + } + + var b strings.Builder + if fu.Share { + b.WriteString("for share") + } else { + b.WriteString("for update") + } + + // OF table list + if len(fu.Tables) > 0 { + b.WriteString(" of ") + for i, tbl := range fu.Tables { + if i > 0 { + b.WriteString(",") + } + b.WriteString("`") + b.WriteString(tbl.Name) + b.WriteString("`") + } + } + + // NOWAIT / SKIP LOCKED + if fu.NoWait { + b.WriteString(" nowait") + } else if fu.SkipLocked { + b.WriteString(" skip locked") + } + + return b.String() +} + +// deparseSetOperation formats a set operation (UNION, INTERSECT, EXCEPT). +// MySQL 8.0 format: select ... union [all] select ... (flat, no parens around sub-selects) +// CTEs from the leftmost child are hoisted and emitted before the entire set operation. +func deparseSetOperation(stmt *ast.SelectStmt) string { + // Hoist CTEs from the leftmost descendant + var ctePrefix string + if ctes := extractCTEs(stmt); len(ctes) > 0 { + ctePrefix = deparseCTEs(ctes) + " " + } + + left := deparseSelectStmt(stmt.Left) + right := deparseSelectStmt(stmt.Right) + + var op string + switch stmt.SetOp { + case ast.SetOpUnion: + if stmt.SetAll { + op = "union all" + } else { + op = "union" + } + case ast.SetOpIntersect: + if stmt.SetAll { + op = "intersect all" + } else { + op = "intersect" + } + case ast.SetOpExcept: + if stmt.SetAll { + op = "except all" + } else { + op = "except" + } + } + + var b strings.Builder + b.WriteString(ctePrefix) + b.WriteString(left) + b.WriteString(" ") + b.WriteString(op) + b.WriteString(" ") + b.WriteString(right) + + // ORDER BY clause (applies to the entire set operation) + if len(stmt.OrderBy) > 0 { + b.WriteString(" order by ") + for i, item := range stmt.OrderBy { + if i > 0 { + b.WriteString(",") + } + b.WriteString(deparseExpr(item.Expr)) + if item.Desc { + b.WriteString(" desc") + } + } + } + + // LIMIT clause (applies to the entire set operation) + if stmt.Limit != nil { + b.WriteString(" limit ") + if stmt.Limit.Offset != nil { + b.WriteString(deparseExpr(stmt.Limit.Offset)) + b.WriteString(",") + } + b.WriteString(deparseExpr(stmt.Limit.Count)) + } + + return b.String() +} + +// extractCTEs walks down the left spine of a set operation tree and extracts +// CTEs from the leftmost leaf SelectStmt, clearing them so they aren't emitted +// again by deparseSelectStmt. +func extractCTEs(stmt *ast.SelectStmt) []*ast.CommonTableExpr { + // Walk to the leftmost leaf + cur := stmt + for cur.SetOp != ast.SetOpNone && cur.Left != nil { + cur = cur.Left + } + if len(cur.CTEs) > 0 { + ctes := cur.CTEs + cur.CTEs = nil // prevent double emission + return ctes + } + return nil +} + +// deparseCTEs formats a WITH clause (one or more CTEs). +// MySQL 8.0 format: with [recursive] `name` [(`col`, ...)] as (select ...) [, ...] +func deparseCTEs(ctes []*ast.CommonTableExpr) string { + var b strings.Builder + b.WriteString("with ") + + // Check if any CTE is recursive (the flag is per-CTE in the AST + // but WITH RECURSIVE applies to the whole clause in SQL) + recursive := false + for _, cte := range ctes { + if cte.Recursive { + recursive = true + break + } + } + if recursive { + b.WriteString("recursive ") + } + + for i, cte := range ctes { + if i > 0 { + b.WriteString(", ") + } + b.WriteString("`") + b.WriteString(cte.Name) + b.WriteString("`") + + // Column list + if len(cte.Columns) > 0 { + b.WriteString(" (") + for j, col := range cte.Columns { + if j > 0 { + b.WriteString(",") + } + b.WriteString("`") + b.WriteString(col) + b.WriteString("`") + } + b.WriteString(")") + } + + b.WriteString(" as (") + if cte.Select != nil { + b.WriteString(deparseSelectStmt(cte.Select)) + } + b.WriteString(")") + } + + return b.String() +} + +// deparseResTargetNoAlias formats a target list entry without the AS alias. +// Used for subquery contexts (IN, EXISTS) where MySQL omits aliases. +func deparseResTargetNoAlias(node ast.ExprNode) string { + if rt, ok := node.(*ast.ResTarget); ok { + return deparseExpr(rt.Val) + } + return deparseExpr(node) +} + +// deparseResTarget formats a single result target in the SELECT list. +// MySQL 8.0 SHOW CREATE VIEW format: expr AS `alias` +// - Always uses AS keyword +// - Alias is always backtick-quoted +// - Auto-alias: column ref → column name; literal → literal text; expression → expression text +func deparseResTarget(node ast.ExprNode, position int) string { + rt, isRT := node.(*ast.ResTarget) + + var expr ast.ExprNode + var explicitAlias string + if isRT { + expr = rt.Val + explicitAlias = rt.Name + } else { + expr = node + } + + exprStr := deparseExpr(expr) + + // Determine alias + alias := explicitAlias + if alias == "" { + alias = autoAlias(expr, exprStr, position) + } + + // MySQL 8.0 uses double-space before AS for window function expressions. + // The OVER clause already ends with " )", and MySQL adds an extra space. + if hasWindowFunction(expr) { + return exprStr + " AS `" + alias + "`" + } + return exprStr + " AS `" + alias + "`" +} + +// hasWindowFunction checks if an expression is or contains a window function +// with an inline OVER (...) clause (not a named window reference like OVER w). +// MySQL 8.0 uses double-space before AS only for inline window definitions. +func hasWindowFunction(node ast.ExprNode) bool { + if fc, ok := node.(*ast.FuncCallExpr); ok && fc.Over != nil { + // Named window reference (OVER w) — no double space + if fc.Over.RefName != "" && len(fc.Over.PartitionBy) == 0 && len(fc.Over.OrderBy) == 0 && fc.Over.Frame == nil { + return false + } + return true + } + return false +} + +// autoAlias generates an automatic alias for a SELECT target expression. +// MySQL 8.0 rules: +// - Column ref → column name (unqualified) +// - Literal → literal text representation +// - Short expression → expression text (without backtick quoting) +// - Long/complex expression → Name_exp_N +func autoAlias(expr ast.ExprNode, exprStr string, position int) string { + switch n := expr.(type) { + case *ast.ColumnRef: + return n.Column + case *ast.IntLit: + return fmt.Sprintf("%d", n.Value) + case *ast.FloatLit: + return n.Value + case *ast.StringLit: + if n.Value == "" { + return fmt.Sprintf("Name_exp_%d", position) + } + return n.Value + case *ast.NullLit: + return "NULL" + case *ast.BoolLit: + if n.Value { + return "TRUE" + } + return "FALSE" + case *ast.HexLit: + // MySQL 8.0 preserves original literal form in auto-alias. + // 0xFF stays as "0xFF"; X'FF' form stored as "FF" → "X'FF'" + val := n.Value + if strings.HasPrefix(val, "0x") || strings.HasPrefix(val, "0X") { + return val // preserve original case: 0xFF + } + return "X'" + val + "'" + case *ast.BitLit: + // MySQL 8.0 preserves original literal form in auto-alias. + // 0b1010 stays as "0b1010"; b'1010' form stored as "1010" → "b'1010'" + val := n.Value + if strings.HasPrefix(val, "0b") || strings.HasPrefix(val, "0B") { + return val + } + return "b'" + val + "'" + case *ast.TemporalLit: + return n.Type + " '" + n.Value + "'" + default: + // For expressions: generate a human-readable alias text without backtick quoting. + // MySQL 8.0 uses the original expression text for the alias. + aliasText := deparseExprAlias(expr) + if len(aliasText) > 64 { + return fmt.Sprintf("Name_exp_%d", position) + } + return aliasText + } +} + +// deparseExprAlias generates a human-readable expression text for use as an auto-alias. +// Unlike deparseExpr, this preserves the original expression text style matching MySQL 8.0's +// auto-alias behavior: function names stay uppercase, spaces after commas, CASE/CAST +// keywords uppercase, no charset addition in CAST, COUNT(*) keeps *. +func deparseExprAlias(node ast.ExprNode) string { + switch n := node.(type) { + case *ast.ColumnRef: + if n.Table != "" { + return n.Table + "." + n.Column + } + return n.Column + case *ast.IntLit: + return fmt.Sprintf("%d", n.Value) + case *ast.FloatLit: + return n.Value + case *ast.StringLit: + // When embedded in an expression, include quotes to match MySQL 8.0's behavior. + // The top-level autoAlias handles standalone StringLit without quotes. + // Escape backslashes and single quotes like deparseStringLit does. + escaped := strings.ReplaceAll(n.Value, `\`, `\\`) + escaped = strings.ReplaceAll(escaped, `'`, `\'`) + return "'" + escaped + "'" + case *ast.NullLit: + return "NULL" + case *ast.BoolLit: + if n.Value { + return "TRUE" + } + return "FALSE" + case *ast.HexLit: + val := n.Value + if strings.HasPrefix(val, "0x") || strings.HasPrefix(val, "0X") { + return val + } + return "X'" + val + "'" + case *ast.BitLit: + val := n.Value + if strings.HasPrefix(val, "0b") || strings.HasPrefix(val, "0B") { + return val + } + return "b'" + val + "'" + case *ast.TemporalLit: + return n.Type + " '" + n.Value + "'" + case *ast.BinaryExpr: + left := deparseExprAlias(n.Left) + right := deparseExprAlias(n.Right) + // Special operator aliases for REGEXP, ->, ->> + switch n.Op { + case ast.BinOpRegexp: + return left + " REGEXP " + right + case ast.BinOpJsonExtract: + return left + "->" + right + case ast.BinOpJsonUnquote: + return left + "->>" + right + case ast.BinOpSoundsLike: + return left + " SOUNDS LIKE " + right + } + op := binaryOpToStringAlias(n.Op) + // Use original operator text for alias when available (e.g., "MOD" instead of "%", "!=" instead of "<>") + if n.OriginalOp != "" { + op = n.OriginalOp + } + return left + " " + op + " " + right + case *ast.UnaryExpr: + operand := deparseExprAlias(n.Operand) + switch n.Op { + case ast.UnaryMinus: + return "-" + operand + case ast.UnaryPlus: + return operand + case ast.UnaryNot: + // NOT REGEXP → "a NOT REGEXP 'pattern'" (NOT between left and REGEXP) + if binExpr, ok := unwrapParen(n.Operand).(*ast.BinaryExpr); ok && binExpr.Op == ast.BinOpRegexp { + return deparseExprAlias(binExpr.Left) + " NOT REGEXP " + deparseExprAlias(binExpr.Right) + } + // MySQL 8.0 preserves ! vs NOT in auto-alias. + if n.OriginalOp == "!" { + return "!" + operand + } + return "NOT " + operand + case ast.UnaryBitNot: + return "~" + operand + } + return operand + case *ast.FuncCallExpr: + // MySQL 8.0 auto-alias preserves original function name case (uppercase from parser). + name := n.Name + upperName := strings.ToUpper(name) + + // Handle TRIM directional forms: TRIM_LEADING → TRIM(LEADING 'x' FROM a) + switch upperName { + case "TRIM_LEADING": + if len(n.Args) == 2 { + return "TRIM(LEADING " + deparseExprAlias(n.Args[0]) + " FROM " + deparseExprAlias(n.Args[1]) + ")" + } + case "TRIM_TRAILING": + if len(n.Args) == 2 { + return "TRIM(TRAILING " + deparseExprAlias(n.Args[0]) + " FROM " + deparseExprAlias(n.Args[1]) + ")" + } + case "TRIM_BOTH": + if len(n.Args) == 2 { + return "TRIM(BOTH " + deparseExprAlias(n.Args[0]) + " FROM " + deparseExprAlias(n.Args[1]) + ")" + } + } + + // Handle GROUP_CONCAT: alias includes ORDER BY and SEPARATOR + if upperName == "GROUP_CONCAT" { + return deparseGroupConcatAlias(n) + } + + if n.Star { + // COUNT(*) → alias "COUNT(*)" — keep *, not 0. + result := name + "(*)" + if n.Over != nil { + result += " " + deparseWindowDefAlias(n.Over) + } + return result + } + // Zero-arg keyword functions without explicit parens: alias is just the keyword name. + // e.g., CURRENT_TIMESTAMP → alias "CURRENT_TIMESTAMP" (no parens). + // With parens: CURRENT_TIMESTAMP() → alias "CURRENT_TIMESTAMP()". + if len(n.Args) == 0 && !n.HasParens { + return name + } + args := make([]string, len(n.Args)) + for i, arg := range n.Args { + args[i] = deparseExprAlias(arg) + } + var result string + if n.Distinct { + result = name + "(DISTINCT " + strings.Join(args, ", ") + ")" + } else { + result = name + "(" + strings.Join(args, ", ") + ")" + } + if n.Over != nil { + result += " " + deparseWindowDefAlias(n.Over) + } + return result + case *ast.ParenExpr: + return "(" + deparseExprAlias(n.Expr) + ")" + case *ast.CastExpr: + // MySQL 8.0 auto-alias: "CAST(a AS CHAR)" — uppercase keywords, no charset. + return "CAST(" + deparseExprAlias(n.Expr) + " AS " + deparseDataTypeAlias(n.TypeName) + ")" + case *ast.ConvertExpr: + if n.Charset != "" { + return "CONVERT(" + deparseExprAlias(n.Expr) + " USING " + strings.ToLower(n.Charset) + ")" + } + // MySQL 8.0 auto-alias preserves "CONVERT(a, CHAR)" form (comma-separated). + return "CONVERT(" + deparseExprAlias(n.Expr) + ", " + deparseDataTypeAlias(n.TypeName) + ")" + case *ast.CaseExpr: + // MySQL 8.0 auto-alias: "CASE WHEN a > 0 THEN 'pos' ELSE 'neg' END" — uppercase keywords. + var b strings.Builder + b.WriteString("CASE") + if n.Operand != nil { + b.WriteString(" ") + b.WriteString(deparseExprAlias(n.Operand)) + } + for _, w := range n.Whens { + b.WriteString(" WHEN ") + b.WriteString(deparseExprAlias(w.Cond)) + b.WriteString(" THEN ") + b.WriteString(deparseExprAlias(w.Result)) + } + if n.Default != nil { + b.WriteString(" ELSE ") + b.WriteString(deparseExprAlias(n.Default)) + } + b.WriteString(" END") + return b.String() + case *ast.IsExpr: + expr := deparseExprAlias(n.Expr) + switch n.Test { + case ast.IsNull: + if n.Not { + return expr + " IS NOT NULL" + } + return expr + " IS NULL" + case ast.IsTrue: + if n.Not { + return expr + " IS NOT TRUE" + } + return expr + " IS TRUE" + case ast.IsFalse: + if n.Not { + return expr + " IS NOT FALSE" + } + return expr + " IS FALSE" + case ast.IsUnknown: + if n.Not { + return expr + " IS NOT UNKNOWN" + } + return expr + " IS UNKNOWN" + } + return deparseExpr(node) + case *ast.InExpr: + expr := deparseExprAlias(n.Expr) + keyword := "IN" + if n.Not { + keyword = "NOT IN" + } + if n.Select != nil { + return expr + " " + keyword + " (...)" + } + items := make([]string, len(n.List)) + for i, item := range n.List { + items[i] = deparseExprAlias(item) + } + return expr + " " + keyword + " (" + strings.Join(items, ", ") + ")" + case *ast.BetweenExpr: + expr := deparseExprAlias(n.Expr) + low := deparseExprAlias(n.Low) + high := deparseExprAlias(n.High) + keyword := "BETWEEN" + if n.Not { + keyword = "NOT BETWEEN" + } + return expr + " " + keyword + " " + low + " AND " + high + case *ast.LikeExpr: + expr := deparseExprAlias(n.Expr) + pattern := deparseExprAlias(n.Pattern) + keyword := "LIKE" + if n.Not { + keyword = "NOT LIKE" + } + result := expr + " " + keyword + " " + pattern + if n.Escape != nil { + result += " ESCAPE " + deparseExprAlias(n.Escape) + } + return result + case *ast.ExistsExpr: + // MySQL 8.0 auto-alias for EXISTS: "EXISTS(SELECT 1 FROM t WHERE a > 0)" — uppercase keywords, + // unqualified column names, no backtick quoting, no column aliases. + if n.Select != nil { + return "EXISTS(" + deparseSelectStmtAlias(n.Select) + ")" + } + return "EXISTS(/* subquery */)" + case *ast.SubqueryExpr: + // MySQL 8.0 auto-alias for subquery: "(SELECT MAX(a) FROM t)" — uppercase keywords, + // unqualified column names, no backtick quoting, no column aliases. + if n.Select != nil { + return "(" + deparseSelectStmtAlias(n.Select) + ")" + } + return "(/* subquery */)" + case *ast.CollateExpr: + // MySQL 8.0 auto-alias: "a COLLATE utf8mb4_unicode_ci" — uppercase COLLATE, no parens. + return deparseExprAlias(n.Expr) + " COLLATE " + n.Collation + case *ast.IntervalExpr: + // MySQL 8.0 auto-alias: "INTERVAL 1 DAY" — uppercase keywords. + return "INTERVAL " + deparseExprAlias(n.Value) + " " + strings.ToUpper(n.Unit) + default: + // Fallback: use the regular deparsed text + return deparseExpr(node) + } +} + +// deparseSelectStmtAlias generates a human-readable SELECT statement for auto-alias purposes. +// MySQL 8.0 uses uppercase keywords, unqualified column names, and no column aliases. +func deparseSelectStmtAlias(stmt *ast.SelectStmt) string { + var b strings.Builder + b.WriteString("SELECT ") + + // Target list (no aliases) + for i, target := range stmt.TargetList { + if i > 0 { + b.WriteString(", ") + } + if rt, ok := target.(*ast.ResTarget); ok { + b.WriteString(deparseExprAlias(rt.Val)) + } else { + b.WriteString(deparseExprAlias(target)) + } + } + + // FROM clause + if len(stmt.From) > 0 { + b.WriteString(" FROM ") + for i, tbl := range stmt.From { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(deparseTableExprAlias(tbl)) + } + } + + // WHERE clause + if stmt.Where != nil { + b.WriteString(" WHERE ") + b.WriteString(deparseExprAlias(stmt.Where)) + } + + return b.String() +} + +// deparseTableExprAlias formats a table expression for auto-alias purposes. +func deparseTableExprAlias(tbl ast.TableExpr) string { + switch t := tbl.(type) { + case *ast.TableRef: + if t.Alias != "" { + return t.Name + " " + t.Alias + } + return t.Name + default: + return deparseExprAlias(tbl.(ast.ExprNode)) + } +} + +// deparseDataTypeAlias formats a data type for auto-alias purposes. +// Unlike deparseDataType, this does NOT add charset (matching MySQL 8.0's alias behavior). +func deparseDataTypeAlias(dt *ast.DataType) string { + if dt == nil { + return "" + } + name := strings.ToUpper(dt.Name) + switch strings.ToLower(dt.Name) { + case "char": + if dt.Length > 0 { + return fmt.Sprintf("%s(%d)", name, dt.Length) + } + return name + case "binary": + if dt.Length > 0 { + return fmt.Sprintf("%s(%d)", name, dt.Length) + } + return name + case "signed", "signed integer": + return "SIGNED" + case "unsigned", "unsigned integer": + return "UNSIGNED" + case "decimal": + if dt.Scale > 0 { + return fmt.Sprintf("DECIMAL(%d,%d)", dt.Length, dt.Scale) + } + if dt.Length > 0 { + return fmt.Sprintf("DECIMAL(%d)", dt.Length) + } + return "DECIMAL" + default: + return name + } +} + +// deparseImplicitCrossJoin normalizes multiple FROM tables (implicit cross join) +// into explicit join syntax with parentheses. +// e.g., FROM t1, t2, t3 → ((`t1` join `t2`) join `t3`) +func deparseImplicitCrossJoin(tables []ast.TableExpr) string { + if len(tables) == 0 { + return "" + } + result := deparseTableExpr(tables[0]) + for i := 1; i < len(tables); i++ { + result = "(" + result + " join " + deparseTableExpr(tables[i]) + ")" + } + return result +} + +// deparseTableExpr formats a table expression in the FROM clause. +func deparseTableExpr(tbl ast.TableExpr) string { + switch t := tbl.(type) { + case *ast.TableRef: + return deparseTableRef(t) + case *ast.JoinClause: + return deparseJoinClause(t) + case *ast.SubqueryExpr: + return deparseSubqueryTableExpr(t) + default: + return fmt.Sprintf("/* unsupported table expr: %T */", tbl) + } +} + +// deparseJoinClause formats a JOIN clause. +// MySQL 8.0 format: (`t1` join `t2` on((...))) +func deparseJoinClause(j *ast.JoinClause) string { + left := deparseTableExpr(j.Left) + right := deparseTableExpr(j.Right) + + var joinType string + switch j.Type { + case ast.JoinInner: + joinType = "join" + case ast.JoinLeft: + joinType = "left join" + case ast.JoinRight: + // RIGHT JOIN → LEFT JOIN with table swap + joinType = "left join" + left, right = right, left + case ast.JoinCross: + // CROSS JOIN → plain join + joinType = "join" + case ast.JoinStraight: + joinType = "straight_join" + case ast.JoinNatural: + joinType = "join" + case ast.JoinNaturalLeft: + joinType = "left join" + case ast.JoinNaturalRight: + joinType = "left join" + left, right = right, left + default: + joinType = "join" + } + + var b strings.Builder + b.WriteString("(") + b.WriteString(left) + b.WriteString(" ") + b.WriteString(joinType) + b.WriteString(" ") + b.WriteString(right) + + // ON condition + if j.Condition != nil { + switch cond := j.Condition.(type) { + case *ast.OnCondition: + b.WriteString(" on(") + b.WriteString(deparseExpr(cond.Expr)) + b.WriteString(")") + case *ast.UsingCondition: + // USING (col1, col2) → on((`left`.`col1` = `right`.`col1`) and (...)) + // Requires resolving table names from Left/Right table expressions. + // For RIGHT JOIN, left/right are already swapped above. + leftName := tableExprName(j.Left) + rightName := tableExprName(j.Right) + if j.Type == ast.JoinRight { + // Tables were swapped above, so swap names to match original SQL + leftName, rightName = rightName, leftName + } + b.WriteString(" on(") + b.WriteString(deparseUsingAsOn(cond.Columns, leftName, rightName)) + b.WriteString(")") + } + } + + b.WriteString(")") + return b.String() +} + +// tableExprName extracts the effective name (alias or table name) from a table expression. +// Used for USING → ON expansion to qualify column references. +func tableExprName(tbl ast.TableExpr) string { + switch t := tbl.(type) { + case *ast.TableRef: + if t.Alias != "" { + return t.Alias + } + return t.Name + default: + return "" + } +} + +// deparseUsingAsOn expands USING columns into ON condition format. +// e.g., USING (a, b) with left=t1, right=t2 → (`t1`.`a` = `t2`.`a`) and (`t1`.`b` = `t2`.`b`) +// MySQL 8.0 format: on((`t1`.`a` = `t2`.`a`)) +func deparseUsingAsOn(columns []string, leftName, rightName string) string { + if len(columns) == 0 { + return "" + } + parts := make([]string, len(columns)) + for i, col := range columns { + parts[i] = "(`" + leftName + "`.`" + col + "` = `" + rightName + "`.`" + col + "`)" + } + if len(parts) == 1 { + return parts[0] + } + // Multiple columns: chain with "and" + result := parts[0] + for i := 1; i < len(parts); i++ { + result = "(" + result + " and " + parts[i] + ")" + } + return result +} + +// deparseSubqueryTableExpr formats a derived table (subquery as table expression). +// MySQL 8.0 format: (select ...) `alias` — no AS keyword for alias +func deparseSubqueryTableExpr(s *ast.SubqueryExpr) string { + var b strings.Builder + b.WriteString("(") + if s.Select != nil { + b.WriteString(deparseSelectStmt(s.Select)) + } + b.WriteString(")") + if s.Alias != "" { + b.WriteString(" `") + b.WriteString(s.Alias) + b.WriteString("`") + } + return b.String() +} + +// deparseTableRef formats a simple table reference. +func deparseTableRef(t *ast.TableRef) string { + var b strings.Builder + if t.Schema != "" { + b.WriteString("`") + b.WriteString(t.Schema) + b.WriteString("`.") + } + b.WriteString("`") + b.WriteString(t.Name) + b.WriteString("`") + if t.Alias != "" { + b.WriteString(" `") + b.WriteString(t.Alias) + b.WriteString("`") + } + return b.String() +} + +func deparseExpr(node ast.ExprNode) string { + switch n := node.(type) { + case *ast.IntLit: + return fmt.Sprintf("%d", n.Value) + case *ast.FloatLit: + return n.Value + case *ast.BoolLit: + if n.Value { + return "true" + } + return "false" + case *ast.StringLit: + return deparseStringLit(n) + case *ast.NullLit: + return "NULL" + case *ast.HexLit: + return deparseHexLit(n) + case *ast.BitLit: + return deparseBitLit(n) + case *ast.TemporalLit: + return n.Type + "'" + n.Value + "'" + case *ast.BinaryExpr: + return deparseBinaryExpr(n) + case *ast.ColumnRef: + return deparseColumnRef(n) + case *ast.UnaryExpr: + return deparseUnaryExpr(n) + case *ast.ParenExpr: + return deparseExpr(n.Expr) + case *ast.InExpr: + return deparseInExpr(n) + case *ast.BetweenExpr: + return deparseBetweenExpr(n) + case *ast.LikeExpr: + return deparseLikeExpr(n) + case *ast.IsExpr: + return deparseIsExpr(n) + case *ast.RowExpr: + return deparseRowExpr(n) + case *ast.CaseExpr: + return deparseCaseExpr(n) + case *ast.CastExpr: + return deparseCastExpr(n) + case *ast.ConvertExpr: + return deparseConvertExpr(n) + case *ast.IntervalExpr: + return deparseIntervalExpr(n) + case *ast.CollateExpr: + return deparseCollateExpr(n) + case *ast.FuncCallExpr: + return deparseFuncCallExpr(n) + case *ast.ExistsExpr: + return deparseExistsExpr(n) + case *ast.SubqueryExpr: + return deparseSubqueryExpr(n) + default: + return fmt.Sprintf("/* unsupported: %T */", node) + } +} + +func deparseStringLit(n *ast.StringLit) string { + // MySQL 8.0 uses backslash escaping for single quotes: '' → \' + // and preserves backslashes as-is. + escaped := strings.ReplaceAll(n.Value, `\`, `\\`) + escaped = strings.ReplaceAll(escaped, `'`, `\'`) + if n.Charset != "" { + return n.Charset + "'" + escaped + "'" + } + return "'" + escaped + "'" +} + +func deparseHexLit(n *ast.HexLit) string { + // MySQL 8.0 normalizes all hex literals to 0x lowercase form. + // HexLit.Value is either "0xFF" (0x prefix form) or "FF" (X'' form). + val := n.Value + if strings.HasPrefix(val, "0x") || strings.HasPrefix(val, "0X") { + // Already has 0x prefix — just lowercase + return "0x" + strings.ToLower(val[2:]) + } + // X'FF' form — value is just the hex digits + return "0x" + strings.ToLower(val) +} + +func deparseBitLit(n *ast.BitLit) string { + // MySQL 8.0 converts all bit literals to hex form. + // BitLit.Value is either "0b1010" (0b prefix form) or "1010" (b'' form). + val := n.Value + if strings.HasPrefix(val, "0b") || strings.HasPrefix(val, "0B") { + val = val[2:] + } + // Parse binary string to integer, then format as hex + i := new(big.Int) + i.SetString(val, 2) + return "0x" + fmt.Sprintf("%02x", i) +} + +func deparseBinaryExpr(n *ast.BinaryExpr) string { + // Operator-to-function rewrites: + // REGEXP → regexp_like(left, right) + // -> → json_extract(left, right) + // ->> → json_unquote(json_extract(left, right)) + switch n.Op { + case ast.BinOpRegexp: + return "regexp_like(" + deparseExpr(n.Left) + "," + deparseExpr(n.Right) + ")" + case ast.BinOpJsonExtract: + return "json_extract(" + deparseExpr(n.Left) + "," + deparseExpr(n.Right) + ")" + case ast.BinOpJsonUnquote: + return "json_unquote(json_extract(" + deparseExpr(n.Left) + "," + deparseExpr(n.Right) + "))" + } + + left := n.Left + right := n.Right + // MySQL normalizes INTERVAL + expr to expr + INTERVAL (interval on the right) + if _, ok := left.(*ast.IntervalExpr); ok { + if _, ok2 := right.(*ast.IntervalExpr); !ok2 { + left, right = right, left + } + } + leftStr := deparseExpr(left) + rightStr := deparseExpr(right) + op := binaryOpToString(n.Op) + return "(" + leftStr + " " + op + " " + rightStr + ")" +} + +func deparseColumnRef(n *ast.ColumnRef) string { + if n.Schema != "" { + return "`" + n.Schema + "`.`" + n.Table + "`.`" + n.Column + "`" + } + if n.Table != "" { + return "`" + n.Table + "`.`" + n.Column + "`" + } + return "`" + n.Column + "`" +} + +func binaryOpToString(op ast.BinaryOp) string { + switch op { + case ast.BinOpAdd: + return "+" + case ast.BinOpSub: + return "-" + case ast.BinOpMul: + return "*" + case ast.BinOpDiv: + return "/" + case ast.BinOpMod: + return "%" + case ast.BinOpDivInt: + return "DIV" + case ast.BinOpEq: + return "=" + case ast.BinOpNe: + return "<>" + case ast.BinOpLt: + return "<" + case ast.BinOpGt: + return ">" + case ast.BinOpLe: + return "<=" + case ast.BinOpGe: + return ">=" + case ast.BinOpNullSafeEq: + return "<=>" + case ast.BinOpAnd: + return "and" + case ast.BinOpOr: + return "or" + case ast.BinOpXor: + return "xor" + case ast.BinOpBitAnd: + return "&" + case ast.BinOpBitOr: + return "|" + case ast.BinOpBitXor: + return "^" + case ast.BinOpShiftLeft: + return "<<" + case ast.BinOpShiftRight: + return ">>" + case ast.BinOpSoundsLike: + return "sounds like" + default: + return "?" + } +} + +// binaryOpToStringAlias returns the operator string for auto-alias purposes. +// MySQL 8.0 uses uppercase AND/OR/XOR in auto-aliases. +func binaryOpToStringAlias(op ast.BinaryOp) string { + switch op { + case ast.BinOpAnd: + return "AND" + case ast.BinOpOr: + return "OR" + case ast.BinOpXor: + return "XOR" + default: + return binaryOpToString(op) + } +} + +// deparseWindowDefAlias formats a window definition for auto-alias purposes. +// MySQL 8.0 uses: OVER (ORDER BY col) in the alias text, or OVER w for named windows. +func deparseWindowDefAlias(wd *ast.WindowDef) string { + // Named window reference: OVER window_name + if wd.RefName != "" && len(wd.PartitionBy) == 0 && len(wd.OrderBy) == 0 && wd.Frame == nil { + return "OVER " + wd.RefName + } + + var b strings.Builder + b.WriteString("OVER (") + + needSpace := false + if len(wd.PartitionBy) > 0 { + b.WriteString("PARTITION BY ") + for i, expr := range wd.PartitionBy { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(deparseExprAlias(expr)) + } + needSpace = true + } + if len(wd.OrderBy) > 0 { + if needSpace { + b.WriteString(" ") + } + b.WriteString("ORDER BY ") + for i, item := range wd.OrderBy { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(deparseExprAlias(item.Expr)) + if item.Desc { + b.WriteString(" desc") + } + } + needSpace = true + } + if wd.Frame != nil { + if needSpace { + b.WriteString(" ") + } + b.WriteString(deparseWindowFrame(wd.Frame)) + } + b.WriteString(")") + return b.String() +} + +func deparseUnaryExpr(n *ast.UnaryExpr) string { + operand := deparseExpr(n.Operand) + switch n.Op { + case ast.UnaryMinus: + // MySQL 8.0 wraps non-literal operands in parens: -(`t`.`a`) but keeps -5 + switch n.Operand.(type) { + case *ast.IntLit, *ast.FloatLit: + return "-" + operand + default: + return "-(" + operand + ")" + } + case ast.UnaryPlus: + // MySQL drops unary plus entirely + return operand + case ast.UnaryNot: + return "(not(" + operand + "))" + case ast.UnaryBitNot: + return "~(" + operand + ")" + default: + return operand + } +} + +func deparseInExpr(n *ast.InExpr) string { + expr := deparseExpr(n.Expr) + keyword := "in" + if n.Not { + keyword = "not in" + } + + // IN subquery: a IN (SELECT ...) + // MySQL 8.0 does NOT wrap IN subquery in outer parens (unlike IN value list). + // MySQL 8.0 also omits AS aliases in the subquery's target list. + if n.Select != nil { + return expr + " " + keyword + " (" + deparseSelectStmtNoAlias(n.Select) + ")" + } + + // Build the value list with no spaces after commas + items := make([]string, len(n.List)) + for i, item := range n.List { + items[i] = deparseExpr(item) + } + return "(" + expr + " " + keyword + " (" + strings.Join(items, ",") + "))" +} + +func deparseBetweenExpr(n *ast.BetweenExpr) string { + expr := deparseExpr(n.Expr) + low := deparseExpr(n.Low) + high := deparseExpr(n.High) + keyword := "between" + if n.Not { + keyword = "not between" + } + return "(" + expr + " " + keyword + " " + low + " and " + high + ")" +} + +func deparseLikeExpr(n *ast.LikeExpr) string { + expr := deparseExpr(n.Expr) + pattern := deparseExpr(n.Pattern) + likeClause := "(" + expr + " like " + pattern + if n.Escape != nil { + likeClause += " escape " + deparseExpr(n.Escape) + } + likeClause += ")" + if n.Not { + return "(not(" + likeClause + "))" + } + return likeClause +} + +func deparseIsExpr(n *ast.IsExpr) string { + expr := deparseExpr(n.Expr) + var test string + switch n.Test { + case ast.IsNull: + if n.Not { + test = "is not null" + } else { + test = "is null" + } + case ast.IsTrue: + if n.Not { + test = "is not true" + } else { + test = "is true" + } + case ast.IsFalse: + if n.Not { + test = "is not false" + } else { + test = "is false" + } + case ast.IsUnknown: + if n.Not { + test = "is not unknown" + } else { + test = "is unknown" + } + default: + test = "is ?" + } + return "(" + expr + " " + test + ")" +} + +func deparseRowExpr(n *ast.RowExpr) string { + items := make([]string, len(n.Items)) + for i, item := range n.Items { + items[i] = deparseExpr(item) + } + return "row(" + strings.Join(items, ",") + ")" +} + +func deparseCaseExpr(n *ast.CaseExpr) string { + var b strings.Builder + b.WriteString("(case") + if n.Operand != nil { + b.WriteString(" ") + b.WriteString(deparseExpr(n.Operand)) + } + for _, w := range n.Whens { + b.WriteString(" when ") + b.WriteString(deparseExpr(w.Cond)) + b.WriteString(" then ") + b.WriteString(deparseExpr(w.Result)) + } + if n.Default != nil { + b.WriteString(" else ") + b.WriteString(deparseExpr(n.Default)) + } + b.WriteString(" end)") + return b.String() +} + +func deparseCastExpr(n *ast.CastExpr) string { + expr := deparseExpr(n.Expr) + typeName := deparseDataType(n.TypeName) + return "cast(" + expr + " as " + typeName + ")" +} + +func deparseConvertExpr(n *ast.ConvertExpr) string { + expr := deparseExpr(n.Expr) + // CONVERT(expr USING charset) form + if n.Charset != "" { + return "convert(" + expr + " using " + strings.ToLower(n.Charset) + ")" + } + // CONVERT(expr, type) form — MySQL rewrites to CAST + typeName := deparseDataType(n.TypeName) + return "cast(" + expr + " as " + typeName + ")" +} + +func deparseDataType(dt *ast.DataType) string { + if dt == nil { + return "" + } + name := strings.ToLower(dt.Name) + switch name { + case "char": + result := "char" + if dt.Length > 0 { + result += fmt.Sprintf("(%d)", dt.Length) + } + // MySQL adds charset for CHAR in CAST + charset := dt.Charset + if charset == "" { + charset = "utf8mb4" + } + result += " charset " + strings.ToLower(charset) + return result + case "binary": + // CAST to BINARY becomes cast(x as char charset binary) + result := "char" + if dt.Length > 0 { + result += fmt.Sprintf("(%d)", dt.Length) + } + result += " charset binary" + return result + case "signed", "signed integer": + return "signed" + case "unsigned", "unsigned integer": + return "unsigned" + case "decimal": + if dt.Scale > 0 { + return fmt.Sprintf("decimal(%d,%d)", dt.Length, dt.Scale) + } + if dt.Length > 0 { + return fmt.Sprintf("decimal(%d)", dt.Length) + } + return "decimal" + case "date": + return "date" + case "datetime": + if dt.Length > 0 { + return fmt.Sprintf("datetime(%d)", dt.Length) + } + return "datetime" + case "time": + if dt.Length > 0 { + return fmt.Sprintf("time(%d)", dt.Length) + } + return "time" + case "json": + return "json" + case "float": + return "float" + case "double": + return "double" + default: + return name + } +} + +func deparseIntervalExpr(n *ast.IntervalExpr) string { + val := deparseExpr(n.Value) + return "interval " + val + " " + strings.ToLower(n.Unit) +} + +func deparseCollateExpr(n *ast.CollateExpr) string { + expr := deparseExpr(n.Expr) + return "(" + expr + " collate " + n.Collation + ")" +} + +// funcNameRewrites maps uppercase function names to their MySQL 8.0 canonical forms. +// These rewrites are applied by SHOW CREATE VIEW in MySQL 8.0. +var funcNameRewrites = map[string]string{ + "SUBSTRING": "substr", + "CURRENT_TIMESTAMP": "now", + "CURRENT_DATE": "curdate", + "CURRENT_TIME": "curtime", + "CURRENT_USER": "current_user", + "NOW": "now", + "LOCALTIME": "now", + "LOCALTIMESTAMP": "now", +} + +// deparseTrimDirectional handles TRIM(LEADING|TRAILING|BOTH remstr FROM str). +// MySQL 8.0 SHOW CREATE VIEW format: trim(leading 'x' from `a`) +func deparseTrimDirectional(direction string, args []ast.ExprNode) string { + if len(args) == 2 { + remstr := deparseExpr(args[0]) + str := deparseExpr(args[1]) + return "trim(" + direction + " " + remstr + " from " + str + ")" + } + // Fallback: single arg (shouldn't happen for directional, but be safe) + if len(args) == 1 { + return "trim(" + direction + " " + deparseExpr(args[0]) + ")" + } + return "trim()" +} + +func deparseFuncCallExpr(n *ast.FuncCallExpr) string { + // Handle TRIM special forms: TRIM_LEADING, TRIM_TRAILING, TRIM_BOTH + // Parser encodes these as FuncCallExpr with Name="TRIM_LEADING" etc. + // Args: [remstr, str] for directional forms + name := strings.ToUpper(n.Name) + switch name { + case "TRIM_LEADING": + return deparseTrimDirectional("leading", n.Args) + case "TRIM_TRAILING": + return deparseTrimDirectional("trailing", n.Args) + case "TRIM_BOTH": + return deparseTrimDirectional("both", n.Args) + } + + // GROUP_CONCAT has special formatting + if name == "GROUP_CONCAT" { + return deparseGroupConcat(n) + } + + // Determine the canonical function name + canonical, ok := funcNameRewrites[name] + if !ok { + canonical = strings.ToLower(n.Name) + } + + // Schema-qualified name + if n.Schema != "" { + canonical = strings.ToLower(n.Schema) + "." + canonical + } + + // Zero-arg functions (CURRENT_TIMESTAMP, NOW(), etc.) — always emit parens + if len(n.Args) == 0 && !n.Star { + result := canonical + "()" + if n.Over != nil { + result += " " + deparseWindowDef(n.Over) + } + return result + } + + // COUNT(*) — MySQL 8.0 rewrites COUNT(*) to count(0) + if n.Star { + result := canonical + "(0)" + if n.Over != nil { + result += " " + deparseWindowDef(n.Over) + } + return result + } + + // Build argument list with no spaces after commas + args := make([]string, len(n.Args)) + for i, arg := range n.Args { + args[i] = deparseExpr(arg) + } + + var argStr string + if n.Distinct { + argStr = "distinct " + strings.Join(args, ",") + } else { + argStr = strings.Join(args, ",") + } + + result := canonical + "(" + argStr + ")" + + // Append OVER clause for window functions + if n.Over != nil { + result += " " + deparseWindowDef(n.Over) + } + + return result +} + +// deparseWindowDef formats a window definition. +// MySQL 8.0 format: OVER (PARTITION BY ... ORDER BY ... frame_clause ) +// Note: trailing space before closing paren, uppercase keywords. +func deparseWindowDef(wd *ast.WindowDef) string { + // Named window reference: OVER `window_name` + if wd.RefName != "" && len(wd.PartitionBy) == 0 && len(wd.OrderBy) == 0 && wd.Frame == nil { + return "OVER `" + wd.RefName + "`" + } + + return "OVER " + deparseWindowBody(wd) +} + +// deparseWindowBody formats the body of a window definition: (PARTITION BY ... ORDER BY ... frame_clause ) +// Used by both OVER clauses and WINDOW named window definitions. +func deparseWindowBody(wd *ast.WindowDef) string { + var b strings.Builder + b.WriteString("(") + + needSpace := false + + // PARTITION BY + if len(wd.PartitionBy) > 0 { + b.WriteString("PARTITION BY ") + for i, expr := range wd.PartitionBy { + if i > 0 { + b.WriteString(",") + } + b.WriteString(deparseExpr(expr)) + } + needSpace = true + } + + // ORDER BY + if len(wd.OrderBy) > 0 { + if needSpace { + b.WriteString(" ") + } + b.WriteString("ORDER BY ") + for i, item := range wd.OrderBy { + if i > 0 { + b.WriteString(",") + } + b.WriteString(deparseExpr(item.Expr)) + if item.Desc { + b.WriteString(" desc") + } + } + needSpace = true + } + + // Frame clause + if wd.Frame != nil { + if needSpace { + b.WriteString(" ") + } + b.WriteString(deparseWindowFrame(wd.Frame)) + // When frame clause is present, MySQL 8.0 has no trailing space before ) + b.WriteString(")") + } else { + // No frame clause — MySQL 8.0 puts a trailing space before ) + b.WriteString(" )") + } + + return b.String() +} + +// deparseWindowFrame formats a window frame specification. +// MySQL 8.0 format: ROWS/RANGE/GROUPS BETWEEN start AND end (all uppercase). +func deparseWindowFrame(f *ast.WindowFrame) string { + var b strings.Builder + + // Frame type + switch f.Type { + case ast.FrameRows: + b.WriteString("ROWS") + case ast.FrameRange: + b.WriteString("RANGE") + case ast.FrameGroups: + b.WriteString("GROUPS") + } + + if f.End != nil { + // BETWEEN ... AND ... form + b.WriteString(" BETWEEN ") + b.WriteString(deparseWindowFrameBound(f.Start)) + b.WriteString(" AND ") + b.WriteString(deparseWindowFrameBound(f.End)) + } else { + // Single bound form + b.WriteString(" ") + b.WriteString(deparseWindowFrameBound(f.Start)) + } + + return b.String() +} + +// deparseWindowFrameBound formats a window frame bound. +func deparseWindowFrameBound(fb *ast.WindowFrameBound) string { + switch fb.Type { + case ast.BoundUnboundedPreceding: + return "UNBOUNDED PRECEDING" + case ast.BoundPreceding: + return deparseExpr(fb.Offset) + " PRECEDING" + case ast.BoundCurrentRow: + return "CURRENT ROW" + case ast.BoundFollowing: + return deparseExpr(fb.Offset) + " FOLLOWING" + case ast.BoundUnboundedFollowing: + return "UNBOUNDED FOLLOWING" + default: + return "/* unknown bound */" + } +} + +// deparseExistsExpr formats an EXISTS expression. +// MySQL 8.0 format: exists(select ...) — column aliases are omitted inside EXISTS. +func deparseExistsExpr(n *ast.ExistsExpr) string { + if n.Select != nil { + return "exists(" + deparseSelectStmtNoAlias(n.Select) + ")" + } + return "exists(/* subquery */)" +} + +// deparseSubqueryExpr formats a subquery expression. +// MySQL 8.0 format: (select ...) — scalar subqueries in expression context +// omit column aliases (AS `...`). +func deparseSubqueryExpr(n *ast.SubqueryExpr) string { + if n.Select != nil { + return "(" + deparseSelectStmtNoAlias(n.Select) + ")" + } + return "(/* subquery */)" +} + +// deparseGroupConcatAlias generates the MySQL 8.0 auto-alias for GROUP_CONCAT. +// MySQL 8.0 alias format: GROUP_CONCAT(DISTINCT a ORDER BY a DESC SEPARATOR ';') +// Uses uppercase keywords and original column names (not table-qualified). +func deparseGroupConcatAlias(n *ast.FuncCallExpr) string { + var b strings.Builder + b.WriteString("GROUP_CONCAT(") + + // DISTINCT + if n.Distinct { + b.WriteString("DISTINCT ") + } + + // Arguments + for i, arg := range n.Args { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(deparseExprAlias(arg)) + } + + // ORDER BY — MySQL 8.0 alias omits ASC (default), only shows DESC + if len(n.OrderBy) > 0 { + b.WriteString(" ORDER BY ") + for i, item := range n.OrderBy { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(deparseExprAlias(item.Expr)) + if item.Desc { + b.WriteString(" DESC") + } + } + } + + // SEPARATOR + b.WriteString(" SEPARATOR ") + if n.Separator != nil { + b.WriteString(deparseExprAlias(n.Separator)) + } else { + b.WriteString("','") + } + + b.WriteString(")") + return b.String() +} + +// deparseGroupConcat handles GROUP_CONCAT with its special syntax: +// group_concat([distinct] expr [order by expr ASC|DESC] separator 'str') +// MySQL 8.0 always shows the separator (default ',') and explicit ASC in ORDER BY. +func deparseGroupConcat(n *ast.FuncCallExpr) string { + var b strings.Builder + b.WriteString("group_concat(") + + // DISTINCT + if n.Distinct { + b.WriteString("distinct ") + } + + // Arguments + for i, arg := range n.Args { + if i > 0 { + b.WriteString(",") + } + b.WriteString(deparseExpr(arg)) + } + + // ORDER BY — MySQL 8.0 always shows explicit ASC/DESC + if len(n.OrderBy) > 0 { + b.WriteString(" order by ") + for i, item := range n.OrderBy { + if i > 0 { + b.WriteString(",") + } + b.WriteString(deparseExpr(item.Expr)) + if item.Desc { + b.WriteString(" DESC") + } else { + b.WriteString(" ASC") + } + } + } + + // SEPARATOR — always shown; default is ',' + b.WriteString(" separator ") + if n.Separator != nil { + b.WriteString(deparseExpr(n.Separator)) + } else { + b.WriteString("','") + } + + b.WriteString(")") + return b.String() +} + diff --git a/tidb/deparse/deparse_test.go b/tidb/deparse/deparse_test.go new file mode 100644 index 00000000..f004543d --- /dev/null +++ b/tidb/deparse/deparse_test.go @@ -0,0 +1,1324 @@ +package deparse + +import ( + "strings" + "testing" + + ast "github.com/bytebase/omni/tidb/ast" + "github.com/bytebase/omni/tidb/parser" +) + +// parseExpr parses a SQL expression by wrapping it in SELECT and extracting +// the first target list expression from the AST. +func parseExpr(t *testing.T, expr string) ast.ExprNode { + t.Helper() + sql := "SELECT " + expr + list, err := parser.Parse(sql) + if err != nil { + t.Fatalf("failed to parse %q: %v", sql, err) + } + if list.Len() == 0 { + t.Fatalf("no statements parsed from %q", sql) + } + sel, ok := list.Items[0].(*ast.SelectStmt) + if !ok { + t.Fatalf("expected SelectStmt, got %T", list.Items[0]) + } + if len(sel.TargetList) == 0 { + t.Fatalf("no target list in SELECT from %q", sql) + } + target := sel.TargetList[0] + // TargetList entries may be ResTarget wrapping the actual expression. + if rt, ok := target.(*ast.ResTarget); ok { + return rt.Val + } + return target +} + +func TestDeparse_Section_1_1_IntFloatNull(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // Integer literals + {"integer_1", "1", "1"}, + {"negative_integer", "-5", "-5"}, + {"large_integer", "9999999999", "9999999999"}, + {"zero", "0", "0"}, + + // Float literals + {"float_1_5", "1.5", "1.5"}, + {"float_with_exponent", "1.5e10", "1.5e10"}, + {"float_zero_point_five", "0.5", "0.5"}, + + // NULL literal + {"null", "NULL", "NULL"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + node := parseExpr(t, tc.input) + got := Deparse(node) + if got != tc.expected { + t.Errorf("Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparse_Section_1_2_BoolStringLiterals(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // Boolean literals + {"true_literal", "TRUE", "true"}, + {"false_literal", "FALSE", "false"}, + + // String literals + {"simple_string", "'hello'", "'hello'"}, + {"string_with_single_quote", "'it''s'", `'it\'s'`}, + {"empty_string", "''", "''"}, + {"string_with_backslash", `'back\\slash'`, `'back\\slash'`}, + + // Charset introducers — skipped: parser doesn't support charset introducers yet + // {"charset_utf8mb4", "_utf8mb4'hello'", "_utf8mb4'hello'"}, + // {"charset_latin1", "_latin1'world'", "_latin1'world'"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + node := parseExpr(t, tc.input) + got := Deparse(node) + if got != tc.expected { + t.Errorf("Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +// TestDeparse_Section_1_2_CharsetIntroducer tests charset introducer deparsing +// using hand-built AST nodes, since the parser doesn't support charset introducers yet. +func TestDeparse_Section_1_2_CharsetIntroducer(t *testing.T) { + cases := []struct { + name string + node ast.ExprNode + expected string + }{ + {"charset_utf8mb4", &ast.StringLit{Value: "hello", Charset: "_utf8mb4"}, "_utf8mb4'hello'"}, + {"charset_latin1", &ast.StringLit{Value: "world", Charset: "_latin1"}, "_latin1'world'"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := Deparse(tc.node) + if got != tc.expected { + t.Errorf("Deparse() = %q, want %q", got, tc.expected) + } + }) + } +} + +func TestDeparse_Section_1_3_HexBitLiterals(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // Hex literals — MySQL normalizes to 0x lowercase form + {"hex_0x_form", "0xFF", "0xff"}, + {"hex_X_quote_form", "X'FF'", "0xff"}, + + // Bit literals — MySQL converts to hex form + {"bit_0b_form", "0b1010", "0x0a"}, + {"bit_b_quote_form", "b'1010'", "0x0a"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + node := parseExpr(t, tc.input) + got := Deparse(node) + if got != tc.expected { + t.Errorf("Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +// TestDeparse_Section_1_3_DateTimeLiterals tests DATE/TIME/TIMESTAMP literal deparsing +// using hand-built AST nodes, since the parser doesn't support temporal literals yet. +// These are marked as [~] partial in SCENARIOS — parser support needed. + +func TestDeparse_Section_2_1_ArithmeticUnary(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // Arithmetic binary operators + {"addition", "a + b", "(`a` + `b`)"}, + {"subtraction", "a - b", "(`a` - `b`)"}, + {"multiplication", "a * b", "(`a` * `b`)"}, + {"division", "a / b", "(`a` / `b`)"}, + {"integer_division", "a DIV b", "(`a` DIV `b`)"}, + {"modulo_MOD", "a MOD b", "(`a` % `b`)"}, + {"modulo_percent", "a % b", "(`a` % `b`)"}, + + // Left-associative chaining + {"left_assoc_chain", "a + b + c", "((`a` + `b`) + `c`)"}, + + // Unary minus (with column ref operand) + {"unary_minus", "-a", "-(`a`)"}, + + // Unary plus — parser drops it entirely, so +a parses as just ColumnRef + {"unary_plus", "+a", "`a`"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + node := parseExpr(t, tc.input) + got := Deparse(node) + if got != tc.expected { + t.Errorf("Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +// TestDeparse_Section_2_1_UnaryPlusAST tests unary plus via hand-built AST +// to verify that deparseUnaryExpr handles UnaryPlus correctly, even though +// the parser drops unary plus before building the AST. +func TestDeparse_Section_2_1_UnaryPlusAST(t *testing.T) { + node := &ast.UnaryExpr{ + Op: ast.UnaryPlus, + Operand: &ast.ColumnRef{Column: "a"}, + } + got := Deparse(node) + expected := "`a`" + if got != expected { + t.Errorf("Deparse(UnaryPlus(a)) = %q, want %q", got, expected) + } +} + +func TestDeparse_Section_2_2_ComparisonOperators(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // Basic comparison operators + {"equal", "a = b", "(`a` = `b`)"}, + {"not_equal_angle", "a <> b", "(`a` <> `b`)"}, + {"not_equal_bang", "a != b", "(`a` <> `b`)"}, // != normalized to <> + {"greater", "a > b", "(`a` > `b`)"}, + {"less", "a < b", "(`a` < `b`)"}, + {"greater_or_equal", "a >= b", "(`a` >= `b`)"}, + {"less_or_equal", "a <= b", "(`a` <= `b`)"}, + {"null_safe_equal", "a <=> b", "(`a` <=> `b`)"}, + {"sounds_like", "a SOUNDS LIKE b", "(`a` sounds like `b`)"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + node := parseExpr(t, tc.input) + got := Deparse(node) + if got != tc.expected { + t.Errorf("Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparse_Section_2_3_BitwiseOperators(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + {"bitwise_or", "a | b", "(`a` | `b`)"}, + {"bitwise_and", "a & b", "(`a` & `b`)"}, + {"bitwise_xor", "a ^ b", "(`a` ^ `b`)"}, + {"left_shift", "a << b", "(`a` << `b`)"}, + {"right_shift", "a >> b", "(`a` >> `b`)"}, + {"bitwise_not", "~a", "~(`a`)"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + node := parseExpr(t, tc.input) + got := Deparse(node) + if got != tc.expected { + t.Errorf("Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +// TestDeparse_Section_2_3_BitwiseNotAST tests bitwise NOT via hand-built AST +// to verify deparseUnaryExpr handles UnaryBitNot correctly. +func TestDeparse_Section_2_3_BitwiseNotAST(t *testing.T) { + node := &ast.UnaryExpr{ + Op: ast.UnaryBitNot, + Operand: &ast.ColumnRef{Column: "a"}, + } + got := Deparse(node) + expected := "~(`a`)" + if got != expected { + t.Errorf("Deparse(UnaryBitNot(a)) = %q, want %q", got, expected) + } +} + +func TestDeparse_Section_2_4_PrecedenceParenthesization(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // Higher precedence preserved: * binds tighter than + + {"higher_prec_preserved", "a + b * c", "(`a` + (`b` * `c`))"}, + // Lower precedence grouping: explicit parens force + before * + {"lower_prec_grouping", "(a + b) * c", "((`a` + `b`) * `c`)"}, + // Mixed precedence: * first, then + + {"mixed_precedence", "a * b + c", "((`a` * `b`) + `c`)"}, + // Deeply nested left-associative chaining + {"deeply_nested", "a + b + c + a + b + c", "(((((`a` + `b`) + `c`) + `a`) + `b`) + `c`)"}, + // Parenthesized expression passthrough — ParenExpr unwrapped, BinaryExpr provides outer parens + {"paren_passthrough", "(a + b)", "(`a` + `b`)"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + node := parseExpr(t, tc.input) + got := Deparse(node) + if got != tc.expected { + t.Errorf("Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparse_Section_2_5_ComparisonPredicates(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // IN list + {"in_list", "a IN (1,2,3)", "(`a` in (1,2,3))"}, + // NOT IN + {"not_in_list", "a NOT IN (1,2,3)", "(`a` not in (1,2,3))"}, + // BETWEEN + {"between", "a BETWEEN 1 AND 10", "(`a` between 1 and 10)"}, + // NOT BETWEEN + {"not_between", "a NOT BETWEEN 1 AND 10", "(`a` not between 1 and 10)"}, + // LIKE + {"like", "a LIKE 'foo%'", "(`a` like 'foo%')"}, + // LIKE with ESCAPE + {"like_escape", "a LIKE 'x' ESCAPE '\\\\'", "(`a` like 'x' escape '\\\\')"}, + // IS NULL + {"is_null", "a IS NULL", "(`a` is null)"}, + // IS NOT NULL + {"is_not_null", "a IS NOT NULL", "(`a` is not null)"}, + // IS TRUE + {"is_true", "a IS TRUE", "(`a` is true)"}, + // IS FALSE + {"is_false", "a IS FALSE", "(`a` is false)"}, + // IS UNKNOWN + {"is_unknown", "a IS UNKNOWN", "(`a` is unknown)"}, + // ROW comparison + {"row_comparison", "ROW(a,b) = ROW(1,2)", "(row(`a`,`b`) = row(1,2))"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + node := parseExpr(t, tc.input) + got := Deparse(node) + if got != tc.expected { + t.Errorf("Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparse_Section_2_6_CaseCastConvert(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // Searched CASE + {"searched_case", "CASE WHEN a > 0 THEN 'pos' ELSE 'zero' END", + "(case when (`a` > 0) then 'pos' else 'zero' end)"}, + // Searched CASE multiple WHEN + {"searched_case_multi", "CASE WHEN a>0 THEN 'a' WHEN b>0 THEN 'b' ELSE 'c' END", + "(case when (`a` > 0) then 'a' when (`b` > 0) then 'b' else 'c' end)"}, + // Simple CASE + {"simple_case", "CASE a WHEN 1 THEN 'one' ELSE 'other' END", + "(case `a` when 1 then 'one' else 'other' end)"}, + // CASE without ELSE + {"case_no_else", "CASE WHEN a > 0 THEN 'pos' END", + "(case when (`a` > 0) then 'pos' end)"}, + // CAST to CHAR + {"cast_char", "CAST(a AS CHAR)", "cast(`a` as char charset utf8mb4)"}, + // CAST to CHAR(N) + {"cast_char_n", "CAST(a AS CHAR(10))", "cast(`a` as char(10) charset utf8mb4)"}, + // CAST to BINARY + {"cast_binary", "CAST(a AS BINARY)", "cast(`a` as char charset binary)"}, + // CAST to SIGNED + {"cast_signed", "CAST(a AS SIGNED)", "cast(`a` as signed)"}, + // CAST to UNSIGNED + {"cast_unsigned", "CAST(a AS UNSIGNED)", "cast(`a` as unsigned)"}, + // CAST to DECIMAL + {"cast_decimal", "CAST(a AS DECIMAL(10,2))", "cast(`a` as decimal(10,2))"}, + // CAST to DATE + {"cast_date", "CAST(a AS DATE)", "cast(`a` as date)"}, + // CAST to DATETIME + {"cast_datetime", "CAST(a AS DATETIME)", "cast(`a` as datetime)"}, + // CAST to JSON + {"cast_json", "CAST(a AS JSON)", "cast(`a` as json)"}, + // CONVERT USING + {"convert_using", "CONVERT(a USING utf8mb4)", "convert(`a` using utf8mb4)"}, + // CONVERT type — rewritten to cast + {"convert_type_char", "CONVERT(a, CHAR)", "cast(`a` as char charset utf8mb4)"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + node := parseExpr(t, tc.input) + got := Deparse(node) + if got != tc.expected { + t.Errorf("Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparse_Section_2_7_OtherExpressions(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // INTERVAL: INTERVAL 1 DAY + a → (`a` + interval 1 day) — operand order swapped + {"interval_add", "INTERVAL 1 DAY + a", "(`a` + interval 1 day)"}, + // COLLATE: a COLLATE utf8mb4_bin → (`a` collate utf8mb4_bin) + {"collate", "a COLLATE utf8mb4_bin", "(`a` collate utf8mb4_bin)"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + node := parseExpr(t, tc.input) + got := Deparse(node) + if got != tc.expected { + t.Errorf("Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparse_Section_3_1_FunctionsAndRewrites(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // Regular function calls — lowercase, no space after comma + {"simple_concat", "CONCAT(a, b)", "concat(`a`,`b`)"}, + {"nested_functions", "CONCAT(UPPER(TRIM(a)), LOWER(b))", "concat(upper(trim(`a`)),lower(`b`))"}, + {"ifnull", "IFNULL(a, 0)", "ifnull(`a`,0)"}, + {"coalesce", "COALESCE(a, b, 0)", "coalesce(`a`,`b`,0)"}, + {"nullif", "NULLIF(a, 0)", "nullif(`a`,0)"}, + {"if_func", "IF(a > 0, 'yes', 'no')", "if((`a` > 0),'yes','no')"}, + {"abs", "ABS(a)", "abs(`a`)"}, + {"greatest", "GREATEST(a, b)", "greatest(`a`,`b`)"}, + {"least", "LEAST(a, b)", "least(`a`,`b`)"}, + + // Function name rewrites + {"substring_to_substr", "SUBSTRING(a, 1, 3)", "substr(`a`,1,3)"}, + {"current_timestamp_no_parens", "CURRENT_TIMESTAMP", "now()"}, + {"current_timestamp_parens", "CURRENT_TIMESTAMP()", "now()"}, + {"current_date", "CURRENT_DATE", "curdate()"}, + {"current_time", "CURRENT_TIME", "curtime()"}, + {"current_user", "CURRENT_USER", "current_user()"}, + {"now_stays", "NOW()", "now()"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + node := parseExpr(t, tc.input) + got := Deparse(node) + if got != tc.expected { + t.Errorf("Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparse_Section_3_2_SpecialFunctionForms(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // TRIM simple — no direction, single arg + {"trim_simple", "TRIM(a)", "trim(`a`)"}, + // TRIM LEADING + {"trim_leading", "TRIM(LEADING 'x' FROM a)", "trim(leading 'x' from `a`)"}, + // TRIM TRAILING + {"trim_trailing", "TRIM(TRAILING 'x' FROM a)", "trim(trailing 'x' from `a`)"}, + // TRIM BOTH + {"trim_both", "TRIM(BOTH 'x' FROM a)", "trim(both 'x' from `a`)"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + node := parseExpr(t, tc.input) + got := Deparse(node) + if got != tc.expected { + t.Errorf("Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparse_Section_3_3_AggregateFunctions(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // COUNT(*) — MySQL 8.0 rewrites * to 0 + {"count_star", "COUNT(*)", "count(0)"}, + // COUNT(col) + {"count_col", "COUNT(a)", "count(`a`)"}, + // COUNT(DISTINCT col) + {"count_distinct", "COUNT(DISTINCT a)", "count(distinct `a`)"}, + // SUM + {"sum", "SUM(a)", "sum(`a`)"}, + // AVG + {"avg", "AVG(a)", "avg(`a`)"}, + // MAX + {"max", "MAX(a)", "max(`a`)"}, + // MIN + {"min", "MIN(a)", "min(`a`)"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + node := parseExpr(t, tc.input) + got := Deparse(node) + if got != tc.expected { + t.Errorf("Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparse_Section_3_4_GroupConcat(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // Basic GROUP_CONCAT — default separator always shown + {"basic", "GROUP_CONCAT(a)", "group_concat(`a` separator ',')"}, + // With ORDER BY — ASC shown explicitly + {"with_order_by", "GROUP_CONCAT(a ORDER BY a)", "group_concat(`a` order by `a` ASC separator ',')"}, + // With explicit SEPARATOR + {"with_separator", "GROUP_CONCAT(a SEPARATOR ';')", "group_concat(`a` separator ';')"}, + // DISTINCT + ORDER BY DESC + SEPARATOR — full combination + {"distinct_order_desc_separator", "GROUP_CONCAT(DISTINCT a ORDER BY a DESC SEPARATOR ';')", + "group_concat(distinct `a` order by `a` DESC separator ';')"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + node := parseExpr(t, tc.input) + got := Deparse(node) + if got != tc.expected { + t.Errorf("Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparse_Section_3_5_WindowFunctions(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // ROW_NUMBER with ORDER BY + {"row_number_over_order_by", + "ROW_NUMBER() OVER (ORDER BY a)", + "row_number() OVER (ORDER BY `a` )"}, + // SUM with PARTITION BY and ORDER BY + {"sum_over_partition_order", + "SUM(b) OVER (PARTITION BY a ORDER BY b)", + "sum(`b`) OVER (PARTITION BY `a` ORDER BY `b` )"}, + // Frame clause: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + {"sum_over_frame", + "SUM(b) OVER (PARTITION BY a ORDER BY b ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)", + "sum(`b`) OVER (PARTITION BY `a` ORDER BY `b` ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)"}, + // Named window reference — MySQL 8.0 backtick-quotes window names + {"named_window_ref", + "ROW_NUMBER() OVER w", + "row_number() OVER `w`"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + node := parseExpr(t, tc.input) + got := Deparse(node) + if got != tc.expected { + t.Errorf("Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparse_Section_3_6_OperatorToFunctionRewrites(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // REGEXP → regexp_like() + {"regexp", "a REGEXP 'pattern'", "regexp_like(`a`,'pattern')"}, + // NOT REGEXP → (not(regexp_like())) + {"not_regexp", "a NOT REGEXP 'p'", "(not(regexp_like(`a`,'p')))"}, + // -> → json_extract() + {"json_extract", "a->'$.key'", "json_extract(`a`,'$.key')"}, + // ->> → json_unquote(json_extract()) + {"json_unquote", "a->>'$.key'", "json_unquote(json_extract(`a`,'$.key'))"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + node := parseExpr(t, tc.input) + got := Deparse(node) + if got != tc.expected { + t.Errorf("Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparse_NilNode(t *testing.T) { + got := Deparse(nil) + if got != "" { + t.Errorf("Deparse(nil) = %q, want empty string", got) + } +} + +// parseRewriteDeparse is a helper that parses an expression, applies RewriteExpr, and deparses. +func parseRewriteDeparse(t *testing.T, expr string) string { + t.Helper() + node := parseExpr(t, expr) + rewritten := RewriteExpr(node) + return Deparse(rewritten) +} + +func TestDeparse_Section_4_1_NOTFolding(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // NOT (comparison) → inverted comparison operator + {"not_gt", "NOT (a > 0)", "(`a` <= 0)"}, + {"not_lt", "NOT (a < 0)", "(`a` >= 0)"}, + {"not_ge", "NOT (a >= 0)", "(`a` < 0)"}, + {"not_le", "NOT (a <= 0)", "(`a` > 0)"}, + {"not_eq", "NOT (a = 0)", "(`a` <> 0)"}, + {"not_ne", "NOT (a <> 0)", "(`a` = 0)"}, + + // NOT (non-boolean) → (0 = expr) + {"not_col", "NOT a", "(0 = `a`)"}, + {"not_add", "NOT (a + 1)", "(0 = (`a` + 1))"}, + + // NOT LIKE → not((expr like pattern)) — stays as not() wrapping + {"not_like", "NOT (a LIKE 'foo%')", "(not((`a` like 'foo%')))"}, + + // ! operator on column — same as NOT + {"bang_col", "!a", "(0 = `a`)"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := parseRewriteDeparse(t, tc.input) + if got != tc.expected { + t.Errorf("RewriteExpr+Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +// TestDeparse_Section_4_1_NOTFolding_AST tests NOT folding via hand-built AST nodes +// for cases where the parser may handle NOT differently (e.g., NOT LIKE is parsed +// directly as LikeExpr with Not=true, not as UnaryNot wrapping LikeExpr). +func TestDeparse_Section_4_1_NOTFolding_AST(t *testing.T) { + // NOT wrapping a LikeExpr (as if we manually constructed this AST) + t.Run("not_wrapping_like_ast", func(t *testing.T) { + node := &ast.UnaryExpr{ + Op: ast.UnaryNot, + Operand: &ast.LikeExpr{ + Expr: &ast.ColumnRef{Column: "a"}, + Pattern: &ast.StringLit{Value: "foo%"}, + }, + } + rewritten := RewriteExpr(node) + got := Deparse(rewritten) + expected := "(not((`a` like 'foo%')))" + if got != expected { + t.Errorf("RewriteExpr+Deparse(NOT(a LIKE 'foo%%')) = %q, want %q", got, expected) + } + }) + + // NOT wrapping comparison via AST (no ParenExpr wrapper) + t.Run("not_gt_no_paren_ast", func(t *testing.T) { + node := &ast.UnaryExpr{ + Op: ast.UnaryNot, + Operand: &ast.BinaryExpr{ + Op: ast.BinOpGt, + Left: &ast.ColumnRef{Column: "a"}, + Right: &ast.IntLit{Value: 0}, + }, + } + rewritten := RewriteExpr(node) + got := Deparse(rewritten) + expected := "(`a` <= 0)" + if got != expected { + t.Errorf("RewriteExpr+Deparse(NOT(a > 0)) = %q, want %q", got, expected) + } + }) + + // ! on column ref — (0 = `a`) + t.Run("bang_column_ast", func(t *testing.T) { + node := &ast.UnaryExpr{ + Op: ast.UnaryNot, + Operand: &ast.ColumnRef{Column: "a"}, + } + rewritten := RewriteExpr(node) + got := Deparse(rewritten) + expected := "(0 = `a`)" + if got != expected { + t.Errorf("RewriteExpr+Deparse(!a) = %q, want %q", got, expected) + } + }) + + // NOT on arithmetic — (0 = (a + 1)) + t.Run("not_arithmetic_ast", func(t *testing.T) { + node := &ast.UnaryExpr{ + Op: ast.UnaryNot, + Operand: &ast.BinaryExpr{ + Op: ast.BinOpAdd, + Left: &ast.ColumnRef{Column: "a"}, + Right: &ast.IntLit{Value: 1}, + }, + } + rewritten := RewriteExpr(node) + got := Deparse(rewritten) + expected := "(0 = (`a` + 1))" + if got != expected { + t.Errorf("RewriteExpr+Deparse(NOT(a+1)) = %q, want %q", got, expected) + } + }) +} + +// TestDeparse_Section_4_2_BooleanContextWrapping tests isBooleanExpr identification +// and boolean context wrapping for AND/OR/XOR operands. +func TestDeparse_Section_4_2_BooleanContextWrapping(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // Column ref in AND: non-boolean operands get wrapped + {"col_ref_in_and", "a AND b", "((0 <> `a`) and (0 <> `b`))"}, + // Arithmetic in AND + {"arithmetic_in_and", "(a+1) AND b", "((0 <> (`a` + 1)) and (0 <> `b`))"}, + // Function in AND + {"function_in_and", "ABS(a) AND b", "((0 <> abs(`a`)) and (0 <> `b`))"}, + // CASE in AND + {"case_in_and", "CASE WHEN a > 0 THEN 1 ELSE 0 END AND b", + "((0 <> (case when (`a` > 0) then 1 else 0 end)) and (0 <> `b`))"}, + // IF in AND + {"if_in_and", "IF(a > 0, 1, 0) AND b", + "((0 <> if((`a` > 0),1,0)) and (0 <> `b`))"}, + // Literal in AND + {"literal_in_and", "'hello' AND 1", "((0 <> 'hello') and (0 <> 1))"}, + // Comparison NOT wrapped: both sides are boolean + {"comparison_not_wrapped", "(a > 0) AND (b > 0)", "((`a` > 0) and (`b` > 0))"}, + // Mixed: one boolean, one non-boolean + {"mixed_bool_nonbool", "(a > 0) AND (b + 1)", "((`a` > 0) and (0 <> (`b` + 1)))"}, + // XOR: non-boolean operands get wrapped + {"xor_wrapping", "a XOR b", "((0 <> `a`) xor (0 <> `b`))"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := parseRewriteDeparse(t, tc.input) + if got != tc.expected { + t.Errorf("RewriteExpr+Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +// TestDeparse_Section_4_2_IsBooleanExpr tests isBooleanExpr identification via AST. +func TestDeparse_Section_4_2_IsBooleanExpr(t *testing.T) { + // Test that comparison, IN, BETWEEN, LIKE, IS NULL, AND/OR/NOT/XOR, EXISTS, TRUE/FALSE + // are all recognized as boolean expressions (not wrapped in AND context). + + booleanCases := []struct { + name string + input string + }{ + // Comparisons are boolean + {"comparison_eq", "(a = 1) AND (b = 2)"}, + {"comparison_ne", "(a <> 1) AND (b <> 2)"}, + {"comparison_lt", "(a < 1) AND (b < 2)"}, + {"comparison_gt", "(a > 1) AND (b > 2)"}, + {"comparison_le", "(a <= 1) AND (b <= 2)"}, + {"comparison_ge", "(a >= 1) AND (b >= 2)"}, + {"comparison_nullsafe", "(a <=> 1) AND (b <=> 2)"}, + // IN is boolean + {"in_is_boolean", "(a IN (1,2)) AND (b IN (3,4))"}, + // BETWEEN is boolean + {"between_is_boolean", "(a BETWEEN 1 AND 10) AND (b BETWEEN 1 AND 10)"}, + // LIKE is boolean + {"like_is_boolean", "(a LIKE 'x') AND (b LIKE 'y')"}, + // IS NULL is boolean + {"is_null_is_boolean", "(a IS NULL) AND (b IS NULL)"}, + // TRUE/FALSE literals are boolean + {"bool_lit_true", "TRUE AND FALSE"}, + } + + for _, tc := range booleanCases { + t.Run(tc.name, func(t *testing.T) { + got := parseRewriteDeparse(t, tc.input) + // None of these should contain "(0 <>" wrapping + if contains0Ne(got) { + t.Errorf("Boolean expression was incorrectly wrapped: RewriteExpr+Deparse(%q) = %q", tc.input, got) + } + }) + } +} + +// TestDeparse_Section_4_2_ISTrueFalse tests IS TRUE/IS FALSE wrapping on non-boolean. +func TestDeparse_Section_4_2_ISTrueFalse(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // IS TRUE on non-boolean column: wrap with (0 <> col) + {"is_true_nonbool", "a IS TRUE", "((0 <> `a`) is true)"}, + // IS FALSE on non-boolean column: wrap with (0 <> col) + {"is_false_nonbool", "a IS FALSE", "((0 <> `a`) is false)"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := parseRewriteDeparse(t, tc.input) + if got != tc.expected { + t.Errorf("RewriteExpr+Deparse(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +} + +// TestDeparse_Section_4_2_SubqueryInAnd tests subquery wrapping in AND context. +func TestDeparse_Section_4_2_SubqueryInAnd(t *testing.T) { + // Subquery in AND: non-boolean subquery gets wrapped + // Note: we build AST directly since parser may not support subqueries in this context easily + t.Run("subquery_in_and", func(t *testing.T) { + // Build: (SELECT 1) AND (SELECT 2) — both are SubqueryExpr (non-boolean) + node := &ast.BinaryExpr{ + Op: ast.BinOpAnd, + Left: &ast.SubqueryExpr{ + Select: &ast.SelectStmt{ + TargetList: []ast.ExprNode{ + &ast.ResTarget{Val: &ast.IntLit{Value: 1}}, + }, + }, + }, + Right: &ast.SubqueryExpr{ + Select: &ast.SelectStmt{ + TargetList: []ast.ExprNode{ + &ast.ResTarget{Val: &ast.IntLit{Value: 2}}, + }, + }, + }, + } + rewritten := RewriteExpr(node) + got := Deparse(rewritten) + // SubqueryExpr is not boolean, so should be wrapped + // The exact output depends on SubqueryExpr deparsing — we just verify wrapping happened + if !contains0Ne(got) { + t.Errorf("Subquery in AND was NOT wrapped: got %q", got) + } + }) +} + +// contains0Ne checks if the output contains "(0 <>" which indicates boolean wrapping. +func contains0Ne(s string) bool { + return strings.Contains(s, "(0 <>") || strings.Contains(s, "(0 =") +} + +// TestRewriteExpr_NilNode tests that RewriteExpr handles nil gracefully. +func TestRewriteExpr_NilNode(t *testing.T) { + got := RewriteExpr(nil) + if got != nil { + t.Errorf("RewriteExpr(nil) = %v, want nil", got) + } +} + +// parseSelect parses a full SQL SELECT statement and returns the SelectStmt. +func parseSelect(t *testing.T, sql string) *ast.SelectStmt { + t.Helper() + list, err := parser.Parse(sql) + if err != nil { + t.Fatalf("failed to parse %q: %v", sql, err) + } + if list.Len() == 0 { + t.Fatalf("no statements parsed from %q", sql) + } + sel, ok := list.Items[0].(*ast.SelectStmt) + if !ok { + t.Fatalf("expected SelectStmt, got %T", list.Items[0]) + } + return sel +} + +func TestDeparseSelect_Section_5_1_TargetListAliases(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // Single column with FROM + {"single_column", "SELECT a FROM t", "select `a` AS `a` from `t`"}, + // Multiple columns: comma-separated, no space after comma + {"multiple_columns", "SELECT a, b, c FROM t", "select `a` AS `a`,`b` AS `b`,`c` AS `c` from `t`"}, + // Column alias with AS + {"column_alias_as", "SELECT a AS col1 FROM t", "select `a` AS `col1` from `t`"}, + // Column alias without AS (should still produce AS in output) + {"column_alias_no_as", "SELECT a col1 FROM t", "select `a` AS `col1` from `t`"}, + // Expression alias + {"expression_alias", "SELECT a + b AS sum_col FROM t", "select (`a` + `b`) AS `sum_col` from `t`"}, + // Auto-alias literal: SELECT 1 → 1 AS `1` + {"auto_alias_literal", "SELECT 1", "select 1 AS `1`"}, + // Auto-alias expression: SELECT a + b → (`a` + `b`) AS `a + b` + {"auto_alias_expression", "SELECT a + b FROM t", "select (`a` + `b`) AS `a + b` from `t`"}, + // Auto-alias string literal + {"auto_alias_string", "SELECT 'hello'", "select 'hello' AS `hello`"}, + // Auto-alias NULL + {"auto_alias_null", "SELECT NULL", "select NULL AS `NULL`"}, + // Auto-alias function call + {"auto_alias_func", "SELECT CONCAT(a, b) FROM t", "select concat(`a`,`b`) AS `CONCAT(a, b)` from `t`"}, + // Auto-alias boolean literal + {"auto_alias_true", "SELECT TRUE", "select true AS `TRUE`"}, + {"auto_alias_false", "SELECT FALSE", "select false AS `FALSE`"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + sel := parseSelect(t, tc.input) + got := DeparseSelect(sel) + if got != tc.expected { + t.Errorf("DeparseSelect(%q) =\n %q\nwant:\n %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparseSelect_Section_5_1_NilStmt(t *testing.T) { + got := DeparseSelect(nil) + if got != "" { + t.Errorf("DeparseSelect(nil) = %q, want empty string", got) + } +} + +func TestDeparseSelect_Section_5_2_FromClause(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // Single table + {"single_table", "SELECT a FROM t", "select `a` AS `a` from `t`"}, + // Table alias with AS — no AS keyword in output for table alias + {"table_alias_with_as", "SELECT a FROM t AS t1", "select `a` AS `a` from `t` `t1`"}, + // Table alias without AS — same output + {"table_alias_without_as", "SELECT a FROM t t1", "select `a` AS `a` from `t` `t1`"}, + // Multiple tables (implicit cross join) → explicit join with parens + {"implicit_cross_join", "SELECT a FROM t1, t2", "select `a` AS `a` from (`t1` join `t2`)"}, + // Three tables implicit cross join + {"implicit_cross_join_3", "SELECT a FROM t1, t2, t3", "select `a` AS `a` from ((`t1` join `t2`) join `t3`)"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + sel := parseSelect(t, tc.input) + got := DeparseSelect(sel) + if got != tc.expected { + t.Errorf("DeparseSelect(%q) =\n %q\nwant:\n %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparseSelect_Section_5_2_DerivedTable(t *testing.T) { + // Note: Parser currently does not produce SubqueryExpr for derived tables in FROM clause. + // These tests verify the deparse logic works when given the correct AST manually. + t.Run("derived_table_subquery_expr", func(t *testing.T) { + // Manually construct: FROM (SELECT 1 AS a) `d` + inner := &ast.SelectStmt{ + TargetList: []ast.ExprNode{ + &ast.ResTarget{Name: "a", Val: &ast.IntLit{Value: 1}}, + }, + } + subq := &ast.SubqueryExpr{ + Select: inner, + Alias: "d", + } + outer := &ast.SelectStmt{ + TargetList: []ast.ExprNode{ + &ast.ResTarget{Val: &ast.ColumnRef{Column: "a"}}, + }, + From: []ast.TableExpr{subq}, + } + got := DeparseSelect(outer) + expected := "select `a` AS `a` from (select 1 AS `a`) `d`" + if got != expected { + t.Errorf("DeparseSelect(derived table) =\n %q\nwant:\n %q", got, expected) + } + }) +} + +func TestDeparseSelect_Section_5_3_JoinClause(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // INNER JOIN + {"inner_join", "SELECT a FROM t1 JOIN t2 ON t1.a = t2.a", + "select `a` AS `a` from (`t1` join `t2` on((`t1`.`a` = `t2`.`a`)))"}, + // LEFT JOIN + {"left_join", "SELECT a FROM t1 LEFT JOIN t2 ON t1.a = t2.a", + "select `a` AS `a` from (`t1` left join `t2` on((`t1`.`a` = `t2`.`a`)))"}, + // RIGHT JOIN → LEFT JOIN with table swap + {"right_join", "SELECT a FROM t1 RIGHT JOIN t2 ON t1.a = t2.a", + "select `a` AS `a` from (`t2` left join `t1` on((`t1`.`a` = `t2`.`a`)))"}, + // CROSS JOIN → plain join (no ON) + {"cross_join", "SELECT a FROM t1 CROSS JOIN t2", + "select `a` AS `a` from (`t1` join `t2`)"}, + // STRAIGHT_JOIN — lowercase + {"straight_join", "SELECT a FROM t1 STRAIGHT_JOIN t2 ON t1.a = t2.a", + "select `a` AS `a` from (`t1` straight_join `t2` on((`t1`.`a` = `t2`.`a`)))"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + sel := parseSelect(t, tc.input) + got := DeparseSelect(sel) + if got != tc.expected { + t.Errorf("DeparseSelect(%q) =\n %q\nwant:\n %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparseSelect_Section_5_3_NaturalJoin(t *testing.T) { + // NATURAL JOIN — expanded to join without ON (needs Phase 6 resolver for column expansion) + // For now, verify basic format; full ON expansion requires schema info. + cases := []struct { + name string + input string + expected string + }{ + // NATURAL JOIN → join (no ON without resolver) + {"natural_join", "SELECT a FROM t1 NATURAL JOIN t2", + "select `a` AS `a` from (`t1` join `t2`)"}, + // NATURAL LEFT JOIN → left join (no ON without resolver) + {"natural_left_join", "SELECT a FROM t1 NATURAL LEFT JOIN t2", + "select `a` AS `a` from (`t1` left join `t2`)"}, + // NATURAL RIGHT JOIN → left join with table swap (no ON without resolver) + {"natural_right_join", "SELECT a FROM t1 NATURAL RIGHT JOIN t2", + "select `a` AS `a` from (`t2` left join `t1`)"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + sel := parseSelect(t, tc.input) + got := DeparseSelect(sel) + if got != tc.expected { + t.Errorf("DeparseSelect(%q) =\n %q\nwant:\n %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparseSelect_Section_5_3_UsingClause(t *testing.T) { + // USING — expanded to ON with qualified column references + cases := []struct { + name string + input string + expected string + }{ + // USING single column + {"using_single", "SELECT a FROM t1 JOIN t2 USING (a)", + "select `a` AS `a` from (`t1` join `t2` on((`t1`.`a` = `t2`.`a`)))"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + sel := parseSelect(t, tc.input) + got := DeparseSelect(sel) + if got != tc.expected { + t.Errorf("DeparseSelect(%q) =\n %q\nwant:\n %q", tc.input, got, tc.expected) + } + }) + } +} + +// TestDeparseSelect_Section_5_3_UsingMultiColumn tests USING with multiple columns via AST. +func TestDeparseSelect_Section_5_3_UsingMultiColumn(t *testing.T) { + // Build AST manually for USING (a, b) since parser may only support single-column USING parsing + t.Run("using_multi_column", func(t *testing.T) { + join := &ast.JoinClause{ + Type: ast.JoinInner, + Left: &ast.TableRef{Name: "t1"}, + Right: &ast.TableRef{Name: "t2"}, + Condition: &ast.UsingCondition{ + Columns: []string{"a", "b"}, + }, + } + sel := &ast.SelectStmt{ + TargetList: []ast.ExprNode{ + &ast.ResTarget{Val: &ast.ColumnRef{Column: "x"}}, + }, + From: []ast.TableExpr{join}, + } + got := DeparseSelect(sel) + expected := "select `x` AS `x` from (`t1` join `t2` on(((`t1`.`a` = `t2`.`a`) and (`t1`.`b` = `t2`.`b`))))" + if got != expected { + t.Errorf("DeparseSelect(USING multi) =\n %q\nwant:\n %q", got, expected) + } + }) +} + +func TestDeparseSelect_Section_5_4_WhereGroupByHavingOrderByLimit(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // WHERE simple + {"where_simple", "SELECT a FROM t WHERE a > 1", + "select `a` AS `a` from `t` where (`a` > 1)"}, + // WHERE compound + {"where_compound", "SELECT a FROM t WHERE a > 1 AND b < 10", + "select `a` AS `a` from `t` where ((`a` > 1) and (`b` < 10))"}, + // GROUP BY single + {"group_by_single", "SELECT a FROM t GROUP BY a", + "select `a` AS `a` from `t` group by `a`"}, + // GROUP BY multiple — no space after comma + {"group_by_multiple", "SELECT a, b FROM t GROUP BY a, b", + "select `a` AS `a`,`b` AS `b` from `t` group by `a`,`b`"}, + // GROUP BY WITH ROLLUP + {"group_by_with_rollup", "SELECT a FROM t GROUP BY a WITH ROLLUP", + "select `a` AS `a` from `t` group by `a` with rollup"}, + // HAVING + {"having", "SELECT a FROM t GROUP BY a HAVING COUNT(*) > 1", + "select `a` AS `a` from `t` group by `a` having (count(0) > 1)"}, + // ORDER BY ASC (default — no explicit ASC in output) + {"order_by_asc", "SELECT a FROM t ORDER BY a", + "select `a` AS `a` from `t` order by `a`"}, + // ORDER BY DESC + {"order_by_desc", "SELECT a FROM t ORDER BY a DESC", + "select `a` AS `a` from `t` order by `a` desc"}, + // ORDER BY multiple — comma no space + {"order_by_multiple", "SELECT a, b FROM t ORDER BY a, b DESC", + "select `a` AS `a`,`b` AS `b` from `t` order by `a`,`b` desc"}, + // LIMIT + {"limit", "SELECT a FROM t LIMIT 10", + "select `a` AS `a` from `t` limit 10"}, + // LIMIT with OFFSET — MySQL comma syntax: LIMIT offset,count + {"limit_with_offset", "SELECT a FROM t LIMIT 10 OFFSET 5", + "select `a` AS `a` from `t` limit 5,10"}, + // DISTINCT + {"distinct", "SELECT DISTINCT a FROM t", + "select distinct `a` AS `a` from `t`"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + sel := parseSelect(t, tc.input) + got := DeparseSelect(sel) + if got != tc.expected { + t.Errorf("DeparseSelect(%q) =\n %q\nwant:\n %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparseSelect_Section_5_6_Subqueries(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // Scalar subquery in SELECT list + // MySQL 8.0 omits column aliases inside scalar subquery body. + // Auto-alias uses uppercase keywords and unqualified column names. + {"scalar_subquery", "SELECT (SELECT MAX(a) FROM t) FROM t", + "select (select max(`a`) from `t`) AS `(SELECT MAX(a) FROM t)` from `t`"}, + // IN subquery in WHERE — MySQL 8.0 omits outer parens and aliases in subquery target list + {"in_subquery", "SELECT a FROM t WHERE a IN (SELECT a FROM t)", + "select `a` AS `a` from `t` where `a` in (select `a` from `t`)"}, + // EXISTS in WHERE + {"exists_subquery", "SELECT a FROM t WHERE EXISTS (SELECT 1 FROM t)", + "select `a` AS `a` from `t` where exists(select 1 from `t`)"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + sel := parseSelect(t, tc.input) + got := DeparseSelect(sel) + if got != tc.expected { + t.Errorf("DeparseSelect(%q) =\n %q\nwant:\n %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparseSelect_Section_5_5_SetOperations(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // UNION + {"union", "SELECT a FROM t UNION SELECT b FROM t", + "select `a` AS `a` from `t` union select `b` AS `b` from `t`"}, + // UNION ALL + {"union_all", "SELECT a FROM t UNION ALL SELECT b FROM t", + "select `a` AS `a` from `t` union all select `b` AS `b` from `t`"}, + // Multiple UNION: three SELECTs chained flat (left-associative) + {"multiple_union", "SELECT a FROM t UNION SELECT b FROM t UNION SELECT c FROM t", + "select `a` AS `a` from `t` union select `b` AS `b` from `t` union select `c` AS `c` from `t`"}, + // INTERSECT + {"intersect", "SELECT a FROM t INTERSECT SELECT b FROM t", + "select `a` AS `a` from `t` intersect select `b` AS `b` from `t`"}, + // EXCEPT + {"except", "SELECT a FROM t EXCEPT SELECT b FROM t", + "select `a` AS `a` from `t` except select `b` AS `b` from `t`"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + sel := parseSelect(t, tc.input) + got := DeparseSelect(sel) + if got != tc.expected { + t.Errorf("DeparseSelect(%q) =\n %q\nwant:\n %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparseSelect_Section_5_7_CTE(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // Simple CTE + { + "simple_cte", + "WITH cte AS (SELECT 1) SELECT a FROM cte", + "with `cte` as (select 1 AS `1`) select `a` AS `a` from `cte`", + }, + // CTE with column list + { + "cte_with_columns", + "WITH cte(x) AS (SELECT 1) SELECT x FROM cte", + "with `cte` (`x`) as (select 1 AS `1`) select `x` AS `x` from `cte`", + }, + // RECURSIVE CTE + { + "recursive_cte", + "WITH RECURSIVE cte AS (SELECT 1 UNION ALL SELECT 1) SELECT a FROM cte", + "with recursive `cte` as (select 1 AS `1` union all select 1 AS `1`) select `a` AS `a` from `cte`", + }, + // Multiple CTEs + { + "multiple_ctes", + "WITH c1 AS (SELECT 1), c2 AS (SELECT 2) SELECT a FROM c1, c2", + "with `c1` as (select 1 AS `1`), `c2` as (select 2 AS `2`) select `a` AS `a` from (`c1` join `c2`)", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + sel := parseSelect(t, tc.input) + got := DeparseSelect(sel) + if got != tc.expected { + t.Errorf("DeparseSelect(%q) =\n %q\nwant:\n %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestDeparse_Section_5_8_ForUpdate(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + { + "for_update", + "SELECT a FROM t FOR UPDATE", + "select `a` AS `a` from `t` for update", + }, + { + "for_share", + "SELECT a FROM t FOR SHARE", + "select `a` AS `a` from `t` for share", + }, + { + "lock_in_share_mode", + "SELECT a FROM t LOCK IN SHARE MODE", + "select `a` AS `a` from `t` lock in share mode", + }, + { + "for_update_of_table", + "SELECT a FROM t FOR UPDATE OF t", + "select `a` AS `a` from `t` for update of `t`", + }, + { + "for_update_nowait", + "SELECT a FROM t FOR UPDATE NOWAIT", + "select `a` AS `a` from `t` for update nowait", + }, + { + "for_update_skip_locked", + "SELECT a FROM t FOR UPDATE SKIP LOCKED", + "select `a` AS `a` from `t` for update skip locked", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + sel := parseSelect(t, tc.input) + got := DeparseSelect(sel) + if got != tc.expected { + t.Errorf("DeparseSelect(%q) =\n %q\nwant:\n %q", tc.input, got, tc.expected) + } + }) + } +} diff --git a/tidb/deparse/resolver.go b/tidb/deparse/resolver.go new file mode 100644 index 00000000..232c404a --- /dev/null +++ b/tidb/deparse/resolver.go @@ -0,0 +1,895 @@ +// Package deparse — resolver.go implements schema-aware column qualification. +// The resolver takes a TableLookup + SelectStmt and returns a new SelectStmt where +// all column references are fully qualified with their table name/alias. +package deparse + +import ( + "fmt" + "sort" + "strings" + + ast "github.com/bytebase/omni/tidb/ast" + scopepkg "github.com/bytebase/omni/tidb/scope" +) + +// ResolverColumn is a type alias for scope.Column, preserving backward compatibility. +type ResolverColumn = scopepkg.Column + +// ResolverTable is a type alias for scope.Table, preserving backward compatibility. +type ResolverTable = scopepkg.Table + +// TableLookup is a function that looks up a table by name in the catalog. +// Returns nil if the table is not found. +type TableLookup func(tableName string) *ResolverTable + +// Resolver resolves column references in a SelectStmt using catalog metadata. +type Resolver struct { + Lookup TableLookup + // DefaultCharset is the database's default character set (e.g., "utf8mb4", "latin1"). + // Used to populate CAST(... AS CHAR) charset when not explicitly specified. + // If empty, defaults to "utf8mb4". + DefaultCharset string +} + +// Resolve takes a SelectStmt and returns a new SelectStmt with all column +// references fully qualified. The original AST is modified in-place. +func (r *Resolver) Resolve(stmt *ast.SelectStmt) *ast.SelectStmt { + return r.resolveWithCTEs(stmt, nil) +} + +// resolveWithCTEs resolves a SelectStmt with optional CTE virtual tables +// available in scope. cteTables maps CTE names to virtual ResolverTables +// built from their SELECT target lists. +func (r *Resolver) resolveWithCTEs(stmt *ast.SelectStmt, cteTables map[string]*ResolverTable) *ast.SelectStmt { + if stmt == nil { + return nil + } + // Handle set operations recursively + if stmt.SetOp != ast.SetOpNone { + // Before recursing, hoist CTEs from the leftmost leaf so they are + // visible to both sides of the set operation (matching MySQL semantics + // where WITH ... applies to the entire UNION). + mergedCTETables := cteTables + if leftCTEs := collectLeftmostCTEs(stmt); len(leftCTEs) > 0 { + mergedCTETables = make(map[string]*ResolverTable) + if cteTables != nil { + for k, v := range cteTables { + mergedCTETables[k] = v + } + } + for _, cte := range leftCTEs { + cteResolver := &Resolver{ + Lookup: r.withCTELookup(mergedCTETables), + DefaultCharset: r.DefaultCharset, + } + cteResolver.resolveWithCTEs(cte.Select, mergedCTETables) + vt := buildCTEVirtualTable(cte) + if vt != nil { + mergedCTETables[strings.ToLower(cte.Name)] = vt + } + } + } + if stmt.Left != nil { + r.resolveWithCTEs(stmt.Left, mergedCTETables) + } + if stmt.Right != nil { + r.resolveWithCTEs(stmt.Right, mergedCTETables) + } + // Resolve ORDER BY ordinals (e.g., ORDER BY 1) to column aliases + // from the leftmost SELECT's target list, matching MySQL 8.0 behavior. + if len(stmt.OrderBy) > 0 { + leftmost := stmt + for leftmost.SetOp != ast.SetOpNone && leftmost.Left != nil { + leftmost = leftmost.Left + } + for _, item := range stmt.OrderBy { + if lit, ok := item.Expr.(*ast.IntLit); ok { + idx := int(lit.Value) - 1 // 1-based to 0-based + if idx >= 0 && idx < len(leftmost.TargetList) { + if rt, ok := leftmost.TargetList[idx].(*ast.ResTarget); ok { + alias := rt.Name + if alias == "" { + // Derive alias from column ref + if cr, ok := rt.Val.(*ast.ColumnRef); ok { + alias = cr.Column + } + } + if alias != "" { + item.Expr = &ast.ColumnRef{Column: alias} + } + } + } + } + } + } + return stmt + } + + // Process CTEs: resolve each CTE's SELECT, then build virtual tables. + // CTEs are resolved in order; later CTEs can reference earlier ones. + localCTETables := make(map[string]*ResolverTable) + if cteTables != nil { + for k, v := range cteTables { + localCTETables[k] = v + } + } + for _, cte := range stmt.CTEs { + if cte.Recursive && cte.Select != nil && cte.Select.SetOp != ast.SetOpNone { + // Recursive CTE: resolve left (non-recursive) branch first to get column info, + // then add CTE to scope, then resolve right (recursive) branch. + sel := cte.Select + + // Step 1: Resolve the non-recursive (left) branch + leftResolver := &Resolver{ + Lookup: r.withCTELookup(localCTETables), + DefaultCharset: r.DefaultCharset, + } + leftResolver.resolveWithCTEs(sel.Left, localCTETables) + + // Step 2: Build virtual table from left branch's target list and add CTE to scope + vt := buildCTEVirtualTableFromSelect(cte.Name, cte.Columns, sel.Left) + if vt != nil { + localCTETables[strings.ToLower(cte.Name)] = vt + } + + // Step 3: Resolve the recursive (right) branch — CTE is now in scope + rightResolver := &Resolver{ + Lookup: r.withCTELookup(localCTETables), + DefaultCharset: r.DefaultCharset, + } + rightResolver.resolveWithCTEs(sel.Right, localCTETables) + } else { + // Non-recursive CTE: resolve entire SELECT, then build virtual table + cteResolver := &Resolver{ + Lookup: r.withCTELookup(localCTETables), + DefaultCharset: r.DefaultCharset, + } + cteResolver.resolveWithCTEs(cte.Select, localCTETables) + + vt := buildCTEVirtualTable(cte) + if vt != nil { + localCTETables[strings.ToLower(cte.Name)] = vt + } + } + } + + // Build scope from FROM clause, using a lookup that includes CTE virtual tables + origLookup := r.Lookup + if len(localCTETables) > 0 { + r.Lookup = r.withCTELookup(localCTETables) + } + sc := r.buildScope(stmt.From) + r.Lookup = origLookup + + // Resolve JOIN ON conditions (walk FROM clause) BEFORE target list resolution + // so that NATURAL JOIN expansion can mark coalesced columns for star expansion. + r.resolveFromExprs(stmt.From, sc) + + + // Resolve target list (may expand stars) + stmt.TargetList = r.resolveTargetList(stmt.TargetList, sc) + + // Resolve WHERE + if stmt.Where != nil { + stmt.Where = r.resolveExpr(stmt.Where, sc) + } + + // Resolve GROUP BY + for i, expr := range stmt.GroupBy { + stmt.GroupBy[i] = r.resolveExpr(expr, sc) + } + + // Resolve HAVING + if stmt.Having != nil { + stmt.Having = r.resolveExpr(stmt.Having, sc) + } + + // Resolve WINDOW clause + for _, wd := range stmt.WindowClause { + for i, expr := range wd.PartitionBy { + wd.PartitionBy[i] = r.resolveExpr(expr, sc) + } + for _, item := range wd.OrderBy { + item.Expr = r.resolveExpr(item.Expr, sc) + } + } + + // Resolve ORDER BY + for _, item := range stmt.OrderBy { + item.Expr = r.resolveExpr(item.Expr, sc) + } + + return stmt +} + +// buildScope constructs a scope from the FROM clause table expressions. +func (r *Resolver) buildScope(from []ast.TableExpr) *scopepkg.Scope { + sc := scopepkg.New() + for _, tbl := range from { + r.addTableExprToScope(tbl, sc) + } + return sc +} + +// addTableExprToScope recursively adds table references from a table expression to the scope. +func (r *Resolver) addTableExprToScope(tbl ast.TableExpr, sc *scopepkg.Scope) { + switch t := tbl.(type) { + case *ast.TableRef: + table := r.Lookup(t.Name) + if table == nil { + return + } + effectiveName := t.Name + if t.Alias != "" { + effectiveName = t.Alias + } + sc.Add(effectiveName, table) + case *ast.JoinClause: + r.addTableExprToScope(t.Left, sc) + r.addTableExprToScope(t.Right, sc) + case *ast.SubqueryExpr: + // Derived table: resolve the inner SELECT, then build a virtual table + // from its target list so outer queries can reference derived columns. + if t.Select != nil { + r.Resolve(t.Select) + vt := buildDerivedVirtualTable(t) + if vt != nil && t.Alias != "" { + sc.Add(t.Alias, vt) + } + } + } +} + +// resolveTargetList resolves all target list entries, expanding qualified stars. +func (r *Resolver) resolveTargetList(targets []ast.ExprNode, sc *scopepkg.Scope) []ast.ExprNode { + var result []ast.ExprNode + for i, target := range targets { + expanded := r.resolveTarget(target, sc, i+1) + result = append(result, expanded...) + } + return result +} + +// resolveTarget resolves a single target list entry. Returns a slice because +// star expansion can produce multiple entries. +func (r *Resolver) resolveTarget(target ast.ExprNode, sc *scopepkg.Scope, position int) []ast.ExprNode { + rt, isRT := target.(*ast.ResTarget) + + var expr ast.ExprNode + var explicitAlias string + if isRT { + expr = rt.Val + explicitAlias = rt.Name + } else { + expr = target + } + + // Check for qualified star: t1.* + if col, ok := expr.(*ast.ColumnRef); ok && col.Star { + return r.expandQualifiedStar(col.Table, sc) + } + + // Check for unqualified star: * + if _, ok := expr.(*ast.StarExpr); ok { + return r.expandStar(sc) + } + + // Apply CAST/CONVERT charset resolution before computing auto-alias. + // This ensures the auto-alias includes the database charset (e.g., "charset latin1"). + r.resolveCastCharsets(expr) + + // Compute auto-alias from the pre-resolution expression when no explicit alias. + // MySQL 8.0 uses the original (unqualified) expression text for auto-aliases, + // so we must derive it before column qualification changes the expression. + if explicitAlias == "" { + exprStr := deparseExpr(expr) + explicitAlias = autoAlias(expr, exprStr, position) + } + + // Resolve the expression (column qualification, etc.) + resolved := r.resolveExpr(expr, sc) + + if isRT { + rt.Val = resolved + rt.Name = explicitAlias + return []ast.ExprNode{rt} + } + return []ast.ExprNode{&ast.ResTarget{Name: explicitAlias, Val: resolved}} +} + +// expandStar expands * to all columns from all tables in scope order. +// Columns marked as coalesced (from NATURAL JOIN / USING) are excluded. +func (r *Resolver) expandStar(sc *scopepkg.Scope) []ast.ExprNode { + var result []ast.ExprNode + for _, entry := range sc.AllEntries() { + cols := sortedResolverColumns(entry.Table) + for _, col := range cols { + // Skip columns coalesced by NATURAL JOIN or USING + if sc.IsCoalesced(entry.Name, col.Name) { + continue + } + result = append(result, &ast.ResTarget{ + Name: col.Name, + Val: &ast.ColumnRef{ + Table: entry.Name, + Column: col.Name, + }, + }) + } + } + return result +} + +// expandQualifiedStar expands t1.* to all columns of table t1. +func (r *Resolver) expandQualifiedStar(tableName string, sc *scopepkg.Scope) []ast.ExprNode { + table := sc.GetTable(tableName) + if table == nil { + // Table not found in scope; return as-is + return []ast.ExprNode{&ast.ResTarget{ + Val: &ast.ColumnRef{Table: tableName, Star: true}, + }} + } + + // Find the effective name from scope entries (preserves alias casing) + effectiveName := tableName + for _, entry := range sc.AllEntries() { + if strings.EqualFold(entry.Name, tableName) { + effectiveName = entry.Name + break + } + } + + var result []ast.ExprNode + cols := sortedResolverColumns(table) + for _, col := range cols { + result = append(result, &ast.ResTarget{ + Name: col.Name, + Val: &ast.ColumnRef{ + Table: effectiveName, + Column: col.Name, + }, + }) + } + return result +} + +// sortedResolverColumns returns columns sorted by Position. +func sortedResolverColumns(table *ResolverTable) []ResolverColumn { + cols := make([]ResolverColumn, len(table.Columns)) + copy(cols, table.Columns) + sort.Slice(cols, func(i, j int) bool { + return cols[i].Position < cols[j].Position + }) + return cols +} + +// resolveExpr resolves column references in an expression. +func (r *Resolver) resolveExpr(node ast.ExprNode, sc *scopepkg.Scope) ast.ExprNode { + if node == nil { + return nil + } + switch n := node.(type) { + case *ast.ColumnRef: + return r.resolveColumnRef(n, sc) + case *ast.BinaryExpr: + n.Left = r.resolveExpr(n.Left, sc) + n.Right = r.resolveExpr(n.Right, sc) + return n + case *ast.UnaryExpr: + n.Operand = r.resolveExpr(n.Operand, sc) + return n + case *ast.ParenExpr: + n.Expr = r.resolveExpr(n.Expr, sc) + return n + case *ast.InExpr: + n.Expr = r.resolveExpr(n.Expr, sc) + for i, item := range n.List { + n.List[i] = r.resolveExpr(item, sc) + } + if n.Select != nil { + r.Resolve(n.Select) + } + return n + case *ast.BetweenExpr: + n.Expr = r.resolveExpr(n.Expr, sc) + n.Low = r.resolveExpr(n.Low, sc) + n.High = r.resolveExpr(n.High, sc) + return n + case *ast.LikeExpr: + n.Expr = r.resolveExpr(n.Expr, sc) + n.Pattern = r.resolveExpr(n.Pattern, sc) + if n.Escape != nil { + n.Escape = r.resolveExpr(n.Escape, sc) + } + return n + case *ast.IsExpr: + n.Expr = r.resolveExpr(n.Expr, sc) + return n + case *ast.CaseExpr: + if n.Operand != nil { + n.Operand = r.resolveExpr(n.Operand, sc) + } + for _, w := range n.Whens { + w.Cond = r.resolveExpr(w.Cond, sc) + w.Result = r.resolveExpr(w.Result, sc) + } + if n.Default != nil { + n.Default = r.resolveExpr(n.Default, sc) + } + return n + case *ast.FuncCallExpr: + for i, arg := range n.Args { + n.Args[i] = r.resolveExpr(arg, sc) + } + // Resolve ORDER BY in aggregate functions (e.g., GROUP_CONCAT) + for _, item := range n.OrderBy { + item.Expr = r.resolveExpr(item.Expr, sc) + } + // Resolve window function OVER clause + if n.Over != nil { + r.resolveWindowDef(n.Over, sc) + } + return n + case *ast.CastExpr: + n.Expr = r.resolveExpr(n.Expr, sc) + r.resolveCastCharset(n.TypeName) + return n + case *ast.ConvertExpr: + n.Expr = r.resolveExpr(n.Expr, sc) + r.resolveCastCharset(n.TypeName) + return n + case *ast.CollateExpr: + n.Expr = r.resolveExpr(n.Expr, sc) + return n + case *ast.IntervalExpr: + n.Value = r.resolveExpr(n.Value, sc) + return n + case *ast.RowExpr: + for i, item := range n.Items { + n.Items[i] = r.resolveExpr(item, sc) + } + return n + case *ast.ExistsExpr: + if n.Select != nil { + r.Resolve(n.Select) + } + return n + case *ast.SubqueryExpr: + if n.Select != nil { + r.Resolve(n.Select) + } + return n + case *ast.ResTarget: + n.Val = r.resolveExpr(n.Val, sc) + return n + default: + // Leaf nodes (literals, etc.) — no resolution needed + return node + } +} + +// resolveColumnRef qualifies an unqualified column reference by finding which +// table in scope contains the column. +func (r *Resolver) resolveColumnRef(col *ast.ColumnRef, sc *scopepkg.Scope) ast.ExprNode { + // Already qualified — just validate the table name maps to an alias + if col.Table != "" { + // Check if this table name is in scope (might be an alias) + if sc.GetTable(col.Table) != nil { + // Find the effective name from scope (preserves case) + for _, entry := range sc.AllEntries() { + if strings.EqualFold(entry.Name, col.Table) { + col.Table = entry.Name + break + } + } + } + return col + } + + // Unqualified — search all tables in scope + var matchTable string + var matchCount int + for _, entry := range sc.AllEntries() { + if entry.Table.GetColumn(col.Column) != nil { + if matchCount == 0 { + matchTable = entry.Name + } + matchCount++ + } + } + + if matchCount == 0 { + // Column not found — return as-is (could be a literal alias or error) + return col + } + if matchCount > 1 { + // Ambiguous — for now, qualify with first match + // MySQL would raise ERROR 1052: Column 'x' in field list is ambiguous + // TODO: return error + } + + col.Table = matchTable + return col +} + +// resolveWindowDef resolves column references in a window definition. +func (r *Resolver) resolveWindowDef(wd *ast.WindowDef, sc *scopepkg.Scope) { + for i, expr := range wd.PartitionBy { + wd.PartitionBy[i] = r.resolveExpr(expr, sc) + } + for _, item := range wd.OrderBy { + item.Expr = r.resolveExpr(item.Expr, sc) + } +} + +// resolveFromExprs walks FROM clause table expressions and resolves +// ON condition expressions in JoinClauses. +func (r *Resolver) resolveFromExprs(from []ast.TableExpr, sc *scopepkg.Scope) { + for _, tbl := range from { + r.resolveTableExpr(tbl, sc) + } +} + +// resolveTableExpr resolves expressions within a table expression (e.g., ON conditions). +// For NATURAL JOINs, it expands the join by finding common columns between both tables +// and building an ON condition. For USING clauses, it resolves column references. +func (r *Resolver) resolveTableExpr(tbl ast.TableExpr, sc *scopepkg.Scope) { + switch t := tbl.(type) { + case *ast.JoinClause: + r.resolveTableExpr(t.Left, sc) + r.resolveTableExpr(t.Right, sc) + + // Expand NATURAL JOIN → find common columns → build ON condition + if t.Type == ast.JoinNatural || t.Type == ast.JoinNaturalLeft || t.Type == ast.JoinNaturalRight { + r.expandNaturalJoin(t, sc) + } + + // Expand USING → build ON condition with qualified column refs + if t.Condition != nil { + if using, ok := t.Condition.(*ast.UsingCondition); ok { + r.expandUsingCondition(t, using) + } + } + + if t.Condition != nil { + if on, ok := t.Condition.(*ast.OnCondition); ok { + on.Expr = r.resolveExpr(on.Expr, sc) + } + } + } +} + +// expandNaturalJoin finds common columns between the left and right tables of a +// NATURAL JOIN and builds an ON condition. It also changes the join type: +// - NATURAL JOIN → JoinInner +// - NATURAL LEFT JOIN → JoinLeft +// - NATURAL RIGHT JOIN → JoinRight (deparse will then swap to LEFT) +func (r *Resolver) expandNaturalJoin(j *ast.JoinClause, sc *scopepkg.Scope) { + leftTable := r.lookupTableExpr(j.Left) + rightTable := r.lookupTableExpr(j.Right) + if leftTable == nil || rightTable == nil { + // Can't expand without schema info; leave as-is + switch j.Type { + case ast.JoinNatural: + j.Type = ast.JoinInner + case ast.JoinNaturalLeft: + j.Type = ast.JoinLeft + case ast.JoinNaturalRight: + j.Type = ast.JoinRight + } + return + } + + // Find common columns (columns with matching names, case-insensitive) + // Use left table column order for deterministic output + leftCols := sortedResolverColumns(leftTable) + var commonCols []string + for _, lc := range leftCols { + if rightTable.GetColumn(lc.Name) != nil { + commonCols = append(commonCols, lc.Name) + } + } + + // Get effective table names for qualified column refs + leftName := tableExprEffectiveName(j.Left) + rightName := tableExprEffectiveName(j.Right) + + // Build ON condition from common columns + if len(commonCols) > 0 { + j.Condition = &ast.OnCondition{ + Expr: buildColumnEqualityChain(commonCols, leftName, rightName), + } + } + + // Mark common columns from the right table as coalesced for star expansion. + // In NATURAL JOIN, common columns appear only once (from the left table). + if sc != nil && len(commonCols) > 0 { + for _, col := range commonCols { + sc.MarkCoalesced(rightName, col) + } + } + + // Change join type from NATURAL variant to standard variant + switch j.Type { + case ast.JoinNatural: + j.Type = ast.JoinInner + case ast.JoinNaturalLeft: + j.Type = ast.JoinLeft + case ast.JoinNaturalRight: + j.Type = ast.JoinRight + } +} + +// expandUsingCondition converts a USING condition into an ON condition with +// fully qualified column references, then replaces the condition on the join. +func (r *Resolver) expandUsingCondition(j *ast.JoinClause, using *ast.UsingCondition) { + leftName := tableExprEffectiveName(j.Left) + rightName := tableExprEffectiveName(j.Right) + + if len(using.Columns) > 0 && leftName != "" && rightName != "" { + j.Condition = &ast.OnCondition{ + Expr: buildColumnEqualityChain(using.Columns, leftName, rightName), + } + } +} + +// lookupTableExpr returns the ResolverTable for a table expression. +// Only works for simple TableRef nodes. +func (r *Resolver) lookupTableExpr(tbl ast.TableExpr) *ResolverTable { + switch t := tbl.(type) { + case *ast.TableRef: + return r.Lookup(t.Name) + default: + return nil + } +} + +// tableExprEffectiveName returns the effective name (alias or table name) of a table expression. +func tableExprEffectiveName(tbl ast.TableExpr) string { + switch t := tbl.(type) { + case *ast.TableRef: + if t.Alias != "" { + return t.Alias + } + return t.Name + default: + return "" + } +} + +// buildColumnEqualityChain builds an AND-chained equality expression for column pairs. +// e.g., columns [a, b] with left=t1, right=t2 → +// +// ((`t1`.`a` = `t2`.`a`) and (`t1`.`b` = `t2`.`b`)) +func buildColumnEqualityChain(columns []string, leftName, rightName string) ast.ExprNode { + if len(columns) == 0 { + return nil + } + + // Build individual equality expressions + equalities := make([]ast.ExprNode, len(columns)) + for i, col := range columns { + equalities[i] = &ast.BinaryExpr{ + Op: ast.BinOpEq, + Left: &ast.ColumnRef{Table: leftName, Column: col}, + Right: &ast.ColumnRef{Table: rightName, Column: col}, + } + } + + // Chain with AND if multiple + if len(equalities) == 1 { + return equalities[0] + } + result := equalities[0] + for i := 1; i < len(equalities); i++ { + result = &ast.BinaryExpr{ + Op: ast.BinOpAnd, + Left: result, + Right: equalities[i], + } + } + return result +} + +// resolveCastCharsets walks the expression tree and sets charset on CAST/CONVERT DataTypes. +func (r *Resolver) resolveCastCharsets(node ast.ExprNode) { + if node == nil { + return + } + switch n := node.(type) { + case *ast.CastExpr: + r.resolveCastCharset(n.TypeName) + r.resolveCastCharsets(n.Expr) + case *ast.ConvertExpr: + r.resolveCastCharset(n.TypeName) + r.resolveCastCharsets(n.Expr) + case *ast.BinaryExpr: + r.resolveCastCharsets(n.Left) + r.resolveCastCharsets(n.Right) + case *ast.UnaryExpr: + r.resolveCastCharsets(n.Operand) + case *ast.ParenExpr: + r.resolveCastCharsets(n.Expr) + case *ast.FuncCallExpr: + for _, arg := range n.Args { + r.resolveCastCharsets(arg) + } + case *ast.CaseExpr: + r.resolveCastCharsets(n.Operand) + for _, w := range n.Whens { + r.resolveCastCharsets(w.Cond) + r.resolveCastCharsets(w.Result) + } + r.resolveCastCharsets(n.Default) + } +} + +// resolveCastCharset sets the charset on a CAST/CONVERT DataType for CHAR types. +// MySQL adds "charset " when no charset is explicitly specified. +// The resolver uses DefaultCharset (from the catalog's database charset). +func (r *Resolver) resolveCastCharset(dt *ast.DataType) { + if dt == nil { + return + } + name := strings.ToLower(dt.Name) + if name == "char" && dt.Charset == "" { + charset := r.DefaultCharset + if charset == "" { + charset = "utf8mb4" + } + dt.Charset = charset + } +} + +// withCTELookup returns a TableLookup function that first checks CTE virtual tables, +// then falls back to the given fallback lookup. +// We capture fallback by value (not r.Lookup) to avoid infinite recursion when +// r.Lookup is later overwritten to point at the returned function itself. +func (r *Resolver) withCTELookup(cteTables map[string]*ResolverTable) TableLookup { + fallback := r.Lookup // capture current value, not a reference to r.Lookup + return func(tableName string) *ResolverTable { + key := strings.ToLower(tableName) + if vt, ok := cteTables[key]; ok { + return vt + } + if fallback != nil { + return fallback(tableName) + } + return nil + } +} + +// collectLeftmostCTEs walks down the left spine of a set operation tree and +// returns CTEs from the leftmost leaf. Unlike extractCTEs in the deparser, +// this does NOT clear the CTEs — they must remain in the AST for the deparser +// to emit the WITH clause later. Instead, it marks them as already-resolved +// by removing them from a copy so the resolver doesn't double-process. +func collectLeftmostCTEs(stmt *ast.SelectStmt) []*ast.CommonTableExpr { + cur := stmt + for cur.SetOp != ast.SetOpNone && cur.Left != nil { + cur = cur.Left + } + if len(cur.CTEs) > 0 { + return cur.CTEs + } + return nil +} + +// buildCTEVirtualTableFromSelect constructs a ResolverTable from a specific SELECT branch. +// Used for recursive CTEs where we need to build the virtual table from the non-recursive +// (left) branch before resolving the recursive (right) branch. +func buildCTEVirtualTableFromSelect(cteName string, cteColumns []string, sel *ast.SelectStmt) *ResolverTable { + if sel == nil { + return nil + } + + // If the CTE has an explicit column list, use those names + if len(cteColumns) > 0 { + cols := make([]ResolverColumn, len(cteColumns)) + for i, name := range cteColumns { + cols[i] = ResolverColumn{Name: name, Position: i + 1} + } + return &ResolverTable{Name: cteName, Columns: cols} + } + + // Walk down to the leftmost leaf for set operations + for sel.SetOp != ast.SetOpNone && sel.Left != nil { + sel = sel.Left + } + + var cols []ResolverColumn + for i, target := range sel.TargetList { + name := cteColumnName(target, i+1) + cols = append(cols, ResolverColumn{Name: name, Position: i + 1}) + } + if len(cols) == 0 { + return nil + } + return &ResolverTable{Name: cteName, Columns: cols} +} + +// buildCTEVirtualTable constructs a ResolverTable from a CTE's SELECT target list. +// This allows the main query to resolve column references to CTE columns. +func buildCTEVirtualTable(cte *ast.CommonTableExpr) *ResolverTable { + if cte.Select == nil { + return nil + } + + // If the CTE has an explicit column list, use those names + if len(cte.Columns) > 0 { + cols := make([]ResolverColumn, len(cte.Columns)) + for i, name := range cte.Columns { + cols[i] = ResolverColumn{Name: name, Position: i + 1} + } + return &ResolverTable{Name: cte.Name, Columns: cols} + } + + // Otherwise, derive columns from the CTE's SELECT target list + sel := cte.Select + // For set operations, use the left side's target list + for sel.SetOp != ast.SetOpNone && sel.Left != nil { + sel = sel.Left + } + + var cols []ResolverColumn + for i, target := range sel.TargetList { + name := cteColumnName(target, i+1) + cols = append(cols, ResolverColumn{Name: name, Position: i + 1}) + } + if len(cols) == 0 { + return nil + } + return &ResolverTable{Name: cte.Name, Columns: cols} +} + +// cteColumnName extracts the column name from a target list entry. +// Uses alias if present, otherwise column ref name, otherwise positional name. +func cteColumnName(target ast.ExprNode, position int) string { + if rt, ok := target.(*ast.ResTarget); ok { + if rt.Name != "" { + return rt.Name + } + if col, ok := rt.Val.(*ast.ColumnRef); ok { + return col.Column + } + return fmt.Sprintf("Name_exp_%d", position) + } + if col, ok := target.(*ast.ColumnRef); ok { + return col.Column + } + return fmt.Sprintf("Name_exp_%d", position) +} + +// buildDerivedVirtualTable constructs a ResolverTable from a derived table's SELECT target list. +// This allows the outer query to resolve column references to derived table columns. +func buildDerivedVirtualTable(sub *ast.SubqueryExpr) *ResolverTable { + if sub.Select == nil || sub.Alias == "" { + return nil + } + + sel := sub.Select + // For set operations, use the left side's target list + for sel.SetOp != ast.SetOpNone && sel.Left != nil { + sel = sel.Left + } + + var cols []ResolverColumn + for i, target := range sel.TargetList { + name := cteColumnName(target, i+1) + cols = append(cols, ResolverColumn{Name: name, Position: i + 1}) + } + if len(cols) == 0 { + return nil + } + return &ResolverTable{Name: sub.Alias, Columns: cols} +} + +// AmbiguousColumnError is returned when a column reference matches multiple tables. +type AmbiguousColumnError struct { + Column string + Tables []string +} + +func (e *AmbiguousColumnError) Error() string { + return fmt.Sprintf("column %q is ambiguous, found in tables: %s", e.Column, strings.Join(e.Tables, ", ")) +} diff --git a/tidb/deparse/resolver_test.go b/tidb/deparse/resolver_test.go new file mode 100644 index 00000000..a2b4a280 --- /dev/null +++ b/tidb/deparse/resolver_test.go @@ -0,0 +1,484 @@ +package deparse_test + +import ( + "strings" + "testing" + + ast "github.com/bytebase/omni/tidb/ast" + "github.com/bytebase/omni/tidb/catalog" + "github.com/bytebase/omni/tidb/deparse" + "github.com/bytebase/omni/tidb/parser" +) + +// setupCatalog creates a catalog with a test database and tables for resolver tests. +// Schema: testdb.t(a INT, b INT, c INT), testdb.t1(a INT, b INT), testdb.t2(a INT, d INT) +func setupCatalog(t *testing.T) *catalog.Catalog { + t.Helper() + c := catalog.New() + sqls := []string{ + "CREATE DATABASE testdb", + "USE testdb", + "CREATE TABLE t (a INT, b INT, c INT)", + "CREATE TABLE t1 (a INT, b INT)", + "CREATE TABLE t2 (a INT, d INT)", + } + for _, sql := range sqls { + _, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("catalog setup failed on %q: %v", sql, err) + } + } + return c +} + +// catalogLookup creates a TableLookup function from a catalog.Catalog. +func catalogLookup(c *catalog.Catalog) deparse.TableLookup { + return func(tableName string) *deparse.ResolverTable { + db := c.GetDatabase(c.CurrentDatabase()) + if db == nil { + return nil + } + tbl := db.GetTable(tableName) + if tbl == nil { + return nil + } + cols := make([]deparse.ResolverColumn, len(tbl.Columns)) + for i, col := range tbl.Columns { + cols[i] = deparse.ResolverColumn{Name: col.Name, Position: col.Position} + } + return &deparse.ResolverTable{Name: tbl.Name, Columns: cols} + } +} + +// resolveAndDeparse parses SQL, resolves column refs, rewrites, and deparses. +func resolveAndDeparse(t *testing.T, cat *catalog.Catalog, sql string) string { + t.Helper() + list, err := parser.Parse(sql) + if err != nil { + t.Fatalf("failed to parse %q: %v", sql, err) + } + if list.Len() == 0 { + t.Fatalf("no statements parsed from %q", sql) + } + sel, ok := list.Items[0].(*ast.SelectStmt) + if !ok { + t.Fatalf("expected SelectStmt, got %T", list.Items[0]) + } + + // Apply rewrites first (NOT folding, boolean context), then resolve + deparse.RewriteSelectStmt(sel) + + // Get the database default charset for the resolver + defaultCharset := "" + db := cat.GetDatabase(cat.CurrentDatabase()) + if db != nil { + defaultCharset = db.Charset + } + + resolver := &deparse.Resolver{Lookup: catalogLookup(cat), DefaultCharset: defaultCharset} + resolved := resolver.Resolve(sel) + return deparse.DeparseSelect(resolved) +} + +func TestResolver_Section_6_1_ColumnQualification(t *testing.T) { + cat := setupCatalog(t) + + cases := []struct { + name string + input string + expected string + }{ + { + "single_table_unqualified", + "SELECT a FROM t", + "select `t`.`a` AS `a` from `t`", + }, + { + "multiple_columns", + "SELECT a, b, c FROM t", + "select `t`.`a` AS `a`,`t`.`b` AS `b`,`t`.`c` AS `c` from `t`", + }, + { + "qualified_column_preserved", + "SELECT t.a FROM t", + "select `t`.`a` AS `a` from `t`", + }, + { + "table_alias", + "SELECT a FROM t AS x", + "select `x`.`a` AS `a` from `t` `x`", + }, + { + "table_alias_no_as", + "SELECT a FROM t x", + "select `x`.`a` AS `a` from `t` `x`", + }, + { + "column_in_where", + "SELECT a FROM t WHERE a > 0", + "select `t`.`a` AS `a` from `t` where (`t`.`a` > 0)", + }, + { + "column_in_order_by", + "SELECT a FROM t ORDER BY a", + "select `t`.`a` AS `a` from `t` order by `t`.`a`", + }, + { + "column_in_group_by", + "SELECT a FROM t GROUP BY a", + "select `t`.`a` AS `a` from `t` group by `t`.`a`", + }, + { + "column_in_having", + "SELECT a, COUNT(*) FROM t GROUP BY a HAVING a > 0", + "select `t`.`a` AS `a`,count(0) AS `COUNT(*)` from `t` group by `t`.`a` having (`t`.`a` > 0)", + }, + { + "column_in_on_condition", + "SELECT t1.a, t2.d FROM t1 JOIN t2 ON t1.a = t2.a", + "select `t1`.`a` AS `a`,`t2`.`d` AS `d` from (`t1` join `t2` on((`t1`.`a` = `t2`.`a`)))", + }, + { + "qualified_star", + "SELECT t1.*, t2.d FROM t1 JOIN t2 ON t1.a = t2.a", + "select `t1`.`a` AS `a`,`t1`.`b` AS `b`,`t2`.`d` AS `d` from (`t1` join `t2` on((`t1`.`a` = `t2`.`a`)))", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := resolveAndDeparse(t, cat, tc.input) + if got != tc.expected { + t.Errorf("resolveAndDeparse(%q) =\n %q\nwant:\n %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestResolver_Section_6_1_AmbiguousColumn(t *testing.T) { + cat := setupCatalog(t) + + // Column 'a' exists in both t1 and t2 — should still work (resolve to first match) + // For MySQL, an ambiguous unqualified column in a multi-table query is an error. + // Our resolver currently resolves to the first match for simplicity. + t.Run("ambiguous_column_two_tables", func(t *testing.T) { + got := resolveAndDeparse(t, cat, "SELECT a FROM t1 JOIN t2 ON t1.a = t2.a") + // The unqualified 'a' in SELECT will match t1 first (insertion order) + expected := "select `t1`.`a` AS `a` from (`t1` join `t2` on((`t1`.`a` = `t2`.`a`)))" + if got != expected { + t.Errorf("resolveAndDeparse(ambiguous) =\n %q\nwant:\n %q", got, expected) + } + }) +} + +func TestResolverTable_GetColumn(t *testing.T) { + rt := &deparse.ResolverTable{ + Name: "t", + Columns: []deparse.ResolverColumn{ + {Name: "a", Position: 1}, + {Name: "b", Position: 2}, + }, + } + + if rt.GetColumn("a") == nil { + t.Error("expected to find column 'a'") + } + if rt.GetColumn("A") == nil { + t.Error("expected case-insensitive match for 'A'") + } + if rt.GetColumn("c") != nil { + t.Error("expected nil for nonexistent column 'c'") + } +} + +func TestResolver_Section_6_2_SelectStarExpansion(t *testing.T) { + cat := setupCatalog(t) + + cases := []struct { + name string + input string + expected string + }{ + { + "simple_star", + "SELECT * FROM t", + "select `t`.`a` AS `a`,`t`.`b` AS `b`,`t`.`c` AS `c` from `t`", + }, + { + "star_alias_per_column", + // Each expanded column gets `table`.`col` AS `col` format + "SELECT * FROM t1", + "select `t1`.`a` AS `a`,`t1`.`b` AS `b` from `t1`", + }, + { + "star_ordering", + // Columns in table definition order (Column.Position) + "SELECT * FROM t", + "select `t`.`a` AS `a`,`t`.`b` AS `b`,`t`.`c` AS `c` from `t`", + }, + { + "star_with_where", + "SELECT * FROM t WHERE a > 0", + "select `t`.`a` AS `a`,`t`.`b` AS `b`,`t`.`c` AS `c` from `t` where (`t`.`a` > 0)", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := resolveAndDeparse(t, cat, tc.input) + if got != tc.expected { + t.Errorf("resolveAndDeparse(%q) =\n %q\nwant:\n %q", tc.input, got, tc.expected) + } + }) + } +} + +func TestResolver_Section_6_3_AutoAliasGeneration(t *testing.T) { + cat := setupCatalog(t) + + cases := []struct { + name string + input string + expected string + }{ + { + "column_ref", + "SELECT a FROM t", + "select `t`.`a` AS `a` from `t`", + }, + { + "qualified_column", + "SELECT t.a FROM t", + "select `t`.`a` AS `a` from `t`", + }, + { + "expression_auto_alias", + "SELECT a + b FROM t", + "select (`t`.`a` + `t`.`b`) AS `a + b` from `t`", + }, + { + "literal_int", + "SELECT 1", + "select 1 AS `1`", + }, + { + "literal_string", + "SELECT 'hello'", + "select 'hello' AS `hello`", + }, + { + "literal_null", + "SELECT NULL", + "select NULL AS `NULL`", + }, + { + "complex_expression_name_exp", + // Expression alias > 64 chars triggers Name_exp_N pattern + "SELECT CONCAT(CONCAT(CONCAT(CONCAT(CONCAT(CONCAT(CONCAT(a, b), c), a), b), c), a), b) FROM t", + "select concat(concat(concat(concat(concat(concat(concat(`t`.`a`,`t`.`b`),`t`.`c`),`t`.`a`),`t`.`b`),`t`.`c`),`t`.`a`),`t`.`b`) AS `Name_exp_1` from `t`", + }, + { + "explicit_alias_preserved", + "SELECT a AS x FROM t", + "select `t`.`a` AS `x` from `t`", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := resolveAndDeparse(t, cat, tc.input) + if got != tc.expected { + t.Errorf("resolveAndDeparse(%q) =\n %q\nwant:\n %q", tc.input, got, tc.expected) + } + }) + } +} + +// setupCatalogForJoins creates a catalog with tables that have overlapping columns +// for testing NATURAL JOIN expansion. +// Schema: testdb.t1(a INT, b INT), testdb.t2(a INT, d INT), testdb.t3(a INT, b INT) +func setupCatalogForJoins(t *testing.T) *catalog.Catalog { + t.Helper() + c := catalog.New() + sqls := []string{ + "CREATE DATABASE testdb", + "USE testdb", + "CREATE TABLE t1 (a INT, b INT)", + "CREATE TABLE t2 (a INT, d INT)", + "CREATE TABLE t3 (a INT, b INT)", + } + for _, sql := range sqls { + _, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("catalog setup failed on %q: %v", sql, err) + } + } + return c +} + +func TestResolver_Section_6_4_JoinNormalization(t *testing.T) { + t.Run("natural_join", func(t *testing.T) { + // t1(a, b) NATURAL JOIN t2(a, d) → common column: a + // Test with explicit column selection to verify ON expansion + cat := setupCatalogForJoins(t) + got := resolveAndDeparse(t, cat, "SELECT t1.a, t1.b, t2.d FROM t1 NATURAL JOIN t2") + expected := "select `t1`.`a` AS `a`,`t1`.`b` AS `b`,`t2`.`d` AS `d` from (`t1` join `t2` on((`t1`.`a` = `t2`.`a`)))" + if got != expected { + t.Errorf("NATURAL JOIN:\n got: %q\n want: %q", got, expected) + } + }) + + t.Run("natural_join_multi_common", func(t *testing.T) { + // t1(a, b) NATURAL JOIN t3(a, b) → common columns: a, b + cat := setupCatalogForJoins(t) + got := resolveAndDeparse(t, cat, "SELECT t1.a, t1.b FROM t1 NATURAL JOIN t3") + expected := "select `t1`.`a` AS `a`,`t1`.`b` AS `b` from (`t1` join `t3` on(((`t1`.`a` = `t3`.`a`) and (`t1`.`b` = `t3`.`b`))))" + if got != expected { + t.Errorf("NATURAL JOIN multi common:\n got: %q\n want: %q", got, expected) + } + }) + + t.Run("using_single_column", func(t *testing.T) { + // USING (a) → on((`t1`.`a` = `t2`.`a`)) + cat := setupCatalogForJoins(t) + got := resolveAndDeparse(t, cat, "SELECT t1.a, t1.b, t2.d FROM t1 JOIN t2 USING (a)") + expected := "select `t1`.`a` AS `a`,`t1`.`b` AS `b`,`t2`.`d` AS `d` from (`t1` join `t2` on((`t1`.`a` = `t2`.`a`)))" + if got != expected { + t.Errorf("USING single:\n got: %q\n want: %q", got, expected) + } + }) + + t.Run("using_multiple_columns", func(t *testing.T) { + // USING (a, b) via AST since parser may not support multi-column USING + cat := setupCatalogForJoins(t) + // Build manually: t1 JOIN t3 USING (a, b) + join := &ast.JoinClause{ + Type: ast.JoinInner, + Left: &ast.TableRef{Name: "t1"}, + Right: &ast.TableRef{Name: "t3"}, + Condition: &ast.UsingCondition{ + Columns: []string{"a", "b"}, + }, + } + sel := &ast.SelectStmt{ + TargetList: []ast.ExprNode{&ast.StarExpr{}}, + From: []ast.TableExpr{join}, + } + deparse.RewriteSelectStmt(sel) + resolver := &deparse.Resolver{Lookup: catalogLookup(cat)} + resolved := resolver.Resolve(sel) + got := deparse.DeparseSelect(resolved) + expected := "select `t1`.`a` AS `a`,`t1`.`b` AS `b`,`t3`.`a` AS `a`,`t3`.`b` AS `b` from (`t1` join `t3` on(((`t1`.`a` = `t3`.`a`) and (`t1`.`b` = `t3`.`b`))))" + if got != expected { + t.Errorf("USING multi:\n got: %q\n want: %q", got, expected) + } + }) + + t.Run("right_join_to_left_join", func(t *testing.T) { + // RIGHT JOIN → LEFT JOIN with table swap + cat := setupCatalogForJoins(t) + got := resolveAndDeparse(t, cat, "SELECT t1.a, t2.d FROM t1 RIGHT JOIN t2 ON t1.a = t2.a") + expected := "select `t1`.`a` AS `a`,`t2`.`d` AS `d` from (`t2` left join `t1` on((`t1`.`a` = `t2`.`a`)))" + if got != expected { + t.Errorf("RIGHT JOIN:\n got: %q\n want: %q", got, expected) + } + }) + + t.Run("cross_join_to_join", func(t *testing.T) { + // CROSS JOIN → plain join (no ON) + cat := setupCatalogForJoins(t) + got := resolveAndDeparse(t, cat, "SELECT t1.a, t2.d FROM t1 CROSS JOIN t2") + expected := "select `t1`.`a` AS `a`,`t2`.`d` AS `d` from (`t1` join `t2`)" + if got != expected { + t.Errorf("CROSS JOIN:\n got: %q\n want: %q", got, expected) + } + }) + + t.Run("implicit_cross_join", func(t *testing.T) { + // FROM t1, t2 → FROM (`t1` join `t2`) + cat := setupCatalogForJoins(t) + got := resolveAndDeparse(t, cat, "SELECT t1.a, t2.d FROM t1, t2") + expected := "select `t1`.`a` AS `a`,`t2`.`d` AS `d` from (`t1` join `t2`)" + if got != expected { + t.Errorf("Implicit cross join:\n got: %q\n want: %q", got, expected) + } + }) +} + +// setupCatalogWithCharset creates a catalog with a database using the given charset. +func setupCatalogWithCharset(t *testing.T, charset string) *catalog.Catalog { + t.Helper() + c := catalog.New() + createDB := "CREATE DATABASE testdb" + if charset != "" { + createDB += " CHARACTER SET " + charset + } + sqls := []string{ + createDB, + "USE testdb", + "CREATE TABLE t (a INT, b VARCHAR(100))", + } + for _, sql := range sqls { + _, err := c.Exec(sql, nil) + if err != nil { + t.Fatalf("catalog setup failed on %q: %v", sql, err) + } + } + return c +} + +func TestResolver_Section_6_5_CastCharsetFromCatalog(t *testing.T) { + t.Run("cast_char_uses_database_default_charset_utf8mb4", func(t *testing.T) { + // Default database charset is utf8mb4 — CAST to CHAR should use charset utf8mb4 + cat := setupCatalog(t) // uses default charset (utf8mb4) + got := resolveAndDeparse(t, cat, "SELECT CAST(a AS CHAR) FROM t") + expected := "select cast(`t`.`a` as char charset utf8mb4) AS `CAST(a AS CHAR)` from `t`" + if got != expected { + t.Errorf("CAST CHAR utf8mb4:\n got: %q\n want: %q", got, expected) + } + }) + + t.Run("cast_char_latin1_database", func(t *testing.T) { + // Database with latin1 charset — CAST to CHAR should use charset latin1 + cat := setupCatalogWithCharset(t, "latin1") + got := resolveAndDeparse(t, cat, "SELECT CAST(a AS CHAR) FROM t") + expected := "select cast(`t`.`a` as char charset latin1) AS `CAST(a AS CHAR)` from `t`" + if got != expected { + t.Errorf("CAST CHAR latin1:\n got: %q\n want: %q", got, expected) + } + }) + + t.Run("cast_char_n_latin1_database", func(t *testing.T) { + // Database with latin1 charset — CAST to CHAR(10) should use charset latin1 + cat := setupCatalogWithCharset(t, "latin1") + got := resolveAndDeparse(t, cat, "SELECT CAST(a AS CHAR(10)) FROM t") + expected := "select cast(`t`.`a` as char(10) charset latin1) AS `CAST(a AS CHAR(10))` from `t`" + if got != expected { + t.Errorf("CAST CHAR(10) latin1:\n got: %q\n want: %q", got, expected) + } + }) + + t.Run("cast_binary_unaffected", func(t *testing.T) { + // CAST to BINARY should always use charset binary, regardless of database charset + cat := setupCatalogWithCharset(t, "latin1") + got := resolveAndDeparse(t, cat, "SELECT CAST(a AS BINARY) FROM t") + expected := "select cast(`t`.`a` as char charset binary) AS `CAST(a AS BINARY)` from `t`" + if got != expected { + t.Errorf("CAST BINARY (latin1 db):\n got: %q\n want: %q", got, expected) + } + }) + + t.Run("convert_type_latin1_database", func(t *testing.T) { + // CONVERT(expr, CHAR) — rewritten to CAST, should use database charset + // MySQL 8.0 auto-alias preserves the original CONVERT(a, CHAR) form. + cat := setupCatalogWithCharset(t, "latin1") + got := resolveAndDeparse(t, cat, "SELECT CONVERT(a, CHAR) FROM t") + expected := "select cast(`t`.`a` as char charset latin1) AS `CONVERT(a, CHAR)` from `t`" + if got != expected { + t.Errorf("CONVERT CHAR latin1:\n got: %q\n want: %q", got, expected) + } + }) +} + +// Ensure the Resolver handles an unused import gracefully. +var _ = strings.ToLower diff --git a/tidb/deparse/rewrite.go b/tidb/deparse/rewrite.go new file mode 100644 index 00000000..6fc28325 --- /dev/null +++ b/tidb/deparse/rewrite.go @@ -0,0 +1,287 @@ +// Package deparse — rewrite.go implements AST rewrites applied before deparsing. +// These rewrites match MySQL 8.0's resolver behavior (SHOW CREATE VIEW output). +package deparse + +import ( + ast "github.com/bytebase/omni/tidb/ast" +) + +// RewriteExpr applies MySQL 8.0 resolver rewrites to an expression AST. +// Currently implements NOT folding: NOT(comparison) → inverted comparison. +// The rewrite is recursive — children are processed first (bottom-up). +func RewriteExpr(node ast.ExprNode) ast.ExprNode { + if node == nil { + return nil + } + return rewriteExpr(node) +} + +func rewriteExpr(node ast.ExprNode) ast.ExprNode { + switch n := node.(type) { + case *ast.UnaryExpr: + // First rewrite the operand recursively + n.Operand = rewriteExpr(n.Operand) + if n.Op == ast.UnaryNot { + return rewriteNot(n.Operand) + } + return n + + case *ast.BinaryExpr: + n.Left = rewriteExpr(n.Left) + n.Right = rewriteExpr(n.Right) + // MySQL 8.0 rewrites SOUNDS LIKE to soundex(left) = soundex(right) + if n.Op == ast.BinOpSoundsLike { + return &ast.ParenExpr{ + Expr: &ast.BinaryExpr{ + Op: ast.BinOpEq, + Left: &ast.FuncCallExpr{ + Name: "soundex", + HasParens: true, + Args: []ast.ExprNode{n.Left}, + }, + Right: &ast.FuncCallExpr{ + Name: "soundex", + HasParens: true, + Args: []ast.ExprNode{n.Right}, + }, + }, + } + } + // Boolean context wrapping: AND/OR/XOR operands that are not boolean + // expressions get wrapped in (0 <> expr). + if n.Op == ast.BinOpAnd || n.Op == ast.BinOpOr || n.Op == ast.BinOpXor { + if !isBooleanExpr(n.Left) { + n.Left = wrapBooleanContext(n.Left) + } + if !isBooleanExpr(n.Right) { + n.Right = wrapBooleanContext(n.Right) + } + } + return n + + case *ast.ParenExpr: + n.Expr = rewriteExpr(n.Expr) + return n + + case *ast.InExpr: + n.Expr = rewriteExpr(n.Expr) + for i, item := range n.List { + n.List[i] = rewriteExpr(item) + } + return n + + case *ast.BetweenExpr: + n.Expr = rewriteExpr(n.Expr) + n.Low = rewriteExpr(n.Low) + n.High = rewriteExpr(n.High) + return n + + case *ast.LikeExpr: + n.Expr = rewriteExpr(n.Expr) + n.Pattern = rewriteExpr(n.Pattern) + if n.Escape != nil { + n.Escape = rewriteExpr(n.Escape) + } + return n + + case *ast.IsExpr: + n.Expr = rewriteExpr(n.Expr) + // IS TRUE / IS FALSE on non-boolean: wrap operand in (0 <> expr) + if (n.Test == ast.IsTrue || n.Test == ast.IsFalse) && !n.Not && !isBooleanExpr(n.Expr) { + n.Expr = wrapBooleanContext(n.Expr) + } + return n + + case *ast.CaseExpr: + if n.Operand != nil { + n.Operand = rewriteExpr(n.Operand) + } + for _, w := range n.Whens { + w.Cond = rewriteExpr(w.Cond) + w.Result = rewriteExpr(w.Result) + } + if n.Default != nil { + n.Default = rewriteExpr(n.Default) + } + return n + + case *ast.FuncCallExpr: + for i, arg := range n.Args { + n.Args[i] = rewriteExpr(arg) + } + return n + + case *ast.CastExpr: + n.Expr = rewriteExpr(n.Expr) + return n + + case *ast.ConvertExpr: + n.Expr = rewriteExpr(n.Expr) + return n + + case *ast.CollateExpr: + n.Expr = rewriteExpr(n.Expr) + return n + + default: + // Leaf nodes (literals, column refs, etc.) — no rewriting needed + return node + } +} + +// invertOp maps comparison operators to their NOT-inverted counterparts. +var invertOp = map[ast.BinaryOp]ast.BinaryOp{ + ast.BinOpGt: ast.BinOpLe, + ast.BinOpLt: ast.BinOpGe, + ast.BinOpGe: ast.BinOpLt, + ast.BinOpLe: ast.BinOpGt, + ast.BinOpEq: ast.BinOpNe, + ast.BinOpNe: ast.BinOpEq, +} + +// isComparisonOp returns true if op is a comparison operator that can be inverted. +func isComparisonOp(op ast.BinaryOp) bool { + _, ok := invertOp[op] + return ok +} + +// unwrapParen strips ParenExpr wrappers to get the inner expression. +func unwrapParen(node ast.ExprNode) ast.ExprNode { + for { + p, ok := node.(*ast.ParenExpr) + if !ok { + return node + } + node = p.Expr + } +} + +// rewriteNot applies NOT folding to the operand of a NOT expression. +// MySQL 8.0's resolver: +// - NOT (comparison) → inverted comparison (e.g., NOT(a > 0) → (a <= 0)) +// - NOT (LIKE) → not((expr like pattern)) — wraps in not(), doesn't fold +// - NOT (non-boolean) → (0 = expr) — e.g., NOT(a+1) → (0 = (a+1)), NOT(col) → (0 = col) +func rewriteNot(operand ast.ExprNode) ast.ExprNode { + inner := unwrapParen(operand) + + // Case 1: NOT(comparison) → invert the comparison operator + if binExpr, ok := inner.(*ast.BinaryExpr); ok { + if inverted, canInvert := invertOp[binExpr.Op]; canInvert { + return &ast.BinaryExpr{ + Loc: binExpr.Loc, + Op: inverted, + Left: binExpr.Left, + Right: binExpr.Right, + } + } + } + + // Case 2a: NOT(LIKE) — keep as not() wrapping (don't fold into the LIKE) + // The deparsing of UnaryNot already produces (not(...)), which is the correct + // MySQL 8.0 output. So we return the UnaryNot as-is. + if _, ok := inner.(*ast.LikeExpr); ok { + return &ast.UnaryExpr{ + Op: ast.UnaryNot, + Operand: operand, + } + } + + // Case 2b: NOT(REGEXP) — keep as not() wrapping. + // MySQL 8.0 rewrites NOT REGEXP to (not(regexp_like(...))). + if binExpr, ok := inner.(*ast.BinaryExpr); ok && binExpr.Op == ast.BinOpRegexp { + return &ast.UnaryExpr{ + Op: ast.UnaryNot, + Operand: operand, + } + } + + // Case 3: NOT(non-boolean) → (0 = expr) + // This handles: NOT(column), NOT(a+1), NOT(func()), etc. + // MySQL rewrites these as (0 = expr). + return &ast.BinaryExpr{ + Op: ast.BinOpEq, + Left: &ast.IntLit{Value: 0}, + Right: operand, + } +} + +// isBooleanExpr returns true if the expression is inherently boolean-valued. +// MySQL's is_bool_func() returns true for: comparisons (=,<>,<,>,<=,>=,<=>), +// IN, BETWEEN, LIKE, IS NULL/IS NOT NULL, AND, OR, NOT, XOR, EXISTS, +// TRUE/FALSE literals. Everything else (column refs, arithmetic, functions, +// CASE, IF, subqueries, literals) is NOT boolean and gets wrapped in (0 <> expr) +// when used as an operand of AND/OR/XOR. +func isBooleanExpr(node ast.ExprNode) bool { + inner := unwrapParen(node) + switch n := inner.(type) { + case *ast.BinaryExpr: + switch n.Op { + case ast.BinOpEq, ast.BinOpNe, ast.BinOpLt, ast.BinOpGt, + ast.BinOpLe, ast.BinOpGe, ast.BinOpNullSafeEq, + ast.BinOpAnd, ast.BinOpOr, ast.BinOpXor, ast.BinOpSoundsLike: + return true + } + return false + case *ast.InExpr: + return true + case *ast.BetweenExpr: + return true + case *ast.LikeExpr: + return true + case *ast.IsExpr: + return true + case *ast.UnaryExpr: + if n.Op == ast.UnaryNot { + return true + } + return false + case *ast.ExistsExpr: + return true + case *ast.BoolLit: + return true + default: + return false + } +} + +// wrapBooleanContext wraps a non-boolean expression in (0 <> expr) for +// boolean context (AND/OR/XOR operands, IS TRUE/IS FALSE operands). +func wrapBooleanContext(node ast.ExprNode) ast.ExprNode { + return &ast.BinaryExpr{ + Op: ast.BinOpNe, + Left: &ast.IntLit{Value: 0}, + Right: node, + } +} + +// RewriteSelectStmt applies RewriteExpr to all expression positions in a SelectStmt. +// This should be called after resolver but before deparsing. +func RewriteSelectStmt(stmt *ast.SelectStmt) { + if stmt == nil { + return + } + if stmt.SetOp != ast.SetOpNone { + RewriteSelectStmt(stmt.Left) + RewriteSelectStmt(stmt.Right) + return + } + for i, target := range stmt.TargetList { + if rt, ok := target.(*ast.ResTarget); ok { + rt.Val = RewriteExpr(rt.Val) + } else { + stmt.TargetList[i] = RewriteExpr(target) + } + } + if stmt.Where != nil { + stmt.Where = RewriteExpr(stmt.Where) + } + for i, expr := range stmt.GroupBy { + stmt.GroupBy[i] = RewriteExpr(expr) + } + if stmt.Having != nil { + stmt.Having = RewriteExpr(stmt.Having) + } + for _, item := range stmt.OrderBy { + item.Expr = RewriteExpr(item.Expr) + } +} diff --git a/tidb/scope/scope.go b/tidb/scope/scope.go new file mode 100644 index 00000000..7d540e5e --- /dev/null +++ b/tidb/scope/scope.go @@ -0,0 +1,129 @@ +// Package scope provides a shared, dependency-free scope data structure for +// tracking visible table references and resolving column names. It is used by +// both the semantic analyzer (mysql/catalog) and the AST resolver/deparser +// (mysql/deparse). +package scope + +import ( + "fmt" + "strings" +) + +// Column represents a column visible from a scope entry. +type Column struct { + Name string + Position int // 1-based position in the table +} + +// Table represents a table or virtual table in scope with its visible columns. +type Table struct { + Name string + Columns []Column +} + +// GetColumn returns a column by name (case-insensitive), or nil. +func (t *Table) GetColumn(name string) *Column { + lower := strings.ToLower(name) + for i := range t.Columns { + if strings.ToLower(t.Columns[i].Name) == lower { + return &t.Columns[i] + } + } + return nil +} + +// Entry is one named table reference visible in the current scope. +type Entry struct { + Name string // effective reference name (alias or table name), as registered + Table *Table +} + +// Scope tracks visible table references for column resolution. +type Scope struct { + entries []Entry + byName map[string]int // lowered name -> index into entries + coalescedCols map[string]bool // "tablename.colname" (lowered) -> hidden by USING/NATURAL +} + +// New creates an empty scope. +func New() *Scope { + return &Scope{ + byName: make(map[string]int), + coalescedCols: make(map[string]bool), + } +} + +// Add registers a table reference in the scope. +func (s *Scope) Add(name string, table *Table) { + lower := strings.ToLower(name) + s.byName[lower] = len(s.entries) + s.entries = append(s.entries, Entry{Name: name, Table: table}) +} + +// ResolveColumn finds an unqualified column name across all scope entries. +// Returns (entry index, 1-based column position, error). +// Returns error if not found or ambiguous. +func (s *Scope) ResolveColumn(colName string) (int, int, error) { + lower := strings.ToLower(colName) + foundIdx := -1 + foundPos := 0 + for i, e := range s.entries { + for j, c := range e.Table.Columns { + if strings.ToLower(c.Name) == lower { + if foundIdx >= 0 { + return 0, 0, fmt.Errorf("column '%s' is ambiguous", colName) + } + foundIdx = i + foundPos = j + 1 // 1-based + } + } + } + if foundIdx < 0 { + return 0, 0, fmt.Errorf("unknown column '%s'", colName) + } + return foundIdx, foundPos, nil +} + +// ResolveQualifiedColumn finds a column qualified by table name. +// Returns (entry index, 1-based column position, error). +func (s *Scope) ResolveQualifiedColumn(tableName, colName string) (int, int, error) { + lowerTable := strings.ToLower(tableName) + idx, ok := s.byName[lowerTable] + if !ok { + return 0, 0, fmt.Errorf("unknown table '%s'", tableName) + } + e := s.entries[idx] + lowerCol := strings.ToLower(colName) + for j, c := range e.Table.Columns { + if strings.ToLower(c.Name) == lowerCol { + return idx, j + 1, nil + } + } + return 0, 0, fmt.Errorf("unknown column '%s.%s'", tableName, colName) +} + +// GetTable returns the table for a named entry, or nil. +func (s *Scope) GetTable(name string) *Table { + idx, ok := s.byName[strings.ToLower(name)] + if !ok { + return nil + } + return s.entries[idx].Table +} + +// AllEntries returns all scope entries in registration order. +func (s *Scope) AllEntries() []Entry { + return s.entries +} + +// MarkCoalesced marks a column from a table as hidden by USING/NATURAL. +func (s *Scope) MarkCoalesced(tableName, colName string) { + key := strings.ToLower(tableName) + "." + strings.ToLower(colName) + s.coalescedCols[key] = true +} + +// IsCoalesced returns true if the column is hidden by USING/NATURAL. +func (s *Scope) IsCoalesced(tableName, colName string) bool { + key := strings.ToLower(tableName) + "." + strings.ToLower(colName) + return s.coalescedCols[key] +}