diff --git a/conn.go b/conn.go index 7f9ece77..698c8274 100644 --- a/conn.go +++ b/conn.go @@ -222,6 +222,9 @@ type SpannerConn interface { // return the same Spanner client. UnderlyingClient() (client *spanner.Client, err error) + // DetectStatementType returns the type of SQL statement. + DetectStatementType(query string) parser.StatementType + // resetTransactionForRetry resets the current transaction after it has // been aborted by Spanner. Calling this function on a transaction that // has not been aborted is not supported and will cause an error to be @@ -286,6 +289,11 @@ func (c *conn) UnderlyingClient() (*spanner.Client, error) { return c.client, nil } +func (c *conn) DetectStatementType(query string) parser.StatementType { + info := c.parser.DetectStatementType(query) + return info.StatementType +} + func (c *conn) CommitTimestamp() (time.Time, error) { ts := propertyCommitTimestamp.GetValueOrDefault(c.state) if ts == nil { @@ -675,6 +683,27 @@ func sum(affected []int64) int64 { return sum } +// WriteMutations is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// WriteMutations writes mutations using this connection. The mutations are either buffered in the current transaction, +// or written directly to Spanner using a new read/write transaction if the connection does not have a transaction. +// +// The function returns an error if the connection currently has a read-only transaction. +// +// The returned CommitResponse is nil if the connection currently has a transaction, as the mutations will only be +// applied to Spanner when the transaction commits. +func (c *conn) WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error) { + if c.inTransaction() { + return nil, c.BufferWrite(ms) + } + ts, err := c.Apply(ctx, ms) + if err != nil { + return nil, err + } + return &spanner.CommitResponse{CommitTs: ts}, nil +} + func (c *conn) Apply(ctx context.Context, ms []*spanner.Mutation, opts ...spanner.ApplyOption) (commitTimestamp time.Time, err error) { if c.inTransaction() { return time.Time{}, spanner.ToSpannerError( @@ -1071,6 +1100,34 @@ func (c *conn) getBatchReadOnlyTransactionOptions() BatchReadOnlyTransactionOpti return BatchReadOnlyTransactionOptions{TimestampBound: c.ReadOnlyStaleness()} } +// BeginReadOnlyTransaction is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// BeginReadOnlyTransaction starts a new read-only transaction on this connection. +func (c *conn) BeginReadOnlyTransaction(ctx context.Context, options *ReadOnlyTransactionOptions) (driver.Tx, error) { + c.withTempReadOnlyTransactionOptions(options) + tx, err := c.BeginTx(ctx, driver.TxOptions{ReadOnly: true}) + if err != nil { + c.withTempReadOnlyTransactionOptions(nil) + return nil, err + } + return tx, nil +} + +// BeginReadWriteTransaction is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// BeginReadWriteTransaction starts a new read/write transaction on this connection. +func (c *conn) BeginReadWriteTransaction(ctx context.Context, options *ReadWriteTransactionOptions) (driver.Tx, error) { + c.withTempTransactionOptions(options) + tx, err := c.BeginTx(ctx, driver.TxOptions{}) + if err != nil { + c.withTempTransactionOptions(nil) + return nil, err + } + return tx, nil +} + func (c *conn) Begin() (driver.Tx, error) { return c.BeginTx(context.Background(), driver.TxOptions{}) } @@ -1254,7 +1311,11 @@ func (c *conn) inReadWriteTransaction() bool { return false } -func (c *conn) commit(ctx context.Context) (*spanner.CommitResponse, error) { +// Commit is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// Commit commits the current transaction on this connection. +func (c *conn) Commit(ctx context.Context) (*spanner.CommitResponse, error) { if !c.inTransaction() { return nil, status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction") } @@ -1262,10 +1323,17 @@ func (c *conn) commit(ctx context.Context) (*spanner.CommitResponse, error) { if err := c.tx.Commit(); err != nil { return nil, err } - return c.CommitResponse() + + // This will return either the commit response or nil, depending on whether the transaction was a + // read/write transaction or a read-only transaction. + return propertyCommitResponse.GetValueOrDefault(c.state), nil } -func (c *conn) rollback(ctx context.Context) error { +// Rollback is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// Rollback rollbacks the current transaction on this connection. +func (c *conn) Rollback(ctx context.Context) error { if !c.inTransaction() { return status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction") } diff --git a/conn_with_mockserver_test.go b/conn_with_mockserver_test.go index 848ee24a..5ceba7ed 100644 --- a/conn_with_mockserver_test.go +++ b/conn_with_mockserver_test.go @@ -82,6 +82,26 @@ func TestTwoTransactionsOnOneConn(t *testing.T) { } } +func TestTwoQueriesOnOneConn(t *testing.T) { + t.Parallel() + + db, _, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + c, _ := db.Conn(ctx) + defer silentClose(c) + + for range 2 { + r, err := c.QueryContext(context.Background(), testutil.SelectFooFromBar) + if err != nil { + t.Fatal(err) + } + _ = r.Next() + defer silentClose(r) + } +} + func TestExplicitBeginTx(t *testing.T) { t.Parallel() diff --git a/spannerlib/api/batch_test.go b/spannerlib/api/batch_test.go new file mode 100644 index 00000000..1ae9af14 --- /dev/null +++ b/spannerlib/api/batch_test.go @@ -0,0 +1,238 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "fmt" + "reflect" + "testing" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/googleapis/go-sql-spanner/testutil" + "google.golang.org/grpc/codes" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/emptypb" +) + +func TestExecuteDmlBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Execute a DML batch. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: testutil.UpdateBarSetFoo}, + {Sql: testutil.UpdateBarSetFoo}, + }} + resp, err := ExecuteBatch(ctx, poolId, connId, request) + if err != nil { + t.Fatalf("ExecuteBatch returned unexpected error: %v", err) + } + if g, w := len(resp.ResultSets), 2; g != w { + t.Fatalf("num results mismatch\n Got: %d\nWant: %d", g, w) + } + for i, result := range resp.ResultSets { + if g, w := result.Stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("%d: update count mismatch\n Got: %d\nWant: %d", i, g, w) + } + } + + requests := server.TestSpanner.DrainRequestsFromServer() + // There should be no ExecuteSql requests. + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 0; g != w { + t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w) + } + batchRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteBatchDmlRequest{})) + if g, w := len(batchRequests), 1; g != w { + t.Fatalf("Execute batch request count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestExecuteDdlBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + // Set up a result for a DDL statement on the mock server. + var expectedResponse = &emptypb.Empty{} + anyMsg, _ := anypb.New(expectedResponse) + server.TestDatabaseAdmin.SetResps([]proto.Message{ + &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Response{Response: anyMsg}, + Name: "test-operation", + }, + }) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Execute a DDL batch. This also uses a DML batch request. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: "create table my_table (id int64 primary key, value string(100))"}, + {Sql: "create index my_index on my_table (value)"}, + }} + resp, err := ExecuteBatch(ctx, poolId, connId, request) + if err != nil { + t.Fatalf("ExecuteBatch returned unexpected error: %v", err) + } + // The response should contain an 'update count' per DDL statement. + if g, w := len(resp.ResultSets), 2; g != w { + t.Fatalf("num results mismatch\n Got: %d\nWant: %d", g, w) + } + // There is no update count for DDL statements. + for i, result := range resp.ResultSets { + emptyStats := &spannerpb.ResultSetStats{} + if g, w := result.Stats, emptyStats; !cmp.Equal(g, w, cmpopts.IgnoreUnexported(spannerpb.ResultSetStats{})) { + t.Fatalf("%d: ResultSetStats mismatch\n Got: %v\nWant: %v", i, g, w) + } + } + + requests := server.TestSpanner.DrainRequestsFromServer() + // There should be no ExecuteSql requests. + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 0; g != w { + t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w) + } + // There should also be no ExecuteBatchDml requests. + batchDmlRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteBatchDmlRequest{})) + if g, w := len(batchDmlRequests), 0; g != w { + t.Fatalf("ExecuteBatchDmlRequest count mismatch\n Got: %v\nWant: %v", g, w) + } + + adminRequests := server.TestDatabaseAdmin.Reqs() + if g, w := len(adminRequests), 1; g != w { + t.Fatalf("admin request count mismatch\n Got: %v\nWant: %v", g, w) + } + ddlRequest := adminRequests[0].(*databasepb.UpdateDatabaseDdlRequest) + if g, w := len(ddlRequest.Statements), 2; g != w { + t.Fatalf("DDL statement count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestExecuteMixedBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Try to execute a batch with mixed DML and DDL statements. This should fail. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: "create table my_table (id int64 primary key, value string(100))"}, + {Sql: "update my_table set value = 100 where true"}, + }} + _, err = ExecuteBatch(ctx, poolId, connId, request) + if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestExecuteDdlBatchInTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Try to execute a DDL batch in a transaction. This should fail. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: "create table my_table (id int64 primary key, value string(100))"}, + {Sql: "create index my_index on my_table (value)"}, + }} + _, err = ExecuteBatch(ctx, poolId, connId, request) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} diff --git a/spannerlib/api/connection.go b/spannerlib/api/connection.go index 66a9e6d4..907212ec 100644 --- a/spannerlib/api/connection.go +++ b/spannerlib/api/connection.go @@ -17,6 +17,7 @@ package api import ( "context" "database/sql" + "database/sql/driver" "fmt" "strings" "sync" @@ -25,6 +26,10 @@ import ( "cloud.google.com/go/spanner" "cloud.google.com/go/spanner/apiv1/spannerpb" spannerdriver "github.com/googleapis/go-sql-spanner" + "github.com/googleapis/go-sql-spanner/parser" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" ) // CloseConnection looks up the connection with the given poolId and connId and closes it. @@ -42,6 +47,51 @@ func CloseConnection(ctx context.Context, poolId, connId int64) error { return conn.close(ctx) } +// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in +// the current read/write transaction if the connection currently has a read/write transaction. +// The mutations are applied to the database in a new read/write transaction that is automatically +// committed if the connection currently does not have a transaction. +// +// The function returns an error if the connection is currently in a read-only transaction. +// +// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object. +func WriteMutations(ctx context.Context, poolId, connId int64, mutations *spannerpb.BatchWriteRequest_MutationGroup) (*spannerpb.CommitResponse, error) { + conn, err := findConnection(poolId, connId) + if err != nil { + return nil, err + } + return conn.writeMutations(ctx, mutations) +} + +// BeginTransaction starts a new transaction on the given connection. +// A connection can have at most one transaction at any time. This function therefore returns an error if the +// connection has an active transaction. +func BeginTransaction(ctx context.Context, poolId, connId int64, txOpts *spannerpb.TransactionOptions) error { + conn, err := findConnection(poolId, connId) + if err != nil { + return err + } + return conn.BeginTransaction(ctx, txOpts) +} + +// Commit commits the current transaction on the given connection. +func Commit(ctx context.Context, poolId, connId int64) (*spannerpb.CommitResponse, error) { + conn, err := findConnection(poolId, connId) + if err != nil { + return nil, err + } + return conn.commit(ctx) +} + +// Rollback rollbacks the current transaction on the given connection. +func Rollback(ctx context.Context, poolId, connId int64) error { + conn, err := findConnection(poolId, connId) + if err != nil { + return err + } + return conn.rollback(ctx) +} + func Execute(ctx context.Context, poolId, connId int64, executeSqlRequest *spannerpb.ExecuteSqlRequest) (int64, error) { conn, err := findConnection(poolId, connId) if err != nil { @@ -50,6 +100,14 @@ func Execute(ctx context.Context, poolId, connId int64, executeSqlRequest *spann return conn.Execute(ctx, executeSqlRequest) } +func ExecuteBatch(ctx context.Context, poolId, connId int64, statements *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { + conn, err := findConnection(poolId, connId) + if err != nil { + return nil, err + } + return conn.ExecuteBatch(ctx, statements.Statements) +} + type Connection struct { // results contains the open query results for this connection. results *sync.Map @@ -59,6 +117,16 @@ type Connection struct { backend *sql.Conn } +// spannerConn is an internal interface that contains the internal functions that are used by this API. +// It is implemented by the spannerdriver.conn struct. +type spannerConn interface { + WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error) + BeginReadOnlyTransaction(ctx context.Context, options *spannerdriver.ReadOnlyTransactionOptions) (driver.Tx, error) + BeginReadWriteTransaction(ctx context.Context, options *spannerdriver.ReadWriteTransactionOptions) (driver.Tx, error) + Commit(ctx context.Context) (*spanner.CommitResponse, error) + Rollback(ctx context.Context) error +} + type queryExecutor interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) @@ -66,6 +134,9 @@ type queryExecutor interface { func (conn *Connection) close(ctx context.Context) error { conn.closeResults(ctx) + // Rollback any open transactions on the connection. + _ = conn.rollback(ctx) + err := conn.backend.Close() if err != nil { return err @@ -73,9 +144,143 @@ func (conn *Connection) close(ctx context.Context) error { return nil } +func (conn *Connection) writeMutations(ctx context.Context, mutation *spannerpb.BatchWriteRequest_MutationGroup) (*spannerpb.CommitResponse, error) { + mutations := make([]*spanner.Mutation, 0, len(mutation.Mutations)) + for _, m := range mutation.Mutations { + spannerMutation, err := spanner.WrapMutation(m) + if err != nil { + return nil, err + } + mutations = append(mutations, spannerMutation) + } + var commitResponse *spanner.CommitResponse + if err := conn.backend.Raw(func(driverConn any) (err error) { + sc, _ := driverConn.(spannerConn) + commitResponse, err = sc.WriteMutations(ctx, mutations) + return err + }); err != nil { + return nil, err + } + + // The commit response is nil if the connection is currently in a transaction. + if commitResponse == nil { + return nil, nil + } + response := spannerpb.CommitResponse{ + CommitTimestamp: timestamppb.New(commitResponse.CommitTs), + } + return &response, nil +} + +func (conn *Connection) BeginTransaction(ctx context.Context, txOpts *spannerpb.TransactionOptions) error { + var err error + if txOpts.GetReadOnly() != nil { + return conn.beginReadOnlyTransaction(ctx, convertToReadOnlyOpts(txOpts)) + } else if txOpts.GetPartitionedDml() != nil { + err = spanner.ToSpannerError(status.Error(codes.InvalidArgument, "transaction type not supported")) + } else { + return conn.beginReadWriteTransaction(ctx, convertToReadWriteTransactionOptions(txOpts)) + } + if err != nil { + return err + } + return nil +} + +func (conn *Connection) beginReadOnlyTransaction(ctx context.Context, opts *spannerdriver.ReadOnlyTransactionOptions) error { + return conn.backend.Raw(func(driverConn any) (err error) { + sc, _ := driverConn.(spannerConn) + _, err = sc.BeginReadOnlyTransaction(ctx, opts) + return err + }) +} + +func (conn *Connection) beginReadWriteTransaction(ctx context.Context, opts *spannerdriver.ReadWriteTransactionOptions) error { + return conn.backend.Raw(func(driverConn any) (err error) { + sc, _ := driverConn.(spannerConn) + _, err = sc.BeginReadWriteTransaction(ctx, opts) + return err + }) +} + +func (conn *Connection) commit(ctx context.Context) (*spannerpb.CommitResponse, error) { + var response *spanner.CommitResponse + if err := conn.backend.Raw(func(driverConn any) (err error) { + spannerConn, _ := driverConn.(spannerConn) + response, err = spannerConn.Commit(ctx) + if err != nil { + return err + } + return nil + }); err != nil { + return nil, err + } + + // The commit response is nil for read-only transactions. + if response == nil { + return nil, nil + } + // TODO: Include commit stats + return &spannerpb.CommitResponse{CommitTimestamp: timestamppb.New(response.CommitTs)}, nil +} + +func (conn *Connection) rollback(ctx context.Context) error { + return conn.backend.Raw(func(driverConn any) (err error) { + spannerConn, _ := driverConn.(spannerConn) + return spannerConn.Rollback(ctx) + }) +} + +func convertToReadOnlyOpts(txOpts *spannerpb.TransactionOptions) *spannerdriver.ReadOnlyTransactionOptions { + return &spannerdriver.ReadOnlyTransactionOptions{ + TimestampBound: convertTimestampBound(txOpts), + } +} + +func convertTimestampBound(txOpts *spannerpb.TransactionOptions) spanner.TimestampBound { + ro := txOpts.GetReadOnly() + if ro.GetStrong() { + return spanner.StrongRead() + } else if ro.GetReadTimestamp() != nil { + return spanner.ReadTimestamp(ro.GetReadTimestamp().AsTime()) + } else if ro.GetMinReadTimestamp() != nil { + return spanner.ReadTimestamp(ro.GetMinReadTimestamp().AsTime()) + } else if ro.GetExactStaleness() != nil { + return spanner.ExactStaleness(ro.GetExactStaleness().AsDuration()) + } else if ro.GetMaxStaleness() != nil { + return spanner.MaxStaleness(ro.GetMaxStaleness().AsDuration()) + } + return spanner.TimestampBound{} +} + +func convertToReadWriteTransactionOptions(txOpts *spannerpb.TransactionOptions) *spannerdriver.ReadWriteTransactionOptions { + readLockMode := spannerpb.TransactionOptions_ReadWrite_READ_LOCK_MODE_UNSPECIFIED + if txOpts.GetReadWrite() != nil { + readLockMode = txOpts.GetReadWrite().GetReadLockMode() + } + return &spannerdriver.ReadWriteTransactionOptions{ + TransactionOptions: spanner.TransactionOptions{ + IsolationLevel: txOpts.GetIsolationLevel(), + ReadLockMode: readLockMode, + }, + } +} + +func convertIsolationLevel(level spannerpb.TransactionOptions_IsolationLevel) sql.IsolationLevel { + switch level { + case spannerpb.TransactionOptions_SERIALIZABLE: + return sql.LevelSerializable + case spannerpb.TransactionOptions_REPEATABLE_READ: + return sql.LevelRepeatableRead + } + return sql.LevelDefault +} + func (conn *Connection) closeResults(ctx context.Context) { conn.results.Range(func(key, value interface{}) bool { - // TODO: Implement + if r, ok := value.(*rows); ok { + _ = r.Close(ctx) + } return true }) } @@ -84,6 +289,10 @@ func (conn *Connection) Execute(ctx context.Context, statement *spannerpb.Execut return execute(ctx, conn, conn.backend, statement) } +func (conn *Connection) ExecuteBatch(ctx context.Context, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + return executeBatch(ctx, conn, conn.backend, statements) +} + func execute(ctx context.Context, conn *Connection, executor queryExecutor, statement *spannerpb.ExecuteSqlRequest) (int64, error) { params := extractParams(statement) it, err := executor.QueryContext(ctx, statement.Sql, params...) @@ -115,6 +324,90 @@ func execute(ctx context.Context, conn *Connection, executor queryExecutor, stat return id, nil } +func executeBatch(ctx context.Context, conn *Connection, executor queryExecutor, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + // Determine the type of batch that should be executed based on the type of statements. + batchType, err := determineBatchType(conn, statements) + if err != nil { + return nil, err + } + switch batchType { + case parser.BatchTypeDml: + return executeBatchDml(ctx, conn, executor, statements) + case parser.BatchTypeDdl: + return executeBatchDdl(ctx, conn, executor, statements) + default: + return nil, status.Errorf(codes.InvalidArgument, "unsupported batch type: %v", batchType) + } +} + +func executeBatchDdl(ctx context.Context, conn *Connection, executor queryExecutor, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + if err := conn.backend.Raw(func(driverConn any) error { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + return spannerConn.StartBatchDDL() + }); err != nil { + return nil, err + } + for _, statement := range statements { + _, err := executor.ExecContext(ctx, statement.Sql) + if err != nil { + return nil, err + } + } + // TODO: Add support for getting the actual Batch DDL response. + if err := conn.backend.Raw(func(driverConn any) (err error) { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + return spannerConn.RunBatch(ctx) + }); err != nil { + return nil, err + } + + response := spannerpb.ExecuteBatchDmlResponse{} + response.ResultSets = make([]*spannerpb.ResultSet, len(statements)) + for i := range statements { + response.ResultSets[i] = &spannerpb.ResultSet{Stats: &spannerpb.ResultSetStats{}} + } + return &response, nil +} + +func executeBatchDml(ctx context.Context, conn *Connection, executor queryExecutor, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + if err := conn.backend.Raw(func(driverConn any) error { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + return spannerConn.StartBatchDML() + }); err != nil { + return nil, err + } + for _, statement := range statements { + request := &spannerpb.ExecuteSqlRequest{ + Sql: statement.Sql, + Params: statement.Params, + ParamTypes: statement.ParamTypes, + } + params := extractParams(request) + _, err := executor.ExecContext(ctx, statement.Sql, params...) + if err != nil { + return nil, err + } + } + var spannerResult spannerdriver.SpannerResult + if err := conn.backend.Raw(func(driverConn any) (err error) { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + spannerResult, err = spannerConn.RunDmlBatch(ctx) + return err + }); err != nil { + return nil, err + } + affected, err := spannerResult.BatchRowsAffected() + if err != nil { + return nil, err + } + response := spannerpb.ExecuteBatchDmlResponse{} + response.ResultSets = make([]*spannerpb.ResultSet, len(affected)) + for i, aff := range affected { + response.ResultSets[i] = &spannerpb.ResultSet{Stats: &spannerpb.ResultSetStats{RowCount: &spannerpb.ResultSetStats_RowCountExact{RowCountExact: aff}}} + } + return &response, nil +} + func extractParams(statement *spannerpb.ExecuteSqlRequest) []any { paramsLen := 1 if statement.Params != nil { @@ -149,3 +442,34 @@ func extractParams(statement *spannerpb.ExecuteSqlRequest) []any { } return params } + +func determineBatchType(conn *Connection, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (parser.BatchType, error) { + if len(statements) == 0 { + return parser.BatchTypeDdl, status.Errorf(codes.InvalidArgument, "cannot determine type of an empty batch") + } + var batchType parser.BatchType + if err := conn.backend.Raw(func(driverConn any) error { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + firstStatementType := spannerConn.DetectStatementType(statements[0].Sql) + if firstStatementType == parser.StatementTypeDml { + batchType = parser.BatchTypeDml + } else if firstStatementType == parser.StatementTypeDdl { + batchType = parser.BatchTypeDdl + } else { + return status.Errorf(codes.InvalidArgument, "unsupported statement type for batching: %v", firstStatementType) + } + for i, statement := range statements { + if i > 0 { + tp := spannerConn.DetectStatementType(statement.Sql) + if tp != firstStatementType { + return status.Errorf(codes.InvalidArgument, "Batches may not contain different types of statements. The first statement is of type %v. The statement on position %d is of type %v.", firstStatementType, i, tp) + } + } + } + return nil + }); err != nil { + return parser.BatchTypeDdl, err + } + + return batchType, nil +} diff --git a/spannerlib/api/connection_test.go b/spannerlib/api/connection_test.go index 64d19ab3..b4625bb3 100644 --- a/spannerlib/api/connection_test.go +++ b/spannerlib/api/connection_test.go @@ -24,6 +24,7 @@ import ( "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/googleapis/go-sql-spanner/testutil" "google.golang.org/grpc/codes" + "google.golang.org/protobuf/types/known/structpb" ) func TestCreateAndCloseConnection(t *testing.T) { @@ -143,3 +144,159 @@ func TestCloseConnectionTwice(t *testing.T) { t.Fatalf("ClosePool returned unexpected error: %v", err) } } + +func TestWriteMutations(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + {Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}}, + {Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}}, + }, + }}}, + {Operation: &spannerpb.Mutation_Update{Update: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("0"), structpb.NewStringValue("Zero")}}, + }, + }}}, + }} + resp, err := WriteMutations(ctx, poolId, connId, mutations) + if err != nil { + t.Fatalf("WriteMutations returned unexpected error: %v", err) + } + if resp.CommitTimestamp == nil { + t.Fatalf("CommitTimestamp is nil") + } + requests := server.TestSpanner.DrainRequestsFromServer() + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequest := commitRequests[0].(*spannerpb.CommitRequest) + if g, w := len(commitRequest.Mutations), 2; g != w { + t.Fatalf("num mutations mismatch\n Got: %d\nWant: %d", g, w) + } + + // Write the same mutations in a transaction. + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + resp, err = WriteMutations(ctx, poolId, connId, mutations) + if err != nil { + t.Fatalf("WriteMutations returned unexpected error: %v", err) + } + if resp != nil { + t.Fatalf("WriteMutations returned unexpected response: %v", resp) + } + resp, err = Commit(ctx, poolId, connId) + if err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + if resp == nil { + t.Fatalf("Commit returned nil response") + } + if resp.CommitTimestamp == nil { + t.Fatalf("CommitTimestamp is nil") + } + requests = server.TestSpanner.DrainRequestsFromServer() + beginRequests = testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequests = testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequest = commitRequests[0].(*spannerpb.CommitRequest) + if g, w := len(commitRequest.Mutations), 2; g != w { + t.Fatalf("num mutations mismatch\n Got: %d\nWant: %d", g, w) + } + + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestWriteMutationsInReadOnlyTx(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Start a read-only transaction and try to write mutations to that transaction. That should return an error. + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadOnly_{ReadOnly: &spannerpb.TransactionOptions_ReadOnly{}}, + }); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + }, + }}}, + }} + _, err = WriteMutations(ctx, poolId, connId, mutations) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("WriteMutations error code mismatch\n Got: %d\nWant: %d", g, w) + } + + // Committing the read-only transaction should not lead to any commits on Spanner. + _, err = Commit(ctx, poolId, connId) + if err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + requests := server.TestSpanner.DrainRequestsFromServer() + // There should also not be any BeginTransaction requests on Spanner, as the transaction was never really started + // by a query or other statement. + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 0; g != w { + t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w) + } + + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} diff --git a/spannerlib/api/transaction_test.go b/spannerlib/api/transaction_test.go new file mode 100644 index 00000000..c629a2b4 --- /dev/null +++ b/spannerlib/api/transaction_test.go @@ -0,0 +1,411 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "fmt" + "reflect" + "testing" + + "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/testutil" + "google.golang.org/grpc/codes" +) + +func TestBeginAndCommit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Execute a statement in the transaction. + rowsId, err := Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo}) + if err != nil { + t.Fatalf("Execute returned unexpected error: %v", err) + } + stats, err := ResultSetStats(ctx, poolId, connId, rowsId) + if err != nil { + t.Fatalf("ResultSetStats returned unexpected error: %v", err) + } + if g, w := stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("row count mismatch\n Got: %v\nWant: %v", g, w) + } + if err := CloseRows(ctx, poolId, connId, rowsId); err != nil { + t.Fatalf("CloseRows returned unexpected error: %v", err) + } + + // Commit the transaction. + if _, err := Commit(ctx, poolId, connId); err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + + // Verify that the statement used the transaction, and that the transaction was started using an inlined begin + // option on the ExecuteSqlRequest. + requests := server.TestSpanner.DrainRequestsFromServer() + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 0; g != w { + t.Fatalf("BeginTransaction request count mismatch\n Got: %v\nWant: %v", g, w) + } + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w) + } + executeRequest := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if executeRequest.GetTransaction() == nil || executeRequest.GetTransaction().GetBegin() == nil || executeRequest.GetTransaction().GetBegin().GetReadWrite() == nil { + t.Fatalf("missing BeginTransaction option on request: %v", executeRequest) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("Commit request count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestBeginAndRollback(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Execute a statement in the transaction. + rowsId, err := Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo}) + if err != nil { + t.Fatalf("Execute returned unexpected error: %v", err) + } + stats, err := ResultSetStats(ctx, poolId, connId, rowsId) + if err != nil { + t.Fatalf("ResultSetStats returned unexpected error: %v", err) + } + if g, w := stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("row count mismatch\n Got: %v\nWant: %v", g, w) + } + if err := CloseRows(ctx, poolId, connId, rowsId); err != nil { + t.Fatalf("CloseRows returned unexpected error: %v", err) + } + + // Rollback the transaction. + if err := Rollback(ctx, poolId, connId); err != nil { + t.Fatalf("Rollback returned unexpected error: %v", err) + } + + // Verify that the transaction was rolled back. + requests := server.TestSpanner.DrainRequestsFromServer() + rollbackRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.RollbackRequest{})) + if g, w := len(rollbackRequests), 1; g != w { + t.Fatalf("Rollback request count mismatch\n Got: %v\nWant: %v", g, w) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("Commit request count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestCommitWithOpenRows(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Execute a statement in the transaction. + _, err = Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo}) + if err != nil { + t.Fatalf("Execute returned unexpected error: %v", err) + } + + // Try to commit the transaction without closing the Rows object that was returned during the transaction. + if _, err := Commit(ctx, poolId, connId); err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestCloseConnectionWithOpenTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + // Execute a statement in the transaction to activate it. + _, err = Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo}) + if err != nil { + t.Fatalf("Execute returned unexpected error: %v", err) + } + + // Close the connection while a transaction is still active. + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } + + // Verify that the transaction was rolled back when the connection was closed. + requests := server.TestSpanner.DrainRequestsFromServer() + rollbackRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.RollbackRequest{})) + if g, w := len(rollbackRequests), 1; g != w { + t.Fatalf("Rollback request count mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginTransactionWithOpenTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Try to start a transaction when one is already active. + err = BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestCommitWithoutTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Try to commit when there is no transaction. + _, err = Commit(ctx, poolId, connId) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestRollbackWithoutTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Try to commit when there is no transaction. + err = Rollback(ctx, poolId, connId) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestReadOnlyTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadOnly_{ + ReadOnly: &spannerpb.TransactionOptions_ReadOnly{}, + }, + }); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Execute a statement in the transaction. + rowsId, err := Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: testutil.SelectFooFromBar}) + if err != nil { + t.Fatalf("Execute returned unexpected error: %v", err) + } + if err := CloseRows(ctx, poolId, connId, rowsId); err != nil { + t.Fatalf("CloseRows returned unexpected error: %v", err) + } + + // Commit the transaction. + if _, err := Commit(ctx, poolId, connId); err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + + // Verify that the statement used a read-only transaction, that the transaction was started using an inlined + // begin option on the ExecuteSqlRequest, and that the Commit call was a no-op. + requests := server.TestSpanner.DrainRequestsFromServer() + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 0; g != w { + t.Fatalf("BeginTransaction request count mismatch\n Got: %v\nWant: %v", g, w) + } + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w) + } + executeRequest := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if executeRequest.GetTransaction() == nil || executeRequest.GetTransaction().GetBegin() == nil || executeRequest.GetTransaction().GetBegin().GetReadOnly() == nil { + t.Fatalf("missing BeginTransaction option on request: %v", executeRequest) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("Commit request count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestDdlInTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Execute a DDL statement in the transaction. This should fail. + _, err = Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: "create table my_table (id int64 primary key)"}) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} diff --git a/spannerlib/lib/connection.go b/spannerlib/lib/connection.go index 8834b92c..75c6bc3a 100644 --- a/spannerlib/lib/connection.go +++ b/spannerlib/lib/connection.go @@ -34,6 +34,69 @@ func CloseConnection(ctx context.Context, poolId, connId int64) *Message { return &Message{} } +// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in +// the current read/write transaction if the connection currently has a read/write transaction. +// The mutations are applied to the database in a new read/write transaction that is automatically +// committed if the connection currently does not have a transaction. +// +// The function returns an error if the connection is currently in a read-only transaction. +// +// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object. +func WriteMutations(ctx context.Context, poolId, connId int64, mutationBytes []byte) *Message { + mutations := spannerpb.BatchWriteRequest_MutationGroup{} + if err := proto.Unmarshal(mutationBytes, &mutations); err != nil { + return errMessage(err) + } + response, err := api.WriteMutations(ctx, poolId, connId, &mutations) + if err != nil { + return errMessage(err) + } + res, err := proto.Marshal(response) + if err != nil { + return errMessage(err) + } + return &Message{Res: res} +} + +// BeginTransaction starts a new transaction on the given connection. A connection can have at most one active +// transaction at any time. This function therefore returns an error if the connection has an active transaction. +func BeginTransaction(ctx context.Context, poolId, connId int64, txOptsBytes []byte) *Message { + txOpts := spannerpb.TransactionOptions{} + if err := proto.Unmarshal(txOptsBytes, &txOpts); err != nil { + return errMessage(err) + } + err := api.BeginTransaction(ctx, poolId, connId, &txOpts) + if err != nil { + return errMessage(err) + } + return &Message{} +} + +// Commit commits the current transaction on the given connection. +func Commit(ctx context.Context, poolId, connId int64) *Message { + response, err := api.Commit(ctx, poolId, connId) + if err != nil { + return errMessage(err) + } + if response == nil { + return &Message{} + } + res, err := proto.Marshal(response) + if err != nil { + return errMessage(err) + } + return &Message{Res: res} +} + +// Rollback rollbacks the current transaction on the given connection. +func Rollback(ctx context.Context, poolId, connId int64) *Message { + err := api.Rollback(ctx, poolId, connId) + if err != nil { + return errMessage(err) + } + return &Message{} +} + func Execute(ctx context.Context, poolId, connId int64, executeSqlRequestBytes []byte) (msg *Message) { defer func() { if r := recover(); r != nil { @@ -50,3 +113,19 @@ func Execute(ctx context.Context, poolId, connId int64, executeSqlRequestBytes [ } return idMessage(id) } + +func ExecuteBatch(ctx context.Context, poolId, connId int64, statementsBytes []byte) *Message { + statements := spannerpb.ExecuteBatchDmlRequest{} + if err := proto.Unmarshal(statementsBytes, &statements); err != nil { + return errMessage(err) + } + response, err := api.ExecuteBatch(ctx, poolId, connId, &statements) + if err != nil { + return errMessage(err) + } + res, err := proto.Marshal(response) + if err != nil { + return errMessage(err) + } + return &Message{Res: res} +} diff --git a/spannerlib/lib/connection_test.go b/spannerlib/lib/connection_test.go index d71da8c7..193e3fd9 100644 --- a/spannerlib/lib/connection_test.go +++ b/spannerlib/lib/connection_test.go @@ -23,6 +23,7 @@ import ( "github.com/googleapis/go-sql-spanner/testutil" "google.golang.org/grpc/codes" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" ) func TestCreateAndCloseConnection(t *testing.T) { @@ -116,3 +117,208 @@ func TestExecute(t *testing.T) { t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) } } + +func TestExecuteBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: testutil.UpdateBarSetFoo}, + {Sql: testutil.UpdateBarSetFoo}, + }} + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + rowsMsg := ExecuteBatch(ctx, poolMsg.ObjectId, connMsg.ObjectId, requestBytes) + if g, w := rowsMsg.Code, int32(0); g != w { + t.Fatalf("ExecuteBatch result mismatch\n Got: %v\nWant: %v", g, w) + } + if rowsMsg.Length() == 0 { + t.Fatal("ExecuteBatch returned no data") + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginAndCommit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + if err != nil { + t.Fatalf("Failed to marshal transaction options: %v", err) + } + txMsg := BeginTransaction(ctx, poolMsg.ObjectId, connMsg.ObjectId, txOptsBytes) + if g, w := txMsg.Code, int32(0); g != w { + t.Fatalf("BeginTransaction result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := txMsg.ObjectId, int64(0); g != w { + t.Fatalf("object ID result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := txMsg.Length(), int32(0); g != w { + t.Fatalf("result length mismatch\n Got: %v\nWant: %v", g, w) + } + + commitMsg := Commit(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := commitMsg.Code, int32(0); g != w { + t.Fatalf("Commit result mismatch\n Got: %v\nWant: %v", g, w) + } + if commitMsg.Length() == 0 { + t.Fatal("Commit return zero length") + } + resp := &spannerpb.CommitResponse{} + if err := proto.Unmarshal(commitMsg.Res, resp); err != nil { + t.Fatalf("Failed to unmarshal commit response: %v", err) + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginAndRollback(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + if err != nil { + t.Fatalf("Failed to marshal transaction options: %v", err) + } + txMsg := BeginTransaction(ctx, poolMsg.ObjectId, connMsg.ObjectId, txOptsBytes) + if g, w := txMsg.Code, int32(0); g != w { + t.Fatalf("BeginTransaction result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := txMsg.ObjectId, int64(0); g != w { + t.Fatalf("object ID result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := txMsg.Length(), int32(0); g != w { + t.Fatalf("result length mismatch\n Got: %v\nWant: %v", g, w) + } + + rollbackMsg := Rollback(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := rollbackMsg.Code, int32(0); g != w { + t.Fatalf("Rollback result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := rollbackMsg.Length(), int32(0); g != w { + t.Fatalf("Rollback length mismatch\n Got: %v\nWant: %v", g, w) + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestWriteMutations(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + {Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}}, + {Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}}, + }, + }}}, + }} + mutationBytes, err := proto.Marshal(mutations) + if err != nil { + t.Fatal(err) + } + mutationsMsg := WriteMutations(ctx, poolMsg.ObjectId, connMsg.ObjectId, mutationBytes) + if g, w := mutationsMsg.Code, int32(0); g != w { + t.Fatalf("WriteMutations result mismatch\n Got: %v\nWant: %v", g, w) + } + if mutationsMsg.Length() == 0 { + t.Fatal("WriteMutations returned no data") + } + + // Write mutations in a transaction. + mutationsMsg = BeginTransaction(ctx, poolMsg.ObjectId, connMsg.ObjectId, mutationBytes) + // The response should now be an empty message, as the mutations were only buffered in the transaction. + if g, w := mutationsMsg.Code, int32(0); g != w { + t.Fatalf("WriteMutations result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := mutationsMsg.Length(), int32(0); g != w { + t.Fatalf("WriteMutations data length mismatch\n Got: %v\nWant: %v", g, w) + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} diff --git a/spannerlib/shared/shared_lib.go b/spannerlib/shared/shared_lib.go index 57a10e18..8645f219 100644 --- a/spannerlib/shared/shared_lib.go +++ b/spannerlib/shared/shared_lib.go @@ -111,6 +111,22 @@ func CloseConnection(poolId, connId int64) (int64, int32, int64, int32, unsafe.P return pin(msg) } +// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in +// the current read/write transaction if the connection currently has a read/write transaction. +// The mutations are applied to the database in a new read/write transaction that is automatically +// committed if the connection currently does not have a transaction. +// +// The function returns an error if the connection is currently in a read-only transaction. +// +// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object. +// +//export WriteMutations +func WriteMutations(poolId, connectionId int64, mutationsBytes []byte) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.WriteMutations(ctx, poolId, connectionId, mutationsBytes) + return pin(msg) +} + // Execute executes a SQL statement on the given connection. // The return type is an identifier for a Rows object. This identifier can be used to // call the functions Metadata and Next to get respectively the metadata of the result @@ -123,6 +139,18 @@ func Execute(poolId, connectionId int64, statement []byte) (int64, int32, int64, return pin(msg) } +// ExecuteBatch executes a batch of statements on the given connection. The statements must all be either DML or DDL +// statements. Mixing DML and DDL in a batch is not supported. Executing queries in a batch is also not supported. +// The batch will use the current transaction on the given connection, or execute as a single auto-commit statement +// if the connection does not have a transaction. +// +//export ExecuteBatch +func ExecuteBatch(poolId, connectionId int64, statements []byte) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.ExecuteBatch(ctx, poolId, connectionId, statements) + return pin(msg) +} + // Metadata returns the metadata of a Rows object. // //export Metadata @@ -166,3 +194,40 @@ func CloseRows(poolId, connId, rowsId int64) (int64, int32, int64, int32, unsafe msg := lib.CloseRows(ctx, poolId, connId, rowsId) return pin(msg) } + +// BeginTransaction begins a new transaction on the given connection. +// The txOpts byte slice contains a serialized protobuf TransactionOptions object. +// +//export BeginTransaction +func BeginTransaction(poolId, connectionId int64, txOpts []byte) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.BeginTransaction(ctx, poolId, connectionId, txOpts) + return pin(msg) +} + +// Commit commits the current transaction on a connection. All transactions must be +// either committed or rolled back, including read-only transactions. This to ensure +// that all resources that are held by a transaction are cleaned up. +// +//export Commit +func Commit(poolId, connectionId int64) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.Commit(ctx, poolId, connectionId) + return pin(msg) +} + +// Rollback rolls back a previously started transaction. All transactions must be either +// committed or rolled back, including read-only transactions. This to ensure that +// all resources that are held by a transaction are cleaned up. +// +// Spanner does not require read-only transactions to be committed or rolled back, but +// this library requires that all transactions are committed or rolled back to clean up +// all resources. Commit and Rollback are semantically the same for read-only transactions +// on Spanner, and both functions just close the transaction. +// +//export Rollback +func Rollback(poolId, connectionId int64) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.Rollback(ctx, poolId, connectionId) + return pin(msg) +} diff --git a/spannerlib/shared/shared_lib_test.go b/spannerlib/shared/shared_lib_test.go index 9c2e3087..b6d63467 100644 --- a/spannerlib/shared/shared_lib_test.go +++ b/spannerlib/shared/shared_lib_test.go @@ -100,7 +100,7 @@ func TestCreateConnection(t *testing.T) { } func TestExecute(t *testing.T) { - t.Parallel() + // This test is intentionally not marked as Parallel, as it checks the number of open memory pointers. server, teardown := setupMockServer(t) defer teardown() @@ -245,6 +245,279 @@ func TestExecute(t *testing.T) { } } +func TestExecuteBatch(t *testing.T) { + t.Parallel() + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + request := &spannerpb.ExecuteBatchDmlRequest{ + Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: testutil.UpdateBarSetFoo}, + {Sql: testutil.UpdateBarSetFoo}, + }, + } + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatal(err) + } + // ExecuteBatch returns a ExecuteBatchDml response. + mem, code, batchId, length, data := ExecuteBatch(poolId, connId, requestBytes) + verifyDataMessage(t, "ExecuteBatch", mem, code, batchId, length, data) + response := &spannerpb.ExecuteBatchDmlResponse{} + responseBytes := reflect.SliceAt(reflect.TypeOf(byte(0)), data, int(length)).Bytes() + if err := proto.Unmarshal(responseBytes, response); err != nil { + t.Fatal(err) + } + if g, w := len(response.ResultSets), 2; g != w { + t.Fatalf("num results mismatch\n Got: %v\nWant: %v", g, w) + } + for i, result := range response.ResultSets { + if g, w := result.Stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("%d: update count mismatch\n Got: %v\nWant: %v", i, g, w) + } + } + // Release the memory held by the response. + if g, w := Release(mem), int32(0); g != w { + t.Fatalf("Release() result mismatch\n Got: %v\nWant: %v", g, w) + } + + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginAndCommitTransaction(t *testing.T) { + t.Parallel() + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + if err != nil { + t.Fatal(err) + } + mem, code, id, length, res := BeginTransaction(poolId, connId, txOptsBytes) + // BeginTransaction should return an empty message. + // That is, there should be no error code, no ObjectID, and no data. + verifyEmptyMessage(t, "BeginTransaction", mem, code, id, length, res) + + // Execute a statement in the transaction. + request := &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo} + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatal(err) + } + _, code, rowsId, _, _ := Execute(poolId, connId, requestBytes) + if g, w := code, int32(0); g != w { + t.Fatalf("Execute result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = CloseRows(poolId, connId, rowsId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseRows result mismatch\n Got: %v\nWant: %v", g, w) + } + + // Commit returns the CommitResponse (if any). + mem, code, id, length, res = Commit(poolId, connId) + verifyDataMessage(t, "Commit", mem, code, id, length, res) + + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginAndRollbackTransaction(t *testing.T) { + t.Parallel() + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + if err != nil { + t.Fatal(err) + } + mem, code, id, length, res := BeginTransaction(poolId, connId, txOptsBytes) + // BeginTransaction should return an empty message. + // That is, there should be no error code, no ObjectID, and no data. + verifyEmptyMessage(t, "BeginTransaction", mem, code, id, length, res) + + // Execute a statement in the transaction. + request := &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo} + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatal(err) + } + _, code, rowsId, _, _ := Execute(poolId, connId, requestBytes) + if g, w := code, int32(0); g != w { + t.Fatalf("Execute result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = CloseRows(poolId, connId, rowsId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseRows result mismatch\n Got: %v\nWant: %v", g, w) + } + + // Rollback returns nothing. + mem, code, id, length, res = Rollback(poolId, connId) + verifyEmptyMessage(t, "Rollback", mem, code, id, length, res) + + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestWriteMutations(t *testing.T) { + t.Parallel() + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + {Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}}, + {Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}}, + }, + }}}, + }} + mutationBytes, err := proto.Marshal(mutations) + if err != nil { + t.Fatal(err) + } + // WriteMutations returns a CommitResponse or nil, depending on whether the connection has an active transaction. + mem, code, id, length, data := WriteMutations(poolId, connId, mutationBytes) + verifyDataMessage(t, "WriteMutations", mem, code, id, length, data) + + response := &spannerpb.CommitResponse{} + responseBytes := reflect.SliceAt(reflect.TypeOf(byte(0)), data, int(length)).Bytes() + if err := proto.Unmarshal(responseBytes, response); err != nil { + t.Fatal(err) + } + if response.CommitTimestamp == nil { + t.Fatal("CommitTimestamp is nil") + } + // Release the memory held by the response. + if g, w := Release(mem), int32(0); g != w { + t.Fatalf("Release() result mismatch\n Got: %v\nWant: %v", g, w) + } + + // Start a transaction on the connection and write the mutations to that transaction. + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + _, code, _, _, _ = BeginTransaction(poolId, connId, txOptsBytes) + if g, w := code, int32(0); g != w { + t.Fatalf("BeginTransaction result mismatch\n Got: %v\nWant: %v", g, w) + } + mem, code, id, length, data = WriteMutations(poolId, connId, mutationBytes) + // The response should now be an empty message, as the mutations were buffered in the current transaction. + verifyEmptyMessage(t, "WriteMutations in tx", mem, code, id, length, data) + + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func verifyEmptyMessage(t *testing.T, name string, mem int64, code int32, id int64, length int32, res unsafe.Pointer) { + if g, w := mem, int64(0); g != w { + t.Fatalf("%s: mem ID mismatch\n Got: %v\nWant: %v", name, g, w) + } + if g, w := code, int32(0); g != w { + t.Fatalf("%s: result mismatch\n Got: %v\nWant: %v", name, g, w) + } + if g, w := id, int64(0); g != w { + t.Fatalf("%s: ID mismatch\n Got: %v\nWant: %v", name, g, w) + } + if g, w := length, int32(0); g != w { + t.Fatalf("%s: length mismatch\n Got: %v\nWant: %v", name, g, w) + } + if g, w := res, unsafe.Pointer(nil); g != w { + t.Fatalf("%s: ptr mismatch\n Got: %v\nWant: %v", name, g, w) + } +} + +// verifyDataMessage verifies that the result contains a data message. +func verifyDataMessage(t *testing.T, name string, mem int64, code int32, id int64, length int32, res unsafe.Pointer) { + if g, w := code, int32(0); g != w { + t.Fatalf("%s: result mismatch\n Got: %v\nWant: %v", name, g, w) + } + if mem == int64(0) { + t.Fatalf("%s: No memory identifier returned", name) + } + if g, w := id, int64(0); g != w { + t.Fatalf("%s: ID mismatch\n Got: %v\nWant: %v", name, g, w) + } + if length == int32(0) { + t.Fatalf("%s: zero length returned", name) + } + if res == unsafe.Pointer(nil) { + t.Fatalf("%s: nil pointer returned", name) + } +} + func countOpenMemoryPointers() (c int) { pinners.Range(func(key, value any) bool { c++ diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java index 6e11fd14..18a9ca50 100644 --- a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java @@ -20,7 +20,14 @@ import com.google.cloud.spannerlib.internal.MessageHandler; import com.google.cloud.spannerlib.internal.WrappedGoBytes; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.spanner.v1.BatchWriteRequest.MutationGroup; +import com.google.spanner.v1.CommitResponse; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteBatchDmlResponse; import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.TransactionOptions; +import java.nio.ByteBuffer; /** A {@link Connection} that has been created by SpannerLib. */ public class Connection extends AbstractLibraryObject { @@ -35,6 +42,64 @@ public Pool getPool() { return this.pool; } + /** + * Writes a group of mutations to Spanner. The mutations are buffered in the current read/write + * transaction if the connection has an active read/write transaction. Otherwise, the mutations + * are written directly to Spanner using a new read/write transaction. Returns a {@link + * CommitResponse} if the mutations were written directly to Spanner, and otherwise null if the + * mutations were buffered in the current transaction. + */ + public CommitResponse WriteMutations(MutationGroup mutations) { + try (WrappedGoBytes serializedRequest = WrappedGoBytes.serialize(mutations); + MessageHandler message = + getLibrary() + .execute( + library -> + library.WriteMutations( + pool.getId(), getId(), serializedRequest.getGoBytes()))) { + if (message.getLength() == 0) { + return null; + } + ByteBuffer buffer = message.getValue().getByteBuffer(0, message.getLength()); + return CommitResponse.parseFrom(buffer); + } catch (InvalidProtocolBufferException decodeException) { + throw new RuntimeException(decodeException); + } + } + + /** Starts a transaction on this connection. */ + public void beginTransaction(TransactionOptions options) { + try (WrappedGoBytes serializedOptions = WrappedGoBytes.serialize(options)) { + executeAndRelease( + getLibrary(), + library -> + library.BeginTransaction(pool.getId(), getId(), serializedOptions.getGoBytes())); + } + } + + /** + * Commits the current transaction on this connection and returns the {@link CommitResponse} or + * null if there is no {@link CommitResponse} (e.g. for read-only transactions). + */ + public CommitResponse commit() { + try (MessageHandler message = + getLibrary().execute(library -> library.Commit(pool.getId(), getId()))) { + // Return null in case there is no CommitResponse. + if (message.getLength() == 0) { + return null; + } + ByteBuffer buffer = message.getValue().getByteBuffer(0, message.getLength()); + return CommitResponse.parseFrom(buffer); + } catch (InvalidProtocolBufferException decodeException) { + throw new RuntimeException(decodeException); + } + } + + /** Rollbacks the current transaction on this connection. */ + public void rollback() { + executeAndRelease(getLibrary(), library -> library.Rollback(pool.getId(), getId())); + } + /** Executes the given SQL statement on this connection. */ public Rows execute(ExecuteSqlRequest request) { try (WrappedGoBytes serializedRequest = WrappedGoBytes.serialize(request); @@ -47,6 +112,26 @@ public Rows execute(ExecuteSqlRequest request) { } } + /** + * Executes the given batch of DML or DDL statements on this connection. The statements must all + * be of the same type. + */ + public ExecuteBatchDmlResponse executeBatch(ExecuteBatchDmlRequest request) { + try (WrappedGoBytes serializedRequest = WrappedGoBytes.serialize(request); + MessageHandler message = + getLibrary() + .execute( + library -> + library.ExecuteBatch( + pool.getId(), getId(), serializedRequest.getGoBytes()))) { + ByteBuffer buffer = message.getValue().getByteBuffer(0, message.getLength()); + return ExecuteBatchDmlResponse.parseFrom(buffer); + } catch (InvalidProtocolBufferException decodeException) { + throw new RuntimeException(decodeException); + } + } + + /** Closes this connection. Any active transaction on the connection is rolled back. */ @Override public void close() { executeAndRelease(getLibrary(), library -> library.CloseConnection(pool.getId(), getId())); diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java index 73394581..7da5a013 100644 --- a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java @@ -65,9 +65,36 @@ default MessageHandler execute(Function function) /** Closes the given Connection. */ Message CloseConnection(long poolId, long connectionId); + /** + * Writes a group of mutations on Spanner. The mutations are buffered in the current read/write + * transaction if the connection has an active read/write transaction. Otherwise, the mutations + * are written directly to Spanner in a new read/write transaction. Returns a {@link + * com.google.spanner.v1.CommitResponse} if the mutations were written directly to Spanner, and an + * empty message if the mutations were only buffered in the current transaction. + */ + Message WriteMutations(long poolId, long connectionId, GoBytes mutations); + + /** Starts a new transaction on the given Connection. */ + Message BeginTransaction(long poolId, long connectionId, GoBytes transactionOptions); + + /** + * Commits the current transaction on the given Connection and returns a {@link + * com.google.spanner.v1.CommitResponse}. + */ + Message Commit(long poolId, long connectionId); + + /** Rollbacks the current transaction on the given Connection. */ + Message Rollback(long poolId, long connectionId); + /** Executes a SQL statement on the given Connection. */ Message Execute(long poolId, long connectionId, GoBytes executeSqlRequest); + /** + * Executes a batch of DML or DDL statements on the given Connection. Returns an {@link + * com.google.spanner.v1.ExecuteBatchDmlResponse} for both DML and DDL batches. + */ + Message ExecuteBatch(long poolId, long connectionId, GoBytes executeBatchDmlRequest); + /** Returns the {@link com.google.spanner.v1.ResultSetMetadata} of the given Rows object. */ Message Metadata(long poolId, long connectionId, long rowsId); diff --git a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java new file mode 100644 index 00000000..9740d541 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java @@ -0,0 +1,152 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spannerlib; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.common.collect.ImmutableMap; +import com.google.longrunning.Operation; +import com.google.protobuf.AbstractMessage; +import com.google.protobuf.Any; +import com.google.protobuf.Empty; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import com.google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata; +import com.google.spanner.admin.database.v1.UpdateDatabaseDdlRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest.Statement; +import com.google.spanner.v1.ExecuteBatchDmlResponse; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; +import java.util.List; +import org.junit.Test; + +public class BatchTest extends AbstractMockServerTest { + + @Test + public void testBatchDml() { + String insert = "insert into test (id, value) values (@id, @value)"; + mockSpanner.putStatementResult( + StatementResult.update( + com.google.cloud.spanner.Statement.newBuilder(insert) + .bind("id") + .to(1L) + .bind("value") + .to("One") + .build(), + 1L)); + mockSpanner.putStatementResult( + StatementResult.update( + com.google.cloud.spanner.Statement.newBuilder(insert) + .bind("id") + .to(2L) + .bind("value") + .to("Two") + .build(), + 1L)); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + ExecuteBatchDmlResponse response = + connection.executeBatch( + ExecuteBatchDmlRequest.newBuilder() + .addStatements( + Statement.newBuilder() + .setSql(insert) + .setParams( + Struct.newBuilder() + .putFields("id", Value.newBuilder().setStringValue("1").build()) + .putFields( + "value", Value.newBuilder().setStringValue("One").build()) + .build()) + .putAllParamTypes( + ImmutableMap.of( + "id", Type.newBuilder().setCode(TypeCode.INT64).build(), + "value", Type.newBuilder().setCode(TypeCode.STRING).build())) + .build()) + .addStatements( + Statement.newBuilder() + .setSql(insert) + .setParams( + Struct.newBuilder() + .putFields("id", Value.newBuilder().setStringValue("2").build()) + .putFields( + "value", Value.newBuilder().setStringValue("Two").build()) + .build()) + .putAllParamTypes( + ImmutableMap.of( + "id", Type.newBuilder().setCode(TypeCode.INT64).build(), + "value", Type.newBuilder().setCode(TypeCode.STRING).build())) + .build()) + .build()); + + assertEquals(2, response.getResultSetsCount()); + assertEquals(1L, response.getResultSets(0).getStats().getRowCountExact()); + assertEquals(1L, response.getResultSets(1).getStats().getRowCountExact()); + } + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + } + + @Test + public void testBatchDdl() { + // Set up a DDL response on the mock server. + mockDatabaseAdmin.addResponse( + Operation.newBuilder() + .setDone(true) + .setResponse(Any.pack(Empty.getDefaultInstance())) + .setMetadata(Any.pack(UpdateDatabaseDdlMetadata.getDefaultInstance())) + .build()); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + ExecuteBatchDmlResponse response = + connection.executeBatch( + ExecuteBatchDmlRequest.newBuilder() + .addStatements( + Statement.newBuilder() + .setSql("create table my_table (id int64 primary key, value string(max))") + .build()) + .addStatements( + Statement.newBuilder() + .setSql("create index my_index on my_table (value)") + .build()) + .build()); + + assertEquals(2, response.getResultSetsCount()); + assertFalse(response.getResultSets(0).getStats().hasRowCountExact()); + } + + List requests = mockDatabaseAdmin.getRequests(); + assertEquals(1, requests.size()); + UpdateDatabaseDdlRequest request = (UpdateDatabaseDdlRequest) requests.get(0); + assertEquals(2, request.getStatementsCount()); + } +} diff --git a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java index 5415b867..8efc22b5 100644 --- a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java +++ b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java @@ -18,10 +18,25 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; +import com.google.rpc.Code; +import com.google.spanner.v1.BatchWriteRequest.MutationGroup; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.CommitResponse; import com.google.spanner.v1.CreateSessionRequest; +import com.google.spanner.v1.Mutation; +import com.google.spanner.v1.Mutation.Write; +import com.google.spanner.v1.TransactionOptions; +import com.google.spanner.v1.TransactionOptions.ReadOnly; import org.junit.Test; public class ConnectionTest extends AbstractMockServerTest { @@ -53,4 +68,138 @@ public void testCreateTwoConnections() { } } } + + @Test + public void testWriteMutations() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + CommitResponse response = + connection.WriteMutations( + MutationGroup.newBuilder() + .addMutations( + Mutation.newBuilder() + .setInsert( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues( + Value.newBuilder().setStringValue("One").build()) + .build()) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("2").build()) + .addValues( + Value.newBuilder().setStringValue("Two").build()) + .build()) + .build()) + .build()) + .addMutations( + Mutation.newBuilder() + .setInsertOrUpdate( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("0").build()) + .addValues( + Value.newBuilder().setStringValue("Zero").build()) + .build()) + .build()) + .build()) + .build()); + assertNotNull(response); + assertNotNull(response.getCommitTimestamp()); + + assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + CommitRequest request = mockSpanner.getRequestsOfType(CommitRequest.class).get(0); + assertEquals(2, request.getMutationsCount()); + assertEquals(2, request.getMutations(0).getInsert().getValuesCount()); + assertEquals(1, request.getMutations(1).getInsertOrUpdate().getValuesCount()); + } + } + + @Test + public void testWriteMutationsInTransaction() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + CommitResponse response = + connection.WriteMutations( + MutationGroup.newBuilder() + .addMutations( + Mutation.newBuilder() + .setInsert( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues( + Value.newBuilder().setStringValue("One").build()) + .build()) + .build()) + .build()) + .build()); + // The mutations are only buffered in the current transaction, so there should be no response. + assertNull(response); + + // Committing the transaction should return a CommitResponse. + response = connection.commit(); + assertNotNull(response); + assertNotNull(response.getCommitTimestamp()); + + assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + CommitRequest request = mockSpanner.getRequestsOfType(CommitRequest.class).get(0); + assertEquals(1, request.getMutationsCount()); + assertEquals(1, request.getMutations(0).getInsert().getValuesCount()); + } + } + + @Test + public void testWriteMutationsInReadOnlyTransaction() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction( + TransactionOptions.newBuilder().setReadOnly(ReadOnly.newBuilder().build()).build()); + SpannerLibException exception = + assertThrows( + SpannerLibException.class, + () -> + connection.WriteMutations( + MutationGroup.newBuilder() + .addMutations( + Mutation.newBuilder() + .setInsert( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues( + Value.newBuilder() + .setStringValue("1") + .build()) + .addValues( + Value.newBuilder() + .setStringValue("One") + .build()) + .build()) + .build()) + .build()) + .build())); + assertEquals(Code.FAILED_PRECONDITION.getNumber(), exception.getStatus().getCode()); + } + } } diff --git a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/TransactionTest.java b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/TransactionTest.java new file mode 100644 index 00000000..f2eef624 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/TransactionTest.java @@ -0,0 +1,178 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spannerlib; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.cloud.spanner.connection.RandomResultSetGenerator; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.RollbackRequest; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.TransactionOptions; +import com.google.spanner.v1.TransactionOptions.ReadOnly; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; +import io.grpc.Status.Code; +import org.junit.Test; + +public class TransactionTest extends AbstractMockServerTest { + @Test + public void testBeginAndCommit() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + connection.commit(); + + // TODO: The library should take a shortcut and just skip committing empty transactions. + assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + } + + @Test + public void testBeginAndRollback() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + connection.rollback(); + + assertEquals(0, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(0, mockSpanner.countRequestsOfType(RollbackRequest.class)); + } + } + + @Test + public void testReadWriteTransaction() { + String updateSql = "update my_table set my_val=@value where id=@id"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(updateSql).bind("value").to("foo").bind("id").to(1L).build(), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType(StructType.newBuilder().build()) + .build()) + .setStats(ResultSetStats.newBuilder().setRowCountExact(1L).build()) + .build())); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + connection.execute( + ExecuteSqlRequest.newBuilder() + .setSql(updateSql) + .setParams( + Struct.newBuilder() + .putFields("value", Value.newBuilder().setStringValue("foo").build()) + .putFields("id", Value.newBuilder().setStringValue("1").build()) + .build()) + .putAllParamTypes( + ImmutableMap.of( + "value", Type.newBuilder().setCode(TypeCode.STRING).build(), + "id", Type.newBuilder().setCode(TypeCode.INT64).build())) + .build()); + connection.commit(); + + // There should be no BeginTransaction requests, as the transaction start is inlined with the + // first statement. + assertEquals(0, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0); + assertTrue(request.hasTransaction()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadWrite()); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + } + + @Test + public void testReadOnlyTransaction() { + String sql = "select * from random"; + int numRows = 5; + RandomResultSetGenerator generator = new RandomResultSetGenerator(numRows); + mockSpanner.putStatementResult(StatementResult.query(Statement.of(sql), generator.generate())); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction( + TransactionOptions.newBuilder().setReadOnly(ReadOnly.newBuilder().build()).build()); + try (Rows rows = connection.execute(ExecuteSqlRequest.newBuilder().setSql(sql).build())) { + int rowCount = 0; + while (rows.next() != null) { + rowCount++; + } + assertEquals(numRows, rowCount); + } + connection.commit(); + } + + // There should be no BeginTransaction requests, as the transaction start is inlined with the + // first statement. + assertEquals(0, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0); + assertTrue(request.hasTransaction()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadOnly()); + // There should be no CommitRequests on the server, as committing a read-only transaction is a + // no-op on Spanner. + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Test + public void testBeginTwice() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + // Try to start two transactions on a connection. + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + SpannerLibException exception = + assertThrows( + SpannerLibException.class, + () -> connection.beginTransaction(TransactionOptions.getDefaultInstance())); + assertEquals(Code.FAILED_PRECONDITION.value(), exception.getStatus().getCode()); + } + } +} diff --git a/statements.go b/statements.go index 6920e7f2..f7232a7b 100644 --- a/statements.go +++ b/statements.go @@ -286,7 +286,7 @@ type executableCommitStatement struct { } func (s *executableCommitStatement) execContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Result, error) { - _, err := c.commit(ctx) + _, err := c.Commit(ctx) if err != nil { return nil, err } @@ -305,7 +305,7 @@ type executableRollbackStatement struct { } func (s *executableRollbackStatement) execContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Result, error) { - if err := c.rollback(ctx); err != nil { + if err := c.Rollback(ctx); err != nil { return nil, err } return driver.ResultNoRows, nil diff --git a/transaction.go b/transaction.go index dcc61d1b..5f93f94e 100644 --- a/transaction.go +++ b/transaction.go @@ -407,6 +407,7 @@ func (tx *readWriteTransaction) Commit() (err error) { return err } var commitResponse spanner.CommitResponse + // TODO: Optimize this to skip the Commit also if the transaction has not yet been used. if tx.rwTx != nil { if !tx.retryAborts() { ts, err := tx.rwTx.CommitWithReturnResp(tx.ctx)