Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ jobs:
uses: actions/checkout@v5
- name: Run unit tests
run: go test -race -short
- name: Run connection state unit tests
run: go test -race -short
working-directory: connectionstate

lint:
runs-on: ubuntu-latest
Expand Down
21 changes: 11 additions & 10 deletions client_side_statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ import (
"cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/googleapis/go-sql-spanner/connectionstate"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/types/known/structpb"
)

func TestStatementExecutor_StartBatchDdl(t *testing.T) {
c := &conn{retryAborts: true, logger: noopLogger}
c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
s := &statementExecutor{}
ctx := context.Background()

Expand Down Expand Up @@ -61,7 +62,7 @@ func TestStatementExecutor_StartBatchDdl(t *testing.T) {
}

func TestStatementExecutor_StartBatchDml(t *testing.T) {
c := &conn{retryAborts: true, logger: noopLogger}
c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
s := &statementExecutor{}
ctx := context.Background()

Expand Down Expand Up @@ -98,7 +99,7 @@ func TestStatementExecutor_StartBatchDml(t *testing.T) {
}

func TestStatementExecutor_RetryAbortsInternally(t *testing.T) {
c := &conn{retryAborts: true, logger: noopLogger}
c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
s := &statementExecutor{}
ctx := context.Background()
for i, test := range []struct {
Expand Down Expand Up @@ -154,7 +155,7 @@ func TestStatementExecutor_RetryAbortsInternally(t *testing.T) {
}

func TestStatementExecutor_AutocommitDmlMode(t *testing.T) {
c := &conn{logger: noopLogger, connector: &connector{}}
c := &conn{logger: noopLogger, connector: &connector{}, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
_ = c.ResetSession(context.Background())
s := &statementExecutor{}
ctx := context.Background()
Expand Down Expand Up @@ -211,7 +212,7 @@ func TestStatementExecutor_AutocommitDmlMode(t *testing.T) {
}

func TestStatementExecutor_ReadOnlyStaleness(t *testing.T) {
c := &conn{logger: noopLogger}
c := &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
s := &statementExecutor{}
ctx := context.Background()
for i, test := range []struct {
Expand Down Expand Up @@ -282,7 +283,7 @@ func TestStatementExecutor_ReadOnlyStaleness(t *testing.T) {
func TestShowCommitTimestamp(t *testing.T) {
t.Parallel()

c := &conn{retryAborts: true, logger: noopLogger}
c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
s := &statementExecutor{}
ctx := context.Background()

Expand Down Expand Up @@ -328,7 +329,7 @@ func TestShowCommitTimestamp(t *testing.T) {
}

func TestStatementExecutor_ExcludeTxnFromChangeStreams(t *testing.T) {
c := &conn{retryAborts: true, logger: noopLogger}
c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
s := &statementExecutor{}
ctx := context.Background()
for i, test := range []struct {
Expand Down Expand Up @@ -384,7 +385,7 @@ func TestStatementExecutor_ExcludeTxnFromChangeStreams(t *testing.T) {
}

func TestStatementExecutor_MaxCommitDelay(t *testing.T) {
c := &conn{logger: noopLogger}
c := &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
s := &statementExecutor{}
ctx := context.Background()
for i, test := range []struct {
Expand Down Expand Up @@ -457,7 +458,7 @@ func TestStatementExecutor_SetTransactionTag(t *testing.T) {
{"", "tag-with-missing-opening-quote'", true},
{"", "'tag-with-missing-closing-quote", true},
} {
c := &conn{retryAborts: true, logger: noopLogger}
c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
s := &statementExecutor{}

it, err := s.ShowTransactionTag(ctx, c, "", ExecOptions{}, nil)
Expand Down Expand Up @@ -517,7 +518,7 @@ func TestStatementExecutor_SetTransactionTag(t *testing.T) {

func TestStatementExecutor_UsesExecOptions(t *testing.T) {
ctx := context.Background()
c := &conn{retryAborts: true, logger: noopLogger}
c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
s := &statementExecutor{}

it, err := s.ShowTransactionTag(ctx, c, "", ExecOptions{DecodeOption: DecodeOptionProto, ReturnResultSetMetadata: true, ReturnResultSetStats: true}, nil)
Expand Down
42 changes: 31 additions & 11 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
adminapi "cloud.google.com/go/spanner/admin/database/apiv1"
adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
"cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/googleapis/go-sql-spanner/connectionstate"
"google.golang.org/api/iterator"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -231,6 +232,8 @@ type conn struct {
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *statementInfo, options ExecOptions) (*result, *spanner.CommitResponse, error)
execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, error)

// state contains the current ConnectionState for this connection.
state *connectionstate.ConnectionState
// batch is the currently active DDL or DML batch on this connection.
batch *batch
// autoBatchDml determines whether DML statements should automatically
Expand All @@ -244,11 +247,6 @@ type conn struct {
// statements was correct.
autoBatchDmlUpdateCountVerification bool

// autocommitDMLMode determines the type of DML to use when a single DML
// statement is executed on a connection. The default is Transactional, but
// it can also be set to PartitionedNonAtomic to execute the statement as
// Partitioned DML.
autocommitDMLMode AutocommitDMLMode
// readOnlyStaleness is used for queries in autocommit mode and for read-only transactions.
readOnlyStaleness spanner.TimestampBound
// isolationLevel determines the default isolation level that is used for read/write
Expand Down Expand Up @@ -308,7 +306,7 @@ func (c *conn) setRetryAbortsInternally(retry bool) (driver.Result, error) {
}

func (c *conn) AutocommitDMLMode() AutocommitDMLMode {
return c.autocommitDMLMode
return propertyAutocommitDmlMode.GetValueOrDefault(c.state)
}

func (c *conn) SetAutocommitDMLMode(mode AutocommitDMLMode) error {
Expand All @@ -320,7 +318,9 @@ func (c *conn) SetAutocommitDMLMode(mode AutocommitDMLMode) error {
}

func (c *conn) setAutocommitDMLMode(mode AutocommitDMLMode) (driver.Result, error) {
c.autocommitDMLMode = mode
if err := propertyAutocommitDmlMode.SetValue(c.state, mode, connectionstate.ContextUser); err != nil {
return nil, err
}
return driver.ResultNoRows, nil
}

Expand Down Expand Up @@ -689,8 +689,9 @@ func (c *conn) ResetSession(_ context.Context) error {
c.retryAborts = c.connector.retryAbortsInternally
c.isolationLevel = c.connector.connectorConfig.IsolationLevel
c.beginTransactionOption = c.connector.connectorConfig.BeginTransactionOption

_ = c.state.Reset(connectionstate.ContextUser)
// TODO: Reset the following fields to the connector default
c.autocommitDMLMode = Transactional
c.readOnlyStaleness = spanner.TimestampBound{}
c.execOptions = ExecOptions{
DecodeToNativeArrays: c.connector.connectorConfig.DecodeToNativeArrays,
Expand Down Expand Up @@ -887,7 +888,7 @@ func (c *conn) execContext(ctx context.Context, query string, execOptions ExecOp
c.batch.statements = append(c.batch.statements, ss)
res = &result{}
} else {
dmlMode := c.autocommitDMLMode
dmlMode := c.AutocommitDMLMode()
if execOptions.AutocommitDMLMode != Unspecified {
dmlMode = execOptions.AutocommitDMLMode
}
Expand Down Expand Up @@ -1015,6 +1016,13 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
c.resetForRetry = false
return c.tx, nil
}
// Also start a transaction on the ConnectionState if the BeginTx call was successful.
defer func() {
if c.tx != nil {
_ = c.state.Begin()
}
}()

readOnlyTxOpts := c.getReadOnlyTransactionOptions()
batchReadOnlyTxOpts := c.getBatchReadOnlyTransactionOptions()
readWriteTransactionOptions := c.getTransactionOptions()
Expand Down Expand Up @@ -1072,13 +1080,18 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
roTx: ro,
boTx: bo,
logger: logger,
close: func() {
close: func(result txResult) {
if batchReadOnlyTxOpts.close != nil {
batchReadOnlyTxOpts.close()
}
if readOnlyTxOpts.close != nil {
readOnlyTxOpts.close()
}
if result == txResultCommit {
_ = c.state.Commit()
} else {
_ = c.state.Rollback()
}
c.tx = nil
},
}
Expand All @@ -1095,14 +1108,21 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
conn: c,
logger: logger,
rwTx: tx,
close: func(commitResponse *spanner.CommitResponse, commitErr error) {
close: func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) {
if readWriteTransactionOptions.close != nil {
readWriteTransactionOptions.close()
}
c.prevTx = c.tx
c.tx = nil
if commitErr == nil {
c.commitResponse = commitResponse
if result == txResultCommit {
_ = c.state.Commit()
} else {
_ = c.state.Rollback()
}
} else {
_ = c.state.Rollback()
}
},
// Disable internal retries if any of these options have been set.
Expand Down
103 changes: 103 additions & 0 deletions conn_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"cloud.google.com/go/spanner"
"cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/googleapis/go-sql-spanner/connectionstate"
"github.com/googleapis/go-sql-spanner/testutil"
)

Expand Down Expand Up @@ -448,3 +449,105 @@ func TestSetRetryAbortsInternallyInActiveTransaction(t *testing.T) {
}
_ = tx.Rollback()
}

func TestSetAutocommitDMLMode(t *testing.T) {
t.Parallel()

for _, tp := range []connectionstate.Type{connectionstate.TypeTransactional, connectionstate.TypeNonTransactional} {
db, _, teardown := setupTestDBConnectionWithConnectorConfig(t, ConnectorConfig{
Project: "p",
Instance: "i",
Database: "d",
ConnectionStateType: tp,
})
defer teardown()

conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
defer func() { _ = conn.Close() }()

_ = conn.Raw(func(driverConn interface{}) error {
c, _ := driverConn.(SpannerConn)
if g, w := c.AutocommitDMLMode(), Transactional; g != w {
t.Fatalf("initial value mismatch\n Got: %v\nWant: %v", g, w)
}
if err := c.SetAutocommitDMLMode(PartitionedNonAtomic); err != nil {
t.Fatal(err)
}
if g, w := c.AutocommitDMLMode(), PartitionedNonAtomic; g != w {
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
}
return nil
})

// Set the value in a transaction and commit.
tx, err := conn.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
t.Fatal(err)
}
_ = conn.Raw(func(driverConn interface{}) error {
c, _ := driverConn.(SpannerConn)
// The value should be the same as before the transaction started.
if g, w := c.AutocommitDMLMode(), PartitionedNonAtomic; g != w {
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
}
// Changes in a transaction should be visible in the transaction.
if err := c.SetAutocommitDMLMode(Transactional); err != nil {
t.Fatal(err)
}
if g, w := c.AutocommitDMLMode(), Transactional; g != w {
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
}
return nil
})
// Committing the transaction should make the change durable (and is a no-op if the connection state type is
// non-transactional).
if err := tx.Commit(); err != nil {
t.Fatal(err)
}
_ = conn.Raw(func(driverConn interface{}) error {
c, _ := driverConn.(SpannerConn)
if g, w := c.AutocommitDMLMode(), Transactional; g != w {
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
}
return nil
})

// Set the value in a transaction and rollback.
tx, err = conn.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
t.Fatal(err)
}
_ = conn.Raw(func(driverConn interface{}) error {
c, _ := driverConn.(SpannerConn)
if err := c.SetAutocommitDMLMode(PartitionedNonAtomic); err != nil {
t.Fatal(err)
}
if g, w := c.AutocommitDMLMode(), PartitionedNonAtomic; g != w {
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
}
return nil
})
// Rolling back the transaction will undo the change if the connection state is transactional.
// In case of non-transactional state, the rollback does not have an effect, as the state change was persisted
// directly when SetAutocommitDMLMode was called.
if err := tx.Rollback(); err != nil {
t.Fatal(err)
}
_ = conn.Raw(func(driverConn interface{}) error {
c, _ := driverConn.(SpannerConn)
var expected AutocommitDMLMode
if tp == connectionstate.TypeTransactional {
expected = Transactional
} else {
expected = PartitionedNonAtomic
}
if g, w := c.AutocommitDMLMode(), expected; g != w {
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
}
return nil
})
}
}
53 changes: 53 additions & 0 deletions connection_properties.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package spannerdriver

import "github.com/googleapis/go-sql-spanner/connectionstate"

// connectionProperties contains all supported connection properties for Spanner.
// These properties are added to all connectionstate.ConnectionState instances that are created for Spanner connections.
var connectionProperties = map[string]connectionstate.ConnectionProperty{}

// The following variables define the various connectionstate.ConnectionProperty instances that are supported and used
// by the Spanner database/sql driver. They are defined as global variables, so they can be used directly in the driver
// to get/set the state of exactly that property.

var propertyConnectionStateType = createConnectionProperty(
"connection_state_type",
"The type of connection state to use for this connection. Can only be set at start up. "+
"If no value is set, then the database dialect default will be used, "+
"which is NON_TRANSACTIONAL for GoogleSQL and TRANSACTIONAL for PostgreSQL.",
connectionstate.TypeDefault,
[]connectionstate.Type{connectionstate.TypeDefault, connectionstate.TypeTransactional, connectionstate.TypeNonTransactional},
connectionstate.ContextStartup,
)
var propertyAutocommitDmlMode = createConnectionProperty(
"autocommit_dml_mode",
"Determines the transaction type that is used to execute DML statements when the connection is in auto-commit mode.",
Transactional,
[]AutocommitDMLMode{Transactional, PartitionedNonAtomic},
connectionstate.ContextUser,
)

func createConnectionProperty[T comparable](name, description string, defaultValue T, validValues []T, context connectionstate.Context) *connectionstate.TypedConnectionProperty[T] {
prop := connectionstate.CreateConnectionProperty(name, description, defaultValue, validValues, context)
connectionProperties[prop.Key()] = prop
return prop
}

func createInitialConnectionState(connectionStateType connectionstate.Type, initialValues map[string]connectionstate.ConnectionPropertyValue) *connectionstate.ConnectionState {
state, _ := connectionstate.NewConnectionState(connectionStateType, connectionProperties, initialValues)
return state
}
Loading
Loading