Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 71 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ type SpannerConn interface {
// return the same Spanner client.
UnderlyingClient() (client *spanner.Client, err error)

// DetectStatementType returns the type of SQL statement.
DetectStatementType(query string) parser.StatementType

// resetTransactionForRetry resets the current transaction after it has
// been aborted by Spanner. Calling this function on a transaction that
// has not been aborted is not supported and will cause an error to be
Expand Down Expand Up @@ -286,6 +289,11 @@ func (c *conn) UnderlyingClient() (*spanner.Client, error) {
return c.client, nil
}

func (c *conn) DetectStatementType(query string) parser.StatementType {
info := c.parser.DetectStatementType(query)
return info.StatementType
}

func (c *conn) CommitTimestamp() (time.Time, error) {
ts := propertyCommitTimestamp.GetValueOrDefault(c.state)
if ts == nil {
Expand Down Expand Up @@ -675,6 +683,27 @@ func sum(affected []int64) int64 {
return sum
}

// WriteMutations is not part of the public API of the database/sql driver.
// It is exported for internal reasons, and may receive breaking changes without prior notice.
//
// WriteMutations writes mutations using this connection. The mutations are either buffered in the current transaction,
// or written directly to Spanner using a new read/write transaction if the connection does not have a transaction.
//
// The function returns an error if the connection currently has a read-only transaction.
//
// The returned CommitResponse is nil if the connection currently has a transaction, as the mutations will only be
// applied to Spanner when the transaction commits.
func (c *conn) WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error) {
if c.inTransaction() {
return nil, c.BufferWrite(ms)
}
ts, err := c.Apply(ctx, ms)
if err != nil {
return nil, err
}
return &spanner.CommitResponse{CommitTs: ts}, nil
}

func (c *conn) Apply(ctx context.Context, ms []*spanner.Mutation, opts ...spanner.ApplyOption) (commitTimestamp time.Time, err error) {
if c.inTransaction() {
return time.Time{}, spanner.ToSpannerError(
Expand Down Expand Up @@ -1071,6 +1100,34 @@ func (c *conn) getBatchReadOnlyTransactionOptions() BatchReadOnlyTransactionOpti
return BatchReadOnlyTransactionOptions{TimestampBound: c.ReadOnlyStaleness()}
}

// BeginReadOnlyTransaction is not part of the public API of the database/sql driver.
// It is exported for internal reasons, and may receive breaking changes without prior notice.
//
// BeginReadOnlyTransaction starts a new read-only transaction on this connection.
func (c *conn) BeginReadOnlyTransaction(ctx context.Context, options *ReadOnlyTransactionOptions) (driver.Tx, error) {
c.withTempReadOnlyTransactionOptions(options)
tx, err := c.BeginTx(ctx, driver.TxOptions{ReadOnly: true})
if err != nil {
c.withTempReadOnlyTransactionOptions(nil)
return nil, err
}
return tx, nil
}

// BeginReadWriteTransaction is not part of the public API of the database/sql driver.
// It is exported for internal reasons, and may receive breaking changes without prior notice.
//
// BeginReadWriteTransaction starts a new read/write transaction on this connection.
func (c *conn) BeginReadWriteTransaction(ctx context.Context, options *ReadWriteTransactionOptions) (driver.Tx, error) {
c.withTempTransactionOptions(options)
tx, err := c.BeginTx(ctx, driver.TxOptions{})
if err != nil {
c.withTempTransactionOptions(nil)
return nil, err
}
return tx, nil
}

func (c *conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
Expand Down Expand Up @@ -1254,18 +1311,29 @@ func (c *conn) inReadWriteTransaction() bool {
return false
}

func (c *conn) commit(ctx context.Context) (*spanner.CommitResponse, error) {
// Commit is not part of the public API of the database/sql driver.
// It is exported for internal reasons, and may receive breaking changes without prior notice.
//
// Commit commits the current transaction on this connection.
func (c *conn) Commit(ctx context.Context) (*spanner.CommitResponse, error) {
if !c.inTransaction() {
return nil, status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
}
// TODO: Pass in context to the tx.Commit() function.
if err := c.tx.Commit(); err != nil {
return nil, err
}
return c.CommitResponse()

// This will return either the commit response or nil, depending on whether the transaction was a
// read/write transaction or a read-only transaction.
return propertyCommitResponse.GetValueOrDefault(c.state), nil
}

func (c *conn) rollback(ctx context.Context) error {
// Rollback is not part of the public API of the database/sql driver.
// It is exported for internal reasons, and may receive breaking changes without prior notice.
//
// Rollback rollbacks the current transaction on this connection.
func (c *conn) Rollback(ctx context.Context) error {
if !c.inTransaction() {
return status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
}
Expand Down
20 changes: 20 additions & 0 deletions conn_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,26 @@ func TestTwoTransactionsOnOneConn(t *testing.T) {
}
}

func TestTwoQueriesOnOneConn(t *testing.T) {
t.Parallel()

db, _, teardown := setupTestDBConnection(t)
defer teardown()
ctx := context.Background()

c, _ := db.Conn(ctx)
defer silentClose(c)

for range 2 {
r, err := c.QueryContext(context.Background(), testutil.SelectFooFromBar)
if err != nil {
t.Fatal(err)
}
_ = r.Next()
defer silentClose(r)
}
}

func TestExplicitBeginTx(t *testing.T) {
t.Parallel()

Expand Down
238 changes: 238 additions & 0 deletions spannerlib/api/batch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package api

import (
"context"
"fmt"
"reflect"
"testing"

"cloud.google.com/go/longrunning/autogen/longrunningpb"
"cloud.google.com/go/spanner"
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
"cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/googleapis/go-sql-spanner/testutil"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/emptypb"
)

func TestExecuteDmlBatch(t *testing.T) {
t.Parallel()

ctx := context.Background()
server, teardown := setupMockServer(t)
defer teardown()
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)

poolId, err := CreatePool(ctx, dsn)
if err != nil {
t.Fatalf("CreatePool returned unexpected error: %v", err)
}
connId, err := CreateConnection(ctx, poolId)
if err != nil {
t.Fatalf("CreateConnection returned unexpected error: %v", err)
}

// Execute a DML batch.
request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{
{Sql: testutil.UpdateBarSetFoo},
{Sql: testutil.UpdateBarSetFoo},
}}
resp, err := ExecuteBatch(ctx, poolId, connId, request)
if err != nil {
t.Fatalf("ExecuteBatch returned unexpected error: %v", err)
}
if g, w := len(resp.ResultSets), 2; g != w {
t.Fatalf("num results mismatch\n Got: %d\nWant: %d", g, w)
}
for i, result := range resp.ResultSets {
if g, w := result.Stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w {
t.Fatalf("%d: update count mismatch\n Got: %d\nWant: %d", i, g, w)
}
}

requests := server.TestSpanner.DrainRequestsFromServer()
// There should be no ExecuteSql requests.
executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
if g, w := len(executeRequests), 0; g != w {
t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w)
}
batchRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteBatchDmlRequest{}))
if g, w := len(batchRequests), 1; g != w {
t.Fatalf("Execute batch request count mismatch\n Got: %v\nWant: %v", g, w)
}

if err := CloseConnection(ctx, poolId, connId); err != nil {
t.Fatalf("CloseConnection returned unexpected error: %v", err)
}
if err := ClosePool(ctx, poolId); err != nil {
t.Fatalf("ClosePool returned unexpected error: %v", err)
}
}

func TestExecuteDdlBatch(t *testing.T) {
t.Parallel()

ctx := context.Background()
server, teardown := setupMockServer(t)
defer teardown()
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)
// Set up a result for a DDL statement on the mock server.
var expectedResponse = &emptypb.Empty{}
anyMsg, _ := anypb.New(expectedResponse)
server.TestDatabaseAdmin.SetResps([]proto.Message{
&longrunningpb.Operation{
Done: true,
Result: &longrunningpb.Operation_Response{Response: anyMsg},
Name: "test-operation",
},
})

poolId, err := CreatePool(ctx, dsn)
if err != nil {
t.Fatalf("CreatePool returned unexpected error: %v", err)
}
connId, err := CreateConnection(ctx, poolId)
if err != nil {
t.Fatalf("CreateConnection returned unexpected error: %v", err)
}

// Execute a DDL batch. This also uses a DML batch request.
request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{
{Sql: "create table my_table (id int64 primary key, value string(100))"},
{Sql: "create index my_index on my_table (value)"},
}}
resp, err := ExecuteBatch(ctx, poolId, connId, request)
if err != nil {
t.Fatalf("ExecuteBatch returned unexpected error: %v", err)
}
// The response should contain an 'update count' per DDL statement.
if g, w := len(resp.ResultSets), 2; g != w {
t.Fatalf("num results mismatch\n Got: %d\nWant: %d", g, w)
}
// There is no update count for DDL statements.
for i, result := range resp.ResultSets {
emptyStats := &spannerpb.ResultSetStats{}
if g, w := result.Stats, emptyStats; !cmp.Equal(g, w, cmpopts.IgnoreUnexported(spannerpb.ResultSetStats{})) {
t.Fatalf("%d: ResultSetStats mismatch\n Got: %v\nWant: %v", i, g, w)
}
}

requests := server.TestSpanner.DrainRequestsFromServer()
// There should be no ExecuteSql requests.
executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
if g, w := len(executeRequests), 0; g != w {
t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w)
}
// There should also be no ExecuteBatchDml requests.
batchDmlRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteBatchDmlRequest{}))
if g, w := len(batchDmlRequests), 0; g != w {
t.Fatalf("ExecuteBatchDmlRequest count mismatch\n Got: %v\nWant: %v", g, w)
}

