diff --git a/column.go b/column.go index b21d0fec..31f0d1c6 100644 --- a/column.go +++ b/column.go @@ -34,9 +34,22 @@ const ( Time ColumnType = "TIME" ) +// AlterColumnMode enum. +type AlterColumnMode uint16 + +const ( + // AlterColumnType operation. + AlterColumnType AlterColumnMode = iota + 1 + // AlterColumnRequired operation. + AlterColumnRequired + // AlterColumnDefault operation. + AlterColumnDefault +) + // Column definition. type Column struct { Op SchemaOp + AlterMode AlterColumnMode Name string Type ColumnType Rename string @@ -75,6 +88,39 @@ func renameColumn(name string, newName string, options []ColumnOption) Column { return column } +func alterColumnType(name string, typ ColumnType, options []ColumnOption) []Column { + column := Column{ + Op: SchemaAlter, + Name: name, + Type: typ, + AlterMode: AlterColumnType, + } + for _, option := range options { + if option.isConstraint() { + continue + } + option.applyColumn(&column) + } + + return append([]Column{column}, alterColumn(name, options)...) +} + +func alterColumn(name string, options []ColumnOption) []Column { + columns := make([]Column, 0, len(options)) + for _, option := range options { + if !option.isConstraint() { + continue + } + column := Column{ + Op: SchemaAlter, + Name: name, + } + option.applyColumn(&column) + columns = append(columns, column) + } + return columns +} + func dropColumn(name string, options []ColumnOption) Column { column := Column{ Op: SchemaDrop, diff --git a/column_test.go b/column_test.go index 6c2a6d14..40a146c7 100644 --- a/column_test.go +++ b/column_test.go @@ -65,6 +65,32 @@ func TestRenameColumn(t *testing.T) { }, column) } +func TestAlterColumn(t *testing.T) { + var ( + options = []ColumnOption{ + Required(true), + Limit(1000), + } + columns = alterColumnType("alter", String, options) + ) + + assert.Equal(t, []Column{ + { + Op: SchemaAlter, + AlterMode: AlterColumnType, + Type: String, + Name: "alter", + Limit: 1000, + }, + { + Op: SchemaAlter, + AlterMode: AlterColumnRequired, + Name: "alter", + Required: true, + }, + }, columns) +} + func TestDropColumn(t *testing.T) { var ( options = []ColumnOption{ diff --git a/query.go b/query.go index 58445b62..9be2a3b1 100644 --- a/query.go +++ b/query.go @@ -16,9 +16,7 @@ type QueryPopulator interface { // Build for given table using given queriers. func Build(table string, queriers ...Querier) Query { - var ( - query = newQuery() - ) + query := newQuery() if len(queriers) > 0 { _, query.empty = queriers[0].(Query) @@ -255,9 +253,7 @@ func (q Query) Sort(fields ...string) Query { // SortAsc query. func (q Query) SortAsc(fields ...string) Query { - var ( - offset = len(q.SortQuery) - ) + offset := len(q.SortQuery) q.SortQuery = append(q.SortQuery, make([]SortQuery, len(fields))...) for i := range fields { @@ -269,9 +265,7 @@ func (q Query) SortAsc(fields ...string) Query { // SortDesc query. func (q Query) SortDesc(fields ...string) Query { - var ( - offset = len(q.SortQuery) - ) + offset := len(q.SortQuery) q.SortQuery = append(q.SortQuery, make([]SortQuery, len(fields))...) for i := range fields { @@ -538,6 +532,10 @@ func (l Limit) applyColumn(column *Column) { column.Limit = int(l) } +func (l Limit) isConstraint() bool { + return false +} + // Lock query. // This query will be ignored if used outside of transaction. type Lock string diff --git a/schema.go b/schema.go index b861d36a..6f2cc0f7 100644 --- a/schema.go +++ b/schema.go @@ -81,6 +81,31 @@ func (s *Schema) AddColumn(table string, name string, typ ColumnType, options .. s.add(at.Table) } +// AlterColumnType with name. +// +// Allows also changing other constraints like [rel.Default] and [rel.Required]. +// +// WARNING: Not supported by SQLite driver. +func (s *Schema) AlterColumnType(table string, name string, typ ColumnType, options ...ColumnOption) { + at := alterTable(table, nil) + at.AlterColumnType(name, typ, options...) + s.add(at.Table) +} + +// AlterColumn with name. +// +// Only [rel.Default] and [rel.Required] are supported. +// Support for underlying drivers might wary. For example PostgreSQL supports both, +// while Microsoft SQL Server and MySQL/MariaDB only supports [rel.Default]. +// See [Schema.AlterColumnType] if other constraints need to be changed also. +// +// WARNING: Not supported by SQLite driver. +func (s *Schema) AlterColumn(table string, name string, options ...ColumnOption) { + at := alterTable(table, nil) + at.AlterColumn(name, options...) + s.add(at.Table) +} + // RenameColumn by name. func (s *Schema) RenameColumn(table string, name string, newName string, options ...ColumnOption) { at := alterTable(table, nil) diff --git a/schema_options.go b/schema_options.go index 5ffe6e56..9ad48f35 100644 --- a/schema_options.go +++ b/schema_options.go @@ -16,6 +16,7 @@ func applyTableOptions(table *Table, options []TableOption) { // Available options are: Nil, Unsigned, Limit, Precision, Scale, Default, Comment, Options. type ColumnOption interface { applyColumn(column *Column) + isConstraint() bool } func applyColumnOptions(column *Column, options []ColumnOption) { @@ -43,6 +44,10 @@ func (r Primary) applyColumn(column *Column) { column.Primary = bool(r) } +func (r Primary) isConstraint() bool { + return false +} + // Unique set column as unique. type Unique bool @@ -50,6 +55,10 @@ func (r Unique) applyColumn(column *Column) { column.Unique = bool(r) } +func (r Unique) isConstraint() bool { + return false +} + func (r Unique) applyIndex(index *Index) { index.Unique = bool(r) } @@ -59,6 +68,13 @@ type Required bool func (r Required) applyColumn(column *Column) { column.Required = bool(r) + if column.Op == SchemaAlter { + column.AlterMode = AlterColumnRequired + } +} + +func (r Required) isConstraint() bool { + return true } // Unsigned sets integer column to be unsigned. @@ -68,6 +84,10 @@ func (u Unsigned) applyColumn(column *Column) { column.Unsigned = bool(u) } +func (r Unsigned) isConstraint() bool { + return false +} + // Precision defines the precision for the decimal fields, representing the total number of digits in the number. type Precision int @@ -75,6 +95,10 @@ func (p Precision) applyColumn(column *Column) { column.Precision = int(p) } +func (p Precision) isConstraint() bool { + return false +} + // Scale Defines the scale for the decimal fields, representing the number of digits after the decimal point. type Scale int @@ -82,12 +106,23 @@ func (s Scale) applyColumn(column *Column) { column.Scale = int(s) } +func (s Scale) isConstraint() bool { + return false +} + type defaultValue struct { value any } func (d defaultValue) applyColumn(column *Column) { column.Default = d.value + if column.Op == SchemaAlter { + column.AlterMode = AlterColumnDefault + } +} + +func (d defaultValue) isConstraint() bool { + return true } // Default allows to set a default value on the column.). @@ -120,6 +155,10 @@ func (o Options) applyColumn(column *Column) { column.Options = string(o) } +func (o Options) isConstraint() bool { + return false +} + func (o Options) applyIndex(index *Index) { index.Options = string(o) } diff --git a/schema_test.go b/schema_test.go index 9e0b7219..320aeef9 100644 --- a/schema_test.go +++ b/schema_test.go @@ -118,6 +118,49 @@ func TestSchema_AddColumn(t *testing.T) { }, schema.Migrations[0]) } +func TestSchema_AlterColumnTypeString(t *testing.T) { + var schema Schema + + schema.AlterColumnType("products", "description", String, Limit(100), Unique(false), Primary(false)) + + assert.Equal(t, Table{ + Op: SchemaAlter, + Name: "products", + Definitions: []TableDefinition{ + Column{Name: "description", Type: String, Op: SchemaAlter, Limit: 100, AlterMode: AlterColumnType}, + }, + }, schema.Migrations[0]) +} + +func TestSchema_AlterColumnTypeNumber(t *testing.T) { + var schema Schema + + schema.AlterColumnType("products", "description", Decimal, Scale(10), Precision(2), Unsigned(true), Options("")) + + assert.Equal(t, Table{ + Op: SchemaAlter, + Name: "products", + Definitions: []TableDefinition{ + Column{Name: "description", Type: Decimal, Op: SchemaAlter, Scale: 10, Precision: 2, Unsigned: true, AlterMode: AlterColumnType}, + }, + }, schema.Migrations[0]) +} + +func TestSchema_AlterColumn(t *testing.T) { + var schema Schema + + schema.AlterColumn("products", "description", Required(true), Default("")) + + assert.Equal(t, Table{ + Op: SchemaAlter, + Name: "products", + Definitions: []TableDefinition{ + Column{Name: "description", Op: SchemaAlter, Required: true, AlterMode: AlterColumnRequired}, + Column{Name: "description", Op: SchemaAlter, Default: "", AlterMode: AlterColumnDefault}, + }, + }, schema.Migrations[0]) +} + func TestSchema_RenameColumn(t *testing.T) { var schema Schema @@ -218,9 +261,7 @@ func TestRaw_InternalTableDefinition(t *testing.T) { } func TestDo(t *testing.T) { - var ( - schema Schema - ) + var schema Schema schema.Do(func(ctx context.Context, repo Repository) error { return nil }) assert.NotNil(t, schema.Migrations[0]) diff --git a/table.go b/table.go index 1c0f23a2..a15120be 100644 --- a/table.go +++ b/table.go @@ -136,6 +136,22 @@ func (at *AlterTable) RenameColumn(name string, newName string, options ...Colum at.Definitions = append(at.Definitions, renameColumn(name, newName, options)) } +// AlterColumnType with name. +func (at *AlterTable) AlterColumnType(name string, typ ColumnType, options ...ColumnOption) { + defs := alterColumnType(name, typ, options) + for _, def := range defs { + at.Definitions = append(at.Definitions, def) + } +} + +// AlterColumn with name. +func (at *AlterTable) AlterColumn(name string, options ...ColumnOption) { + defs := alterColumn(name, options) + for _, def := range defs { + at.Definitions = append(at.Definitions, def) + } +} + // DropColumn from this table. func (at *AlterTable) DropColumn(name string, options ...ColumnOption) { at.Definitions = append(at.Definitions, dropColumn(name, options))