diff --git a/spanner/batch.go b/spanner/batch.go index 6c999da23740..74088a8dbc75 100644 --- a/spanner/batch.go +++ b/spanner/batch.go @@ -104,13 +104,13 @@ func (t *BatchReadOnlyTransaction) PartitionRead(ctx context.Context, table stri // can be configured using PartitionOptions. Pass a ReadOptions to modify the // read operation. func (t *BatchReadOnlyTransaction) PartitionReadWithOptions(ctx context.Context, table string, keys KeySet, columns []string, opt PartitionOptions, readOptions ReadOptions) ([]*Partition, error) { - return t.PartitionReadUsingIndexWithOptions(ctx, table, "", keys, columns, opt, readOptions) + return t.PartitionReadUsingIndexWithOptions(ctx, table, "", keys, columns, opt, t.ReadOnlyTransaction.txReadOnly.ro.merge(readOptions)) } // PartitionReadUsingIndex returns a list of Partitions that can be used to read // rows from the database using an index. func (t *BatchReadOnlyTransaction) PartitionReadUsingIndex(ctx context.Context, table, index string, keys KeySet, columns []string, opt PartitionOptions) ([]*Partition, error) { - return t.PartitionReadUsingIndexWithOptions(ctx, table, index, keys, columns, opt, ReadOptions{}) + return t.PartitionReadUsingIndexWithOptions(ctx, table, index, keys, columns, opt, t.ReadOnlyTransaction.txReadOnly.ro) } // PartitionReadUsingIndexWithOptions returns a list of Partitions that can be diff --git a/spanner/batch_test.go b/spanner/batch_test.go index 0c77576b94f5..162b5298f9ff 100644 --- a/spanner/batch_test.go +++ b/spanner/batch_test.go @@ -22,8 +22,9 @@ import ( "testing" "time" - . "cloud.google.com/go/spanner/internal/testutil" sppb "google.golang.org/genproto/googleapis/spanner/v1" + + . "cloud.google.com/go/spanner/internal/testutil" ) func TestPartitionRoundTrip(t *testing.T) { @@ -120,6 +121,77 @@ func TestPartitionQuery_QueryOptions(t *testing.T) { } } +func TestPartitionQuery_ReadOptions(t *testing.T) { + testcases := []ReadOptionsTestCase{ + { + name: "Client level", + client: &ReadOptions{Index: "testIndex", Limit: 100, Priority: sppb.RequestOptions_PRIORITY_LOW, RequestTag: "testRequestTag"}, + // Index and Limit are always ignored + want: &ReadOptions{Index: "", Limit: 0, Priority: sppb.RequestOptions_PRIORITY_LOW, RequestTag: "testRequestTag"}, + }, + { + name: "Read level", + client: &ReadOptions{}, + read: &ReadOptions{Index: "testIndex", Limit: 100, Priority: sppb.RequestOptions_PRIORITY_LOW, RequestTag: "testRequestTag"}, + // Index and Limit are always ignored + want: &ReadOptions{Index: "", Limit: 0, Priority: sppb.RequestOptions_PRIORITY_LOW, RequestTag: "testRequestTag"}, + }, + { + name: "Read level has precedence than client level", + client: &ReadOptions{Index: "clientIndex", Limit: 10, Priority: sppb.RequestOptions_PRIORITY_LOW, RequestTag: "clientRequestTag"}, + read: &ReadOptions{Index: "readIndex", Limit: 20, Priority: sppb.RequestOptions_PRIORITY_MEDIUM, RequestTag: "readRequestTag"}, + // Index and Limit are always ignored + want: &ReadOptions{Index: "", Limit: 0, Priority: sppb.RequestOptions_PRIORITY_MEDIUM, RequestTag: "readRequestTag"}, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ReadOptions: *tt.client}) + defer teardown() + + var ( + err error + txn *BatchReadOnlyTransaction + ps []*Partition + ) + + if txn, err = client.BatchReadOnlyTransaction(ctx, StrongRead()); err != nil { + t.Fatal(err) + } + defer txn.Cleanup(ctx) + + if tt.read == nil { + ps, err = txn.PartitionRead(ctx, "Albums", KeySets(Key{"foo"}), []string{"SingerId", "AlbumId", "AlbumTitle"}, PartitionOptions{0, 3}) + } else { + ps, err = txn.PartitionReadWithOptions(ctx, "Albums", KeySets(Key{"foo"}), []string{"SingerId", "AlbumId", "AlbumTitle"}, PartitionOptions{0, 3}, *tt.read) + } + if err != nil { + t.Fatal(err) + } + + for _, p := range ps { + req := p.rreq + if got, want := req.Index, tt.want.Index; got != want { + t.Fatalf("Incorrect index: got %v, want %v", got, want) + } + if got, want := req.Limit, int64(tt.want.Limit); got != want { + t.Fatalf("Incorrect limit: got %v, want %v", got, want) + } + + ro := req.RequestOptions + if got, want := ro.Priority, tt.want.Priority; got != want { + t.Fatalf("Incorrect priority: got %v, want %v", got, want) + } + if got, want := ro.RequestTag, tt.want.RequestTag; got != want { + t.Fatalf("Incorrect request tag: got %v, want %v", got, want) + } + } + }) + } +} + func TestPartitionQuery_Parallel(t *testing.T) { ctx := context.Background() server, client, teardown := setupMockedTestServer(t) diff --git a/spanner/client.go b/spanner/client.go index 377949b2dd4d..a3f282cee6f9 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -25,8 +25,6 @@ import ( "time" "cloud.google.com/go/internal/trace" - vkit "cloud.google.com/go/spanner/apiv1" - "cloud.google.com/go/spanner/internal" "google.golang.org/api/option" "google.golang.org/api/option/internaloption" gtransport "google.golang.org/api/transport/grpc" @@ -35,6 +33,9 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + vkit "cloud.google.com/go/spanner/apiv1" + "cloud.google.com/go/spanner/internal" + // Install google-c2p resolver, which is required for direct path. _ "google.golang.org/grpc/xds/googledirectpath" // Install RLS load balancer policy, which is needed for gRPC RLS. @@ -86,6 +87,9 @@ type Client struct { idleSessions *sessionPool logger *log.Logger qo QueryOptions + ro ReadOptions + ao []ApplyOption + txo TransactionOptions ct *commonTags } @@ -117,6 +121,15 @@ type ClientConfig struct { // QueryOptions is the configuration for executing a sql query. QueryOptions QueryOptions + // ReadOptions is the configuration for reading rows from a database + ReadOptions ReadOptions + + // ApplyOptions is the configuration for applying + ApplyOptions []ApplyOption + + // TransactionOptions is the configuration for a transaction. + TransactionOptions TransactionOptions + // CallOptions is the configuration for providing custom retry settings that // override the default values. CallOptions *vkit.CallOptions @@ -211,6 +224,9 @@ func NewClientWithConfig(ctx context.Context, database string, config ClientConf idleSessions: sp, logger: config.logger, qo: getQueryOptions(config.QueryOptions), + ro: config.ReadOptions, + ao: config.ApplyOptions, + txo: config.TransactionOptions, ct: getCommonTags(sc), } return c, nil @@ -273,6 +289,7 @@ func (c *Client) Single() *ReadOnlyTransaction { t.txReadOnly.sp = c.idleSessions t.txReadOnly.txReadEnv = t t.txReadOnly.qo = c.qo + t.txReadOnly.ro = c.ro t.txReadOnly.replaceSessionFunc = func(ctx context.Context) error { if t.sh == nil { return spannerErrorf(codes.InvalidArgument, "missing session handle on transaction") @@ -309,6 +326,7 @@ func (c *Client) ReadOnlyTransaction() *ReadOnlyTransaction { t.txReadOnly.sp = c.idleSessions t.txReadOnly.txReadEnv = t t.txReadOnly.qo = c.qo + t.txReadOnly.ro = c.ro t.ct = c.ct return t } @@ -378,6 +396,7 @@ func (c *Client) BatchReadOnlyTransaction(ctx context.Context, tb TimestampBound t.txReadOnly.sh = sh t.txReadOnly.txReadEnv = t t.txReadOnly.qo = c.qo + t.txReadOnly.ro = c.ro t.ct = c.ct return t, nil } @@ -406,6 +425,7 @@ func (c *Client) BatchReadOnlyTransactionFromID(tid BatchReadOnlyTransactionID) t.txReadOnly.sh = sh t.txReadOnly.txReadEnv = t t.txReadOnly.qo = c.qo + t.txReadOnly.ro = c.ro t.ct = c.ct return t } @@ -491,7 +511,8 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea t.txReadOnly.sh = sh t.txReadOnly.txReadEnv = t t.txReadOnly.qo = c.qo - t.txOpts = options + t.txReadOnly.ro = c.ro + t.txOpts = c.txo.merge(options) t.ct = c.ct trace.TracePrintf(ctx, map[string]interface{}{"transactionID": string(sh.getTransactionID())}, @@ -555,6 +576,11 @@ func Priority(priority sppb.RequestOptions_Priority) ApplyOption { // Apply applies a list of mutations atomically to the database. func (c *Client) Apply(ctx context.Context, ms []*Mutation, opts ...ApplyOption) (commitTimestamp time.Time, err error) { ao := &applyOption{} + + for _, opt := range c.ao { + opt(ao) + } + for _, opt := range opts { opt(ao) } diff --git a/spanner/client_test.go b/spanner/client_test.go index e8c780fecee5..8ee27068049d 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -28,15 +28,16 @@ import ( "cloud.google.com/go/civil" itestutil "cloud.google.com/go/internal/testutil" - vkit "cloud.google.com/go/spanner/apiv1" - . "cloud.google.com/go/spanner/internal/testutil" structpb "github.com/golang/protobuf/ptypes/struct" - gax "github.com/googleapis/gax-go/v2" + "github.com/googleapis/gax-go/v2" "google.golang.org/api/iterator" "google.golang.org/api/option" sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + + vkit "cloud.google.com/go/spanner/apiv1" + . "cloud.google.com/go/spanner/internal/testutil" ) func setupMockedTestServer(t *testing.T) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) { @@ -416,6 +417,24 @@ func TestClient_Single_QueryOptions(t *testing.T) { } } +func TestClient_Single_ReadOptions(t *testing.T) { + for _, tt := range readOptionsTestCases() { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ReadOptions: *tt.client}) + defer teardown() + + var iter *RowIterator + if tt.read == nil { + iter = client.Single().Read(ctx, "Albums", KeySets(Key{"foo"}), []string{"SingerId", "AlbumId", "AlbumTitle"}) + } else { + iter = client.Single().ReadWithOptions(ctx, "Albums", KeySets(Key{"foo"}), []string{"SingerId", "AlbumId", "AlbumTitle"}, tt.read) + } + testReadOptions(t, iter, server.TestSpanner, *tt.want) + }) + } +} + func TestClient_ReturnDatabaseName(t *testing.T) { t.Parallel() @@ -463,6 +482,76 @@ func checkReqsForQueryOptions(t *testing.T, server InMemSpannerServer, qo QueryO } } +func testReadOptions(t *testing.T, iter *RowIterator, server InMemSpannerServer, ro ReadOptions) { + defer iter.Stop() + + _, err := iter.Next() + if err != nil { + t.Fatalf("Failed to read from the iterator: %v", err) + } + + checkReqsForReadOptions(t, server, ro) +} + +func checkReqsForReadOptions(t *testing.T, server InMemSpannerServer, ro ReadOptions) { + reqs := drainRequestsFromServer(server) + sqlReqs := []*sppb.ReadRequest{} + + for _, req := range reqs { + if sqlReq, ok := req.(*sppb.ReadRequest); ok { + sqlReqs = append(sqlReqs, sqlReq) + } + } + + if got, want := len(sqlReqs), 1; got != want { + t.Fatalf("Length mismatch, got %v, want %v", got, want) + } + + sqlReq := sqlReqs[0] + if got, want := sqlReq.Index, ro.Index; got != want { + t.Fatalf("Index mismatch, got %v, want %v", got, want) + } + if got, want := sqlReq.Limit, ro.Limit; got != int64(want) { + t.Fatalf("Limit mismatch, got %v, want %v", got, want) + } + + reqRequestOptions := sqlReq.RequestOptions + if got, want := reqRequestOptions.Priority, ro.Priority; got != want { + t.Fatalf("Priority mismatch, got %v, want %v", got, want) + } + if got, want := reqRequestOptions.RequestTag, ro.RequestTag; got != want { + t.Fatalf("Request tag mismatch, got %v, want %v", got, want) + } +} + +func checkReqsForTransactionOptions(t *testing.T, server InMemSpannerServer, txo TransactionOptions) { + reqs := drainRequestsFromServer(server) + sqlReqs := []*sppb.CommitRequest{} + + for _, req := range reqs { + if sqlReq, ok := req.(*sppb.CommitRequest); ok { + sqlReqs = append(sqlReqs, sqlReq) + } + } + + if got, want := len(sqlReqs), 1; got != want { + t.Fatalf("Length mismatch, got %v, want %v", got, want) + } + + sqlReq := sqlReqs[0] + if got, want := sqlReq.ReturnCommitStats, txo.CommitOptions.ReturnCommitStats; got != want { + t.Fatalf("Return commit stats mismatch, got %v, want %v", got, want) + } + + reqRequestOptions := sqlReq.RequestOptions + if got, want := reqRequestOptions.Priority, txo.CommitPriority; got != want { + t.Fatalf("Commit priority mismatch, got %v, want %v", got, want) + } + if got, want := reqRequestOptions.TransactionTag, txo.TransactionTag; got != want { + t.Fatalf("Transaction tag mismatch, got %v, want %v", got, want) + } +} + func testSingleQuery(t *testing.T, serverError error) error { ctx := context.Background() server, client, teardown := setupMockedTestServer(t) @@ -648,6 +737,27 @@ func TestClient_ReadOnlyTransaction_QueryOptions(t *testing.T) { } } +func TestClient_ReadOnlyTransaction_ReadOptions(t *testing.T) { + for _, tt := range readOptionsTestCases() { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ReadOptions: *tt.client}) + defer teardown() + + tx := client.ReadOnlyTransaction() + defer tx.Close() + + var iter *RowIterator + if tt.read == nil { + iter = tx.Read(ctx, "Albums", KeySets(Key{"foo"}), []string{"SingerId", "AlbumId", "AlbumTitle"}) + } else { + iter = tx.ReadWithOptions(ctx, "Albums", KeySets(Key{"foo"}), []string{"SingerId", "AlbumId", "AlbumTitle"}, tt.read) + } + testReadOptions(t, iter, server.TestSpanner, *tt.want) + }) + } +} + func setQueryOptionsEnvVars(opts *sppb.ExecuteSqlRequest_QueryOptions) func() { os.Setenv("SPANNER_OPTIMIZER_VERSION", opts.OptimizerVersion) os.Setenv("SPANNER_OPTIMIZER_STATISTICS_PACKAGE", opts.OptimizerStatisticsPackage) @@ -821,6 +931,30 @@ func TestClient_ReadWriteTransaction_Query_QueryOptions(t *testing.T) { } } +func TestClient_ReadWriteTransaction_Query_ReadOptions(t *testing.T) { + for _, tt := range readOptionsTestCases() { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ReadOptions: *tt.client}) + defer teardown() + + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + var iter *RowIterator + if tt.read == nil { + iter = tx.Read(ctx, "Albums", KeySets(Key{"foo"}), []string{"SingerId", "AlbumId", "AlbumTitle"}) + } else { + iter = tx.ReadWithOptions(ctx, "Albums", KeySets(Key{"foo"}), []string{"SingerId", "AlbumId", "AlbumTitle"}, tt.read) + } + testReadOptions(t, iter, server.TestSpanner, *tt.want) + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + } +} + func TestClient_ReadWriteTransaction_Update_QueryOptions(t *testing.T) { for _, tt := range queryOptionsTestCases() { t.Run(tt.name, func(t *testing.T) { @@ -854,6 +988,32 @@ func TestClient_ReadWriteTransaction_Update_QueryOptions(t *testing.T) { } } +func TestClient_ReadWriteTransaction_TransactionOptions(t *testing.T) { + for _, tt := range transactionOptionsTestCases() { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{TransactionOptions: *tt.client}) + defer teardown() + + var err error + if tt.write == nil { + _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + return nil + }) + } else { + _, err = client.ReadWriteTransactionWithOptions(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + return nil + }, *tt.write) + } + + if err != nil { + t.Fatalf("Failed executing a read-write transaction: %v", err) + } + checkReqsForTransactionOptions(t, server.TestSpanner, *tt.want) + }) + } +} + func TestClient_ReadWriteTransactionWithOptions(t *testing.T) { _, client, teardown := setupMockedTestServer(t) defer teardown() @@ -890,6 +1050,32 @@ func TestClient_ReadWriteTransactionWithOptions(t *testing.T) { } } +func TestClient_ReadWriteStmtBasedTransaction_TransactionOptions(t *testing.T) { + for _, tt := range transactionOptionsTestCases() { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{TransactionOptions: *tt.client}) + defer teardown() + + var tx *ReadWriteStmtBasedTransaction + var err error + if tt.write == nil { + tx, err = NewReadWriteStmtBasedTransaction(ctx, client) + } else { + tx, err = NewReadWriteStmtBasedTransactionWithOptions(ctx, client, *tt.write) + } + + if err != nil { + t.Fatalf("Failed initializing a read-write stmt based transaction: %v", err) + } + + if got, want := tx.txOpts, *tt.want; got != want { + t.Fatalf("Transaction options mismatch, got %v, want %v", got, want) + } + }) + } +} + func TestClient_ReadWriteStmtBasedTransactionWithOptions(t *testing.T) { _, client, teardown := setupMockedTestServer(t) defer teardown() @@ -1393,6 +1579,70 @@ func TestClient_ApplyAtLeastOnceInvalidArgument(t *testing.T) { } } +func TestClient_Apply_ApplyOptions(t *testing.T) { + t.Parallel() + + testcases := []struct { + name string + client []ApplyOption + apply []ApplyOption + wantTransactionTag string + wantPriority sppb.RequestOptions_Priority + }{ + { + name: "At least once & client level", + client: []ApplyOption{ApplyAtLeastOnce(), TransactionTag("testTransactionTag"), Priority(sppb.RequestOptions_PRIORITY_LOW)}, + wantTransactionTag: "testTransactionTag", + wantPriority: sppb.RequestOptions_PRIORITY_LOW, + }, + { + name: "Not at least once & client level", + client: []ApplyOption{TransactionTag("testTransactionTag"), Priority(sppb.RequestOptions_PRIORITY_LOW)}, + wantTransactionTag: "testTransactionTag", + wantPriority: sppb.RequestOptions_PRIORITY_LOW, + }, + { + name: "At least once & apply level", + apply: []ApplyOption{ApplyAtLeastOnce(), TransactionTag("testTransactionTag"), Priority(sppb.RequestOptions_PRIORITY_LOW)}, + wantTransactionTag: "testTransactionTag", + wantPriority: sppb.RequestOptions_PRIORITY_LOW, + }, + { + name: "Not at least once & apply level", + apply: []ApplyOption{TransactionTag("testTransactionTag"), Priority(sppb.RequestOptions_PRIORITY_LOW)}, + wantTransactionTag: "testTransactionTag", + wantPriority: sppb.RequestOptions_PRIORITY_LOW, + }, + { + name: "At least once & query level has precedence than client level", + client: []ApplyOption{ApplyAtLeastOnce(), TransactionTag("clientTransactionTag"), Priority(sppb.RequestOptions_PRIORITY_LOW)}, + apply: []ApplyOption{ApplyAtLeastOnce(), TransactionTag("applyTransactionTag"), Priority(sppb.RequestOptions_PRIORITY_MEDIUM)}, + wantTransactionTag: "applyTransactionTag", + wantPriority: sppb.RequestOptions_PRIORITY_MEDIUM, + }, + { + name: "Not at least once & apply level", + client: []ApplyOption{TransactionTag("clientTransactionTag"), Priority(sppb.RequestOptions_PRIORITY_LOW)}, + apply: []ApplyOption{TransactionTag("applyTransactionTag"), Priority(sppb.RequestOptions_PRIORITY_MEDIUM)}, + wantTransactionTag: "applyTransactionTag", + wantPriority: sppb.RequestOptions_PRIORITY_MEDIUM, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ApplyOptions: tt.client}) + defer teardown() + + _, err := client.Apply(context.Background(), []*Mutation{Insert("foo", []string{"col1"}, []interface{}{"val1"})}, tt.apply...) + if err != nil { + t.Fatalf("failed applying mutations: %v", err) + } + checkCommitForExpectedRequestOptions(t, server.TestSpanner, sppb.RequestOptions{Priority: tt.wantPriority, TransactionTag: tt.wantTransactionTag}) + }) + } +} + func TestReadWriteTransaction_ErrUnexpectedEOF(t *testing.T) { t.Parallel() _, client, teardown := setupMockedTestServer(t) @@ -2296,6 +2546,45 @@ func TestBatchReadOnlyTransactionFromID_QueryOptions(t *testing.T) { } } +func TestBatchReadOnlyTransaction_ReadOptions(t *testing.T) { + ctx := context.Background() + ro := ReadOptions{ + Index: "testIndex", + Limit: 100, + Priority: sppb.RequestOptions_PRIORITY_LOW, + RequestTag: "testRequestTag", + } + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ReadOptions: ro}) + defer teardown() + + txn, err := client.BatchReadOnlyTransaction(ctx, StrongRead()) + if err != nil { + t.Fatal(err) + } + defer txn.Cleanup(ctx) + + if txn.ro != ro { + t.Fatalf("Read options are mismatched: got %v, want %v", txn.ro, ro) + } +} + +func TestBatchReadOnlyTransactionFromID_ReadOptions(t *testing.T) { + ro := ReadOptions{ + Index: "testIndex", + Limit: 100, + Priority: sppb.RequestOptions_PRIORITY_LOW, + RequestTag: "testRequestTag", + } + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ReadOptions: ro}) + defer teardown() + + txn := client.BatchReadOnlyTransactionFromID(BatchReadOnlyTransactionID{}) + + if txn.ro != ro { + t.Fatalf("Read options are mismatched: got %v, want %v", txn.ro, ro) + } +} + type QueryOptionsTestCase struct { name string client QueryOptions @@ -2352,6 +2641,64 @@ func queryOptionsTestCases() []QueryOptionsTestCase { } } +type ReadOptionsTestCase struct { + name string + client *ReadOptions + read *ReadOptions + want *ReadOptions +} + +func readOptionsTestCases() []ReadOptionsTestCase { + return []ReadOptionsTestCase{ + { + name: "Client level", + client: &ReadOptions{Index: "testIndex", Limit: 100, Priority: sppb.RequestOptions_PRIORITY_LOW, RequestTag: "testRequestTag"}, + want: &ReadOptions{Index: "testIndex", Limit: 100, Priority: sppb.RequestOptions_PRIORITY_LOW, RequestTag: "testRequestTag"}, + }, + { + name: "Read level", + client: &ReadOptions{}, + read: &ReadOptions{Index: "testIndex", Limit: 100, Priority: sppb.RequestOptions_PRIORITY_LOW, RequestTag: "testRequestTag"}, + want: &ReadOptions{Index: "testIndex", Limit: 100, Priority: sppb.RequestOptions_PRIORITY_LOW, RequestTag: "testRequestTag"}, + }, + { + name: "Read level has precedence than client level", + client: &ReadOptions{Index: "clientIndex", Limit: 10, Priority: sppb.RequestOptions_PRIORITY_LOW, RequestTag: "clientRequestTag"}, + read: &ReadOptions{Index: "readIndex", Limit: 20, Priority: sppb.RequestOptions_PRIORITY_MEDIUM, RequestTag: "readRequestTag"}, + want: &ReadOptions{Index: "readIndex", Limit: 20, Priority: sppb.RequestOptions_PRIORITY_MEDIUM, RequestTag: "readRequestTag"}, + }, + } +} + +type TransactionOptionsTestCase struct { + name string + client *TransactionOptions + write *TransactionOptions + want *TransactionOptions +} + +func transactionOptionsTestCases() []TransactionOptionsTestCase { + return []TransactionOptionsTestCase{ + { + name: "Client level", + client: &TransactionOptions{CommitOptions: CommitOptions{ReturnCommitStats: true}, TransactionTag: "testTransactionTag", CommitPriority: sppb.RequestOptions_PRIORITY_LOW}, + want: &TransactionOptions{CommitOptions: CommitOptions{ReturnCommitStats: true}, TransactionTag: "testTransactionTag", CommitPriority: sppb.RequestOptions_PRIORITY_LOW}, + }, + { + name: "Write level", + client: &TransactionOptions{}, + write: &TransactionOptions{CommitOptions: CommitOptions{ReturnCommitStats: true}, TransactionTag: "testTransactionTag", CommitPriority: sppb.RequestOptions_PRIORITY_LOW}, + want: &TransactionOptions{CommitOptions: CommitOptions{ReturnCommitStats: true}, TransactionTag: "testTransactionTag", CommitPriority: sppb.RequestOptions_PRIORITY_LOW}, + }, + { + name: "Write level has precedence than client level", + client: &TransactionOptions{CommitOptions: CommitOptions{ReturnCommitStats: false}, TransactionTag: "clientTransactionTag", CommitPriority: sppb.RequestOptions_PRIORITY_LOW}, + write: &TransactionOptions{CommitOptions: CommitOptions{ReturnCommitStats: true}, TransactionTag: "writeTransactionTag", CommitPriority: sppb.RequestOptions_PRIORITY_MEDIUM}, + want: &TransactionOptions{CommitOptions: CommitOptions{ReturnCommitStats: true}, TransactionTag: "writeTransactionTag", CommitPriority: sppb.RequestOptions_PRIORITY_MEDIUM}, + }, + } +} + func TestClient_DoForEachRow_ShouldNotEndSpanWithIteratorDoneError(t *testing.T) { // This test cannot be parallel, as the TestExporter does not support that. te := itestutil.NewTestExporter() diff --git a/spanner/transaction.go b/spanner/transaction.go index 8f9283092346..bf1e10944229 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -23,7 +23,6 @@ import ( "time" "cloud.google.com/go/internal/trace" - vkit "cloud.google.com/go/spanner/apiv1" "github.com/golang/protobuf/proto" "github.com/googleapis/gax-go/v2" "google.golang.org/api/iterator" @@ -32,6 +31,8 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + + vkit "cloud.google.com/go/spanner/apiv1" ) // transactionID stores a transaction ID which uniquely identifies a transaction @@ -76,6 +77,9 @@ type txReadOnly struct { // qo provides options for executing a sql query. qo QueryOptions + // ro provides options for reading rows from a database. + ro ReadOptions + // txOpts provides options for a transaction. txOpts TransactionOptions @@ -97,12 +101,21 @@ type TransactionOptions struct { CommitPriority sppb.RequestOptions_Priority } -func (to *TransactionOptions) requestPriority() sppb.RequestOptions_Priority { - return to.CommitPriority -} - -func (to *TransactionOptions) requestTag() string { - return "" +// merge combines two TransactionOptions that the input parameter will have higher +// order of precedence. +func (to TransactionOptions) merge(opts TransactionOptions) TransactionOptions { + merged := TransactionOptions{ + CommitOptions: to.CommitOptions.merge(opts.CommitOptions), + TransactionTag: to.TransactionTag, + CommitPriority: to.CommitPriority, + } + if opts.TransactionTag != "" { + merged.TransactionTag = opts.TransactionTag + } + if opts.CommitPriority != sppb.RequestOptions_PRIORITY_UNSPECIFIED { + merged.CommitPriority = opts.CommitPriority + } + return merged } // errSessionClosed returns error for using a recycled/destroyed session @@ -139,6 +152,30 @@ type ReadOptions struct { RequestTag string } +// merge combines two ReadOptions that the input parameter will have higher +// order of precedence. +func (ro ReadOptions) merge(opts ReadOptions) ReadOptions { + merged := ReadOptions{ + Index: ro.Index, + Limit: ro.Limit, + Priority: ro.Priority, + RequestTag: ro.RequestTag, + } + if opts.Index != "" { + merged.Index = opts.Index + } + if opts.Limit > 0 { + merged.Limit = opts.Limit + } + if opts.Priority != sppb.RequestOptions_PRIORITY_UNSPECIFIED { + merged.Priority = opts.Priority + } + if opts.RequestTag != "" { + merged.RequestTag = opts.RequestTag + } + return merged +} + // ReadWithOptions returns a RowIterator for reading multiple rows from the // database. Pass a ReadOptions to modify the read operation. func (t *txReadOnly) ReadWithOptions(ctx context.Context, table string, keys KeySet, columns []string, opts *ReadOptions) (ri *RowIterator) { @@ -162,10 +199,10 @@ func (t *txReadOnly) ReadWithOptions(ctx context.Context, table string, keys Key // Might happen if transaction is closed in the middle of a API call. return &RowIterator{err: errSessionClosed(sh)} } - index := "" - limit := 0 - prio := sppb.RequestOptions_PRIORITY_UNSPECIFIED - requestTag := "" + index := t.ro.Index + limit := t.ro.Limit + prio := t.ro.Priority + requestTag := t.ro.RequestTag if opts != nil { index = opts.Index if opts.Limit > 0 { @@ -1106,11 +1143,19 @@ type CommitResponse struct { CommitStats *sppb.CommitResponse_CommitStats } -// CommitOptions provides options for commiting a transaction in a database. +// CommitOptions provides options for committing a transaction in a database. type CommitOptions struct { ReturnCommitStats bool } +// merge combines two CommitOptions that the input parameter will have higher +// order of precedence. +func (co CommitOptions) merge(opts CommitOptions) CommitOptions { + return CommitOptions{ + ReturnCommitStats: co.ReturnCommitStats || opts.ReturnCommitStats, + } +} + // commit tries to commit a readwrite transaction to Cloud Spanner. It also // returns the commit response for the transactions. func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions) (CommitResponse, error) { @@ -1275,7 +1320,8 @@ func NewReadWriteStmtBasedTransactionWithOptions(ctx context.Context, c *Client, t.txReadOnly.sh = sh t.txReadOnly.txReadEnv = t t.txReadOnly.qo = c.qo - t.txOpts = options + t.txReadOnly.ro = c.ro + t.txOpts = c.txo.merge(options) t.ct = c.ct if err = t.begin(ctx); err != nil {