diff --git a/conn.go b/conn.go index 7f9ece77..698c8274 100644 --- a/conn.go +++ b/conn.go @@ -222,6 +222,9 @@ type SpannerConn interface { // return the same Spanner client. UnderlyingClient() (client *spanner.Client, err error) + // DetectStatementType returns the type of SQL statement. + DetectStatementType(query string) parser.StatementType + // resetTransactionForRetry resets the current transaction after it has // been aborted by Spanner. Calling this function on a transaction that // has not been aborted is not supported and will cause an error to be @@ -286,6 +289,11 @@ func (c *conn) UnderlyingClient() (*spanner.Client, error) { return c.client, nil } +func (c *conn) DetectStatementType(query string) parser.StatementType { + info := c.parser.DetectStatementType(query) + return info.StatementType +} + func (c *conn) CommitTimestamp() (time.Time, error) { ts := propertyCommitTimestamp.GetValueOrDefault(c.state) if ts == nil { @@ -675,6 +683,27 @@ func sum(affected []int64) int64 { return sum } +// WriteMutations is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// WriteMutations writes mutations using this connection. The mutations are either buffered in the current transaction, +// or written directly to Spanner using a new read/write transaction if the connection does not have a transaction. +// +// The function returns an error if the connection currently has a read-only transaction. +// +// The returned CommitResponse is nil if the connection currently has a transaction, as the mutations will only be +// applied to Spanner when the transaction commits. +func (c *conn) WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error) { + if c.inTransaction() { + return nil, c.BufferWrite(ms) + } + ts, err := c.Apply(ctx, ms) + if err != nil { + return nil, err + } + return &spanner.CommitResponse{CommitTs: ts}, nil +} + func (c *conn) Apply(ctx context.Context, ms []*spanner.Mutation, opts ...spanner.ApplyOption) (commitTimestamp time.Time, err error) { if c.inTransaction() { return time.Time{}, spanner.ToSpannerError( @@ -1071,6 +1100,34 @@ func (c *conn) getBatchReadOnlyTransactionOptions() BatchReadOnlyTransactionOpti return BatchReadOnlyTransactionOptions{TimestampBound: c.ReadOnlyStaleness()} } +// BeginReadOnlyTransaction is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// BeginReadOnlyTransaction starts a new read-only transaction on this connection. +func (c *conn) BeginReadOnlyTransaction(ctx context.Context, options *ReadOnlyTransactionOptions) (driver.Tx, error) { + c.withTempReadOnlyTransactionOptions(options) + tx, err := c.BeginTx(ctx, driver.TxOptions{ReadOnly: true}) + if err != nil { + c.withTempReadOnlyTransactionOptions(nil) + return nil, err + } + return tx, nil +} + +// BeginReadWriteTransaction is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// BeginReadWriteTransaction starts a new read/write transaction on this connection. +func (c *conn) BeginReadWriteTransaction(ctx context.Context, options *ReadWriteTransactionOptions) (driver.Tx, error) { + c.withTempTransactionOptions(options) + tx, err := c.BeginTx(ctx, driver.TxOptions{}) + if err != nil { + c.withTempTransactionOptions(nil) + return nil, err + } + return tx, nil +} + func (c *conn) Begin() (driver.Tx, error) { return c.BeginTx(context.Background(), driver.TxOptions{}) } @@ -1254,7 +1311,11 @@ func (c *conn) inReadWriteTransaction() bool { return false } -func (c *conn) commit(ctx context.Context) (*spanner.CommitResponse, error) { +// Commit is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// Commit commits the current transaction on this connection. +func (c *conn) Commit(ctx context.Context) (*spanner.CommitResponse, error) { if !c.inTransaction() { return nil, status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction") } @@ -1262,10 +1323,17 @@ func (c *conn) commit(ctx context.Context) (*spanner.CommitResponse, error) { if err := c.tx.Commit(); err != nil { return nil, err } - return c.CommitResponse() + + // This will return either the commit response or nil, depending on whether the transaction was a + // read/write transaction or a read-only transaction. + return propertyCommitResponse.GetValueOrDefault(c.state), nil } -func (c *conn) rollback(ctx context.Context) error { +// Rollback is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// Rollback rollbacks the current transaction on this connection. +func (c *conn) Rollback(ctx context.Context) error { if !c.inTransaction() { return status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction") } diff --git a/conn_with_mockserver_test.go b/conn_with_mockserver_test.go index 848ee24a..5ceba7ed 100644 --- a/conn_with_mockserver_test.go +++ b/conn_with_mockserver_test.go @@ -82,6 +82,26 @@ func TestTwoTransactionsOnOneConn(t *testing.T) { } } +func TestTwoQueriesOnOneConn(t *testing.T) { + t.Parallel() + + db, _, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + c, _ := db.Conn(ctx) + defer silentClose(c) + + for range 2 { + r, err := c.QueryContext(context.Background(), testutil.SelectFooFromBar) + if err != nil { + t.Fatal(err) + } + _ = r.Next() + defer silentClose(r) + } +} + func TestExplicitBeginTx(t *testing.T) { t.Parallel() diff --git a/spannerlib/api/batch_test.go b/spannerlib/api/batch_test.go new file mode 100644 index 00000000..1ae9af14 --- /dev/null +++ b/spannerlib/api/batch_test.go @@ -0,0 +1,238 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "fmt" + "reflect" + "testing" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/googleapis/go-sql-spanner/testutil" + "google.golang.org/grpc/codes" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/emptypb" +) + +func TestExecuteDmlBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Execute a DML batch. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: testutil.UpdateBarSetFoo}, + {Sql: testutil.UpdateBarSetFoo}, + }} + resp, err := ExecuteBatch(ctx, poolId, connId, request) + if err != nil { + t.Fatalf("ExecuteBatch returned unexpected error: %v", err) + } + if g, w := len(resp.ResultSets), 2; g != w { + t.Fatalf("num results mismatch\n Got: %d\nWant: %d", g, w) + } + for i, result := range resp.ResultSets { + if g, w := result.Stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("%d: update count mismatch\n Got: %d\nWant: %d", i, g, w) + } + } + + requests := server.TestSpanner.DrainRequestsFromServer() + // There should be no ExecuteSql requests. + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 0; g != w { + t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w) + } + batchRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteBatchDmlRequest{})) + if g, w := len(batchRequests), 1; g != w { + t.Fatalf("Execute batch request count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestExecuteDdlBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + // Set up a result for a DDL statement on the mock server. + var expectedResponse = &emptypb.Empty{} + anyMsg, _ := anypb.New(expectedResponse) + server.TestDatabaseAdmin.SetResps([]proto.Message{ + &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Response{Response: anyMsg}, + Name: "test-operation", + }, + }) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Execute a DDL batch. This also uses a DML batch request. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: "create table my_table (id int64 primary key, value string(100))"}, + {Sql: "create index my_index on my_table (value)"}, + }} + resp, err := ExecuteBatch(ctx, poolId, connId, request) + if err != nil { + t.Fatalf("ExecuteBatch returned unexpected error: %v", err) + } + // The response should contain an 'update count' per DDL statement. + if g, w := len(resp.ResultSets), 2; g != w { + t.Fatalf("num results mismatch\n Got: %d\nWant: %d", g, w) + } + // There is no update count for DDL statements. + for i, result := range resp.ResultSets { + emptyStats := &spannerpb.ResultSetStats{} + if g, w := result.Stats, emptyStats; !cmp.Equal(g, w, cmpopts.IgnoreUnexported(spannerpb.ResultSetStats{})) { + t.Fatalf("%d: ResultSetStats mismatch\n Got: %v\nWant: %v", i, g, w) + } + } + + requests := server.TestSpanner.DrainRequestsFromServer() + // There should be no ExecuteSql requests. + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 0; g != w { + t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w) + } + // There should also be no ExecuteBatchDml requests. + batchDmlRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteBatchDmlRequest{})) + if g, w := len(batchDmlRequests), 0; g != w { + t.Fatalf("ExecuteBatchDmlRequest count mismatch\n Got: %v\nWant: %v", g, w) + } + + adminRequests := server.TestDatabaseAdmin.Reqs() + if g, w := len(adminRequests), 1; g != w { + t.Fatalf("admin request count mismatch\n Got: %v\nWant: %v", g, w) + } + ddlRequest := adminRequests[0].(*databasepb.UpdateDatabaseDdlRequest) + if g, w := len(ddlRequest.Statements), 2; g != w { + t.Fatalf("DDL statement count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestExecuteMixedBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Try to execute a batch with mixed DML and DDL statements. This should fail. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: "create table my_table (id int64 primary key, value string(100))"}, + {Sql: "update my_table set value = 100 where true"}, + }} + _, err = ExecuteBatch(ctx, poolId, connId, request) + if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestExecuteDdlBatchInTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Try to execute a DDL batch in a transaction. This should fail. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: "create table my_table (id int64 primary key, value string(100))"}, + {Sql: "create index my_index on my_table (value)"}, + }} + _, err = ExecuteBatch(ctx, poolId, connId, request) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} diff --git a/spannerlib/api/connection.go b/spannerlib/api/connection.go index b697dc6d..907212ec 100644 --- a/spannerlib/api/connection.go +++ b/spannerlib/api/connection.go @@ -17,8 +17,19 @@ package api import ( "context" "database/sql" + "database/sql/driver" + "fmt" + "strings" "sync" "sync/atomic" + + "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/apiv1/spannerpb" + spannerdriver "github.com/googleapis/go-sql-spanner" + "github.com/googleapis/go-sql-spanner/parser" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" ) // CloseConnection looks up the connection with the given poolId and connId and closes it. @@ -36,6 +47,67 @@ func CloseConnection(ctx context.Context, poolId, connId int64) error { return conn.close(ctx) } +// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in +// the current read/write transaction if the connection currently has a read/write transaction. +// The mutations are applied to the database in a new read/write transaction that is automatically +// committed if the connection currently does not have a transaction. +// +// The function returns an error if the connection is currently in a read-only transaction. +// +// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object. +func WriteMutations(ctx context.Context, poolId, connId int64, mutations *spannerpb.BatchWriteRequest_MutationGroup) (*spannerpb.CommitResponse, error) { + conn, err := findConnection(poolId, connId) + if err != nil { + return nil, err + } + return conn.writeMutations(ctx, mutations) +} + +// BeginTransaction starts a new transaction on the given connection. +// A connection can have at most one transaction at any time. This function therefore returns an error if the +// connection has an active transaction. +func BeginTransaction(ctx context.Context, poolId, connId int64, txOpts *spannerpb.TransactionOptions) error { + conn, err := findConnection(poolId, connId) + if err != nil { + return err + } + return conn.BeginTransaction(ctx, txOpts) +} + +// Commit commits the current transaction on the given connection. +func Commit(ctx context.Context, poolId, connId int64) (*spannerpb.CommitResponse, error) { + conn, err := findConnection(poolId, connId) + if err != nil { + return nil, err + } + return conn.commit(ctx) +} + +// Rollback rollbacks the current transaction on the given connection. +func Rollback(ctx context.Context, poolId, connId int64) error { + conn, err := findConnection(poolId, connId) + if err != nil { + return err + } + return conn.rollback(ctx) +} + +func Execute(ctx context.Context, poolId, connId int64, executeSqlRequest *spannerpb.ExecuteSqlRequest) (int64, error) { + conn, err := findConnection(poolId, connId) + if err != nil { + return 0, err + } + return conn.Execute(ctx, executeSqlRequest) +} + +func ExecuteBatch(ctx context.Context, poolId, connId int64, statements *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { + conn, err := findConnection(poolId, connId) + if err != nil { + return nil, err + } + return conn.ExecuteBatch(ctx, statements.Statements) +} + type Connection struct { // results contains the open query results for this connection. results *sync.Map @@ -45,8 +117,26 @@ type Connection struct { backend *sql.Conn } +// spannerConn is an internal interface that contains the internal functions that are used by this API. +// It is implemented by the spannerdriver.conn struct. +type spannerConn interface { + WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error) + BeginReadOnlyTransaction(ctx context.Context, options *spannerdriver.ReadOnlyTransactionOptions) (driver.Tx, error) + BeginReadWriteTransaction(ctx context.Context, options *spannerdriver.ReadWriteTransactionOptions) (driver.Tx, error) + Commit(ctx context.Context) (*spanner.CommitResponse, error) + Rollback(ctx context.Context) error +} + +type queryExecutor interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) +} + func (conn *Connection) close(ctx context.Context) error { conn.closeResults(ctx) + // Rollback any open transactions on the connection. + _ = conn.rollback(ctx) + err := conn.backend.Close() if err != nil { return err @@ -54,9 +144,332 @@ func (conn *Connection) close(ctx context.Context) error { return nil } +func (conn *Connection) writeMutations(ctx context.Context, mutation *spannerpb.BatchWriteRequest_MutationGroup) (*spannerpb.CommitResponse, error) { + mutations := make([]*spanner.Mutation, 0, len(mutation.Mutations)) + for _, m := range mutation.Mutations { + spannerMutation, err := spanner.WrapMutation(m) + if err != nil { + return nil, err + } + mutations = append(mutations, spannerMutation) + } + var commitResponse *spanner.CommitResponse + if err := conn.backend.Raw(func(driverConn any) (err error) { + sc, _ := driverConn.(spannerConn) + commitResponse, err = sc.WriteMutations(ctx, mutations) + return err + }); err != nil { + return nil, err + } + + // The commit response is nil if the connection is currently in a transaction. + if commitResponse == nil { + return nil, nil + } + response := spannerpb.CommitResponse{ + CommitTimestamp: timestamppb.New(commitResponse.CommitTs), + } + return &response, nil +} + +func (conn *Connection) BeginTransaction(ctx context.Context, txOpts *spannerpb.TransactionOptions) error { + var err error + if txOpts.GetReadOnly() != nil { + return conn.beginReadOnlyTransaction(ctx, convertToReadOnlyOpts(txOpts)) + } else if txOpts.GetPartitionedDml() != nil { + err = spanner.ToSpannerError(status.Error(codes.InvalidArgument, "transaction type not supported")) + } else { + return conn.beginReadWriteTransaction(ctx, convertToReadWriteTransactionOptions(txOpts)) + } + if err != nil { + return err + } + return nil +} + +func (conn *Connection) beginReadOnlyTransaction(ctx context.Context, opts *spannerdriver.ReadOnlyTransactionOptions) error { + return conn.backend.Raw(func(driverConn any) (err error) { + sc, _ := driverConn.(spannerConn) + _, err = sc.BeginReadOnlyTransaction(ctx, opts) + return err + }) +} + +func (conn *Connection) beginReadWriteTransaction(ctx context.Context, opts *spannerdriver.ReadWriteTransactionOptions) error { + return conn.backend.Raw(func(driverConn any) (err error) { + sc, _ := driverConn.(spannerConn) + _, err = sc.BeginReadWriteTransaction(ctx, opts) + return err + }) +} + +func (conn *Connection) commit(ctx context.Context) (*spannerpb.CommitResponse, error) { + var response *spanner.CommitResponse + if err := conn.backend.Raw(func(driverConn any) (err error) { + spannerConn, _ := driverConn.(spannerConn) + response, err = spannerConn.Commit(ctx) + if err != nil { + return err + } + return nil + }); err != nil { + return nil, err + } + + // The commit response is nil for read-only transactions. + if response == nil { + return nil, nil + } + // TODO: Include commit stats + return &spannerpb.CommitResponse{CommitTimestamp: timestamppb.New(response.CommitTs)}, nil +} + +func (conn *Connection) rollback(ctx context.Context) error { + return conn.backend.Raw(func(driverConn any) (err error) { + spannerConn, _ := driverConn.(spannerConn) + return spannerConn.Rollback(ctx) + }) +} + +func convertToReadOnlyOpts(txOpts *spannerpb.TransactionOptions) *spannerdriver.ReadOnlyTransactionOptions { + return &spannerdriver.ReadOnlyTransactionOptions{ + TimestampBound: convertTimestampBound(txOpts), + } +} + +func convertTimestampBound(txOpts *spannerpb.TransactionOptions) spanner.TimestampBound { + ro := txOpts.GetReadOnly() + if ro.GetStrong() { + return spanner.StrongRead() + } else if ro.GetReadTimestamp() != nil { + return spanner.ReadTimestamp(ro.GetReadTimestamp().AsTime()) + } else if ro.GetMinReadTimestamp() != nil { + return spanner.ReadTimestamp(ro.GetMinReadTimestamp().AsTime()) + } else if ro.GetExactStaleness() != nil { + return spanner.ExactStaleness(ro.GetExactStaleness().AsDuration()) + } else if ro.GetMaxStaleness() != nil { + return spanner.MaxStaleness(ro.GetMaxStaleness().AsDuration()) + } + return spanner.TimestampBound{} +} + +func convertToReadWriteTransactionOptions(txOpts *spannerpb.TransactionOptions) *spannerdriver.ReadWriteTransactionOptions { + readLockMode := spannerpb.TransactionOptions_ReadWrite_READ_LOCK_MODE_UNSPECIFIED + if txOpts.GetReadWrite() != nil { + readLockMode = txOpts.GetReadWrite().GetReadLockMode() + } + return &spannerdriver.ReadWriteTransactionOptions{ + TransactionOptions: spanner.TransactionOptions{ + IsolationLevel: txOpts.GetIsolationLevel(), + ReadLockMode: readLockMode, + }, + } +} + +func convertIsolationLevel(level spannerpb.TransactionOptions_IsolationLevel) sql.IsolationLevel { + switch level { + case spannerpb.TransactionOptions_SERIALIZABLE: + return sql.LevelSerializable + case spannerpb.TransactionOptions_REPEATABLE_READ: + return sql.LevelRepeatableRead + } + return sql.LevelDefault +} + func (conn *Connection) closeResults(ctx context.Context) { conn.results.Range(func(key, value interface{}) bool { - // TODO: Implement + if r, ok := value.(*rows); ok { + _ = r.Close(ctx) + } return true }) } + +func (conn *Connection) Execute(ctx context.Context, statement *spannerpb.ExecuteSqlRequest) (int64, error) { + return execute(ctx, conn, conn.backend, statement) +} + +func (conn *Connection) ExecuteBatch(ctx context.Context, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + return executeBatch(ctx, conn, conn.backend, statements) +} + +func execute(ctx context.Context, conn *Connection, executor queryExecutor, statement *spannerpb.ExecuteSqlRequest) (int64, error) { + params := extractParams(statement) + it, err := executor.QueryContext(ctx, statement.Sql, params...) + if err != nil { + return 0, err + } + // The first result set should contain the metadata. + if !it.Next() { + return 0, fmt.Errorf("query returned no metadata") + } + metadata := &spannerpb.ResultSetMetadata{} + if err := it.Scan(&metadata); err != nil { + return 0, err + } + // Move to the next result set, which contains the normal data. + if !it.NextResultSet() { + return 0, fmt.Errorf("no results found after metadata") + } + id := conn.resultsIdx.Add(1) + res := &rows{ + backend: it, + metadata: metadata, + } + if len(metadata.RowType.Fields) == 0 { + // No rows returned. Read the stats now. + res.readStats(ctx) + } + conn.results.Store(id, res) + return id, nil +} + +func executeBatch(ctx context.Context, conn *Connection, executor queryExecutor, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + // Determine the type of batch that should be executed based on the type of statements. + batchType, err := determineBatchType(conn, statements) + if err != nil { + return nil, err + } + switch batchType { + case parser.BatchTypeDml: + return executeBatchDml(ctx, conn, executor, statements) + case parser.BatchTypeDdl: + return executeBatchDdl(ctx, conn, executor, statements) + default: + return nil, status.Errorf(codes.InvalidArgument, "unsupported batch type: %v", batchType) + } +} + +func executeBatchDdl(ctx context.Context, conn *Connection, executor queryExecutor, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + if err := conn.backend.Raw(func(driverConn any) error { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + return spannerConn.StartBatchDDL() + }); err != nil { + return nil, err + } + for _, statement := range statements { + _, err := executor.ExecContext(ctx, statement.Sql) + if err != nil { + return nil, err + } + } + // TODO: Add support for getting the actual Batch DDL response. + if err := conn.backend.Raw(func(driverConn any) (err error) { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + return spannerConn.RunBatch(ctx) + }); err != nil { + return nil, err + } + + response := spannerpb.ExecuteBatchDmlResponse{} + response.ResultSets = make([]*spannerpb.ResultSet, len(statements)) + for i := range statements { + response.ResultSets[i] = &spannerpb.ResultSet{Stats: &spannerpb.ResultSetStats{}} + } + return &response, nil +} + +func executeBatchDml(ctx context.Context, conn *Connection, executor queryExecutor, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + if err := conn.backend.Raw(func(driverConn any) error { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + return spannerConn.StartBatchDML() + }); err != nil { + return nil, err + } + for _, statement := range statements { + request := &spannerpb.ExecuteSqlRequest{ + Sql: statement.Sql, + Params: statement.Params, + ParamTypes: statement.ParamTypes, + } + params := extractParams(request) + _, err := executor.ExecContext(ctx, statement.Sql, params...) + if err != nil { + return nil, err + } + } + var spannerResult spannerdriver.SpannerResult + if err := conn.backend.Raw(func(driverConn any) (err error) { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + spannerResult, err = spannerConn.RunDmlBatch(ctx) + return err + }); err != nil { + return nil, err + } + affected, err := spannerResult.BatchRowsAffected() + if err != nil { + return nil, err + } + response := spannerpb.ExecuteBatchDmlResponse{} + response.ResultSets = make([]*spannerpb.ResultSet, len(affected)) + for i, aff := range affected { + response.ResultSets[i] = &spannerpb.ResultSet{Stats: &spannerpb.ResultSetStats{RowCount: &spannerpb.ResultSetStats_RowCountExact{RowCountExact: aff}}} + } + return &response, nil +} + +func extractParams(statement *spannerpb.ExecuteSqlRequest) []any { + paramsLen := 1 + if statement.Params != nil { + paramsLen = 1 + len(statement.Params.Fields) + } + params := make([]any, paramsLen) + params = append(params, spannerdriver.ExecOptions{ + DecodeOption: spannerdriver.DecodeOptionProto, + // TODO: Implement support for passing in stale query options + // TimestampBound: extractTimestampBound(statement), + ReturnResultSetMetadata: true, + ReturnResultSetStats: true, + DirectExecuteQuery: true, + }) + if statement.Params != nil { + if statement.ParamTypes == nil { + statement.ParamTypes = make(map[string]*spannerpb.Type) + } + for param, value := range statement.Params.Fields { + genericValue := spanner.GenericColumnValue{ + Value: value, + Type: statement.ParamTypes[param], + } + if strings.HasPrefix(param, "_") { + // Prefix the parameter name with a 'p' to work around the fact that database/sql does not allow + // named arguments to start with anything else than a letter. + params = append(params, sql.Named("p"+param, spannerdriver.SpannerNamedArg{NameInQuery: param, Value: genericValue})) + } else { + params = append(params, sql.Named(param, genericValue)) + } + } + } + return params +} + +func determineBatchType(conn *Connection, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (parser.BatchType, error) { + if len(statements) == 0 { + return parser.BatchTypeDdl, status.Errorf(codes.InvalidArgument, "cannot determine type of an empty batch") + } + var batchType parser.BatchType + if err := conn.backend.Raw(func(driverConn any) error { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + firstStatementType := spannerConn.DetectStatementType(statements[0].Sql) + if firstStatementType == parser.StatementTypeDml { + batchType = parser.BatchTypeDml + } else if firstStatementType == parser.StatementTypeDdl { + batchType = parser.BatchTypeDdl + } else { + return status.Errorf(codes.InvalidArgument, "unsupported statement type for batching: %v", firstStatementType) + } + for i, statement := range statements { + if i > 0 { + tp := spannerConn.DetectStatementType(statement.Sql) + if tp != firstStatementType { + return status.Errorf(codes.InvalidArgument, "Batches may not contain different types of statements. The first statement is of type %v. The statement on position %d is of type %v.", firstStatementType, i, tp) + } + } + } + return nil + }); err != nil { + return parser.BatchTypeDdl, err + } + + return batchType, nil +} diff --git a/spannerlib/api/connection_test.go b/spannerlib/api/connection_test.go index 64d19ab3..b4625bb3 100644 --- a/spannerlib/api/connection_test.go +++ b/spannerlib/api/connection_test.go @@ -24,6 +24,7 @@ import ( "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/googleapis/go-sql-spanner/testutil" "google.golang.org/grpc/codes" + "google.golang.org/protobuf/types/known/structpb" ) func TestCreateAndCloseConnection(t *testing.T) { @@ -143,3 +144,159 @@ func TestCloseConnectionTwice(t *testing.T) { t.Fatalf("ClosePool returned unexpected error: %v", err) } } + +func TestWriteMutations(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + {Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}}, + {Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}}, + }, + }}}, + {Operation: &spannerpb.Mutation_Update{Update: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("0"), structpb.NewStringValue("Zero")}}, + }, + }}}, + }} + resp, err := WriteMutations(ctx, poolId, connId, mutations) + if err != nil { + t.Fatalf("WriteMutations returned unexpected error: %v", err) + } + if resp.CommitTimestamp == nil { + t.Fatalf("CommitTimestamp is nil") + } + requests := server.TestSpanner.DrainRequestsFromServer() + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequest := commitRequests[0].(*spannerpb.CommitRequest) + if g, w := len(commitRequest.Mutations), 2; g != w { + t.Fatalf("num mutations mismatch\n Got: %d\nWant: %d", g, w) + } + + // Write the same mutations in a transaction. + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + resp, err = WriteMutations(ctx, poolId, connId, mutations) + if err != nil { + t.Fatalf("WriteMutations returned unexpected error: %v", err) + } + if resp != nil { + t.Fatalf("WriteMutations returned unexpected response: %v", resp) + } + resp, err = Commit(ctx, poolId, connId) + if err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + if resp == nil { + t.Fatalf("Commit returned nil response") + } + if resp.CommitTimestamp == nil { + t.Fatalf("CommitTimestamp is nil") + } + requests = server.TestSpanner.DrainRequestsFromServer() + beginRequests = testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequests = testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequest = commitRequests[0].(*spannerpb.CommitRequest) + if g, w := len(commitRequest.Mutations), 2; g != w { + t.Fatalf("num mutations mismatch\n Got: %d\nWant: %d", g, w) + } + + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestWriteMutationsInReadOnlyTx(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Start a read-only transaction and try to write mutations to that transaction. That should return an error. + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadOnly_{ReadOnly: &spannerpb.TransactionOptions_ReadOnly{}}, + }); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + }, + }}}, + }} + _, err = WriteMutations(ctx, poolId, connId, mutations) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("WriteMutations error code mismatch\n Got: %d\nWant: %d", g, w) + } + + // Committing the read-only transaction should not lead to any commits on Spanner. + _, err = Commit(ctx, poolId, connId) + if err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + requests := server.TestSpanner.DrainRequestsFromServer() + // There should also not be any BeginTransaction requests on Spanner, as the transaction was never really started + // by a query or other statement. + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 0; g != w { + t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w) + } + + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} diff --git a/spannerlib/api/pool.go b/spannerlib/api/pool.go index 75e9eaa5..9468e85a 100644 --- a/spannerlib/api/pool.go +++ b/spannerlib/api/pool.go @@ -17,6 +17,7 @@ package api import ( "context" "database/sql" + "fmt" "sync" "sync/atomic" @@ -131,3 +132,16 @@ func findConnection(poolId, connId int64) (*Connection, error) { conn := c.(*Connection) return conn, nil } + +func findRows(poolId, connId, rowsId int64) (*rows, error) { + conn, err := findConnection(poolId, connId) + if err != nil { + return nil, err + } + r, ok := conn.results.Load(rowsId) + if !ok { + return nil, fmt.Errorf("rows %v not found", rowsId) + } + res := r.(*rows) + return res, nil +} diff --git a/spannerlib/api/rows.go b/spannerlib/api/rows.go new file mode 100644 index 00000000..52bedbe1 --- /dev/null +++ b/spannerlib/api/rows.go @@ -0,0 +1,203 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "database/sql" + "errors" + + "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" +) + +type EncodeRowOption int32 + +const ( + EncodeRowOptionProto EncodeRowOption = iota +) + +// Metadata returns the ResultSetMetadata of the given rows. +// This function can be called for any type of statement (queries, DML, DDL). +func Metadata(_ context.Context, poolId, connId, rowsId int64) (*spannerpb.ResultSetMetadata, error) { + res, err := findRows(poolId, connId, rowsId) + if err != nil { + return nil, err + } + return res.Metadata() +} + +// ResultSetStats returns the result statistics of the given rows. +// This function can only be called once all data in the rows have been fetched. +// The stats are empty for queries and DDL statements. +func ResultSetStats(ctx context.Context, poolId, connId, rowsId int64) (*spannerpb.ResultSetStats, error) { + res, err := findRows(poolId, connId, rowsId) + if err != nil { + return nil, err + } + return res.ResultSetStats(ctx) +} + +// NextEncoded returns the next row data in encoded form. +// Using NextEncoded instead of Next can be more efficient for large result sets, +// as it allows the library to re-use the encoding buffer. +// TODO: Add an encoder function as input argument, instead of hardcoding protobuf encoding here. +func NextEncoded(ctx context.Context, poolId, connId, rowsId int64) ([]byte, error) { + _, bytes, err := next(ctx, poolId, connId, rowsId, true) + if err != nil { + return nil, err + } + return bytes, nil +} + +// Next returns the next row as a protobuf ListValue. +func Next(ctx context.Context, poolId, connId, rowsId int64) (*structpb.ListValue, error) { + values, _, err := next(ctx, poolId, connId, rowsId, false) + if err != nil { + return nil, err + } + return values, nil +} + +// next returns the next row of data. +// The row is returned as a protobuf ListValue if marshalResult==false. +// The row is returned as a byte slice if marshalResult==true. +// TODO: Add generics to the function and add input arguments for encoding instead of hardcoding it. +func next(ctx context.Context, poolId, connId, rowsId int64, marshalResult bool) (*structpb.ListValue, []byte, error) { + rows, err := findRows(poolId, connId, rowsId) + if err != nil { + return nil, nil, err + } + values, err := rows.Next(ctx) + if err != nil { + return nil, nil, err + } + if !marshalResult || values == nil { + return values, nil, nil + } + + rows.marshalBuffer, err = proto.MarshalOptions{}.MarshalAppend(rows.marshalBuffer[:0], rows.values) + if err != nil { + return nil, nil, err + } + return values, rows.marshalBuffer, nil +} + +// CloseRows closes the given rows. Callers must always call this to clean up any resources +// that are held by the underlying cursor. +func CloseRows(ctx context.Context, poolId, connId, rowsId int64) error { + conn, err := findConnection(poolId, connId) + if err != nil { + return err + } + r, ok := conn.results.LoadAndDelete(rowsId) + if !ok { + return nil + } + res := r.(*rows) + return res.Close(ctx) +} + +type rows struct { + backend *sql.Rows + metadata *spannerpb.ResultSetMetadata + stats *spannerpb.ResultSetStats + done bool + + buffer []any + values *structpb.ListValue + marshalBuffer []byte +} + +func (rows *rows) Close(ctx context.Context) error { + err := rows.backend.Close() + if err != nil { + return err + } + return nil +} + +func (rows *rows) Metadata() (*spannerpb.ResultSetMetadata, error) { + return rows.metadata, nil +} + +func (rows *rows) ResultSetStats(ctx context.Context) (*spannerpb.ResultSetStats, error) { + if rows.stats == nil { + rows.readStats(ctx) + } + return rows.stats, nil +} + +type genericValue struct { + v *structpb.Value +} + +func (gv *genericValue) Scan(src any) error { + if v, ok := src.(spanner.GenericColumnValue); ok { + gv.v = v.Value + return nil + } + return errors.New("cannot convert value to generic column value") +} + +func (rows *rows) Next(ctx context.Context) (*structpb.ListValue, error) { + // No columns means no rows, so just return nil to indicate that there are no (more) rows. + if len(rows.metadata.RowType.Fields) == 0 || rows.done { + return nil, nil + } + if rows.stats != nil { + return nil, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "cannot read more data after returning stats")) + } + ok := rows.backend.Next() + if !ok { + rows.done = true + // No more rows. Read stats and return nil. + rows.readStats(ctx) + // nil indicates no more rows. + return nil, nil + } + + if rows.buffer == nil { + rows.buffer = make([]any, len(rows.metadata.RowType.Fields)) + for i := range rows.buffer { + rows.buffer[i] = &genericValue{} + } + rows.values = &structpb.ListValue{ + Values: make([]*structpb.Value, len(rows.buffer)), + } + rows.marshalBuffer = make([]byte, 0) + } + if err := rows.backend.Scan(rows.buffer...); err != nil { + return nil, err + } + for i := range rows.buffer { + rows.values.Values[i] = rows.buffer[i].(*genericValue).v + } + return rows.values, nil +} + +func (rows *rows) readStats(ctx context.Context) { + rows.stats = &spannerpb.ResultSetStats{} + if !rows.backend.NextResultSet() { + return + } + if rows.backend.Next() { + _ = rows.backend.Scan(&rows.stats) + } +} diff --git a/spannerlib/api/rows_test.go b/spannerlib/api/rows_test.go new file mode 100644 index 00000000..508dffb1 --- /dev/null +++ b/spannerlib/api/rows_test.go @@ -0,0 +1,128 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "fmt" + "reflect" + "testing" + + "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/testutil" + "google.golang.org/grpc/codes" +) + +func TestExecute(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + rowsId, err := Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{ + Sql: testutil.SelectFooFromBar, + }) + if rowsId == 0 { + t.Fatal("Execute returned unexpected zero id") + } + p, ok := pools.Load(poolId) + if !ok { + t.Fatal("pool not found in map") + } + pool := p.(*Pool) + c, ok := pool.connections.Load(connId) + if !ok { + t.Fatal("connection not in map") + } + connection, ok := c.(*Connection) + if _, ok := connection.results.Load(rowsId); !ok { + t.Fatal("result not in map") + } + + requests := server.TestSpanner.DrainRequestsFromServer() + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("num ExecuteSql requests mismatch\n Got: %d\nWant: %d", g, w) + } + + if err := CloseRows(ctx, poolId, connId, rowsId); err != nil { + t.Fatalf("CloseRows returned unexpected error: %v", err) + } + if _, ok := connection.results.Load(rowsId); ok { + t.Fatal("rows still in results map") + } + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestExecuteUnknownConnection(t *testing.T) { + t.Parallel() + + ctx := context.Background() + _, err := Execute(ctx, -1, -1, &spannerpb.ExecuteSqlRequest{ + Sql: testutil.SelectFooFromBar, + }) + if g, w := spanner.ErrCode(err), codes.NotFound; g != w { + t.Fatalf("error code mismatch\n Got: %d\nWant: %d", g, w) + } +} + +func TestCloseRowsTwice(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + rowsId, err := Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{ + Sql: testutil.SelectFooFromBar, + }) + + for range 2 { + if err := CloseRows(ctx, poolId, connId, rowsId); err != nil { + t.Fatalf("CloseRows returned unexpected error: %v", err) + } + } + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} diff --git a/spannerlib/api/transaction_test.go b/spannerlib/api/transaction_test.go new file mode 100644 index 00000000..c629a2b4 --- /dev/null +++ b/spannerlib/api/transaction_test.go @@ -0,0 +1,411 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "fmt" + "reflect" + "testing" + + "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/testutil" + "google.golang.org/grpc/codes" +) + +func TestBeginAndCommit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Execute a statement in the transaction. + rowsId, err := Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo}) + if err != nil { + t.Fatalf("Execute returned unexpected error: %v", err) + } + stats, err := ResultSetStats(ctx, poolId, connId, rowsId) + if err != nil { + t.Fatalf("ResultSetStats returned unexpected error: %v", err) + } + if g, w := stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("row count mismatch\n Got: %v\nWant: %v", g, w) + } + if err := CloseRows(ctx, poolId, connId, rowsId); err != nil { + t.Fatalf("CloseRows returned unexpected error: %v", err) + } + + // Commit the transaction. + if _, err := Commit(ctx, poolId, connId); err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + + // Verify that the statement used the transaction, and that the transaction was started using an inlined begin + // option on the ExecuteSqlRequest. + requests := server.TestSpanner.DrainRequestsFromServer() + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 0; g != w { + t.Fatalf("BeginTransaction request count mismatch\n Got: %v\nWant: %v", g, w) + } + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w) + } + executeRequest := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if executeRequest.GetTransaction() == nil || executeRequest.GetTransaction().GetBegin() == nil || executeRequest.GetTransaction().GetBegin().GetReadWrite() == nil { + t.Fatalf("missing BeginTransaction option on request: %v", executeRequest) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("Commit request count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestBeginAndRollback(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Execute a statement in the transaction. + rowsId, err := Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo}) + if err != nil { + t.Fatalf("Execute returned unexpected error: %v", err) + } + stats, err := ResultSetStats(ctx, poolId, connId, rowsId) + if err != nil { + t.Fatalf("ResultSetStats returned unexpected error: %v", err) + } + if g, w := stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("row count mismatch\n Got: %v\nWant: %v", g, w) + } + if err := CloseRows(ctx, poolId, connId, rowsId); err != nil { + t.Fatalf("CloseRows returned unexpected error: %v", err) + } + + // Rollback the transaction. + if err := Rollback(ctx, poolId, connId); err != nil { + t.Fatalf("Rollback returned unexpected error: %v", err) + } + + // Verify that the transaction was rolled back. + requests := server.TestSpanner.DrainRequestsFromServer() + rollbackRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.RollbackRequest{})) + if g, w := len(rollbackRequests), 1; g != w { + t.Fatalf("Rollback request count mismatch\n Got: %v\nWant: %v", g, w) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("Commit request count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestCommitWithOpenRows(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Execute a statement in the transaction. + _, err = Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo}) + if err != nil { + t.Fatalf("Execute returned unexpected error: %v", err) + } + + // Try to commit the transaction without closing the Rows object that was returned during the transaction. + if _, err := Commit(ctx, poolId, connId); err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestCloseConnectionWithOpenTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + // Execute a statement in the transaction to activate it. + _, err = Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo}) + if err != nil { + t.Fatalf("Execute returned unexpected error: %v", err) + } + + // Close the connection while a transaction is still active. + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } + + // Verify that the transaction was rolled back when the connection was closed. + requests := server.TestSpanner.DrainRequestsFromServer() + rollbackRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.RollbackRequest{})) + if g, w := len(rollbackRequests), 1; g != w { + t.Fatalf("Rollback request count mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginTransactionWithOpenTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Try to start a transaction when one is already active. + err = BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestCommitWithoutTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Try to commit when there is no transaction. + _, err = Commit(ctx, poolId, connId) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestRollbackWithoutTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Try to commit when there is no transaction. + err = Rollback(ctx, poolId, connId) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestReadOnlyTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadOnly_{ + ReadOnly: &spannerpb.TransactionOptions_ReadOnly{}, + }, + }); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Execute a statement in the transaction. + rowsId, err := Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: testutil.SelectFooFromBar}) + if err != nil { + t.Fatalf("Execute returned unexpected error: %v", err) + } + if err := CloseRows(ctx, poolId, connId, rowsId); err != nil { + t.Fatalf("CloseRows returned unexpected error: %v", err) + } + + // Commit the transaction. + if _, err := Commit(ctx, poolId, connId); err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + + // Verify that the statement used a read-only transaction, that the transaction was started using an inlined + // begin option on the ExecuteSqlRequest, and that the Commit call was a no-op. + requests := server.TestSpanner.DrainRequestsFromServer() + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 0; g != w { + t.Fatalf("BeginTransaction request count mismatch\n Got: %v\nWant: %v", g, w) + } + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w) + } + executeRequest := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if executeRequest.GetTransaction() == nil || executeRequest.GetTransaction().GetBegin() == nil || executeRequest.GetTransaction().GetBegin().GetReadOnly() == nil { + t.Fatalf("missing BeginTransaction option on request: %v", executeRequest) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("Commit request count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestDdlInTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Execute a DDL statement in the transaction. This should fail. + _, err = Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: "create table my_table (id int64 primary key)"}) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} diff --git a/spannerlib/go.mod b/spannerlib/go.mod index 4b0a8f8b..6d6912b0 100644 --- a/spannerlib/go.mod +++ b/spannerlib/go.mod @@ -8,8 +8,8 @@ replace github.com/googleapis/go-sql-spanner => .. require ( cloud.google.com/go/spanner v1.85.1 + github.com/google/go-cmp v0.7.0 github.com/googleapis/go-sql-spanner v1.18.0 - google.golang.org/api v0.249.0 google.golang.org/grpc v1.75.1 google.golang.org/protobuf v1.36.9 ) @@ -60,6 +60,7 @@ require ( golang.org/x/sys v0.35.0 // indirect golang.org/x/text v0.28.0 // indirect golang.org/x/time v0.12.0 // indirect + google.golang.org/api v0.249.0 // indirect google.golang.org/genproto v0.0.0-20250804133106-a7a43d27e69b // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250818200422-3122310a409c // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250908214217-97024824d090 // indirect diff --git a/spannerlib/lib/connection.go b/spannerlib/lib/connection.go index bbac0bb4..75c6bc3a 100644 --- a/spannerlib/lib/connection.go +++ b/spannerlib/lib/connection.go @@ -17,7 +17,10 @@ package lib import "C" import ( "context" + "fmt" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "google.golang.org/protobuf/proto" "spannerlib/api" ) @@ -30,3 +33,99 @@ func CloseConnection(ctx context.Context, poolId, connId int64) *Message { } return &Message{} } + +// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in +// the current read/write transaction if the connection currently has a read/write transaction. +// The mutations are applied to the database in a new read/write transaction that is automatically +// committed if the connection currently does not have a transaction. +// +// The function returns an error if the connection is currently in a read-only transaction. +// +// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object. +func WriteMutations(ctx context.Context, poolId, connId int64, mutationBytes []byte) *Message { + mutations := spannerpb.BatchWriteRequest_MutationGroup{} + if err := proto.Unmarshal(mutationBytes, &mutations); err != nil { + return errMessage(err) + } + response, err := api.WriteMutations(ctx, poolId, connId, &mutations) + if err != nil { + return errMessage(err) + } + res, err := proto.Marshal(response) + if err != nil { + return errMessage(err) + } + return &Message{Res: res} +} + +// BeginTransaction starts a new transaction on the given connection. A connection can have at most one active +// transaction at any time. This function therefore returns an error if the connection has an active transaction. +func BeginTransaction(ctx context.Context, poolId, connId int64, txOptsBytes []byte) *Message { + txOpts := spannerpb.TransactionOptions{} + if err := proto.Unmarshal(txOptsBytes, &txOpts); err != nil { + return errMessage(err) + } + err := api.BeginTransaction(ctx, poolId, connId, &txOpts) + if err != nil { + return errMessage(err) + } + return &Message{} +} + +// Commit commits the current transaction on the given connection. +func Commit(ctx context.Context, poolId, connId int64) *Message { + response, err := api.Commit(ctx, poolId, connId) + if err != nil { + return errMessage(err) + } + if response == nil { + return &Message{} + } + res, err := proto.Marshal(response) + if err != nil { + return errMessage(err) + } + return &Message{Res: res} +} + +// Rollback rollbacks the current transaction on the given connection. +func Rollback(ctx context.Context, poolId, connId int64) *Message { + err := api.Rollback(ctx, poolId, connId) + if err != nil { + return errMessage(err) + } + return &Message{} +} + +func Execute(ctx context.Context, poolId, connId int64, executeSqlRequestBytes []byte) (msg *Message) { + defer func() { + if r := recover(); r != nil { + msg = errMessage(fmt.Errorf("panic for message with size %d: %v", len(executeSqlRequestBytes), r)) + } + }() + statement := spannerpb.ExecuteSqlRequest{} + if err := proto.Unmarshal(executeSqlRequestBytes, &statement); err != nil { + return errMessage(err) + } + id, err := api.Execute(ctx, poolId, connId, &statement) + if err != nil { + return errMessage(err) + } + return idMessage(id) +} + +func ExecuteBatch(ctx context.Context, poolId, connId int64, statementsBytes []byte) *Message { + statements := spannerpb.ExecuteBatchDmlRequest{} + if err := proto.Unmarshal(statementsBytes, &statements); err != nil { + return errMessage(err) + } + response, err := api.ExecuteBatch(ctx, poolId, connId, &statements) + if err != nil { + return errMessage(err) + } + res, err := proto.Marshal(response) + if err != nil { + return errMessage(err) + } + return &Message{Res: res} +} diff --git a/spannerlib/lib/connection_test.go b/spannerlib/lib/connection_test.go index 2a51777e..193e3fd9 100644 --- a/spannerlib/lib/connection_test.go +++ b/spannerlib/lib/connection_test.go @@ -19,7 +19,11 @@ import ( "fmt" "testing" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/testutil" "google.golang.org/grpc/codes" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" ) func TestCreateAndCloseConnection(t *testing.T) { @@ -65,3 +69,256 @@ func TestCreateConnectionWithUnknownPool(t *testing.T) { t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) } } + +func TestExecute(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + request := &spannerpb.ExecuteSqlRequest{ + Sql: testutil.SelectFooFromBar, + } + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + rowsMsg := Execute(ctx, poolMsg.ObjectId, connMsg.ObjectId, requestBytes) + if g, w := rowsMsg.Code, int32(0); g != w { + t.Fatalf("Execute result mismatch\n Got: %v\nWant: %v", g, w) + } + if rowsMsg.ObjectId <= 0 { + t.Fatalf("rowsId mismatch: %v", rowsMsg.ObjectId) + } + if g, w := rowsMsg.Length(), int32(0); g != w { + t.Fatalf("result length mismatch\n Got: %v\nWant: %v", g, w) + } + + closeMsg := CloseRows(ctx, poolMsg.ObjectId, connMsg.ObjectId, rowsMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseRows result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestExecuteBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: testutil.UpdateBarSetFoo}, + {Sql: testutil.UpdateBarSetFoo}, + }} + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + rowsMsg := ExecuteBatch(ctx, poolMsg.ObjectId, connMsg.ObjectId, requestBytes) + if g, w := rowsMsg.Code, int32(0); g != w { + t.Fatalf("ExecuteBatch result mismatch\n Got: %v\nWant: %v", g, w) + } + if rowsMsg.Length() == 0 { + t.Fatal("ExecuteBatch returned no data") + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginAndCommit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + if err != nil { + t.Fatalf("Failed to marshal transaction options: %v", err) + } + txMsg := BeginTransaction(ctx, poolMsg.ObjectId, connMsg.ObjectId, txOptsBytes) + if g, w := txMsg.Code, int32(0); g != w { + t.Fatalf("BeginTransaction result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := txMsg.ObjectId, int64(0); g != w { + t.Fatalf("object ID result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := txMsg.Length(), int32(0); g != w { + t.Fatalf("result length mismatch\n Got: %v\nWant: %v", g, w) + } + + commitMsg := Commit(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := commitMsg.Code, int32(0); g != w { + t.Fatalf("Commit result mismatch\n Got: %v\nWant: %v", g, w) + } + if commitMsg.Length() == 0 { + t.Fatal("Commit return zero length") + } + resp := &spannerpb.CommitResponse{} + if err := proto.Unmarshal(commitMsg.Res, resp); err != nil { + t.Fatalf("Failed to unmarshal commit response: %v", err) + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginAndRollback(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + if err != nil { + t.Fatalf("Failed to marshal transaction options: %v", err) + } + txMsg := BeginTransaction(ctx, poolMsg.ObjectId, connMsg.ObjectId, txOptsBytes) + if g, w := txMsg.Code, int32(0); g != w { + t.Fatalf("BeginTransaction result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := txMsg.ObjectId, int64(0); g != w { + t.Fatalf("object ID result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := txMsg.Length(), int32(0); g != w { + t.Fatalf("result length mismatch\n Got: %v\nWant: %v", g, w) + } + + rollbackMsg := Rollback(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := rollbackMsg.Code, int32(0); g != w { + t.Fatalf("Rollback result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := rollbackMsg.Length(), int32(0); g != w { + t.Fatalf("Rollback length mismatch\n Got: %v\nWant: %v", g, w) + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestWriteMutations(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + {Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}}, + {Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}}, + }, + }}}, + }} + mutationBytes, err := proto.Marshal(mutations) + if err != nil { + t.Fatal(err) + } + mutationsMsg := WriteMutations(ctx, poolMsg.ObjectId, connMsg.ObjectId, mutationBytes) + if g, w := mutationsMsg.Code, int32(0); g != w { + t.Fatalf("WriteMutations result mismatch\n Got: %v\nWant: %v", g, w) + } + if mutationsMsg.Length() == 0 { + t.Fatal("WriteMutations returned no data") + } + + // Write mutations in a transaction. + mutationsMsg = BeginTransaction(ctx, poolMsg.ObjectId, connMsg.ObjectId, mutationBytes) + // The response should now be an empty message, as the mutations were only buffered in the transaction. + if g, w := mutationsMsg.Code, int32(0); g != w { + t.Fatalf("WriteMutations result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := mutationsMsg.Length(), int32(0); g != w { + t.Fatalf("WriteMutations data length mismatch\n Got: %v\nWant: %v", g, w) + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} diff --git a/spannerlib/lib/rows.go b/spannerlib/lib/rows.go new file mode 100644 index 00000000..04068a98 --- /dev/null +++ b/spannerlib/lib/rows.go @@ -0,0 +1,62 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lib + +import ( + "context" + + "google.golang.org/protobuf/proto" + "spannerlib/api" +) + +func Metadata(ctx context.Context, poolId, connId, rowsId int64) *Message { + metadata, err := api.Metadata(ctx, poolId, connId, rowsId) + if err != nil { + return errMessage(err) + } + metadataBytes, err := proto.Marshal(metadata) + if err != nil { + return errMessage(err) + } + return &Message{Res: metadataBytes} +} + +func ResultSetStats(ctx context.Context, poolId, connId, rowsId int64) *Message { + stats, err := api.ResultSetStats(ctx, poolId, connId, rowsId) + if err != nil { + return errMessage(err) + } + statsBytes, err := proto.Marshal(stats) + if err != nil { + return errMessage(err) + } + return &Message{Res: statsBytes} +} + +func Next(ctx context.Context, poolId, connId, rowsId int64) *Message { + valuesBytes, err := api.NextEncoded(ctx, poolId, connId, rowsId) + if err != nil { + return errMessage(err) + } + return &Message{Res: valuesBytes} +} + +func CloseRows(ctx context.Context, poolId, connId, rowsId int64) *Message { + err := api.CloseRows(ctx, poolId, connId, rowsId) + if err != nil { + return errMessage(err) + } + return &Message{} +} diff --git a/spannerlib/lib/rows_test.go b/spannerlib/lib/rows_test.go new file mode 100644 index 00000000..d1168928 --- /dev/null +++ b/spannerlib/lib/rows_test.go @@ -0,0 +1,314 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lib + +import ( + "context" + "fmt" + "testing" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/googleapis/go-sql-spanner/testutil" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/structpb" +) + +func TestQuery(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + request := &spannerpb.ExecuteSqlRequest{ + Sql: testutil.SelectFooFromBar, + } + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + rowsMsg := Execute(ctx, poolMsg.ObjectId, connMsg.ObjectId, requestBytes) + if g, w := rowsMsg.Code, int32(0); g != w { + t.Fatalf("Execute result mismatch\n Got: %v\nWant: %v", g, w) + } + + numRows := 0 + for { + rowMsg := Next(ctx, poolMsg.ObjectId, connMsg.ObjectId, rowsMsg.ObjectId) + if g, w := rowMsg.Code, int32(0); g != w { + t.Fatalf("Next result mismatch\n Got: %v\nWant: %v", g, w) + } + // Data length == 0 means end of data. + if rowMsg.Length() == 0 { + break + } + numRows++ + values := &structpb.ListValue{} + if err := proto.Unmarshal(rowMsg.Res, values); err != nil { + t.Fatalf("Failed to unmarshal values: %v", err) + } + if g, w := len(values.Values), 1; g != w { + t.Fatalf("num values mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := values.Values[0].GetStringValue(), fmt.Sprintf("%d", numRows); g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } + if g, w := numRows, 2; g != w { + t.Fatalf("num rows mismatch\n Got: %v\nWant: %v", g, w) + } + + // The ResultSetStats should be empty for queries. + statsMsg := ResultSetStats(ctx, poolMsg.ObjectId, connMsg.ObjectId, rowsMsg.ObjectId) + if g, w := statsMsg.Code, int32(0); g != w { + t.Fatalf("ResultSetStats result mismatch\n Got: %v\nWant: %v", g, w) + } + stats := &spannerpb.ResultSetStats{} + if err := proto.Unmarshal(statsMsg.Res, stats); err != nil { + t.Fatalf("Failed to unmarshal ResultSetStats: %v", err) + } + // TODO: Enable when this branch is up to date with main + //emptyStats := &spannerpb.ResultSetStats{} + //if g, w := stats, emptyStats; !cmp.Equal(g, w, cmpopts.IgnoreUnexported(spannerpb.ResultSetStats{})) { + // t.Fatalf("ResultSetStats mismatch\n Got: %v\nWant: %v", g, w) + //} + + closeMsg := CloseRows(ctx, poolMsg.ObjectId, connMsg.ObjectId, rowsMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseRows result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestDml(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + request := &spannerpb.ExecuteSqlRequest{ + Sql: testutil.UpdateBarSetFoo, + } + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + // Execute is used for all types of statements. + rowsMsg := Execute(ctx, poolMsg.ObjectId, connMsg.ObjectId, requestBytes) + if g, w := rowsMsg.Code, int32(0); g != w { + t.Fatalf("Execute result mismatch\n Got: %v\nWant: %v", g, w) + } + + // Next should return no rows for a DML statement. + rowMsg := Next(ctx, poolMsg.ObjectId, connMsg.ObjectId, rowsMsg.ObjectId) + if g, w := rowMsg.Code, int32(0); g != w { + t.Fatalf("Next result mismatch\n Got: %v\nWant: %v", g, w) + } + // Data length == 0 means end of data (or no data). + if g, w := rowMsg.Length(), int32(0); g != w { + t.Fatalf("row length mismatch\n Got: %v\nWant: %v", g, w) + } + + // The ResultSetStats should contain the update count for DML statements. + statsMsg := ResultSetStats(ctx, poolMsg.ObjectId, connMsg.ObjectId, rowsMsg.ObjectId) + if g, w := statsMsg.Code, int32(0); g != w { + t.Fatalf("ResultSetStats result mismatch\n Got: %v\nWant: %v", g, w) + } + stats := &spannerpb.ResultSetStats{} + if err := proto.Unmarshal(statsMsg.Res, stats); err != nil { + t.Fatalf("Failed to unmarshal ResultSetStats: %v", err) + } + wantStats := &spannerpb.ResultSetStats{ + RowCount: &spannerpb.ResultSetStats_RowCountExact{RowCountExact: testutil.UpdateBarSetFooRowCount}, + } + if g, w := stats, wantStats; !cmp.Equal(g, w, cmpopts.IgnoreUnexported(spannerpb.ResultSetStats{})) { + t.Fatalf("ResultSetStats mismatch\n Got: %v\nWant: %v", g, w) + } + + closeMsg := CloseRows(ctx, poolMsg.ObjectId, connMsg.ObjectId, rowsMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseRows result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestDdl(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + // Set up a result for a DDL statement on the mock server. + var expectedResponse = &emptypb.Empty{} + anyMsg, _ := anypb.New(expectedResponse) + server.TestDatabaseAdmin.SetResps([]proto.Message{ + &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Response{Response: anyMsg}, + Name: "test-operation", + }, + }) + + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + // The input argument for the library is always an ExecuteSqlRequest, also for DDL statements. + request := &spannerpb.ExecuteSqlRequest{ + Sql: "CREATE TABLE my_table (id INT64 PRIMARY KEY, value STRING(MAX))", + } + + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + // Execute is used for all types of statements, including DDL. + rowsMsg := Execute(ctx, poolMsg.ObjectId, connMsg.ObjectId, requestBytes) + if g, w := rowsMsg.Code, int32(0); g != w { + t.Fatalf("Execute result mismatch\n Got: %v\nWant: %v", g, w) + } + + // Next should return no rows for a DDL statement. + rowMsg := Next(ctx, poolMsg.ObjectId, connMsg.ObjectId, rowsMsg.ObjectId) + if g, w := rowMsg.Code, int32(0); g != w { + t.Fatalf("Next result mismatch\n Got: %v\nWant: %v", g, w) + } + // Data length == 0 means end of data (or no data). + if g, w := rowMsg.Length(), int32(0); g != w { + t.Fatalf("row length mismatch\n Got: %v\nWant: %v", g, w) + } + + // The ResultSetStats should be empty for DDL statements. + statsMsg := ResultSetStats(ctx, poolMsg.ObjectId, connMsg.ObjectId, rowsMsg.ObjectId) + if g, w := statsMsg.Code, int32(0); g != w { + t.Fatalf("ResultSetStats result mismatch\n Got: %v\nWant: %v", g, w) + } + stats := &spannerpb.ResultSetStats{} + if err := proto.Unmarshal(statsMsg.Res, stats); err != nil { + t.Fatalf("Failed to unmarshal ResultSetStats: %v", err) + } + emptyStats := &spannerpb.ResultSetStats{} + if g, w := stats, emptyStats; !cmp.Equal(g, w, cmpopts.IgnoreUnexported(spannerpb.ResultSetStats{})) { + t.Fatalf("ResultSetStats mismatch\n Got: %v\nWant: %v", g, w) + } + + closeMsg := CloseRows(ctx, poolMsg.ObjectId, connMsg.ObjectId, rowsMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseRows result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestQueryError(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + // Set up a query that returns an error. + query := "select * from non_existing_table" + _ = server.TestSpanner.PutStatementResult(query, &testutil.StatementResult{ + Type: testutil.StatementResultError, + Err: status.Error(codes.NotFound, "Table not found"), + }) + // Execute the query that will return an error. + request := &spannerpb.ExecuteSqlRequest{ + Sql: query, + } + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + rowsMsg := Execute(ctx, poolMsg.ObjectId, connMsg.ObjectId, requestBytes) + if g, w := rowsMsg.Code, int32(codes.NotFound); g != w { + t.Fatalf("Execute result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := rowsMsg.ObjectId, int64(0); g != w { + t.Fatalf("rowsId mismatch\n Got: %v\nWant: %v", g, w) + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} diff --git a/spannerlib/shared/build-java-darwin-aarch64.sh b/spannerlib/shared/build-java-darwin-aarch64.sh new file mode 100755 index 00000000..7198a4d9 --- /dev/null +++ b/spannerlib/shared/build-java-darwin-aarch64.sh @@ -0,0 +1,2 @@ +go build -o spannerlib.so -buildmode=c-shared shared_lib.go +cp spannerlib.so ../wrappers/spannerlib-java/src/main/resources/darwin-aarch64/libspanner.dylib diff --git a/spannerlib/shared/shared_lib.go b/spannerlib/shared/shared_lib.go index f39b3b2c..8645f219 100644 --- a/spannerlib/shared/shared_lib.go +++ b/spannerlib/shared/shared_lib.go @@ -110,3 +110,124 @@ func CloseConnection(poolId, connId int64) (int64, int32, int64, int32, unsafe.P msg := lib.CloseConnection(ctx, poolId, connId) return pin(msg) } + +// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in +// the current read/write transaction if the connection currently has a read/write transaction. +// The mutations are applied to the database in a new read/write transaction that is automatically +// committed if the connection currently does not have a transaction. +// +// The function returns an error if the connection is currently in a read-only transaction. +// +// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object. +// +//export WriteMutations +func WriteMutations(poolId, connectionId int64, mutationsBytes []byte) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.WriteMutations(ctx, poolId, connectionId, mutationsBytes) + return pin(msg) +} + +// Execute executes a SQL statement on the given connection. +// The return type is an identifier for a Rows object. This identifier can be used to +// call the functions Metadata and Next to get respectively the metadata of the result +// and the next row of results. +// +//export Execute +func Execute(poolId, connectionId int64, statement []byte) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.Execute(ctx, poolId, connectionId, statement) + return pin(msg) +} + +// ExecuteBatch executes a batch of statements on the given connection. The statements must all be either DML or DDL +// statements. Mixing DML and DDL in a batch is not supported. Executing queries in a batch is also not supported. +// The batch will use the current transaction on the given connection, or execute as a single auto-commit statement +// if the connection does not have a transaction. +// +//export ExecuteBatch +func ExecuteBatch(poolId, connectionId int64, statements []byte) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.ExecuteBatch(ctx, poolId, connectionId, statements) + return pin(msg) +} + +// Metadata returns the metadata of a Rows object. +// +//export Metadata +func Metadata(poolId, connId, rowsId int64) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.Metadata(ctx, poolId, connId, rowsId) + return pin(msg) +} + +// ResultSetStats returns the statistics for a statement that has been executed. This includes +// the number of rows affected in case of a DML statement. +// Statistics are only available once all rows have been consumed. +// +//export ResultSetStats +func ResultSetStats(poolId, connId, rowsId int64) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.ResultSetStats(ctx, poolId, connId, rowsId) + return pin(msg) +} + +// Next returns the next row in a Rows object. The returned message contains a protobuf +// ListValue that contains all the columns of the row. The message is empty if there are +// no more rows in the Rows object. +// +//export Next +func Next(poolId, connId, rowsId int64, numRows int32, encodeRowOption int32) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + // TODO: Implement support for: + // 1. Fetching more than one row at a time. + // 2. Specifying the return type (e.g. proto, struct, ...) + msg := lib.Next(ctx, poolId, connId, rowsId) + return pin(msg) +} + +// CloseRows closes and cleans up all memory held by a Rows object. This must be called +// when the application is done with the Rows object. +// +//export CloseRows +func CloseRows(poolId, connId, rowsId int64) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.CloseRows(ctx, poolId, connId, rowsId) + return pin(msg) +} + +// BeginTransaction begins a new transaction on the given connection. +// The txOpts byte slice contains a serialized protobuf TransactionOptions object. +// +//export BeginTransaction +func BeginTransaction(poolId, connectionId int64, txOpts []byte) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.BeginTransaction(ctx, poolId, connectionId, txOpts) + return pin(msg) +} + +// Commit commits the current transaction on a connection. All transactions must be +// either committed or rolled back, including read-only transactions. This to ensure +// that all resources that are held by a transaction are cleaned up. +// +//export Commit +func Commit(poolId, connectionId int64) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.Commit(ctx, poolId, connectionId) + return pin(msg) +} + +// Rollback rolls back a previously started transaction. All transactions must be either +// committed or rolled back, including read-only transactions. This to ensure that +// all resources that are held by a transaction are cleaned up. +// +// Spanner does not require read-only transactions to be committed or rolled back, but +// this library requires that all transactions are committed or rolled back to clean up +// all resources. Commit and Rollback are semantically the same for read-only transactions +// on Spanner, and both functions just close the transaction. +// +//export Rollback +func Rollback(poolId, connectionId int64) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.Rollback(ctx, poolId, connectionId) + return pin(msg) +} diff --git a/spannerlib/shared/shared_lib_test.go b/spannerlib/shared/shared_lib_test.go index 7824aeda..b6d63467 100644 --- a/spannerlib/shared/shared_lib_test.go +++ b/spannerlib/shared/shared_lib_test.go @@ -16,11 +16,16 @@ package main import ( "fmt" + "reflect" "testing" "unsafe" "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/googleapis/go-sql-spanner/testutil" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" + "spannerlib/api" ) // The tests in this file only verify the happy flow to ensure that everything compiles. @@ -94,6 +99,433 @@ func TestCreateConnection(t *testing.T) { } } +func TestExecute(t *testing.T) { + // This test is intentionally not marked as Parallel, as it checks the number of open memory pointers. + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + request := &spannerpb.ExecuteSqlRequest{ + // This query returns a result set with one column and two rows. + // The values in the two rows are 1 and 2. + Sql: testutil.SelectFooFromBar, + } + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatal(err) + } + // Execute returns a reference to a Rows object, not the actual data. + mem, code, rowsId, length, data := Execute(poolId, connId, requestBytes) + if g, w := mem, int64(0); g != w { + t.Fatalf("Execute mem mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := code, int32(0); g != w { + t.Fatalf("Execute result mismatch\n Got: %v\nWant: %v", g, w) + } + if rowsId <= int64(0) { + t.Fatalf("rowsId mismatch: %v", rowsId) + } + if g, w := length, int32(0); g != w { + t.Fatalf("Execute length mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := unsafe.Pointer(nil), data; g != w { + t.Fatalf("Execute data mismatch\n Got: %v\nWant: %v", g, w) + } + + // Get the metadata of the selected rows. + mem, code, _, length, data = Metadata(poolId, connId, rowsId) + // Metadata returns actual data, and should therefore return a memory ID that needs to be released. + if mem == int64(0) { + t.Fatalf("Metadata mem mismatch: %v", mem) + } + if length == int32(0) { + t.Fatalf("Metadata length mismatch: %v", length) + } + // Get a []byte from the pointer to the data and the length. + metadataBytes := reflect.SliceAt(reflect.TypeOf(byte(0)), data, int(length)).Bytes() + metadata := &spannerpb.ResultSetMetadata{} + if err := proto.Unmarshal(metadataBytes, metadata); err != nil { + t.Fatal(err) + } + if g, w := len(metadata.RowType.Fields), 1; g != w { + t.Fatalf("Metadata field count mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := metadata.RowType.Fields[0].Name, "FOO"; g != w { + t.Fatalf("Metadata field name mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := metadata.RowType.Fields[0].Type.Code, spannerpb.TypeCode_INT64; g != w { + t.Fatalf("Metadata type code mismatch\n Got: %v\nWant: %v", g, w) + } + // Release the memory. + if g, w := Release(mem), int32(0); g != w { + t.Fatalf("Release() result mismatch\n Got: %v\nWant: %v", g, w) + } + + // Iterate over the rows. + numRows := 0 + for { + mem, code, _, length, data = Next(poolId, connId, rowsId /*numRows = */, 1, int32(api.EncodeRowOptionProto)) + // Next returns an empty message if it is the end of the query results. + if length == 0 { + break + } + numRows++ + // Decode the row. + rowBytes := reflect.SliceAt(reflect.TypeOf(byte(0)), data, int(length)).Bytes() + row := &structpb.ListValue{} + if err := proto.Unmarshal(rowBytes, row); err != nil { + t.Fatal(err) + } + // Release the memory that was held for the row. We can do that as soon as it has + // been copied into a data structure that is maintained by the 'application'. + // The 'application' in this case is the test. + if g, w := Release(mem), int32(0); g != w { + t.Fatalf("Release() result mismatch\n Got: %v\nWant: %v", g, w) + } + // Verify the row data. + if g, w := len(row.GetValues()), 1; g != w { + t.Fatalf("num row values mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := row.GetValues()[0].GetStringValue(), fmt.Sprintf("%d", numRows); g != w { + t.Fatalf("row values mismatch\n Got: %v\nWant: %v", g, w) + } + } + // The result should contain two rows. + if g, w := numRows, 2; g != w { + t.Fatalf("num rows mismatch\n Got: %v\nWant: %v", g, w) + } + + // Get the ResultSetStats. For queries, this is an empty instance. + mem, code, _, length, data = ResultSetStats(poolId, connId, rowsId) + if g, w := code, int32(0); g != w { + t.Fatalf("ResultSetStats result code mismatch\n Got: %v\nWant: %v", g, w) + } + if length == int32(0) { + t.Fatalf("ResultSetStats length mismatch: %v", length) + } + statsBytes := reflect.SliceAt(reflect.TypeOf(byte(0)), data, int(length)).Bytes() + stats := &spannerpb.ResultSetStats{} + if err := proto.Unmarshal(statsBytes, stats); err != nil { + t.Fatal(err) + } + // TODO: Enable when this branch is up to date with main + // emptyStats := &spannerpb.ResultSetStats{} + //if g, w := stats, emptyStats; !cmp.Equal(g, w, cmpopts.IgnoreUnexported(spannerpb.ResultSetStats{})) { + // t.Fatalf("ResultSetStats mismatch\n Got: %v\nWant: %v", g, w) + //} + if res := Release(mem); res != 0 { + t.Fatalf("Release() result mismatch\n Got: %v\nWant: %v", res, 0) + } + + _, code, _, _, _ = CloseRows(poolId, connId, rowsId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseRows result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } + + if g, w := countOpenMemoryPointers(), 0; g != w { + t.Fatalf("countOpenMemoryPointers() result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestExecuteBatch(t *testing.T) { + t.Parallel() + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + request := &spannerpb.ExecuteBatchDmlRequest{ + Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: testutil.UpdateBarSetFoo}, + {Sql: testutil.UpdateBarSetFoo}, + }, + } + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatal(err) + } + // ExecuteBatch returns a ExecuteBatchDml response. + mem, code, batchId, length, data := ExecuteBatch(poolId, connId, requestBytes) + verifyDataMessage(t, "ExecuteBatch", mem, code, batchId, length, data) + response := &spannerpb.ExecuteBatchDmlResponse{} + responseBytes := reflect.SliceAt(reflect.TypeOf(byte(0)), data, int(length)).Bytes() + if err := proto.Unmarshal(responseBytes, response); err != nil { + t.Fatal(err) + } + if g, w := len(response.ResultSets), 2; g != w { + t.Fatalf("num results mismatch\n Got: %v\nWant: %v", g, w) + } + for i, result := range response.ResultSets { + if g, w := result.Stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("%d: update count mismatch\n Got: %v\nWant: %v", i, g, w) + } + } + // Release the memory held by the response. + if g, w := Release(mem), int32(0); g != w { + t.Fatalf("Release() result mismatch\n Got: %v\nWant: %v", g, w) + } + + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginAndCommitTransaction(t *testing.T) { + t.Parallel() + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + if err != nil { + t.Fatal(err) + } + mem, code, id, length, res := BeginTransaction(poolId, connId, txOptsBytes) + // BeginTransaction should return an empty message. + // That is, there should be no error code, no ObjectID, and no data. + verifyEmptyMessage(t, "BeginTransaction", mem, code, id, length, res) + + // Execute a statement in the transaction. + request := &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo} + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatal(err) + } + _, code, rowsId, _, _ := Execute(poolId, connId, requestBytes) + if g, w := code, int32(0); g != w { + t.Fatalf("Execute result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = CloseRows(poolId, connId, rowsId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseRows result mismatch\n Got: %v\nWant: %v", g, w) + } + + // Commit returns the CommitResponse (if any). + mem, code, id, length, res = Commit(poolId, connId) + verifyDataMessage(t, "Commit", mem, code, id, length, res) + + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginAndRollbackTransaction(t *testing.T) { + t.Parallel() + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + if err != nil { + t.Fatal(err) + } + mem, code, id, length, res := BeginTransaction(poolId, connId, txOptsBytes) + // BeginTransaction should return an empty message. + // That is, there should be no error code, no ObjectID, and no data. + verifyEmptyMessage(t, "BeginTransaction", mem, code, id, length, res) + + // Execute a statement in the transaction. + request := &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo} + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatal(err) + } + _, code, rowsId, _, _ := Execute(poolId, connId, requestBytes) + if g, w := code, int32(0); g != w { + t.Fatalf("Execute result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = CloseRows(poolId, connId, rowsId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseRows result mismatch\n Got: %v\nWant: %v", g, w) + } + + // Rollback returns nothing. + mem, code, id, length, res = Rollback(poolId, connId) + verifyEmptyMessage(t, "Rollback", mem, code, id, length, res) + + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestWriteMutations(t *testing.T) { + t.Parallel() + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + {Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}}, + {Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}}, + }, + }}}, + }} + mutationBytes, err := proto.Marshal(mutations) + if err != nil { + t.Fatal(err) + } + // WriteMutations returns a CommitResponse or nil, depending on whether the connection has an active transaction. + mem, code, id, length, data := WriteMutations(poolId, connId, mutationBytes) + verifyDataMessage(t, "WriteMutations", mem, code, id, length, data) + + response := &spannerpb.CommitResponse{} + responseBytes := reflect.SliceAt(reflect.TypeOf(byte(0)), data, int(length)).Bytes() + if err := proto.Unmarshal(responseBytes, response); err != nil { + t.Fatal(err) + } + if response.CommitTimestamp == nil { + t.Fatal("CommitTimestamp is nil") + } + // Release the memory held by the response. + if g, w := Release(mem), int32(0); g != w { + t.Fatalf("Release() result mismatch\n Got: %v\nWant: %v", g, w) + } + + // Start a transaction on the connection and write the mutations to that transaction. + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + _, code, _, _, _ = BeginTransaction(poolId, connId, txOptsBytes) + if g, w := code, int32(0); g != w { + t.Fatalf("BeginTransaction result mismatch\n Got: %v\nWant: %v", g, w) + } + mem, code, id, length, data = WriteMutations(poolId, connId, mutationBytes) + // The response should now be an empty message, as the mutations were buffered in the current transaction. + verifyEmptyMessage(t, "WriteMutations in tx", mem, code, id, length, data) + + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func verifyEmptyMessage(t *testing.T, name string, mem int64, code int32, id int64, length int32, res unsafe.Pointer) { + if g, w := mem, int64(0); g != w { + t.Fatalf("%s: mem ID mismatch\n Got: %v\nWant: %v", name, g, w) + } + if g, w := code, int32(0); g != w { + t.Fatalf("%s: result mismatch\n Got: %v\nWant: %v", name, g, w) + } + if g, w := id, int64(0); g != w { + t.Fatalf("%s: ID mismatch\n Got: %v\nWant: %v", name, g, w) + } + if g, w := length, int32(0); g != w { + t.Fatalf("%s: length mismatch\n Got: %v\nWant: %v", name, g, w) + } + if g, w := res, unsafe.Pointer(nil); g != w { + t.Fatalf("%s: ptr mismatch\n Got: %v\nWant: %v", name, g, w) + } +} + +// verifyDataMessage verifies that the result contains a data message. +func verifyDataMessage(t *testing.T, name string, mem int64, code int32, id int64, length int32, res unsafe.Pointer) { + if g, w := code, int32(0); g != w { + t.Fatalf("%s: result mismatch\n Got: %v\nWant: %v", name, g, w) + } + if mem == int64(0) { + t.Fatalf("%s: No memory identifier returned", name) + } + if g, w := id, int64(0); g != w { + t.Fatalf("%s: ID mismatch\n Got: %v\nWant: %v", name, g, w) + } + if length == int32(0) { + t.Fatalf("%s: zero length returned", name) + } + if res == unsafe.Pointer(nil) { + t.Fatalf("%s: nil pointer returned", name) + } +} + +func countOpenMemoryPointers() (c int) { + pinners.Range(func(key, value any) bool { + c++ + return true + }) + return +} + func setupMockServer(t *testing.T) (server *testutil.MockedSpannerInMemTestServer, teardown func()) { return setupMockServerWithDialect(t, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL) } diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java index ddd558ad..18a9ca50 100644 --- a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java @@ -18,6 +18,17 @@ import static com.google.cloud.spannerlib.internal.SpannerLibrary.executeAndRelease; +import com.google.cloud.spannerlib.internal.MessageHandler; +import com.google.cloud.spannerlib.internal.WrappedGoBytes; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.spanner.v1.BatchWriteRequest.MutationGroup; +import com.google.spanner.v1.CommitResponse; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteBatchDmlResponse; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.TransactionOptions; +import java.nio.ByteBuffer; + /** A {@link Connection} that has been created by SpannerLib. */ public class Connection extends AbstractLibraryObject { private final Pool pool; @@ -31,6 +42,96 @@ public Pool getPool() { return this.pool; } + /** + * Writes a group of mutations to Spanner. The mutations are buffered in the current read/write + * transaction if the connection has an active read/write transaction. Otherwise, the mutations + * are written directly to Spanner using a new read/write transaction. Returns a {@link + * CommitResponse} if the mutations were written directly to Spanner, and otherwise null if the + * mutations were buffered in the current transaction. + */ + public CommitResponse WriteMutations(MutationGroup mutations) { + try (WrappedGoBytes serializedRequest = WrappedGoBytes.serialize(mutations); + MessageHandler message = + getLibrary() + .execute( + library -> + library.WriteMutations( + pool.getId(), getId(), serializedRequest.getGoBytes()))) { + if (message.getLength() == 0) { + return null; + } + ByteBuffer buffer = message.getValue().getByteBuffer(0, message.getLength()); + return CommitResponse.parseFrom(buffer); + } catch (InvalidProtocolBufferException decodeException) { + throw new RuntimeException(decodeException); + } + } + + /** Starts a transaction on this connection. */ + public void beginTransaction(TransactionOptions options) { + try (WrappedGoBytes serializedOptions = WrappedGoBytes.serialize(options)) { + executeAndRelease( + getLibrary(), + library -> + library.BeginTransaction(pool.getId(), getId(), serializedOptions.getGoBytes())); + } + } + + /** + * Commits the current transaction on this connection and returns the {@link CommitResponse} or + * null if there is no {@link CommitResponse} (e.g. for read-only transactions). + */ + public CommitResponse commit() { + try (MessageHandler message = + getLibrary().execute(library -> library.Commit(pool.getId(), getId()))) { + // Return null in case there is no CommitResponse. + if (message.getLength() == 0) { + return null; + } + ByteBuffer buffer = message.getValue().getByteBuffer(0, message.getLength()); + return CommitResponse.parseFrom(buffer); + } catch (InvalidProtocolBufferException decodeException) { + throw new RuntimeException(decodeException); + } + } + + /** Rollbacks the current transaction on this connection. */ + public void rollback() { + executeAndRelease(getLibrary(), library -> library.Rollback(pool.getId(), getId())); + } + + /** Executes the given SQL statement on this connection. */ + public Rows execute(ExecuteSqlRequest request) { + try (WrappedGoBytes serializedRequest = WrappedGoBytes.serialize(request); + MessageHandler message = + getLibrary() + .execute( + library -> + library.Execute(pool.getId(), getId(), serializedRequest.getGoBytes()))) { + return new Rows(this, message.getObjectId()); + } + } + + /** + * Executes the given batch of DML or DDL statements on this connection. The statements must all + * be of the same type. + */ + public ExecuteBatchDmlResponse executeBatch(ExecuteBatchDmlRequest request) { + try (WrappedGoBytes serializedRequest = WrappedGoBytes.serialize(request); + MessageHandler message = + getLibrary() + .execute( + library -> + library.ExecuteBatch( + pool.getId(), getId(), serializedRequest.getGoBytes()))) { + ByteBuffer buffer = message.getValue().getByteBuffer(0, message.getLength()); + return ExecuteBatchDmlResponse.parseFrom(buffer); + } catch (InvalidProtocolBufferException decodeException) { + throw new RuntimeException(decodeException); + } + } + + /** Closes this connection. Any active transaction on the connection is rolled back. */ @Override public void close() { executeAndRelease(getLibrary(), library -> library.CloseConnection(pool.getId(), getId())); diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Rows.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Rows.java new file mode 100644 index 00000000..6861891b --- /dev/null +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Rows.java @@ -0,0 +1,113 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spannerlib; + +import static com.google.cloud.spannerlib.internal.SpannerLibrary.executeAndRelease; + +import com.google.cloud.spannerlib.internal.MessageHandler; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.ListValue; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; +import java.nio.ByteBuffer; +import java.sql.Statement; + +public class Rows extends AbstractLibraryObject { + public enum Encoding { + PROTOBUF, + } + + private final Connection connection; + + Rows(Connection connection, long id) { + super(connection.getLibrary(), id); + this.connection = connection; + } + + @Override + public void close() { + executeAndRelease( + getLibrary(), + library -> library.CloseRows(connection.getPool().getId(), connection.getId(), getId())); + } + + public ResultSetMetadata getMetadata() { + try (MessageHandler message = + getLibrary() + .execute( + library -> + library.Metadata(connection.getPool().getId(), connection.getId(), getId()))) { + if (message.getLength() == 0) { + return ResultSetMetadata.getDefaultInstance(); + } + ByteBuffer buffer = message.getValue().getByteBuffer(0, message.getLength()); + return ResultSetMetadata.parseFrom(buffer); + } catch (InvalidProtocolBufferException decodeException) { + throw new RuntimeException(decodeException); + } + } + + public ResultSetStats getResultSetStats() { + try (MessageHandler message = + getLibrary() + .execute( + library -> + library.ResultSetStats( + connection.getPool().getId(), connection.getId(), getId()))) { + if (message.getLength() == 0) { + return ResultSetStats.getDefaultInstance(); + } + ByteBuffer buffer = message.getValue().getByteBuffer(0, message.getLength()); + return ResultSetStats.parseFrom(buffer); + } catch (InvalidProtocolBufferException decodeException) { + throw new RuntimeException(decodeException); + } + } + + public long getUpdateCount() { + ResultSetStats stats = getResultSetStats(); + if (stats.hasRowCountExact()) { + return stats.getRowCountExact(); + } else if (stats.hasRowCountLowerBound()) { + return stats.getRowCountLowerBound(); + } + return Statement.SUCCESS_NO_INFO; + } + + /** Returns the next row in this {@link Rows} instance, or null if there are no more rows. */ + public ListValue next() { + try (MessageHandler message = + getLibrary() + .execute( + library -> + library.Next( + connection.getPool().getId(), + connection.getId(), + getId(), + /* numRows= */ 1, + Encoding.PROTOBUF.ordinal()))) { + // An empty message means that we have reached the end of the iterator. + if (message.getLength() == 0) { + return null; + } + ByteBuffer buffer = message.getValue().getByteBuffer(0, message.getLength()); + return ListValue.parseFrom(buffer); + } catch (InvalidProtocolBufferException decodeException) { + throw new RuntimeException(decodeException); + } + } +} diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/GoBytes.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/GoBytes.java index f2e4b02c..ac524f62 100644 --- a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/GoBytes.java +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/GoBytes.java @@ -17,6 +17,7 @@ package com.google.cloud.spannerlib.internal; import com.google.cloud.spannerlib.SpannerLibException; +import com.google.common.base.Preconditions; import com.google.protobuf.CodedOutputStream; import com.google.protobuf.Message; import com.google.rpc.Code; @@ -30,7 +31,7 @@ import java.util.List; /** {@link GoBytes} is the Java representation of a Go byte slice ([]byte). */ -public class GoBytes extends Structure implements Structure.ByReference, AutoCloseable { +public class GoBytes extends Structure implements Structure.ByValue, AutoCloseable { // JNA does not allow these fields to be final. /** The pointer to the actual data. */ @@ -60,7 +61,7 @@ public static GoBytes serialize(Message message) { } GoBytes(ByteBuffer buffer, long size) { - this.p = Native.getDirectBufferPointer(buffer); + this.p = Preconditions.checkNotNull(Native.getDirectBufferPointer(buffer)); this.n = size; this.c = size; } diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java index e7810878..7da5a013 100644 --- a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java @@ -26,10 +26,6 @@ * SpannerLib library. */ public interface SpannerLibrary extends Library { - // String LIBRARY_PATH = - // System.getProperty( - // "spanner.library", SpannerLibrary.class.getResource("spannerlib.so").getPath()); - // SpannerLibrary LIBRARY = Native.load(LIBRARY_PATH, SpannerLibrary.class); SpannerLibrary LIBRARY = Native.load("spanner", SpannerLibrary.class); /** Returns the singleton instance of the library. */ @@ -68,4 +64,46 @@ default MessageHandler execute(Function function) /** Closes the given Connection. */ Message CloseConnection(long poolId, long connectionId); + + /** + * Writes a group of mutations on Spanner. The mutations are buffered in the current read/write + * transaction if the connection has an active read/write transaction. Otherwise, the mutations + * are written directly to Spanner in a new read/write transaction. Returns a {@link + * com.google.spanner.v1.CommitResponse} if the mutations were written directly to Spanner, and an + * empty message if the mutations were only buffered in the current transaction. + */ + Message WriteMutations(long poolId, long connectionId, GoBytes mutations); + + /** Starts a new transaction on the given Connection. */ + Message BeginTransaction(long poolId, long connectionId, GoBytes transactionOptions); + + /** + * Commits the current transaction on the given Connection and returns a {@link + * com.google.spanner.v1.CommitResponse}. + */ + Message Commit(long poolId, long connectionId); + + /** Rollbacks the current transaction on the given Connection. */ + Message Rollback(long poolId, long connectionId); + + /** Executes a SQL statement on the given Connection. */ + Message Execute(long poolId, long connectionId, GoBytes executeSqlRequest); + + /** + * Executes a batch of DML or DDL statements on the given Connection. Returns an {@link + * com.google.spanner.v1.ExecuteBatchDmlResponse} for both DML and DDL batches. + */ + Message ExecuteBatch(long poolId, long connectionId, GoBytes executeBatchDmlRequest); + + /** Returns the {@link com.google.spanner.v1.ResultSetMetadata} of the given Rows object. */ + Message Metadata(long poolId, long connectionId, long rowsId); + + /** Returns the next row from the given Rows object. */ + Message Next(long poolId, long connectionId, long rowsId, int numRows, int encoding); + + /** Returns the {@link com.google.spanner.v1.ResultSetStats} of the given Rows object. */ + Message ResultSetStats(long poolId, long connectionId, long rowsId); + + /** Closes the given Rows object. */ + Message CloseRows(long poolId, long connectionId, long rowsId); } diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/WrappedGoBytes.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/WrappedGoBytes.java index 6aa245fd..5e8a7a32 100644 --- a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/WrappedGoBytes.java +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/WrappedGoBytes.java @@ -33,7 +33,7 @@ public class WrappedGoBytes implements AutoCloseable { /** Serializes a protobuf {@link Message} into a {@link WrappedGoBytes} instance. */ public static WrappedGoBytes serialize(Message message) { int size = message.getSerializedSize(); - byte[] bytes = message.toByteArray(); + // TODO: Use a pool of direct byte buffers to prevent creating new buffers for every request. ByteBuffer buffer = ByteBuffer.allocateDirect(size); try { message.writeTo(CodedOutputStream.newInstance(buffer)); diff --git a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java new file mode 100644 index 00000000..9740d541 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java @@ -0,0 +1,152 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spannerlib; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.common.collect.ImmutableMap; +import com.google.longrunning.Operation; +import com.google.protobuf.AbstractMessage; +import com.google.protobuf.Any; +import com.google.protobuf.Empty; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import com.google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata; +import com.google.spanner.admin.database.v1.UpdateDatabaseDdlRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest.Statement; +import com.google.spanner.v1.ExecuteBatchDmlResponse; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; +import java.util.List; +import org.junit.Test; + +public class BatchTest extends AbstractMockServerTest { + + @Test + public void testBatchDml() { + String insert = "insert into test (id, value) values (@id, @value)"; + mockSpanner.putStatementResult( + StatementResult.update( + com.google.cloud.spanner.Statement.newBuilder(insert) + .bind("id") + .to(1L) + .bind("value") + .to("One") + .build(), + 1L)); + mockSpanner.putStatementResult( + StatementResult.update( + com.google.cloud.spanner.Statement.newBuilder(insert) + .bind("id") + .to(2L) + .bind("value") + .to("Two") + .build(), + 1L)); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + ExecuteBatchDmlResponse response = + connection.executeBatch( + ExecuteBatchDmlRequest.newBuilder() + .addStatements( + Statement.newBuilder() + .setSql(insert) + .setParams( + Struct.newBuilder() + .putFields("id", Value.newBuilder().setStringValue("1").build()) + .putFields( + "value", Value.newBuilder().setStringValue("One").build()) + .build()) + .putAllParamTypes( + ImmutableMap.of( + "id", Type.newBuilder().setCode(TypeCode.INT64).build(), + "value", Type.newBuilder().setCode(TypeCode.STRING).build())) + .build()) + .addStatements( + Statement.newBuilder() + .setSql(insert) + .setParams( + Struct.newBuilder() + .putFields("id", Value.newBuilder().setStringValue("2").build()) + .putFields( + "value", Value.newBuilder().setStringValue("Two").build()) + .build()) + .putAllParamTypes( + ImmutableMap.of( + "id", Type.newBuilder().setCode(TypeCode.INT64).build(), + "value", Type.newBuilder().setCode(TypeCode.STRING).build())) + .build()) + .build()); + + assertEquals(2, response.getResultSetsCount()); + assertEquals(1L, response.getResultSets(0).getStats().getRowCountExact()); + assertEquals(1L, response.getResultSets(1).getStats().getRowCountExact()); + } + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + } + + @Test + public void testBatchDdl() { + // Set up a DDL response on the mock server. + mockDatabaseAdmin.addResponse( + Operation.newBuilder() + .setDone(true) + .setResponse(Any.pack(Empty.getDefaultInstance())) + .setMetadata(Any.pack(UpdateDatabaseDdlMetadata.getDefaultInstance())) + .build()); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + ExecuteBatchDmlResponse response = + connection.executeBatch( + ExecuteBatchDmlRequest.newBuilder() + .addStatements( + Statement.newBuilder() + .setSql("create table my_table (id int64 primary key, value string(max))") + .build()) + .addStatements( + Statement.newBuilder() + .setSql("create index my_index on my_table (value)") + .build()) + .build()); + + assertEquals(2, response.getResultSetsCount()); + assertFalse(response.getResultSets(0).getStats().hasRowCountExact()); + } + + List requests = mockDatabaseAdmin.getRequests(); + assertEquals(1, requests.size()); + UpdateDatabaseDdlRequest request = (UpdateDatabaseDdlRequest) requests.get(0); + assertEquals(2, request.getStatementsCount()); + } +} diff --git a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java index 5415b867..8efc22b5 100644 --- a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java +++ b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java @@ -18,10 +18,25 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; +import com.google.rpc.Code; +import com.google.spanner.v1.BatchWriteRequest.MutationGroup; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.CommitResponse; import com.google.spanner.v1.CreateSessionRequest; +import com.google.spanner.v1.Mutation; +import com.google.spanner.v1.Mutation.Write; +import com.google.spanner.v1.TransactionOptions; +import com.google.spanner.v1.TransactionOptions.ReadOnly; import org.junit.Test; public class ConnectionTest extends AbstractMockServerTest { @@ -53,4 +68,138 @@ public void testCreateTwoConnections() { } } } + + @Test + public void testWriteMutations() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + CommitResponse response = + connection.WriteMutations( + MutationGroup.newBuilder() + .addMutations( + Mutation.newBuilder() + .setInsert( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues( + Value.newBuilder().setStringValue("One").build()) + .build()) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("2").build()) + .addValues( + Value.newBuilder().setStringValue("Two").build()) + .build()) + .build()) + .build()) + .addMutations( + Mutation.newBuilder() + .setInsertOrUpdate( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("0").build()) + .addValues( + Value.newBuilder().setStringValue("Zero").build()) + .build()) + .build()) + .build()) + .build()); + assertNotNull(response); + assertNotNull(response.getCommitTimestamp()); + + assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + CommitRequest request = mockSpanner.getRequestsOfType(CommitRequest.class).get(0); + assertEquals(2, request.getMutationsCount()); + assertEquals(2, request.getMutations(0).getInsert().getValuesCount()); + assertEquals(1, request.getMutations(1).getInsertOrUpdate().getValuesCount()); + } + } + + @Test + public void testWriteMutationsInTransaction() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + CommitResponse response = + connection.WriteMutations( + MutationGroup.newBuilder() + .addMutations( + Mutation.newBuilder() + .setInsert( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues( + Value.newBuilder().setStringValue("One").build()) + .build()) + .build()) + .build()) + .build()); + // The mutations are only buffered in the current transaction, so there should be no response. + assertNull(response); + + // Committing the transaction should return a CommitResponse. + response = connection.commit(); + assertNotNull(response); + assertNotNull(response.getCommitTimestamp()); + + assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + CommitRequest request = mockSpanner.getRequestsOfType(CommitRequest.class).get(0); + assertEquals(1, request.getMutationsCount()); + assertEquals(1, request.getMutations(0).getInsert().getValuesCount()); + } + } + + @Test + public void testWriteMutationsInReadOnlyTransaction() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction( + TransactionOptions.newBuilder().setReadOnly(ReadOnly.newBuilder().build()).build()); + SpannerLibException exception = + assertThrows( + SpannerLibException.class, + () -> + connection.WriteMutations( + MutationGroup.newBuilder() + .addMutations( + Mutation.newBuilder() + .setInsert( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues( + Value.newBuilder() + .setStringValue("1") + .build()) + .addValues( + Value.newBuilder() + .setStringValue("One") + .build()) + .build()) + .build()) + .build()) + .build())); + assertEquals(Code.FAILED_PRECONDITION.getNumber(), exception.getStatus().getCode()); + } + } } diff --git a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/RowsTest.java b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/RowsTest.java new file mode 100644 index 00000000..827ba51d --- /dev/null +++ b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/RowsTest.java @@ -0,0 +1,307 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spannerlib; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.spanner.Dialect; +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.cloud.spanner.connection.RandomResultSetGenerator; +import com.google.common.collect.ImmutableMap; +import com.google.longrunning.Operation; +import com.google.protobuf.Any; +import com.google.protobuf.Empty; +import com.google.protobuf.ListValue; +import com.google.protobuf.Struct; +import com.google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; +import org.junit.Test; + +public class RowsTest extends AbstractMockServerTest { + + @Test + public void testExecuteSelect1() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + try (Rows rows = + connection.execute(ExecuteSqlRequest.newBuilder().setSql("SELECT 1").build())) { + ListValue row; + int numRows = 0; + while ((row = rows.next()) != null) { + numRows++; + assertEquals(1, row.getValuesList().size()); + assertTrue(row.getValues(0).hasStringValue()); + assertEquals("1", row.getValues(0).getStringValue()); + } + assertEquals(1, numRows); + } + } + } + + @Test + public void testRandomResults() { + String sql = "select * from random"; + int numRows = 100; + RandomResultSetGenerator generator = new RandomResultSetGenerator(numRows); + int numCols = RandomResultSetGenerator.generateAllTypes(Dialect.GOOGLE_STANDARD_SQL).length; + mockSpanner.putStatementResult(StatementResult.query(Statement.of(sql), generator.generate())); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + try (Rows rows = connection.execute(ExecuteSqlRequest.newBuilder().setSql(sql).build())) { + ListValue row; + int rowCount = 0; + ResultSetMetadata metadata = rows.getMetadata(); + assertEquals(numCols, metadata.getRowType().getFieldsCount()); + while ((row = rows.next()) != null) { + rowCount++; + assertEquals(numCols, row.getValuesList().size()); + } + assertEquals(numRows, rowCount); + } + } + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0); + assertEquals(sql, request.getSql()); + assertTrue(request.hasTransaction()); + assertTrue(request.getTransaction().hasSingleUse()); + assertTrue(request.getTransaction().getSingleUse().hasReadOnly()); + assertTrue(request.getTransaction().getSingleUse().getReadOnly().hasStrong()); + } + + @Test + public void testExecuteDml() { + String sql = "update my_table set my_val=1 where id=2"; + // Set up the result as a ResultSet, as the Java mock Spanner server does not return the + // transaction ID correctly when ExecuteStreamingSql is used for an update count result. + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType(StructType.newBuilder().build()) + .build()) + .setStats(ResultSetStats.newBuilder().setRowCountExact(1L).build()) + .build())); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + // The Execute method is used for all types of statements. + // The return type is always Rows. + try (Rows rows = connection.execute(ExecuteSqlRequest.newBuilder().setSql(sql).build())) { + // A DML statement without a THEN RETURN clause does not return any rows. + assertNull(rows.next()); + // The ResultSetStats contains the update count. The Rows wrapper contains a util method + // for getting it. + assertEquals(1L, rows.getUpdateCount()); + } + } + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0); + assertEquals(sql, request.getSql()); + assertTrue(request.hasTransaction()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadWrite()); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Test + public void testExecuteDdl() { + // Set up a DDL response on the mock server. + mockDatabaseAdmin.addResponse( + Operation.newBuilder() + .setDone(true) + .setResponse(Any.pack(Empty.getDefaultInstance())) + .setMetadata(Any.pack(UpdateDatabaseDdlMetadata.getDefaultInstance())) + .build()); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + // The Execute method is used for all types of statements. + // The input type is always an ExecuteSqlRequest, even for DDL statements. + // The return type is always Rows. + try (Rows rows = + connection.execute( + ExecuteSqlRequest.newBuilder() + .setSql( + "create table my_table (" + "id int64 primary key, " + "value string(max))") + .build())) { + // A DDL statement does not return any rows. + assertNull(rows.next()); + // There is no update count for DDL statements. + // The library returns the Java standard constant for the update count for DDL statements. + assertEquals(java.sql.Statement.SUCCESS_NO_INFO, rows.getUpdateCount()); + } + } + + assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals(1, mockDatabaseAdmin.getRequests().size()); + } + + @Test + public void testExecuteCustomSql() { + // The Execute method can be used to execute any type of SQL statement, including statements + // that are handled internally by the Spanner library. This includes for example statements to + // start and commit a transaction. + + String selectSql = "select value from my_val where id=@id"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(selectSql).bind("id").to(1L).build(), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName("value") + .build()) + .build()) + .build()) + .addRows( + ListValue.newBuilder() + .addValues( + com.google.protobuf.Value.newBuilder().setStringValue("bar").build()) + .build()) + .build())); + + String updateSql = "update my_table set my_val=@value where id=@id"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(updateSql).bind("value").to("foo").bind("id").to(1L).build(), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType(StructType.newBuilder().build()) + .build()) + .setStats(ResultSetStats.newBuilder().setRowCountExact(1L).build()) + .build())); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + // The Execute method is used for all types of statements. + // This starts a new transaction on the connection. + try (Rows rows = connection.execute(ExecuteSqlRequest.newBuilder().setSql("begin").build())) { + assertNull(rows.next()); + assertEquals(java.sql.Statement.SUCCESS_NO_INFO, rows.getUpdateCount()); + } + + // Execute a parameterized query using the current transaction. + try (Rows rows = + connection.execute( + ExecuteSqlRequest.newBuilder() + .setSql(selectSql) + .setParams( + Struct.newBuilder() + .putFields( + "id", + com.google.protobuf.Value.newBuilder().setStringValue("1").build()) + .build()) + .putAllParamTypes( + ImmutableMap.of("id", Type.newBuilder().setCode(TypeCode.INT64).build())) + .build())) { + ListValue row = rows.next(); + assertNotNull(row); + assertEquals(1, row.getValuesList().size()); + assertTrue(row.getValues(0).hasStringValue()); + assertEquals("bar", row.getValues(0).getStringValue()); + assertNull(rows.next()); + } + + // Execute a DML statement using the current transaction. + try (Rows rows = + connection.execute( + ExecuteSqlRequest.newBuilder() + .setSql(updateSql) + .setParams( + Struct.newBuilder() + .putFields( + "id", + com.google.protobuf.Value.newBuilder().setStringValue("1").build()) + .putFields( + "value", + com.google.protobuf.Value.newBuilder().setStringValue("foo").build()) + .build()) + .putAllParamTypes( + ImmutableMap.of( + "id", Type.newBuilder().setCode(TypeCode.INT64).build(), + "value", Type.newBuilder().setCode(TypeCode.STRING).build())) + .build())) { + // There should be no rows. + assertNull(rows.next()); + // There should be an update count. + assertEquals(1, rows.getUpdateCount()); + } + + // Commit the transaction. + try (Rows rows = + connection.execute(ExecuteSqlRequest.newBuilder().setSql("commit").build())) { + assertNull(rows.next()); + assertEquals(java.sql.Statement.SUCCESS_NO_INFO, rows.getUpdateCount()); + } + } + + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest selectRequest = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0); + assertEquals(selectSql, selectRequest.getSql()); + assertTrue(selectRequest.hasTransaction()); + // The library uses inline-begin-transaction, even if a transaction is started by executing a + // BEGIN statement. + assertTrue(selectRequest.getTransaction().hasBegin()); + assertTrue(selectRequest.getTransaction().getBegin().hasReadWrite()); + + ExecuteSqlRequest updateRequest = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(updateSql, updateRequest.getSql()); + assertTrue(updateRequest.hasTransaction()); + assertTrue(updateRequest.getTransaction().hasId()); + + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } +} diff --git a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/TransactionTest.java b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/TransactionTest.java new file mode 100644 index 00000000..f2eef624 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/TransactionTest.java @@ -0,0 +1,178 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spannerlib; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.cloud.spanner.connection.RandomResultSetGenerator; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.RollbackRequest; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.TransactionOptions; +import com.google.spanner.v1.TransactionOptions.ReadOnly; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; +import io.grpc.Status.Code; +import org.junit.Test; + +public class TransactionTest extends AbstractMockServerTest { + @Test + public void testBeginAndCommit() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + connection.commit(); + + // TODO: The library should take a shortcut and just skip committing empty transactions. + assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + } + + @Test + public void testBeginAndRollback() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + connection.rollback(); + + assertEquals(0, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(0, mockSpanner.countRequestsOfType(RollbackRequest.class)); + } + } + + @Test + public void testReadWriteTransaction() { + String updateSql = "update my_table set my_val=@value where id=@id"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(updateSql).bind("value").to("foo").bind("id").to(1L).build(), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType(StructType.newBuilder().build()) + .build()) + .setStats(ResultSetStats.newBuilder().setRowCountExact(1L).build()) + .build())); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + connection.execute( + ExecuteSqlRequest.newBuilder() + .setSql(updateSql) + .setParams( + Struct.newBuilder() + .putFields("value", Value.newBuilder().setStringValue("foo").build()) + .putFields("id", Value.newBuilder().setStringValue("1").build()) + .build()) + .putAllParamTypes( + ImmutableMap.of( + "value", Type.newBuilder().setCode(TypeCode.STRING).build(), + "id", Type.newBuilder().setCode(TypeCode.INT64).build())) + .build()); + connection.commit(); + + // There should be no BeginTransaction requests, as the transaction start is inlined with the + // first statement. + assertEquals(0, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0); + assertTrue(request.hasTransaction()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadWrite()); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + } + + @Test + public void testReadOnlyTransaction() { + String sql = "select * from random"; + int numRows = 5; + RandomResultSetGenerator generator = new RandomResultSetGenerator(numRows); + mockSpanner.putStatementResult(StatementResult.query(Statement.of(sql), generator.generate())); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction( + TransactionOptions.newBuilder().setReadOnly(ReadOnly.newBuilder().build()).build()); + try (Rows rows = connection.execute(ExecuteSqlRequest.newBuilder().setSql(sql).build())) { + int rowCount = 0; + while (rows.next() != null) { + rowCount++; + } + assertEquals(numRows, rowCount); + } + connection.commit(); + } + + // There should be no BeginTransaction requests, as the transaction start is inlined with the + // first statement. + assertEquals(0, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0); + assertTrue(request.hasTransaction()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadOnly()); + // There should be no CommitRequests on the server, as committing a read-only transaction is a + // no-op on Spanner. + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Test + public void testBeginTwice() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + // Try to start two transactions on a connection. + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + SpannerLibException exception = + assertThrows( + SpannerLibException.class, + () -> connection.beginTransaction(TransactionOptions.getDefaultInstance())); + assertEquals(Code.FAILED_PRECONDITION.value(), exception.getStatus().getCode()); + } + } +} diff --git a/statements.go b/statements.go index 6920e7f2..f7232a7b 100644 --- a/statements.go +++ b/statements.go @@ -286,7 +286,7 @@ type executableCommitStatement struct { } func (s *executableCommitStatement) execContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Result, error) { - _, err := c.commit(ctx) + _, err := c.Commit(ctx) if err != nil { return nil, err } @@ -305,7 +305,7 @@ type executableRollbackStatement struct { } func (s *executableRollbackStatement) execContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Result, error) { - if err := c.rollback(ctx); err != nil { + if err := c.Rollback(ctx); err != nil { return nil, err } return driver.ResultNoRows, nil diff --git a/transaction.go b/transaction.go index dcc61d1b..5f93f94e 100644 --- a/transaction.go +++ b/transaction.go @@ -407,6 +407,7 @@ func (tx *readWriteTransaction) Commit() (err error) { return err } var commitResponse spanner.CommitResponse + // TODO: Optimize this to skip the Commit also if the transaction has not yet been used. if tx.rwTx != nil { if !tx.retryAborts() { ts, err := tx.rwTx.CommitWithReturnResp(tx.ctx)