From 860b1cc99dfe5a7d1a08d9f65ab5e365eec18c68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Mon, 15 Sep 2025 14:46:23 +0200 Subject: [PATCH 1/2] chore: add transaction support for SpannerLib --- conn.go | 45 +- conn_with_mockserver_test.go | 20 + spannerlib/api/connection.go | 153 ++++++- spannerlib/api/transaction_test.go | 411 ++++++++++++++++++ spannerlib/lib/connection.go | 39 ++ spannerlib/lib/connection_test.go | 104 +++++ spannerlib/shared/shared_lib.go | 37 ++ spannerlib/shared/shared_lib_test.go | 149 ++++++- .../google/cloud/spannerlib/Connection.java | 38 ++ .../spannerlib/internal/SpannerLibrary.java | 12 + .../cloud/spannerlib/TransactionTest.java | 178 ++++++++ statements.go | 4 +- transaction.go | 1 + 13 files changed, 1184 insertions(+), 7 deletions(-) create mode 100644 spannerlib/api/transaction_test.go create mode 100644 spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/TransactionTest.java diff --git a/conn.go b/conn.go index 7f9ece77..9ee4efb1 100644 --- a/conn.go +++ b/conn.go @@ -1071,6 +1071,34 @@ func (c *conn) getBatchReadOnlyTransactionOptions() BatchReadOnlyTransactionOpti return BatchReadOnlyTransactionOptions{TimestampBound: c.ReadOnlyStaleness()} } +// BeginReadOnlyTransaction is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// BeginReadOnlyTransaction starts a new read-only transaction on this connection. +func (c *conn) BeginReadOnlyTransaction(ctx context.Context, options *ReadOnlyTransactionOptions) (driver.Tx, error) { + c.withTempReadOnlyTransactionOptions(options) + tx, err := c.BeginTx(ctx, driver.TxOptions{ReadOnly: true}) + if err != nil { + c.withTempReadOnlyTransactionOptions(nil) + return nil, err + } + return tx, nil +} + +// BeginReadWriteTransaction is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// BeginReadWriteTransaction starts a new read/write transaction on this connection. +func (c *conn) BeginReadWriteTransaction(ctx context.Context, options *ReadWriteTransactionOptions) (driver.Tx, error) { + c.withTempTransactionOptions(options) + tx, err := c.BeginTx(ctx, driver.TxOptions{}) + if err != nil { + c.withTempTransactionOptions(nil) + return nil, err + } + return tx, nil +} + func (c *conn) Begin() (driver.Tx, error) { return c.BeginTx(context.Background(), driver.TxOptions{}) } @@ -1254,7 +1282,11 @@ func (c *conn) inReadWriteTransaction() bool { return false } -func (c *conn) commit(ctx context.Context) (*spanner.CommitResponse, error) { +// Commit is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// Commit commits the current transaction on this connection. +func (c *conn) Commit(ctx context.Context) (*spanner.CommitResponse, error) { if !c.inTransaction() { return nil, status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction") } @@ -1262,10 +1294,17 @@ func (c *conn) commit(ctx context.Context) (*spanner.CommitResponse, error) { if err := c.tx.Commit(); err != nil { return nil, err } - return c.CommitResponse() + + // This will return either the commit response or nil, depending on whether the transaction was a + // read/write transaction or a read-only transaction. + return propertyCommitResponse.GetValueOrDefault(c.state), nil } -func (c *conn) rollback(ctx context.Context) error { +// Rollback is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// Rollback rollbacks the current transaction on this connection. +func (c *conn) Rollback(ctx context.Context) error { if !c.inTransaction() { return status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction") } diff --git a/conn_with_mockserver_test.go b/conn_with_mockserver_test.go index 848ee24a..5ceba7ed 100644 --- a/conn_with_mockserver_test.go +++ b/conn_with_mockserver_test.go @@ -82,6 +82,26 @@ func TestTwoTransactionsOnOneConn(t *testing.T) { } } +func TestTwoQueriesOnOneConn(t *testing.T) { + t.Parallel() + + db, _, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + c, _ := db.Conn(ctx) + defer silentClose(c) + + for range 2 { + r, err := c.QueryContext(context.Background(), testutil.SelectFooFromBar) + if err != nil { + t.Fatal(err) + } + _ = r.Next() + defer silentClose(r) + } +} + func TestExplicitBeginTx(t *testing.T) { t.Parallel() diff --git a/spannerlib/api/connection.go b/spannerlib/api/connection.go index 66a9e6d4..f5f7643a 100644 --- a/spannerlib/api/connection.go +++ b/spannerlib/api/connection.go @@ -17,6 +17,7 @@ package api import ( "context" "database/sql" + "database/sql/driver" "fmt" "strings" "sync" @@ -25,6 +26,9 @@ import ( "cloud.google.com/go/spanner" "cloud.google.com/go/spanner/apiv1/spannerpb" spannerdriver "github.com/googleapis/go-sql-spanner" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" ) // CloseConnection looks up the connection with the given poolId and connId and closes it. @@ -42,6 +46,35 @@ func CloseConnection(ctx context.Context, poolId, connId int64) error { return conn.close(ctx) } +// BeginTransaction starts a new transaction on the given connection. +// A connection can have at most one transaction at any time. This function therefore returns an error if the +// connection has an active transaction. +func BeginTransaction(ctx context.Context, poolId, connId int64, txOpts *spannerpb.TransactionOptions) error { + conn, err := findConnection(poolId, connId) + if err != nil { + return err + } + return conn.BeginTransaction(ctx, txOpts) +} + +// Commit commits the current transaction on the given connection. +func Commit(ctx context.Context, poolId, connId int64) (*spannerpb.CommitResponse, error) { + conn, err := findConnection(poolId, connId) + if err != nil { + return nil, err + } + return conn.commit(ctx) +} + +// Rollback rollbacks the current transaction on the given connection. +func Rollback(ctx context.Context, poolId, connId int64) error { + conn, err := findConnection(poolId, connId) + if err != nil { + return err + } + return conn.rollback(ctx) +} + func Execute(ctx context.Context, poolId, connId int64, executeSqlRequest *spannerpb.ExecuteSqlRequest) (int64, error) { conn, err := findConnection(poolId, connId) if err != nil { @@ -59,6 +92,15 @@ type Connection struct { backend *sql.Conn } +// spannerConn is an internal interface that contains the internal functions that are used by this API. +// It is implemented by the spannerdriver.conn struct. +type spannerConn interface { + BeginReadOnlyTransaction(ctx context.Context, options *spannerdriver.ReadOnlyTransactionOptions) (driver.Tx, error) + BeginReadWriteTransaction(ctx context.Context, options *spannerdriver.ReadWriteTransactionOptions) (driver.Tx, error) + Commit(ctx context.Context) (*spanner.CommitResponse, error) + Rollback(ctx context.Context) error +} + type queryExecutor interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) @@ -66,6 +108,9 @@ type queryExecutor interface { func (conn *Connection) close(ctx context.Context) error { conn.closeResults(ctx) + // Rollback any open transactions on the connection. + _ = conn.rollback(ctx) + err := conn.backend.Close() if err != nil { return err @@ -73,9 +118,115 @@ func (conn *Connection) close(ctx context.Context) error { return nil } +func (conn *Connection) BeginTransaction(ctx context.Context, txOpts *spannerpb.TransactionOptions) error { + var err error + if txOpts.GetReadOnly() != nil { + return conn.beginReadOnlyTransaction(ctx, convertToReadOnlyOpts(txOpts)) + } else if txOpts.GetPartitionedDml() != nil { + err = spanner.ToSpannerError(status.Error(codes.InvalidArgument, "transaction type not supported")) + } else { + return conn.beginReadWriteTransaction(ctx, convertToReadWriteTransactionOptions(txOpts)) + } + if err != nil { + return err + } + return nil +} + +func (conn *Connection) beginReadOnlyTransaction(ctx context.Context, opts *spannerdriver.ReadOnlyTransactionOptions) error { + return conn.backend.Raw(func(driverConn any) (err error) { + sc, _ := driverConn.(spannerConn) + _, err = sc.BeginReadOnlyTransaction(ctx, opts) + return err + }) +} + +func (conn *Connection) beginReadWriteTransaction(ctx context.Context, opts *spannerdriver.ReadWriteTransactionOptions) error { + return conn.backend.Raw(func(driverConn any) (err error) { + sc, _ := driverConn.(spannerConn) + _, err = sc.BeginReadWriteTransaction(ctx, opts) + return err + }) +} + +func (conn *Connection) commit(ctx context.Context) (*spannerpb.CommitResponse, error) { + var response *spanner.CommitResponse + if err := conn.backend.Raw(func(driverConn any) (err error) { + spannerConn, _ := driverConn.(spannerConn) + response, err = spannerConn.Commit(ctx) + if err != nil { + return err + } + return nil + }); err != nil { + return nil, err + } + + // The commit response is nil for read-only transactions. + if response == nil { + return nil, nil + } + // TODO: Include commit stats + return &spannerpb.CommitResponse{CommitTimestamp: timestamppb.New(response.CommitTs)}, nil +} + +func (conn *Connection) rollback(ctx context.Context) error { + return conn.backend.Raw(func(driverConn any) (err error) { + spannerConn, _ := driverConn.(spannerConn) + return spannerConn.Rollback(ctx) + }) +} + +func convertToReadOnlyOpts(txOpts *spannerpb.TransactionOptions) *spannerdriver.ReadOnlyTransactionOptions { + return &spannerdriver.ReadOnlyTransactionOptions{ + TimestampBound: convertTimestampBound(txOpts), + } +} + +func convertTimestampBound(txOpts *spannerpb.TransactionOptions) spanner.TimestampBound { + ro := txOpts.GetReadOnly() + if ro.GetStrong() { + return spanner.StrongRead() + } else if ro.GetReadTimestamp() != nil { + return spanner.ReadTimestamp(ro.GetReadTimestamp().AsTime()) + } else if ro.GetMinReadTimestamp() != nil { + return spanner.ReadTimestamp(ro.GetMinReadTimestamp().AsTime()) + } else if ro.GetExactStaleness() != nil { + return spanner.ExactStaleness(ro.GetExactStaleness().AsDuration()) + } else if ro.GetMaxStaleness() != nil { + return spanner.MaxStaleness(ro.GetMaxStaleness().AsDuration()) + } + return spanner.TimestampBound{} +} + +func convertToReadWriteTransactionOptions(txOpts *spannerpb.TransactionOptions) *spannerdriver.ReadWriteTransactionOptions { + readLockMode := spannerpb.TransactionOptions_ReadWrite_READ_LOCK_MODE_UNSPECIFIED + if txOpts.GetReadWrite() != nil { + readLockMode = txOpts.GetReadWrite().GetReadLockMode() + } + return &spannerdriver.ReadWriteTransactionOptions{ + TransactionOptions: spanner.TransactionOptions{ + IsolationLevel: txOpts.GetIsolationLevel(), + ReadLockMode: readLockMode, + }, + } +} + +func convertIsolationLevel(level spannerpb.TransactionOptions_IsolationLevel) sql.IsolationLevel { + switch level { + case spannerpb.TransactionOptions_SERIALIZABLE: + return sql.LevelSerializable + case spannerpb.TransactionOptions_REPEATABLE_READ: + return sql.LevelRepeatableRead + } + return sql.LevelDefault +} + func (conn *Connection) closeResults(ctx context.Context) { conn.results.Range(func(key, value interface{}) bool { - // TODO: Implement + if r, ok := value.(*rows); ok { + _ = r.Close(ctx) + } return true }) } diff --git a/spannerlib/api/transaction_test.go b/spannerlib/api/transaction_test.go new file mode 100644 index 00000000..c629a2b4 --- /dev/null +++ b/spannerlib/api/transaction_test.go @@ -0,0 +1,411 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "fmt" + "reflect" + "testing" + + "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/testutil" + "google.golang.org/grpc/codes" +) + +func TestBeginAndCommit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Execute a statement in the transaction. + rowsId, err := Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo}) + if err != nil { + t.Fatalf("Execute returned unexpected error: %v", err) + } + stats, err := ResultSetStats(ctx, poolId, connId, rowsId) + if err != nil { + t.Fatalf("ResultSetStats returned unexpected error: %v", err) + } + if g, w := stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("row count mismatch\n Got: %v\nWant: %v", g, w) + } + if err := CloseRows(ctx, poolId, connId, rowsId); err != nil { + t.Fatalf("CloseRows returned unexpected error: %v", err) + } + + // Commit the transaction. + if _, err := Commit(ctx, poolId, connId); err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + + // Verify that the statement used the transaction, and that the transaction was started using an inlined begin + // option on the ExecuteSqlRequest. + requests := server.TestSpanner.DrainRequestsFromServer() + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 0; g != w { + t.Fatalf("BeginTransaction request count mismatch\n Got: %v\nWant: %v", g, w) + } + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w) + } + executeRequest := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if executeRequest.GetTransaction() == nil || executeRequest.GetTransaction().GetBegin() == nil || executeRequest.GetTransaction().GetBegin().GetReadWrite() == nil { + t.Fatalf("missing BeginTransaction option on request: %v", executeRequest) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("Commit request count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestBeginAndRollback(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Execute a statement in the transaction. + rowsId, err := Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo}) + if err != nil { + t.Fatalf("Execute returned unexpected error: %v", err) + } + stats, err := ResultSetStats(ctx, poolId, connId, rowsId) + if err != nil { + t.Fatalf("ResultSetStats returned unexpected error: %v", err) + } + if g, w := stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("row count mismatch\n Got: %v\nWant: %v", g, w) + } + if err := CloseRows(ctx, poolId, connId, rowsId); err != nil { + t.Fatalf("CloseRows returned unexpected error: %v", err) + } + + // Rollback the transaction. + if err := Rollback(ctx, poolId, connId); err != nil { + t.Fatalf("Rollback returned unexpected error: %v", err) + } + + // Verify that the transaction was rolled back. + requests := server.TestSpanner.DrainRequestsFromServer() + rollbackRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.RollbackRequest{})) + if g, w := len(rollbackRequests), 1; g != w { + t.Fatalf("Rollback request count mismatch\n Got: %v\nWant: %v", g, w) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("Commit request count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestCommitWithOpenRows(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Execute a statement in the transaction. + _, err = Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo}) + if err != nil { + t.Fatalf("Execute returned unexpected error: %v", err) + } + + // Try to commit the transaction without closing the Rows object that was returned during the transaction. + if _, err := Commit(ctx, poolId, connId); err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestCloseConnectionWithOpenTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + // Execute a statement in the transaction to activate it. + _, err = Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo}) + if err != nil { + t.Fatalf("Execute returned unexpected error: %v", err) + } + + // Close the connection while a transaction is still active. + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } + + // Verify that the transaction was rolled back when the connection was closed. + requests := server.TestSpanner.DrainRequestsFromServer() + rollbackRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.RollbackRequest{})) + if g, w := len(rollbackRequests), 1; g != w { + t.Fatalf("Rollback request count mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginTransactionWithOpenTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Try to start a transaction when one is already active. + err = BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestCommitWithoutTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Try to commit when there is no transaction. + _, err = Commit(ctx, poolId, connId) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestRollbackWithoutTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Try to commit when there is no transaction. + err = Rollback(ctx, poolId, connId) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestReadOnlyTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadOnly_{ + ReadOnly: &spannerpb.TransactionOptions_ReadOnly{}, + }, + }); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Execute a statement in the transaction. + rowsId, err := Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: testutil.SelectFooFromBar}) + if err != nil { + t.Fatalf("Execute returned unexpected error: %v", err) + } + if err := CloseRows(ctx, poolId, connId, rowsId); err != nil { + t.Fatalf("CloseRows returned unexpected error: %v", err) + } + + // Commit the transaction. + if _, err := Commit(ctx, poolId, connId); err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + + // Verify that the statement used a read-only transaction, that the transaction was started using an inlined + // begin option on the ExecuteSqlRequest, and that the Commit call was a no-op. + requests := server.TestSpanner.DrainRequestsFromServer() + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 0; g != w { + t.Fatalf("BeginTransaction request count mismatch\n Got: %v\nWant: %v", g, w) + } + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w) + } + executeRequest := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if executeRequest.GetTransaction() == nil || executeRequest.GetTransaction().GetBegin() == nil || executeRequest.GetTransaction().GetBegin().GetReadOnly() == nil { + t.Fatalf("missing BeginTransaction option on request: %v", executeRequest) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("Commit request count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestDdlInTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Execute a DDL statement in the transaction. This should fail. + _, err = Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: "create table my_table (id int64 primary key)"}) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} diff --git a/spannerlib/lib/connection.go b/spannerlib/lib/connection.go index 8834b92c..d41b6a6e 100644 --- a/spannerlib/lib/connection.go +++ b/spannerlib/lib/connection.go @@ -34,6 +34,45 @@ func CloseConnection(ctx context.Context, poolId, connId int64) *Message { return &Message{} } +// BeginTransaction starts a new transaction on the given connection. A connection can have at most one active +// transaction at any time. This function therefore returns an error if the connection has an active transaction. +func BeginTransaction(ctx context.Context, poolId, connId int64, txOptsBytes []byte) *Message { + txOpts := spannerpb.TransactionOptions{} + if err := proto.Unmarshal(txOptsBytes, &txOpts); err != nil { + return errMessage(err) + } + err := api.BeginTransaction(ctx, poolId, connId, &txOpts) + if err != nil { + return errMessage(err) + } + return &Message{} +} + +// Commit commits the current transaction on the given connection. +func Commit(ctx context.Context, poolId, connId int64) *Message { + response, err := api.Commit(ctx, poolId, connId) + if err != nil { + return errMessage(err) + } + if response == nil { + return &Message{} + } + res, err := proto.Marshal(response) + if err != nil { + return errMessage(err) + } + return &Message{Res: res} +} + +// Rollback rollbacks the current transaction on the given connection. +func Rollback(ctx context.Context, poolId, connId int64) *Message { + err := api.Rollback(ctx, poolId, connId) + if err != nil { + return errMessage(err) + } + return &Message{} +} + func Execute(ctx context.Context, poolId, connId int64, executeSqlRequestBytes []byte) (msg *Message) { defer func() { if r := recover(); r != nil { diff --git a/spannerlib/lib/connection_test.go b/spannerlib/lib/connection_test.go index d71da8c7..02ba453c 100644 --- a/spannerlib/lib/connection_test.go +++ b/spannerlib/lib/connection_test.go @@ -116,3 +116,107 @@ func TestExecute(t *testing.T) { t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) } } + +func TestBeginAndCommit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + if err != nil { + t.Fatalf("Failed to marshal transaction options: %v", err) + } + txMsg := BeginTransaction(ctx, poolMsg.ObjectId, connMsg.ObjectId, txOptsBytes) + if g, w := txMsg.Code, int32(0); g != w { + t.Fatalf("BeginTransaction result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := txMsg.ObjectId, int64(0); g != w { + t.Fatalf("object ID result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := txMsg.Length(), int32(0); g != w { + t.Fatalf("result length mismatch\n Got: %v\nWant: %v", g, w) + } + + commitMsg := Commit(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := commitMsg.Code, int32(0); g != w { + t.Fatalf("Commit result mismatch\n Got: %v\nWant: %v", g, w) + } + if commitMsg.Length() == 0 { + t.Fatal("Commit return zero length") + } + resp := &spannerpb.CommitResponse{} + if err := proto.Unmarshal(commitMsg.Res, resp); err != nil { + t.Fatalf("Failed to unmarshal commit response: %v", err) + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginAndRollback(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + if err != nil { + t.Fatalf("Failed to marshal transaction options: %v", err) + } + txMsg := BeginTransaction(ctx, poolMsg.ObjectId, connMsg.ObjectId, txOptsBytes) + if g, w := txMsg.Code, int32(0); g != w { + t.Fatalf("BeginTransaction result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := txMsg.ObjectId, int64(0); g != w { + t.Fatalf("object ID result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := txMsg.Length(), int32(0); g != w { + t.Fatalf("result length mismatch\n Got: %v\nWant: %v", g, w) + } + + rollbackMsg := Rollback(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := rollbackMsg.Code, int32(0); g != w { + t.Fatalf("Rollback result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := rollbackMsg.Length(), int32(0); g != w { + t.Fatalf("Rollback length mismatch\n Got: %v\nWant: %v", g, w) + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} diff --git a/spannerlib/shared/shared_lib.go b/spannerlib/shared/shared_lib.go index 57a10e18..e9c9be6a 100644 --- a/spannerlib/shared/shared_lib.go +++ b/spannerlib/shared/shared_lib.go @@ -166,3 +166,40 @@ func CloseRows(poolId, connId, rowsId int64) (int64, int32, int64, int32, unsafe msg := lib.CloseRows(ctx, poolId, connId, rowsId) return pin(msg) } + +// BeginTransaction begins a new transaction on the given connection. +// The txOpts byte slice contains a serialized protobuf TransactionOptions object. +// +//export BeginTransaction +func BeginTransaction(poolId, connectionId int64, txOpts []byte) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.BeginTransaction(ctx, poolId, connectionId, txOpts) + return pin(msg) +} + +// Commit commits the current transaction on a connection. All transactions must be +// either committed or rolled back, including read-only transactions. This to ensure +// that all resources that are held by a transaction are cleaned up. +// +//export Commit +func Commit(poolId, connectionId int64) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.Commit(ctx, poolId, connectionId) + return pin(msg) +} + +// Rollback rolls back a previously started transaction. All transactions must be either +// committed or rolled back, including read-only transactions. This to ensure that +// all resources that are held by a transaction are cleaned up. +// +// Spanner does not require read-only transactions to be committed or rolled back, but +// this library requires that all transactions are committed or rolled back to clean up +// all resources. Commit and Rollback are semantically the same for read-only transactions +// on Spanner, and both functions just close the transaction. +// +//export Rollback +func Rollback(poolId, connectionId int64) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.Rollback(ctx, poolId, connectionId) + return pin(msg) +} diff --git a/spannerlib/shared/shared_lib_test.go b/spannerlib/shared/shared_lib_test.go index 9c2e3087..82d010bd 100644 --- a/spannerlib/shared/shared_lib_test.go +++ b/spannerlib/shared/shared_lib_test.go @@ -100,7 +100,7 @@ func TestCreateConnection(t *testing.T) { } func TestExecute(t *testing.T) { - t.Parallel() + // This test is intentionally not marked as Parallel, as it checks the number of open memory pointers. server, teardown := setupMockServer(t) defer teardown() @@ -245,6 +245,153 @@ func TestExecute(t *testing.T) { } } +func TestBeginAndCommitTransaction(t *testing.T) { + t.Parallel() + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + if err != nil { + t.Fatal(err) + } + mem, code, id, length, res := BeginTransaction(poolId, connId, txOptsBytes) + // BeginTransaction should return an empty message. + // That is, there should be no error code, no ObjectID, and no data. + verifyEmptyMessage(t, "BeginTransaction", mem, code, id, length, res) + + // Execute a statement in the transaction. + request := &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo} + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatal(err) + } + _, code, rowsId, _, _ := Execute(poolId, connId, requestBytes) + if g, w := code, int32(0); g != w { + t.Fatalf("Execute result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = CloseRows(poolId, connId, rowsId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseRows result mismatch\n Got: %v\nWant: %v", g, w) + } + + // Commit returns the CommitResponse (if any). + mem, code, id, length, res = Commit(poolId, connId) + verifyDataMessage(t, "Commit", mem, code, id, length, res) + + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginAndRollbackTransaction(t *testing.T) { + t.Parallel() + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + if err != nil { + t.Fatal(err) + } + mem, code, id, length, res := BeginTransaction(poolId, connId, txOptsBytes) + // BeginTransaction should return an empty message. + // That is, there should be no error code, no ObjectID, and no data. + verifyEmptyMessage(t, "BeginTransaction", mem, code, id, length, res) + + // Execute a statement in the transaction. + request := &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo} + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatal(err) + } + _, code, rowsId, _, _ := Execute(poolId, connId, requestBytes) + if g, w := code, int32(0); g != w { + t.Fatalf("Execute result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = CloseRows(poolId, connId, rowsId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseRows result mismatch\n Got: %v\nWant: %v", g, w) + } + + // Rollback returns nothing. + mem, code, id, length, res = Rollback(poolId, connId) + verifyEmptyMessage(t, "Rollback", mem, code, id, length, res) + + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func verifyEmptyMessage(t *testing.T, name string, mem int64, code int32, id int64, length int32, res unsafe.Pointer) { + if g, w := mem, int64(0); g != w { + t.Fatalf("%s: mem ID mismatch\n Got: %v\nWant: %v", name, g, w) + } + if g, w := code, int32(0); g != w { + t.Fatalf("%s: result mismatch\n Got: %v\nWant: %v", name, g, w) + } + if g, w := id, int64(0); g != w { + t.Fatalf("%s: ID mismatch\n Got: %v\nWant: %v", name, g, w) + } + if g, w := length, int32(0); g != w { + t.Fatalf("%s: length mismatch\n Got: %v\nWant: %v", name, g, w) + } + if g, w := res, unsafe.Pointer(nil); g != w { + t.Fatalf("%s: ptr mismatch\n Got: %v\nWant: %v", name, g, w) + } +} + +// verifyDataMessage verifies that the result contains a data message. +func verifyDataMessage(t *testing.T, name string, mem int64, code int32, id int64, length int32, res unsafe.Pointer) { + if g, w := code, int32(0); g != w { + t.Fatalf("%s: result mismatch\n Got: %v\nWant: %v", name, g, w) + } + if mem == int64(0) { + t.Fatalf("%s: No memory identifier returned", name) + } + if g, w := id, int64(0); g != w { + t.Fatalf("%s: ID mismatch\n Got: %v\nWant: %v", name, g, w) + } + if length == int32(0) { + t.Fatalf("%s: zero length returned", name) + } + if res == unsafe.Pointer(nil) { + t.Fatalf("%s: nil pointer returned", name) + } +} + func countOpenMemoryPointers() (c int) { pinners.Range(func(key, value any) bool { c++ diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java index 6e11fd14..90ac04f9 100644 --- a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java @@ -20,7 +20,11 @@ import com.google.cloud.spannerlib.internal.MessageHandler; import com.google.cloud.spannerlib.internal.WrappedGoBytes; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.spanner.v1.CommitResponse; import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.TransactionOptions; +import java.nio.ByteBuffer; /** A {@link Connection} that has been created by SpannerLib. */ public class Connection extends AbstractLibraryObject { @@ -35,6 +39,39 @@ public Pool getPool() { return this.pool; } + /** Starts a transaction on this connection. */ + public void beginTransaction(TransactionOptions options) { + try (WrappedGoBytes serializedOptions = WrappedGoBytes.serialize(options)) { + executeAndRelease( + getLibrary(), + library -> + library.BeginTransaction(pool.getId(), getId(), serializedOptions.getGoBytes())); + } + } + + /** + * Commits the current transaction on this connection and returns the {@link CommitResponse} or + * null if there is no {@link CommitResponse} (e.g. for read-only transactions). + */ + public CommitResponse commit() { + try (MessageHandler message = + getLibrary().execute(library -> library.Commit(pool.getId(), getId()))) { + // Return null in case there is no CommitResponse. + if (message.getLength() == 0) { + return null; + } + ByteBuffer buffer = message.getValue().getByteBuffer(0, message.getLength()); + return CommitResponse.parseFrom(buffer); + } catch (InvalidProtocolBufferException decodeException) { + throw new RuntimeException(decodeException); + } + } + + /** Rollbacks the current transaction on this connection. */ + public void rollback() { + executeAndRelease(getLibrary(), library -> library.Rollback(pool.getId(), getId())); + } + /** Executes the given SQL statement on this connection. */ public Rows execute(ExecuteSqlRequest request) { try (WrappedGoBytes serializedRequest = WrappedGoBytes.serialize(request); @@ -47,6 +84,7 @@ public Rows execute(ExecuteSqlRequest request) { } } + /** Closes this connection. Any active transaction on the connection is rolled back. */ @Override public void close() { executeAndRelease(getLibrary(), library -> library.CloseConnection(pool.getId(), getId())); diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java index 73394581..1ad3f14e 100644 --- a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java @@ -65,6 +65,18 @@ default MessageHandler execute(Function function) /** Closes the given Connection. */ Message CloseConnection(long poolId, long connectionId); + /** Starts a new transaction on the given Connection. */ + Message BeginTransaction(long poolId, long connectionId, GoBytes transactionOptions); + + /** + * Commits the current transaction on the given Connection and returns a {@link + * com.google.spanner.v1.CommitResponse}. + */ + Message Commit(long poolId, long connectionId); + + /** Rollbacks the current transaction on the given Connection. */ + Message Rollback(long poolId, long connectionId); + /** Executes a SQL statement on the given Connection. */ Message Execute(long poolId, long connectionId, GoBytes executeSqlRequest); diff --git a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/TransactionTest.java b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/TransactionTest.java new file mode 100644 index 00000000..f2eef624 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/TransactionTest.java @@ -0,0 +1,178 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spannerlib; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.cloud.spanner.connection.RandomResultSetGenerator; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.RollbackRequest; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.TransactionOptions; +import com.google.spanner.v1.TransactionOptions.ReadOnly; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; +import io.grpc.Status.Code; +import org.junit.Test; + +public class TransactionTest extends AbstractMockServerTest { + @Test + public void testBeginAndCommit() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + connection.commit(); + + // TODO: The library should take a shortcut and just skip committing empty transactions. + assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + } + + @Test + public void testBeginAndRollback() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + connection.rollback(); + + assertEquals(0, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(0, mockSpanner.countRequestsOfType(RollbackRequest.class)); + } + } + + @Test + public void testReadWriteTransaction() { + String updateSql = "update my_table set my_val=@value where id=@id"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(updateSql).bind("value").to("foo").bind("id").to(1L).build(), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType(StructType.newBuilder().build()) + .build()) + .setStats(ResultSetStats.newBuilder().setRowCountExact(1L).build()) + .build())); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + connection.execute( + ExecuteSqlRequest.newBuilder() + .setSql(updateSql) + .setParams( + Struct.newBuilder() + .putFields("value", Value.newBuilder().setStringValue("foo").build()) + .putFields("id", Value.newBuilder().setStringValue("1").build()) + .build()) + .putAllParamTypes( + ImmutableMap.of( + "value", Type.newBuilder().setCode(TypeCode.STRING).build(), + "id", Type.newBuilder().setCode(TypeCode.INT64).build())) + .build()); + connection.commit(); + + // There should be no BeginTransaction requests, as the transaction start is inlined with the + // first statement. + assertEquals(0, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0); + assertTrue(request.hasTransaction()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadWrite()); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + } + + @Test + public void testReadOnlyTransaction() { + String sql = "select * from random"; + int numRows = 5; + RandomResultSetGenerator generator = new RandomResultSetGenerator(numRows); + mockSpanner.putStatementResult(StatementResult.query(Statement.of(sql), generator.generate())); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction( + TransactionOptions.newBuilder().setReadOnly(ReadOnly.newBuilder().build()).build()); + try (Rows rows = connection.execute(ExecuteSqlRequest.newBuilder().setSql(sql).build())) { + int rowCount = 0; + while (rows.next() != null) { + rowCount++; + } + assertEquals(numRows, rowCount); + } + connection.commit(); + } + + // There should be no BeginTransaction requests, as the transaction start is inlined with the + // first statement. + assertEquals(0, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0); + assertTrue(request.hasTransaction()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadOnly()); + // There should be no CommitRequests on the server, as committing a read-only transaction is a + // no-op on Spanner. + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Test + public void testBeginTwice() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + // Try to start two transactions on a connection. + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + SpannerLibException exception = + assertThrows( + SpannerLibException.class, + () -> connection.beginTransaction(TransactionOptions.getDefaultInstance())); + assertEquals(Code.FAILED_PRECONDITION.value(), exception.getStatus().getCode()); + } + } +} diff --git a/statements.go b/statements.go index 6920e7f2..f7232a7b 100644 --- a/statements.go +++ b/statements.go @@ -286,7 +286,7 @@ type executableCommitStatement struct { } func (s *executableCommitStatement) execContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Result, error) { - _, err := c.commit(ctx) + _, err := c.Commit(ctx) if err != nil { return nil, err } @@ -305,7 +305,7 @@ type executableRollbackStatement struct { } func (s *executableRollbackStatement) execContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Result, error) { - if err := c.rollback(ctx); err != nil { + if err := c.Rollback(ctx); err != nil { return nil, err } return driver.ResultNoRows, nil diff --git a/transaction.go b/transaction.go index dcc61d1b..5f93f94e 100644 --- a/transaction.go +++ b/transaction.go @@ -407,6 +407,7 @@ func (tx *readWriteTransaction) Commit() (err error) { return err } var commitResponse spanner.CommitResponse + // TODO: Optimize this to skip the Commit also if the transaction has not yet been used. if tx.rwTx != nil { if !tx.retryAborts() { ts, err := tx.rwTx.CommitWithReturnResp(tx.ctx) From e72e994567cb00f916a5de3df635551c7a7694aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Sat, 20 Sep 2025 12:14:07 +0200 Subject: [PATCH 2/2] chore: add ExecuteBatch to SpannerLib (#531) * chore: add ExecuteBatch to SpannerLib Adds an ExecuteBatch function to SpannerLib that supports executing DML or DDL statements as a single batch. The function accepts an ExecuteBatchDml request for both types of batches. The type of batch that is actually being executed is determined based on the statements in the batch. Mixing DML and DDL in the same batch is not supported. Queries are also not supported in batches. * chore: add WriteMutations function for SpannerLib (#532) Adds a WriteMutations function for SpannerLib. This function can be used to write mutations to Spanner in two ways: 1. In a transaction: The mutations are buffered in the current read/write transaction. The returned message is empty. 2. Outside a transaction: The mutations are written to Spanner directly in a new read/write transaction. The returned message contains the CommitResponse. --- conn.go | 29 +++ spannerlib/api/batch_test.go | 238 ++++++++++++++++++ spannerlib/api/connection.go | 173 +++++++++++++ spannerlib/api/connection_test.go | 157 ++++++++++++ spannerlib/lib/connection.go | 40 +++ spannerlib/lib/connection_test.go | 102 ++++++++ spannerlib/shared/shared_lib.go | 28 +++ spannerlib/shared/shared_lib_test.go | 126 ++++++++++ .../google/cloud/spannerlib/Connection.java | 47 ++++ .../spannerlib/internal/SpannerLibrary.java | 15 ++ .../google/cloud/spannerlib/BatchTest.java | 152 +++++++++++ .../cloud/spannerlib/ConnectionTest.java | 149 +++++++++++ 12 files changed, 1256 insertions(+) create mode 100644 spannerlib/api/batch_test.go create mode 100644 spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java diff --git a/conn.go b/conn.go index 9ee4efb1..698c8274 100644 --- a/conn.go +++ b/conn.go @@ -222,6 +222,9 @@ type SpannerConn interface { // return the same Spanner client. UnderlyingClient() (client *spanner.Client, err error) + // DetectStatementType returns the type of SQL statement. + DetectStatementType(query string) parser.StatementType + // resetTransactionForRetry resets the current transaction after it has // been aborted by Spanner. Calling this function on a transaction that // has not been aborted is not supported and will cause an error to be @@ -286,6 +289,11 @@ func (c *conn) UnderlyingClient() (*spanner.Client, error) { return c.client, nil } +func (c *conn) DetectStatementType(query string) parser.StatementType { + info := c.parser.DetectStatementType(query) + return info.StatementType +} + func (c *conn) CommitTimestamp() (time.Time, error) { ts := propertyCommitTimestamp.GetValueOrDefault(c.state) if ts == nil { @@ -675,6 +683,27 @@ func sum(affected []int64) int64 { return sum } +// WriteMutations is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// WriteMutations writes mutations using this connection. The mutations are either buffered in the current transaction, +// or written directly to Spanner using a new read/write transaction if the connection does not have a transaction. +// +// The function returns an error if the connection currently has a read-only transaction. +// +// The returned CommitResponse is nil if the connection currently has a transaction, as the mutations will only be +// applied to Spanner when the transaction commits. +func (c *conn) WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error) { + if c.inTransaction() { + return nil, c.BufferWrite(ms) + } + ts, err := c.Apply(ctx, ms) + if err != nil { + return nil, err + } + return &spanner.CommitResponse{CommitTs: ts}, nil +} + func (c *conn) Apply(ctx context.Context, ms []*spanner.Mutation, opts ...spanner.ApplyOption) (commitTimestamp time.Time, err error) { if c.inTransaction() { return time.Time{}, spanner.ToSpannerError( diff --git a/spannerlib/api/batch_test.go b/spannerlib/api/batch_test.go new file mode 100644 index 00000000..1ae9af14 --- /dev/null +++ b/spannerlib/api/batch_test.go @@ -0,0 +1,238 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "fmt" + "reflect" + "testing" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/googleapis/go-sql-spanner/testutil" + "google.golang.org/grpc/codes" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/emptypb" +) + +func TestExecuteDmlBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Execute a DML batch. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: testutil.UpdateBarSetFoo}, + {Sql: testutil.UpdateBarSetFoo}, + }} + resp, err := ExecuteBatch(ctx, poolId, connId, request) + if err != nil { + t.Fatalf("ExecuteBatch returned unexpected error: %v", err) + } + if g, w := len(resp.ResultSets), 2; g != w { + t.Fatalf("num results mismatch\n Got: %d\nWant: %d", g, w) + } + for i, result := range resp.ResultSets { + if g, w := result.Stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("%d: update count mismatch\n Got: %d\nWant: %d", i, g, w) + } + } + + requests := server.TestSpanner.DrainRequestsFromServer() + // There should be no ExecuteSql requests. + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 0; g != w { + t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w) + } + batchRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteBatchDmlRequest{})) + if g, w := len(batchRequests), 1; g != w { + t.Fatalf("Execute batch request count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestExecuteDdlBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + // Set up a result for a DDL statement on the mock server. + var expectedResponse = &emptypb.Empty{} + anyMsg, _ := anypb.New(expectedResponse) + server.TestDatabaseAdmin.SetResps([]proto.Message{ + &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Response{Response: anyMsg}, + Name: "test-operation", + }, + }) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Execute a DDL batch. This also uses a DML batch request. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: "create table my_table (id int64 primary key, value string(100))"}, + {Sql: "create index my_index on my_table (value)"}, + }} + resp, err := ExecuteBatch(ctx, poolId, connId, request) + if err != nil { + t.Fatalf("ExecuteBatch returned unexpected error: %v", err) + } + // The response should contain an 'update count' per DDL statement. + if g, w := len(resp.ResultSets), 2; g != w { + t.Fatalf("num results mismatch\n Got: %d\nWant: %d", g, w) + } + // There is no update count for DDL statements. + for i, result := range resp.ResultSets { + emptyStats := &spannerpb.ResultSetStats{} + if g, w := result.Stats, emptyStats; !cmp.Equal(g, w, cmpopts.IgnoreUnexported(spannerpb.ResultSetStats{})) { + t.Fatalf("%d: ResultSetStats mismatch\n Got: %v\nWant: %v", i, g, w) + } + } + + requests := server.TestSpanner.DrainRequestsFromServer() + // There should be no ExecuteSql requests. + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 0; g != w { + t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w) + } + // There should also be no ExecuteBatchDml requests. + batchDmlRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteBatchDmlRequest{})) + if g, w := len(batchDmlRequests), 0; g != w { + t.Fatalf("ExecuteBatchDmlRequest count mismatch\n Got: %v\nWant: %v", g, w) + } + + adminRequests := server.TestDatabaseAdmin.Reqs() + if g, w := len(adminRequests), 1; g != w { + t.Fatalf("admin request count mismatch\n Got: %v\nWant: %v", g, w) + } + ddlRequest := adminRequests[0].(*databasepb.UpdateDatabaseDdlRequest) + if g, w := len(ddlRequest.Statements), 2; g != w { + t.Fatalf("DDL statement count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestExecuteMixedBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Try to execute a batch with mixed DML and DDL statements. This should fail. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: "create table my_table (id int64 primary key, value string(100))"}, + {Sql: "update my_table set value = 100 where true"}, + }} + _, err = ExecuteBatch(ctx, poolId, connId, request) + if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestExecuteDdlBatchInTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Try to execute a DDL batch in a transaction. This should fail. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: "create table my_table (id int64 primary key, value string(100))"}, + {Sql: "create index my_index on my_table (value)"}, + }} + _, err = ExecuteBatch(ctx, poolId, connId, request) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} diff --git a/spannerlib/api/connection.go b/spannerlib/api/connection.go index f5f7643a..907212ec 100644 --- a/spannerlib/api/connection.go +++ b/spannerlib/api/connection.go @@ -26,6 +26,7 @@ import ( "cloud.google.com/go/spanner" "cloud.google.com/go/spanner/apiv1/spannerpb" spannerdriver "github.com/googleapis/go-sql-spanner" + "github.com/googleapis/go-sql-spanner/parser" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" @@ -46,6 +47,22 @@ func CloseConnection(ctx context.Context, poolId, connId int64) error { return conn.close(ctx) } +// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in +// the current read/write transaction if the connection currently has a read/write transaction. +// The mutations are applied to the database in a new read/write transaction that is automatically +// committed if the connection currently does not have a transaction. +// +// The function returns an error if the connection is currently in a read-only transaction. +// +// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object. +func WriteMutations(ctx context.Context, poolId, connId int64, mutations *spannerpb.BatchWriteRequest_MutationGroup) (*spannerpb.CommitResponse, error) { + conn, err := findConnection(poolId, connId) + if err != nil { + return nil, err + } + return conn.writeMutations(ctx, mutations) +} + // BeginTransaction starts a new transaction on the given connection. // A connection can have at most one transaction at any time. This function therefore returns an error if the // connection has an active transaction. @@ -83,6 +100,14 @@ func Execute(ctx context.Context, poolId, connId int64, executeSqlRequest *spann return conn.Execute(ctx, executeSqlRequest) } +func ExecuteBatch(ctx context.Context, poolId, connId int64, statements *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { + conn, err := findConnection(poolId, connId) + if err != nil { + return nil, err + } + return conn.ExecuteBatch(ctx, statements.Statements) +} + type Connection struct { // results contains the open query results for this connection. results *sync.Map @@ -95,6 +120,7 @@ type Connection struct { // spannerConn is an internal interface that contains the internal functions that are used by this API. // It is implemented by the spannerdriver.conn struct. type spannerConn interface { + WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error) BeginReadOnlyTransaction(ctx context.Context, options *spannerdriver.ReadOnlyTransactionOptions) (driver.Tx, error) BeginReadWriteTransaction(ctx context.Context, options *spannerdriver.ReadWriteTransactionOptions) (driver.Tx, error) Commit(ctx context.Context) (*spanner.CommitResponse, error) @@ -118,6 +144,34 @@ func (conn *Connection) close(ctx context.Context) error { return nil } +func (conn *Connection) writeMutations(ctx context.Context, mutation *spannerpb.BatchWriteRequest_MutationGroup) (*spannerpb.CommitResponse, error) { + mutations := make([]*spanner.Mutation, 0, len(mutation.Mutations)) + for _, m := range mutation.Mutations { + spannerMutation, err := spanner.WrapMutation(m) + if err != nil { + return nil, err + } + mutations = append(mutations, spannerMutation) + } + var commitResponse *spanner.CommitResponse + if err := conn.backend.Raw(func(driverConn any) (err error) { + sc, _ := driverConn.(spannerConn) + commitResponse, err = sc.WriteMutations(ctx, mutations) + return err + }); err != nil { + return nil, err + } + + // The commit response is nil if the connection is currently in a transaction. + if commitResponse == nil { + return nil, nil + } + response := spannerpb.CommitResponse{ + CommitTimestamp: timestamppb.New(commitResponse.CommitTs), + } + return &response, nil +} + func (conn *Connection) BeginTransaction(ctx context.Context, txOpts *spannerpb.TransactionOptions) error { var err error if txOpts.GetReadOnly() != nil { @@ -235,6 +289,10 @@ func (conn *Connection) Execute(ctx context.Context, statement *spannerpb.Execut return execute(ctx, conn, conn.backend, statement) } +func (conn *Connection) ExecuteBatch(ctx context.Context, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + return executeBatch(ctx, conn, conn.backend, statements) +} + func execute(ctx context.Context, conn *Connection, executor queryExecutor, statement *spannerpb.ExecuteSqlRequest) (int64, error) { params := extractParams(statement) it, err := executor.QueryContext(ctx, statement.Sql, params...) @@ -266,6 +324,90 @@ func execute(ctx context.Context, conn *Connection, executor queryExecutor, stat return id, nil } +func executeBatch(ctx context.Context, conn *Connection, executor queryExecutor, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + // Determine the type of batch that should be executed based on the type of statements. + batchType, err := determineBatchType(conn, statements) + if err != nil { + return nil, err + } + switch batchType { + case parser.BatchTypeDml: + return executeBatchDml(ctx, conn, executor, statements) + case parser.BatchTypeDdl: + return executeBatchDdl(ctx, conn, executor, statements) + default: + return nil, status.Errorf(codes.InvalidArgument, "unsupported batch type: %v", batchType) + } +} + +func executeBatchDdl(ctx context.Context, conn *Connection, executor queryExecutor, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + if err := conn.backend.Raw(func(driverConn any) error { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + return spannerConn.StartBatchDDL() + }); err != nil { + return nil, err + } + for _, statement := range statements { + _, err := executor.ExecContext(ctx, statement.Sql) + if err != nil { + return nil, err + } + } + // TODO: Add support for getting the actual Batch DDL response. + if err := conn.backend.Raw(func(driverConn any) (err error) { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + return spannerConn.RunBatch(ctx) + }); err != nil { + return nil, err + } + + response := spannerpb.ExecuteBatchDmlResponse{} + response.ResultSets = make([]*spannerpb.ResultSet, len(statements)) + for i := range statements { + response.ResultSets[i] = &spannerpb.ResultSet{Stats: &spannerpb.ResultSetStats{}} + } + return &response, nil +} + +func executeBatchDml(ctx context.Context, conn *Connection, executor queryExecutor, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + if err := conn.backend.Raw(func(driverConn any) error { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + return spannerConn.StartBatchDML() + }); err != nil { + return nil, err + } + for _, statement := range statements { + request := &spannerpb.ExecuteSqlRequest{ + Sql: statement.Sql, + Params: statement.Params, + ParamTypes: statement.ParamTypes, + } + params := extractParams(request) + _, err := executor.ExecContext(ctx, statement.Sql, params...) + if err != nil { + return nil, err + } + } + var spannerResult spannerdriver.SpannerResult + if err := conn.backend.Raw(func(driverConn any) (err error) { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + spannerResult, err = spannerConn.RunDmlBatch(ctx) + return err + }); err != nil { + return nil, err + } + affected, err := spannerResult.BatchRowsAffected() + if err != nil { + return nil, err + } + response := spannerpb.ExecuteBatchDmlResponse{} + response.ResultSets = make([]*spannerpb.ResultSet, len(affected)) + for i, aff := range affected { + response.ResultSets[i] = &spannerpb.ResultSet{Stats: &spannerpb.ResultSetStats{RowCount: &spannerpb.ResultSetStats_RowCountExact{RowCountExact: aff}}} + } + return &response, nil +} + func extractParams(statement *spannerpb.ExecuteSqlRequest) []any { paramsLen := 1 if statement.Params != nil { @@ -300,3 +442,34 @@ func extractParams(statement *spannerpb.ExecuteSqlRequest) []any { } return params } + +func determineBatchType(conn *Connection, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (parser.BatchType, error) { + if len(statements) == 0 { + return parser.BatchTypeDdl, status.Errorf(codes.InvalidArgument, "cannot determine type of an empty batch") + } + var batchType parser.BatchType + if err := conn.backend.Raw(func(driverConn any) error { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + firstStatementType := spannerConn.DetectStatementType(statements[0].Sql) + if firstStatementType == parser.StatementTypeDml { + batchType = parser.BatchTypeDml + } else if firstStatementType == parser.StatementTypeDdl { + batchType = parser.BatchTypeDdl + } else { + return status.Errorf(codes.InvalidArgument, "unsupported statement type for batching: %v", firstStatementType) + } + for i, statement := range statements { + if i > 0 { + tp := spannerConn.DetectStatementType(statement.Sql) + if tp != firstStatementType { + return status.Errorf(codes.InvalidArgument, "Batches may not contain different types of statements. The first statement is of type %v. The statement on position %d is of type %v.", firstStatementType, i, tp) + } + } + } + return nil + }); err != nil { + return parser.BatchTypeDdl, err + } + + return batchType, nil +} diff --git a/spannerlib/api/connection_test.go b/spannerlib/api/connection_test.go index 64d19ab3..b4625bb3 100644 --- a/spannerlib/api/connection_test.go +++ b/spannerlib/api/connection_test.go @@ -24,6 +24,7 @@ import ( "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/googleapis/go-sql-spanner/testutil" "google.golang.org/grpc/codes" + "google.golang.org/protobuf/types/known/structpb" ) func TestCreateAndCloseConnection(t *testing.T) { @@ -143,3 +144,159 @@ func TestCloseConnectionTwice(t *testing.T) { t.Fatalf("ClosePool returned unexpected error: %v", err) } } + +func TestWriteMutations(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + {Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}}, + {Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}}, + }, + }}}, + {Operation: &spannerpb.Mutation_Update{Update: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("0"), structpb.NewStringValue("Zero")}}, + }, + }}}, + }} + resp, err := WriteMutations(ctx, poolId, connId, mutations) + if err != nil { + t.Fatalf("WriteMutations returned unexpected error: %v", err) + } + if resp.CommitTimestamp == nil { + t.Fatalf("CommitTimestamp is nil") + } + requests := server.TestSpanner.DrainRequestsFromServer() + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequest := commitRequests[0].(*spannerpb.CommitRequest) + if g, w := len(commitRequest.Mutations), 2; g != w { + t.Fatalf("num mutations mismatch\n Got: %d\nWant: %d", g, w) + } + + // Write the same mutations in a transaction. + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + resp, err = WriteMutations(ctx, poolId, connId, mutations) + if err != nil { + t.Fatalf("WriteMutations returned unexpected error: %v", err) + } + if resp != nil { + t.Fatalf("WriteMutations returned unexpected response: %v", resp) + } + resp, err = Commit(ctx, poolId, connId) + if err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + if resp == nil { + t.Fatalf("Commit returned nil response") + } + if resp.CommitTimestamp == nil { + t.Fatalf("CommitTimestamp is nil") + } + requests = server.TestSpanner.DrainRequestsFromServer() + beginRequests = testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequests = testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequest = commitRequests[0].(*spannerpb.CommitRequest) + if g, w := len(commitRequest.Mutations), 2; g != w { + t.Fatalf("num mutations mismatch\n Got: %d\nWant: %d", g, w) + } + + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestWriteMutationsInReadOnlyTx(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Start a read-only transaction and try to write mutations to that transaction. That should return an error. + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadOnly_{ReadOnly: &spannerpb.TransactionOptions_ReadOnly{}}, + }); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + }, + }}}, + }} + _, err = WriteMutations(ctx, poolId, connId, mutations) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("WriteMutations error code mismatch\n Got: %d\nWant: %d", g, w) + } + + // Committing the read-only transaction should not lead to any commits on Spanner. + _, err = Commit(ctx, poolId, connId) + if err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + requests := server.TestSpanner.DrainRequestsFromServer() + // There should also not be any BeginTransaction requests on Spanner, as the transaction was never really started + // by a query or other statement. + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 0; g != w { + t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w) + } + + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} diff --git a/spannerlib/lib/connection.go b/spannerlib/lib/connection.go index d41b6a6e..75c6bc3a 100644 --- a/spannerlib/lib/connection.go +++ b/spannerlib/lib/connection.go @@ -34,6 +34,30 @@ func CloseConnection(ctx context.Context, poolId, connId int64) *Message { return &Message{} } +// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in +// the current read/write transaction if the connection currently has a read/write transaction. +// The mutations are applied to the database in a new read/write transaction that is automatically +// committed if the connection currently does not have a transaction. +// +// The function returns an error if the connection is currently in a read-only transaction. +// +// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object. +func WriteMutations(ctx context.Context, poolId, connId int64, mutationBytes []byte) *Message { + mutations := spannerpb.BatchWriteRequest_MutationGroup{} + if err := proto.Unmarshal(mutationBytes, &mutations); err != nil { + return errMessage(err) + } + response, err := api.WriteMutations(ctx, poolId, connId, &mutations) + if err != nil { + return errMessage(err) + } + res, err := proto.Marshal(response) + if err != nil { + return errMessage(err) + } + return &Message{Res: res} +} + // BeginTransaction starts a new transaction on the given connection. A connection can have at most one active // transaction at any time. This function therefore returns an error if the connection has an active transaction. func BeginTransaction(ctx context.Context, poolId, connId int64, txOptsBytes []byte) *Message { @@ -89,3 +113,19 @@ func Execute(ctx context.Context, poolId, connId int64, executeSqlRequestBytes [ } return idMessage(id) } + +func ExecuteBatch(ctx context.Context, poolId, connId int64, statementsBytes []byte) *Message { + statements := spannerpb.ExecuteBatchDmlRequest{} + if err := proto.Unmarshal(statementsBytes, &statements); err != nil { + return errMessage(err) + } + response, err := api.ExecuteBatch(ctx, poolId, connId, &statements) + if err != nil { + return errMessage(err) + } + res, err := proto.Marshal(response) + if err != nil { + return errMessage(err) + } + return &Message{Res: res} +} diff --git a/spannerlib/lib/connection_test.go b/spannerlib/lib/connection_test.go index 02ba453c..193e3fd9 100644 --- a/spannerlib/lib/connection_test.go +++ b/spannerlib/lib/connection_test.go @@ -23,6 +23,7 @@ import ( "github.com/googleapis/go-sql-spanner/testutil" "google.golang.org/grpc/codes" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" ) func TestCreateAndCloseConnection(t *testing.T) { @@ -117,6 +118,48 @@ func TestExecute(t *testing.T) { } } +func TestExecuteBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: testutil.UpdateBarSetFoo}, + {Sql: testutil.UpdateBarSetFoo}, + }} + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + rowsMsg := ExecuteBatch(ctx, poolMsg.ObjectId, connMsg.ObjectId, requestBytes) + if g, w := rowsMsg.Code, int32(0); g != w { + t.Fatalf("ExecuteBatch result mismatch\n Got: %v\nWant: %v", g, w) + } + if rowsMsg.Length() == 0 { + t.Fatal("ExecuteBatch returned no data") + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + func TestBeginAndCommit(t *testing.T) { t.Parallel() @@ -220,3 +263,62 @@ func TestBeginAndRollback(t *testing.T) { t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) } } + +func TestWriteMutations(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + {Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}}, + {Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}}, + }, + }}}, + }} + mutationBytes, err := proto.Marshal(mutations) + if err != nil { + t.Fatal(err) + } + mutationsMsg := WriteMutations(ctx, poolMsg.ObjectId, connMsg.ObjectId, mutationBytes) + if g, w := mutationsMsg.Code, int32(0); g != w { + t.Fatalf("WriteMutations result mismatch\n Got: %v\nWant: %v", g, w) + } + if mutationsMsg.Length() == 0 { + t.Fatal("WriteMutations returned no data") + } + + // Write mutations in a transaction. + mutationsMsg = BeginTransaction(ctx, poolMsg.ObjectId, connMsg.ObjectId, mutationBytes) + // The response should now be an empty message, as the mutations were only buffered in the transaction. + if g, w := mutationsMsg.Code, int32(0); g != w { + t.Fatalf("WriteMutations result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := mutationsMsg.Length(), int32(0); g != w { + t.Fatalf("WriteMutations data length mismatch\n Got: %v\nWant: %v", g, w) + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} diff --git a/spannerlib/shared/shared_lib.go b/spannerlib/shared/shared_lib.go index e9c9be6a..8645f219 100644 --- a/spannerlib/shared/shared_lib.go +++ b/spannerlib/shared/shared_lib.go @@ -111,6 +111,22 @@ func CloseConnection(poolId, connId int64) (int64, int32, int64, int32, unsafe.P return pin(msg) } +// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in +// the current read/write transaction if the connection currently has a read/write transaction. +// The mutations are applied to the database in a new read/write transaction that is automatically +// committed if the connection currently does not have a transaction. +// +// The function returns an error if the connection is currently in a read-only transaction. +// +// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object. +// +//export WriteMutations +func WriteMutations(poolId, connectionId int64, mutationsBytes []byte) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.WriteMutations(ctx, poolId, connectionId, mutationsBytes) + return pin(msg) +} + // Execute executes a SQL statement on the given connection. // The return type is an identifier for a Rows object. This identifier can be used to // call the functions Metadata and Next to get respectively the metadata of the result @@ -123,6 +139,18 @@ func Execute(poolId, connectionId int64, statement []byte) (int64, int32, int64, return pin(msg) } +// ExecuteBatch executes a batch of statements on the given connection. The statements must all be either DML or DDL +// statements. Mixing DML and DDL in a batch is not supported. Executing queries in a batch is also not supported. +// The batch will use the current transaction on the given connection, or execute as a single auto-commit statement +// if the connection does not have a transaction. +// +//export ExecuteBatch +func ExecuteBatch(poolId, connectionId int64, statements []byte) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.ExecuteBatch(ctx, poolId, connectionId, statements) + return pin(msg) +} + // Metadata returns the metadata of a Rows object. // //export Metadata diff --git a/spannerlib/shared/shared_lib_test.go b/spannerlib/shared/shared_lib_test.go index 82d010bd..b6d63467 100644 --- a/spannerlib/shared/shared_lib_test.go +++ b/spannerlib/shared/shared_lib_test.go @@ -245,6 +245,63 @@ func TestExecute(t *testing.T) { } } +func TestExecuteBatch(t *testing.T) { + t.Parallel() + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + request := &spannerpb.ExecuteBatchDmlRequest{ + Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: testutil.UpdateBarSetFoo}, + {Sql: testutil.UpdateBarSetFoo}, + }, + } + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatal(err) + } + // ExecuteBatch returns a ExecuteBatchDml response. + mem, code, batchId, length, data := ExecuteBatch(poolId, connId, requestBytes) + verifyDataMessage(t, "ExecuteBatch", mem, code, batchId, length, data) + response := &spannerpb.ExecuteBatchDmlResponse{} + responseBytes := reflect.SliceAt(reflect.TypeOf(byte(0)), data, int(length)).Bytes() + if err := proto.Unmarshal(responseBytes, response); err != nil { + t.Fatal(err) + } + if g, w := len(response.ResultSets), 2; g != w { + t.Fatalf("num results mismatch\n Got: %v\nWant: %v", g, w) + } + for i, result := range response.ResultSets { + if g, w := result.Stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("%d: update count mismatch\n Got: %v\nWant: %v", i, g, w) + } + } + // Release the memory held by the response. + if g, w := Release(mem), int32(0); g != w { + t.Fatalf("Release() result mismatch\n Got: %v\nWant: %v", g, w) + } + + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + func TestBeginAndCommitTransaction(t *testing.T) { t.Parallel() @@ -355,6 +412,75 @@ func TestBeginAndRollbackTransaction(t *testing.T) { } } +func TestWriteMutations(t *testing.T) { + t.Parallel() + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + {Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}}, + {Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}}, + }, + }}}, + }} + mutationBytes, err := proto.Marshal(mutations) + if err != nil { + t.Fatal(err) + } + // WriteMutations returns a CommitResponse or nil, depending on whether the connection has an active transaction. + mem, code, id, length, data := WriteMutations(poolId, connId, mutationBytes) + verifyDataMessage(t, "WriteMutations", mem, code, id, length, data) + + response := &spannerpb.CommitResponse{} + responseBytes := reflect.SliceAt(reflect.TypeOf(byte(0)), data, int(length)).Bytes() + if err := proto.Unmarshal(responseBytes, response); err != nil { + t.Fatal(err) + } + if response.CommitTimestamp == nil { + t.Fatal("CommitTimestamp is nil") + } + // Release the memory held by the response. + if g, w := Release(mem), int32(0); g != w { + t.Fatalf("Release() result mismatch\n Got: %v\nWant: %v", g, w) + } + + // Start a transaction on the connection and write the mutations to that transaction. + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + _, code, _, _, _ = BeginTransaction(poolId, connId, txOptsBytes) + if g, w := code, int32(0); g != w { + t.Fatalf("BeginTransaction result mismatch\n Got: %v\nWant: %v", g, w) + } + mem, code, id, length, data = WriteMutations(poolId, connId, mutationBytes) + // The response should now be an empty message, as the mutations were buffered in the current transaction. + verifyEmptyMessage(t, "WriteMutations in tx", mem, code, id, length, data) + + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + func verifyEmptyMessage(t *testing.T, name string, mem int64, code int32, id int64, length int32, res unsafe.Pointer) { if g, w := mem, int64(0); g != w { t.Fatalf("%s: mem ID mismatch\n Got: %v\nWant: %v", name, g, w) diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java index 90ac04f9..18a9ca50 100644 --- a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java @@ -21,7 +21,10 @@ import com.google.cloud.spannerlib.internal.MessageHandler; import com.google.cloud.spannerlib.internal.WrappedGoBytes; import com.google.protobuf.InvalidProtocolBufferException; +import com.google.spanner.v1.BatchWriteRequest.MutationGroup; import com.google.spanner.v1.CommitResponse; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteBatchDmlResponse; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.TransactionOptions; import java.nio.ByteBuffer; @@ -39,6 +42,31 @@ public Pool getPool() { return this.pool; } + /** + * Writes a group of mutations to Spanner. The mutations are buffered in the current read/write + * transaction if the connection has an active read/write transaction. Otherwise, the mutations + * are written directly to Spanner using a new read/write transaction. Returns a {@link + * CommitResponse} if the mutations were written directly to Spanner, and otherwise null if the + * mutations were buffered in the current transaction. + */ + public CommitResponse WriteMutations(MutationGroup mutations) { + try (WrappedGoBytes serializedRequest = WrappedGoBytes.serialize(mutations); + MessageHandler message = + getLibrary() + .execute( + library -> + library.WriteMutations( + pool.getId(), getId(), serializedRequest.getGoBytes()))) { + if (message.getLength() == 0) { + return null; + } + ByteBuffer buffer = message.getValue().getByteBuffer(0, message.getLength()); + return CommitResponse.parseFrom(buffer); + } catch (InvalidProtocolBufferException decodeException) { + throw new RuntimeException(decodeException); + } + } + /** Starts a transaction on this connection. */ public void beginTransaction(TransactionOptions options) { try (WrappedGoBytes serializedOptions = WrappedGoBytes.serialize(options)) { @@ -84,6 +112,25 @@ public Rows execute(ExecuteSqlRequest request) { } } + /** + * Executes the given batch of DML or DDL statements on this connection. The statements must all + * be of the same type. + */ + public ExecuteBatchDmlResponse executeBatch(ExecuteBatchDmlRequest request) { + try (WrappedGoBytes serializedRequest = WrappedGoBytes.serialize(request); + MessageHandler message = + getLibrary() + .execute( + library -> + library.ExecuteBatch( + pool.getId(), getId(), serializedRequest.getGoBytes()))) { + ByteBuffer buffer = message.getValue().getByteBuffer(0, message.getLength()); + return ExecuteBatchDmlResponse.parseFrom(buffer); + } catch (InvalidProtocolBufferException decodeException) { + throw new RuntimeException(decodeException); + } + } + /** Closes this connection. Any active transaction on the connection is rolled back. */ @Override public void close() { diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java index 1ad3f14e..7da5a013 100644 --- a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java @@ -65,6 +65,15 @@ default MessageHandler execute(Function function) /** Closes the given Connection. */ Message CloseConnection(long poolId, long connectionId); + /** + * Writes a group of mutations on Spanner. The mutations are buffered in the current read/write + * transaction if the connection has an active read/write transaction. Otherwise, the mutations + * are written directly to Spanner in a new read/write transaction. Returns a {@link + * com.google.spanner.v1.CommitResponse} if the mutations were written directly to Spanner, and an + * empty message if the mutations were only buffered in the current transaction. + */ + Message WriteMutations(long poolId, long connectionId, GoBytes mutations); + /** Starts a new transaction on the given Connection. */ Message BeginTransaction(long poolId, long connectionId, GoBytes transactionOptions); @@ -80,6 +89,12 @@ default MessageHandler execute(Function function) /** Executes a SQL statement on the given Connection. */ Message Execute(long poolId, long connectionId, GoBytes executeSqlRequest); + /** + * Executes a batch of DML or DDL statements on the given Connection. Returns an {@link + * com.google.spanner.v1.ExecuteBatchDmlResponse} for both DML and DDL batches. + */ + Message ExecuteBatch(long poolId, long connectionId, GoBytes executeBatchDmlRequest); + /** Returns the {@link com.google.spanner.v1.ResultSetMetadata} of the given Rows object. */ Message Metadata(long poolId, long connectionId, long rowsId); diff --git a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java new file mode 100644 index 00000000..9740d541 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java @@ -0,0 +1,152 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spannerlib; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.common.collect.ImmutableMap; +import com.google.longrunning.Operation; +import com.google.protobuf.AbstractMessage; +import com.google.protobuf.Any; +import com.google.protobuf.Empty; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import com.google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata; +import com.google.spanner.admin.database.v1.UpdateDatabaseDdlRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest.Statement; +import com.google.spanner.v1.ExecuteBatchDmlResponse; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; +import java.util.List; +import org.junit.Test; + +public class BatchTest extends AbstractMockServerTest { + + @Test + public void testBatchDml() { + String insert = "insert into test (id, value) values (@id, @value)"; + mockSpanner.putStatementResult( + StatementResult.update( + com.google.cloud.spanner.Statement.newBuilder(insert) + .bind("id") + .to(1L) + .bind("value") + .to("One") + .build(), + 1L)); + mockSpanner.putStatementResult( + StatementResult.update( + com.google.cloud.spanner.Statement.newBuilder(insert) + .bind("id") + .to(2L) + .bind("value") + .to("Two") + .build(), + 1L)); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + ExecuteBatchDmlResponse response = + connection.executeBatch( + ExecuteBatchDmlRequest.newBuilder() + .addStatements( + Statement.newBuilder() + .setSql(insert) + .setParams( + Struct.newBuilder() + .putFields("id", Value.newBuilder().setStringValue("1").build()) + .putFields( + "value", Value.newBuilder().setStringValue("One").build()) + .build()) + .putAllParamTypes( + ImmutableMap.of( + "id", Type.newBuilder().setCode(TypeCode.INT64).build(), + "value", Type.newBuilder().setCode(TypeCode.STRING).build())) + .build()) + .addStatements( + Statement.newBuilder() + .setSql(insert) + .setParams( + Struct.newBuilder() + .putFields("id", Value.newBuilder().setStringValue("2").build()) + .putFields( + "value", Value.newBuilder().setStringValue("Two").build()) + .build()) + .putAllParamTypes( + ImmutableMap.of( + "id", Type.newBuilder().setCode(TypeCode.INT64).build(), + "value", Type.newBuilder().setCode(TypeCode.STRING).build())) + .build()) + .build()); + + assertEquals(2, response.getResultSetsCount()); + assertEquals(1L, response.getResultSets(0).getStats().getRowCountExact()); + assertEquals(1L, response.getResultSets(1).getStats().getRowCountExact()); + } + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + } + + @Test + public void testBatchDdl() { + // Set up a DDL response on the mock server. + mockDatabaseAdmin.addResponse( + Operation.newBuilder() + .setDone(true) + .setResponse(Any.pack(Empty.getDefaultInstance())) + .setMetadata(Any.pack(UpdateDatabaseDdlMetadata.getDefaultInstance())) + .build()); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + ExecuteBatchDmlResponse response = + connection.executeBatch( + ExecuteBatchDmlRequest.newBuilder() + .addStatements( + Statement.newBuilder() + .setSql("create table my_table (id int64 primary key, value string(max))") + .build()) + .addStatements( + Statement.newBuilder() + .setSql("create index my_index on my_table (value)") + .build()) + .build()); + + assertEquals(2, response.getResultSetsCount()); + assertFalse(response.getResultSets(0).getStats().hasRowCountExact()); + } + + List requests = mockDatabaseAdmin.getRequests(); + assertEquals(1, requests.size()); + UpdateDatabaseDdlRequest request = (UpdateDatabaseDdlRequest) requests.get(0); + assertEquals(2, request.getStatementsCount()); + } +} diff --git a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java index 5415b867..8efc22b5 100644 --- a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java +++ b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java @@ -18,10 +18,25 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; +import com.google.rpc.Code; +import com.google.spanner.v1.BatchWriteRequest.MutationGroup; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.CommitResponse; import com.google.spanner.v1.CreateSessionRequest; +import com.google.spanner.v1.Mutation; +import com.google.spanner.v1.Mutation.Write; +import com.google.spanner.v1.TransactionOptions; +import com.google.spanner.v1.TransactionOptions.ReadOnly; import org.junit.Test; public class ConnectionTest extends AbstractMockServerTest { @@ -53,4 +68,138 @@ public void testCreateTwoConnections() { } } } + + @Test + public void testWriteMutations() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + CommitResponse response = + connection.WriteMutations( + MutationGroup.newBuilder() + .addMutations( + Mutation.newBuilder() + .setInsert( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues( + Value.newBuilder().setStringValue("One").build()) + .build()) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("2").build()) + .addValues( + Value.newBuilder().setStringValue("Two").build()) + .build()) + .build()) + .build()) + .addMutations( + Mutation.newBuilder() + .setInsertOrUpdate( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("0").build()) + .addValues( + Value.newBuilder().setStringValue("Zero").build()) + .build()) + .build()) + .build()) + .build()); + assertNotNull(response); + assertNotNull(response.getCommitTimestamp()); + + assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + CommitRequest request = mockSpanner.getRequestsOfType(CommitRequest.class).get(0); + assertEquals(2, request.getMutationsCount()); + assertEquals(2, request.getMutations(0).getInsert().getValuesCount()); + assertEquals(1, request.getMutations(1).getInsertOrUpdate().getValuesCount()); + } + } + + @Test + public void testWriteMutationsInTransaction() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + CommitResponse response = + connection.WriteMutations( + MutationGroup.newBuilder() + .addMutations( + Mutation.newBuilder() + .setInsert( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues( + Value.newBuilder().setStringValue("One").build()) + .build()) + .build()) + .build()) + .build()); + // The mutations are only buffered in the current transaction, so there should be no response. + assertNull(response); + + // Committing the transaction should return a CommitResponse. + response = connection.commit(); + assertNotNull(response); + assertNotNull(response.getCommitTimestamp()); + + assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + CommitRequest request = mockSpanner.getRequestsOfType(CommitRequest.class).get(0); + assertEquals(1, request.getMutationsCount()); + assertEquals(1, request.getMutations(0).getInsert().getValuesCount()); + } + } + + @Test + public void testWriteMutationsInReadOnlyTransaction() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction( + TransactionOptions.newBuilder().setReadOnly(ReadOnly.newBuilder().build()).build()); + SpannerLibException exception = + assertThrows( + SpannerLibException.class, + () -> + connection.WriteMutations( + MutationGroup.newBuilder() + .addMutations( + Mutation.newBuilder() + .setInsert( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues( + Value.newBuilder() + .setStringValue("1") + .build()) + .addValues( + Value.newBuilder() + .setStringValue("One") + .build()) + .build()) + .build()) + .build()) + .build())); + assertEquals(Code.FAILED_PRECONDITION.getNumber(), exception.getStatus().getCode()); + } + } }