adminRequests := server.TestDatabaseAdmin.Reqs()
if g, w := len(adminRequests), 1; g != w {
t.Fatalf("admin request count mismatch\n Got: %v\nWant: %v", g, w)
}
ddlRequest := adminRequests[0].(*databasepb.UpdateDatabaseDdlRequest)
if g, w := len(ddlRequest.Statements), 2; g != w {
t.Fatalf("DDL statement count mismatch\n Got: %v\nWant: %v", g, w)
}

if err := CloseConnection(ctx, poolId, connId); err != nil {
t.Fatalf("CloseConnection returned unexpected error: %v", err)
}
if err := ClosePool(ctx, poolId); err != nil {
t.Fatalf("ClosePool returned unexpected error: %v", err)
}
}

func TestExecuteMixedBatch(t *testing.T) {
t.Parallel()

ctx := context.Background()
server, teardown := setupMockServer(t)
defer teardown()
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)

poolId, err := CreatePool(ctx, dsn)
if err != nil {
t.Fatalf("CreatePool returned unexpected error: %v", err)
}
connId, err := CreateConnection(ctx, poolId)
if err != nil {
t.Fatalf("CreateConnection returned unexpected error: %v", err)
}

// Try to execute a batch with mixed DML and DDL statements. This should fail.
request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{
{Sql: "create table my_table (id int64 primary key, value string(100))"},
{Sql: "update my_table set value = 100 where true"},
}}
_, err = ExecuteBatch(ctx, poolId, connId, request)
if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w {
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
}

