From 31e8285d48d484cf47d020c3cd669b606320ad06 Mon Sep 17 00:00:00 2001 From: Minh Cung Date: Sun, 29 Mar 2026 23:14:14 +1100 Subject: [PATCH 1/4] . --- README.md | 57 +++ examples/basic/main.go | 19 +- pkg/rain/rain.go | 205 ++++++++++- pkg/rain/read_replica_internal_test.go | 460 +++++++++++++++++++++++++ 4 files changed, 727 insertions(+), 14 deletions(-) create mode 100644 pkg/rain/read_replica_internal_test.go diff --git a/README.md b/README.md index 905d16c..be9dec7 100644 --- a/README.md +++ b/README.md @@ -174,6 +174,63 @@ if err != nil { `RunInTx` commits when the callback returns `nil` and rolls back when it returns an error. Nested `RunInTx` calls use savepoints on dialects that support them. Inside a nested callback, call patterns should return errors instead of calling `Commit`/`Rollback` directly. +## Read Replicas + +Rain can route builder-based reads to replicas while keeping writes, raw SQL, and transactions on primary: + +```go +primaryDB, err := rain.Open("postgres", "postgres://user:pass@localhost/primary") +if err != nil { + panic(err) +} + +read1, err := rain.Open("postgres", "postgres://user:pass@localhost/read1") +if err != nil { + panic(err) +} + +read2, err := rain.Open("postgres", "postgres://user:pass@localhost/read2") +if err != nil { + panic(err) +} + +db, err := rain.WithReplicas(primaryDB, []*rain.DB{read1, read2}, nil) +if err != nil { + panic(err) +} +defer func() { _ = db.Close() }() + +var replicaRows []User +if err := db.Select(). + Table(Users). + Where(Users.Active.Eq(true)). + Scan(context.Background(), &replicaRows); err != nil { + panic(err) +} + +var primaryRows []User +if err := db.Primary().Select(). + Table(Users). + Where(Users.Active.Eq(true)). + Scan(context.Background(), &primaryRows); err != nil { + panic(err) +} + +// Writes stay on primary. +if _, err := db.Insert(). + Table(Users). + Model(&User{Email: "replica-aware@example.com", Name: "Replica Aware"}). + Exec(context.Background()); err != nil { + panic(err) +} +``` + +Notes: +- `Select()` uses a replica by default. +- `Primary().Select()` forces reads to the primary database. +- `Insert`, `Update`, `Delete`, `Exec`, `Query`, `QueryRow`, `Begin`, and `RunInTx` always use primary. +- v1 does not hide replica lag automatically; use `Primary()` when you need read-after-write consistency. + ## Opt-in Query Cache (v1) Rain supports opt-in caching for `SELECT` helpers (`Scan`, `Count`, and `Exists`). Caching is disabled unless you set a cache backend on `DB`. diff --git a/examples/basic/main.go b/examples/basic/main.go index 5d4097f..828dfce 100644 --- a/examples/basic/main.go +++ b/examples/basic/main.go @@ -48,7 +48,17 @@ var Posts = schema.Define("posts", func(t *PostsTable) { }) func main() { - db, err := rain.Open("postgres", "postgres://example") + primaryDB, err := rain.Open("postgres", "postgres://primary") + if err != nil { + panic(err) + } + + readReplica, err := rain.Open("postgres", "postgres://read-replica") + if err != nil { + panic(err) + } + + db, err := rain.WithReplicas(primaryDB, []*rain.DB{readReplica}, nil) if err != nil { panic(err) } @@ -94,9 +104,16 @@ func main() { Where(Users.ID.Eq(int64(99))). ToSQL() + primarySelectSQL, primarySelectArgs, _ := db.Primary().Select(). + Table(Users). + Column(Users.ID, Users.Email). + Limit(5). + ToSQL() + fmt.Println(selectSQL, selectArgs) fmt.Println(aggSQL, aggArgs) fmt.Println(insertSQL, insertArgs) fmt.Println(updateSQL, updateArgs) fmt.Println(deleteSQL, deleteArgs) + fmt.Println(primarySelectSQL, primarySelectArgs) } diff --git a/pkg/rain/rain.go b/pkg/rain/rain.go index 7722174..5cf81be 100644 --- a/pkg/rain/rain.go +++ b/pkg/rain/rain.go @@ -6,7 +6,9 @@ import ( "database/sql" "errors" "fmt" + "math/rand" "strings" + "sync" "sync/atomic" "github.com/hyperlocalise/rain-orm/pkg/dialect" @@ -21,11 +23,30 @@ var ErrNestedTxNotSupported = errors.New("rain: nested transactions are not supp // ErrNestedTxControlNotAllowed is returned when nested callbacks attempt to commit or roll back the outer transaction directly. var ErrNestedTxControlNotAllowed = errors.New("rain: nested RunInTx callbacks cannot call Commit or Rollback directly") +// ReplicaSelector chooses which read replica should serve a SELECT query. +type ReplicaSelector func(replicas []*DB) *DB + +type dbSharedState struct { + queryCache QueryCache +} + +type replicaRoute struct { + primary *DB + replicas []*DB + selector ReplicaSelector + all []*DB + + closeOnce sync.Once + closeErr error +} + // DB represents a database connection pool. type DB struct { - db *sql.DB - dialect dialect.Dialect - queryCache QueryCache + db *sql.DB + dialect dialect.Dialect + shared *dbSharedState + replicaRoute *replicaRoute + forcePrimaryReads bool } // Open creates a database handle for the selected dialect. @@ -46,6 +67,7 @@ func Open(driver, dsn string) (*DB, error) { return &DB{ db: db, dialect: d, + shared: &dbSharedState{}, }, nil } @@ -58,11 +80,76 @@ func OpenDialect(driver string) (*DB, error) { return &DB{ dialect: d, + shared: &dbSharedState{}, + }, nil +} + +// WithReplicas returns a DB handle that routes SELECT queries to read replicas while +// keeping writes, raw SQL, and transactions on the primary handle. +func WithReplicas(primary *DB, replicas []*DB, selector ReplicaSelector) (*DB, error) { + if primary == nil { + return nil, errors.New("rain: read replicas require a non-nil primary database") + } + if len(replicas) == 0 { + return nil, errors.New("rain: read replicas require at least one replica database") + } + + shared := primary.ensureSharedState() + validatedReplicas := make([]*DB, 0, len(replicas)) + seen := make(map[*DB]struct{}, len(replicas)+1) + underlying := make([]*DB, 0, len(replicas)+1) + + if _, ok := seen[primary]; !ok { + seen[primary] = struct{}{} + underlying = append(underlying, primary) + } + + for idx, replica := range replicas { + if replica == nil { + return nil, fmt.Errorf("rain: read replica %d is nil", idx+1) + } + if replica.Dialect().Name() != primary.Dialect().Name() { + return nil, fmt.Errorf( + "rain: read replica %d uses dialect %q, expected %q", + idx+1, + replica.Dialect().Name(), + primary.Dialect().Name(), + ) + } + replica.shared = shared + validatedReplicas = append(validatedReplicas, replica) + if _, ok := seen[replica]; ok { + continue + } + seen[replica] = struct{}{} + underlying = append(underlying, replica) + } + + if selector == nil { + selector = randomReplicaSelector + } + + primary.shared = shared + route := &replicaRoute{ + primary: primary, + replicas: validatedReplicas, + selector: selector, + all: underlying, + } + + return &DB{ + db: primary.db, + dialect: primary.dialect, + shared: shared, + replicaRoute: route, }, nil } // Close closes the database connection. func (db *DB) Close() error { + if db.replicaRoute != nil { + return db.replicaRoute.close() + } if db.db == nil { return nil } @@ -75,38 +162,58 @@ func (db *DB) Dialect() dialect.Dialect { return db.dialect } +// Primary returns a DB view that forces reads to use the primary handle. +func (db *DB) Primary() *DB { + if db == nil || db.replicaRoute == nil { + return db + } + + return &DB{ + db: db.replicaRoute.primary.db, + dialect: db.replicaRoute.primary.dialect, + shared: db.shared, + replicaRoute: db.replicaRoute, + forcePrimaryReads: true, + } +} + // Select starts a typed SELECT query builder. func (db *DB) Select() *SelectQuery { - return &SelectQuery{runner: db, dialect: db.dialect, cache: db.queryCache} + runner := db.selectRunner() + cache := db.queryCache() + if routed, ok := runner.(*DB); ok { + return &SelectQuery{runner: routed, dialect: db.dialect, cache: cache} + } + return &SelectQuery{runner: runner, dialect: db.dialect, cache: cache} } // WithQueryCache sets the shared SELECT query cache backend on DB. func (db *DB) WithQueryCache(cache QueryCache) *DB { - db.queryCache = cache + db.ensureSharedState().queryCache = cache return db } // InvalidateQueryCache removes cached query entries associated with any provided tag. func (db *DB) InvalidateQueryCache(ctx context.Context, tags ...string) error { - if db.queryCache == nil { + if db.queryCache() == nil { return nil } - return db.queryCache.InvalidateTags(ctx, tags...) + return db.queryCache().InvalidateTags(ctx, tags...) } // Insert starts a typed INSERT query builder. func (db *DB) Insert() *InsertQuery { - return &InsertQuery{runner: db, dialect: db.dialect} + return &InsertQuery{runner: db.primaryRunner(), dialect: db.dialect} } // Update starts a typed UPDATE query builder. func (db *DB) Update() *UpdateQuery { - return &UpdateQuery{runner: db, dialect: db.dialect} + return &UpdateQuery{runner: db.primaryRunner(), dialect: db.dialect} } // Delete starts a typed DELETE query builder. func (db *DB) Delete() *DeleteQuery { - return &DeleteQuery{runner: db, dialect: db.dialect} + return &DeleteQuery{runner: db.primaryRunner(), dialect: db.dialect} } // Exec executes raw SQL against the database. @@ -146,16 +253,17 @@ func (db *DB) QueryRow(ctx context.Context, query string, args ...any) *sql.Row // Begin starts a new transaction. func (db *DB) Begin(ctx context.Context) (*Tx, error) { - if db.db == nil { + primary := db.primaryHandle() + if primary.db == nil { return nil, ErrNoConnection } - tx, err := db.db.BeginTx(ctx, nil) + tx, err := primary.db.BeginTx(ctx, nil) if err != nil { return nil, err } - return &Tx{tx: tx, dialect: db.dialect, savepointSeq: new(int64), canControlTx: true, queryCache: db.queryCache}, nil + return &Tx{tx: tx, dialect: db.dialect, savepointSeq: new(int64), canControlTx: true, queryCache: db.queryCache()}, nil } // RunInTx executes fn in a transaction, rolling back on error and committing on success. @@ -323,3 +431,74 @@ func (tx *Tx) rollbackSavepoint(ctx context.Context, savepoint string) error { return nil } + +func (db *DB) ensureSharedState() *dbSharedState { + if db.shared == nil { + db.shared = &dbSharedState{} + } + return db.shared +} + +func (db *DB) queryCache() QueryCache { + if db.shared == nil { + return nil + } + return db.shared.queryCache +} + +func (db *DB) primaryHandle() *DB { + if db == nil || db.replicaRoute == nil { + return db + } + return db.replicaRoute.primary +} + +func (db *DB) primaryRunner() queryRunner { + return db.primaryHandle() +} + +func (db *DB) selectRunner() queryRunner { + if db == nil || db.replicaRoute == nil || db.forcePrimaryReads { + return db.primaryRunner() + } + return db.replicaRoute.pickReplica() +} + +func randomReplicaSelector(replicas []*DB) *DB { + if len(replicas) == 0 { + return nil + } + return replicas[rand.Intn(len(replicas))] +} + +func (r *replicaRoute) pickReplica() *DB { + if r == nil || len(r.replicas) == 0 { + return nil + } + chosen := r.selector(r.replicas) + for _, replica := range r.replicas { + if replica == chosen { + return replica + } + } + return randomReplicaSelector(r.replicas) +} + +func (r *replicaRoute) close() error { + if r == nil { + return nil + } + r.closeOnce.Do(func() { + var errs []error + for _, handle := range r.all { + if handle == nil || handle.db == nil { + continue + } + if err := handle.db.Close(); err != nil { + errs = append(errs, err) + } + } + r.closeErr = errors.Join(errs...) + }) + return r.closeErr +} diff --git a/pkg/rain/read_replica_internal_test.go b/pkg/rain/read_replica_internal_test.go new file mode 100644 index 0000000..670853b --- /dev/null +++ b/pkg/rain/read_replica_internal_test.go @@ -0,0 +1,460 @@ +package rain + +import ( + "context" + "path/filepath" + "testing" + + "github.com/hyperlocalise/rain-orm/pkg/schema" + _ "modernc.org/sqlite" +) + +func openReplicaTestDB(t *testing.T, name string) *DB { + t.Helper() + + dbPath := filepath.Join(t.TempDir(), name+".sqlite") + db, err := Open("sqlite", dbPath) + if err != nil { + t.Fatalf("open sqlite db %q: %v", name, err) + } + t.Cleanup(func() { + _ = db.Close() + }) + + return db +} + +func createInternalQuerySchemaForTables( + t *testing.T, + ctx context.Context, + db *DB, + users *internalQueryUsersTable, + posts *internalQueryPostsTable, +) { + t.Helper() + + for _, table := range []schema.TableReference{users, posts} { + statement, err := db.CreateTableSQL(table) + if err != nil { + t.Fatalf("compile schema for %q: %v", table.TableDef().Name, err) + } + if _, err := db.Exec(ctx, statement); err != nil { + t.Fatalf("exec schema statement %q: %v", statement, err) + } + } +} + +func insertReplicaTestUser(t *testing.T, ctx context.Context, db *DB, users *internalQueryUsersTable, email, name string) int64 { + t.Helper() + + result, err := db.Insert(). + Table(users). + Model(&internalInsertModel{Email: email, Name: name}). + Exec(ctx) + if err != nil { + t.Fatalf("insert user %q: %v", email, err) + } + id, err := result.LastInsertId() + if err != nil { + t.Fatalf("last insert id for %q: %v", email, err) + } + + return id +} + +func insertReplicaTestPost(t *testing.T, ctx context.Context, db *DB, posts *internalQueryPostsTable, userID int64, title string) { + t.Helper() + + if _, err := db.Insert(). + Table(posts). + Set(posts.UserID, userID). + Set(posts.Title, title). + Exec(ctx); err != nil { + t.Fatalf("insert post %q: %v", title, err) + } +} + +func TestWithReplicasValidation(t *testing.T) { + t.Parallel() + + primary, err := OpenDialect("sqlite") + if err != nil { + t.Fatalf("OpenDialect(sqlite): %v", err) + } + replica, err := OpenDialect("sqlite") + if err != nil { + t.Fatalf("OpenDialect(sqlite): %v", err) + } + mysqlReplica, err := OpenDialect("mysql") + if err != nil { + t.Fatalf("OpenDialect(mysql): %v", err) + } + + if _, err := WithReplicas(nil, []*DB{replica}, nil); err == nil { + t.Fatalf("expected nil primary validation error") + } + if _, err := WithReplicas(primary, nil, nil); err == nil { + t.Fatalf("expected empty replica validation error") + } + if _, err := WithReplicas(primary, []*DB{nil}, nil); err == nil { + t.Fatalf("expected nil replica validation error") + } + if _, err := WithReplicas(primary, []*DB{mysqlReplica}, nil); err == nil { + t.Fatalf("expected mixed dialect validation error") + } +} + +func TestWithReplicasDefaultSelectorUsesOnlyReplicas(t *testing.T) { + t.Parallel() + + ctx := context.Background() + users, posts := defineInternalQueryTables() + primary := openReplicaTestDB(t, "primary-default") + replica1 := openReplicaTestDB(t, "replica-default-1") + replica2 := openReplicaTestDB(t, "replica-default-2") + + for _, db := range []*DB{primary, replica1, replica2} { + createInternalQuerySchemaForTables(t, ctx, db, users, posts) + } + + insertReplicaTestUser(t, ctx, primary, users, "primary@example.com", "Primary") + insertReplicaTestUser(t, ctx, replica1, users, "replica1-a@example.com", "Replica 1A") + insertReplicaTestUser(t, ctx, replica1, users, "replica1-b@example.com", "Replica 1B") + insertReplicaTestUser(t, ctx, replica2, users, "replica2-a@example.com", "Replica 2A") + insertReplicaTestUser(t, ctx, replica2, users, "replica2-b@example.com", "Replica 2B") + insertReplicaTestUser(t, ctx, replica2, users, "replica2-c@example.com", "Replica 2C") + + routed, err := WithReplicas(primary, []*DB{replica1, replica2}, nil) + if err != nil { + t.Fatalf("WithReplicas: %v", err) + } + + for range 24 { + count, err := routed.Select().Table(users).Count(ctx) + if err != nil { + t.Fatalf("Count: %v", err) + } + if count == 1 { + t.Fatalf("expected default selector to avoid primary row count, got %d", count) + } + if count != 2 && count != 3 { + t.Fatalf("expected replica row count, got %d", count) + } + } +} + +func TestWithReplicasCustomSelectorAndPrimaryView(t *testing.T) { + t.Parallel() + + ctx := context.Background() + users, posts := defineInternalQueryTables() + primary := openReplicaTestDB(t, "primary-custom") + replica1 := openReplicaTestDB(t, "replica-custom-1") + replica2 := openReplicaTestDB(t, "replica-custom-2") + + for _, db := range []*DB{primary, replica1, replica2} { + createInternalQuerySchemaForTables(t, ctx, db, users, posts) + } + + insertReplicaTestUser(t, ctx, primary, users, "primary@example.com", "Primary") + insertReplicaTestUser(t, ctx, replica1, users, "replica1@example.com", "Replica 1") + insertReplicaTestUser(t, ctx, replica2, users, "replica2@example.com", "Replica 2") + + routed, err := WithReplicas(primary, []*DB{replica1, replica2}, func(replicas []*DB) *DB { + return replicas[1] + }) + if err != nil { + t.Fatalf("WithReplicas: %v", err) + } + + var row internalUserRow + if err := routed.Select(). + Table(users). + Scan(ctx, &row); err != nil { + t.Fatalf("replica select scan: %v", err) + } + if row.Email != "replica2@example.com" { + t.Fatalf("expected custom selector to use replica2, got %#v", row) + } + + row = internalUserRow{} + if err := routed.Primary().Select(). + Table(users). + Scan(ctx, &row); err != nil { + t.Fatalf("primary select scan: %v", err) + } + if row.Email != "primary@example.com" { + t.Fatalf("expected primary view to use primary, got %#v", row) + } +} + +func TestWithReplicasRelationLoadingUsesSelectedReplica(t *testing.T) { + t.Parallel() + + ctx := context.Background() + users, posts := defineInternalQueryTables() + primary := openReplicaTestDB(t, "primary-relations") + replica := openReplicaTestDB(t, "replica-relations") + + for _, db := range []*DB{primary, replica} { + createInternalQuerySchemaForTables(t, ctx, db, users, posts) + } + + primaryUserID := insertReplicaTestUser(t, ctx, primary, users, "shared@example.com", "Primary User") + insertReplicaTestPost(t, ctx, primary, posts, primaryUserID, "primary-post") + + replicaUserID := insertReplicaTestUser(t, ctx, replica, users, "shared@example.com", "Replica User") + insertReplicaTestPost(t, ctx, replica, posts, replicaUserID, "replica-post") + + routed, err := WithReplicas(primary, []*DB{replica}, func(replicas []*DB) *DB { + return replicas[0] + }) + if err != nil { + t.Fatalf("WithReplicas: %v", err) + } + + var rows []internalUserWithPostsRow + if err := routed.Select(). + Table(users). + WithRelations("posts"). + Scan(ctx, &rows); err != nil { + t.Fatalf("select with relations: %v", err) + } + if len(rows) != 1 || len(rows[0].Posts) != 1 { + t.Fatalf("expected one related post from replica, got %#v", rows) + } + if rows[0].Posts[0].Title != "replica-post" { + t.Fatalf("expected replica relation row, got %#v", rows[0].Posts[0]) + } +} + +func TestWithReplicasWritesUsePrimary(t *testing.T) { + t.Parallel() + + ctx := context.Background() + users, posts := defineInternalQueryTables() + primary := openReplicaTestDB(t, "primary-writes") + replica := openReplicaTestDB(t, "replica-writes") + + for _, db := range []*DB{primary, replica} { + createInternalQuerySchemaForTables(t, ctx, db, users, posts) + } + + routed, err := WithReplicas(primary, []*DB{replica}, nil) + if err != nil { + t.Fatalf("WithReplicas: %v", err) + } + + var inserted internalUserRow + if err := routed.Insert(). + Table(users). + Model(&internalInsertModel{Email: "inserted@example.com", Name: "Inserted"}). + Returning(users.ID, users.Email, users.Name). + Scan(ctx, &inserted); err != nil { + t.Fatalf("insert returning scan: %v", err) + } + + var updated internalUserRow + if err := routed.Update(). + Table(users). + Set(users.Name, "Updated"). + Where(users.ID.Eq(inserted.ID)). + Returning(users.ID, users.Email, users.Name). + Scan(ctx, &updated); err != nil { + t.Fatalf("update returning scan: %v", err) + } + if updated.Name != "Updated" { + t.Fatalf("expected updated row from primary, got %#v", updated) + } + + var deleted internalUserRow + if err := routed.Delete(). + Table(users). + Where(users.ID.Eq(inserted.ID)). + Returning(users.ID, users.Email). + Scan(ctx, &deleted); err != nil { + t.Fatalf("delete returning scan: %v", err) + } + if deleted.Email != "inserted@example.com" { + t.Fatalf("expected deleted row from primary, got %#v", deleted) + } + + primaryCount, err := primary.Select().Table(users).Count(ctx) + if err != nil && primaryCount != 0 { + t.Fatalf("primary count after delete: %v", err) + } + replicaCount, err := replica.Select().Table(users).Count(ctx) + if err != nil && replicaCount != 0 { + t.Fatalf("replica count after delete: %v", err) + } + if primaryCount != 0 || replicaCount != 0 { + t.Fatalf("expected writes to affect only primary and leave no rows, got primary=%d replica=%d", primaryCount, replicaCount) + } +} + +func TestWithReplicasTransactionsAndRawSQLUsePrimary(t *testing.T) { + t.Parallel() + + ctx := context.Background() + users, posts := defineInternalQueryTables() + primary := openReplicaTestDB(t, "primary-tx") + replica := openReplicaTestDB(t, "replica-tx") + + for _, db := range []*DB{primary, replica} { + createInternalQuerySchemaForTables(t, ctx, db, users, posts) + } + + insertReplicaTestUser(t, ctx, primary, users, "primary@example.com", "Primary") + insertReplicaTestUser(t, ctx, replica, users, "replica-a@example.com", "Replica A") + insertReplicaTestUser(t, ctx, replica, users, "replica-b@example.com", "Replica B") + + routed, err := WithReplicas(primary, []*DB{replica}, nil) + if err != nil { + t.Fatalf("WithReplicas: %v", err) + } + + tx, err := routed.Begin(ctx) + if err != nil { + t.Fatalf("Begin: %v", err) + } + count, err := tx.Select().Table(users).Count(ctx) + if err != nil { + t.Fatalf("tx count: %v", err) + } + if count != 1 { + t.Fatalf("expected tx read to use primary, got %d", count) + } + if err := tx.Rollback(); err != nil { + t.Fatalf("Rollback: %v", err) + } + + if err := routed.RunInTx(ctx, func(tx *Tx) error { + txCount, err := tx.Select().Table(users).Count(ctx) + if err != nil { + return err + } + if txCount != 1 { + t.Fatalf("expected RunInTx read to use primary, got %d", txCount) + } + _, err = tx.Insert().Table(users).Model(&internalInsertModel{ + Email: "tx-write@example.com", + Name: "Tx Write", + }).Exec(ctx) + return err + }); err != nil { + t.Fatalf("RunInTx: %v", err) + } + + row := routed.QueryRow(ctx, `SELECT COUNT(*) FROM users`) + if row == nil { + t.Fatalf("QueryRow returned nil") + } + var queryRowCount int + if err := row.Scan(&queryRowCount); err != nil { + t.Fatalf("scan QueryRow count: %v", err) + } + if queryRowCount != 2 { + t.Fatalf("expected QueryRow to use primary, got %d", queryRowCount) + } + + rows, err := routed.Query(ctx, `SELECT email FROM users ORDER BY id`) + if err != nil { + t.Fatalf("Query: %v", err) + } + defer func() { + _ = rows.Close() + }() + + var emails []string + for rows.Next() { + var email string + if err := rows.Scan(&email); err != nil { + t.Fatalf("scan Query row: %v", err) + } + emails = append(emails, email) + } + if err := rows.Err(); err != nil { + t.Fatalf("rows err: %v", err) + } + if len(emails) != 2 || emails[0] != "primary@example.com" || emails[1] != "tx-write@example.com" { + t.Fatalf("expected Query to use primary, got %#v", emails) + } + + if _, err := routed.Exec(ctx, `INSERT INTO users (email, name, active, created_at) VALUES (?, ?, ?, CURRENT_TIMESTAMP)`, "exec-write@example.com", "Exec Write", true); err != nil { + t.Fatalf("Exec: %v", err) + } + + primaryCount, err := primary.Select().Table(users).Count(ctx) + if err != nil { + t.Fatalf("primary count: %v", err) + } + replicaCount, err := replica.Select().Table(users).Count(ctx) + if err != nil { + t.Fatalf("replica count: %v", err) + } + if primaryCount != 3 || replicaCount != 2 { + t.Fatalf("expected raw SQL and tx writes on primary, got primary=%d replica=%d", primaryCount, replicaCount) + } +} + +func TestWithReplicasSharesQueryCacheAndPrimaryView(t *testing.T) { + t.Parallel() + + primary, err := OpenDialect("sqlite") + if err != nil { + t.Fatalf("OpenDialect(sqlite): %v", err) + } + replica, err := OpenDialect("sqlite") + if err != nil { + t.Fatalf("OpenDialect(sqlite): %v", err) + } + + routed, err := WithReplicas(primary, []*DB{replica}, nil) + if err != nil { + t.Fatalf("WithReplicas: %v", err) + } + + cache := NewMemoryQueryCache() + routed.WithQueryCache(cache) + + if routed.Select().cache != cache { + t.Fatalf("expected routed Select cache to be shared") + } + if routed.Primary().Select().cache != cache { + t.Fatalf("expected primary view Select cache to be shared") + } + if primary.Select().cache != cache { + t.Fatalf("expected primary handle Select cache to be shared") + } + if replica.Select().cache != cache { + t.Fatalf("expected replica handle Select cache to be shared") + } +} + +func TestWithReplicasCloseDeduplicatesUnderlyingHandles(t *testing.T) { + t.Parallel() + + primary, err := OpenDialect("sqlite") + if err != nil { + t.Fatalf("OpenDialect(sqlite): %v", err) + } + replica, err := OpenDialect("sqlite") + if err != nil { + t.Fatalf("OpenDialect(sqlite): %v", err) + } + + routed, err := WithReplicas(primary, []*DB{replica, replica}, nil) + if err != nil { + t.Fatalf("WithReplicas: %v", err) + } + + if got := len(routed.replicaRoute.all); got != 2 { + t.Fatalf("expected two unique underlying handles, got %d", got) + } + if err := routed.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + if err := routed.Close(); err != nil { + t.Fatalf("second Close: %v", err) + } +} From 4fcd961426b706ac25e7b83d819d22bbceea0afe Mon Sep 17 00:00:00 2001 From: Minh Cung Date: Sun, 29 Mar 2026 23:27:44 +1100 Subject: [PATCH 2/4] . --- pkg/rain/rain.go | 13 +-- pkg/rain/read_replica_internal_test.go | 147 ++++++++++++++++++++++++- 2 files changed, 148 insertions(+), 12 deletions(-) diff --git a/pkg/rain/rain.go b/pkg/rain/rain.go index 5cf81be..d8df272 100644 --- a/pkg/rain/rain.go +++ b/pkg/rain/rain.go @@ -99,10 +99,8 @@ func WithReplicas(primary *DB, replicas []*DB, selector ReplicaSelector) (*DB, e seen := make(map[*DB]struct{}, len(replicas)+1) underlying := make([]*DB, 0, len(replicas)+1) - if _, ok := seen[primary]; !ok { - seen[primary] = struct{}{} - underlying = append(underlying, primary) - } + seen[primary] = struct{}{} + underlying = append(underlying, primary) for idx, replica := range replicas { if replica == nil { @@ -179,12 +177,7 @@ func (db *DB) Primary() *DB { // Select starts a typed SELECT query builder. func (db *DB) Select() *SelectQuery { - runner := db.selectRunner() - cache := db.queryCache() - if routed, ok := runner.(*DB); ok { - return &SelectQuery{runner: routed, dialect: db.dialect, cache: cache} - } - return &SelectQuery{runner: runner, dialect: db.dialect, cache: cache} + return &SelectQuery{runner: db.selectRunner(), dialect: db.dialect, cache: db.queryCache()} } // WithQueryCache sets the shared SELECT query cache backend on DB. diff --git a/pkg/rain/read_replica_internal_test.go b/pkg/rain/read_replica_internal_test.go index 670853b..2cee55d 100644 --- a/pkg/rain/read_replica_internal_test.go +++ b/pkg/rain/read_replica_internal_test.go @@ -3,6 +3,8 @@ package rain import ( "context" "path/filepath" + "strconv" + "sync" "testing" "github.com/hyperlocalise/rain-orm/pkg/schema" @@ -188,6 +190,42 @@ func TestWithReplicasCustomSelectorAndPrimaryView(t *testing.T) { } } +func TestWithReplicasFallsBackWhenSelectorReturnsUnknownHandle(t *testing.T) { + t.Parallel() + + ctx := context.Background() + users, posts := defineInternalQueryTables() + primary := openReplicaTestDB(t, "primary-selector-fallback") + replica := openReplicaTestDB(t, "replica-selector-fallback") + unknown := openReplicaTestDB(t, "unknown-selector-fallback") + + for _, db := range []*DB{primary, replica, unknown} { + createInternalQuerySchemaForTables(t, ctx, db, users, posts) + } + + insertReplicaTestUser(t, ctx, primary, users, "primary@example.com", "Primary") + insertReplicaTestUser(t, ctx, replica, users, "replica-a@example.com", "Replica A") + insertReplicaTestUser(t, ctx, replica, users, "replica-b@example.com", "Replica B") + insertReplicaTestUser(t, ctx, unknown, users, "unknown@example.com", "Unknown") + + routed, err := WithReplicas(primary, []*DB{replica}, func(_ []*DB) *DB { + return unknown + }) + if err != nil { + t.Fatalf("WithReplicas: %v", err) + } + + for range 8 { + count, err := routed.Select().Table(users).Count(ctx) + if err != nil { + t.Fatalf("Count: %v", err) + } + if count != 2 { + t.Fatalf("expected fallback to configured replica count 2, got %d", count) + } + } +} + func TestWithReplicasRelationLoadingUsesSelectedReplica(t *testing.T) { t.Parallel() @@ -280,11 +318,11 @@ func TestWithReplicasWritesUsePrimary(t *testing.T) { } primaryCount, err := primary.Select().Table(users).Count(ctx) - if err != nil && primaryCount != 0 { + if err != nil { t.Fatalf("primary count after delete: %v", err) } replicaCount, err := replica.Select().Table(users).Count(ctx) - if err != nil && replicaCount != 0 { + if err != nil { t.Fatalf("replica count after delete: %v", err) } if primaryCount != 0 || replicaCount != 0 { @@ -429,6 +467,26 @@ func TestWithReplicasSharesQueryCacheAndPrimaryView(t *testing.T) { if replica.Select().cache != cache { t.Fatalf("expected replica handle Select cache to be shared") } + + primaryView := routed.Primary() + if primaryView == nil { + t.Fatalf("expected Primary view") + } + if primaryView == routed { + t.Fatalf("expected Primary to return a distinct view") + } + if primaryView.Primary() == primaryView { + t.Fatalf("expected repeated Primary calls to return a fresh primary view") + } + if !primaryView.forcePrimaryReads { + t.Fatalf("expected primary view to force primary reads") + } + if primaryView.replicaRoute != routed.replicaRoute { + t.Fatalf("expected primary view to share replica route") + } + if primaryView.shared != routed.shared { + t.Fatalf("expected primary view to share query cache state") + } } func TestWithReplicasCloseDeduplicatesUnderlyingHandles(t *testing.T) { @@ -458,3 +516,88 @@ func TestWithReplicasCloseDeduplicatesUnderlyingHandles(t *testing.T) { t.Fatalf("second Close: %v", err) } } + +func TestWithReplicasCloseSharesDedupAcrossViews(t *testing.T) { + t.Parallel() + + primary, err := OpenDialect("sqlite") + if err != nil { + t.Fatalf("OpenDialect(sqlite): %v", err) + } + replica, err := OpenDialect("sqlite") + if err != nil { + t.Fatalf("OpenDialect(sqlite): %v", err) + } + + routed, err := WithReplicas(primary, []*DB{replica}, nil) + if err != nil { + t.Fatalf("WithReplicas: %v", err) + } + + primaryView := routed.Primary() + if err := primaryView.Close(); err != nil { + t.Fatalf("Primary().Close: %v", err) + } + if err := routed.Close(); err != nil { + t.Fatalf("Close after Primary().Close: %v", err) + } +} + +func TestWithReplicasConcurrentReadsStayOnReplicas(t *testing.T) { + t.Parallel() + + ctx := context.Background() + users, posts := defineInternalQueryTables() + primary := openReplicaTestDB(t, "primary-concurrent") + replica1 := openReplicaTestDB(t, "replica-concurrent-1") + replica2 := openReplicaTestDB(t, "replica-concurrent-2") + + for _, db := range []*DB{primary, replica1, replica2} { + createInternalQuerySchemaForTables(t, ctx, db, users, posts) + } + + insertReplicaTestUser(t, ctx, primary, users, "primary@example.com", "Primary") + insertReplicaTestUser(t, ctx, replica1, users, "replica1-a@example.com", "Replica 1A") + insertReplicaTestUser(t, ctx, replica1, users, "replica1-b@example.com", "Replica 1B") + insertReplicaTestUser(t, ctx, replica2, users, "replica2-a@example.com", "Replica 2A") + insertReplicaTestUser(t, ctx, replica2, users, "replica2-b@example.com", "Replica 2B") + insertReplicaTestUser(t, ctx, replica2, users, "replica2-c@example.com", "Replica 2C") + + routed, err := WithReplicas(primary, []*DB{replica1, replica2}, nil) + if err != nil { + t.Fatalf("WithReplicas: %v", err) + } + + errCh := make(chan error, 32) + var wg sync.WaitGroup + for range 32 { + wg.Add(1) + go func() { + defer wg.Done() + count, err := routed.Select().Table(users).Count(ctx) + if err != nil { + errCh <- err + return + } + if count != 2 && count != 3 { + errCh <- &replicaCountError{count: count} + } + }() + } + wg.Wait() + close(errCh) + + for err := range errCh { + if err != nil { + t.Fatalf("concurrent routed read failed: %v", err) + } + } +} + +type replicaCountError struct { + count int64 +} + +func (e *replicaCountError) Error() string { + return "unexpected replica count: " + strconv.FormatInt(e.count, 10) +} From c3929c5595c69d764ec3c2c224bf5bc8c50a03a2 Mon Sep 17 00:00:00 2001 From: Minh Cung Date: Sun, 29 Mar 2026 23:31:48 +1100 Subject: [PATCH 3/4] . --- pkg/rain/query_runtime_internal_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/rain/query_runtime_internal_test.go b/pkg/rain/query_runtime_internal_test.go index ec66d5a..bd33e05 100644 --- a/pkg/rain/query_runtime_internal_test.go +++ b/pkg/rain/query_runtime_internal_test.go @@ -173,7 +173,7 @@ func TestSelectQueryCacheArgsAndManualInvalidation(t *testing.T) { counter := &countingRunner{base: db} queryFor := func(email string) *SelectQuery { - return (&SelectQuery{runner: counter, dialect: db.Dialect(), cache: db.queryCache}). + return (&SelectQuery{runner: counter, dialect: db.Dialect(), cache: db.queryCache()}). Table(users). Where(users.Email.Eq(email)). Cache(QueryCacheOptions{TTL: 5 * time.Minute, Tags: []string{"users"}}) @@ -269,7 +269,7 @@ func TestSelectAggregateCacheForCountAndExists(t *testing.T) { } counter := &countingRunner{base: db} - query := (&SelectQuery{runner: counter, dialect: db.Dialect(), cache: db.queryCache}). + query := (&SelectQuery{runner: counter, dialect: db.Dialect(), cache: db.queryCache()}). Table(users). Where(users.Email.Eq("agg@example.com")). Cache(QueryCacheOptions{TTL: time.Minute, Tags: []string{"users"}}) From 351920440ba03c6e3df4690c18a9bfd54febc008 Mon Sep 17 00:00:00 2001 From: Minh Cung Date: Sun, 29 Mar 2026 23:41:17 +1100 Subject: [PATCH 4/4] . --- pkg/rain/rain.go | 34 +++++++++++++++++++++++++- pkg/rain/read_replica_internal_test.go | 31 +++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/pkg/rain/rain.go b/pkg/rain/rain.go index d8df272..694d9fa 100644 --- a/pkg/rain/rain.go +++ b/pkg/rain/rain.go @@ -94,7 +94,7 @@ func WithReplicas(primary *DB, replicas []*DB, selector ReplicaSelector) (*DB, e return nil, errors.New("rain: read replicas require at least one replica database") } - shared := primary.ensureSharedState() + shared := resolveReplicaSharedState(primary, replicas) validatedReplicas := make([]*DB, 0, len(replicas)) seen := make(map[*DB]struct{}, len(replicas)+1) underlying := make([]*DB, 0, len(replicas)+1) @@ -432,6 +432,38 @@ func (db *DB) ensureSharedState() *dbSharedState { return db.shared } +func resolveReplicaSharedState(primary *DB, replicas []*DB) *dbSharedState { + var shared *dbSharedState + if primary != nil && primary.shared != nil { + shared = primary.shared + } + if shared == nil { + for _, replica := range replicas { + if replica != nil && replica.shared != nil { + shared = replica.shared + break + } + } + } + if shared == nil { + shared = &dbSharedState{} + } + if shared.queryCache != nil { + return shared + } + if primary != nil && primary.queryCache() != nil { + shared.queryCache = primary.queryCache() + return shared + } + for _, replica := range replicas { + if replica != nil && replica.queryCache() != nil { + shared.queryCache = replica.queryCache() + return shared + } + } + return shared +} + func (db *DB) queryCache() QueryCache { if db.shared == nil { return nil diff --git a/pkg/rain/read_replica_internal_test.go b/pkg/rain/read_replica_internal_test.go index 2cee55d..905bd6d 100644 --- a/pkg/rain/read_replica_internal_test.go +++ b/pkg/rain/read_replica_internal_test.go @@ -489,6 +489,37 @@ func TestWithReplicasSharesQueryCacheAndPrimaryView(t *testing.T) { } } +func TestWithReplicasPreservesPreconfiguredReplicaCache(t *testing.T) { + t.Parallel() + + primary, err := OpenDialect("sqlite") + if err != nil { + t.Fatalf("OpenDialect(sqlite): %v", err) + } + replica, err := OpenDialect("sqlite") + if err != nil { + t.Fatalf("OpenDialect(sqlite): %v", err) + } + + replicaCache := NewMemoryQueryCache() + replica.WithQueryCache(replicaCache) + + routed, err := WithReplicas(primary, []*DB{replica}, nil) + if err != nil { + t.Fatalf("WithReplicas: %v", err) + } + + if routed.Select().cache != replicaCache { + t.Fatalf("expected routed Select cache to preserve preconfigured replica cache") + } + if routed.Primary().Select().cache != replicaCache { + t.Fatalf("expected primary view Select cache to preserve preconfigured replica cache") + } + if primary.Select().cache != replicaCache { + t.Fatalf("expected primary handle Select cache to preserve preconfigured replica cache") + } +} + func TestWithReplicasCloseDeduplicatesUnderlyingHandles(t *testing.T) { t.Parallel()