diff --git a/conn.go b/conn.go index 9ee4efb1..698c8274 100644 --- a/conn.go +++ b/conn.go @@ -222,6 +222,9 @@ type SpannerConn interface { // return the same Spanner client. UnderlyingClient() (client *spanner.Client, err error) + // DetectStatementType returns the type of SQL statement. + DetectStatementType(query string) parser.StatementType + // resetTransactionForRetry resets the current transaction after it has // been aborted by Spanner. Calling this function on a transaction that // has not been aborted is not supported and will cause an error to be @@ -286,6 +289,11 @@ func (c *conn) UnderlyingClient() (*spanner.Client, error) { return c.client, nil } +func (c *conn) DetectStatementType(query string) parser.StatementType { + info := c.parser.DetectStatementType(query) + return info.StatementType +} + func (c *conn) CommitTimestamp() (time.Time, error) { ts := propertyCommitTimestamp.GetValueOrDefault(c.state) if ts == nil { @@ -675,6 +683,27 @@ func sum(affected []int64) int64 { return sum } +// WriteMutations is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// WriteMutations writes mutations using this connection. The mutations are either buffered in the current transaction, +// or written directly to Spanner using a new read/write transaction if the connection does not have a transaction. +// +// The function returns an error if the connection currently has a read-only transaction. +// +// The returned CommitResponse is nil if the connection currently has a transaction, as the mutations will only be +// applied to Spanner when the transaction commits. +func (c *conn) WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error) { + if c.inTransaction() { + return nil, c.BufferWrite(ms) + } + ts, err := c.Apply(ctx, ms) + if err != nil { + return nil, err + } + return &spanner.CommitResponse{CommitTs: ts}, nil +} + func (c *conn) Apply(ctx context.Context, ms []*spanner.Mutation, opts ...spanner.ApplyOption) (commitTimestamp time.Time, err error) { if c.inTransaction() { return time.Time{}, spanner.ToSpannerError( diff --git a/spannerlib/api/batch_test.go b/spannerlib/api/batch_test.go new file mode 100644 index 00000000..1ae9af14 --- /dev/null +++ b/spannerlib/api/batch_test.go @@ -0,0 +1,238 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "fmt" + "reflect" + "testing" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/googleapis/go-sql-spanner/testutil" + "google.golang.org/grpc/codes" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/emptypb" +) + +func TestExecuteDmlBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Execute a DML batch. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: testutil.UpdateBarSetFoo}, + {Sql: testutil.UpdateBarSetFoo}, + }} + resp, err := ExecuteBatch(ctx, poolId, connId, request) + if err != nil { + t.Fatalf("ExecuteBatch returned unexpected error: %v", err) + } + if g, w := len(resp.ResultSets), 2; g != w { + t.Fatalf("num results mismatch\n Got: %d\nWant: %d", g, w) + } + for i, result := range resp.ResultSets { + if g, w := result.Stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("%d: update count mismatch\n Got: %d\nWant: %d", i, g, w) + } + } + + requests := server.TestSpanner.DrainRequestsFromServer() + // There should be no ExecuteSql requests. + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 0; g != w { + t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w) + } + batchRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteBatchDmlRequest{})) + if g, w := len(batchRequests), 1; g != w { + t.Fatalf("Execute batch request count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestExecuteDdlBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + // Set up a result for a DDL statement on the mock server. + var expectedResponse = &emptypb.Empty{} + anyMsg, _ := anypb.New(expectedResponse) + server.TestDatabaseAdmin.SetResps([]proto.Message{ + &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Response{Response: anyMsg}, + Name: "test-operation", + }, + }) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Execute a DDL batch. This also uses a DML batch request. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: "create table my_table (id int64 primary key, value string(100))"}, + {Sql: "create index my_index on my_table (value)"}, + }} + resp, err := ExecuteBatch(ctx, poolId, connId, request) + if err != nil { + t.Fatalf("ExecuteBatch returned unexpected error: %v", err) + } + // The response should contain an 'update count' per DDL statement. + if g, w := len(resp.ResultSets), 2; g != w { + t.Fatalf("num results mismatch\n Got: %d\nWant: %d", g, w) + } + // There is no update count for DDL statements. + for i, result := range resp.ResultSets { + emptyStats := &spannerpb.ResultSetStats{} + if g, w := result.Stats, emptyStats; !cmp.Equal(g, w, cmpopts.IgnoreUnexported(spannerpb.ResultSetStats{})) { + t.Fatalf("%d: ResultSetStats mismatch\n Got: %v\nWant: %v", i, g, w) + } + } + + requests := server.TestSpanner.DrainRequestsFromServer() + // There should be no ExecuteSql requests. + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 0; g != w { + t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w) + } + // There should also be no ExecuteBatchDml requests. + batchDmlRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteBatchDmlRequest{})) + if g, w := len(batchDmlRequests), 0; g != w { + t.Fatalf("ExecuteBatchDmlRequest count mismatch\n Got: %v\nWant: %v", g, w) + } + + adminRequests := server.TestDatabaseAdmin.Reqs() + if g, w := len(adminRequests), 1; g != w { + t.Fatalf("admin request count mismatch\n Got: %v\nWant: %v", g, w) + } + ddlRequest := adminRequests[0].(*databasepb.UpdateDatabaseDdlRequest) + if g, w := len(ddlRequest.Statements), 2; g != w { + t.Fatalf("DDL statement count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestExecuteMixedBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Try to execute a batch with mixed DML and DDL statements. This should fail. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: "create table my_table (id int64 primary key, value string(100))"}, + {Sql: "update my_table set value = 100 where true"}, + }} + _, err = ExecuteBatch(ctx, poolId, connId, request) + if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestExecuteDdlBatchInTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Try to execute a DDL batch in a transaction. This should fail. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: "create table my_table (id int64 primary key, value string(100))"}, + {Sql: "create index my_index on my_table (value)"}, + }} + _, err = ExecuteBatch(ctx, poolId, connId, request) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} diff --git a/spannerlib/api/connection.go b/spannerlib/api/connection.go index f5f7643a..907212ec 100644 --- a/spannerlib/api/connection.go +++ b/spannerlib/api/connection.go @@ -26,6 +26,7 @@ import ( "cloud.google.com/go/spanner" "cloud.google.com/go/spanner/apiv1/spannerpb" spannerdriver "github.com/googleapis/go-sql-spanner" + "github.com/googleapis/go-sql-spanner/parser" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" @@ -46,6 +47,22 @@ func CloseConnection(ctx context.Context, poolId, connId int64) error { return conn.close(ctx) } +// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in +// the current read/write transaction if the connection currently has a read/write transaction. +// The mutations are applied to the database in a new read/write transaction that is automatically +// committed if the connection currently does not have a transaction. +// +// The function returns an error if the connection is currently in a read-only transaction. +// +// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object. +func WriteMutations(ctx context.Context, poolId, connId int64, mutations *spannerpb.BatchWriteRequest_MutationGroup) (*spannerpb.CommitResponse, error) { + conn, err := findConnection(poolId, connId) + if err != nil { + return nil, err + } + return conn.writeMutations(ctx, mutations) +} + // BeginTransaction starts a new transaction on the given connection. // A connection can have at most one transaction at any time. This function therefore returns an error if the // connection has an active transaction. @@ -83,6 +100,14 @@ func Execute(ctx context.Context, poolId, connId int64, executeSqlRequest *spann return conn.Execute(ctx, executeSqlRequest) } +func ExecuteBatch(ctx context.Context, poolId, connId int64, statements *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { + conn, err := findConnection(poolId, connId) + if err != nil { + return nil, err + } + return conn.ExecuteBatch(ctx, statements.Statements) +} + type Connection struct { // results contains the open query results for this connection. results *sync.Map @@ -95,6 +120,7 @@ type Connection struct { // spannerConn is an internal interface that contains the internal functions that are used by this API. // It is implemented by the spannerdriver.conn struct. type spannerConn interface { + WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error) BeginReadOnlyTransaction(ctx context.Context, options *spannerdriver.ReadOnlyTransactionOptions) (driver.Tx, error) BeginReadWriteTransaction(ctx context.Context, options *spannerdriver.ReadWriteTransactionOptions) (driver.Tx, error) Commit(ctx context.Context) (*spanner.CommitResponse, error) @@ -118,6 +144,34 @@ func (conn *Connection) close(ctx context.Context) error { return nil } +func (conn *Connection) writeMutations(ctx context.Context, mutation *spannerpb.BatchWriteRequest_MutationGroup) (*spannerpb.CommitResponse, error) { + mutations := make([]*spanner.Mutation, 0, len(mutation.Mutations)) + for _, m := range mutation.Mutations { + spannerMutation, err := spanner.WrapMutation(m) + if err != nil { + return nil, err + } + mutations = append(mutations, spannerMutation) + } + var commitResponse *spanner.CommitResponse + if err := conn.backend.Raw(func(driverConn any) (err error) { + sc, _ := driverConn.(spannerConn) + commitResponse, err = sc.WriteMutations(ctx, mutations) + return err + }); err != nil { + return nil, err + } + + // The commit response is nil if the connection is currently in a transaction. + if commitResponse == nil { + return nil, nil + } + response := spannerpb.CommitResponse{ + CommitTimestamp: timestamppb.New(commitResponse.CommitTs), + } + return &response, nil +} + func (conn *Connection) BeginTransaction(ctx context.Context, txOpts *spannerpb.TransactionOptions) error { var err error if txOpts.GetReadOnly() != nil { @@ -235,6 +289,10 @@ func (conn *Connection) Execute(ctx context.Context, statement *spannerpb.Execut return execute(ctx, conn, conn.backend, statement) } +func (conn *Connection) ExecuteBatch(ctx context.Context, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + return executeBatch(ctx, conn, conn.backend, statements) +} + func execute(ctx context.Context, conn *Connection, executor queryExecutor, statement *spannerpb.ExecuteSqlRequest) (int64, error) { params := extractParams(statement) it, err := executor.QueryContext(ctx, statement.Sql, params...) @@ -266,6 +324,90 @@ func execute(ctx context.Context, conn *Connection, executor queryExecutor, stat return id, nil } +func executeBatch(ctx context.Context, conn *Connection, executor queryExecutor, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + // Determine the type of batch that should be executed based on the type of statements. + batchType, err := determineBatchType(conn, statements) + if err != nil { + return nil, err + } + switch batchType { + case parser.BatchTypeDml: + return executeBatchDml(ctx, conn, executor, statements) + case parser.BatchTypeDdl: + return executeBatchDdl(ctx, conn, executor, statements) + default: + return nil, status.Errorf(codes.InvalidArgument, "unsupported batch type: %v", batchType) + } +} + +func executeBatchDdl(ctx context.Context, conn *Connection, executor queryExecutor, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + if err := conn.backend.Raw(func(driverConn any) error { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + return spannerConn.StartBatchDDL() + }); err != nil { + return nil, err + } + for _, statement := range statements { + _, err := executor.ExecContext(ctx, statement.Sql) + if err != nil { + return nil, err + } + } + // TODO: Add support for getting the actual Batch DDL response. + if err := conn.backend.Raw(func(driverConn any) (err error) { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + return spannerConn.RunBatch(ctx) + }); err != nil { + return nil, err + } + + response := spannerpb.ExecuteBatchDmlResponse{} + response.ResultSets = make([]*spannerpb.ResultSet, len(statements)) + for i := range statements { + response.ResultSets[i] = &spannerpb.ResultSet{Stats: &spannerpb.ResultSetStats{}} + } + return &response, nil +} + +func executeBatchDml(ctx context.Context, conn *Connection, executor queryExecutor, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + if err := conn.backend.Raw(func(driverConn any) error { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + return spannerConn.StartBatchDML() + }); err != nil { + return nil, err + } + for _, statement := range statements { + request := &spannerpb.ExecuteSqlRequest{ + Sql: statement.Sql, + Params: statement.Params, + ParamTypes: statement.ParamTypes, + } + params := extractParams(request) + _, err := executor.ExecContext(ctx, statement.Sql, params...) + if err != nil { + return nil, err + } + } + var spannerResult spannerdriver.SpannerResult + if err := conn.backend.Raw(func(driverConn any) (err error) { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + spannerResult, err = spannerConn.RunDmlBatch(ctx) + return err + }); err != nil { + return nil, err + } + affected, err := spannerResult.BatchRowsAffected() + if err != nil { + return nil, err + } + response := spannerpb.ExecuteBatchDmlResponse{} + response.ResultSets = make([]*spannerpb.ResultSet, len(affected)) + for i, aff := range affected { + response.ResultSets[i] = &spannerpb.ResultSet{Stats: &spannerpb.ResultSetStats{RowCount: &spannerpb.ResultSetStats_RowCountExact{RowCountExact: aff}}} + } + return &response, nil +} + func extractParams(statement *spannerpb.ExecuteSqlRequest) []any { paramsLen := 1 if statement.Params != nil { @@ -300,3 +442,34 @@ func extractParams(statement *spannerpb.ExecuteSqlRequest) []any { } return params } + +func determineBatchType(conn *Connection, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (parser.BatchType, error) { + if len(statements) == 0 { + return parser.BatchTypeDdl, status.Errorf(codes.InvalidArgument, "cannot determine type of an empty batch") + } + var batchType parser.BatchType + if err := conn.backend.Raw(func(driverConn any) error { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + firstStatementType := spannerConn.DetectStatementType(statements[0].Sql) + if firstStatementType == parser.StatementTypeDml { + batchType = parser.BatchTypeDml + } else if firstStatementType == parser.StatementTypeDdl { + batchType = parser.BatchTypeDdl + } else { + return status.Errorf(codes.InvalidArgument, "unsupported statement type for batching: %v", firstStatementType) + } + for i, statement := range statements { + if i > 0 { + tp := spannerConn.DetectStatementType(statement.Sql) + if tp != firstStatementType { + return status.Errorf(codes.InvalidArgument, "Batches may not contain different types of statements. The first statement is of type %v. The statement on position %d is of type %v.", firstStatementType, i, tp) + } + } + } + return nil + }); err != nil { + return parser.BatchTypeDdl, err + } + + return batchType, nil +} diff --git a/spannerlib/api/connection_test.go b/spannerlib/api/connection_test.go index 64d19ab3..b4625bb3 100644 --- a/spannerlib/api/connection_test.go +++ b/spannerlib/api/connection_test.go @@ -24,6 +24,7 @@ import ( "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/googleapis/go-sql-spanner/testutil" "google.golang.org/grpc/codes" + "google.golang.org/protobuf/types/known/structpb" ) func TestCreateAndCloseConnection(t *testing.T) { @@ -143,3 +144,159 @@ func TestCloseConnectionTwice(t *testing.T) { t.Fatalf("ClosePool returned unexpected error: %v", err) } } + +func TestWriteMutations(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + {Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}}, + {Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}}, + }, + }}}, + {Operation: &spannerpb.Mutation_Update{Update: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("0"), structpb.NewStringValue("Zero")}}, + }, + }}}, + }} + resp, err := WriteMutations(ctx, poolId, connId, mutations) + if err != nil { + t.Fatalf("WriteMutations returned unexpected error: %v", err) + } + if resp.CommitTimestamp == nil { + t.Fatalf("CommitTimestamp is nil") + } + requests := server.TestSpanner.DrainRequestsFromServer() + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequest := commitRequests[0].(*spannerpb.CommitRequest) + if g, w := len(commitRequest.Mutations), 2; g != w { + t.Fatalf("num mutations mismatch\n Got: %d\nWant: %d", g, w) + } + + // Write the same mutations in a transaction. + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + resp, err = WriteMutations(ctx, poolId, connId, mutations) + if err != nil { + t.Fatalf("WriteMutations returned unexpected error: %v", err) + } + if resp != nil { + t.Fatalf("WriteMutations returned unexpected response: %v", resp) + } + resp, err = Commit(ctx, poolId, connId) + if err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + if resp == nil { + t.Fatalf("Commit returned nil response") + } + if resp.CommitTimestamp == nil { + t.Fatalf("CommitTimestamp is nil") + } + requests = server.TestSpanner.DrainRequestsFromServer() + beginRequests = testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequests = testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequest = commitRequests[0].(*spannerpb.CommitRequest) + if g, w := len(commitRequest.Mutations), 2; g != w { + t.Fatalf("num mutations mismatch\n Got: %d\nWant: %d", g, w) + } + + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestWriteMutationsInReadOnlyTx(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Start a read-only transaction and try to write mutations to that transaction. That should return an error. + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadOnly_{ReadOnly: &spannerpb.TransactionOptions_ReadOnly{}}, + }); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + }, + }}}, + }} + _, err = WriteMutations(ctx, poolId, connId, mutations) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("WriteMutations error code mismatch\n Got: %d\nWant: %d", g, w) + } + + // Committing the read-only transaction should not lead to any commits on Spanner. + _, err = Commit(ctx, poolId, connId) + if err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + requests := server.TestSpanner.DrainRequestsFromServer() + // There should also not be any BeginTransaction requests on Spanner, as the transaction was never really started + // by a query or other statement. + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 0; g != w { + t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w) + } + + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} diff --git a/spannerlib/lib/connection.go b/spannerlib/lib/connection.go index d41b6a6e..75c6bc3a 100644 --- a/spannerlib/lib/connection.go +++ b/spannerlib/lib/connection.go @@ -34,6 +34,30 @@ func CloseConnection(ctx context.Context, poolId, connId int64) *Message { return &Message{} } +// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in +// the current read/write transaction if the connection currently has a read/write transaction. +// The mutations are applied to the database in a new read/write transaction that is automatically +// committed if the connection currently does not have a transaction. +// +// The function returns an error if the connection is currently in a read-only transaction. +// +// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object. +func WriteMutations(ctx context.Context, poolId, connId int64, mutationBytes []byte) *Message { + mutations := spannerpb.BatchWriteRequest_MutationGroup{} + if err := proto.Unmarshal(mutationBytes, &mutations); err != nil { + return errMessage(err) + } + response, err := api.WriteMutations(ctx, poolId, connId, &mutations) + if err != nil { + return errMessage(err) + } + res, err := proto.Marshal(response) + if err != nil { + return errMessage(err) + } + return &Message{Res: res} +} + // BeginTransaction starts a new transaction on the given connection. A connection can have at most one active // transaction at any time. This function therefore returns an error if the connection has an active transaction. func BeginTransaction(ctx context.Context, poolId, connId int64, txOptsBytes []byte) *Message { @@ -89,3 +113,19 @@ func Execute(ctx context.Context, poolId, connId int64, executeSqlRequestBytes [ } return idMessage(id) } + +func ExecuteBatch(ctx context.Context, poolId, connId int64, statementsBytes []byte) *Message { + statements := spannerpb.ExecuteBatchDmlRequest{} + if err := proto.Unmarshal(statementsBytes, &statements); err != nil { + return errMessage(err) + } + response, err := api.ExecuteBatch(ctx, poolId, connId, &statements) + if err != nil { + return errMessage(err) + } + res, err := proto.Marshal(response) + if err != nil { + return errMessage(err) + } + return &Message{Res: res} +} diff --git a/spannerlib/lib/connection_test.go b/spannerlib/lib/connection_test.go index 02ba453c..193e3fd9 100644 --- a/spannerlib/lib/connection_test.go +++ b/spannerlib/lib/connection_test.go @@ -23,6 +23,7 @@ import ( "github.com/googleapis/go-sql-spanner/testutil" "google.golang.org/grpc/codes" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" ) func TestCreateAndCloseConnection(t *testing.T) { @@ -117,6 +118,48 @@ func TestExecute(t *testing.T) { } } +func TestExecuteBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: testutil.UpdateBarSetFoo}, + {Sql: testutil.UpdateBarSetFoo}, + }} + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + rowsMsg := ExecuteBatch(ctx, poolMsg.ObjectId, connMsg.ObjectId, requestBytes) + if g, w := rowsMsg.Code, int32(0); g != w { + t.Fatalf("ExecuteBatch result mismatch\n Got: %v\nWant: %v", g, w) + } + if rowsMsg.Length() == 0 { + t.Fatal("ExecuteBatch returned no data") + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + func TestBeginAndCommit(t *testing.T) { t.Parallel() @@ -220,3 +263,62 @@ func TestBeginAndRollback(t *testing.T) { t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) } } + +func TestWriteMutations(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + {Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}}, + {Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}}, + }, + }}}, + }} + mutationBytes, err := proto.Marshal(mutations) + if err != nil { + t.Fatal(err) + } + mutationsMsg := WriteMutations(ctx, poolMsg.ObjectId, connMsg.ObjectId, mutationBytes) + if g, w := mutationsMsg.Code, int32(0); g != w { + t.Fatalf("WriteMutations result mismatch\n Got: %v\nWant: %v", g, w) + } + if mutationsMsg.Length() == 0 { + t.Fatal("WriteMutations returned no data") + } + + // Write mutations in a transaction. + mutationsMsg = BeginTransaction(ctx, poolMsg.ObjectId, connMsg.ObjectId, mutationBytes) + // The response should now be an empty message, as the mutations were only buffered in the transaction. + if g, w := mutationsMsg.Code, int32(0); g != w { + t.Fatalf("WriteMutations result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := mutationsMsg.Length(), int32(0); g != w { + t.Fatalf("WriteMutations data length mismatch\n Got: %v\nWant: %v", g, w) + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} diff --git a/spannerlib/shared/shared_lib.go b/spannerlib/shared/shared_lib.go index e9c9be6a..8645f219 100644 --- a/spannerlib/shared/shared_lib.go +++ b/spannerlib/shared/shared_lib.go @@ -111,6 +111,22 @@ func CloseConnection(poolId, connId int64) (int64, int32, int64, int32, unsafe.P return pin(msg) } +// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in +// the current read/write transaction if the connection currently has a read/write transaction. +// The mutations are applied to the database in a new read/write transaction that is automatically +// committed if the connection currently does not have a transaction. +// +// The function returns an error if the connection is currently in a read-only transaction. +// +// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object. +// +//export WriteMutations +func WriteMutations(poolId, connectionId int64, mutationsBytes []byte) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.WriteMutations(ctx, poolId, connectionId, mutationsBytes) + return pin(msg) +} + // Execute executes a SQL statement on the given connection. // The return type is an identifier for a Rows object. This identifier can be used to // call the functions Metadata and Next to get respectively the metadata of the result @@ -123,6 +139,18 @@ func Execute(poolId, connectionId int64, statement []byte) (int64, int32, int64, return pin(msg) } +// ExecuteBatch executes a batch of statements on the given connection. The statements must all be either DML or DDL +// statements. Mixing DML and DDL in a batch is not supported. Executing queries in a batch is also not supported. +// The batch will use the current transaction on the given connection, or execute as a single auto-commit statement +// if the connection does not have a transaction. +// +//export ExecuteBatch +func ExecuteBatch(poolId, connectionId int64, statements []byte) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.ExecuteBatch(ctx, poolId, connectionId, statements) + return pin(msg) +} + // Metadata returns the metadata of a Rows object. // //export Metadata diff --git a/spannerlib/shared/shared_lib_test.go b/spannerlib/shared/shared_lib_test.go index 82d010bd..b6d63467 100644 --- a/spannerlib/shared/shared_lib_test.go +++ b/spannerlib/shared/shared_lib_test.go @@ -245,6 +245,63 @@ func TestExecute(t *testing.T) { } } +func TestExecuteBatch(t *testing.T) { + t.Parallel() + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + request := &spannerpb.ExecuteBatchDmlRequest{ + Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: testutil.UpdateBarSetFoo}, + {Sql: testutil.UpdateBarSetFoo}, + }, + } + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatal(err) + } + // ExecuteBatch returns a ExecuteBatchDml response. + mem, code, batchId, length, data := ExecuteBatch(poolId, connId, requestBytes) + verifyDataMessage(t, "ExecuteBatch", mem, code, batchId, length, data) + response := &spannerpb.ExecuteBatchDmlResponse{} + responseBytes := reflect.SliceAt(reflect.TypeOf(byte(0)), data, int(length)).Bytes() + if err := proto.Unmarshal(responseBytes, response); err != nil { + t.Fatal(err) + } + if g, w := len(response.ResultSets), 2; g != w { + t.Fatalf("num results mismatch\n Got: %v\nWant: %v", g, w) + } + for i, result := range response.ResultSets { + if g, w := result.Stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("%d: update count mismatch\n Got: %v\nWant: %v", i, g, w) + } + } + // Release the memory held by the response. + if g, w := Release(mem), int32(0); g != w { + t.Fatalf("Release() result mismatch\n Got: %v\nWant: %v", g, w) + } + + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + func TestBeginAndCommitTransaction(t *testing.T) { t.Parallel() @@ -355,6 +412,75 @@ func TestBeginAndRollbackTransaction(t *testing.T) { } } +func TestWriteMutations(t *testing.T) { + t.Parallel() + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + {Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}}, + {Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}}, + }, + }}}, + }} + mutationBytes, err := proto.Marshal(mutations) + if err != nil { + t.Fatal(err) + } + // WriteMutations returns a CommitResponse or nil, depending on whether the connection has an active transaction. + mem, code, id, length, data := WriteMutations(poolId, connId, mutationBytes) + verifyDataMessage(t, "WriteMutations", mem, code, id, length, data) + + response := &spannerpb.CommitResponse{} + responseBytes := reflect.SliceAt(reflect.TypeOf(byte(0)), data, int(length)).Bytes() + if err := proto.Unmarshal(responseBytes, response); err != nil { + t.Fatal(err) + } + if response.CommitTimestamp == nil { + t.Fatal("CommitTimestamp is nil") + } + // Release the memory held by the response. + if g, w := Release(mem), int32(0); g != w { + t.Fatalf("Release() result mismatch\n Got: %v\nWant: %v", g, w) + } + + // Start a transaction on the connection and write the mutations to that transaction. + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + _, code, _, _, _ = BeginTransaction(poolId, connId, txOptsBytes) + if g, w := code, int32(0); g != w { + t.Fatalf("BeginTransaction result mismatch\n Got: %v\nWant: %v", g, w) + } + mem, code, id, length, data = WriteMutations(poolId, connId, mutationBytes) + // The response should now be an empty message, as the mutations were buffered in the current transaction. + verifyEmptyMessage(t, "WriteMutations in tx", mem, code, id, length, data) + + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + func verifyEmptyMessage(t *testing.T, name string, mem int64, code int32, id int64, length int32, res unsafe.Pointer) { if g, w := mem, int64(0); g != w { t.Fatalf("%s: mem ID mismatch\n Got: %v\nWant: %v", name, g, w) diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java index 90ac04f9..18a9ca50 100644 --- a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java @@ -21,7 +21,10 @@ import com.google.cloud.spannerlib.internal.MessageHandler; import com.google.cloud.spannerlib.internal.WrappedGoBytes; import com.google.protobuf.InvalidProtocolBufferException; +import com.google.spanner.v1.BatchWriteRequest.MutationGroup; import com.google.spanner.v1.CommitResponse; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteBatchDmlResponse; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.TransactionOptions; import java.nio.ByteBuffer; @@ -39,6 +42,31 @@ public Pool getPool() { return this.pool; } + /** + * Writes a group of mutations to Spanner. The mutations are buffered in the current read/write + * transaction if the connection has an active read/write transaction. Otherwise, the mutations + * are written directly to Spanner using a new read/write transaction. Returns a {@link + * CommitResponse} if the mutations were written directly to Spanner, and otherwise null if the + * mutations were buffered in the current transaction. + */ + public CommitResponse WriteMutations(MutationGroup mutations) { + try (WrappedGoBytes serializedRequest = WrappedGoBytes.serialize(mutations); + MessageHandler message = + getLibrary() + .execute( + library -> + library.WriteMutations( + pool.getId(), getId(), serializedRequest.getGoBytes()))) { + if (message.getLength() == 0) { + return null; + } + ByteBuffer buffer = message.getValue().getByteBuffer(0, message.getLength()); + return CommitResponse.parseFrom(buffer); + } catch (InvalidProtocolBufferException decodeException) { + throw new RuntimeException(decodeException); + } + } + /** Starts a transaction on this connection. */ public void beginTransaction(TransactionOptions options) { try (WrappedGoBytes serializedOptions = WrappedGoBytes.serialize(options)) { @@ -84,6 +112,25 @@ public Rows execute(ExecuteSqlRequest request) { } } + /** + * Executes the given batch of DML or DDL statements on this connection. The statements must all + * be of the same type. + */ + public ExecuteBatchDmlResponse executeBatch(ExecuteBatchDmlRequest request) { + try (WrappedGoBytes serializedRequest = WrappedGoBytes.serialize(request); + MessageHandler message = + getLibrary() + .execute( + library -> + library.ExecuteBatch( + pool.getId(), getId(), serializedRequest.getGoBytes()))) { + ByteBuffer buffer = message.getValue().getByteBuffer(0, message.getLength()); + return ExecuteBatchDmlResponse.parseFrom(buffer); + } catch (InvalidProtocolBufferException decodeException) { + throw new RuntimeException(decodeException); + } + } + /** Closes this connection. Any active transaction on the connection is rolled back. */ @Override public void close() { diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java index 1ad3f14e..7da5a013 100644 --- a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java @@ -65,6 +65,15 @@ default MessageHandler execute(Function function) /** Closes the given Connection. */ Message CloseConnection(long poolId, long connectionId); + /** + * Writes a group of mutations on Spanner. The mutations are buffered in the current read/write + * transaction if the connection has an active read/write transaction. Otherwise, the mutations + * are written directly to Spanner in a new read/write transaction. Returns a {@link + * com.google.spanner.v1.CommitResponse} if the mutations were written directly to Spanner, and an + * empty message if the mutations were only buffered in the current transaction. + */ + Message WriteMutations(long poolId, long connectionId, GoBytes mutations); + /** Starts a new transaction on the given Connection. */ Message BeginTransaction(long poolId, long connectionId, GoBytes transactionOptions); @@ -80,6 +89,12 @@ default MessageHandler execute(Function function) /** Executes a SQL statement on the given Connection. */ Message Execute(long poolId, long connectionId, GoBytes executeSqlRequest); + /** + * Executes a batch of DML or DDL statements on the given Connection. Returns an {@link + * com.google.spanner.v1.ExecuteBatchDmlResponse} for both DML and DDL batches. + */ + Message ExecuteBatch(long poolId, long connectionId, GoBytes executeBatchDmlRequest); + /** Returns the {@link com.google.spanner.v1.ResultSetMetadata} of the given Rows object. */ Message Metadata(long poolId, long connectionId, long rowsId); diff --git a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java new file mode 100644 index 00000000..9740d541 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java @@ -0,0 +1,152 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spannerlib; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.common.collect.ImmutableMap; +import com.google.longrunning.Operation; +import com.google.protobuf.AbstractMessage; +import com.google.protobuf.Any; +import com.google.protobuf.Empty; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import com.google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata; +import com.google.spanner.admin.database.v1.UpdateDatabaseDdlRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest.Statement; +import com.google.spanner.v1.ExecuteBatchDmlResponse; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; +import java.util.List; +import org.junit.Test; + +public class BatchTest extends AbstractMockServerTest { + + @Test + public void testBatchDml() { + String insert = "insert into test (id, value) values (@id, @value)"; + mockSpanner.putStatementResult( + StatementResult.update( + com.google.cloud.spanner.Statement.newBuilder(insert) + .bind("id") + .to(1L) + .bind("value") + .to("One") + .build(), + 1L)); + mockSpanner.putStatementResult( + StatementResult.update( + com.google.cloud.spanner.Statement.newBuilder(insert) + .bind("id") + .to(2L) + .bind("value") + .to("Two") + .build(), + 1L)); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + ExecuteBatchDmlResponse response = + connection.executeBatch( + ExecuteBatchDmlRequest.newBuilder() + .addStatements( + Statement.newBuilder() + .setSql(insert) + .setParams( + Struct.newBuilder() + .putFields("id", Value.newBuilder().setStringValue("1").build()) + .putFields( + "value", Value.newBuilder().setStringValue("One").build()) + .build()) + .putAllParamTypes( + ImmutableMap.of( + "id", Type.newBuilder().setCode(TypeCode.INT64).build(), + "value", Type.newBuilder().setCode(TypeCode.STRING).build())) + .build()) + .addStatements( + Statement.newBuilder() + .setSql(insert) + .setParams( + Struct.newBuilder() + .putFields("id", Value.newBuilder().setStringValue("2").build()) + .putFields( + "value", Value.newBuilder().setStringValue("Two").build()) + .build()) + .putAllParamTypes( + ImmutableMap.of( + "id", Type.newBuilder().setCode(TypeCode.INT64).build(), + "value", Type.newBuilder().setCode(TypeCode.STRING).build())) + .build()) + .build()); + + assertEquals(2, response.getResultSetsCount()); + assertEquals(1L, response.getResultSets(0).getStats().getRowCountExact()); + assertEquals(1L, response.getResultSets(1).getStats().getRowCountExact()); + } + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + } + + @Test + public void testBatchDdl() { + // Set up a DDL response on the mock server. + mockDatabaseAdmin.addResponse( + Operation.newBuilder() + .setDone(true) + .setResponse(Any.pack(Empty.getDefaultInstance())) + .setMetadata(Any.pack(UpdateDatabaseDdlMetadata.getDefaultInstance())) + .build()); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + ExecuteBatchDmlResponse response = + connection.executeBatch( + ExecuteBatchDmlRequest.newBuilder() + .addStatements( + Statement.newBuilder() + .setSql("create table my_table (id int64 primary key, value string(max))") + .build()) + .addStatements( + Statement.newBuilder() + .setSql("create index my_index on my_table (value)") + .build()) + .build()); + + assertEquals(2, response.getResultSetsCount()); + assertFalse(response.getResultSets(0).getStats().hasRowCountExact()); + } + + List requests = mockDatabaseAdmin.getRequests(); + assertEquals(1, requests.size()); + UpdateDatabaseDdlRequest request = (UpdateDatabaseDdlRequest) requests.get(0); + assertEquals(2, request.getStatementsCount()); + } +} diff --git a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java index 5415b867..8efc22b5 100644 --- a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java +++ b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java @@ -18,10 +18,25 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; +import com.google.rpc.Code; +import com.google.spanner.v1.BatchWriteRequest.MutationGroup; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.CommitResponse; import com.google.spanner.v1.CreateSessionRequest; +import com.google.spanner.v1.Mutation; +import com.google.spanner.v1.Mutation.Write; +import com.google.spanner.v1.TransactionOptions; +import com.google.spanner.v1.TransactionOptions.ReadOnly; import org.junit.Test; public class ConnectionTest extends AbstractMockServerTest { @@ -53,4 +68,138 @@ public void testCreateTwoConnections() { } } } + + @Test + public void testWriteMutations() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + CommitResponse response = + connection.WriteMutations( + MutationGroup.newBuilder() + .addMutations( + Mutation.newBuilder() + .setInsert( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues( + Value.newBuilder().setStringValue("One").build()) + .build()) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("2").build()) + .addValues( + Value.newBuilder().setStringValue("Two").build()) + .build()) + .build()) + .build()) + .addMutations( + Mutation.newBuilder() + .setInsertOrUpdate( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("0").build()) + .addValues( + Value.newBuilder().setStringValue("Zero").build()) + .build()) + .build()) + .build()) + .build()); + assertNotNull(response); + assertNotNull(response.getCommitTimestamp()); + + assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + CommitRequest request = mockSpanner.getRequestsOfType(CommitRequest.class).get(0); + assertEquals(2, request.getMutationsCount()); + assertEquals(2, request.getMutations(0).getInsert().getValuesCount()); + assertEquals(1, request.getMutations(1).getInsertOrUpdate().getValuesCount()); + } + } + + @Test + public void testWriteMutationsInTransaction() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + CommitResponse response = + connection.WriteMutations( + MutationGroup.newBuilder() + .addMutations( + Mutation.newBuilder() + .setInsert( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues( + Value.newBuilder().setStringValue("One").build()) + .build()) + .build()) + .build()) + .build()); + // The mutations are only buffered in the current transaction, so there should be no response. + assertNull(response); + + // Committing the transaction should return a CommitResponse. + response = connection.commit(); + assertNotNull(response); + assertNotNull(response.getCommitTimestamp()); + + assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + CommitRequest request = mockSpanner.getRequestsOfType(CommitRequest.class).get(0); + assertEquals(1, request.getMutationsCount()); + assertEquals(1, request.getMutations(0).getInsert().getValuesCount()); + } + } + + @Test + public void testWriteMutationsInReadOnlyTransaction() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction( + TransactionOptions.newBuilder().setReadOnly(ReadOnly.newBuilder().build()).build()); + SpannerLibException exception = + assertThrows( + SpannerLibException.class, + () -> + connection.WriteMutations( + MutationGroup.newBuilder() + .addMutations( + Mutation.newBuilder() + .setInsert( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues( + Value.newBuilder() + .setStringValue("1") + .build()) + .addValues( + Value.newBuilder() + .setStringValue("One") + .build()) + .build()) + .build()) + .build()) + .build())); + assertEquals(Code.FAILED_PRECONDITION.getNumber(), exception.getStatus().getCode()); + } + } }