if err := CloseConnection(ctx, poolId, connId); err != nil {
t.Fatalf("CloseConnection returned unexpected error: %v", err)
}
if err := ClosePool(ctx, poolId); err != nil {
t.Fatalf("ClosePool returned unexpected error: %v", err)
}
}

func TestExecuteDdlBatchInTransaction(t *testing.T) {
t.Parallel()

ctx := context.Background()
server, teardown := setupMockServer(t)
defer teardown()
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)

poolId, err := CreatePool(ctx, dsn)
if err != nil {
t.Fatalf("CreatePool returned unexpected error: %v", err)
}
connId, err := CreateConnection(ctx, poolId)
if err != nil {
t.Fatalf("CreateConnection returned unexpected error: %v", err)
}
if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil {
t.Fatalf("BeginTransaction returned unexpected error: %v", err)
}

// Try to execute a DDL batch in a transaction. This should fail.
request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{
{Sql: "create table my_table (id int64 primary key, value string(100))"},
{Sql: "create index my_index on my_table (value)"},
}}
_, err = ExecuteBatch(ctx, poolId, connId, request)
if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w {
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
}

if err := CloseConnection(ctx, poolId, connId); err != nil {
t.Fatalf("CloseConnection returned unexpected error: %v", err)
}
if err := ClosePool(ctx, poolId); err != nil {
t.Fatalf("ClosePool returned unexpected error: %v", err)
}
}
Loading
Loading