diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index adbb282c7a..ab10c975bb 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -2060,6 +2060,10 @@ func TestCreateTable(t *testing.T, harness Harness) { TestScriptPrepared(t, harness, script) } + for _, script := range queries.CreateTableAutoIncrementTests { + TestScript(t, harness, script) + } + harness.Setup(setup.MydbData, setup.MytableData) e := mustNewEngine(t, harness) defer e.Close() @@ -5415,86 +5419,6 @@ func TestAddDropPks(t *testing.T, harness Harness) { }) } -func TestAddAutoIncrementColumn(t *testing.T, harness Harness) { - harness.Setup([]setup.SetupScript{{ - "create database mydb", - "use mydb", - }}) - e := mustNewEngine(t, harness) - defer e.Close() - ctx := NewContext(harness) - - t.Run("Add primary key column with auto increment", func(t *testing.T) { - ctx.SetCurrentDatabase("mydb") - RunQuery(t, e, harness, "CREATE TABLE t1 (i int, j int);") - RunQuery(t, e, harness, "insert into t1 values (1,1), (2,2), (3,3)") - AssertErr( - t, e, harness, - "alter table t1 add column pk int primary key;", - sql.ErrPrimaryKeyViolation, - ) - - TestQueryWithContext( - t, ctx, e, harness, - "alter table t1 add column pk int primary key auto_increment;", - []sql.Row{{types.NewOkResult(0)}}, - nil, nil, - ) - - TestQueryWithContext( - t, ctx, e, harness, - "select pk from t1;", - []sql.Row{ - {1}, - {2}, - {3}, - }, - nil, nil, - ) - - TestQueryWithContext( - t, ctx, e, harness, - "show create table t1;", - []sql.Row{ - {"t1", "CREATE TABLE `t1` (\n `i` int,\n `j` int,\n `pk` int NOT NULL AUTO_INCREMENT,\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, - }, - nil, nil, - ) - }) - - t.Run("Add primary key column with auto increment first", func(t *testing.T) { - ctx.SetCurrentDatabase("mydb") - RunQuery(t, e, harness, "CREATE TABLE t2 (i int, j int);") - RunQuery(t, e, harness, "insert into t2 values (1,1), (2,2), (3,3)") - TestQueryWithContext( - t, ctx, e, harness, - "alter table t2 add column pk int primary key auto_increment first;", - []sql.Row{{types.NewOkResult(0)}}, - nil, nil, - ) - - TestQueryWithContext( - t, ctx, e, harness, - "select pk from t2;", - []sql.Row{ - {1}, - {2}, - {3}, - }, - nil, nil, - ) - - TestQueryWithContext( - t, ctx, e, harness, - "show create table t2;", - []sql.Row{ - {"t2", "CREATE TABLE `t2` (\n `pk` int NOT NULL AUTO_INCREMENT,\n `i` int,\n `j` int,\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, - }, - nil, nil, - ) - }) -} - func TestNullRanges(t *testing.T, harness Harness) { harness.Setup(setup.NullsSetup...) for _, tt := range queries.NullRangeTests { diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index a8e674c9ae..d0528c1a4e 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -271,33 +271,7 @@ func TestSingleQueryPrepared(t *testing.T) { // Convenience test for debugging a single query. Unskip and set to the desired query. func TestSingleScript(t *testing.T) { t.Skip() - var scripts = []queries.ScriptTest{ - { - Name: "renaming views with RENAME TABLE ... TO .. statement", - SetUpScript: []string{ - "create table t1 (id int primary key, v1 int);", - "create view v1 as select * from t1;", - }, - Assertions: []queries.ScriptTestAssertion{ - { - Query: "rename table v1 to view1", - Expected: []sql.Row{{types.OkResult{RowsAffected: 0}}}, - }, - { - Query: "show tables;", - Expected: []sql.Row{{"myview"}, {"t1"}, {"view1"}}, - }, - { - Query: "rename table view1 to newViewName, t1 to newTableName", - Expected: []sql.Row{{types.OkResult{RowsAffected: 0}}}, - }, - { - Query: "show tables;", - Expected: []sql.Row{{"myview"}, {"newTableName"}, {"newViewName"}}, - }, - }, - }, - } + var scripts = []queries.ScriptTest{} for _, test := range scripts { harness := enginetest.NewMemoryHarness("", 1, testNumPartitions, true, nil) @@ -1307,7 +1281,9 @@ func TestAddDropPks_Exp(t *testing.T) { func TestAddAutoIncrementColumn(t *testing.T) { t.Skip("in memory tables don't implement sql.RewritableTable yet") - enginetest.TestAddAutoIncrementColumn(t, enginetest.NewDefaultMemoryHarness()) + for _, script := range queries.AlterTableAddAutoIncrementScripts { + enginetest.TestScript(t, enginetest.NewDefaultMemoryHarness(), script) + } } func TestNullRanges(t *testing.T) { diff --git a/enginetest/queries/alter_table_queries.go b/enginetest/queries/alter_table_queries.go index e37e595e5b..f9c7886dc5 100755 --- a/enginetest/queries/alter_table_queries.go +++ b/enginetest/queries/alter_table_queries.go @@ -355,4 +355,182 @@ var AlterTableScripts = []ScriptTest{ }, }, }, + { + Name: "multi-alter ddl column errors", + SetUpScript: []string{ + "create table tbl_i (i int primary key)", + "create table tbl_ij (i int primary key, j int)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "alter table tbl_i add column j int, drop column j", + ExpectedErr: sql.ErrTableColumnNotFound, + }, + { + Query: "alter table tbl_i add column j int, rename column j to k;", + ExpectedErr: sql.ErrTableColumnNotFound, + }, + { + Query: "alter table tbl_i add column j int, modify column j varchar(10)", + ExpectedErr: sql.ErrTableColumnNotFound, + }, + { + Query: "alter table tbl_ij drop column j, rename column j to k;", + ExpectedErr: sql.ErrTableColumnNotFound, + }, + { + Query: "alter table tbl_ij drop column k, rename column j to k;", + ExpectedErr: sql.ErrTableColumnNotFound, + }, + { + Query: "alter table tbl_i add index(j), add column j int;", + ExpectedErr: sql.ErrKeyColumnDoesNotExist, + }, + }, + }, + { + Name: "Add column and make unique in separate clauses", + SetUpScript: []string{ + "create table t (c1 int primary key, c2 int, c3 int)", + "insert into t values (1, 1, 1), (2, 2, 2), (3, 3, 3)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "alter table t add column c4 int null, add unique index uniq(c4)", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "show create table t", + Expected: []sql.Row{sql.Row{"t", + "CREATE TABLE `t` (\n" + + " `c1` int NOT NULL,\n" + + " `c2` int,\n" + + " `c3` int,\n" + + " `c4` int,\n" + + " PRIMARY KEY (`c1`),\n" + + " UNIQUE KEY `uniq` (`c4`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "select * from t", + Expected: []sql.Row{ + {1, 1, 1, nil}, + {2, 2, 2, nil}, + {3, 3, 3, nil}, + }, + }, + }, + }, +} + +var AlterTableAddAutoIncrementScripts = []ScriptTest{ + { + Name: "Add primary key column with auto increment", + SetUpScript: []string{ + "CREATE TABLE t1 (i int, j int);", + "insert into t1 values (1,1), (2,2), (3,3)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "alter table t1 add column pk int primary key auto_increment;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + Query: "show create table t1", + Expected: []sql.Row{{"t1", + "CREATE TABLE `t1` (\n" + + " `i` int,\n" + + " `j` int,\n" + + " `pk` int NOT NULL AUTO_INCREMENT,\n" + + " PRIMARY KEY (`pk`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "select pk from t1 order by pk", + Expected: []sql.Row{ + {1}, {2}, {3}, + }, + }, + }, + }, + { + Name: "Add primary key column with auto increment, first", + SetUpScript: []string{ + "CREATE TABLE t1 (i int, j int);", + "insert into t1 values (1,1), (2,2), (3,3)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "alter table t1 add column pk int primary key", + ExpectedErr: sql.ErrPrimaryKeyViolation, + }, + { + Query: "alter table t1 add column pk int primary key auto_increment first", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + Query: "show create table t1", + Expected: []sql.Row{{"t1", + "CREATE TABLE `t1` (\n" + + " `pk` int NOT NULL AUTO_INCREMENT,\n" + + " `i` int,\n" + + " `j` int,\n" + + " PRIMARY KEY (`pk`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "select pk from t1 order by pk", + Expected: []sql.Row{ + {1}, {2}, {3}, + }, + }, + }, + }, + { + Name: "add column auto_increment, non primary key", + SetUpScript: []string{ + "CREATE TABLE t1 (i bigint primary key, s varchar(20))", + "INSERT INTO t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "alter table t1 add column j int auto_increment unique", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + Query: "show create table t1", + Expected: []sql.Row{{"t1", + "CREATE TABLE `t1` (\n" + + " `i` bigint NOT NULL,\n" + + " `s` varchar(20),\n" + + " `j` int AUTO_INCREMENT,\n" + + " PRIMARY KEY (`i`),\n" + + " UNIQUE KEY `j` (`j`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "select * from t1 order by i", + Expected: []sql.Row{ + {1, "a", 1}, + {2, "b", 2}, + {3, "c", 3}, + }, + }, + }, + }, + { + Name: "add column auto_increment, non key", + SetUpScript: []string{ + "CREATE TABLE t1 (i bigint primary key, s varchar(20))", + "INSERT INTO t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "alter table t1 add column j int auto_increment", + ExpectedErr: sql.ErrInvalidAutoIncCols, + }, + }, + }, } diff --git a/enginetest/queries/create_table_queries.go b/enginetest/queries/create_table_queries.go index b7e206554f..75a4862076 100644 --- a/enginetest/queries/create_table_queries.go +++ b/enginetest/queries/create_table_queries.go @@ -248,6 +248,89 @@ var CreateTableScriptTests = []ScriptTest{ }, } +var CreateTableAutoIncrementTests = []ScriptTest{ + { + Name: "create table with non primary auto_increment column", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table t1 (a int auto_increment unique, b int, primary key(b))", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + Query: "insert into t1 (b) values (1), (2)", + Expected: []sql.Row{ + { + types.OkResult{ + RowsAffected: 2, + InsertID: 1, + }, + }, + }, + }, + { + Query: "show create table t1", + Expected: []sql.Row{{"t1", + "CREATE TABLE `t1` (\n" + + " `a` int AUTO_INCREMENT,\n" + + " `b` int NOT NULL,\n" + + " PRIMARY KEY (`b`),\n" + + " UNIQUE KEY `a` (`a`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "select * from t1 order by b", + Expected: []sql.Row{{1, 1}, {2, 2}}, + }, + }, + }, + { + Name: "create table with non primary auto_increment column, separate unique key", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table t1 (a int auto_increment, b int, primary key(b), unique key(a))", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + Query: "insert into t1 (b) values (1), (2)", + Expected: []sql.Row{ + { + types.OkResult{ + RowsAffected: 2, + InsertID: 1, + }, + }, + }, + }, + { + Query: "show create table t1", + Expected: []sql.Row{{"t1", + "CREATE TABLE `t1` (\n" + + " `a` int AUTO_INCREMENT,\n" + + " `b` int NOT NULL,\n" + + " PRIMARY KEY (`b`),\n" + + " UNIQUE KEY `a` (`a`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "select * from t1 order by b", + Expected: []sql.Row{{1, 1}, {2, 2}}, + }, + }, + }, + { + Name: "create table with non primary auto_increment column, missing unique key", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table t1 (a int auto_increment, b int, primary key(b))", + ExpectedErr: sql.ErrInvalidAutoIncCols, + }, + }, + }, +} + var BrokenCreateTableQueries = []WriteQueryTest{ { WriteQuery: `create table t1 (b blob, primary key(b(1)))`, diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index e73873fa9d..39cfd723f7 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -3100,63 +3100,6 @@ var ScriptTests = []ScriptTest{ }, }, }, - { - Name: "multi-alter ddl column statements", - SetUpScript: []string{ - "create table tbl_i (i int primary key)", - "create table tbl_ij (i int primary key, j int)", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "alter table tbl_i add column j int, drop column j", - ExpectedErr: sql.ErrTableColumnNotFound, - }, - { - Query: "alter table tbl_i add column j int, rename column j to k;", - ExpectedErr: sql.ErrTableColumnNotFound, - }, - { - Query: "alter table tbl_i add column j int, modify column j varchar(10)", - ExpectedErr: sql.ErrTableColumnNotFound, - }, - { - Query: "alter table tbl_ij add index (j), drop column j;", - ExpectedErr: sql.ErrKeyColumnDoesNotExist, - }, - { - Query: "alter table tbl_ij drop column j, rename column j to k;", - ExpectedErr: sql.ErrTableColumnNotFound, - }, - { - Query: "alter table tbl_ij drop column k, rename column j to k;", - ExpectedErr: sql.ErrTableColumnNotFound, - }, - { - Query: "alter table tbl_i add index(j), add column j int;", - Expected: []sql.Row{ - {types.NewOkResult(0)}, - }, - }, - { - Query: "show create table tbl_i", - Expected: []sql.Row{ - {"tbl_i", "CREATE TABLE `tbl_i` (\n `i` int NOT NULL,\n `j` int,\n PRIMARY KEY (`i`),\n KEY `j` (`j`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, - }, - }, - { - Query: "alter table tbl_ij add index (j), drop column j, add column j int;", - Expected: []sql.Row{ - {types.NewOkResult(0)}, - }, - }, - { - Query: "show create table tbl_ij", - Expected: []sql.Row{ - {"tbl_ij", "CREATE TABLE `tbl_ij` (\n `i` int NOT NULL,\n `j` int,\n PRIMARY KEY (`i`),\n KEY `j` (`j`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, - }, - }, - }, - }, { Name: "Keyless Table with Unique Index", SetUpScript: []string{ diff --git a/sql/analyzer/validate_create_table.go b/sql/analyzer/validate_create_table.go index db6d42cb8e..eae43545c2 100644 --- a/sql/analyzer/validate_create_table.go +++ b/sql/analyzer/validate_create_table.go @@ -48,7 +48,7 @@ func validateCreateTable(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan. } } - err = validateAutoIncrement(ct.CreateSchema.Schema, keyedColumns) + err = validateAutoIncrementModify(ct.CreateSchema.Schema, keyedColumns) if err != nil { return nil, transform.SameTree, err } @@ -59,7 +59,7 @@ func validateCreateTable(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan. func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { var sch sql.Schema var indexes []string - var keyedColumns map[string]bool + keyedColumns := make(map[string]bool) var err error transform.Inspect(n, func(n sql.Node) bool { switch n := n.(type) { @@ -102,15 +102,18 @@ func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S sch = sch.Copy() // Make a copy of the original schema to deal with any references to the original table. initialSch := sch + addedColumn := false + // Need a TransformUp here because multiple of these statement types can be nested under a Block node. // It doesn't look it, but this is actually an iterative loop over all the independent clauses in an ALTER statement - return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { + n, same, err := transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { switch nn := n.(type) { case *plan.ModifyColumn: n, err := nn.WithTargetSchema(sch.Copy()) if err != nil { return nil, transform.SameTree, err } + sch, err = validateModifyColumn(ctx, initialSch, sch, n.(*plan.ModifyColumn), keyedColumns) if err != nil { return nil, transform.SameTree, err @@ -127,16 +130,17 @@ func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S } return n, transform.NewTree, nil case *plan.AddColumn: - // TODO: can't `alter table add column j int unique auto_increment` as it ignores unique - // TODO: when above works, need to make sure unique index exists first then do what we did for modify n, err := nn.WithTargetSchema(sch.Copy()) if err != nil { return nil, transform.SameTree, err } - sch, err = validateAddColumn(initialSch, sch, n.(*plan.AddColumn), keyedColumns) + + sch, err = validateAddColumn(initialSch, sch, n.(*plan.AddColumn)) if err != nil { return nil, transform.SameTree, err } + + addedColumn = true return n, transform.NewTree, nil case *plan.DropColumn: n, err := nn.WithTargetSchema(sch.Copy()) @@ -147,6 +151,8 @@ func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S if err != nil { return nil, transform.SameTree, err } + delete(keyedColumns, nn.Column) + return n, transform.NewTree, nil case *plan.AlterIndex: n, err := nn.WithTargetSchema(sch.Copy()) @@ -157,6 +163,8 @@ func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S if err != nil { return nil, transform.SameTree, err } + + keyedColumns = updateKeyedColumns(keyedColumns, nn) return n, transform.NewTree, nil case *plan.AlterPK: n, err := nn.WithTargetSchema(sch.Copy()) @@ -191,6 +199,37 @@ func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S } return n, transform.SameTree, nil }) + + if err != nil { + return nil, transform.SameTree, err + } + + // We can't evaluate auto-increment until the end of the analysis, since we break adding a new auto-increment unique + // column into two steps: first add the column, then create the index. If there was no index created, that's an error. + if addedColumn { + err = validateAutoIncrementAdd(ctx, sch, keyedColumns) + if err != nil { + return nil, false, err + } + } + + return n, same, nil +} + +// updateKeyedColumns updates the keyedColumns map based on the action of the AlterIndex node +func updateKeyedColumns(keyedColumns map[string]bool, n *plan.AlterIndex) map[string]bool { + switch n.Action { + case plan.IndexAction_Create: + for _, col := range n.Columns { + keyedColumns[col.Name] = true + } + case plan.IndexAction_Drop: + for _, col := range n.Columns { + delete(keyedColumns, col.Name) + } + } + + return keyedColumns } // validateRenameColumn checks that a DDL RenameColumn node can be safely executed (e.g. no collision with other @@ -222,7 +261,7 @@ func validateRenameColumn(initialSch, sch sql.Schema, rc *plan.RenameColumn) (sq return renameInSchema(sch, rc.ColumnName, rc.NewColumnName, nameable.Name()), nil } -func validateAddColumn(initialSch sql.Schema, schema sql.Schema, ac *plan.AddColumn, keyedColumns map[string]bool) (sql.Schema, error) { +func validateAddColumn(initialSch sql.Schema, schema sql.Schema, ac *plan.AddColumn) (sql.Schema, error) { table := ac.Table nameable := table.(sql.Nameable) @@ -251,12 +290,6 @@ func validateAddColumn(initialSch sql.Schema, schema sql.Schema, ac *plan.AddCol newSch = append(newSch, ac.Column().Copy()) } - // TODO: more validation possible to do here - err := validateAutoIncrement(newSch, keyedColumns) - if err != nil { - return nil, err - } - return newSch, nil } @@ -273,7 +306,7 @@ func validateModifyColumn(ctx *sql.Context, initialSch sql.Schema, schema sql.Sc newSch := replaceInSchema(schema, mc.NewColumn(), nameable.Name()) - err := validateAutoIncrement(newSch, keyedColumns) + err := validateAutoIncrementModify(newSch, keyedColumns) if err != nil { return nil, err } @@ -597,7 +630,8 @@ func removeInSchema(sch sql.Schema, colName, tableName string) sql.Schema { return schCopy } -func validateAutoIncrement(schema sql.Schema, keyedColumns map[string]bool) error { +// TODO: make this work for CREATE TABLE statements where there's a non-pk auto increment column +func validateAutoIncrementModify(schema sql.Schema, keyedColumns map[string]bool) error { seen := false for _, col := range schema { if col.AutoIncrement { @@ -620,6 +654,30 @@ func validateAutoIncrement(schema sql.Schema, keyedColumns map[string]bool) erro return nil } +func validateAutoIncrementAdd(ctx *sql.Context, schema sql.Schema, keyColumns map[string]bool) error { + seen := false + for _, col := range schema { + if col.AutoIncrement { + { + if !col.PrimaryKey && !keyColumns[col.Name] { + // AUTO_INCREMENT col must be a key + return sql.ErrInvalidAutoIncCols.New() + } + if col.Default != nil { + // AUTO_INCREMENT col cannot have default + return sql.ErrInvalidAutoIncCols.New() + } + if seen { + // there can be at most one AUTO_INCREMENT col + return sql.ErrInvalidAutoIncCols.New() + } + seen = true + } + } + } + return nil +} + const textIndexPrefix = 1000 // validateIndexes prevents creating tables with blob/text primary keys and indexes without a specified length diff --git a/sql/parse/parse.go b/sql/parse/parse.go index b1415e00f5..31c1ffcba1 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -19,7 +19,6 @@ import ( goerrors "errors" "fmt" "regexp" - "sort" "strconv" "strings" "time" @@ -1306,18 +1305,14 @@ func convertDDL(ctx *sql.Context, query string, c *sqlparser.DDL) (sql.Node, err } } -// convertAlterTable converts AlterTable AST nodes -// If there are multiple alter statements, they are sorted in order of their precedence and placed inside a plan.Block -// Currently, the precedence of DDL statements is: -// 1. RENAME COLUMN -// 2. DROP COLUMN -// 3. MODIFY COLUMN -// 4. ADD COLUMN -// 5. DROP CHECK/CONSTRAINT -// 7. CREATE CHECK/CONSTRAINT -// 8. RENAME INDEX -// 9. DROP INDEX -// 10. ADD INDEX +// convertAlterTable converts AlterTable AST nodes. If there is a single clause in the statement, it is returned as +// the appropriate node type. Otherwise, a plan.Block is returned with children representing all the various clauses. +// Our validation rules for what counts as a legal set of alter clauses differs from mysql's here. MySQL seems to apply +// some form of precedence rules to the clauses in an ALTER TABLE so that e.g. DROP COLUMN always happens before other +// kinds of statements. So in MySQL, statements like `ALTER TABLE t ADD KEY (a), DROP COLUMN a` fails, whereas our +// analyzer happily produces a plan that adds an index and then drops that column. We do this in part for simplicity, +// and also because we construct more than one node per clause in some cases and really want them executed in a +// particular order in that case. func convertAlterTable(ctx *sql.Context, query string, c *sqlparser.AlterTable) (sql.Node, error) { statements := make([]sql.Node, 0, len(c.Statements)) for i := 0; i < len(c.Statements); i++ { @@ -1332,77 +1327,6 @@ func convertAlterTable(ctx *sql.Context, query string, c *sqlparser.AlterTable) return statements[0], nil } - // TODO: add correct precedence for ADD/DROP PRIMARY KEY and (maybe) FOREIGN KEY - // certain alter statements need to happen before others - sort.Slice(statements, func(i, j int) bool { - switch ii := statements[i].(type) { - case *plan.RenameColumn: - switch statements[j].(type) { - case *plan.DropColumn, - *plan.ModifyColumn, - *plan.AddColumn, - *plan.DropConstraint, - *plan.DropCheck, - *plan.CreateCheck, - *plan.AlterIndex: - return true - } - case *plan.DropColumn: - switch statements[j].(type) { - case *plan.ModifyColumn, - *plan.AddColumn, - *plan.DropConstraint, - *plan.DropCheck, - *plan.CreateCheck, - *plan.AlterIndex: - return true - } - case *plan.ModifyColumn: - switch statements[j].(type) { - case *plan.AddColumn, - *plan.DropConstraint, - *plan.DropCheck, - *plan.CreateCheck, - *plan.AlterIndex: - return true - } - case *plan.AddColumn: - switch statements[j].(type) { - case *plan.DropConstraint, - *plan.DropCheck, - *plan.CreateCheck, - *plan.AlterIndex: - return true - } - case *plan.DropConstraint: - switch statements[j].(type) { - case *plan.DropCheck, - *plan.CreateCheck, - *plan.AlterIndex: - return true - } - case *plan.DropCheck: - switch statements[j].(type) { - case *plan.CreateCheck, - *plan.AlterIndex: - return true - } - case *plan.CreateCheck: - switch statements[j].(type) { - case *plan.AlterIndex: - return true - } - // AlterIndex precedence is Rename, Drop, then Create - // So statement[i] < statement[j] = statement[i].action > statement[j].action - case *plan.AlterIndex: - switch jj := statements[j].(type) { - case *plan.AlterIndex: - return ii.Action > jj.Action - } - } - return false - }) - return plan.NewBlock(statements), nil } @@ -2093,16 +2017,9 @@ func convertRenameTable(ctx *sql.Context, ddl *sqlparser.DDL, alterTbl bool) (sq return plan.NewRenameTable(sql.UnresolvedDatabase(""), fromTables, toTables, alterTbl), nil } - -func isUniqueColumn(tableSpec *sqlparser.TableSpec, columnName string) (bool, error) { - for _, column := range tableSpec.Columns { - if column.Name.String() == columnName { - return column.Type.KeyOpt == colKeyUnique || - column.Type.KeyOpt == colKeyUniqueKey, nil - } - } - return false, fmt.Errorf("unknown column name %s", columnName) - +func isUniqueColumn(column *sqlparser.ColumnDefinition) bool { + return column.Type.KeyOpt == colKeyUnique || + column.Type.KeyOpt == colKeyUniqueKey } func newColumnAction(ctx *sql.Context, ddl *sqlparser.DDL) (sql.Node, error) { @@ -2112,6 +2029,7 @@ func newColumnAction(ctx *sql.Context, ddl *sqlparser.DDL) (sql.Node, error) { if err != nil { return nil, err } + return plan.NewAddColumn(sql.UnresolvedDatabase(ddl.Table.Qualifier.String()), tableNameToUnresolvedTable(ddl.Table), sch.Schema[0], columnOrderToColumnOrder(ddl.ColumnOrder)), nil case sqlparser.DropStr: return plan.NewDropColumn(sql.UnresolvedDatabase(ddl.Table.Qualifier.String()), tableNameToUnresolvedTable(ddl.Table), ddl.Column.String()), nil @@ -2146,10 +2064,7 @@ func convertAlterTableClause(ctx *sql.Context, query string, ddl *sqlparser.DDL) } column := ddl.TableSpec.Columns[0] - isUnique, err := isUniqueColumn(ddl.TableSpec, column.Name.String()) - if err != nil { - return nil, fmt.Errorf("on table %s, %w", ddl.Table.String(), err) - } + isUnique := isUniqueColumn(column) if isUnique { createIndex := plan.NewAlterCreateIndex( diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index da0abbb93c..151e5b64c4 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -2879,15 +2879,15 @@ CREATE TABLE t2 { input: `alter table t add index (i), drop index i, add check (i = 0), drop check chk, drop constraint c, add column i int, modify column i text, drop column i, rename column i to j`, plan: plan.NewBlock([]sql.Node{ - plan.NewRenameColumn(sql.UnresolvedDatabase(""), plan.NewUnresolvedTable("t", ""), "i", "j"), - plan.NewDropColumn(sql.UnresolvedDatabase(""), plan.NewUnresolvedTable("t", ""), "i"), - plan.NewModifyColumn(sql.UnresolvedDatabase(""), plan.NewUnresolvedTable("t", ""), "i", &sql.Column{Name: "i", Type: types.CreateText(sql.Collation_Unspecified), Nullable: true, Source: "t"}, nil), - plan.NewAddColumn(sql.UnresolvedDatabase(""), plan.NewUnresolvedTable("t", ""), &sql.Column{Name: "i", Type: types.Int32, Nullable: true, Source: "t"}, nil), - plan.NewDropConstraint(plan.NewUnresolvedTable("t", ""), "c"), - plan.NewAlterDropCheck(plan.NewUnresolvedTable("t", ""), "chk"), - plan.NewAlterAddCheck(plan.NewUnresolvedTable("t", ""), &sql.CheckConstraint{Name: "", Expr: expression.NewEquals(expression.NewUnresolvedColumn("i"), expression.NewLiteral(int8(0), types.Int8)), Enforced: true}), - plan.NewAlterDropIndex(sql.UnresolvedDatabase(""), plan.NewUnresolvedTable("t", ""), "i"), plan.NewAlterCreateIndex(sql.UnresolvedDatabase(""), plan.NewUnresolvedTable("t", ""), "", sql.IndexUsing_BTree, sql.IndexConstraint_None, []sql.IndexColumn{{Name: "i", Length: 0}}, ""), + plan.NewAlterDropIndex(sql.UnresolvedDatabase(""), plan.NewUnresolvedTable("t", ""), "i"), + plan.NewAlterAddCheck(plan.NewUnresolvedTable("t", ""), &sql.CheckConstraint{Name: "", Expr: expression.NewEquals(expression.NewUnresolvedColumn("i"), expression.NewLiteral(int8(0), types.Int8)), Enforced: true}), + plan.NewAlterDropCheck(plan.NewUnresolvedTable("t", ""), "chk"), + plan.NewDropConstraint(plan.NewUnresolvedTable("t", ""), "c"), + plan.NewAddColumn(sql.UnresolvedDatabase(""), plan.NewUnresolvedTable("t", ""), &sql.Column{Name: "i", Type: types.Int32, Nullable: true, Source: "t"}, nil), + plan.NewModifyColumn(sql.UnresolvedDatabase(""), plan.NewUnresolvedTable("t", ""), "i", &sql.Column{Name: "i", Type: types.CreateText(sql.Collation_Unspecified), Nullable: true, Source: "t"}, nil), + plan.NewDropColumn(sql.UnresolvedDatabase(""), plan.NewUnresolvedTable("t", ""), "i"), + plan.NewRenameColumn(sql.UnresolvedDatabase(""), plan.NewUnresolvedTable("t", ""), "i", "j"), }), }, { diff --git a/sql/planbuilder/ddl.go b/sql/planbuilder/ddl.go index 0af0aa3001..02a6e9f853 100644 --- a/sql/planbuilder/ddl.go +++ b/sql/planbuilder/ddl.go @@ -2,7 +2,6 @@ package planbuilder import ( "fmt" - "sort" "strconv" "strings" @@ -33,6 +32,14 @@ func (b *Builder) resolveDb(name string) sql.Database { return database } +// buildAlterTable converts AlterTable AST nodes. If there is a single clause in the statement, it is returned as +// the appropriate node type. Otherwise, a plan.Block is returned with children representing all the various clauses. +// Our validation rules for what counts as a legal set of alter clauses differs from mysql's here. MySQL seems to apply +// some form of precedence rules to the clauses in an ALTER TABLE so that e.g. DROP COLUMN always happens before other +// kinds of statements. So in MySQL, statements like `ALTER TABLE t ADD KEY (a), DROP COLUMN a` fails, whereas our +// analyzer happily produces a plan that adds an index and then drops that column. We do this in part for simplicity, +// and also because we construct more than one node per clause in some cases and really want them executed in a +// particular order in that case. func (b *Builder) buildAlterTable(inScope *scope, query string, c *ast.AlterTable) (outScope *scope) { b.multiDDL = true defer func() { @@ -53,30 +60,6 @@ func (b *Builder) buildAlterTable(inScope *scope, query string, c *ast.AlterTabl return outScope } - // certain alter statements need to happen before others - sort.Slice(statements, func(i, j int) bool { - switch statements[i].(type) { - case *plan.RenameColumn: - switch statements[j].(type) { - case *plan.DropColumn, - *plan.AddColumn, - *plan.AlterIndex: - return true - } - case *plan.DropColumn: - switch statements[j].(type) { - case *plan.AddColumn, - *plan.AlterIndex: - return true - } - case *plan.AddColumn: - switch statements[j].(type) { - case *plan.AlterIndex: - return true - } - } - return false - }) outScope = inScope.push() outScope.node = plan.NewBlock(statements) return diff --git a/sql/rowexec/ddl_iters.go b/sql/rowexec/ddl_iters.go index 391a5d42b1..1be7f588a6 100644 --- a/sql/rowexec/ddl_iters.go +++ b/sql/rowexec/ddl_iters.go @@ -1333,7 +1333,7 @@ func (i *addColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable) oldPkSchema, newPkSchema := sql.SchemaToPrimaryKeySchema(rwt, rwt.Schema()), sql.SchemaToPrimaryKeySchema(rwt, newSch) rewriteRequired := false - if i.a.Column().Default != nil || !i.a.Column().Nullable { + if i.a.Column().Default != nil || !i.a.Column().Nullable || i.a.Column().AutoIncrement { rewriteRequired = true }