diff --git a/pkg/rain/postgres_integration_test.go b/pkg/rain/postgres_integration_test.go index 4d12804..4abbaa0 100644 --- a/pkg/rain/postgres_integration_test.go +++ b/pkg/rain/postgres_integration_test.go @@ -7,6 +7,7 @@ import ( "net" "net/url" "os" + "slices" "strconv" "strings" "testing" @@ -44,10 +45,8 @@ type postgresInsertModel struct { func registerPostgresDriverForTests(tb testing.TB) { tb.Helper() - for _, name := range sql.Drivers() { - if name == "postgres" { - return - } + if slices.Contains(sql.Drivers(), "postgres") { + return } sql.Register("postgres", stdlib.GetDefaultDriver()) diff --git a/pkg/rain/query.go b/pkg/rain/query.go deleted file mode 100644 index 8cb7a7d..0000000 --- a/pkg/rain/query.go +++ /dev/null @@ -1,1491 +0,0 @@ -package rain - -import ( - "context" - "database/sql" - "encoding/json" - "errors" - "fmt" - "reflect" - "slices" - "strings" - - "github.com/hyperlocalise/rain-orm/pkg/dialect" - "github.com/hyperlocalise/rain-orm/pkg/schema" -) - -type queryRunner interface { - execContext(context.Context, string, ...any) (sql.Result, error) - queryContext(context.Context, string, ...any) (*sql.Rows, error) -} - -type joinClause struct { - kind string - table selectTableSource - on schema.Predicate -} - -type assignment struct { - column schema.ColumnReference - value schema.Expression -} - -type returningClause struct { - feature dialect.Feature - label string -} - -type selectTableSource interface { - writeSQL(*compileContext) error -} - -type tableDefSource struct { - table *schema.TableDef -} - -func (s tableDefSource) writeSQL(ctx *compileContext) error { - ctx.writeTable(s.table) - return nil -} - -type subqueryTableSource struct { - query *SelectQuery - alias string -} - -func (s subqueryTableSource) writeSQL(ctx *compileContext) error { - if strings.TrimSpace(s.alias) == "" { - return errors.New("rain: subquery table source requires a non-empty alias") - } - if s.query == nil { - return fmt.Errorf("rain: subquery table source %q requires a non-nil query", s.alias) - } - ctx.writeByte('(') - if err := s.query.writeSQL(ctx); err != nil { - return err - } - ctx.writeString(") AS ") - ctx.writeQuotedIdentifier(s.alias) - return nil -} - -type cteDefinition struct { - name string - query *SelectQuery -} - -func closeRows(rows *sql.Rows, errp *error) { - if err := rows.Close(); err != nil && *errp == nil { - *errp = err - } -} - -// SelectQuery builds typed SELECT statements. -type SelectQuery struct { - runner queryRunner - dialect dialect.Dialect - cache QueryCache - table selectTableSource - cols []schema.Expression - where []schema.Predicate - joins []joinClause - order []schema.OrderExpr - groupBy []schema.Expression - having []schema.Predicate - ctes []cteDefinition - distinct bool - limit int - offset int - relationNames []string - cacheOptions *queryCacheOptions -} - -// Table sets the table source for the query. -func (q *SelectQuery) Table(table schema.TableReference) *SelectQuery { - q.table = tableDefSource{table: table.TableDef()} - return q -} - -// TableSubquery sets a subquery source for the query's FROM clause. -func (q *SelectQuery) TableSubquery(query *SelectQuery, alias string) *SelectQuery { - q.table = subqueryTableSource{query: query, alias: alias} - return q -} - -// Column sets the selected expressions. -func (q *SelectQuery) Column(cols ...schema.Expression) *SelectQuery { - q.cols = append(q.cols, cols...) - return q -} - -// Where appends a WHERE predicate joined with AND. -func (q *SelectQuery) Where(predicate schema.Predicate) *SelectQuery { - q.where = append(q.where, predicate) - return q -} - -// Join appends an INNER JOIN clause. -func (q *SelectQuery) Join(table schema.TableReference, on schema.Predicate) *SelectQuery { - q.joins = append(q.joins, joinClause{kind: "INNER JOIN", table: tableDefSource{table: table.TableDef()}, on: on}) - return q -} - -// LeftJoin appends a LEFT JOIN clause. -func (q *SelectQuery) LeftJoin(table schema.TableReference, on schema.Predicate) *SelectQuery { - q.joins = append(q.joins, joinClause{kind: "LEFT JOIN", table: tableDefSource{table: table.TableDef()}, on: on}) - return q -} - -// JoinSubquery appends an INNER JOIN against a subquery source. -func (q *SelectQuery) JoinSubquery(query *SelectQuery, alias string, on schema.Predicate) *SelectQuery { - q.joins = append(q.joins, joinClause{kind: "INNER JOIN", table: subqueryTableSource{query: query, alias: alias}, on: on}) - return q -} - -// LeftJoinSubquery appends a LEFT JOIN against a subquery source. -func (q *SelectQuery) LeftJoinSubquery(query *SelectQuery, alias string, on schema.Predicate) *SelectQuery { - q.joins = append(q.joins, joinClause{kind: "LEFT JOIN", table: subqueryTableSource{query: query, alias: alias}, on: on}) - return q -} - -// Distinct marks the SELECT query as DISTINCT. -func (q *SelectQuery) Distinct() *SelectQuery { - q.distinct = true - return q -} - -// GroupBy appends GROUP BY expressions. -func (q *SelectQuery) GroupBy(exprs ...schema.Expression) *SelectQuery { - q.groupBy = append(q.groupBy, exprs...) - return q -} - -// Having appends a HAVING predicate joined with AND. -func (q *SelectQuery) Having(predicate schema.Predicate) *SelectQuery { - q.having = append(q.having, predicate) - return q -} - -// With appends a common table expression definition. -func (q *SelectQuery) With(name string, query *SelectQuery) *SelectQuery { - q.ctes = append(q.ctes, cteDefinition{name: name, query: query}) - return q -} - -// OrderBy appends ORDER BY expressions. -func (q *SelectQuery) OrderBy(order ...schema.OrderExpr) *SelectQuery { - q.order = append(q.order, order...) - return q -} - -// Limit sets the LIMIT clause. -func (q *SelectQuery) Limit(limit int) *SelectQuery { - q.limit = limit - return q -} - -// Offset sets the OFFSET clause. -func (q *SelectQuery) Offset(offset int) *SelectQuery { - q.offset = offset - return q -} - -// WithRelations requests one or more named relations to be loaded after scanning base rows. -func (q *SelectQuery) WithRelations(names ...string) *SelectQuery { - q.relationNames = append(q.relationNames, names...) - return q -} - -// Cache enables opt-in query caching for this SELECT with TTL and optional metadata. -func (q *SelectQuery) Cache(options QueryCacheOptions) *SelectQuery { - q.cacheOptions = normalizeQueryCacheOptions(options) - return q -} - -// ToSQL compiles the query into SQL and args. -func (q *SelectQuery) ToSQL() (string, []any, error) { - ctx := newCompileContext(q.dialect) - if err := q.writeSQL(ctx); err != nil { - return "", nil, err - } - return ctx.String(), ctx.args, nil -} - -func (q *SelectQuery) writeSQL(ctx *compileContext) error { - if q.table == nil { - return errors.New("rain: select query requires a table") - } - - if len(q.ctes) > 0 { - if !dialect.HasFeature(ctx.dialect.Features(), dialect.FeatureCTE) { - return fmt.Errorf("rain: select queries do not support CTEs for %s dialect", ctx.dialect.Name()) - } - ctx.writeString("WITH ") - for idx, cte := range q.ctes { - if idx > 0 { - ctx.writeString(", ") - } - if strings.TrimSpace(cte.name) == "" { - return errors.New("rain: CTE name cannot be empty") - } - if cte.query == nil { - return fmt.Errorf("rain: CTE %q requires a query", cte.name) - } - if len(cte.query.ctes) > 0 { - return fmt.Errorf("rain: CTE %q body cannot itself contain CTEs", cte.name) - } - ctx.writeQuotedIdentifier(cte.name) - ctx.writeString(" AS (") - if err := cte.query.writeSQL(ctx); err != nil { - return err - } - ctx.writeByte(')') - } - ctx.writeByte(' ') - } - - ctx.writeString("SELECT ") - if q.distinct { - ctx.writeString("DISTINCT ") - } - if len(q.cols) == 0 { - ctx.writeString("*") - } else { - for idx, column := range q.cols { - if idx > 0 { - ctx.writeString(", ") - } - if err := ctx.writeSelectExpression(column); err != nil { - return err - } - } - } - - ctx.writeString(" FROM ") - if err := q.table.writeSQL(ctx); err != nil { - return err - } - - for _, join := range q.joins { - ctx.writeByte(' ') - ctx.writeString(join.kind) - ctx.writeByte(' ') - if err := join.table.writeSQL(ctx); err != nil { - return err - } - ctx.writeString(" ON ") - if err := ctx.writePredicate(join.on); err != nil { - return err - } - } - - if len(q.where) > 0 { - ctx.writeString(" WHERE ") - if err := ctx.writePredicate(joinPredicates(q.where)); err != nil { - return err - } - } - - if len(q.groupBy) > 0 { - ctx.writeString(" GROUP BY ") - for idx, expr := range q.groupBy { - if idx > 0 { - ctx.writeString(", ") - } - if err := ctx.writeExpression(expr); err != nil { - return err - } - } - } - - if len(q.having) > 0 { - ctx.writeString(" HAVING ") - if err := ctx.writePredicate(joinPredicates(q.having)); err != nil { - return err - } - } - - if len(q.order) > 0 { - ctx.writeString(" ORDER BY ") - for idx, item := range q.order { - if idx > 0 { - ctx.writeString(", ") - } - if err := ctx.writeExpression(item.Expr); err != nil { - return err - } - ctx.writeByte(' ') - ctx.writeString(string(item.Direction)) - } - } - - if clause := q.dialect.LimitOffset(q.limit, q.offset); clause != "" { - ctx.writeByte(' ') - ctx.writeString(clause) - } - - return nil -} - -// Scan executes the SELECT query and scans results into dest. -func (q *SelectQuery) Scan(ctx context.Context, dest any) error { - if q.runner == nil { - return ErrNoConnection - } - - query, args, err := q.ToSQL() - if err != nil { - return err - } - - cacheKey, cacheOptions, err := q.resolveCacheKey(query, args) - if err != nil { - return err - } - if cacheOptions != nil && !cacheOptions.bypass { - cached, ok, cacheErr := q.cache.Get(ctx, cacheKey) - if cacheErr != nil { - return cacheErr - } - if ok { - return json.Unmarshal(cached, dest) - } - } - - rows, err := q.runner.queryContext(ctx, query, args...) - if err != nil { - return err - } - defer closeRows(rows, &err) - - if len(q.relationNames) == 0 { - err = scanRows(rows, dest) - } else { - err = q.scanRowsWithRelations(ctx, rows, dest) - } - if err != nil { - return err - } - err = q.writeCachedResult(ctx, cacheKey, cacheOptions, dest) - return err -} - -// Count executes SELECT COUNT(*). -func (q *SelectQuery) Count(ctx context.Context) (int64, error) { - if q.runner == nil { - return 0, ErrNoConnection - } - - query, args, err := q.toAggregateSQL("COUNT(*)") - if err != nil { - return 0, err - } - - cacheKey, cacheOptions, err := q.resolveCacheKey(query, args) - if err != nil { - return 0, err - } - if cacheOptions != nil && !cacheOptions.bypass { - cached, ok, cacheErr := q.cache.Get(ctx, cacheKey) - if cacheErr != nil { - return 0, cacheErr - } - if ok { - var count int64 - if err := json.Unmarshal(cached, &count); err != nil { - return 0, err - } - return count, nil - } - } - - rows, err := q.runner.queryContext(ctx, query, args...) - if err != nil { - return 0, err - } - defer closeRows(rows, &err) - - var count int64 - if !rows.Next() { - err = sql.ErrNoRows - return 0, err - } - if err := rows.Scan(&count); err != nil { - return 0, err - } - - err = rows.Err() - if err != nil { - return 0, err - } - err = q.writeCachedResult(ctx, cacheKey, cacheOptions, count) - return count, err -} - -// Exists executes a SELECT EXISTS query. -func (q *SelectQuery) Exists(ctx context.Context) (bool, error) { - if q.runner == nil { - return false, ErrNoConnection - } - - sqlText, args, err := q.ToSQL() - if err != nil { - return false, err - } - - ctxCompiler := newCompileContext(q.dialect) - ctxCompiler.writeString("SELECT EXISTS(") - ctxCompiler.writeString(sqlText) - ctxCompiler.writeByte(')') - ctxCompiler.args = append(ctxCompiler.args, args...) - - query := ctxCompiler.String() - cacheKey, cacheOptions, err := q.resolveCacheKey(query, ctxCompiler.args) - if err != nil { - return false, err - } - if cacheOptions != nil && !cacheOptions.bypass { - cached, ok, cacheErr := q.cache.Get(ctx, cacheKey) - if cacheErr != nil { - return false, cacheErr - } - if ok { - var exists bool - if err := json.Unmarshal(cached, &exists); err != nil { - return false, err - } - return exists, nil - } - } - - rows, err := q.runner.queryContext(ctx, query, ctxCompiler.args...) - if err != nil { - return false, err - } - defer closeRows(rows, &err) - - var exists bool - if !rows.Next() { - err = sql.ErrNoRows - return false, err - } - if err := rows.Scan(&exists); err != nil { - return false, err - } - - err = rows.Err() - if err != nil { - return false, err - } - err = q.writeCachedResult(ctx, cacheKey, cacheOptions, exists) - return exists, err -} - -func (q *SelectQuery) resolveCacheKey(query string, args []any) (string, *queryCacheOptions, error) { - if q.cacheOptions == nil || q.cache == nil { - return "", nil, nil - } - key, err := buildQueryCacheKey(q.dialect.Name(), query, args, q.relationNames, q.cacheOptions) - if err != nil { - return "", nil, err - } - return key, q.cacheOptions, nil -} - -func (q *SelectQuery) writeCachedResult(ctx context.Context, key string, options *queryCacheOptions, value any) error { - if options == nil || options.bypass { - return nil - } - encoded, err := json.Marshal(value) - if err != nil { - return err - } - return q.cache.Set(ctx, key, encoded, options.ttl, options.tags) -} - -func (q *SelectQuery) toAggregateSQL(selection string) (string, []any, error) { - if q.table == nil { - return "", nil, errors.New("rain: select query requires a table") - } - if len(q.ctes) > 0 { - return "", nil, errors.New("rain: aggregate helpers do not support WITH clauses") - } - if q.distinct || len(q.groupBy) > 0 || len(q.having) > 0 { - return "", nil, errors.New("rain: aggregate helpers do not support DISTINCT, GROUP BY, or HAVING clauses") - } - - ctx := newCompileContext(q.dialect) - ctx.writeString("SELECT ") - ctx.writeString(selection) - ctx.writeString(" FROM ") - if err := q.table.writeSQL(ctx); err != nil { - return "", nil, err - } - - for _, join := range q.joins { - ctx.writeByte(' ') - ctx.writeString(join.kind) - ctx.writeByte(' ') - if err := join.table.writeSQL(ctx); err != nil { - return "", nil, err - } - ctx.writeString(" ON ") - if err := ctx.writePredicate(join.on); err != nil { - return "", nil, err - } - } - - if len(q.where) > 0 { - ctx.writeString(" WHERE ") - if err := ctx.writePredicate(joinPredicates(q.where)); err != nil { - return "", nil, err - } - } - - return ctx.String(), ctx.args, ctx.err -} - -// InsertQuery builds typed INSERT statements. -type InsertQuery struct { - runner queryRunner - dialect dialect.Dialect - table *schema.TableDef - model any - models any - values []assignment - rows []map[schema.ColumnReference]any - returning []schema.Expression - conflict *insertConflictClause -} - -type insertConflictAction uint8 - -const ( - insertConflictActionNone insertConflictAction = iota - insertConflictActionDoNothing - insertConflictActionDoUpdateSet -) - -type insertConflictClause struct { - columns []schema.ColumnReference - action insertConflictAction - updates []schema.ColumnReference -} - -// InsertConflictBuilder configures conflict behavior for INSERT statements. -type InsertConflictBuilder struct { - query *InsertQuery -} - -// Table sets the INSERT target table. -func (q *InsertQuery) Table(table schema.TableReference) *InsertQuery { - q.table = table.TableDef() - return q -} - -// Model sets a struct payload for the insert. -// Zero-valued fields for columns with schema defaults are omitted so the -// database default applies; use Set to override that behavior explicitly. -func (q *InsertQuery) Model(model any) *InsertQuery { - q.model = model - return q -} - -// Models sets multiple struct payloads for a bulk insert. -func (q *InsertQuery) Models(models any) *InsertQuery { - q.models = models - return q -} - -// Set adds an explicit column assignment. -func (q *InsertQuery) Set(column schema.ColumnReference, value any) *InsertQuery { - q.values = append(q.values, assignment{column: column, value: schema.ValueExpr{Value: value}}) - return q -} - -// Values appends explicit row value sets for a bulk insert. -func (q *InsertQuery) Values(rows ...map[schema.ColumnReference]any) *InsertQuery { - q.rows = append(q.rows, rows...) - return q -} - -// OnConflict starts an upsert clause for PostgreSQL and SQLite dialects. -func (q *InsertQuery) OnConflict(columns ...schema.ColumnReference) *InsertConflictBuilder { - q.conflict = &insertConflictClause{columns: columns} - return &InsertConflictBuilder{query: q} -} - -// DoNothing configures ON CONFLICT ... DO NOTHING. -func (b *InsertConflictBuilder) DoNothing() *InsertQuery { - b.query.conflict.action = insertConflictActionDoNothing - return b.query -} - -// DoUpdateSet configures ON CONFLICT ... DO UPDATE SET using EXCLUDED values. -func (b *InsertConflictBuilder) DoUpdateSet(columns ...schema.ColumnReference) *InsertQuery { - b.query.conflict.action = insertConflictActionDoUpdateSet - b.query.conflict.updates = columns - return b.query -} - -// Returning adds RETURNING expressions when supported by the dialect. -func (q *InsertQuery) Returning(exprs ...schema.Expression) *InsertQuery { - q.returning = append(q.returning, exprs...) - return q -} - -// ToSQL compiles the insert into SQL and args. -func (q *InsertQuery) ToSQL() (string, []any, error) { - rows, err := q.insertAssignments() - if err != nil { - return "", nil, err - } - - ctx := newCompileContext(q.dialect) - ctx.writeString("INSERT INTO ") - ctx.writeTableName(q.table) - ctx.writeString(" (") - for idx, item := range rows[0] { - if idx > 0 { - ctx.writeString(", ") - } - ctx.writeQuotedIdentifier(item.column.ColumnDef().Name) - } - ctx.writeString(") VALUES ") - for rowIdx, row := range rows { - if rowIdx > 0 { - ctx.writeString(", ") - } - ctx.writeByte('(') - for idx, item := range row { - if idx > 0 { - ctx.writeString(", ") - } - if err := ctx.writeExpression(item.value); err != nil { - return "", nil, err - } - } - ctx.writeByte(')') - } - - if err := q.writeConflictClause(ctx); err != nil { - return "", nil, err - } - - if err := ctx.writeReturning(q.returning, q.returningClause()); err != nil { - return "", nil, err - } - - return ctx.String(), ctx.args, ctx.err -} - -func (q *InsertQuery) returningClause() returningClause { - return returningClause{ - feature: dialect.FeatureInsertReturning, - label: "insert", - } -} - -// Exec executes the INSERT query. -func (q *InsertQuery) Exec(ctx context.Context) (sql.Result, error) { - if q.runner == nil { - return nil, ErrNoConnection - } - - query, args, err := q.ToSQL() - if err != nil { - return nil, err - } - - return q.runner.execContext(ctx, query, args...) -} - -// Scan executes an INSERT ... RETURNING query and scans one row into dest. -func (q *InsertQuery) Scan(ctx context.Context, dest any) error { - if q.runner == nil { - return ErrNoConnection - } - if len(q.returning) == 0 { - return errors.New("rain: insert scan requires RETURNING") - } - - query, args, err := q.ToSQL() - if err != nil { - return err - } - - rows, err := q.runner.queryContext(ctx, query, args...) - if err != nil { - return err - } - defer closeRows(rows, &err) - - err = scanRows(rows, dest) - return err -} - -func (q *InsertQuery) insertAssignments() ([][]assignment, error) { - if q.table == nil { - return nil, errors.New("rain: insert query requires a table") - } - - sources := 0 - if q.model != nil || len(q.values) > 0 { - sources++ - } - if q.models != nil { - sources++ - } - if len(q.rows) > 0 { - sources++ - } - if sources == 0 { - return nil, errors.New("rain: insert query requires either explicit values or a model") - } - if sources > 1 { - return nil, errors.New("rain: insert query requires exactly one value source: Model/Set, Models, or Values") - } - - var rows [][]assignment - if q.models != nil { - modelRows, err := q.assignmentsFromModels() - if err != nil { - return nil, err - } - rows = append(rows, modelRows...) - } - if len(q.rows) > 0 { - valueRows, err := q.assignmentsFromRows() - if err != nil { - return nil, err - } - rows = append(rows, valueRows...) - } - if q.model != nil || len(q.values) > 0 { - singleRow, err := q.assignmentsFromModelAndSet() - if err != nil { - return nil, err - } - rows = append(rows, singleRow) - } - - if len(rows) == 0 { - return nil, errors.New("rain: insert query produced no values") - } - - if err := validateInsertRowShape(rows); err != nil { - return nil, err - } - - return rows, nil -} - -func (q *InsertQuery) assignmentsFromModelAndSet() ([]assignment, error) { - var ( - modelAssignments []assignment - err error - ) - if q.model != nil { - modelAssignments, err = assignmentsFromModel(q.table, q.model, true) - if err != nil { - return nil, err - } - } - - assignments, err := mergeAssignments(q.table, modelAssignments, q.values) - if err != nil { - return nil, err - } - if len(assignments) == 0 { - return nil, errors.New("rain: insert query produced no values") - } - - return assignments, nil -} - -func (q *InsertQuery) assignmentsFromModels() ([][]assignment, error) { - value := reflect.ValueOf(q.models) - for value.Kind() == reflect.Pointer { - if value.IsNil() { - return nil, errors.New("rain: insert models cannot be nil") - } - value = value.Elem() - } - if value.Kind() != reflect.Slice && value.Kind() != reflect.Array { - return nil, errors.New("rain: Models expects a slice or array") - } - if value.Len() == 0 { - return nil, errors.New("rain: Models expects at least one model") - } - - rows := make([][]assignment, 0, value.Len()) - for idx := range value.Len() { - assignments, err := assignmentsFromModel(q.table, value.Index(idx).Interface(), true) - if err != nil { - return nil, err - } - if len(assignments) == 0 { - return nil, fmt.Errorf("rain: insert row %d produced no values", idx+1) - } - rows = append(rows, assignments) - } - - return rows, nil -} - -func (q *InsertQuery) assignmentsFromRows() ([][]assignment, error) { - rows := make([][]assignment, 0, len(q.rows)) - for idx, row := range q.rows { - if len(row) == 0 { - return nil, fmt.Errorf("rain: insert row %d has no values", idx+1) - } - - overrides := make([]assignment, 0, len(row)) - for column, value := range row { - overrides = append(overrides, assignment{column: column, value: schema.ValueExpr{Value: value}}) - } - - assignments, err := mergeAssignments(q.table, nil, overrides) - if err != nil { - return nil, err - } - if len(assignments) == 0 { - return nil, fmt.Errorf("rain: insert row %d produced no values", idx+1) - } - rows = append(rows, assignments) - } - - return rows, nil -} - -func validateInsertRowShape(rows [][]assignment) error { - want := rows[0] - wantColumns := make([]string, 0, len(want)) - for _, item := range want { - wantColumns = append(wantColumns, item.column.ColumnDef().Name) - } - - for rowIdx := 1; rowIdx < len(rows); rowIdx++ { - row := rows[rowIdx] - if len(row) != len(want) { - return fmt.Errorf("rain: insert row %d targets %d columns, expected %d", rowIdx+1, len(row), len(want)) - } - for colIdx := range row { - if row[colIdx].column.ColumnDef().Name != wantColumns[colIdx] { - return fmt.Errorf("rain: insert row %d column mismatch at position %d: got %q, expected %q", rowIdx+1, colIdx+1, row[colIdx].column.ColumnDef().Name, wantColumns[colIdx]) - } - } - } - - return nil -} - -func (q *InsertQuery) writeConflictClause(ctx *compileContext) error { - if q.conflict == nil { - return nil - } - if q.conflict.action == insertConflictActionNone { - return errors.New("rain: conflict action is required; call DoNothing() or DoUpdateSet(...)") - } - - if q.dialect.Name() != "postgres" && q.dialect.Name() != "sqlite" { - return fmt.Errorf("rain: insert conflict clauses are not implemented for %s dialect", q.dialect.Name()) - } - - if len(q.conflict.columns) == 0 { - return errors.New("rain: conflict clause requires at least one target column") - } - - ctx.writeString(" ON CONFLICT (") - for idx, col := range q.conflict.columns { - if err := validateAssignmentTarget(q.table, assignment{column: col}); err != nil { - return err - } - if idx > 0 { - ctx.writeString(", ") - } - ctx.writeQuotedIdentifier(col.ColumnDef().Name) - } - ctx.writeByte(')') - - switch q.conflict.action { - case insertConflictActionDoNothing: - ctx.writeString(" DO NOTHING") - case insertConflictActionDoUpdateSet: - if len(q.conflict.updates) == 0 { - return errors.New("rain: conflict DO UPDATE requires at least one update column") - } - ctx.writeString(" DO UPDATE SET ") - for idx, col := range q.conflict.updates { - if err := validateAssignmentTarget(q.table, assignment{column: col}); err != nil { - return err - } - if idx > 0 { - ctx.writeString(", ") - } - ctx.writeQuotedIdentifier(col.ColumnDef().Name) - ctx.writeString(" = EXCLUDED.") - ctx.writeQuotedIdentifier(col.ColumnDef().Name) - } - } - - return nil -} - -// UpdateQuery builds typed UPDATE statements. -type UpdateQuery struct { - runner queryRunner - dialect dialect.Dialect - table *schema.TableDef - values []assignment - where []schema.Predicate - returning []schema.Expression - unbounded bool -} - -// Table sets the UPDATE target table. -func (q *UpdateQuery) Table(table schema.TableReference) *UpdateQuery { - q.table = table.TableDef() - return q -} - -// Set adds an explicit typed assignment. -func (q *UpdateQuery) Set(column schema.ColumnReference, value any) *UpdateQuery { - q.values = append(q.values, assignment{column: column, value: schema.ValueExpr{Value: value}}) - return q -} - -// Where appends a WHERE predicate joined with AND. -func (q *UpdateQuery) Where(predicate schema.Predicate) *UpdateQuery { - q.where = append(q.where, predicate) - return q -} - -// Returning adds RETURNING expressions when supported by the dialect. -func (q *UpdateQuery) Returning(exprs ...schema.Expression) *UpdateQuery { - q.returning = append(q.returning, exprs...) - return q -} - -// Unbounded allows UPDATE without a WHERE clause. -func (q *UpdateQuery) Unbounded() *UpdateQuery { - q.unbounded = true - return q -} - -// ToSQL compiles the update into SQL and args. -func (q *UpdateQuery) ToSQL() (string, []any, error) { - if q.table == nil { - return "", nil, errors.New("rain: update query requires a table") - } - if len(q.values) == 0 { - return "", nil, errors.New("rain: update query requires at least one assignment") - } - if len(q.where) == 0 && !q.unbounded { - return "", nil, errors.New("rain: update query requires at least one WHERE predicate; call Unbounded() to allow all rows") - } - - ctx := newCompileContext(q.dialect) - ctx.writeString("UPDATE ") - ctx.writeTableName(q.table) - ctx.writeString(" SET ") - for idx, item := range q.values { - if idx > 0 { - ctx.writeString(", ") - } - ctx.writeQuotedIdentifier(item.column.ColumnDef().Name) - ctx.writeString(" = ") - if err := ctx.writeExpression(item.value); err != nil { - return "", nil, err - } - } - - if len(q.where) > 0 { - ctx.writeString(" WHERE ") - if err := ctx.writePredicate(joinPredicates(q.where)); err != nil { - return "", nil, err - } - } - - if err := ctx.writeReturning(q.returning, q.returningClause()); err != nil { - return "", nil, err - } - - return ctx.String(), ctx.args, ctx.err -} - -func (q *UpdateQuery) returningClause() returningClause { - return returningClause{ - feature: dialect.FeatureUpdateReturning, - label: "update", - } -} - -// Exec executes the UPDATE query. -func (q *UpdateQuery) Exec(ctx context.Context) (sql.Result, error) { - if q.runner == nil { - return nil, ErrNoConnection - } - - query, args, err := q.ToSQL() - if err != nil { - return nil, err - } - - return q.runner.execContext(ctx, query, args...) -} - -// Scan executes an UPDATE ... RETURNING query and scans results into dest. -func (q *UpdateQuery) Scan(ctx context.Context, dest any) error { - if q.runner == nil { - return ErrNoConnection - } - if len(q.returning) == 0 { - return errors.New("rain: update scan requires RETURNING") - } - - query, args, err := q.ToSQL() - if err != nil { - return err - } - - rows, err := q.runner.queryContext(ctx, query, args...) - if err != nil { - return err - } - defer closeRows(rows, &err) - - err = scanRows(rows, dest) - return err -} - -// DeleteQuery builds typed DELETE statements. -type DeleteQuery struct { - runner queryRunner - dialect dialect.Dialect - table *schema.TableDef - where []schema.Predicate - returning []schema.Expression - unbounded bool -} - -// Table sets the DELETE target table. -func (q *DeleteQuery) Table(table schema.TableReference) *DeleteQuery { - q.table = table.TableDef() - return q -} - -// Where appends a WHERE predicate joined with AND. -func (q *DeleteQuery) Where(predicate schema.Predicate) *DeleteQuery { - q.where = append(q.where, predicate) - return q -} - -// Returning adds RETURNING expressions when supported by the dialect. -func (q *DeleteQuery) Returning(exprs ...schema.Expression) *DeleteQuery { - q.returning = append(q.returning, exprs...) - return q -} - -// Unbounded allows DELETE without a WHERE clause. -func (q *DeleteQuery) Unbounded() *DeleteQuery { - q.unbounded = true - return q -} - -// ToSQL compiles the delete into SQL and args. -func (q *DeleteQuery) ToSQL() (string, []any, error) { - if q.table == nil { - return "", nil, errors.New("rain: delete query requires a table") - } - if len(q.where) == 0 && !q.unbounded { - return "", nil, errors.New("rain: delete query requires at least one WHERE predicate; call Unbounded() to allow all rows") - } - - ctx := newCompileContext(q.dialect) - ctx.writeString("DELETE FROM ") - ctx.writeTableName(q.table) - if len(q.where) > 0 { - ctx.writeString(" WHERE ") - if err := ctx.writePredicate(joinPredicates(q.where)); err != nil { - return "", nil, err - } - } - - if err := ctx.writeReturning(q.returning, q.returningClause()); err != nil { - return "", nil, err - } - - return ctx.String(), ctx.args, ctx.err -} - -func (q *DeleteQuery) returningClause() returningClause { - return returningClause{ - feature: dialect.FeatureDeleteReturning, - label: "delete", - } -} - -// Exec executes the DELETE query. -func (q *DeleteQuery) Exec(ctx context.Context) (sql.Result, error) { - if q.runner == nil { - return nil, ErrNoConnection - } - - query, args, err := q.ToSQL() - if err != nil { - return nil, err - } - - return q.runner.execContext(ctx, query, args...) -} - -// Scan executes a DELETE ... RETURNING query and scans results into dest. -func (q *DeleteQuery) Scan(ctx context.Context, dest any) error { - if q.runner == nil { - return ErrNoConnection - } - if len(q.returning) == 0 { - return errors.New("rain: delete scan requires RETURNING") - } - - query, args, err := q.ToSQL() - if err != nil { - return err - } - - rows, err := q.runner.queryContext(ctx, query, args...) - if err != nil { - return err - } - defer closeRows(rows, &err) - - err = scanRows(rows, dest) - return err -} - -type compileContext struct { - builder strings.Builder - dialect dialect.Dialect - args []any - err error -} - -func newCompileContext(d dialect.Dialect) *compileContext { - return &compileContext{ - dialect: d, - args: make([]any, 0, 8), - } -} - -func (c *compileContext) String() string { - return c.builder.String() -} - -func (c *compileContext) writeByte(ch byte) { - c.builder.WriteByte(ch) -} - -func (c *compileContext) writeString(value string) { - c.builder.WriteString(value) -} - -func (c *compileContext) writeQuotedIdentifier(name string) { - c.writeString(c.dialect.QuoteIdentifier(name)) -} - -func (c *compileContext) writeTableName(table *schema.TableDef) { - c.writeQuotedIdentifier(table.Name) -} - -func (c *compileContext) writeTable(table *schema.TableDef) { - c.writeTableName(table) - if table.Alias != "" { - c.writeString(" AS ") - c.writeQuotedIdentifier(table.Alias) - } -} - -func (c *compileContext) writeReturning(exprs []schema.Expression, clause returningClause) error { - if len(exprs) == 0 { - return nil - } - if !dialect.HasFeature(c.dialect.Features(), clause.feature) { - return fmt.Errorf("rain: %s queries do not support RETURNING for %s dialect", clause.label, c.dialect.Name()) - } - - c.writeString(" RETURNING ") - for idx, expr := range exprs { - if idx > 0 { - c.writeString(", ") - } - if err := c.writeExpression(expr); err != nil { - return err - } - } - - return nil -} - -func (c *compileContext) writePredicate(predicate schema.Predicate) error { - return c.writeExpression(predicate) -} - -type expressionContext struct { - allowAlias bool -} - -func (c *compileContext) writeExpression(expr schema.Expression) error { - return c.writeExpressionInContext(expr, expressionContext{}) -} - -func (c *compileContext) writeSelectExpression(expr schema.Expression) error { - return c.writeExpressionInContext(expr, expressionContext{allowAlias: true}) -} - -func (c *compileContext) writeExpressionInContext(expr schema.Expression, context expressionContext) error { - switch value := expr.(type) { - case schema.ColumnReference: - c.writeColumn(value) - case schema.ValueExpr: - c.args = append(c.args, value.Value) - c.writeString(c.dialect.Placeholder(len(c.args))) - case schema.ComparisonExpr: - if err := c.writeExpression(value.Left); err != nil { - return err - } - c.writeByte(' ') - c.writeString(value.Operator) - c.writeByte(' ') - if err := c.writeExpression(value.Right); err != nil { - return err - } - case schema.InExpr: - if len(value.Values) == 0 { - return errors.New("rain: IN predicate requires at least one value") - } - if err := c.writeExpression(value.Left); err != nil { - return err - } - c.writeString(" IN (") - for idx, item := range value.Values { - if idx > 0 { - c.writeString(", ") - } - if err := c.writeExpression(item); err != nil { - return err - } - } - c.writeByte(')') - case schema.NullCheckExpr: - if err := c.writeExpression(value.Expr); err != nil { - return err - } - if value.Negated { - c.writeString(" IS NOT NULL") - } else { - c.writeString(" IS NULL") - } - case schema.LogicalExpr: - c.writeByte('(') - for idx, part := range value.Exprs { - if idx > 0 { - c.writeByte(' ') - c.writeString(value.Operator) - c.writeByte(' ') - } - if err := c.writePredicate(part); err != nil { - return err - } - } - c.writeByte(')') - case schema.AggregateExpr: - if value.Function == "" { - return errors.New("rain: aggregate function name cannot be empty") - } - if value.Distinct && value.Star { - return fmt.Errorf("rain: aggregate %s cannot combine DISTINCT with *", value.Function) - } - c.writeString(value.Function) - c.writeByte('(') - if value.Distinct { - c.writeString("DISTINCT ") - } - switch { - case value.Star: - c.writeByte('*') - case value.Expr != nil: - if err := c.writeExpression(value.Expr); err != nil { - return err - } - default: - return fmt.Errorf("rain: aggregate %s requires an expression", value.Function) - } - c.writeByte(')') - case schema.AliasExpr: - if !context.allowAlias { - return errors.New("rain: aliased expressions are only supported in SELECT columns") - } - if err := c.writeExpressionInContext(value.Expr, expressionContext{}); err != nil { - return err - } - c.writeString(" AS ") - c.writeQuotedIdentifier(value.Alias) - case schema.RawExpr: - if err := c.writeRaw(value); err != nil { - return err - } - default: - return fmt.Errorf("rain: unsupported expression type %T", expr) - } - - return nil -} - -func (c *compileContext) writeRaw(raw schema.RawExpr) error { - argIndex := 0 - for idx := range len(raw.SQL) { - if raw.SQL[idx] != '?' { - c.writeByte(raw.SQL[idx]) - continue - } - if argIndex >= len(raw.Args) { - return errors.New("rain: raw SQL placeholder count does not match args") - } - c.args = append(c.args, raw.Args[argIndex]) - c.writeString(c.dialect.Placeholder(len(c.args))) - argIndex++ - } - if argIndex != len(raw.Args) { - return errors.New("rain: raw SQL has unused args") - } - - return nil -} - -func (c *compileContext) writeColumn(column schema.ColumnReference) { - def := column.ColumnDef() - table := def.Table - qualifier := table.Name - if table.Alias != "" { - qualifier = table.Alias - } - - c.writeQuotedIdentifier(qualifier) - c.writeByte('.') - c.writeQuotedIdentifier(def.Name) -} - -func joinPredicates(predicates []schema.Predicate) schema.Predicate { - if len(predicates) == 1 { - return predicates[0] - } - - return schema.And(predicates...) -} - -func assignmentsFromModel(table *schema.TableDef, model any, skipAuto bool) ([]assignment, error) { - meta, value, err := lookupModelMeta(model) - if err != nil { - return nil, err - } - - assignments := make([]assignment, 0, len(table.Columns)) - for _, column := range table.Columns { - field, ok := meta.byColumn[column.Name] - if !ok { - continue - } - - fieldValue := value.FieldByIndex(field.index) - resolvedValue, include := fieldValueForInsert(column, fieldValue, skipAuto) - if !include { - continue - } - - assignments = append(assignments, assignment{ - column: schema.Ref(column), - value: schema.ValueExpr{Value: resolvedValue}, - }) - } - - return assignments, nil -} - -func mergeAssignments(table *schema.TableDef, base, overrides []assignment) ([]assignment, error) { - ordered := make([]assignment, 0, len(table.Columns)) - assignmentsByName := make(map[string]assignment, len(table.Columns)) - - for _, item := range base { - if err := validateAssignmentTarget(table, item); err != nil { - return nil, err - } - assignmentsByName[item.column.ColumnDef().Name] = item - } - for _, item := range overrides { - if err := validateAssignmentTarget(table, item); err != nil { - return nil, err - } - assignmentsByName[item.column.ColumnDef().Name] = item - } - - for _, column := range table.Columns { - item, ok := assignmentsByName[column.Name] - if !ok { - continue - } - ordered = append(ordered, item) - delete(assignmentsByName, column.Name) - } - - if len(assignmentsByName) > 0 { - names := make([]string, 0, len(assignmentsByName)) - for name := range assignmentsByName { - names = append(names, name) - } - slices.Sort(names) - return nil, fmt.Errorf("rain: insert assignments contain unknown target columns: %s", strings.Join(names, ", ")) - } - - return ordered, nil -} - -func validateAssignmentTarget(table *schema.TableDef, item assignment) error { - column := item.column.ColumnDef() - if column.Table.Name != table.Name { - return fmt.Errorf("rain: column %s belongs to table %s, not %s", column.Name, column.Table.Name, table.Name) - } - if _, ok := table.ColumnByName(column.Name); !ok { - return fmt.Errorf("rain: unknown column %s on table %s", column.Name, table.Name) - } - - return nil -} - -func fieldValueForInsert(column *schema.ColumnDef, fieldValue reflect.Value, skipAuto bool) (any, bool) { - resolved, isNil := dereferenceValue(fieldValue) - if isNil { - return nil, false - } - - if skipAuto && column.AutoIncrement && resolved.IsZero() { - return nil, false - } - if column.HasDefault && resolved.IsZero() { - return nil, false - } - - return resolved.Interface(), true -} - -func dereferenceValue(value reflect.Value) (reflect.Value, bool) { - current := value - for current.Kind() == reflect.Pointer { - if current.IsNil() { - return reflect.Value{}, true - } - current = current.Elem() - } - - return current, false -} diff --git a/pkg/rain/query_common.go b/pkg/rain/query_common.go new file mode 100644 index 0000000..d36bfa9 --- /dev/null +++ b/pkg/rain/query_common.go @@ -0,0 +1,78 @@ +package rain + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + + "github.com/hyperlocalise/rain-orm/pkg/dialect" + "github.com/hyperlocalise/rain-orm/pkg/schema" +) + +type queryRunner interface { + execContext(context.Context, string, ...any) (sql.Result, error) + queryContext(context.Context, string, ...any) (*sql.Rows, error) +} + +type joinClause struct { + kind string + table selectTableSource + on schema.Predicate +} + +type assignment struct { + column schema.ColumnReference + value schema.Expression +} + +type returningClause struct { + feature dialect.Feature + label string +} + +type selectTableSource interface { + writeSQL(*compileContext) error +} + +type tableDefSource struct { + table *schema.TableDef +} + +func (s tableDefSource) writeSQL(ctx *compileContext) error { + ctx.writeTable(s.table) + return nil +} + +type subqueryTableSource struct { + query *SelectQuery + alias string +} + +func (s subqueryTableSource) writeSQL(ctx *compileContext) error { + if strings.TrimSpace(s.alias) == "" { + return errors.New("rain: subquery table source requires a non-empty alias") + } + if s.query == nil { + return fmt.Errorf("rain: subquery table source %q requires a non-nil query", s.alias) + } + ctx.writeByte('(') + if err := s.query.writeSQL(ctx); err != nil { + return err + } + ctx.writeString(") AS ") + ctx.writeQuotedIdentifier(s.alias) + return nil +} + +type cteDefinition struct { + name string + query *SelectQuery +} + +func closeRows(rows *sql.Rows, errp *error) { + if err := rows.Close(); err != nil && *errp == nil { + *errp = err + } +} diff --git a/pkg/rain/query_compile.go b/pkg/rain/query_compile.go new file mode 100644 index 0000000..c9e3902 --- /dev/null +++ b/pkg/rain/query_compile.go @@ -0,0 +1,230 @@ +package rain + +import ( + "errors" + "fmt" + "strings" + + "github.com/hyperlocalise/rain-orm/pkg/dialect" + "github.com/hyperlocalise/rain-orm/pkg/schema" +) + +type compileContext struct { + builder strings.Builder + dialect dialect.Dialect + args []any + err error +} + +func newCompileContext(d dialect.Dialect) *compileContext { + return &compileContext{ + dialect: d, + args: make([]any, 0, 8), + } +} + +func (c *compileContext) String() string { + return c.builder.String() +} + +func (c *compileContext) writeByte(ch byte) { + c.builder.WriteByte(ch) +} + +func (c *compileContext) writeString(value string) { + c.builder.WriteString(value) +} + +func (c *compileContext) writeQuotedIdentifier(name string) { + c.writeString(c.dialect.QuoteIdentifier(name)) +} + +func (c *compileContext) writeTableName(table *schema.TableDef) { + c.writeQuotedIdentifier(table.Name) +} + +func (c *compileContext) writeTable(table *schema.TableDef) { + c.writeTableName(table) + if table.Alias != "" { + c.writeString(" AS ") + c.writeQuotedIdentifier(table.Alias) + } +} + +func (c *compileContext) writeReturning(exprs []schema.Expression, clause returningClause) error { + if len(exprs) == 0 { + return nil + } + if !dialect.HasFeature(c.dialect.Features(), clause.feature) { + return fmt.Errorf("rain: %s queries do not support RETURNING for %s dialect", clause.label, c.dialect.Name()) + } + + c.writeString(" RETURNING ") + for idx, expr := range exprs { + if idx > 0 { + c.writeString(", ") + } + if err := c.writeExpression(expr); err != nil { + return err + } + } + + return nil +} + +func (c *compileContext) writePredicate(predicate schema.Predicate) error { + return c.writeExpression(predicate) +} + +type expressionContext struct { + allowAlias bool +} + +func (c *compileContext) writeExpression(expr schema.Expression) error { + return c.writeExpressionInContext(expr, expressionContext{}) +} + +func (c *compileContext) writeSelectExpression(expr schema.Expression) error { + return c.writeExpressionInContext(expr, expressionContext{allowAlias: true}) +} + +func (c *compileContext) writeExpressionInContext(expr schema.Expression, context expressionContext) error { + switch value := expr.(type) { + case schema.ColumnReference: + c.writeColumn(value) + case schema.ValueExpr: + c.args = append(c.args, value.Value) + c.writeString(c.dialect.Placeholder(len(c.args))) + case schema.ComparisonExpr: + if err := c.writeExpression(value.Left); err != nil { + return err + } + c.writeByte(' ') + c.writeString(value.Operator) + c.writeByte(' ') + if err := c.writeExpression(value.Right); err != nil { + return err + } + case schema.InExpr: + if len(value.Values) == 0 { + return errors.New("rain: IN predicate requires at least one value") + } + if err := c.writeExpression(value.Left); err != nil { + return err + } + c.writeString(" IN (") + for idx, item := range value.Values { + if idx > 0 { + c.writeString(", ") + } + if err := c.writeExpression(item); err != nil { + return err + } + } + c.writeByte(')') + case schema.NullCheckExpr: + if err := c.writeExpression(value.Expr); err != nil { + return err + } + if value.Negated { + c.writeString(" IS NOT NULL") + } else { + c.writeString(" IS NULL") + } + case schema.LogicalExpr: + c.writeByte('(') + for idx, part := range value.Exprs { + if idx > 0 { + c.writeByte(' ') + c.writeString(value.Operator) + c.writeByte(' ') + } + if err := c.writePredicate(part); err != nil { + return err + } + } + c.writeByte(')') + case schema.AggregateExpr: + if value.Function == "" { + return errors.New("rain: aggregate function name cannot be empty") + } + if value.Distinct && value.Star { + return fmt.Errorf("rain: aggregate %s cannot combine DISTINCT with *", value.Function) + } + c.writeString(value.Function) + c.writeByte('(') + if value.Distinct { + c.writeString("DISTINCT ") + } + switch { + case value.Star: + c.writeByte('*') + case value.Expr != nil: + if err := c.writeExpression(value.Expr); err != nil { + return err + } + default: + return fmt.Errorf("rain: aggregate %s requires an expression", value.Function) + } + c.writeByte(')') + case schema.AliasExpr: + if !context.allowAlias { + return errors.New("rain: aliased expressions are only supported in SELECT columns") + } + if err := c.writeExpressionInContext(value.Expr, expressionContext{}); err != nil { + return err + } + c.writeString(" AS ") + c.writeQuotedIdentifier(value.Alias) + case schema.RawExpr: + if err := c.writeRaw(value); err != nil { + return err + } + default: + return fmt.Errorf("rain: unsupported expression type %T", expr) + } + + return nil +} + +func (c *compileContext) writeRaw(raw schema.RawExpr) error { + argIndex := 0 + for idx := range len(raw.SQL) { + if raw.SQL[idx] != '?' { + c.writeByte(raw.SQL[idx]) + continue + } + if argIndex >= len(raw.Args) { + return errors.New("rain: raw SQL placeholder count does not match args") + } + c.args = append(c.args, raw.Args[argIndex]) + c.writeString(c.dialect.Placeholder(len(c.args))) + argIndex++ + } + if argIndex != len(raw.Args) { + return errors.New("rain: raw SQL has unused args") + } + + return nil +} + +func (c *compileContext) writeColumn(column schema.ColumnReference) { + def := column.ColumnDef() + table := def.Table + qualifier := table.Name + if table.Alias != "" { + qualifier = table.Alias + } + + c.writeQuotedIdentifier(qualifier) + c.writeByte('.') + c.writeQuotedIdentifier(def.Name) +} + +func joinPredicates(predicates []schema.Predicate) schema.Predicate { + if len(predicates) == 1 { + return predicates[0] + } + + return schema.And(predicates...) +} diff --git a/pkg/rain/query_compile_internal_test.go b/pkg/rain/query_compile_internal_test.go new file mode 100644 index 0000000..d9f958e --- /dev/null +++ b/pkg/rain/query_compile_internal_test.go @@ -0,0 +1,259 @@ +package rain + +import ( + "context" + "errors" + "reflect" + "strings" + "testing" + + "github.com/hyperlocalise/rain-orm/pkg/dialect" + "github.com/hyperlocalise/rain-orm/pkg/schema" +) + +func TestQueryBuilderAndHelperErrors(t *testing.T) { + t.Parallel() + + db, err := OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, posts := defineInternalQueryTables() + + if _, _, err := db.Select().ToSQL(); err == nil || !strings.Contains(err.Error(), "requires a table") { + t.Fatalf("expected select table error, got %v", err) + } + selectNoRunner := &SelectQuery{dialect: db.Dialect(), table: tableDefSource{table: users.TableDef()}} + if err := selectNoRunner.Scan(context.Background(), &internalUserRow{}); !errors.Is(err, ErrNoConnection) { + t.Fatalf("expected select scan ErrNoConnection, got %v", err) + } + if _, err := (&SelectQuery{dialect: db.Dialect()}).Count(context.Background()); !errors.Is(err, ErrNoConnection) { + t.Fatalf("expected select count ErrNoConnection, got %v", err) + } + if _, err := (&SelectQuery{dialect: db.Dialect()}).Exists(context.Background()); !errors.Is(err, ErrNoConnection) { + t.Fatalf("expected select exists ErrNoConnection, got %v", err) + } + + if _, _, err := db.Insert().ToSQL(); err == nil || !strings.Contains(err.Error(), "requires a table") { + t.Fatalf("expected insert table error, got %v", err) + } + if _, _, err := db.Insert().Table(users).ToSQL(); err == nil || !strings.Contains(err.Error(), "requires either explicit values or a model") { + t.Fatalf("expected insert values error, got %v", err) + } + insertNoRunner := &InsertQuery{dialect: db.Dialect(), table: users.TableDef(), returning: []schema.Expression{users.ID}} + if err := insertNoRunner.Scan(context.Background(), &internalUserRow{}); !errors.Is(err, ErrNoConnection) { + t.Fatalf("expected insert returning scan ErrNoConnection, got %v", err) + } + insertNoReturning := &InsertQuery{runner: db, dialect: db.Dialect(), table: users.TableDef()} + if err := insertNoReturning.Scan(context.Background(), &internalUserRow{}); err == nil || !strings.Contains(err.Error(), "requires RETURNING") { + t.Fatalf("expected insert returning error, got %v", err) + } + + if _, _, err := db.Update().ToSQL(); err == nil || !strings.Contains(err.Error(), "requires a table") { + t.Fatalf("expected update table error, got %v", err) + } + if _, _, err := db.Update().Table(users).ToSQL(); err == nil || !strings.Contains(err.Error(), "requires at least one assignment") { + t.Fatalf("expected update assignment error, got %v", err) + } + if _, _, err := db.Update().Table(users).Set(users.Name, "Alice").ToSQL(); err == nil || !strings.Contains(err.Error(), "requires at least one WHERE predicate") { + t.Fatalf("expected update WHERE guard error, got %v", err) + } + if _, _, err := db.Update().Table(users).Set(users.Name, "Alice").Unbounded().ToSQL(); err != nil { + t.Fatalf("expected unbounded update to succeed, got %v", err) + } + updateNoRunner := &UpdateQuery{dialect: db.Dialect(), table: users.TableDef(), values: []assignment{{column: users.Name, value: schema.ValueExpr{Value: "Alice"}}}, returning: []schema.Expression{users.ID}} + if err := updateNoRunner.Scan(context.Background(), &internalUserRow{}); !errors.Is(err, ErrNoConnection) { + t.Fatalf("expected update scan ErrNoConnection, got %v", err) + } + updateNoReturning := &UpdateQuery{runner: db, dialect: db.Dialect(), table: users.TableDef(), values: []assignment{{column: users.Name, value: schema.ValueExpr{Value: "Alice"}}}} + if err := updateNoReturning.Scan(context.Background(), &internalUserRow{}); err == nil || !strings.Contains(err.Error(), "requires RETURNING") { + t.Fatalf("expected update returning error, got %v", err) + } + + if _, _, err := db.Delete().ToSQL(); err == nil || !strings.Contains(err.Error(), "requires a table") { + t.Fatalf("expected delete table error, got %v", err) + } + if _, _, err := db.Delete().Table(users).ToSQL(); err == nil || !strings.Contains(err.Error(), "requires at least one WHERE predicate") { + t.Fatalf("expected delete WHERE guard error, got %v", err) + } + if _, _, err := db.Delete().Table(users).Unbounded().ToSQL(); err != nil { + t.Fatalf("expected unbounded delete to succeed, got %v", err) + } + deleteNoRunner := &DeleteQuery{dialect: db.Dialect(), table: users.TableDef(), returning: []schema.Expression{users.ID}} + if err := deleteNoRunner.Scan(context.Background(), &internalUserRow{}); !errors.Is(err, ErrNoConnection) { + t.Fatalf("expected delete scan ErrNoConnection, got %v", err) + } + deleteNoReturning := &DeleteQuery{runner: db, dialect: db.Dialect(), table: users.TableDef()} + if err := deleteNoReturning.Scan(context.Background(), &internalUserRow{}); err == nil || !strings.Contains(err.Error(), "requires RETURNING") { + t.Fatalf("expected delete returning error, got %v", err) + } + + leftJoinSQL, _, err := db.Select(). + Table(users). + Column(users.ID). + LeftJoin(posts, users.ID.EqCol(posts.UserID)). + Where(users.Active.Eq(true)). + Where(users.Email.Eq("alice@example.com")). + OrderBy(users.ID.Asc(), users.Email.Desc()). + Limit(5). + Offset(10). + ToSQL() + if err != nil { + t.Fatalf("left join ToSQL failed: %v", err) + } + if !strings.Contains(leftJoinSQL, "LEFT JOIN") || !strings.Contains(leftJoinSQL, "OFFSET 10") { + t.Fatalf("unexpected left join SQL: %s", leftJoinSQL) + } +} + +func TestCompileContextAndAssignmentsHelpers(t *testing.T) { + t.Parallel() + + users, posts := defineInternalQueryTables() + + ctx := newCompileContext(dialectForTest(t, "postgres")) + if err := ctx.writeRaw(schema.Raw("NOW()")); err != nil { + t.Fatalf("writeRaw without args failed: %v", err) + } + if ctx.String() != "NOW()" { + t.Fatalf("unexpected raw SQL: %s", ctx.String()) + } + + ctx = newCompileContext(dialectForTest(t, "postgres")) + if err := ctx.writeRaw(schema.Raw("? + ?", 1, 2)); err != nil { + t.Fatalf("writeRaw placeholders failed: %v", err) + } + if ctx.String() != "$1 + $2" { + t.Fatalf("unexpected placeholder SQL: %s", ctx.String()) + } + + if err := newCompileContext(dialectForTest(t, "postgres")).writeRaw(schema.Raw("?", 1, 2)); err == nil || !strings.Contains(err.Error(), "unused args") { + t.Fatalf("expected raw unused args error, got %v", err) + } + if err := newCompileContext(dialectForTest(t, "postgres")).writeRaw(schema.Raw("? ?", 1)); err == nil || !strings.Contains(err.Error(), "placeholder count") { + t.Fatalf("expected raw placeholder mismatch error, got %v", err) + } + if err := newCompileContext(dialectForTest(t, "postgres")).writeExpression(users.ID.In()); err == nil || !strings.Contains(err.Error(), "requires at least one value") { + t.Fatalf("expected empty IN error, got %v", err) + } + if err := newCompileContext(dialectForTest(t, "postgres")).writeExpression(nil); err == nil || !strings.Contains(err.Error(), "unsupported expression type") { + t.Fatalf("expected unsupported expression error, got %v", err) + } + + merged, err := mergeAssignments(users.TableDef(), + []assignment{ + {column: users.Email, value: schema.ValueExpr{Value: "base@example.com"}}, + {column: users.Name, value: schema.ValueExpr{Value: "Base"}}, + }, + []assignment{ + {column: users.Name, value: schema.ValueExpr{Value: "Override"}}, + {column: users.Active, value: schema.ValueExpr{Value: true}}, + }, + ) + if err != nil { + t.Fatalf("mergeAssignments failed: %v", err) + } + if len(merged) != 3 { + t.Fatalf("expected 3 merged assignments, got %d", len(merged)) + } + if merged[1].column.ColumnDef().Name != "name" || merged[2].column.ColumnDef().Name != "active" { + t.Fatalf("unexpected merged order: %#v", merged) + } + if merged[1].value.(schema.ValueExpr).Value != "Override" { + t.Fatalf("expected override assignment to win, got %#v", merged[1].value) + } + + if _, err := mergeAssignments(users.TableDef(), nil, []assignment{{column: posts.Title, value: schema.ValueExpr{Value: "bad"}}}); err == nil || !strings.Contains(err.Error(), "belongs to table posts") { + t.Fatalf("expected foreign table assignment error, got %v", err) + } + + ghostColumn := schema.Ref(&schema.ColumnDef{Table: users.TableDef(), Name: "ghost"}) + if _, err := mergeAssignments(users.TableDef(), nil, []assignment{{column: ghostColumn, value: schema.ValueExpr{Value: "bad"}}}); err == nil || !strings.Contains(err.Error(), "unknown column ghost") { + t.Fatalf("expected unknown column assignment error, got %v", err) + } + + if got := joinPredicates([]schema.Predicate{users.Active.Eq(true)}); got != users.Active.Eq(true) { + t.Fatalf("expected single predicate to pass through") + } + if _, ok := joinPredicates([]schema.Predicate{users.Active.Eq(true), users.Email.Eq("alice@example.com")}).(schema.LogicalExpr); !ok { + t.Fatalf("expected multiple predicates to produce logical expression") + } +} + +func TestModelAssignmentAndValueHelpers(t *testing.T) { + t.Parallel() + + users, _ := defineInternalQueryTables() + + nickname := "ally" + assignments, err := assignmentsFromModel(users.TableDef(), &internalInsertModel{ + ID: 0, + Email: "alice@example.com", + Name: "", + Active: false, + Nickname: &nickname, + }, true) + if err != nil { + t.Fatalf("assignmentsFromModel failed: %v", err) + } + if len(assignments) != 2 { + t.Fatalf("expected 2 assignments after skipping default-backed zero values, got %d", len(assignments)) + } + if assignments[0].column.ColumnDef().Name != "email" || assignments[1].column.ColumnDef().Name != "nickname" { + t.Fatalf("unexpected assignments: %#v", assignments) + } + + assignments, err = assignmentsFromModel(users.TableDef(), &internalInsertModel{ + ID: 42, + Email: "bob@example.com", + Name: "Bob", + Active: true, + }, false) + if err != nil { + t.Fatalf("assignmentsFromModel skipAuto=false failed: %v", err) + } + if len(assignments) != 4 { + t.Fatalf("expected 4 assignments when auto id is retained, got %d", len(assignments)) + } + + if _, include := fieldValueForInsert(users.ID.ColumnDef(), reflect.ValueOf(int64(0)), true); include { + t.Fatalf("expected zero auto-increment id to be skipped") + } + if _, include := fieldValueForInsert(users.Name.ColumnDef(), reflect.ValueOf(""), true); include { + t.Fatalf("expected default-backed zero string to be skipped") + } + if _, include := fieldValueForInsert(users.Nickname.ColumnDef(), reflect.ValueOf((*string)(nil)), true); include { + t.Fatalf("expected nil pointer to be skipped") + } + if value, include := fieldValueForInsert(users.Name.ColumnDef(), reflect.ValueOf("Alice"), true); !include || value != "Alice" { + t.Fatalf("expected non-zero string to be included, got %#v include=%v", value, include) + } + + type pointerHolder struct { + Value **string + } + var nilStringPtr *string + holder := pointerHolder{Value: &nilStringPtr} + if _, isNil := dereferenceValue(reflect.ValueOf(holder).Field(0)); !isNil { + t.Fatalf("expected nested nil pointer to be detected") + } + + name := "Alice" + namePtr := &name + holder = pointerHolder{Value: &namePtr} + resolved, isNil := dereferenceValue(reflect.ValueOf(holder).Field(0)) + if isNil || resolved.Kind() != reflect.String || resolved.String() != "Alice" { + t.Fatalf("unexpected dereference result: %#v isNil=%v", resolved, isNil) + } +} + +func dialectForTest(t *testing.T, driver string) dialect.Dialect { + t.Helper() + + db, err := OpenDialect(driver) + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + + return db.Dialect() +} diff --git a/pkg/rain/query_delete.go b/pkg/rain/query_delete.go new file mode 100644 index 0000000..f43a3d6 --- /dev/null +++ b/pkg/rain/query_delete.go @@ -0,0 +1,115 @@ +package rain + +import ( + "context" + "database/sql" + "errors" + + "github.com/hyperlocalise/rain-orm/pkg/dialect" + "github.com/hyperlocalise/rain-orm/pkg/schema" +) + +// DeleteQuery builds typed DELETE statements. +type DeleteQuery struct { + runner queryRunner + dialect dialect.Dialect + table *schema.TableDef + where []schema.Predicate + returning []schema.Expression + unbounded bool +} + +// Table sets the DELETE target table. +func (q *DeleteQuery) Table(table schema.TableReference) *DeleteQuery { + q.table = table.TableDef() + return q +} + +// Where appends a WHERE predicate joined with AND. +func (q *DeleteQuery) Where(predicate schema.Predicate) *DeleteQuery { + q.where = append(q.where, predicate) + return q +} + +// Returning adds RETURNING expressions when supported by the dialect. +func (q *DeleteQuery) Returning(exprs ...schema.Expression) *DeleteQuery { + q.returning = append(q.returning, exprs...) + return q +} + +// Unbounded allows DELETE without a WHERE clause. +func (q *DeleteQuery) Unbounded() *DeleteQuery { + q.unbounded = true + return q +} + +// ToSQL compiles the delete into SQL and args. +func (q *DeleteQuery) ToSQL() (string, []any, error) { + if q.table == nil { + return "", nil, errors.New("rain: delete query requires a table") + } + if len(q.where) == 0 && !q.unbounded { + return "", nil, errors.New("rain: delete query requires at least one WHERE predicate; call Unbounded() to allow all rows") + } + + ctx := newCompileContext(q.dialect) + ctx.writeString("DELETE FROM ") + ctx.writeTableName(q.table) + if len(q.where) > 0 { + ctx.writeString(" WHERE ") + if err := ctx.writePredicate(joinPredicates(q.where)); err != nil { + return "", nil, err + } + } + + if err := ctx.writeReturning(q.returning, q.returningClause()); err != nil { + return "", nil, err + } + + return ctx.String(), ctx.args, ctx.err +} + +func (q *DeleteQuery) returningClause() returningClause { + return returningClause{ + feature: dialect.FeatureDeleteReturning, + label: "delete", + } +} + +// Exec executes the DELETE query. +func (q *DeleteQuery) Exec(ctx context.Context) (sql.Result, error) { + if q.runner == nil { + return nil, ErrNoConnection + } + + query, args, err := q.ToSQL() + if err != nil { + return nil, err + } + + return q.runner.execContext(ctx, query, args...) +} + +// Scan executes a DELETE ... RETURNING query and scans results into dest. +func (q *DeleteQuery) Scan(ctx context.Context, dest any) error { + if q.runner == nil { + return ErrNoConnection + } + if len(q.returning) == 0 { + return errors.New("rain: delete scan requires RETURNING") + } + + query, args, err := q.ToSQL() + if err != nil { + return err + } + + rows, err := q.runner.queryContext(ctx, query, args...) + if err != nil { + return err + } + defer closeRows(rows, &err) + + err = scanRows(rows, dest) + return err +} diff --git a/pkg/rain/query_insert.go b/pkg/rain/query_insert.go new file mode 100644 index 0000000..0a38a04 --- /dev/null +++ b/pkg/rain/query_insert.go @@ -0,0 +1,399 @@ +package rain + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + + "github.com/hyperlocalise/rain-orm/pkg/dialect" + "github.com/hyperlocalise/rain-orm/pkg/schema" +) + +// InsertQuery builds typed INSERT statements. +type InsertQuery struct { + runner queryRunner + dialect dialect.Dialect + table *schema.TableDef + model any + models any + values []assignment + rows []map[schema.ColumnReference]any + returning []schema.Expression + conflict *insertConflictClause +} + +type insertConflictAction uint8 + +const ( + insertConflictActionNone insertConflictAction = iota + insertConflictActionDoNothing + insertConflictActionDoUpdateSet +) + +type insertConflictClause struct { + columns []schema.ColumnReference + action insertConflictAction + updates []schema.ColumnReference +} + +// InsertConflictBuilder configures conflict behavior for INSERT statements. +type InsertConflictBuilder struct { + query *InsertQuery +} + +// Table sets the INSERT target table. +func (q *InsertQuery) Table(table schema.TableReference) *InsertQuery { + q.table = table.TableDef() + return q +} + +// Model sets a struct payload for the insert. +// Zero-valued fields for columns with schema defaults are omitted so the +// database default applies; use Set to override that behavior explicitly. +func (q *InsertQuery) Model(model any) *InsertQuery { + q.model = model + return q +} + +// Models sets multiple struct payloads for a bulk insert. +func (q *InsertQuery) Models(models any) *InsertQuery { + q.models = models + return q +} + +// Set adds an explicit column assignment. +func (q *InsertQuery) Set(column schema.ColumnReference, value any) *InsertQuery { + q.values = append(q.values, assignment{column: column, value: schema.ValueExpr{Value: value}}) + return q +} + +// Values appends explicit row value sets for a bulk insert. +func (q *InsertQuery) Values(rows ...map[schema.ColumnReference]any) *InsertQuery { + q.rows = append(q.rows, rows...) + return q +} + +// OnConflict starts an upsert clause for PostgreSQL and SQLite dialects. +func (q *InsertQuery) OnConflict(columns ...schema.ColumnReference) *InsertConflictBuilder { + q.conflict = &insertConflictClause{columns: columns} + return &InsertConflictBuilder{query: q} +} + +// DoNothing configures ON CONFLICT ... DO NOTHING. +func (b *InsertConflictBuilder) DoNothing() *InsertQuery { + b.query.conflict.action = insertConflictActionDoNothing + return b.query +} + +// DoUpdateSet configures ON CONFLICT ... DO UPDATE SET using EXCLUDED values. +func (b *InsertConflictBuilder) DoUpdateSet(columns ...schema.ColumnReference) *InsertQuery { + b.query.conflict.action = insertConflictActionDoUpdateSet + b.query.conflict.updates = columns + return b.query +} + +// Returning adds RETURNING expressions when supported by the dialect. +func (q *InsertQuery) Returning(exprs ...schema.Expression) *InsertQuery { + q.returning = append(q.returning, exprs...) + return q +} + +// ToSQL compiles the insert into SQL and args. +func (q *InsertQuery) ToSQL() (string, []any, error) { + rows, err := q.insertAssignments() + if err != nil { + return "", nil, err + } + + ctx := newCompileContext(q.dialect) + ctx.writeString("INSERT INTO ") + ctx.writeTableName(q.table) + ctx.writeString(" (") + for idx, item := range rows[0] { + if idx > 0 { + ctx.writeString(", ") + } + ctx.writeQuotedIdentifier(item.column.ColumnDef().Name) + } + ctx.writeString(") VALUES ") + for rowIdx, row := range rows { + if rowIdx > 0 { + ctx.writeString(", ") + } + ctx.writeByte('(') + for idx, item := range row { + if idx > 0 { + ctx.writeString(", ") + } + if err := ctx.writeExpression(item.value); err != nil { + return "", nil, err + } + } + ctx.writeByte(')') + } + + if err := q.writeConflictClause(ctx); err != nil { + return "", nil, err + } + + if err := ctx.writeReturning(q.returning, q.returningClause()); err != nil { + return "", nil, err + } + + return ctx.String(), ctx.args, ctx.err +} + +func (q *InsertQuery) returningClause() returningClause { + return returningClause{ + feature: dialect.FeatureInsertReturning, + label: "insert", + } +} + +// Exec executes the INSERT query. +func (q *InsertQuery) Exec(ctx context.Context) (sql.Result, error) { + if q.runner == nil { + return nil, ErrNoConnection + } + + query, args, err := q.ToSQL() + if err != nil { + return nil, err + } + + return q.runner.execContext(ctx, query, args...) +} + +// Scan executes an INSERT ... RETURNING query and scans one row into dest. +func (q *InsertQuery) Scan(ctx context.Context, dest any) error { + if q.runner == nil { + return ErrNoConnection + } + if len(q.returning) == 0 { + return errors.New("rain: insert scan requires RETURNING") + } + + query, args, err := q.ToSQL() + if err != nil { + return err + } + + rows, err := q.runner.queryContext(ctx, query, args...) + if err != nil { + return err + } + defer closeRows(rows, &err) + + err = scanRows(rows, dest) + return err +} + +func (q *InsertQuery) insertAssignments() ([][]assignment, error) { + if q.table == nil { + return nil, errors.New("rain: insert query requires a table") + } + + sources := 0 + if q.model != nil || len(q.values) > 0 { + sources++ + } + if q.models != nil { + sources++ + } + if len(q.rows) > 0 { + sources++ + } + if sources == 0 { + return nil, errors.New("rain: insert query requires either explicit values or a model") + } + if sources > 1 { + return nil, errors.New("rain: insert query requires exactly one value source: Model/Set, Models, or Values") + } + + var rows [][]assignment + if q.models != nil { + modelRows, err := q.assignmentsFromModels() + if err != nil { + return nil, err + } + rows = append(rows, modelRows...) + } + if len(q.rows) > 0 { + valueRows, err := q.assignmentsFromRows() + if err != nil { + return nil, err + } + rows = append(rows, valueRows...) + } + if q.model != nil || len(q.values) > 0 { + singleRow, err := q.assignmentsFromModelAndSet() + if err != nil { + return nil, err + } + rows = append(rows, singleRow) + } + + if len(rows) == 0 { + return nil, errors.New("rain: insert query produced no values") + } + + if err := validateInsertRowShape(rows); err != nil { + return nil, err + } + + return rows, nil +} + +func (q *InsertQuery) assignmentsFromModelAndSet() ([]assignment, error) { + var ( + modelAssignments []assignment + err error + ) + if q.model != nil { + modelAssignments, err = assignmentsFromModel(q.table, q.model, true) + if err != nil { + return nil, err + } + } + + assignments, err := mergeAssignments(q.table, modelAssignments, q.values) + if err != nil { + return nil, err + } + if len(assignments) == 0 { + return nil, errors.New("rain: insert query produced no values") + } + + return assignments, nil +} + +func (q *InsertQuery) assignmentsFromModels() ([][]assignment, error) { + value := reflect.ValueOf(q.models) + for value.Kind() == reflect.Pointer { + if value.IsNil() { + return nil, errors.New("rain: insert models cannot be nil") + } + value = value.Elem() + } + if value.Kind() != reflect.Slice && value.Kind() != reflect.Array { + return nil, errors.New("rain: Models expects a slice or array") + } + if value.Len() == 0 { + return nil, errors.New("rain: Models expects at least one model") + } + + rows := make([][]assignment, 0, value.Len()) + for idx := range value.Len() { + assignments, err := assignmentsFromModel(q.table, value.Index(idx).Interface(), true) + if err != nil { + return nil, err + } + if len(assignments) == 0 { + return nil, fmt.Errorf("rain: insert row %d produced no values", idx+1) + } + rows = append(rows, assignments) + } + + return rows, nil +} + +func (q *InsertQuery) assignmentsFromRows() ([][]assignment, error) { + rows := make([][]assignment, 0, len(q.rows)) + for idx, row := range q.rows { + if len(row) == 0 { + return nil, fmt.Errorf("rain: insert row %d has no values", idx+1) + } + + overrides := make([]assignment, 0, len(row)) + for column, value := range row { + overrides = append(overrides, assignment{column: column, value: schema.ValueExpr{Value: value}}) + } + + assignments, err := mergeAssignments(q.table, nil, overrides) + if err != nil { + return nil, err + } + if len(assignments) == 0 { + return nil, fmt.Errorf("rain: insert row %d produced no values", idx+1) + } + rows = append(rows, assignments) + } + + return rows, nil +} + +func validateInsertRowShape(rows [][]assignment) error { + want := rows[0] + wantColumns := make([]string, 0, len(want)) + for _, item := range want { + wantColumns = append(wantColumns, item.column.ColumnDef().Name) + } + + for rowIdx := 1; rowIdx < len(rows); rowIdx++ { + row := rows[rowIdx] + if len(row) != len(want) { + return fmt.Errorf("rain: insert row %d targets %d columns, expected %d", rowIdx+1, len(row), len(want)) + } + for colIdx := range row { + if row[colIdx].column.ColumnDef().Name != wantColumns[colIdx] { + return fmt.Errorf("rain: insert row %d column mismatch at position %d: got %q, expected %q", rowIdx+1, colIdx+1, row[colIdx].column.ColumnDef().Name, wantColumns[colIdx]) + } + } + } + + return nil +} + +func (q *InsertQuery) writeConflictClause(ctx *compileContext) error { + if q.conflict == nil { + return nil + } + if q.conflict.action == insertConflictActionNone { + return errors.New("rain: conflict action is required; call DoNothing() or DoUpdateSet(...)") + } + + if q.dialect.Name() != "postgres" && q.dialect.Name() != "sqlite" { + return fmt.Errorf("rain: insert conflict clauses are not implemented for %s dialect", q.dialect.Name()) + } + + if len(q.conflict.columns) == 0 { + return errors.New("rain: conflict clause requires at least one target column") + } + + ctx.writeString(" ON CONFLICT (") + for idx, col := range q.conflict.columns { + if err := validateAssignmentTarget(q.table, assignment{column: col}); err != nil { + return err + } + if idx > 0 { + ctx.writeString(", ") + } + ctx.writeQuotedIdentifier(col.ColumnDef().Name) + } + ctx.writeByte(')') + + switch q.conflict.action { + case insertConflictActionDoNothing: + ctx.writeString(" DO NOTHING") + case insertConflictActionDoUpdateSet: + if len(q.conflict.updates) == 0 { + return errors.New("rain: conflict DO UPDATE requires at least one update column") + } + ctx.writeString(" DO UPDATE SET ") + for idx, col := range q.conflict.updates { + if err := validateAssignmentTarget(q.table, assignment{column: col}); err != nil { + return err + } + if idx > 0 { + ctx.writeString(", ") + } + ctx.writeQuotedIdentifier(col.ColumnDef().Name) + ctx.writeString(" = EXCLUDED.") + ctx.writeQuotedIdentifier(col.ColumnDef().Name) + } + } + + return nil +} diff --git a/pkg/rain/query_insert_test.go b/pkg/rain/query_insert_test.go new file mode 100644 index 0000000..dfade54 --- /dev/null +++ b/pkg/rain/query_insert_test.go @@ -0,0 +1,251 @@ +package rain_test + +import ( + "reflect" + "strings" + "testing" + + "github.com/hyperlocalise/rain-orm/pkg/rain" + "github.com/hyperlocalise/rain-orm/pkg/schema" +) + +func TestInsertModelAndSetMergeToSQL(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + sqlText, args, err := db.Insert(). + Table(users). + Model(&userModel{Email: "alice@example.com", Name: "", Active: false}). + Set(users.Name, "Alice"). + Set(users.Active, false). + ToSQL() + if err != nil { + t.Fatalf("insert merge ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email", "name", "active") VALUES ($1, $2, $3)` + if sqlText != wantSQL { + t.Fatalf("unexpected merged insert SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if len(args) != 3 || args[0] != "alice@example.com" || args[1] != "Alice" || args[2] != false { + t.Fatalf("unexpected merged insert args: %#v", args) + } +} + +func TestInsertOmitDefaultBackedZeroValues(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + sqlText, args, err := db.Insert(). + Table(users). + Model(&userModel{Email: "alice@example.com"}). + ToSQL() + if err != nil { + t.Fatalf("insert default omission ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email", "name") VALUES ($1, $2)` + if sqlText != wantSQL { + t.Fatalf("unexpected default-omitting insert SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if len(args) != 2 || args[0] != "alice@example.com" || args[1] != "" { + t.Fatalf("unexpected default-omitting insert args: %#v", args) + } +} + +func TestInsertMultiRowModelsToSQL(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + sqlText, args, err := db.Insert(). + Table(users). + Models([]userModel{ + {Email: "alice@example.com", Name: "Alice", Active: true}, + {Email: "bob@example.com", Name: "Bob", Active: true}, + }). + Returning(users.ID). + ToSQL() + if err != nil { + t.Fatalf("insert multi model ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email", "name", "active") VALUES ($1, $2, $3), ($4, $5, $6) RETURNING "users"."id"` + if sqlText != wantSQL { + t.Fatalf("unexpected multi model insert SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + wantArgs := []any{"alice@example.com", "Alice", true, "bob@example.com", "Bob", true} + if !reflect.DeepEqual(args, wantArgs) { + t.Fatalf("unexpected multi model insert args: %#v", args) + } +} + +func TestInsertMultiRowValuesToSQL(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + sqlText, args, err := db.Insert(). + Table(users). + Values( + map[schema.ColumnReference]any{users.Email: "alice@example.com", users.Name: "Alice", users.Active: true}, + map[schema.ColumnReference]any{users.Email: "bob@example.com", users.Name: "Bob", users.Active: false}, + ). + ToSQL() + if err != nil { + t.Fatalf("insert multi values ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email", "name", "active") VALUES ($1, $2, $3), ($4, $5, $6)` + if sqlText != wantSQL { + t.Fatalf("unexpected multi values insert SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + wantArgs := []any{"alice@example.com", "Alice", true, "bob@example.com", "Bob", false} + if !reflect.DeepEqual(args, wantArgs) { + t.Fatalf("unexpected multi values insert args: %#v", args) + } +} + +func TestInsertMultiRowColumnMismatchReturnsError(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + _, _, err = db.Insert(). + Table(users). + Models([]userModel{ + {Email: "alice@example.com", Name: "Alice", Active: true}, + {Email: "bob@example.com", Name: "", Active: false}, + }). + ToSQL() + if err == nil || !strings.Contains(err.Error(), "targets 2 columns, expected 3") { + t.Fatalf("expected column mismatch error, got %v", err) + } +} + +func TestInsertOnConflictPostgres(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + t.Run("do nothing", func(t *testing.T) { + sqlText, args, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + Set(users.Name, "Alice"). + OnConflict(users.Email). + DoNothing(). + ToSQL() + if err != nil { + t.Fatalf("insert on conflict do nothing ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email", "name") VALUES ($1, $2) ON CONFLICT ("email") DO NOTHING` + if sqlText != wantSQL { + t.Fatalf("unexpected do nothing SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if len(args) != 2 { + t.Fatalf("unexpected do nothing args: %#v", args) + } + }) + + t.Run("do update set", func(t *testing.T) { + sqlText, args, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + Set(users.Name, "Alice"). + Set(users.Active, true). + OnConflict(users.Email). + DoUpdateSet(users.Name, users.Active). + ToSQL() + if err != nil { + t.Fatalf("insert on conflict do update ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email", "name", "active") VALUES ($1, $2, $3) ON CONFLICT ("email") DO UPDATE SET "name" = EXCLUDED."name", "active" = EXCLUDED."active"` + if sqlText != wantSQL { + t.Fatalf("unexpected do update SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if len(args) != 3 { + t.Fatalf("unexpected do update args: %#v", args) + } + }) +} + +func TestInsertOnConflictSQLite(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("sqlite") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + sqlText, args, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + Set(users.Name, "Alice"). + Set(users.Active, true). + OnConflict(users.Email). + DoUpdateSet(users.Name, users.Active). + ToSQL() + if err != nil { + t.Fatalf("insert on conflict sqlite ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email", "name", "active") VALUES (?, ?, ?) ON CONFLICT ("email") DO UPDATE SET "name" = EXCLUDED."name", "active" = EXCLUDED."active"` + if sqlText != wantSQL { + t.Fatalf("unexpected sqlite do update SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + wantArgs := []any{"alice@example.com", "Alice", true} + if !reflect.DeepEqual(args, wantArgs) { + t.Fatalf("unexpected sqlite do update args: %#v", args) + } +} + +func TestInsertOnConflictUnsupportedDialectReturnsError(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("mysql") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + _, _, err = db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + Set(users.Name, "Alice"). + OnConflict(users.Email). + DoUpdateSet(users.Name). + ToSQL() + if err == nil || !strings.Contains(err.Error(), "not implemented") { + t.Fatalf("expected unsupported dialect error, got %v", err) + } +} diff --git a/pkg/rain/query_internal_test.go b/pkg/rain/query_internal_test.go index 43840e8..1b10f3e 100644 --- a/pkg/rain/query_internal_test.go +++ b/pkg/rain/query_internal_test.go @@ -3,15 +3,10 @@ package rain import ( "context" "database/sql" - "errors" - "fmt" "path/filepath" - "reflect" - "strings" "testing" "time" - "github.com/hyperlocalise/rain-orm/pkg/dialect" "github.com/hyperlocalise/rain-orm/pkg/schema" _ "modernc.org/sqlite" ) @@ -162,791 +157,3 @@ func createInternalQuerySchema(t *testing.T, ctx context.Context, db *DB) { } } } - -func TestQueryExecutionPaths(t *testing.T) { - t.Parallel() - - ctx := context.Background() - db := openInternalQueryDB(t) - users, posts := defineInternalQueryTables() - createInternalQuerySchema(t, ctx, db) - - insert := db.Insert(). - Table(users). - Model(&internalInsertModel{Email: "alice@example.com", Name: "Alice"}) - result, err := insert.Exec(ctx) - if err != nil { - t.Fatalf("insert exec failed: %v", err) - } - insertedID, err := result.LastInsertId() - if err != nil { - t.Fatalf("last insert id failed: %v", err) - } - - if _, err := db.Insert(). - Table(posts). - Set(posts.UserID, insertedID). - Set(posts.Title, "Hello"). - Exec(ctx); err != nil { - t.Fatalf("insert post failed: %v", err) - } - - count, err := db.Select().Table(users).Where(users.Active.Eq(true)).Count(ctx) - if err != nil { - t.Fatalf("count failed: %v", err) - } - if count != 1 { - t.Fatalf("expected count 1, got %d", count) - } - - exists, err := db.Select().Table(users).Where(users.Email.Eq("alice@example.com")).Exists(ctx) - if err != nil { - t.Fatalf("exists failed: %v", err) - } - if !exists { - t.Fatalf("expected row to exist") - } - - var row internalUserRow - if err := db.Select(). - Table(users). - Where(users.ID.Eq(insertedID)). - Scan(ctx, &row); err != nil { - t.Fatalf("scan select failed: %v", err) - } - if row.Email != "alice@example.com" || row.Name != "Alice" { - t.Fatalf("unexpected row: %#v", row) - } - - if err := db.Update(). - Table(users). - Set(users.Name, "Alice Updated"). - Where(users.ID.Eq(insertedID)). - Returning(users.ID, users.Name). - Scan(ctx, &row); err != nil { - t.Fatalf("update returning scan failed: %v", err) - } - if row.ID != insertedID || row.Name != "Alice Updated" { - t.Fatalf("unexpected updated row: %#v", row) - } - - if err := db.Delete(). - Table(users). - Where(users.ID.Eq(insertedID)). - Returning(users.ID, users.Email). - Scan(ctx, &row); err != nil { - t.Fatalf("delete returning scan failed: %v", err) - } - if row.ID != insertedID || row.Email != "alice@example.com" { - t.Fatalf("unexpected deleted row: %#v", row) - } - - if _, err := db.Select().Table(users).Where(users.ID.Eq(insertedID)).Count(ctx); !errors.Is(err, sql.ErrNoRows) && err != nil { - t.Fatalf("unexpected count error after delete: %v", err) - } -} - -func TestSelectQueryCacheHitMissExpiryAndBypass(t *testing.T) { - t.Parallel() - - ctx := context.Background() - db := openInternalQueryDB(t) - users, _ := defineInternalQueryTables() - createInternalQuerySchema(t, ctx, db) - - cache := NewMemoryQueryCache() - cache.now = func() time.Time { return time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) } - db.WithQueryCache(cache) - - if _, err := db.Insert().Table(users).Model(&internalInsertModel{Email: "cache@example.com", Name: "Cache"}).Exec(ctx); err != nil { - t.Fatalf("insert user: %v", err) - } - - counter := &countingRunner{base: db} - q := (&SelectQuery{runner: counter, dialect: db.Dialect(), cache: cache}). - Table(users). - Where(users.Email.Eq("cache@example.com")). - Cache(QueryCacheOptions{TTL: time.Minute, Tags: []string{"users"}}) - - var first []internalUserRow - if err := q.Scan(ctx, &first); err != nil { - t.Fatalf("first scan: %v", err) - } - if counter.queryCount != 1 { - t.Fatalf("expected first query to hit DB once, got %d", counter.queryCount) - } - - var second []internalUserRow - if err := q.Scan(ctx, &second); err != nil { - t.Fatalf("second scan: %v", err) - } - if counter.queryCount != 1 { - t.Fatalf("expected second query to hit cache, got query count %d", counter.queryCount) - } - if !reflect.DeepEqual(first, second) { - t.Fatalf("cached scan mismatch:\nfirst=%#v\nsecond=%#v", first, second) - } - - cache.now = func() time.Time { return time.Date(2026, 3, 29, 12, 2, 0, 0, time.UTC) } - var third []internalUserRow - if err := q.Scan(ctx, &third); err != nil { - t.Fatalf("third scan after expiry: %v", err) - } - if counter.queryCount != 2 { - t.Fatalf("expected expiry to force DB query, got %d", counter.queryCount) - } - - var bypassed []internalUserRow - if err := q.Cache(QueryCacheOptions{TTL: time.Minute, Tags: []string{"users"}, Bypass: true}).Scan(ctx, &bypassed); err != nil { - t.Fatalf("bypass scan: %v", err) - } - if counter.queryCount != 3 { - t.Fatalf("expected bypass to force DB query, got %d", counter.queryCount) - } -} - -func TestSelectQueryCacheArgsAndManualInvalidation(t *testing.T) { - t.Parallel() - - ctx := context.Background() - db := openInternalQueryDB(t) - users, _ := defineInternalQueryTables() - createInternalQuerySchema(t, ctx, db) - db.WithQueryCache(NewMemoryQueryCache()) - - for _, item := range []internalInsertModel{ - {Email: "alice@example.com", Name: "Alice"}, - {Email: "bob@example.com", Name: "Bob"}, - } { - if _, err := db.Insert().Table(users).Model(&item).Exec(ctx); err != nil { - t.Fatalf("insert user %s: %v", item.Email, err) - } - } - - counter := &countingRunner{base: db} - queryFor := func(email string) *SelectQuery { - return (&SelectQuery{runner: counter, dialect: db.Dialect(), cache: db.queryCache}). - Table(users). - Where(users.Email.Eq(email)). - Cache(QueryCacheOptions{TTL: 5 * time.Minute, Tags: []string{"users"}}) - } - - var alice []internalUserRow - if err := queryFor("alice@example.com").Scan(ctx, &alice); err != nil { - t.Fatalf("alice query first run: %v", err) - } - if err := queryFor("alice@example.com").Scan(ctx, &alice); err != nil { - t.Fatalf("alice query second run: %v", err) - } - if counter.queryCount != 1 { - t.Fatalf("expected repeated identical args to hit cache, query count %d", counter.queryCount) - } - - var bob []internalUserRow - if err := queryFor("bob@example.com").Scan(ctx, &bob); err != nil { - t.Fatalf("bob query first run: %v", err) - } - if counter.queryCount != 2 { - t.Fatalf("expected different args to use different entry, query count %d", counter.queryCount) - } - - if err := db.InvalidateQueryCache(ctx, "users"); err != nil { - t.Fatalf("invalidate tag: %v", err) - } - if err := queryFor("alice@example.com").Scan(ctx, &alice); err != nil { - t.Fatalf("alice query after invalidation: %v", err) - } - if counter.queryCount != 3 { - t.Fatalf("expected invalidation miss, query count %d", counter.queryCount) - } -} - -func TestSelectQueryCacheDisabledKeepsNormalBehavior(t *testing.T) { - t.Parallel() - - ctx := context.Background() - db := openInternalQueryDB(t) - users, _ := defineInternalQueryTables() - createInternalQuerySchema(t, ctx, db) - if _, err := db.Insert().Table(users).Model(&internalInsertModel{Email: "nocache@example.com", Name: "No Cache"}).Exec(ctx); err != nil { - t.Fatalf("insert user: %v", err) - } - - counter := &countingRunner{base: db} - q := (&SelectQuery{runner: counter, dialect: db.Dialect()}). - Table(users). - Where(users.Email.Eq("nocache@example.com")). - Cache(QueryCacheOptions{TTL: time.Minute, Tags: []string{"users"}}) - - var rows []internalUserRow - if err := q.Scan(ctx, &rows); err != nil { - t.Fatalf("first uncached scan: %v", err) - } - if err := q.Scan(ctx, &rows); err != nil { - t.Fatalf("second uncached scan: %v", err) - } - if counter.queryCount != 2 { - t.Fatalf("expected uncached behavior without backend, query count %d", counter.queryCount) - } -} - -func TestBuildQueryCacheKeyIsStableForEquivalentArgs(t *testing.T) { - t.Parallel() - - opts := normalizeQueryCacheOptions(QueryCacheOptions{TTL: time.Minute, Tags: []string{"users", "lookup"}, Namespace: "by-id"}) - keyOne, err := buildQueryCacheKey("sqlite", "SELECT * FROM users WHERE id = ?", []any{int64(1)}, nil, opts) - if err != nil { - t.Fatalf("build key one: %v", err) - } - keyTwo, err := buildQueryCacheKey("sqlite", "SELECT * FROM users WHERE id = ?", []any{int64(1)}, nil, opts) - if err != nil { - t.Fatalf("build key two: %v", err) - } - if keyOne != keyTwo { - t.Fatalf("expected stable key, got %q and %q", keyOne, keyTwo) - } -} - -func TestSelectAggregateCacheForCountAndExists(t *testing.T) { - t.Parallel() - - ctx := context.Background() - db := openInternalQueryDB(t) - users, _ := defineInternalQueryTables() - createInternalQuerySchema(t, ctx, db) - db.WithQueryCache(NewMemoryQueryCache()) - - if _, err := db.Insert().Table(users).Model(&internalInsertModel{Email: "agg@example.com", Name: "Agg"}).Exec(ctx); err != nil { - t.Fatalf("insert user: %v", err) - } - - counter := &countingRunner{base: db} - query := (&SelectQuery{runner: counter, dialect: db.Dialect(), cache: db.queryCache}). - Table(users). - Where(users.Email.Eq("agg@example.com")). - Cache(QueryCacheOptions{TTL: time.Minute, Tags: []string{"users"}}) - - count, err := query.Count(ctx) - if err != nil { - t.Fatalf("count first: %v", err) - } - if count != 1 { - t.Fatalf("expected count 1, got %d", count) - } - if _, err := query.Count(ctx); err != nil { - t.Fatalf("count second: %v", err) - } - if counter.queryCount != 1 { - t.Fatalf("expected second count to hit cache, query count %d", counter.queryCount) - } - - exists, err := query.Exists(ctx) - if err != nil { - t.Fatalf("exists first: %v", err) - } - if !exists { - t.Fatalf("expected exists=true") - } - if _, err := query.Exists(ctx); err != nil { - t.Fatalf("exists second: %v", err) - } - if counter.queryCount != 2 { - t.Fatalf("expected second exists to hit cache, query count %d", counter.queryCount) - } -} - -func TestQueryBuilderAndHelperErrors(t *testing.T) { - t.Parallel() - - db, err := OpenDialect("postgres") - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - users, posts := defineInternalQueryTables() - - if _, _, err := db.Select().ToSQL(); err == nil || !strings.Contains(err.Error(), "requires a table") { - t.Fatalf("expected select table error, got %v", err) - } - selectNoRunner := &SelectQuery{dialect: db.Dialect(), table: tableDefSource{table: users.TableDef()}} - if err := selectNoRunner.Scan(context.Background(), &internalUserRow{}); !errors.Is(err, ErrNoConnection) { - t.Fatalf("expected select scan ErrNoConnection, got %v", err) - } - if _, err := (&SelectQuery{dialect: db.Dialect()}).Count(context.Background()); !errors.Is(err, ErrNoConnection) { - t.Fatalf("expected select count ErrNoConnection, got %v", err) - } - if _, err := (&SelectQuery{dialect: db.Dialect()}).Exists(context.Background()); !errors.Is(err, ErrNoConnection) { - t.Fatalf("expected select exists ErrNoConnection, got %v", err) - } - - if _, _, err := db.Insert().ToSQL(); err == nil || !strings.Contains(err.Error(), "requires a table") { - t.Fatalf("expected insert table error, got %v", err) - } - if _, _, err := db.Insert().Table(users).ToSQL(); err == nil || !strings.Contains(err.Error(), "requires either explicit values or a model") { - t.Fatalf("expected insert values error, got %v", err) - } - if err := (&InsertQuery{dialect: db.Dialect(), table: users.TableDef(), returning: []schema.Expression{users.ID}}).Scan(context.Background(), &internalUserRow{}); !errors.Is(err, ErrNoConnection) { - t.Fatalf("expected insert scan ErrNoConnection, got %v", err) - } - - insertNoRunner := &InsertQuery{dialect: db.Dialect(), table: users.TableDef(), returning: []schema.Expression{users.ID}} - if err := insertNoRunner.Scan(context.Background(), &internalUserRow{}); !errors.Is(err, ErrNoConnection) { - t.Fatalf("expected insert returning scan ErrNoConnection, got %v", err) - } - insertNoReturning := &InsertQuery{runner: db, dialect: db.Dialect(), table: users.TableDef()} - if err := insertNoReturning.Scan(context.Background(), &internalUserRow{}); err == nil || !strings.Contains(err.Error(), "requires RETURNING") { - t.Fatalf("expected insert returning error, got %v", err) - } - - if _, _, err := db.Update().ToSQL(); err == nil || !strings.Contains(err.Error(), "requires a table") { - t.Fatalf("expected update table error, got %v", err) - } - if _, _, err := db.Update().Table(users).ToSQL(); err == nil || !strings.Contains(err.Error(), "requires at least one assignment") { - t.Fatalf("expected update assignment error, got %v", err) - } - if _, _, err := db.Update().Table(users).Set(users.Name, "Alice").ToSQL(); err == nil || !strings.Contains(err.Error(), "requires at least one WHERE predicate") { - t.Fatalf("expected update WHERE guard error, got %v", err) - } - if _, _, err := db.Update().Table(users).Set(users.Name, "Alice").Unbounded().ToSQL(); err != nil { - t.Fatalf("expected unbounded update to succeed, got %v", err) - } - updateNoRunner := &UpdateQuery{dialect: db.Dialect(), table: users.TableDef(), values: []assignment{{column: users.Name, value: schema.ValueExpr{Value: "Alice"}}}, returning: []schema.Expression{users.ID}} - if err := updateNoRunner.Scan(context.Background(), &internalUserRow{}); !errors.Is(err, ErrNoConnection) { - t.Fatalf("expected update scan ErrNoConnection, got %v", err) - } - updateNoReturning := &UpdateQuery{runner: db, dialect: db.Dialect(), table: users.TableDef(), values: []assignment{{column: users.Name, value: schema.ValueExpr{Value: "Alice"}}}} - if err := updateNoReturning.Scan(context.Background(), &internalUserRow{}); err == nil || !strings.Contains(err.Error(), "requires RETURNING") { - t.Fatalf("expected update returning error, got %v", err) - } - - if _, _, err := db.Delete().ToSQL(); err == nil || !strings.Contains(err.Error(), "requires a table") { - t.Fatalf("expected delete table error, got %v", err) - } - if _, _, err := db.Delete().Table(users).ToSQL(); err == nil || !strings.Contains(err.Error(), "requires at least one WHERE predicate") { - t.Fatalf("expected delete WHERE guard error, got %v", err) - } - if _, _, err := db.Delete().Table(users).Unbounded().ToSQL(); err != nil { - t.Fatalf("expected unbounded delete to succeed, got %v", err) - } - deleteNoRunner := &DeleteQuery{dialect: db.Dialect(), table: users.TableDef(), returning: []schema.Expression{users.ID}} - if err := deleteNoRunner.Scan(context.Background(), &internalUserRow{}); !errors.Is(err, ErrNoConnection) { - t.Fatalf("expected delete scan ErrNoConnection, got %v", err) - } - deleteNoReturning := &DeleteQuery{runner: db, dialect: db.Dialect(), table: users.TableDef()} - if err := deleteNoReturning.Scan(context.Background(), &internalUserRow{}); err == nil || !strings.Contains(err.Error(), "requires RETURNING") { - t.Fatalf("expected delete returning error, got %v", err) - } - - leftJoinSQL, _, err := db.Select(). - Table(users). - Column(users.ID). - LeftJoin(posts, users.ID.EqCol(posts.UserID)). - Where(users.Active.Eq(true)). - Where(users.Email.Eq("alice@example.com")). - OrderBy(users.ID.Asc(), users.Email.Desc()). - Limit(5). - Offset(10). - ToSQL() - if err != nil { - t.Fatalf("left join ToSQL failed: %v", err) - } - if !strings.Contains(leftJoinSQL, "LEFT JOIN") || !strings.Contains(leftJoinSQL, "OFFSET 10") { - t.Fatalf("unexpected left join SQL: %s", leftJoinSQL) - } -} - -func TestSelectWithRelations(t *testing.T) { - t.Parallel() - - ctx := context.Background() - db := openInternalQueryDB(t) - users, posts := defineInternalQueryTables() - createInternalQuerySchema(t, ctx, db) - - aliceResult, err := db.Insert().Table(users).Set(users.Email, "alice@example.com").Set(users.Name, "Alice").Exec(ctx) - if err != nil { - t.Fatalf("insert alice failed: %v", err) - } - aliceID, err := aliceResult.LastInsertId() - if err != nil { - t.Fatalf("alice last insert id failed: %v", err) - } - bobResult, err := db.Insert().Table(users).Set(users.Email, "bob@example.com").Set(users.Name, "Bob").Exec(ctx) - if err != nil { - t.Fatalf("insert bob failed: %v", err) - } - bobID, err := bobResult.LastInsertId() - if err != nil { - t.Fatalf("bob last insert id failed: %v", err) - } - - if _, err := db.Insert().Table(posts).Set(posts.UserID, aliceID).Set(posts.Title, "Hello from Alice").Exec(ctx); err != nil { - t.Fatalf("insert alice post failed: %v", err) - } - if _, err := db.Insert().Table(posts).Set(posts.UserID, aliceID).Set(posts.Title, "Second Alice Post").Exec(ctx); err != nil { - t.Fatalf("insert alice post 2 failed: %v", err) - } - if _, err := db.Insert().Table(posts).Set(posts.UserID, bobID).Set(posts.Title, "Bob Post").Exec(ctx); err != nil { - t.Fatalf("insert bob post failed: %v", err) - } - - var postsWithAuthor []internalPostWithAuthorRow - if err := db.Select(). - Table(posts). - Where(posts.Title.Eq("Hello from Alice")). - WithRelations("author"). - Scan(ctx, &postsWithAuthor); err != nil { - t.Fatalf("select with author relation failed: %v", err) - } - if len(postsWithAuthor) != 1 { - t.Fatalf("expected one post row, got %d", len(postsWithAuthor)) - } - if postsWithAuthor[0].Author.Email != "alice@example.com" { - t.Fatalf("expected author alice@example.com, got %#v", postsWithAuthor[0].Author) - } - - var postsWithAuthorPtr []internalPostWithAuthorPointerRow - if err := db.Select(). - Table(posts). - Where(posts.Title.Eq("Hello from Alice")). - WithRelations("author"). - Scan(ctx, &postsWithAuthorPtr); err != nil { - t.Fatalf("select with pointer author relation failed: %v", err) - } - if len(postsWithAuthorPtr) != 1 || postsWithAuthorPtr[0].Author == nil || postsWithAuthorPtr[0].Author.Email != "alice@example.com" { - t.Fatalf("expected pointer author alice@example.com, got %#v", postsWithAuthorPtr) - } - - var usersWithPosts []internalUserWithPostsRow - if err := db.Select(). - Table(users). - Where(users.ID.Eq(aliceID)). - WithRelations("posts"). - Scan(ctx, &usersWithPosts); err != nil { - t.Fatalf("select with posts relation failed: %v", err) - } - if len(usersWithPosts) != 1 { - t.Fatalf("expected one user row, got %d", len(usersWithPosts)) - } - if len(usersWithPosts[0].Posts) != 2 { - t.Fatalf("expected two posts for alice, got %d", len(usersWithPosts[0].Posts)) - } - - var usersWithPostPointers []internalUserWithPostPointersRow - if err := db.Select(). - Table(users). - Where(users.ID.Eq(aliceID)). - WithRelations("posts"). - Scan(ctx, &usersWithPostPointers); err != nil { - t.Fatalf("select with pointer posts relation failed: %v", err) - } - if len(usersWithPostPointers) != 1 || len(usersWithPostPointers[0].Posts) != 2 || usersWithPostPointers[0].Posts[0] == nil { - t.Fatalf("expected pointer posts relation to populate, got %#v", usersWithPostPointers) - } - - var nested []internalUserWithPostsAndAuthorsRow - if err := db.Select(). - Table(users). - Where(users.ID.Eq(aliceID)). - WithRelations("posts.author"). - Scan(ctx, &nested); err != nil { - t.Fatalf("select with nested relations failed: %v", err) - } - if len(nested) != 1 || len(nested[0].Posts) != 2 { - t.Fatalf("expected nested relation rows, got %#v", nested) - } - for _, post := range nested[0].Posts { - if post.Author == nil || post.Author.Email != "alice@example.com" { - t.Fatalf("expected nested author alice@example.com, got %#v", post.Author) - } - } - - var bad []internalUserRow - err = db.Select().Table(users).WithRelations("does_not_exist").Scan(ctx, &bad) - if err == nil || !strings.Contains(err.Error(), "unknown relation") { - t.Fatalf("expected unknown relation error, got %v", err) - } - - var empty []internalUserRow - err = db.Select().Table(users).Where(users.ID.Eq(-999)).WithRelations("does_not_exist").Scan(ctx, &empty) - if err == nil || !strings.Contains(err.Error(), "unknown relation") { - t.Fatalf("expected unknown relation error for empty result, got %v", err) - } - - err = db.Select().Table(users).WithRelations("posts.does_not_exist").Scan(ctx, &bad) - if err == nil || !strings.Contains(err.Error(), "unknown relation") { - t.Fatalf("expected unknown nested relation error, got %v", err) - } -} - -func TestRelationLoadingBatchesQueriesPerRelation(t *testing.T) { - t.Parallel() - - ctx := context.Background() - db := openInternalQueryDB(t) - users, posts := defineInternalQueryTables() - createInternalQuerySchema(t, ctx, db) - - aliceResult, err := db.Insert().Table(users).Set(users.Email, "alice@example.com").Set(users.Name, "Alice").Exec(ctx) - if err != nil { - t.Fatalf("insert alice failed: %v", err) - } - aliceID, err := aliceResult.LastInsertId() - if err != nil { - t.Fatalf("alice last insert id failed: %v", err) - } - bobResult, err := db.Insert().Table(users).Set(users.Email, "bob@example.com").Set(users.Name, "Bob").Exec(ctx) - if err != nil { - t.Fatalf("insert bob failed: %v", err) - } - bobID, err := bobResult.LastInsertId() - if err != nil { - t.Fatalf("bob last insert id failed: %v", err) - } - - for _, row := range []struct { - userID int64 - title string - }{ - {userID: aliceID, title: "Alice 1"}, - {userID: aliceID, title: "Alice 2"}, - {userID: bobID, title: "Bob 1"}, - } { - if _, err := db.Insert().Table(posts).Set(posts.UserID, row.userID).Set(posts.Title, row.title).Exec(ctx); err != nil { - t.Fatalf("insert post %q failed: %v", row.title, err) - } - } - - runner := &countingRunner{base: db} - query := &SelectQuery{runner: runner, dialect: db.Dialect()} - - var rows []internalUserWithPostsRow - if err := query.Table(users).WithRelations("posts").Scan(ctx, &rows); err != nil { - t.Fatalf("relation batch scan failed: %v", err) - } - if len(rows) != 2 { - t.Fatalf("expected 2 users, got %d", len(rows)) - } - if runner.queryCount != 2 { - t.Fatalf("expected 2 query executions (base + relation batch), got %d", runner.queryCount) - } - if len(runner.lastQueries) != 2 || !strings.Contains(runner.lastQueries[1], `IN (`) { - t.Fatalf("expected relation load query with IN clause, got %#v", runner.lastQueries) - } -} - -func TestRelationLoadingChunksLargeINQueries(t *testing.T) { - t.Parallel() - - ctx := context.Background() - db := openInternalQueryDB(t) - users, posts := defineInternalQueryTables() - createInternalQuerySchema(t, ctx, db) - - for idx := 0; idx < relationBatchSize+5; idx++ { - result, err := db.Insert(). - Table(users). - Set(users.Email, fmt.Sprintf("user-%d@example.com", idx)). - Set(users.Name, fmt.Sprintf("User %d", idx)). - Exec(ctx) - if err != nil { - t.Fatalf("insert user %d failed: %v", idx, err) - } - userID, err := result.LastInsertId() - if err != nil { - t.Fatalf("last insert id %d failed: %v", idx, err) - } - if _, err := db.Insert().Table(posts).Set(posts.UserID, userID).Set(posts.Title, fmt.Sprintf("Post %d", idx)).Exec(ctx); err != nil { - t.Fatalf("insert post %d failed: %v", idx, err) - } - } - - runner := &countingRunner{base: db} - query := &SelectQuery{runner: runner, dialect: db.Dialect()} - - var rows []internalUserWithPostsRow - if err := query.Table(users).WithRelations("posts").Scan(ctx, &rows); err != nil { - t.Fatalf("chunked relation load failed: %v", err) - } - if len(rows) != relationBatchSize+5 { - t.Fatalf("expected %d users, got %d", relationBatchSize+5, len(rows)) - } - if runner.queryCount != 3 { - t.Fatalf("expected 3 query executions (base + 2 relation batches), got %d", runner.queryCount) - } -} - -func TestRelationElementTypeFromTypeHandlesPointerSlices(t *testing.T) { - t.Parallel() - - users, _ := defineInternalQueryTables() - db, err := OpenDialect("sqlite") - if err != nil { - t.Fatalf("OpenDialect(sqlite): %v", err) - } - - parentsType := reflect.TypeOf([]*internalUserWithPostPointersRow{}) - parentStructType, err := sliceParentStructType(parentsType) - if err != nil { - t.Fatalf("sliceParentStructType failed: %v", err) - } - - relatedType, err := db.Select().relationElementTypeFromType(parentStructType, users.TableDef().Relations[0]) - if err != nil { - t.Fatalf("relationElementTypeFromType failed: %v", err) - } - if relatedType != reflect.TypeOf(internalPostOnlyRow{}) { - t.Fatalf("expected related type %v, got %v", reflect.TypeOf(internalPostOnlyRow{}), relatedType) - } -} - -func TestCompileContextAndAssignmentsHelpers(t *testing.T) { - t.Parallel() - - users, posts := defineInternalQueryTables() - - ctx := newCompileContext(dialectForTest(t, "postgres")) - if err := ctx.writeRaw(schema.Raw("NOW()")); err != nil { - t.Fatalf("writeRaw without args failed: %v", err) - } - if ctx.String() != "NOW()" { - t.Fatalf("unexpected raw SQL: %s", ctx.String()) - } - - ctx = newCompileContext(dialectForTest(t, "postgres")) - if err := ctx.writeRaw(schema.Raw("? + ?", 1, 2)); err != nil { - t.Fatalf("writeRaw placeholders failed: %v", err) - } - if ctx.String() != "$1 + $2" { - t.Fatalf("unexpected placeholder SQL: %s", ctx.String()) - } - - if err := newCompileContext(dialectForTest(t, "postgres")).writeRaw(schema.Raw("?", 1, 2)); err == nil || !strings.Contains(err.Error(), "unused args") { - t.Fatalf("expected raw unused args error, got %v", err) - } - if err := newCompileContext(dialectForTest(t, "postgres")).writeRaw(schema.Raw("? ?", 1)); err == nil || !strings.Contains(err.Error(), "placeholder count") { - t.Fatalf("expected raw placeholder mismatch error, got %v", err) - } - if err := newCompileContext(dialectForTest(t, "postgres")).writeExpression(users.ID.In()); err == nil || !strings.Contains(err.Error(), "requires at least one value") { - t.Fatalf("expected empty IN error, got %v", err) - } - if err := newCompileContext(dialectForTest(t, "postgres")).writeExpression(nil); err == nil || !strings.Contains(err.Error(), "unsupported expression type") { - t.Fatalf("expected unsupported expression error, got %v", err) - } - - merged, err := mergeAssignments(users.TableDef(), - []assignment{ - {column: users.Email, value: schema.ValueExpr{Value: "base@example.com"}}, - {column: users.Name, value: schema.ValueExpr{Value: "Base"}}, - }, - []assignment{ - {column: users.Name, value: schema.ValueExpr{Value: "Override"}}, - {column: users.Active, value: schema.ValueExpr{Value: true}}, - }, - ) - if err != nil { - t.Fatalf("mergeAssignments failed: %v", err) - } - if len(merged) != 3 { - t.Fatalf("expected 3 merged assignments, got %d", len(merged)) - } - if merged[1].column.ColumnDef().Name != "name" || merged[2].column.ColumnDef().Name != "active" { - t.Fatalf("unexpected merged order: %#v", merged) - } - if merged[1].value.(schema.ValueExpr).Value != "Override" { - t.Fatalf("expected override assignment to win, got %#v", merged[1].value) - } - - if _, err := mergeAssignments(users.TableDef(), nil, []assignment{{column: posts.Title, value: schema.ValueExpr{Value: "bad"}}}); err == nil || !strings.Contains(err.Error(), "belongs to table posts") { - t.Fatalf("expected foreign table assignment error, got %v", err) - } - - ghostColumn := schema.Ref(&schema.ColumnDef{Table: users.TableDef(), Name: "ghost"}) - if _, err := mergeAssignments(users.TableDef(), nil, []assignment{{column: ghostColumn, value: schema.ValueExpr{Value: "bad"}}}); err == nil || !strings.Contains(err.Error(), "unknown column ghost") { - t.Fatalf("expected unknown column assignment error, got %v", err) - } - - if got := joinPredicates([]schema.Predicate{users.Active.Eq(true)}); got != users.Active.Eq(true) { - t.Fatalf("expected single predicate to pass through") - } - if _, ok := joinPredicates([]schema.Predicate{users.Active.Eq(true), users.Email.Eq("alice@example.com")}).(schema.LogicalExpr); !ok { - t.Fatalf("expected multiple predicates to produce logical expression") - } -} - -func TestModelAssignmentAndValueHelpers(t *testing.T) { - t.Parallel() - - users, _ := defineInternalQueryTables() - - nickname := "ally" - assignments, err := assignmentsFromModel(users.TableDef(), &internalInsertModel{ - ID: 0, - Email: "alice@example.com", - Name: "", - Active: false, - Nickname: &nickname, - }, true) - if err != nil { - t.Fatalf("assignmentsFromModel failed: %v", err) - } - if len(assignments) != 2 { - t.Fatalf("expected 2 assignments after skipping default-backed zero values, got %d", len(assignments)) - } - if assignments[0].column.ColumnDef().Name != "email" || assignments[1].column.ColumnDef().Name != "nickname" { - t.Fatalf("unexpected assignments: %#v", assignments) - } - - assignments, err = assignmentsFromModel(users.TableDef(), &internalInsertModel{ - ID: 42, - Email: "bob@example.com", - Name: "Bob", - Active: true, - }, false) - if err != nil { - t.Fatalf("assignmentsFromModel skipAuto=false failed: %v", err) - } - if len(assignments) != 4 { - t.Fatalf("expected 4 assignments when auto id is retained, got %d", len(assignments)) - } - - if _, include := fieldValueForInsert(users.ID.ColumnDef(), reflect.ValueOf(int64(0)), true); include { - t.Fatalf("expected zero auto-increment id to be skipped") - } - if _, include := fieldValueForInsert(users.Name.ColumnDef(), reflect.ValueOf(""), true); include { - t.Fatalf("expected default-backed zero string to be skipped") - } - if _, include := fieldValueForInsert(users.Nickname.ColumnDef(), reflect.ValueOf((*string)(nil)), true); include { - t.Fatalf("expected nil pointer to be skipped") - } - if value, include := fieldValueForInsert(users.Name.ColumnDef(), reflect.ValueOf("Alice"), true); !include || value != "Alice" { - t.Fatalf("expected non-zero string to be included, got %#v include=%v", value, include) - } - - type pointerHolder struct { - Value **string - } - var nilStringPtr *string - holder := pointerHolder{Value: &nilStringPtr} - if _, isNil := dereferenceValue(reflect.ValueOf(holder).Field(0)); !isNil { - t.Fatalf("expected nested nil pointer to be detected") - } - - name := "Alice" - namePtr := &name - holder = pointerHolder{Value: &namePtr} - resolved, isNil := dereferenceValue(reflect.ValueOf(holder).Field(0)) - if isNil || resolved.Kind() != reflect.String || resolved.String() != "Alice" { - t.Fatalf("unexpected dereference result: %#v isNil=%v", resolved, isNil) - } -} - -func dialectForTest(t *testing.T, driver string) dialect.Dialect { - t.Helper() - - db, err := OpenDialect(driver) - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - - return db.Dialect() -} diff --git a/pkg/rain/query_model_assignments.go b/pkg/rain/query_model_assignments.go new file mode 100644 index 0000000..6b1334d --- /dev/null +++ b/pkg/rain/query_model_assignments.go @@ -0,0 +1,116 @@ +package rain + +import ( + "fmt" + "reflect" + "slices" + "strings" + + "github.com/hyperlocalise/rain-orm/pkg/schema" +) + +func assignmentsFromModel(table *schema.TableDef, model any, skipAuto bool) ([]assignment, error) { + meta, value, err := lookupModelMeta(model) + if err != nil { + return nil, err + } + + assignments := make([]assignment, 0, len(table.Columns)) + for _, column := range table.Columns { + field, ok := meta.byColumn[column.Name] + if !ok { + continue + } + + fieldValue := value.FieldByIndex(field.index) + resolvedValue, include := fieldValueForInsert(column, fieldValue, skipAuto) + if !include { + continue + } + + assignments = append(assignments, assignment{ + column: schema.Ref(column), + value: schema.ValueExpr{Value: resolvedValue}, + }) + } + + return assignments, nil +} + +func mergeAssignments(table *schema.TableDef, base, overrides []assignment) ([]assignment, error) { + ordered := make([]assignment, 0, len(table.Columns)) + assignmentsByName := make(map[string]assignment, len(table.Columns)) + + for _, item := range base { + if err := validateAssignmentTarget(table, item); err != nil { + return nil, err + } + assignmentsByName[item.column.ColumnDef().Name] = item + } + for _, item := range overrides { + if err := validateAssignmentTarget(table, item); err != nil { + return nil, err + } + assignmentsByName[item.column.ColumnDef().Name] = item + } + + for _, column := range table.Columns { + item, ok := assignmentsByName[column.Name] + if !ok { + continue + } + ordered = append(ordered, item) + delete(assignmentsByName, column.Name) + } + + if len(assignmentsByName) > 0 { + names := make([]string, 0, len(assignmentsByName)) + for name := range assignmentsByName { + names = append(names, name) + } + slices.Sort(names) + return nil, fmt.Errorf("rain: insert assignments contain unknown target columns: %s", strings.Join(names, ", ")) + } + + return ordered, nil +} + +func validateAssignmentTarget(table *schema.TableDef, item assignment) error { + column := item.column.ColumnDef() + if column.Table.Name != table.Name { + return fmt.Errorf("rain: column %s belongs to table %s, not %s", column.Name, column.Table.Name, table.Name) + } + if _, ok := table.ColumnByName(column.Name); !ok { + return fmt.Errorf("rain: unknown column %s on table %s", column.Name, table.Name) + } + + return nil +} + +func fieldValueForInsert(column *schema.ColumnDef, fieldValue reflect.Value, skipAuto bool) (any, bool) { + resolved, isNil := dereferenceValue(fieldValue) + if isNil { + return nil, false + } + + if skipAuto && column.AutoIncrement && resolved.IsZero() { + return nil, false + } + if column.HasDefault && resolved.IsZero() { + return nil, false + } + + return resolved.Interface(), true +} + +func dereferenceValue(value reflect.Value) (reflect.Value, bool) { + current := value + for current.Kind() == reflect.Pointer { + if current.IsNil() { + return reflect.Value{}, true + } + current = current.Elem() + } + + return current, false +} diff --git a/pkg/rain/query_runtime_internal_test.go b/pkg/rain/query_runtime_internal_test.go new file mode 100644 index 0000000..ec66d5a --- /dev/null +++ b/pkg/rain/query_runtime_internal_test.go @@ -0,0 +1,549 @@ +package rain + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + "testing" + "time" +) + +func TestQueryExecutionPaths(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openInternalQueryDB(t) + users, posts := defineInternalQueryTables() + createInternalQuerySchema(t, ctx, db) + + insert := db.Insert(). + Table(users). + Model(&internalInsertModel{Email: "alice@example.com", Name: "Alice"}) + result, err := insert.Exec(ctx) + if err != nil { + t.Fatalf("insert exec failed: %v", err) + } + insertedID, err := result.LastInsertId() + if err != nil { + t.Fatalf("last insert id failed: %v", err) + } + + if _, err := db.Insert(). + Table(posts). + Set(posts.UserID, insertedID). + Set(posts.Title, "Hello"). + Exec(ctx); err != nil { + t.Fatalf("insert post failed: %v", err) + } + + count, err := db.Select().Table(users).Where(users.Active.Eq(true)).Count(ctx) + if err != nil { + t.Fatalf("count failed: %v", err) + } + if count != 1 { + t.Fatalf("expected count 1, got %d", count) + } + + exists, err := db.Select().Table(users).Where(users.Email.Eq("alice@example.com")).Exists(ctx) + if err != nil { + t.Fatalf("exists failed: %v", err) + } + if !exists { + t.Fatalf("expected row to exist") + } + + var row internalUserRow + if err := db.Select(). + Table(users). + Where(users.ID.Eq(insertedID)). + Scan(ctx, &row); err != nil { + t.Fatalf("scan select failed: %v", err) + } + if row.Email != "alice@example.com" || row.Name != "Alice" { + t.Fatalf("unexpected row: %#v", row) + } + + if err := db.Update(). + Table(users). + Set(users.Name, "Alice Updated"). + Where(users.ID.Eq(insertedID)). + Returning(users.ID, users.Name). + Scan(ctx, &row); err != nil { + t.Fatalf("update returning scan failed: %v", err) + } + if row.ID != insertedID || row.Name != "Alice Updated" { + t.Fatalf("unexpected updated row: %#v", row) + } + + if err := db.Delete(). + Table(users). + Where(users.ID.Eq(insertedID)). + Returning(users.ID, users.Email). + Scan(ctx, &row); err != nil { + t.Fatalf("delete returning scan failed: %v", err) + } + if row.ID != insertedID || row.Email != "alice@example.com" { + t.Fatalf("unexpected deleted row: %#v", row) + } + + if _, err := db.Select().Table(users).Where(users.ID.Eq(insertedID)).Count(ctx); !errors.Is(err, sql.ErrNoRows) && err != nil { + t.Fatalf("unexpected count error after delete: %v", err) + } +} + +func TestSelectQueryCacheHitMissExpiryAndBypass(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openInternalQueryDB(t) + users, _ := defineInternalQueryTables() + createInternalQuerySchema(t, ctx, db) + + cache := NewMemoryQueryCache() + cache.now = func() time.Time { return time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) } + db.WithQueryCache(cache) + + if _, err := db.Insert().Table(users).Model(&internalInsertModel{Email: "cache@example.com", Name: "Cache"}).Exec(ctx); err != nil { + t.Fatalf("insert user: %v", err) + } + + counter := &countingRunner{base: db} + q := (&SelectQuery{runner: counter, dialect: db.Dialect(), cache: cache}). + Table(users). + Where(users.Email.Eq("cache@example.com")). + Cache(QueryCacheOptions{TTL: time.Minute, Tags: []string{"users"}}) + + var first []internalUserRow + if err := q.Scan(ctx, &first); err != nil { + t.Fatalf("first scan: %v", err) + } + if counter.queryCount != 1 { + t.Fatalf("expected first query to hit DB once, got %d", counter.queryCount) + } + + var second []internalUserRow + if err := q.Scan(ctx, &second); err != nil { + t.Fatalf("second scan: %v", err) + } + if counter.queryCount != 1 { + t.Fatalf("expected second query to hit cache, got query count %d", counter.queryCount) + } + if !reflect.DeepEqual(first, second) { + t.Fatalf("cached scan mismatch:\nfirst=%#v\nsecond=%#v", first, second) + } + + cache.now = func() time.Time { return time.Date(2026, 3, 29, 12, 2, 0, 0, time.UTC) } + var third []internalUserRow + if err := q.Scan(ctx, &third); err != nil { + t.Fatalf("third scan after expiry: %v", err) + } + if counter.queryCount != 2 { + t.Fatalf("expected expiry to force DB query, got %d", counter.queryCount) + } + + var bypassed []internalUserRow + if err := q.Cache(QueryCacheOptions{TTL: time.Minute, Tags: []string{"users"}, Bypass: true}).Scan(ctx, &bypassed); err != nil { + t.Fatalf("bypass scan: %v", err) + } + if counter.queryCount != 3 { + t.Fatalf("expected bypass to force DB query, got %d", counter.queryCount) + } +} + +func TestSelectQueryCacheArgsAndManualInvalidation(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openInternalQueryDB(t) + users, _ := defineInternalQueryTables() + createInternalQuerySchema(t, ctx, db) + db.WithQueryCache(NewMemoryQueryCache()) + + for _, item := range []internalInsertModel{ + {Email: "alice@example.com", Name: "Alice"}, + {Email: "bob@example.com", Name: "Bob"}, + } { + if _, err := db.Insert().Table(users).Model(&item).Exec(ctx); err != nil { + t.Fatalf("insert user %s: %v", item.Email, err) + } + } + + counter := &countingRunner{base: db} + queryFor := func(email string) *SelectQuery { + return (&SelectQuery{runner: counter, dialect: db.Dialect(), cache: db.queryCache}). + Table(users). + Where(users.Email.Eq(email)). + Cache(QueryCacheOptions{TTL: 5 * time.Minute, Tags: []string{"users"}}) + } + + var alice []internalUserRow + if err := queryFor("alice@example.com").Scan(ctx, &alice); err != nil { + t.Fatalf("alice query first run: %v", err) + } + if err := queryFor("alice@example.com").Scan(ctx, &alice); err != nil { + t.Fatalf("alice query second run: %v", err) + } + if counter.queryCount != 1 { + t.Fatalf("expected repeated identical args to hit cache, query count %d", counter.queryCount) + } + + var bob []internalUserRow + if err := queryFor("bob@example.com").Scan(ctx, &bob); err != nil { + t.Fatalf("bob query first run: %v", err) + } + if counter.queryCount != 2 { + t.Fatalf("expected different args to use different entry, query count %d", counter.queryCount) + } + + if err := db.InvalidateQueryCache(ctx, "users"); err != nil { + t.Fatalf("invalidate tag: %v", err) + } + if err := queryFor("alice@example.com").Scan(ctx, &alice); err != nil { + t.Fatalf("alice query after invalidation: %v", err) + } + if counter.queryCount != 3 { + t.Fatalf("expected invalidation miss, query count %d", counter.queryCount) + } +} + +func TestSelectQueryCacheDisabledKeepsNormalBehavior(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openInternalQueryDB(t) + users, _ := defineInternalQueryTables() + createInternalQuerySchema(t, ctx, db) + if _, err := db.Insert().Table(users).Model(&internalInsertModel{Email: "nocache@example.com", Name: "No Cache"}).Exec(ctx); err != nil { + t.Fatalf("insert user: %v", err) + } + + counter := &countingRunner{base: db} + q := (&SelectQuery{runner: counter, dialect: db.Dialect()}). + Table(users). + Where(users.Email.Eq("nocache@example.com")). + Cache(QueryCacheOptions{TTL: time.Minute, Tags: []string{"users"}}) + + var rows []internalUserRow + if err := q.Scan(ctx, &rows); err != nil { + t.Fatalf("first uncached scan: %v", err) + } + if err := q.Scan(ctx, &rows); err != nil { + t.Fatalf("second uncached scan: %v", err) + } + if counter.queryCount != 2 { + t.Fatalf("expected uncached behavior without backend, query count %d", counter.queryCount) + } +} + +func TestBuildQueryCacheKeyIsStableForEquivalentArgs(t *testing.T) { + t.Parallel() + + opts := normalizeQueryCacheOptions(QueryCacheOptions{TTL: time.Minute, Tags: []string{"users", "lookup"}, Namespace: "by-id"}) + keyOne, err := buildQueryCacheKey("sqlite", "SELECT * FROM users WHERE id = ?", []any{int64(1)}, nil, opts) + if err != nil { + t.Fatalf("build key one: %v", err) + } + keyTwo, err := buildQueryCacheKey("sqlite", "SELECT * FROM users WHERE id = ?", []any{int64(1)}, nil, opts) + if err != nil { + t.Fatalf("build key two: %v", err) + } + if keyOne != keyTwo { + t.Fatalf("expected stable key, got %q and %q", keyOne, keyTwo) + } +} + +func TestSelectAggregateCacheForCountAndExists(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openInternalQueryDB(t) + users, _ := defineInternalQueryTables() + createInternalQuerySchema(t, ctx, db) + db.WithQueryCache(NewMemoryQueryCache()) + + if _, err := db.Insert().Table(users).Model(&internalInsertModel{Email: "agg@example.com", Name: "Agg"}).Exec(ctx); err != nil { + t.Fatalf("insert user: %v", err) + } + + counter := &countingRunner{base: db} + query := (&SelectQuery{runner: counter, dialect: db.Dialect(), cache: db.queryCache}). + Table(users). + Where(users.Email.Eq("agg@example.com")). + Cache(QueryCacheOptions{TTL: time.Minute, Tags: []string{"users"}}) + + count, err := query.Count(ctx) + if err != nil { + t.Fatalf("count first: %v", err) + } + if count != 1 { + t.Fatalf("expected count 1, got %d", count) + } + if _, err := query.Count(ctx); err != nil { + t.Fatalf("count second: %v", err) + } + if counter.queryCount != 1 { + t.Fatalf("expected second count to hit cache, query count %d", counter.queryCount) + } + + exists, err := query.Exists(ctx) + if err != nil { + t.Fatalf("exists first: %v", err) + } + if !exists { + t.Fatalf("expected exists=true") + } + if _, err := query.Exists(ctx); err != nil { + t.Fatalf("exists second: %v", err) + } + if counter.queryCount != 2 { + t.Fatalf("expected second exists to hit cache, query count %d", counter.queryCount) + } +} + +func TestSelectWithRelations(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openInternalQueryDB(t) + users, posts := defineInternalQueryTables() + createInternalQuerySchema(t, ctx, db) + + aliceResult, err := db.Insert().Table(users).Set(users.Email, "alice@example.com").Set(users.Name, "Alice").Exec(ctx) + if err != nil { + t.Fatalf("insert alice failed: %v", err) + } + aliceID, err := aliceResult.LastInsertId() + if err != nil { + t.Fatalf("alice last insert id failed: %v", err) + } + bobResult, err := db.Insert().Table(users).Set(users.Email, "bob@example.com").Set(users.Name, "Bob").Exec(ctx) + if err != nil { + t.Fatalf("insert bob failed: %v", err) + } + bobID, err := bobResult.LastInsertId() + if err != nil { + t.Fatalf("bob last insert id failed: %v", err) + } + + if _, err := db.Insert().Table(posts).Set(posts.UserID, aliceID).Set(posts.Title, "Hello from Alice").Exec(ctx); err != nil { + t.Fatalf("insert alice post failed: %v", err) + } + if _, err := db.Insert().Table(posts).Set(posts.UserID, aliceID).Set(posts.Title, "Second Alice Post").Exec(ctx); err != nil { + t.Fatalf("insert alice post 2 failed: %v", err) + } + if _, err := db.Insert().Table(posts).Set(posts.UserID, bobID).Set(posts.Title, "Bob Post").Exec(ctx); err != nil { + t.Fatalf("insert bob post failed: %v", err) + } + + var postsWithAuthor []internalPostWithAuthorRow + if err := db.Select(). + Table(posts). + Where(posts.Title.Eq("Hello from Alice")). + WithRelations("author"). + Scan(ctx, &postsWithAuthor); err != nil { + t.Fatalf("select with author relation failed: %v", err) + } + if len(postsWithAuthor) != 1 { + t.Fatalf("expected one post row, got %d", len(postsWithAuthor)) + } + if postsWithAuthor[0].Author.Email != "alice@example.com" { + t.Fatalf("expected author alice@example.com, got %#v", postsWithAuthor[0].Author) + } + + var postsWithAuthorPtr []internalPostWithAuthorPointerRow + if err := db.Select(). + Table(posts). + Where(posts.Title.Eq("Hello from Alice")). + WithRelations("author"). + Scan(ctx, &postsWithAuthorPtr); err != nil { + t.Fatalf("select with pointer author relation failed: %v", err) + } + if len(postsWithAuthorPtr) != 1 || postsWithAuthorPtr[0].Author == nil || postsWithAuthorPtr[0].Author.Email != "alice@example.com" { + t.Fatalf("expected pointer author alice@example.com, got %#v", postsWithAuthorPtr) + } + + var usersWithPosts []internalUserWithPostsRow + if err := db.Select(). + Table(users). + Where(users.ID.Eq(aliceID)). + WithRelations("posts"). + Scan(ctx, &usersWithPosts); err != nil { + t.Fatalf("select with posts relation failed: %v", err) + } + if len(usersWithPosts) != 1 { + t.Fatalf("expected one user row, got %d", len(usersWithPosts)) + } + if len(usersWithPosts[0].Posts) != 2 { + t.Fatalf("expected two posts for alice, got %d", len(usersWithPosts[0].Posts)) + } + + var usersWithPostPointers []internalUserWithPostPointersRow + if err := db.Select(). + Table(users). + Where(users.ID.Eq(aliceID)). + WithRelations("posts"). + Scan(ctx, &usersWithPostPointers); err != nil { + t.Fatalf("select with pointer posts relation failed: %v", err) + } + if len(usersWithPostPointers) != 1 || len(usersWithPostPointers[0].Posts) != 2 || usersWithPostPointers[0].Posts[0] == nil { + t.Fatalf("expected pointer posts relation to populate, got %#v", usersWithPostPointers) + } + + var nested []internalUserWithPostsAndAuthorsRow + if err := db.Select(). + Table(users). + Where(users.ID.Eq(aliceID)). + WithRelations("posts.author"). + Scan(ctx, &nested); err != nil { + t.Fatalf("select with nested relations failed: %v", err) + } + if len(nested) != 1 || len(nested[0].Posts) != 2 { + t.Fatalf("expected nested relation rows, got %#v", nested) + } + for _, post := range nested[0].Posts { + if post.Author == nil || post.Author.Email != "alice@example.com" { + t.Fatalf("expected nested author alice@example.com, got %#v", post.Author) + } + } + + var bad []internalUserRow + err = db.Select().Table(users).WithRelations("does_not_exist").Scan(ctx, &bad) + if err == nil || !strings.Contains(err.Error(), "unknown relation") { + t.Fatalf("expected unknown relation error, got %v", err) + } + + var empty []internalUserRow + err = db.Select().Table(users).Where(users.ID.Eq(-999)).WithRelations("does_not_exist").Scan(ctx, &empty) + if err == nil || !strings.Contains(err.Error(), "unknown relation") { + t.Fatalf("expected unknown relation error for empty result, got %v", err) + } + + err = db.Select().Table(users).WithRelations("posts.does_not_exist").Scan(ctx, &bad) + if err == nil || !strings.Contains(err.Error(), "unknown relation") { + t.Fatalf("expected unknown nested relation error, got %v", err) + } +} + +func TestRelationLoadingBatchesQueriesPerRelation(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openInternalQueryDB(t) + users, posts := defineInternalQueryTables() + createInternalQuerySchema(t, ctx, db) + + aliceResult, err := db.Insert().Table(users).Set(users.Email, "alice@example.com").Set(users.Name, "Alice").Exec(ctx) + if err != nil { + t.Fatalf("insert alice failed: %v", err) + } + aliceID, err := aliceResult.LastInsertId() + if err != nil { + t.Fatalf("alice last insert id failed: %v", err) + } + bobResult, err := db.Insert().Table(users).Set(users.Email, "bob@example.com").Set(users.Name, "Bob").Exec(ctx) + if err != nil { + t.Fatalf("insert bob failed: %v", err) + } + bobID, err := bobResult.LastInsertId() + if err != nil { + t.Fatalf("bob last insert id failed: %v", err) + } + + for _, row := range []struct { + userID int64 + title string + }{ + {userID: aliceID, title: "Alice 1"}, + {userID: aliceID, title: "Alice 2"}, + {userID: bobID, title: "Bob 1"}, + } { + if _, err := db.Insert().Table(posts).Set(posts.UserID, row.userID).Set(posts.Title, row.title).Exec(ctx); err != nil { + t.Fatalf("insert post %q failed: %v", row.title, err) + } + } + + runner := &countingRunner{base: db} + query := &SelectQuery{runner: runner, dialect: db.Dialect()} + + var rows []internalUserWithPostsRow + if err := query.Table(users).WithRelations("posts").Scan(ctx, &rows); err != nil { + t.Fatalf("relation batch scan failed: %v", err) + } + if len(rows) != 2 { + t.Fatalf("expected 2 users, got %d", len(rows)) + } + if runner.queryCount != 2 { + t.Fatalf("expected 2 query executions (base + relation batch), got %d", runner.queryCount) + } + if len(runner.lastQueries) != 2 || !strings.Contains(runner.lastQueries[1], `IN (`) { + t.Fatalf("expected relation load query with IN clause, got %#v", runner.lastQueries) + } +} + +func TestRelationLoadingChunksLargeINQueries(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openInternalQueryDB(t) + users, posts := defineInternalQueryTables() + createInternalQuerySchema(t, ctx, db) + + for idx := 0; idx < relationBatchSize+5; idx++ { + result, err := db.Insert(). + Table(users). + Set(users.Email, fmt.Sprintf("user-%d@example.com", idx)). + Set(users.Name, fmt.Sprintf("User %d", idx)). + Exec(ctx) + if err != nil { + t.Fatalf("insert user %d failed: %v", idx, err) + } + userID, err := result.LastInsertId() + if err != nil { + t.Fatalf("last insert id %d failed: %v", idx, err) + } + if _, err := db.Insert().Table(posts).Set(posts.UserID, userID).Set(posts.Title, fmt.Sprintf("Post %d", idx)).Exec(ctx); err != nil { + t.Fatalf("insert post %d failed: %v", idx, err) + } + } + + runner := &countingRunner{base: db} + query := &SelectQuery{runner: runner, dialect: db.Dialect()} + + var rows []internalUserWithPostsRow + if err := query.Table(users).WithRelations("posts").Scan(ctx, &rows); err != nil { + t.Fatalf("chunked relation load failed: %v", err) + } + if len(rows) != relationBatchSize+5 { + t.Fatalf("expected %d users, got %d", relationBatchSize+5, len(rows)) + } + if runner.queryCount != 3 { + t.Fatalf("expected 3 query executions (base + 2 relation batches), got %d", runner.queryCount) + } +} + +func TestRelationElementTypeFromTypeHandlesPointerSlices(t *testing.T) { + t.Parallel() + + users, _ := defineInternalQueryTables() + db, err := OpenDialect("sqlite") + if err != nil { + t.Fatalf("OpenDialect(sqlite): %v", err) + } + + parentsType := reflect.TypeOf([]*internalUserWithPostPointersRow{}) + parentStructType, err := sliceParentStructType(parentsType) + if err != nil { + t.Fatalf("sliceParentStructType failed: %v", err) + } + + relatedType, err := db.Select().relationElementTypeFromType(parentStructType, users.TableDef().Relations[0]) + if err != nil { + t.Fatalf("relationElementTypeFromType failed: %v", err) + } + if relatedType != reflect.TypeOf(internalPostOnlyRow{}) { + t.Fatalf("expected related type %v, got %v", reflect.TypeOf(internalPostOnlyRow{}), relatedType) + } +} diff --git a/pkg/rain/query_select.go b/pkg/rain/query_select.go new file mode 100644 index 0000000..2390689 --- /dev/null +++ b/pkg/rain/query_select.go @@ -0,0 +1,478 @@ +package rain + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/hyperlocalise/rain-orm/pkg/dialect" + "github.com/hyperlocalise/rain-orm/pkg/schema" +) + +// SelectQuery builds typed SELECT statements. +type SelectQuery struct { + runner queryRunner + dialect dialect.Dialect + cache QueryCache + table selectTableSource + cols []schema.Expression + where []schema.Predicate + joins []joinClause + order []schema.OrderExpr + groupBy []schema.Expression + having []schema.Predicate + ctes []cteDefinition + distinct bool + limit int + offset int + relationNames []string + cacheOptions *queryCacheOptions +} + +// Table sets the table source for the query. +func (q *SelectQuery) Table(table schema.TableReference) *SelectQuery { + q.table = tableDefSource{table: table.TableDef()} + return q +} + +// TableSubquery sets a subquery source for the query's FROM clause. +func (q *SelectQuery) TableSubquery(query *SelectQuery, alias string) *SelectQuery { + q.table = subqueryTableSource{query: query, alias: alias} + return q +} + +// Column sets the selected expressions. +func (q *SelectQuery) Column(cols ...schema.Expression) *SelectQuery { + q.cols = append(q.cols, cols...) + return q +} + +// Where appends a WHERE predicate joined with AND. +func (q *SelectQuery) Where(predicate schema.Predicate) *SelectQuery { + q.where = append(q.where, predicate) + return q +} + +// Join appends an INNER JOIN clause. +func (q *SelectQuery) Join(table schema.TableReference, on schema.Predicate) *SelectQuery { + q.joins = append(q.joins, joinClause{kind: "INNER JOIN", table: tableDefSource{table: table.TableDef()}, on: on}) + return q +} + +// LeftJoin appends a LEFT JOIN clause. +func (q *SelectQuery) LeftJoin(table schema.TableReference, on schema.Predicate) *SelectQuery { + q.joins = append(q.joins, joinClause{kind: "LEFT JOIN", table: tableDefSource{table: table.TableDef()}, on: on}) + return q +} + +// JoinSubquery appends an INNER JOIN against a subquery source. +func (q *SelectQuery) JoinSubquery(query *SelectQuery, alias string, on schema.Predicate) *SelectQuery { + q.joins = append(q.joins, joinClause{kind: "INNER JOIN", table: subqueryTableSource{query: query, alias: alias}, on: on}) + return q +} + +// LeftJoinSubquery appends a LEFT JOIN against a subquery source. +func (q *SelectQuery) LeftJoinSubquery(query *SelectQuery, alias string, on schema.Predicate) *SelectQuery { + q.joins = append(q.joins, joinClause{kind: "LEFT JOIN", table: subqueryTableSource{query: query, alias: alias}, on: on}) + return q +} + +// Distinct marks the SELECT query as DISTINCT. +func (q *SelectQuery) Distinct() *SelectQuery { + q.distinct = true + return q +} + +// GroupBy appends GROUP BY expressions. +func (q *SelectQuery) GroupBy(exprs ...schema.Expression) *SelectQuery { + q.groupBy = append(q.groupBy, exprs...) + return q +} + +// Having appends a HAVING predicate joined with AND. +func (q *SelectQuery) Having(predicate schema.Predicate) *SelectQuery { + q.having = append(q.having, predicate) + return q +} + +// With appends a common table expression definition. +func (q *SelectQuery) With(name string, query *SelectQuery) *SelectQuery { + q.ctes = append(q.ctes, cteDefinition{name: name, query: query}) + return q +} + +// OrderBy appends ORDER BY expressions. +func (q *SelectQuery) OrderBy(order ...schema.OrderExpr) *SelectQuery { + q.order = append(q.order, order...) + return q +} + +// Limit sets the LIMIT clause. +func (q *SelectQuery) Limit(limit int) *SelectQuery { + q.limit = limit + return q +} + +// Offset sets the OFFSET clause. +func (q *SelectQuery) Offset(offset int) *SelectQuery { + q.offset = offset + return q +} + +// WithRelations requests one or more named relations to be loaded after scanning base rows. +func (q *SelectQuery) WithRelations(names ...string) *SelectQuery { + q.relationNames = append(q.relationNames, names...) + return q +} + +// Cache enables opt-in query caching for this SELECT with TTL and optional metadata. +func (q *SelectQuery) Cache(options QueryCacheOptions) *SelectQuery { + q.cacheOptions = normalizeQueryCacheOptions(options) + return q +} + +// ToSQL compiles the query into SQL and args. +func (q *SelectQuery) ToSQL() (string, []any, error) { + ctx := newCompileContext(q.dialect) + if err := q.writeSQL(ctx); err != nil { + return "", nil, err + } + return ctx.String(), ctx.args, nil +} + +func (q *SelectQuery) writeSQL(ctx *compileContext) error { + if q.table == nil { + return errors.New("rain: select query requires a table") + } + + if len(q.ctes) > 0 { + if !dialect.HasFeature(ctx.dialect.Features(), dialect.FeatureCTE) { + return fmt.Errorf("rain: select queries do not support CTEs for %s dialect", ctx.dialect.Name()) + } + ctx.writeString("WITH ") + for idx, cte := range q.ctes { + if idx > 0 { + ctx.writeString(", ") + } + if strings.TrimSpace(cte.name) == "" { + return errors.New("rain: CTE name cannot be empty") + } + if cte.query == nil { + return fmt.Errorf("rain: CTE %q requires a query", cte.name) + } + if len(cte.query.ctes) > 0 { + return fmt.Errorf("rain: CTE %q body cannot itself contain CTEs", cte.name) + } + ctx.writeQuotedIdentifier(cte.name) + ctx.writeString(" AS (") + if err := cte.query.writeSQL(ctx); err != nil { + return err + } + ctx.writeByte(')') + } + ctx.writeByte(' ') + } + + ctx.writeString("SELECT ") + if q.distinct { + ctx.writeString("DISTINCT ") + } + if len(q.cols) == 0 { + ctx.writeString("*") + } else { + for idx, column := range q.cols { + if idx > 0 { + ctx.writeString(", ") + } + if err := ctx.writeSelectExpression(column); err != nil { + return err + } + } + } + + ctx.writeString(" FROM ") + if err := q.table.writeSQL(ctx); err != nil { + return err + } + + for _, join := range q.joins { + ctx.writeByte(' ') + ctx.writeString(join.kind) + ctx.writeByte(' ') + if err := join.table.writeSQL(ctx); err != nil { + return err + } + ctx.writeString(" ON ") + if err := ctx.writePredicate(join.on); err != nil { + return err + } + } + + if len(q.where) > 0 { + ctx.writeString(" WHERE ") + if err := ctx.writePredicate(joinPredicates(q.where)); err != nil { + return err + } + } + + if len(q.groupBy) > 0 { + ctx.writeString(" GROUP BY ") + for idx, expr := range q.groupBy { + if idx > 0 { + ctx.writeString(", ") + } + if err := ctx.writeExpression(expr); err != nil { + return err + } + } + } + + if len(q.having) > 0 { + ctx.writeString(" HAVING ") + if err := ctx.writePredicate(joinPredicates(q.having)); err != nil { + return err + } + } + + if len(q.order) > 0 { + ctx.writeString(" ORDER BY ") + for idx, item := range q.order { + if idx > 0 { + ctx.writeString(", ") + } + if err := ctx.writeExpression(item.Expr); err != nil { + return err + } + ctx.writeByte(' ') + ctx.writeString(string(item.Direction)) + } + } + + if clause := q.dialect.LimitOffset(q.limit, q.offset); clause != "" { + ctx.writeByte(' ') + ctx.writeString(clause) + } + + return nil +} + +// Scan executes the SELECT query and scans results into dest. +func (q *SelectQuery) Scan(ctx context.Context, dest any) error { + if q.runner == nil { + return ErrNoConnection + } + + query, args, err := q.ToSQL() + if err != nil { + return err + } + + cacheKey, cacheOptions, err := q.resolveCacheKey(query, args) + if err != nil { + return err + } + if cacheOptions != nil && !cacheOptions.bypass { + cached, ok, cacheErr := q.cache.Get(ctx, cacheKey) + if cacheErr != nil { + return cacheErr + } + if ok { + return json.Unmarshal(cached, dest) + } + } + + rows, err := q.runner.queryContext(ctx, query, args...) + if err != nil { + return err + } + defer closeRows(rows, &err) + + if len(q.relationNames) == 0 { + err = scanRows(rows, dest) + } else { + err = q.scanRowsWithRelations(ctx, rows, dest) + } + if err != nil { + return err + } + err = q.writeCachedResult(ctx, cacheKey, cacheOptions, dest) + return err +} + +// Count executes SELECT COUNT(*). +func (q *SelectQuery) Count(ctx context.Context) (int64, error) { + if q.runner == nil { + return 0, ErrNoConnection + } + + query, args, err := q.toAggregateSQL("COUNT(*)") + if err != nil { + return 0, err + } + + cacheKey, cacheOptions, err := q.resolveCacheKey(query, args) + if err != nil { + return 0, err + } + if cacheOptions != nil && !cacheOptions.bypass { + cached, ok, cacheErr := q.cache.Get(ctx, cacheKey) + if cacheErr != nil { + return 0, cacheErr + } + if ok { + var count int64 + if err := json.Unmarshal(cached, &count); err != nil { + return 0, err + } + return count, nil + } + } + + rows, err := q.runner.queryContext(ctx, query, args...) + if err != nil { + return 0, err + } + defer closeRows(rows, &err) + + var count int64 + if !rows.Next() { + err = sql.ErrNoRows + return 0, err + } + if err := rows.Scan(&count); err != nil { + return 0, err + } + + err = rows.Err() + if err != nil { + return 0, err + } + err = q.writeCachedResult(ctx, cacheKey, cacheOptions, count) + return count, err +} + +// Exists executes a SELECT EXISTS query. +func (q *SelectQuery) Exists(ctx context.Context) (bool, error) { + if q.runner == nil { + return false, ErrNoConnection + } + + sqlText, args, err := q.ToSQL() + if err != nil { + return false, err + } + + ctxCompiler := newCompileContext(q.dialect) + ctxCompiler.writeString("SELECT EXISTS(") + ctxCompiler.writeString(sqlText) + ctxCompiler.writeByte(')') + ctxCompiler.args = append(ctxCompiler.args, args...) + + query := ctxCompiler.String() + cacheKey, cacheOptions, err := q.resolveCacheKey(query, ctxCompiler.args) + if err != nil { + return false, err + } + if cacheOptions != nil && !cacheOptions.bypass { + cached, ok, cacheErr := q.cache.Get(ctx, cacheKey) + if cacheErr != nil { + return false, cacheErr + } + if ok { + var exists bool + if err := json.Unmarshal(cached, &exists); err != nil { + return false, err + } + return exists, nil + } + } + + rows, err := q.runner.queryContext(ctx, query, ctxCompiler.args...) + if err != nil { + return false, err + } + defer closeRows(rows, &err) + + var exists bool + if !rows.Next() { + err = sql.ErrNoRows + return false, err + } + if err := rows.Scan(&exists); err != nil { + return false, err + } + + err = rows.Err() + if err != nil { + return false, err + } + err = q.writeCachedResult(ctx, cacheKey, cacheOptions, exists) + return exists, err +} + +func (q *SelectQuery) resolveCacheKey(query string, args []any) (string, *queryCacheOptions, error) { + if q.cacheOptions == nil || q.cache == nil { + return "", nil, nil + } + key, err := buildQueryCacheKey(q.dialect.Name(), query, args, q.relationNames, q.cacheOptions) + if err != nil { + return "", nil, err + } + return key, q.cacheOptions, nil +} + +func (q *SelectQuery) writeCachedResult(ctx context.Context, key string, options *queryCacheOptions, value any) error { + if options == nil || options.bypass { + return nil + } + encoded, err := json.Marshal(value) + if err != nil { + return err + } + return q.cache.Set(ctx, key, encoded, options.ttl, options.tags) +} + +func (q *SelectQuery) toAggregateSQL(selection string) (string, []any, error) { + if q.table == nil { + return "", nil, errors.New("rain: select query requires a table") + } + if len(q.ctes) > 0 { + return "", nil, errors.New("rain: aggregate helpers do not support WITH clauses") + } + if q.distinct || len(q.groupBy) > 0 || len(q.having) > 0 { + return "", nil, errors.New("rain: aggregate helpers do not support DISTINCT, GROUP BY, or HAVING clauses") + } + + ctx := newCompileContext(q.dialect) + ctx.writeString("SELECT ") + ctx.writeString(selection) + ctx.writeString(" FROM ") + if err := q.table.writeSQL(ctx); err != nil { + return "", nil, err + } + + for _, join := range q.joins { + ctx.writeByte(' ') + ctx.writeString(join.kind) + ctx.writeByte(' ') + if err := join.table.writeSQL(ctx); err != nil { + return "", nil, err + } + ctx.writeString(" ON ") + if err := ctx.writePredicate(join.on); err != nil { + return "", nil, err + } + } + + if len(q.where) > 0 { + ctx.writeString(" WHERE ") + if err := ctx.writePredicate(joinPredicates(q.where)); err != nil { + return "", nil, err + } + } + + return ctx.String(), ctx.args, ctx.err +} diff --git a/pkg/rain/query_select_test.go b/pkg/rain/query_select_test.go new file mode 100644 index 0000000..6d38cd7 --- /dev/null +++ b/pkg/rain/query_select_test.go @@ -0,0 +1,469 @@ +package rain_test + +import ( + "reflect" + "strings" + "testing" + "time" + + "github.com/hyperlocalise/rain-orm/pkg/rain" + "github.com/hyperlocalise/rain-orm/pkg/schema" +) + +func TestSelectToSQL(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, posts := defineTables() + u := schema.Alias(users, "u") + p := schema.Alias(posts, "p") + + sqlText, args, err := db.Select(). + Table(p). + Column(p.ID, p.Title, u.Email). + Join(u, p.UserID.EqCol(u.ID)). + Where(u.Active.Eq(true)). + OrderBy(p.ID.Desc()). + Limit(10). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `SELECT "p"."id", "p"."title", "u"."email" FROM "posts" AS "p" INNER JOIN "users" AS "u" ON "p"."user_id" = "u"."id" WHERE "u"."active" = $1 ORDER BY "p"."id" DESC LIMIT 10` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if len(args) != 1 || args[0] != true { + t.Fatalf("unexpected args: %#v", args) + } +} + +func TestExpandedTypesCompileToSQL(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + expanded := defineExpandedTypesTable() + processedAt := time.Date(2026, 3, 28, 10, 30, 0, 0, time.UTC) + publishedOn := time.Date(2026, 3, 28, 0, 0, 0, 0, time.UTC) + + sqlText, args, err := db.Select(). + Table(expanded). + Column( + expanded.SmallCount, + expanded.Count, + expanded.Score, + expanded.Precise, + expanded.Amount, + expanded.Meta, + expanded.MetaBin, + expanded.ExternalID, + expanded.Payload, + expanded.PublishedOn, + expanded.ProcessedAt, + expanded.Category, + ). + Where(schema.And( + expanded.SmallCount.Eq(3), + expanded.Count.Eq(11), + expanded.Score.Gt(1.5), + expanded.Precise.Lte(7.25), + expanded.Amount.Eq("42.10"), + expanded.Meta.Eq(map[string]any{"enabled": true}), + expanded.MetaBin.Eq(map[string]any{"raw": "yes"}), + expanded.ExternalID.Eq("00000000-0000-0000-0000-000000000042"), + expanded.Payload.Eq([]byte{0xCA, 0xFE}), + expanded.PublishedOn.Eq(publishedOn), + expanded.ProcessedAt.Eq(processedAt), + expanded.Category.Eq("alpha"), + )). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `SELECT "expanded_types"."small_count", "expanded_types"."count", "expanded_types"."score", "expanded_types"."precise", "expanded_types"."amount", "expanded_types"."meta", "expanded_types"."meta_bin", "expanded_types"."external_id", "expanded_types"."payload", "expanded_types"."published_on", "expanded_types"."processed_at", "expanded_types"."category" FROM "expanded_types" WHERE ("expanded_types"."small_count" = $1 AND "expanded_types"."count" = $2 AND "expanded_types"."score" > $3 AND "expanded_types"."precise" <= $4 AND "expanded_types"."amount" = $5 AND "expanded_types"."meta" = $6 AND "expanded_types"."meta_bin" = $7 AND "expanded_types"."external_id" = $8 AND "expanded_types"."payload" = $9 AND "expanded_types"."published_on" = $10 AND "expanded_types"."processed_at" = $11 AND "expanded_types"."category" = $12)` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if len(args) != 12 { + t.Fatalf("unexpected args length: %d", len(args)) + } +} + +func TestSelectAdvancedComposition(t *testing.T) { + t.Parallel() + + users, posts := defineTables() + cteSales := schema.Define("sales_by_user", func(t *struct { + schema.TableModel + UserID *schema.Column[int64] + Total *schema.Column[int64] + }, + ) { + t.UserID = t.BigInt("user_id") + t.Total = t.BigInt("total") + }) + cteFiltered := schema.Define("filtered_sales", func(t *struct { + schema.TableModel + UserID *schema.Column[int64] + Total *schema.Column[int64] + }, + ) { + t.UserID = t.BigInt("user_id") + t.Total = t.BigInt("total") + }) + + type tc struct { + name string + dialect string + build func(*rain.DB) *rain.SelectQuery + wantSQL string + wantArgs []any + wantErr string + } + + cases := []tc{ + { + name: "distinct rendering postgres", + dialect: "postgres", + build: func(db *rain.DB) *rain.SelectQuery { + return db.Select().Distinct().Table(users).Column(users.ID) + }, + wantSQL: `SELECT DISTINCT "users"."id" FROM "users"`, + }, + { + name: "group by without having mysql", + dialect: "mysql", + build: func(db *rain.DB) *rain.SelectQuery { + return db.Select(). + Table(posts). + Column(posts.UserID, schema.Raw("COUNT(*)")). + GroupBy(posts.UserID) + }, + wantSQL: "SELECT `posts`.`user_id`, COUNT(*) FROM `posts` GROUP BY `posts`.`user_id`", + }, + { + name: "aggregate helpers in select postgres", + dialect: "postgres", + build: func(db *rain.DB) *rain.SelectQuery { + return db.Select(). + Table(posts). + Column( + posts.UserID, + schema.Count().As("post_count"), + schema.Sum(posts.ID).As("id_sum"), + schema.Avg(posts.ID).As("id_avg"), + schema.Min(posts.ID).As("id_min"), + schema.Max(posts.ID).As("id_max"), + ). + GroupBy(posts.UserID) + }, + wantSQL: `SELECT "posts"."user_id", COUNT(*) AS "post_count", SUM("posts"."id") AS "id_sum", AVG("posts"."id") AS "id_avg", MIN("posts"."id") AS "id_min", MAX("posts"."id") AS "id_max" FROM "posts" GROUP BY "posts"."user_id"`, + }, + { + name: "alias helper in where placeholder ordering postgres", + dialect: "postgres", + build: func(db *rain.DB) *rain.SelectQuery { + return db.Select(). + Table(posts). + Column(posts.UserID, schema.Count().As("post_count")). + Where(posts.Title.Eq("hello")). + GroupBy(posts.UserID). + Having(schema.ComparisonExpr{Left: schema.Count(), Operator: ">", Right: schema.ValueExpr{Value: 3}}) + }, + wantSQL: `SELECT "posts"."user_id", COUNT(*) AS "post_count" FROM "posts" WHERE "posts"."title" = $1 GROUP BY "posts"."user_id" HAVING COUNT(*) > $2`, + wantArgs: []any{"hello", 3}, + }, + { + name: "aggregate helper mixed with raw placeholders mysql", + dialect: "mysql", + build: func(db *rain.DB) *rain.SelectQuery { + return db.Select(). + Table(posts). + Column(schema.Sum(posts.ID).As("total_id")). + Where(schema.ComparisonExpr{Left: schema.Raw("COALESCE(?, 0)", 10), Operator: "<", Right: schema.ValueExpr{Value: 50}}) + }, + wantSQL: "SELECT SUM(`posts`.`id`) AS `total_id` FROM `posts` WHERE COALESCE(?, 0) < ?", + wantArgs: []any{10, 50}, + }, + { + name: "column alias helper in select postgres", + dialect: "postgres", + build: func(db *rain.DB) *rain.SelectQuery { + return db.Select(). + Table(users). + Column(users.Email.As("user_email")) + }, + wantSQL: `SELECT "users"."email" AS "user_email" FROM "users"`, + }, + { + name: "aggregate distinct star is invalid", + dialect: "postgres", + build: func(db *rain.DB) *rain.SelectQuery { + return db.Select(). + Table(posts). + Column(schema.AggregateExpr{ + Function: "COUNT", + Distinct: true, + Star: true, + }) + }, + wantErr: "cannot combine DISTINCT with *", + }, + { + name: "aggregate missing function is invalid", + dialect: "postgres", + build: func(db *rain.DB) *rain.SelectQuery { + return db.Select(). + Table(posts). + Column(schema.AggregateExpr{Expr: posts.ID}) + }, + wantErr: "function name cannot be empty", + }, + { + name: "alias in group by is invalid", + dialect: "postgres", + build: func(db *rain.DB) *rain.SelectQuery { + return db.Select(). + Table(posts). + Column(posts.UserID). + GroupBy(schema.As(posts.UserID, "uid")) + }, + wantErr: "aliased expressions are only supported in SELECT columns", + }, + { + name: "group by with having postgres", + dialect: "postgres", + build: func(db *rain.DB) *rain.SelectQuery { + return db.Select(). + Table(posts). + Column(posts.UserID, schema.Raw("COUNT(*)")). + GroupBy(posts.UserID). + Having(schema.ComparisonExpr{ + Left: schema.Raw("COUNT(*)"), + Operator: ">", + Right: schema.ValueExpr{Value: 2}, + }) + }, + wantSQL: `SELECT "posts"."user_id", COUNT(*) FROM "posts" GROUP BY "posts"."user_id" HAVING COUNT(*) > $1`, + wantArgs: []any{2}, + }, + { + name: "single cte postgres", + dialect: "postgres", + build: func(db *rain.DB) *rain.SelectQuery { + salesByUser := db.Select(). + Table(posts). + Column(posts.UserID, schema.Raw("SUM(amount) AS total")). + GroupBy(posts.UserID) + + return db.Select(). + With("sales_by_user", salesByUser). + Table(cteSales). + Column(cteSales.UserID, cteSales.Total) + }, + wantSQL: `WITH "sales_by_user" AS (SELECT "posts"."user_id", SUM(amount) AS total FROM "posts" GROUP BY "posts"."user_id") SELECT "sales_by_user"."user_id", "sales_by_user"."total" FROM "sales_by_user"`, + }, + { + name: "multiple ctes postgres", + dialect: "postgres", + build: func(db *rain.DB) *rain.SelectQuery { + salesByUser := db.Select(). + Table(posts). + Column(posts.UserID, schema.Raw("SUM(amount) AS total")). + GroupBy(posts.UserID) + filtered := db.Select(). + Table(cteSales). + Column(cteSales.UserID, cteSales.Total). + Where(schema.ComparisonExpr{ + Left: schema.Raw("total"), + Operator: ">", + Right: schema.ValueExpr{Value: 100}, + }) + + return db.Select(). + With("sales_by_user", salesByUser). + With("filtered_sales", filtered). + Table(cteFiltered). + Column(cteFiltered.UserID, cteFiltered.Total) + }, + wantSQL: `WITH "sales_by_user" AS (SELECT "posts"."user_id", SUM(amount) AS total FROM "posts" GROUP BY "posts"."user_id"), "filtered_sales" AS (SELECT "sales_by_user"."user_id", "sales_by_user"."total" FROM "sales_by_user" WHERE total > $1) SELECT "filtered_sales"."user_id", "filtered_sales"."total" FROM "filtered_sales"`, + wantArgs: []any{100}, + }, + { + name: "subquery in from placeholder numbering postgres", + dialect: "postgres", + build: func(db *rain.DB) *rain.SelectQuery { + postsByUser := db.Select(). + Table(posts). + Column(posts.UserID, schema.Raw("COUNT(*)").As("post_count")). + Where(posts.Title.Eq("hello")). + GroupBy(posts.UserID) + + return db.Select(). + TableSubquery(postsByUser, "pbu"). + Column(schema.Raw("pbu.user_id"), schema.Raw("pbu.post_count")). + Where(schema.ComparisonExpr{ + Left: schema.Raw("pbu.post_count"), + Operator: ">", + Right: schema.ValueExpr{Value: 3}, + }) + }, + wantSQL: `SELECT pbu.user_id, pbu.post_count FROM (SELECT "posts"."user_id", COUNT(*) AS "post_count" FROM "posts" WHERE "posts"."title" = $1 GROUP BY "posts"."user_id") AS "pbu" WHERE pbu.post_count > $2`, + wantArgs: []any{"hello", 3}, + }, + { + name: "subquery in join mysql", + dialect: "mysql", + build: func(db *rain.DB) *rain.SelectQuery { + userPosts := db.Select(). + Table(posts). + Column(posts.UserID, schema.Raw("COUNT(*)").As("post_count")). + GroupBy(posts.UserID) + + return db.Select(). + Table(users). + Column(users.ID, schema.Raw("up.post_count")). + JoinSubquery(userPosts, "up", schema.ComparisonExpr{ + Left: users.ID, + Operator: "=", + Right: schema.Raw("up.user_id"), + }) + }, + wantSQL: "SELECT `users`.`id`, up.post_count FROM `users` INNER JOIN (SELECT `posts`.`user_id`, COUNT(*) AS `post_count` FROM `posts` GROUP BY `posts`.`user_id`) AS `up` ON `users`.`id` = up.user_id", + }, + { + name: "left join subquery postgres", + dialect: "postgres", + build: func(db *rain.DB) *rain.SelectQuery { + userPosts := db.Select(). + Table(posts). + Column(posts.UserID, schema.Raw("COUNT(*)").As("post_count")). + GroupBy(posts.UserID) + + return db.Select(). + Table(users). + Column(users.ID, schema.Raw("up.post_count")). + LeftJoinSubquery(userPosts, "up", schema.ComparisonExpr{ + Left: users.ID, + Operator: "=", + Right: schema.Raw("up.user_id"), + }) + }, + wantSQL: `SELECT "users"."id", up.post_count FROM "users" LEFT JOIN (SELECT "posts"."user_id", COUNT(*) AS "post_count" FROM "posts" GROUP BY "posts"."user_id") AS "up" ON "users"."id" = up.user_id`, + }, + { + name: "subquery without alias is invalid", + dialect: "postgres", + build: func(db *rain.DB) *rain.SelectQuery { + return db.Select(). + TableSubquery(db.Select().Table(users), ""). + Column(schema.Raw("id")) + }, + wantErr: "requires a non-empty alias", + }, + { + name: "subquery without query is invalid", + dialect: "postgres", + build: func(db *rain.DB) *rain.SelectQuery { + return db.Select(). + TableSubquery(nil, "sq"). + Column(schema.Raw("id")) + }, + wantErr: "requires a non-nil query", + }, + { + name: "cte unsupported on mysql", + dialect: "mysql", + build: func(db *rain.DB) *rain.SelectQuery { + base := db.Select().Table(users) + return db.Select().With("u", base).Table(users) + }, + wantErr: "do not support CTEs", + }, + { + name: "nested cte body is invalid", + dialect: "postgres", + build: func(db *rain.DB) *rain.SelectQuery { + inner := db.Select().Table(users).Column(users.ID) + outerBody := db.Select(). + With("inner", inner). + Table(users). + Column(users.ID) + + return db.Select(). + With("outer", outerBody). + Table(users). + Column(users.ID) + }, + wantErr: `CTE "outer" body cannot itself contain CTEs`, + }, + } + + for _, tt := range cases { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect(tt.dialect) + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + + sqlText, args, err := tt.build(db).ToSQL() + if tt.wantErr != "" { + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got %v", tt.wantErr, err) + } + return + } + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + if sqlText != tt.wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", tt.wantSQL, sqlText) + } + if len(args) != len(tt.wantArgs) { + t.Fatalf("unexpected arg count: want %d got %d (%#v)", len(tt.wantArgs), len(args), args) + } + for idx := range tt.wantArgs { + if args[idx] != tt.wantArgs[idx] { + t.Fatalf("unexpected arg[%d]: want %#v got %#v", idx, tt.wantArgs[idx], args[idx]) + } + } + }) + } +} + +func TestSelectInPredicateToSQL(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + sqlText, args, err := db.Select(). + Table(users). + Where(users.ID.In(int64(3), int64(5), int64(8))). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `SELECT * FROM "users" WHERE "users"."id" IN ($1, $2, $3)` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if !reflect.DeepEqual(args, []any{int64(3), int64(5), int64(8)}) { + t.Fatalf("unexpected args: %#v", args) + } +} diff --git a/pkg/rain/query_test.go b/pkg/rain/query_test.go index 1f385ce..5147a2e 100644 --- a/pkg/rain/query_test.go +++ b/pkg/rain/query_test.go @@ -1,13 +1,8 @@ package rain_test import ( - "reflect" - "strings" - "testing" "time" - "github.com/hyperlocalise/rain-orm/pkg/dialect" - "github.com/hyperlocalise/rain-orm/pkg/rain" "github.com/hyperlocalise/rain-orm/pkg/schema" ) @@ -86,893 +81,3 @@ func defineExpandedTypesTable() *expandedTypesTable { t.Category = t.Enum("category", "alpha", "beta").NotNull() }) } - -func TestSelectToSQL(t *testing.T) { - db, err := rain.OpenDialect("postgres") - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - users, posts := defineTables() - u := schema.Alias(users, "u") - p := schema.Alias(posts, "p") - - sqlText, args, err := db.Select(). - Table(p). - Column(p.ID, p.Title, u.Email). - Join(u, p.UserID.EqCol(u.ID)). - Where(u.Active.Eq(true)). - OrderBy(p.ID.Desc()). - Limit(10). - ToSQL() - if err != nil { - t.Fatalf("ToSQL returned error: %v", err) - } - - wantSQL := `SELECT "p"."id", "p"."title", "u"."email" FROM "posts" AS "p" INNER JOIN "users" AS "u" ON "p"."user_id" = "u"."id" WHERE "u"."active" = $1 ORDER BY "p"."id" DESC LIMIT 10` - if sqlText != wantSQL { - t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) - } - if len(args) != 1 || args[0] != true { - t.Fatalf("unexpected args: %#v", args) - } -} - -func TestExpandedTypesCompileToSQL(t *testing.T) { - t.Parallel() - - db, err := rain.OpenDialect("postgres") - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - expanded := defineExpandedTypesTable() - processedAt := time.Date(2026, 3, 28, 10, 30, 0, 0, time.UTC) - publishedOn := time.Date(2026, 3, 28, 0, 0, 0, 0, time.UTC) - - sqlText, args, err := db.Select(). - Table(expanded). - Column( - expanded.SmallCount, - expanded.Count, - expanded.Score, - expanded.Precise, - expanded.Amount, - expanded.Meta, - expanded.MetaBin, - expanded.ExternalID, - expanded.Payload, - expanded.PublishedOn, - expanded.ProcessedAt, - expanded.Category, - ). - Where(schema.And( - expanded.SmallCount.Eq(3), - expanded.Count.Eq(11), - expanded.Score.Gt(1.5), - expanded.Precise.Lte(7.25), - expanded.Amount.Eq("42.10"), - expanded.Meta.Eq(map[string]any{"enabled": true}), - expanded.MetaBin.Eq(map[string]any{"raw": "yes"}), - expanded.ExternalID.Eq("00000000-0000-0000-0000-000000000042"), - expanded.Payload.Eq([]byte{0xCA, 0xFE}), - expanded.PublishedOn.Eq(publishedOn), - expanded.ProcessedAt.Eq(processedAt), - expanded.Category.Eq("alpha"), - )). - ToSQL() - if err != nil { - t.Fatalf("ToSQL returned error: %v", err) - } - - wantSQL := `SELECT "expanded_types"."small_count", "expanded_types"."count", "expanded_types"."score", "expanded_types"."precise", "expanded_types"."amount", "expanded_types"."meta", "expanded_types"."meta_bin", "expanded_types"."external_id", "expanded_types"."payload", "expanded_types"."published_on", "expanded_types"."processed_at", "expanded_types"."category" FROM "expanded_types" WHERE ("expanded_types"."small_count" = $1 AND "expanded_types"."count" = $2 AND "expanded_types"."score" > $3 AND "expanded_types"."precise" <= $4 AND "expanded_types"."amount" = $5 AND "expanded_types"."meta" = $6 AND "expanded_types"."meta_bin" = $7 AND "expanded_types"."external_id" = $8 AND "expanded_types"."payload" = $9 AND "expanded_types"."published_on" = $10 AND "expanded_types"."processed_at" = $11 AND "expanded_types"."category" = $12)` - if sqlText != wantSQL { - t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) - } - if len(args) != 12 { - t.Fatalf("unexpected args length: %d", len(args)) - } -} - -func TestSelectAdvancedComposition(t *testing.T) { - t.Parallel() - - users, posts := defineTables() - cteSales := schema.Define("sales_by_user", func(t *struct { - schema.TableModel - UserID *schema.Column[int64] - Total *schema.Column[int64] - }, - ) { - t.UserID = t.BigInt("user_id") - t.Total = t.BigInt("total") - }) - cteFiltered := schema.Define("filtered_sales", func(t *struct { - schema.TableModel - UserID *schema.Column[int64] - Total *schema.Column[int64] - }, - ) { - t.UserID = t.BigInt("user_id") - t.Total = t.BigInt("total") - }) - - type tc struct { - name string - dialect string - build func(*rain.DB) *rain.SelectQuery - wantSQL string - wantArgs []any - wantErr string - } - - cases := []tc{ - { - name: "distinct rendering postgres", - dialect: "postgres", - build: func(db *rain.DB) *rain.SelectQuery { - return db.Select().Distinct().Table(users).Column(users.ID) - }, - wantSQL: `SELECT DISTINCT "users"."id" FROM "users"`, - }, - { - name: "group by without having mysql", - dialect: "mysql", - build: func(db *rain.DB) *rain.SelectQuery { - return db.Select(). - Table(posts). - Column(posts.UserID, schema.Raw("COUNT(*)")). - GroupBy(posts.UserID) - }, - wantSQL: "SELECT `posts`.`user_id`, COUNT(*) FROM `posts` GROUP BY `posts`.`user_id`", - }, - { - name: "aggregate helpers in select postgres", - dialect: "postgres", - build: func(db *rain.DB) *rain.SelectQuery { - return db.Select(). - Table(posts). - Column( - posts.UserID, - schema.Count().As("post_count"), - schema.Sum(posts.ID).As("id_sum"), - schema.Avg(posts.ID).As("id_avg"), - schema.Min(posts.ID).As("id_min"), - schema.Max(posts.ID).As("id_max"), - ). - GroupBy(posts.UserID) - }, - wantSQL: `SELECT "posts"."user_id", COUNT(*) AS "post_count", SUM("posts"."id") AS "id_sum", AVG("posts"."id") AS "id_avg", MIN("posts"."id") AS "id_min", MAX("posts"."id") AS "id_max" FROM "posts" GROUP BY "posts"."user_id"`, - }, - { - name: "alias helper in where placeholder ordering postgres", - dialect: "postgres", - build: func(db *rain.DB) *rain.SelectQuery { - return db.Select(). - Table(posts). - Column(posts.UserID, schema.Count().As("post_count")). - Where(posts.Title.Eq("hello")). - GroupBy(posts.UserID). - Having(schema.ComparisonExpr{Left: schema.Count(), Operator: ">", Right: schema.ValueExpr{Value: 3}}) - }, - wantSQL: `SELECT "posts"."user_id", COUNT(*) AS "post_count" FROM "posts" WHERE "posts"."title" = $1 GROUP BY "posts"."user_id" HAVING COUNT(*) > $2`, - wantArgs: []any{"hello", 3}, - }, - { - name: "aggregate helper mixed with raw placeholders mysql", - dialect: "mysql", - build: func(db *rain.DB) *rain.SelectQuery { - return db.Select(). - Table(posts). - Column(schema.Sum(posts.ID).As("total_id")). - Where(schema.ComparisonExpr{Left: schema.Raw("COALESCE(?, 0)", 10), Operator: "<", Right: schema.ValueExpr{Value: 50}}) - }, - wantSQL: "SELECT SUM(`posts`.`id`) AS `total_id` FROM `posts` WHERE COALESCE(?, 0) < ?", - wantArgs: []any{10, 50}, - }, - { - name: "column alias helper in select postgres", - dialect: "postgres", - build: func(db *rain.DB) *rain.SelectQuery { - return db.Select(). - Table(users). - Column(users.Email.As("user_email")) - }, - wantSQL: `SELECT "users"."email" AS "user_email" FROM "users"`, - }, - { - name: "aggregate distinct star is invalid", - dialect: "postgres", - build: func(db *rain.DB) *rain.SelectQuery { - return db.Select(). - Table(posts). - Column(schema.AggregateExpr{ - Function: "COUNT", - Distinct: true, - Star: true, - }) - }, - wantErr: "cannot combine DISTINCT with *", - }, - { - name: "aggregate missing function is invalid", - dialect: "postgres", - build: func(db *rain.DB) *rain.SelectQuery { - return db.Select(). - Table(posts). - Column(schema.AggregateExpr{Expr: posts.ID}) - }, - wantErr: "function name cannot be empty", - }, - { - name: "alias in group by is invalid", - dialect: "postgres", - build: func(db *rain.DB) *rain.SelectQuery { - return db.Select(). - Table(posts). - Column(posts.UserID). - GroupBy(schema.As(posts.UserID, "uid")) - }, - wantErr: "aliased expressions are only supported in SELECT columns", - }, - { - name: "group by with having postgres", - dialect: "postgres", - build: func(db *rain.DB) *rain.SelectQuery { - return db.Select(). - Table(posts). - Column(posts.UserID, schema.Raw("COUNT(*)")). - GroupBy(posts.UserID). - Having(schema.ComparisonExpr{ - Left: schema.Raw("COUNT(*)"), - Operator: ">", - Right: schema.ValueExpr{Value: 2}, - }) - }, - wantSQL: `SELECT "posts"."user_id", COUNT(*) FROM "posts" GROUP BY "posts"."user_id" HAVING COUNT(*) > $1`, - wantArgs: []any{2}, - }, - { - name: "single cte postgres", - dialect: "postgres", - build: func(db *rain.DB) *rain.SelectQuery { - salesByUser := db.Select(). - Table(posts). - Column(posts.UserID, schema.Raw("SUM(amount) AS total")). - GroupBy(posts.UserID) - - return db.Select(). - With("sales_by_user", salesByUser). - Table(cteSales). - Column(cteSales.UserID, cteSales.Total) - }, - wantSQL: `WITH "sales_by_user" AS (SELECT "posts"."user_id", SUM(amount) AS total FROM "posts" GROUP BY "posts"."user_id") SELECT "sales_by_user"."user_id", "sales_by_user"."total" FROM "sales_by_user"`, - }, - { - name: "multiple ctes postgres", - dialect: "postgres", - build: func(db *rain.DB) *rain.SelectQuery { - salesByUser := db.Select(). - Table(posts). - Column(posts.UserID, schema.Raw("SUM(amount) AS total")). - GroupBy(posts.UserID) - filtered := db.Select(). - Table(cteSales). - Column(cteSales.UserID, cteSales.Total). - Where(schema.ComparisonExpr{ - Left: schema.Raw("total"), - Operator: ">", - Right: schema.ValueExpr{Value: 100}, - }) - - return db.Select(). - With("sales_by_user", salesByUser). - With("filtered_sales", filtered). - Table(cteFiltered). - Column(cteFiltered.UserID, cteFiltered.Total) - }, - wantSQL: `WITH "sales_by_user" AS (SELECT "posts"."user_id", SUM(amount) AS total FROM "posts" GROUP BY "posts"."user_id"), "filtered_sales" AS (SELECT "sales_by_user"."user_id", "sales_by_user"."total" FROM "sales_by_user" WHERE total > $1) SELECT "filtered_sales"."user_id", "filtered_sales"."total" FROM "filtered_sales"`, - wantArgs: []any{100}, - }, - { - name: "subquery in from placeholder numbering postgres", - dialect: "postgres", - build: func(db *rain.DB) *rain.SelectQuery { - postsByUser := db.Select(). - Table(posts). - Column(posts.UserID, schema.Raw("COUNT(*)").As("post_count")). - Where(posts.Title.Eq("hello")). - GroupBy(posts.UserID) - - return db.Select(). - TableSubquery(postsByUser, "pbu"). - Column(schema.Raw("pbu.user_id"), schema.Raw("pbu.post_count")). - Where(schema.ComparisonExpr{ - Left: schema.Raw("pbu.post_count"), - Operator: ">", - Right: schema.ValueExpr{Value: 3}, - }) - }, - wantSQL: `SELECT pbu.user_id, pbu.post_count FROM (SELECT "posts"."user_id", COUNT(*) AS "post_count" FROM "posts" WHERE "posts"."title" = $1 GROUP BY "posts"."user_id") AS "pbu" WHERE pbu.post_count > $2`, - wantArgs: []any{"hello", 3}, - }, - { - name: "subquery in join mysql", - dialect: "mysql", - build: func(db *rain.DB) *rain.SelectQuery { - userPosts := db.Select(). - Table(posts). - Column(posts.UserID, schema.Raw("COUNT(*)").As("post_count")). - GroupBy(posts.UserID) - - return db.Select(). - Table(users). - Column(users.ID, schema.Raw("up.post_count")). - JoinSubquery(userPosts, "up", schema.ComparisonExpr{ - Left: users.ID, - Operator: "=", - Right: schema.Raw("up.user_id"), - }) - }, - wantSQL: "SELECT `users`.`id`, up.post_count FROM `users` INNER JOIN (SELECT `posts`.`user_id`, COUNT(*) AS `post_count` FROM `posts` GROUP BY `posts`.`user_id`) AS `up` ON `users`.`id` = up.user_id", - }, - { - name: "left join subquery postgres", - dialect: "postgres", - build: func(db *rain.DB) *rain.SelectQuery { - userPosts := db.Select(). - Table(posts). - Column(posts.UserID, schema.Raw("COUNT(*)").As("post_count")). - GroupBy(posts.UserID) - - return db.Select(). - Table(users). - Column(users.ID, schema.Raw("up.post_count")). - LeftJoinSubquery(userPosts, "up", schema.ComparisonExpr{ - Left: users.ID, - Operator: "=", - Right: schema.Raw("up.user_id"), - }) - }, - wantSQL: `SELECT "users"."id", up.post_count FROM "users" LEFT JOIN (SELECT "posts"."user_id", COUNT(*) AS "post_count" FROM "posts" GROUP BY "posts"."user_id") AS "up" ON "users"."id" = up.user_id`, - }, - { - name: "subquery without alias is invalid", - dialect: "postgres", - build: func(db *rain.DB) *rain.SelectQuery { - return db.Select(). - TableSubquery(db.Select().Table(users), ""). - Column(schema.Raw("id")) - }, - wantErr: "requires a non-empty alias", - }, - { - name: "subquery without query is invalid", - dialect: "postgres", - build: func(db *rain.DB) *rain.SelectQuery { - return db.Select(). - TableSubquery(nil, "sq"). - Column(schema.Raw("id")) - }, - wantErr: "requires a non-nil query", - }, - { - name: "cte unsupported on mysql", - dialect: "mysql", - build: func(db *rain.DB) *rain.SelectQuery { - base := db.Select().Table(users) - return db.Select().With("u", base).Table(users) - }, - wantErr: "do not support CTEs", - }, - { - name: "nested cte body is invalid", - dialect: "postgres", - build: func(db *rain.DB) *rain.SelectQuery { - inner := db.Select().Table(users).Column(users.ID) - outerBody := db.Select(). - With("inner", inner). - Table(users). - Column(users.ID) - - return db.Select(). - With("outer", outerBody). - Table(users). - Column(users.ID) - }, - wantErr: `CTE "outer" body cannot itself contain CTEs`, - }, - } - - for _, tt := range cases { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - db, err := rain.OpenDialect(tt.dialect) - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - - sqlText, args, err := tt.build(db).ToSQL() - if tt.wantErr != "" { - if err == nil || !strings.Contains(err.Error(), tt.wantErr) { - t.Fatalf("expected error containing %q, got %v", tt.wantErr, err) - } - return - } - if err != nil { - t.Fatalf("ToSQL returned error: %v", err) - } - if sqlText != tt.wantSQL { - t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", tt.wantSQL, sqlText) - } - if len(args) != len(tt.wantArgs) { - t.Fatalf("unexpected arg count: want %d got %d (%#v)", len(tt.wantArgs), len(args), args) - } - for idx := range tt.wantArgs { - if args[idx] != tt.wantArgs[idx] { - t.Fatalf("unexpected arg[%d]: want %#v got %#v", idx, tt.wantArgs[idx], args[idx]) - } - } - }) - } -} - -func TestSelectInPredicateToSQL(t *testing.T) { - t.Parallel() - - db, err := rain.OpenDialect("postgres") - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - users, _ := defineTables() - - sqlText, args, err := db.Select(). - Table(users). - Where(users.ID.In(int64(3), int64(5), int64(8))). - ToSQL() - if err != nil { - t.Fatalf("ToSQL returned error: %v", err) - } - - wantSQL := `SELECT * FROM "users" WHERE "users"."id" IN ($1, $2, $3)` - if sqlText != wantSQL { - t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) - } - if !reflect.DeepEqual(args, []any{int64(3), int64(5), int64(8)}) { - t.Fatalf("unexpected args: %#v", args) - } -} - -func TestInsertUpdateDeleteToSQL(t *testing.T) { - db, err := rain.OpenDialect("postgres") - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - users, _ := defineTables() - - insertSQL, insertArgs, err := db.Insert(). - Table(users). - Model(&userModel{Email: "alice@example.com", Name: "Alice", Active: true}). - Returning(users.ID). - ToSQL() - if err != nil { - t.Fatalf("insert ToSQL returned error: %v", err) - } - wantInsert := `INSERT INTO "users" ("email", "name", "active") VALUES ($1, $2, $3) RETURNING "users"."id"` - if insertSQL != wantInsert { - t.Fatalf("unexpected insert SQL:\nwant: %s\ngot: %s", wantInsert, insertSQL) - } - if len(insertArgs) != 3 { - t.Fatalf("unexpected insert args: %#v", insertArgs) - } - - updateSQL, updateArgs, err := db.Update(). - Table(users). - Set(users.Name, "Alice Smith"). - Where(users.ID.Eq(int64(1))). - ToSQL() - if err != nil { - t.Fatalf("update ToSQL returned error: %v", err) - } - wantUpdate := `UPDATE "users" SET "name" = $1 WHERE "users"."id" = $2` - if updateSQL != wantUpdate { - t.Fatalf("unexpected update SQL:\nwant: %s\ngot: %s", wantUpdate, updateSQL) - } - if len(updateArgs) != 2 { - t.Fatalf("unexpected update args: %#v", updateArgs) - } - - deleteSQL, deleteArgs, err := db.Delete(). - Table(users). - Where(users.ID.Eq(int64(99))). - ToSQL() - if err != nil { - t.Fatalf("delete ToSQL returned error: %v", err) - } - wantDelete := `DELETE FROM "users" WHERE "users"."id" = $1` - if deleteSQL != wantDelete { - t.Fatalf("unexpected delete SQL:\nwant: %s\ngot: %s", wantDelete, deleteSQL) - } - if len(deleteArgs) != 1 || deleteArgs[0] != int64(99) { - t.Fatalf("unexpected delete args: %#v", deleteArgs) - } -} - -func TestDialectFeatures(t *testing.T) { - t.Parallel() - - cases := []struct { - name string - dialect string - features dialect.Feature - missing []dialect.Feature - }{ - { - name: "postgres", - dialect: "postgres", - features: dialect.FeatureInsertReturning | - dialect.FeatureUpdateReturning | - dialect.FeatureDeleteReturning | - dialect.FeatureOffset | - dialect.FeatureUpsert | - dialect.FeatureCTE | - dialect.FeatureDefaultPlaceholder | - dialect.FeatureSavepoint, - }, - { - name: "mysql", - dialect: "mysql", - features: dialect.FeatureOffset | dialect.FeatureUpsert | dialect.FeatureSavepoint, - missing: []dialect.Feature{ - dialect.FeatureInsertReturning, - dialect.FeatureUpdateReturning, - dialect.FeatureDeleteReturning, - dialect.FeatureCTE, - dialect.FeatureDefaultPlaceholder, - }, - }, - { - name: "sqlite", - dialect: "sqlite", - features: dialect.FeatureInsertReturning | - dialect.FeatureUpdateReturning | - dialect.FeatureDeleteReturning | - dialect.FeatureOffset | - dialect.FeatureUpsert | - dialect.FeatureSavepoint, - missing: []dialect.Feature{ - dialect.FeatureCTE, - dialect.FeatureDefaultPlaceholder, - }, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - db, err := rain.OpenDialect(tc.dialect) - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - got := db.Dialect().Features() - if got != tc.features { - t.Fatalf("unexpected features: want %b got %b", tc.features, got) - } - for _, feature := range tc.missing { - if dialect.HasFeature(got, feature) { - t.Fatalf("expected feature %b to be absent from %b", feature, got) - } - } - }) - } -} - -func TestOpenDialectUnknownDialectReturnsError(t *testing.T) { - t.Parallel() - - db, err := rain.OpenDialect("postres") - if err == nil { - t.Fatalf("expected unsupported dialect error, got nil") - } - if db != nil { - t.Fatalf("expected nil db for unsupported dialect") - } -} - -func TestReturningUnsupportedDialect(t *testing.T) { - db, err := rain.OpenDialect("mysql") - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - users, _ := defineTables() - - _, _, err = db.Insert(). - Table(users). - Set(users.Name, "Alice"). - Returning(users.ID). - ToSQL() - if err == nil || !strings.Contains(err.Error(), "insert queries do not support RETURNING") { - t.Fatalf("expected insert RETURNING to fail on mysql dialect, got %v", err) - } - - _, _, err = db.Update(). - Table(users). - Set(users.Name, "Alice"). - Where(users.ID.Eq(int64(1))). - Returning(users.ID). - ToSQL() - if err == nil || !strings.Contains(err.Error(), "update queries do not support RETURNING") { - t.Fatalf("expected update RETURNING to fail on mysql dialect, got %v", err) - } - - _, _, err = db.Delete(). - Table(users). - Where(users.ID.Eq(int64(1))). - Returning(users.ID). - ToSQL() - if err == nil || !strings.Contains(err.Error(), "delete queries do not support RETURNING") { - t.Fatalf("expected delete RETURNING to fail on mysql dialect, got %v", err) - } -} - -func TestReturningSupportedOperations(t *testing.T) { - db, err := rain.OpenDialect("postgres") - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - users, _ := defineTables() - - insertSQL, _, err := db.Insert(). - Table(users). - Set(users.Name, "Alice"). - Returning(users.ID). - ToSQL() - if err != nil || !strings.Contains(insertSQL, "RETURNING") { - t.Fatalf("expected insert RETURNING to compile, got sql=%q err=%v", insertSQL, err) - } - - updateSQL, _, err := db.Update(). - Table(users). - Set(users.Name, "Alice"). - Where(users.ID.Eq(int64(1))). - Returning(users.ID). - ToSQL() - if err != nil || !strings.Contains(updateSQL, "RETURNING") { - t.Fatalf("expected update RETURNING to compile, got sql=%q err=%v", updateSQL, err) - } - - deleteSQL, _, err := db.Delete(). - Table(users). - Where(users.ID.Eq(int64(1))). - Returning(users.ID). - ToSQL() - if err != nil || !strings.Contains(deleteSQL, "RETURNING") { - t.Fatalf("expected delete RETURNING to compile, got sql=%q err=%v", deleteSQL, err) - } -} - -func TestInsertModelAndSetMergeToSQL(t *testing.T) { - db, err := rain.OpenDialect("postgres") - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - users, _ := defineTables() - - sqlText, args, err := db.Insert(). - Table(users). - Model(&userModel{Email: "alice@example.com", Name: "", Active: false}). - Set(users.Name, "Alice"). - Set(users.Active, false). - ToSQL() - if err != nil { - t.Fatalf("insert merge ToSQL returned error: %v", err) - } - - wantSQL := `INSERT INTO "users" ("email", "name", "active") VALUES ($1, $2, $3)` - if sqlText != wantSQL { - t.Fatalf("unexpected merged insert SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) - } - if len(args) != 3 || args[0] != "alice@example.com" || args[1] != "Alice" || args[2] != false { - t.Fatalf("unexpected merged insert args: %#v", args) - } -} - -func TestInsertOmitDefaultBackedZeroValues(t *testing.T) { - db, err := rain.OpenDialect("postgres") - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - users, _ := defineTables() - - sqlText, args, err := db.Insert(). - Table(users). - Model(&userModel{Email: "alice@example.com"}). - ToSQL() - if err != nil { - t.Fatalf("insert default omission ToSQL returned error: %v", err) - } - - wantSQL := `INSERT INTO "users" ("email", "name") VALUES ($1, $2)` - if sqlText != wantSQL { - t.Fatalf("unexpected default-omitting insert SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) - } - if len(args) != 2 || args[0] != "alice@example.com" || args[1] != "" { - t.Fatalf("unexpected default-omitting insert args: %#v", args) - } -} - -func TestInsertMultiRowModelsToSQL(t *testing.T) { - db, err := rain.OpenDialect("postgres") - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - users, _ := defineTables() - - sqlText, args, err := db.Insert(). - Table(users). - Models([]userModel{ - {Email: "alice@example.com", Name: "Alice", Active: true}, - {Email: "bob@example.com", Name: "Bob", Active: true}, - }). - Returning(users.ID). - ToSQL() - if err != nil { - t.Fatalf("insert multi model ToSQL returned error: %v", err) - } - - wantSQL := `INSERT INTO "users" ("email", "name", "active") VALUES ($1, $2, $3), ($4, $5, $6) RETURNING "users"."id"` - if sqlText != wantSQL { - t.Fatalf("unexpected multi model insert SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) - } - wantArgs := []any{"alice@example.com", "Alice", true, "bob@example.com", "Bob", true} - if !reflect.DeepEqual(args, wantArgs) { - t.Fatalf("unexpected multi model insert args: %#v", args) - } -} - -func TestInsertMultiRowValuesToSQL(t *testing.T) { - db, err := rain.OpenDialect("postgres") - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - users, _ := defineTables() - - sqlText, args, err := db.Insert(). - Table(users). - Values( - map[schema.ColumnReference]any{users.Email: "alice@example.com", users.Name: "Alice", users.Active: true}, - map[schema.ColumnReference]any{users.Email: "bob@example.com", users.Name: "Bob", users.Active: false}, - ). - ToSQL() - if err != nil { - t.Fatalf("insert multi values ToSQL returned error: %v", err) - } - - wantSQL := `INSERT INTO "users" ("email", "name", "active") VALUES ($1, $2, $3), ($4, $5, $6)` - if sqlText != wantSQL { - t.Fatalf("unexpected multi values insert SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) - } - wantArgs := []any{"alice@example.com", "Alice", true, "bob@example.com", "Bob", false} - if !reflect.DeepEqual(args, wantArgs) { - t.Fatalf("unexpected multi values insert args: %#v", args) - } -} - -func TestInsertMultiRowColumnMismatchReturnsError(t *testing.T) { - db, err := rain.OpenDialect("postgres") - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - users, _ := defineTables() - - _, _, err = db.Insert(). - Table(users). - Models([]userModel{ - {Email: "alice@example.com", Name: "Alice", Active: true}, - {Email: "bob@example.com", Name: "", Active: false}, - }). - ToSQL() - if err == nil || !strings.Contains(err.Error(), "targets 2 columns, expected 3") { - t.Fatalf("expected column mismatch error, got %v", err) - } -} - -func TestInsertOnConflictPostgres(t *testing.T) { - db, err := rain.OpenDialect("postgres") - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - users, _ := defineTables() - - t.Run("do nothing", func(t *testing.T) { - sqlText, args, err := db.Insert(). - Table(users). - Set(users.Email, "alice@example.com"). - Set(users.Name, "Alice"). - OnConflict(users.Email). - DoNothing(). - ToSQL() - if err != nil { - t.Fatalf("insert on conflict do nothing ToSQL returned error: %v", err) - } - - wantSQL := `INSERT INTO "users" ("email", "name") VALUES ($1, $2) ON CONFLICT ("email") DO NOTHING` - if sqlText != wantSQL { - t.Fatalf("unexpected do nothing SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) - } - if len(args) != 2 { - t.Fatalf("unexpected do nothing args: %#v", args) - } - }) - - t.Run("do update set", func(t *testing.T) { - sqlText, args, err := db.Insert(). - Table(users). - Set(users.Email, "alice@example.com"). - Set(users.Name, "Alice"). - Set(users.Active, true). - OnConflict(users.Email). - DoUpdateSet(users.Name, users.Active). - ToSQL() - if err != nil { - t.Fatalf("insert on conflict do update ToSQL returned error: %v", err) - } - - wantSQL := `INSERT INTO "users" ("email", "name", "active") VALUES ($1, $2, $3) ON CONFLICT ("email") DO UPDATE SET "name" = EXCLUDED."name", "active" = EXCLUDED."active"` - if sqlText != wantSQL { - t.Fatalf("unexpected do update SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) - } - if len(args) != 3 { - t.Fatalf("unexpected do update args: %#v", args) - } - }) -} - -func TestInsertOnConflictSQLite(t *testing.T) { - db, err := rain.OpenDialect("sqlite") - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - users, _ := defineTables() - - sqlText, args, err := db.Insert(). - Table(users). - Set(users.Email, "alice@example.com"). - Set(users.Name, "Alice"). - Set(users.Active, true). - OnConflict(users.Email). - DoUpdateSet(users.Name, users.Active). - ToSQL() - if err != nil { - t.Fatalf("insert on conflict sqlite ToSQL returned error: %v", err) - } - - wantSQL := `INSERT INTO "users" ("email", "name", "active") VALUES (?, ?, ?) ON CONFLICT ("email") DO UPDATE SET "name" = EXCLUDED."name", "active" = EXCLUDED."active"` - if sqlText != wantSQL { - t.Fatalf("unexpected sqlite do update SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) - } - wantArgs := []any{"alice@example.com", "Alice", true} - if !reflect.DeepEqual(args, wantArgs) { - t.Fatalf("unexpected sqlite do update args: %#v", args) - } -} - -func TestInsertOnConflictUnsupportedDialectReturnsError(t *testing.T) { - db, err := rain.OpenDialect("mysql") - if err != nil { - t.Fatalf("OpenDialect returned error: %v", err) - } - users, _ := defineTables() - - _, _, err = db.Insert(). - Table(users). - Set(users.Email, "alice@example.com"). - Set(users.Name, "Alice"). - OnConflict(users.Email). - DoUpdateSet(users.Name). - ToSQL() - if err == nil || !strings.Contains(err.Error(), "not implemented") { - t.Fatalf("expected unsupported dialect error, got %v", err) - } -} diff --git a/pkg/rain/query_update.go b/pkg/rain/query_update.go new file mode 100644 index 0000000..1b8dba5 --- /dev/null +++ b/pkg/rain/query_update.go @@ -0,0 +1,137 @@ +package rain + +import ( + "context" + "database/sql" + "errors" + + "github.com/hyperlocalise/rain-orm/pkg/dialect" + "github.com/hyperlocalise/rain-orm/pkg/schema" +) + +// UpdateQuery builds typed UPDATE statements. +type UpdateQuery struct { + runner queryRunner + dialect dialect.Dialect + table *schema.TableDef + values []assignment + where []schema.Predicate + returning []schema.Expression + unbounded bool +} + +// Table sets the UPDATE target table. +func (q *UpdateQuery) Table(table schema.TableReference) *UpdateQuery { + q.table = table.TableDef() + return q +} + +// Set adds an explicit typed assignment. +func (q *UpdateQuery) Set(column schema.ColumnReference, value any) *UpdateQuery { + q.values = append(q.values, assignment{column: column, value: schema.ValueExpr{Value: value}}) + return q +} + +// Where appends a WHERE predicate joined with AND. +func (q *UpdateQuery) Where(predicate schema.Predicate) *UpdateQuery { + q.where = append(q.where, predicate) + return q +} + +// Returning adds RETURNING expressions when supported by the dialect. +func (q *UpdateQuery) Returning(exprs ...schema.Expression) *UpdateQuery { + q.returning = append(q.returning, exprs...) + return q +} + +// Unbounded allows UPDATE without a WHERE clause. +func (q *UpdateQuery) Unbounded() *UpdateQuery { + q.unbounded = true + return q +} + +// ToSQL compiles the update into SQL and args. +func (q *UpdateQuery) ToSQL() (string, []any, error) { + if q.table == nil { + return "", nil, errors.New("rain: update query requires a table") + } + if len(q.values) == 0 { + return "", nil, errors.New("rain: update query requires at least one assignment") + } + if len(q.where) == 0 && !q.unbounded { + return "", nil, errors.New("rain: update query requires at least one WHERE predicate; call Unbounded() to allow all rows") + } + + ctx := newCompileContext(q.dialect) + ctx.writeString("UPDATE ") + ctx.writeTableName(q.table) + ctx.writeString(" SET ") + for idx, item := range q.values { + if idx > 0 { + ctx.writeString(", ") + } + ctx.writeQuotedIdentifier(item.column.ColumnDef().Name) + ctx.writeString(" = ") + if err := ctx.writeExpression(item.value); err != nil { + return "", nil, err + } + } + + if len(q.where) > 0 { + ctx.writeString(" WHERE ") + if err := ctx.writePredicate(joinPredicates(q.where)); err != nil { + return "", nil, err + } + } + + if err := ctx.writeReturning(q.returning, q.returningClause()); err != nil { + return "", nil, err + } + + return ctx.String(), ctx.args, ctx.err +} + +func (q *UpdateQuery) returningClause() returningClause { + return returningClause{ + feature: dialect.FeatureUpdateReturning, + label: "update", + } +} + +// Exec executes the UPDATE query. +func (q *UpdateQuery) Exec(ctx context.Context) (sql.Result, error) { + if q.runner == nil { + return nil, ErrNoConnection + } + + query, args, err := q.ToSQL() + if err != nil { + return nil, err + } + + return q.runner.execContext(ctx, query, args...) +} + +// Scan executes an UPDATE ... RETURNING query and scans results into dest. +func (q *UpdateQuery) Scan(ctx context.Context, dest any) error { + if q.runner == nil { + return ErrNoConnection + } + if len(q.returning) == 0 { + return errors.New("rain: update scan requires RETURNING") + } + + query, args, err := q.ToSQL() + if err != nil { + return err + } + + rows, err := q.runner.queryContext(ctx, query, args...) + if err != nil { + return err + } + defer closeRows(rows, &err) + + err = scanRows(rows, dest) + return err +} diff --git a/pkg/rain/query_write_test.go b/pkg/rain/query_write_test.go new file mode 100644 index 0000000..9ae4f6b --- /dev/null +++ b/pkg/rain/query_write_test.go @@ -0,0 +1,224 @@ +package rain_test + +import ( + "strings" + "testing" + + "github.com/hyperlocalise/rain-orm/pkg/dialect" + "github.com/hyperlocalise/rain-orm/pkg/rain" +) + +func TestInsertUpdateDeleteToSQL(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + insertSQL, insertArgs, err := db.Insert(). + Table(users). + Model(&userModel{Email: "alice@example.com", Name: "Alice", Active: true}). + Returning(users.ID). + ToSQL() + if err != nil { + t.Fatalf("insert ToSQL returned error: %v", err) + } + wantInsert := `INSERT INTO "users" ("email", "name", "active") VALUES ($1, $2, $3) RETURNING "users"."id"` + if insertSQL != wantInsert { + t.Fatalf("unexpected insert SQL:\nwant: %s\ngot: %s", wantInsert, insertSQL) + } + if len(insertArgs) != 3 { + t.Fatalf("unexpected insert args: %#v", insertArgs) + } + + updateSQL, updateArgs, err := db.Update(). + Table(users). + Set(users.Name, "Alice Smith"). + Where(users.ID.Eq(int64(1))). + ToSQL() + if err != nil { + t.Fatalf("update ToSQL returned error: %v", err) + } + wantUpdate := `UPDATE "users" SET "name" = $1 WHERE "users"."id" = $2` + if updateSQL != wantUpdate { + t.Fatalf("unexpected update SQL:\nwant: %s\ngot: %s", wantUpdate, updateSQL) + } + if len(updateArgs) != 2 { + t.Fatalf("unexpected update args: %#v", updateArgs) + } + + deleteSQL, deleteArgs, err := db.Delete(). + Table(users). + Where(users.ID.Eq(int64(99))). + ToSQL() + if err != nil { + t.Fatalf("delete ToSQL returned error: %v", err) + } + wantDelete := `DELETE FROM "users" WHERE "users"."id" = $1` + if deleteSQL != wantDelete { + t.Fatalf("unexpected delete SQL:\nwant: %s\ngot: %s", wantDelete, deleteSQL) + } + if len(deleteArgs) != 1 || deleteArgs[0] != int64(99) { + t.Fatalf("unexpected delete args: %#v", deleteArgs) + } +} + +func TestDialectFeatures(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + dialect string + features dialect.Feature + missing []dialect.Feature + }{ + { + name: "postgres", + dialect: "postgres", + features: dialect.FeatureInsertReturning | + dialect.FeatureUpdateReturning | + dialect.FeatureDeleteReturning | + dialect.FeatureOffset | + dialect.FeatureUpsert | + dialect.FeatureCTE | + dialect.FeatureDefaultPlaceholder | + dialect.FeatureSavepoint, + }, + { + name: "mysql", + dialect: "mysql", + features: dialect.FeatureOffset | dialect.FeatureUpsert | dialect.FeatureSavepoint, + missing: []dialect.Feature{ + dialect.FeatureInsertReturning, + dialect.FeatureUpdateReturning, + dialect.FeatureDeleteReturning, + dialect.FeatureCTE, + dialect.FeatureDefaultPlaceholder, + }, + }, + { + name: "sqlite", + dialect: "sqlite", + features: dialect.FeatureInsertReturning | + dialect.FeatureUpdateReturning | + dialect.FeatureDeleteReturning | + dialect.FeatureOffset | + dialect.FeatureUpsert | + dialect.FeatureSavepoint, + missing: []dialect.Feature{ + dialect.FeatureCTE, + dialect.FeatureDefaultPlaceholder, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect(tc.dialect) + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + got := db.Dialect().Features() + if got != tc.features { + t.Fatalf("unexpected features: want %b got %b", tc.features, got) + } + for _, feature := range tc.missing { + if dialect.HasFeature(got, feature) { + t.Fatalf("expected feature %b to be absent from %b", feature, got) + } + } + }) + } +} + +func TestOpenDialectUnknownDialectReturnsError(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postres") + if err == nil { + t.Fatalf("expected unsupported dialect error, got nil") + } + if db != nil { + t.Fatalf("expected nil db for unsupported dialect") + } +} + +func TestReturningUnsupportedDialect(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("mysql") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + _, _, err = db.Insert(). + Table(users). + Set(users.Name, "Alice"). + Returning(users.ID). + ToSQL() + if err == nil || !strings.Contains(err.Error(), "insert queries do not support RETURNING") { + t.Fatalf("expected insert RETURNING to fail on mysql dialect, got %v", err) + } + + _, _, err = db.Update(). + Table(users). + Set(users.Name, "Alice"). + Where(users.ID.Eq(int64(1))). + Returning(users.ID). + ToSQL() + if err == nil || !strings.Contains(err.Error(), "update queries do not support RETURNING") { + t.Fatalf("expected update RETURNING to fail on mysql dialect, got %v", err) + } + + _, _, err = db.Delete(). + Table(users). + Where(users.ID.Eq(int64(1))). + Returning(users.ID). + ToSQL() + if err == nil || !strings.Contains(err.Error(), "delete queries do not support RETURNING") { + t.Fatalf("expected delete RETURNING to fail on mysql dialect, got %v", err) + } +} + +func TestReturningSupportedOperations(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + insertSQL, _, err := db.Insert(). + Table(users). + Set(users.Name, "Alice"). + Returning(users.ID). + ToSQL() + if err != nil || !strings.Contains(insertSQL, "RETURNING") { + t.Fatalf("expected insert RETURNING to compile, got sql=%q err=%v", insertSQL, err) + } + + updateSQL, _, err := db.Update(). + Table(users). + Set(users.Name, "Alice"). + Where(users.ID.Eq(int64(1))). + Returning(users.ID). + ToSQL() + if err != nil || !strings.Contains(updateSQL, "RETURNING") { + t.Fatalf("expected update RETURNING to compile, got sql=%q err=%v", updateSQL, err) + } + + deleteSQL, _, err := db.Delete(). + Table(users). + Where(users.ID.Eq(int64(1))). + Returning(users.ID). + ToSQL() + if err != nil || !strings.Contains(deleteSQL, "RETURNING") { + t.Fatalf("expected delete RETURNING to compile, got sql=%q err=%v", deleteSQL, err) + } +} diff --git a/pkg/rain/sqlite_benchmark_test.go b/pkg/rain/sqlite_benchmark_test.go index 2939c08..12ae06f 100644 --- a/pkg/rain/sqlite_benchmark_test.go +++ b/pkg/rain/sqlite_benchmark_test.go @@ -181,10 +181,7 @@ func BenchmarkSQLiteSelectJoinScan(b *testing.B) { func BenchmarkSQLiteSelectWithRelations(b *testing.B) { runSQLiteBenchmarkDatasets(b, func(b *testing.B, fixture *benchmarkFixture, dataset benchmarkDataset) { ctx := context.Background() - limit := min(dataset.users/10, 100) - if limit < 1 { - limit = 1 - } + limit := max(min(dataset.users/10, 100), 1) b.ReportAllocs() b.ResetTimer() @@ -209,7 +206,6 @@ func runSQLiteBenchmarkDatasets( b.Helper() for _, dataset := range benchmarkDatasets { - dataset := dataset b.Run(dataset.name, func(b *testing.B) { fixture := newSQLiteBenchmarkFixture(b, dataset) run(b, fixture, dataset)