Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/rain/coverage_target_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,17 +748,17 @@ func TestCoverageDDLMethodsAndHelpers(t *testing.T) {
{value: time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC), want: "'2026-01-02T03:04:05Z'"},
{value: []byte("abc"), want: "'abc'"},
} {
if got, err := columnDefaultSQL(pg, &schema.ColumnDef{Name: "x", Default: tc.value}); err != nil || got != tc.want {
if got, err := columnDefaultSQL(pg, users.TableDef(), &schema.ColumnDef{Name: "x", Default: tc.value}); err != nil || got != tc.want {
t.Fatalf("unexpected columnDefaultSQL for %#v: %q err=%v", tc.value, got, err)
}
if got, err := literalDDLSQL(pg, tc.value); err != nil || got != tc.want {
t.Fatalf("unexpected literalDDLSQL for %#v: %q err=%v", tc.value, got, err)
}
}
if got, err := columnDefaultSQL(pg, &schema.ColumnDef{Name: "x", DefaultSQL: "NOW()"}); err != nil || got != "NOW()" {
if got, err := columnDefaultSQL(pg, users.TableDef(), &schema.ColumnDef{Name: "x", DefaultSQL: "NOW()"}); err != nil || got != "NOW()" {
t.Fatalf("unexpected DefaultSQL passthrough: %q err=%v", got, err)
}
if _, err := columnDefaultSQL(pg, &schema.ColumnDef{Name: "x", Default: struct{}{}}); err == nil {
if _, err := columnDefaultSQL(pg, users.TableDef(), &schema.ColumnDef{Name: "x", Default: struct{}{}}); err == nil {
t.Fatalf("expected unsupported default type to fail")
}
if _, err := literalDDLSQL(pg, struct{}{}); err == nil {
Expand Down
14 changes: 9 additions & 5 deletions pkg/rain/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ func (db *DB) ColumnDefaultSQL(table schema.TableReference, columnName string) (
if !ok {
return "", fmt.Errorf("rain: table %q has no column %q", tableDef.Name, columnName)
}
if !column.HasDefault && column.DefaultSQL == "" {
if !column.HasDefault && column.DefaultSQL == "" && column.DefaultExpr == nil {
return "", nil
}

return columnDefaultSQL(db.dialect, column)
return columnDefaultSQL(db.dialect, tableDef, column)
}

func createTableSQL(d dialect.Dialect, table *schema.TableDef) (string, error) {
Expand Down Expand Up @@ -342,8 +342,8 @@ func columnDefinitionSQL(d dialect.Dialect, table *schema.TableDef, column *sche
if column.Unique {
parts = append(parts, "UNIQUE")
}
if column.HasDefault || column.DefaultSQL != "" {
defaultSQL, err := columnDefaultSQL(d, column)
if column.HasDefault || column.DefaultSQL != "" || column.DefaultExpr != nil {
defaultSQL, err := columnDefaultSQL(d, table, column)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -396,7 +396,11 @@ func shouldEmitAutoIncrementKeyword(d dialect.Dialect, column *schema.ColumnDef,
}
}

func columnDefaultSQL(d dialect.Dialect, column *schema.ColumnDef) (string, error) {
func columnDefaultSQL(d dialect.Dialect, table *schema.TableDef, column *schema.ColumnDef) (string, error) {
if column.DefaultExpr != nil {
return expressionDDLSQL(d, table, column.DefaultExpr)
}

if column.DefaultSQL != "" {
return column.DefaultSQL, nil
}
Expand Down
71 changes: 71 additions & 0 deletions pkg/rain/ddl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,77 @@ func defineDDLTables() (*ddlUsersTable, *ddlPostsTable, *ddlMembershipsTable) {
return users, posts, memberships
}

type ddlDefaultRawTable struct {
schema.TableModel
ID *schema.Column[int64]
CreatedAt *schema.Column[time.Time]
Random *schema.Column[float64]
}

func TestCreateTableSQLWithDefaultRaw(t *testing.T) {
t.Parallel()

table := schema.Define("default_raw_test", func(t *ddlDefaultRawTable) {
t.ID = t.BigSerial("id").PrimaryKey()
t.CreatedAt = t.TimestampTZ("created_at").NotNull().DefaultRaw(schema.Raw("now()"))
t.Random = t.Double("random").NotNull().DefaultRaw(schema.Raw("random()"))
})

cases := []struct {
name string
dialect string
fragments []string
}{
{
name: "postgres default raw",
dialect: "postgres",
fragments: []string{
`"created_at" TIMESTAMPTZ NOT NULL DEFAULT now()`,
`"random" DOUBLE PRECISION NOT NULL DEFAULT random()`,
},
},
{
name: "mysql default raw",
dialect: "mysql",
fragments: []string{
"`created_at` DATETIME NOT NULL DEFAULT now()",
"`random` DOUBLE NOT NULL DEFAULT random()",
},
},
{
name: "sqlite default raw",
dialect: "sqlite",
fragments: []string{
`"created_at" TEXT NOT NULL DEFAULT now()`,
`"random" REAL NOT NULL DEFAULT random()`,
},
},
}

for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

db, err := rain.OpenDialect(tc.dialect)
if err != nil {
t.Fatalf("OpenDialect(%q): %v", tc.dialect, err)
}

sql, err := db.CreateTableSQL(table)
if err != nil {
t.Fatalf("CreateTableSQL: %v", err)
}

for _, fragment := range tc.fragments {
if !strings.Contains(sql, fragment) {
t.Fatalf("expected SQL to contain %q, got:\n%s", fragment, sql)
}
}
})
}
}

func TestCreateTableSQLAcrossDialects(t *testing.T) {
t.Parallel()

Expand Down
17 changes: 17 additions & 0 deletions pkg/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ type ColumnDef struct {
Default any
HasDefault bool
DefaultSQL string
DefaultExpr Expression
PrimaryKey bool
AutoIncrement bool
Unique bool
Expand Down Expand Up @@ -628,13 +629,29 @@ func (c *Column[T]) Nullable() *Column[T] {
func (c *Column[T]) Default(value T) *Column[T] {
c.def.HasDefault = true
c.def.Default = value
c.def.DefaultSQL = ""
c.def.DefaultExpr = nil
return c
}

// DefaultNow sets CURRENT_TIMESTAMP as the default value.
func (c *Column[T]) DefaultNow() *Column[T] {
c.def.HasDefault = true
c.def.DefaultSQL = "CURRENT_TIMESTAMP"
c.def.Default = nil
c.def.DefaultExpr = nil
return c
}

// DefaultRaw sets a raw SQL expression as the default value.
func (c *Column[T]) DefaultRaw(expr Expression) *Column[T] {
if expr == nil {
panic("schema: DefaultRaw requires a non-nil expression")
}
c.def.HasDefault = true
c.def.DefaultExpr = expr
c.def.Default = nil
c.def.DefaultSQL = ""
return c
}

Expand Down