From 4825aa46f92ec061ed62f62d76b37f483b25ca10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 21 Aug 2025 20:55:59 +0200 Subject: [PATCH 1/2] chore: add generic transactional connection state Adds data structures for generic transactional connection state. These structures will be used to keep all connection state in one place, making it easier to add new connection variables. This also adds support for transactional connection state; Changes that are made during a transaction are only persisted if the transaction is committed. It also allows for setting temporary (local) values during a transaction. This change is the first step in a multi-step process for moving all connection variables into a generic structure. Following changes will move the other connection variables into this structure, and will add support for executing `set local ...` statements. --- .github/workflows/unit-tests.yml | 3 + client_side_statement_test.go | 21 +- conn.go | 42 +- conn_with_mockserver_test.go | 103 +++++ connection_properties.go | 53 +++ connectionstate/connection_property.go | 344 ++++++++++++++++ connectionstate/connection_state.go | 163 ++++++++ connectionstate/connection_state_test.go | 491 +++++++++++++++++++++++ driver.go | 19 + driver_test.go | 19 +- transaction.go | 21 +- 11 files changed, 1243 insertions(+), 36 deletions(-) create mode 100644 connection_properties.go create mode 100644 connectionstate/connection_property.go create mode 100644 connectionstate/connection_state.go create mode 100644 connectionstate/connection_state_test.go diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 90b3fa55..942a4273 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -19,6 +19,9 @@ jobs: uses: actions/checkout@v5 - name: Run unit tests run: go test -race -short + - name: Run connection state unit tests + run: go test -race -short + working-directory: connectionstate lint: runs-on: ubuntu-latest diff --git a/client_side_statement_test.go b/client_side_statement_test.go index f99609e6..f20067e2 100644 --- a/client_side_statement_test.go +++ b/client_side_statement_test.go @@ -25,12 +25,13 @@ import ( "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/googleapis/go-sql-spanner/connectionstate" "google.golang.org/grpc/codes" "google.golang.org/protobuf/types/known/structpb" ) func TestStatementExecutor_StartBatchDdl(t *testing.T) { - c := &conn{retryAborts: true, logger: noopLogger} + c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} ctx := context.Background() @@ -61,7 +62,7 @@ func TestStatementExecutor_StartBatchDdl(t *testing.T) { } func TestStatementExecutor_StartBatchDml(t *testing.T) { - c := &conn{retryAborts: true, logger: noopLogger} + c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} ctx := context.Background() @@ -98,7 +99,7 @@ func TestStatementExecutor_StartBatchDml(t *testing.T) { } func TestStatementExecutor_RetryAbortsInternally(t *testing.T) { - c := &conn{retryAborts: true, logger: noopLogger} + c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} ctx := context.Background() for i, test := range []struct { @@ -154,7 +155,7 @@ func TestStatementExecutor_RetryAbortsInternally(t *testing.T) { } func TestStatementExecutor_AutocommitDmlMode(t *testing.T) { - c := &conn{logger: noopLogger, connector: &connector{}} + c := &conn{logger: noopLogger, connector: &connector{}, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} _ = c.ResetSession(context.Background()) s := &statementExecutor{} ctx := context.Background() @@ -211,7 +212,7 @@ func TestStatementExecutor_AutocommitDmlMode(t *testing.T) { } func TestStatementExecutor_ReadOnlyStaleness(t *testing.T) { - c := &conn{logger: noopLogger} + c := &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} ctx := context.Background() for i, test := range []struct { @@ -282,7 +283,7 @@ func TestStatementExecutor_ReadOnlyStaleness(t *testing.T) { func TestShowCommitTimestamp(t *testing.T) { t.Parallel() - c := &conn{retryAborts: true, logger: noopLogger} + c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} ctx := context.Background() @@ -328,7 +329,7 @@ func TestShowCommitTimestamp(t *testing.T) { } func TestStatementExecutor_ExcludeTxnFromChangeStreams(t *testing.T) { - c := &conn{retryAborts: true, logger: noopLogger} + c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} ctx := context.Background() for i, test := range []struct { @@ -384,7 +385,7 @@ func TestStatementExecutor_ExcludeTxnFromChangeStreams(t *testing.T) { } func TestStatementExecutor_MaxCommitDelay(t *testing.T) { - c := &conn{logger: noopLogger} + c := &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} ctx := context.Background() for i, test := range []struct { @@ -457,7 +458,7 @@ func TestStatementExecutor_SetTransactionTag(t *testing.T) { {"", "tag-with-missing-opening-quote'", true}, {"", "'tag-with-missing-closing-quote", true}, } { - c := &conn{retryAborts: true, logger: noopLogger} + c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} it, err := s.ShowTransactionTag(ctx, c, "", ExecOptions{}, nil) @@ -517,7 +518,7 @@ func TestStatementExecutor_SetTransactionTag(t *testing.T) { func TestStatementExecutor_UsesExecOptions(t *testing.T) { ctx := context.Background() - c := &conn{retryAborts: true, logger: noopLogger} + c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} it, err := s.ShowTransactionTag(ctx, c, "", ExecOptions{DecodeOption: DecodeOptionProto, ReturnResultSetMetadata: true, ReturnResultSetStats: true}, nil) diff --git a/conn.go b/conn.go index addca502..929294bc 100644 --- a/conn.go +++ b/conn.go @@ -27,6 +27,7 @@ import ( adminapi "cloud.google.com/go/spanner/admin/database/apiv1" adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/connectionstate" "google.golang.org/api/iterator" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -231,6 +232,8 @@ type conn struct { execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *statementInfo, options ExecOptions) (*result, *spanner.CommitResponse, error) execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, error) + // state contains the current ConnectionState for this connection. + state *connectionstate.ConnectionState // batch is the currently active DDL or DML batch on this connection. batch *batch // autoBatchDml determines whether DML statements should automatically @@ -244,11 +247,6 @@ type conn struct { // statements was correct. autoBatchDmlUpdateCountVerification bool - // autocommitDMLMode determines the type of DML to use when a single DML - // statement is executed on a connection. The default is Transactional, but - // it can also be set to PartitionedNonAtomic to execute the statement as - // Partitioned DML. - autocommitDMLMode AutocommitDMLMode // readOnlyStaleness is used for queries in autocommit mode and for read-only transactions. readOnlyStaleness spanner.TimestampBound // isolationLevel determines the default isolation level that is used for read/write @@ -308,7 +306,7 @@ func (c *conn) setRetryAbortsInternally(retry bool) (driver.Result, error) { } func (c *conn) AutocommitDMLMode() AutocommitDMLMode { - return c.autocommitDMLMode + return propertyAutocommitDmlMode.GetValueOrDefault(c.state) } func (c *conn) SetAutocommitDMLMode(mode AutocommitDMLMode) error { @@ -320,7 +318,9 @@ func (c *conn) SetAutocommitDMLMode(mode AutocommitDMLMode) error { } func (c *conn) setAutocommitDMLMode(mode AutocommitDMLMode) (driver.Result, error) { - c.autocommitDMLMode = mode + if err := propertyAutocommitDmlMode.SetValue(c.state, mode, connectionstate.ContextUser); err != nil { + return nil, err + } return driver.ResultNoRows, nil } @@ -689,8 +689,9 @@ func (c *conn) ResetSession(_ context.Context) error { c.retryAborts = c.connector.retryAbortsInternally c.isolationLevel = c.connector.connectorConfig.IsolationLevel c.beginTransactionOption = c.connector.connectorConfig.BeginTransactionOption + + _ = c.state.Reset(connectionstate.ContextUser) // TODO: Reset the following fields to the connector default - c.autocommitDMLMode = Transactional c.readOnlyStaleness = spanner.TimestampBound{} c.execOptions = ExecOptions{ DecodeToNativeArrays: c.connector.connectorConfig.DecodeToNativeArrays, @@ -887,7 +888,7 @@ func (c *conn) execContext(ctx context.Context, query string, execOptions ExecOp c.batch.statements = append(c.batch.statements, ss) res = &result{} } else { - dmlMode := c.autocommitDMLMode + dmlMode := c.AutocommitDMLMode() if execOptions.AutocommitDMLMode != Unspecified { dmlMode = execOptions.AutocommitDMLMode } @@ -1015,6 +1016,13 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e c.resetForRetry = false return c.tx, nil } + // Also start a transaction on the ConnectionState if the BeginTx call was successful. + defer func() { + if c.tx != nil { + _ = c.state.Begin() + } + }() + readOnlyTxOpts := c.getReadOnlyTransactionOptions() batchReadOnlyTxOpts := c.getBatchReadOnlyTransactionOptions() readWriteTransactionOptions := c.getTransactionOptions() @@ -1072,13 +1080,18 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e roTx: ro, boTx: bo, logger: logger, - close: func() { + close: func(result txResult) { if batchReadOnlyTxOpts.close != nil { batchReadOnlyTxOpts.close() } if readOnlyTxOpts.close != nil { readOnlyTxOpts.close() } + if result == txResultCommit { + _ = c.state.Commit() + } else { + _ = c.state.Rollback() + } c.tx = nil }, } @@ -1095,7 +1108,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e conn: c, logger: logger, rwTx: tx, - close: func(commitResponse *spanner.CommitResponse, commitErr error) { + close: func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) { if readWriteTransactionOptions.close != nil { readWriteTransactionOptions.close() } @@ -1103,6 +1116,13 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e c.tx = nil if commitErr == nil { c.commitResponse = commitResponse + if result == txResultCommit { + _ = c.state.Commit() + } else { + _ = c.state.Rollback() + } + } else { + _ = c.state.Rollback() } }, // Disable internal retries if any of these options have been set. diff --git a/conn_with_mockserver_test.go b/conn_with_mockserver_test.go index eaffb39d..8560acba 100644 --- a/conn_with_mockserver_test.go +++ b/conn_with_mockserver_test.go @@ -23,6 +23,7 @@ import ( "cloud.google.com/go/spanner" "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/connectionstate" "github.com/googleapis/go-sql-spanner/testutil" ) @@ -448,3 +449,105 @@ func TestSetRetryAbortsInternallyInActiveTransaction(t *testing.T) { } _ = tx.Rollback() } + +func TestSetAutocommitDMLMode(t *testing.T) { + t.Parallel() + + for _, tp := range []connectionstate.Type{connectionstate.TypeTransactional, connectionstate.TypeNonTransactional} { + db, _, teardown := setupTestDBConnectionWithConnectorConfig(t, ConnectorConfig{ + Project: "p", + Instance: "i", + Database: "d", + ConnectionStateType: tp, + }) + defer teardown() + + conn, err := db.Conn(context.Background()) + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn.Close() }() + + _ = conn.Raw(func(driverConn interface{}) error { + c, _ := driverConn.(SpannerConn) + if g, w := c.AutocommitDMLMode(), Transactional; g != w { + t.Fatalf("initial value mismatch\n Got: %v\nWant: %v", g, w) + } + if err := c.SetAutocommitDMLMode(PartitionedNonAtomic); err != nil { + t.Fatal(err) + } + if g, w := c.AutocommitDMLMode(), PartitionedNonAtomic; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + return nil + }) + + // Set the value in a transaction and commit. + tx, err := conn.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + t.Fatal(err) + } + _ = conn.Raw(func(driverConn interface{}) error { + c, _ := driverConn.(SpannerConn) + // The value should be the same as before the transaction started. + if g, w := c.AutocommitDMLMode(), PartitionedNonAtomic; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + // Changes in a transaction should be visible in the transaction. + if err := c.SetAutocommitDMLMode(Transactional); err != nil { + t.Fatal(err) + } + if g, w := c.AutocommitDMLMode(), Transactional; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + return nil + }) + // Committing the transaction should make the change durable (and is a no-op if the connection state type is + // non-transactional). + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + _ = conn.Raw(func(driverConn interface{}) error { + c, _ := driverConn.(SpannerConn) + if g, w := c.AutocommitDMLMode(), Transactional; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + return nil + }) + + // Set the value in a transaction and rollback. + tx, err = conn.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + t.Fatal(err) + } + _ = conn.Raw(func(driverConn interface{}) error { + c, _ := driverConn.(SpannerConn) + if err := c.SetAutocommitDMLMode(PartitionedNonAtomic); err != nil { + t.Fatal(err) + } + if g, w := c.AutocommitDMLMode(), PartitionedNonAtomic; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + return nil + }) + // Rolling back the transaction will undo the change if the connection state is transactional. + // In case of non-transactional state, the rollback does not have an effect, as the state change was persisted + // directly when SetAutocommitDMLMode was called. + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + _ = conn.Raw(func(driverConn interface{}) error { + c, _ := driverConn.(SpannerConn) + var expected AutocommitDMLMode + if tp == connectionstate.TypeTransactional { + expected = Transactional + } else { + expected = PartitionedNonAtomic + } + if g, w := c.AutocommitDMLMode(), expected; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + return nil + }) + } +} diff --git a/connection_properties.go b/connection_properties.go new file mode 100644 index 00000000..bd38d042 --- /dev/null +++ b/connection_properties.go @@ -0,0 +1,53 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spannerdriver + +import "github.com/googleapis/go-sql-spanner/connectionstate" + +// connectionProperties contains all supported connection properties for Spanner. +// These properties are added to all connectionstate.ConnectionState instances that are created for Spanner connections. +var connectionProperties = map[string]connectionstate.ConnectionProperty{} + +// The following variables define the various connectionstate.ConnectionProperty instances that are supported and used +// by the Spanner database/sql driver. They are defined as global variables, so they can be used directly in the driver +// to get/set the state of exactly that property. + +var propertyConnectionStateType = createConnectionProperty( + "connection_state_type", + "The type of connection state to use for this connection. Can only be set at start up. "+ + "If no value is set, then the database dialect default will be used, "+ + "which is NON_TRANSACTIONAL for GoogleSQL and TRANSACTIONAL for PostgreSQL.", + connectionstate.TypeDefault, + []connectionstate.Type{connectionstate.TypeDefault, connectionstate.TypeTransactional, connectionstate.TypeNonTransactional}, + connectionstate.ContextStartup, +) +var propertyAutocommitDmlMode = createConnectionProperty( + "autocommit_dml_mode", + "Determines the transaction type that is used to execute DML statements when the connection is in auto-commit mode.", + Transactional, + []AutocommitDMLMode{Transactional, PartitionedNonAtomic}, + connectionstate.ContextUser, +) + +func createConnectionProperty[T comparable](name, description string, defaultValue T, validValues []T, context connectionstate.Context) *connectionstate.TypedConnectionProperty[T] { + prop := connectionstate.CreateConnectionProperty(name, description, defaultValue, validValues, context) + connectionProperties[prop.Key()] = prop + return prop +} + +func createInitialConnectionState(connectionStateType connectionstate.Type, initialValues map[string]connectionstate.ConnectionPropertyValue) *connectionstate.ConnectionState { + state, _ := connectionstate.NewConnectionState(connectionStateType, connectionProperties, initialValues) + return state +} diff --git a/connectionstate/connection_property.go b/connectionstate/connection_property.go new file mode 100644 index 00000000..66db44b1 --- /dev/null +++ b/connectionstate/connection_property.go @@ -0,0 +1,344 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connectionstate + +import ( + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// Context indicates when a ConnectionProperty may be set. +type Context int + +const ( + // ContextStartup is used for ConnectionProperty instances that may only be set at startup and may not be changed + // during the lifetime of a connection. + ContextStartup = iota + // ContextUser is used for ConnectionProperty instances that may be set both at startup and during the lifetime of + // a connection. + ContextUser +) + +func (c Context) String() string { + switch c { + case ContextStartup: + return "Startup" + case ContextUser: + return "User" + default: + return "Unknown" + } +} + +// ConnectionProperty defines the public interface for connection properties. +type ConnectionProperty interface { + // Key returns the unique key of the ConnectionProperty. This is equal to the name of the property for properties + // without an extension, and equal to `extension.name` for properties with an extension. + Key() string + // Context returns the Context where the property is allowed to be updated (e.g. only at startup or during the + // lifetime of a connection). + Context() Context + // CreateInitialValue creates an initial value of the property with the default value of the property as the current + // and reset value. + CreateInitialValue() ConnectionPropertyValue +} + +// CreateConnectionProperty is used to create a new ConnectionProperty with a specific type. This function is intended +// for use by driver implementations at initialization time to define the properties that the driver supports. +func CreateConnectionProperty[T comparable](name, description string, defaultValue T, validValues []T, context Context) *TypedConnectionProperty[T] { + return CreateConnectionPropertyWithExtension("", name, description, defaultValue, validValues, context) +} + +// CreateConnectionPropertyWithExtension is used to create a new ConnectionProperty with a specific type and an +// extension. Properties with an extension can be created dynamically during the lifetime of a connection. These are +// lost when the connection is reset to its original state. +func CreateConnectionPropertyWithExtension[T comparable](extension, name, description string, defaultValue T, validValues []T, context Context) *TypedConnectionProperty[T] { + var key string + if extension == "" { + key = name + } else { + key = extension + "." + name + } + return &TypedConnectionProperty[T]{ + key: key, + extension: extension, + name: name, + description: description, + defaultValue: defaultValue, + validValues: validValues, + context: context, + } +} + +var _ ConnectionProperty = (*TypedConnectionProperty[any])(nil) + +// TypedConnectionProperty implements the ConnectionProperty interface. +// All fields are unexported to ensure that the values can only be updated in accordance with the semantics of the +// chosen ConnectionState Type. +type TypedConnectionProperty[T comparable] struct { + key string + extension string + name string + description string + defaultValue T + validValues []T + context Context +} + +func (p *TypedConnectionProperty[T]) String() string { + return p.Key() +} + +func (p *TypedConnectionProperty[T]) Key() string { + return p.key +} + +func (p *TypedConnectionProperty[T]) Context() Context { + return p.context +} + +func (p *TypedConnectionProperty[T]) CreateInitialValue() ConnectionPropertyValue { + return &connectionPropertyValue[T]{ + connectionProperty: p, + resetValue: p.defaultValue, + value: p.defaultValue, + removeAtReset: false, + } +} + +// GetValueOrDefault returns the current value of the property in the given ConnectionState. +// It returns the default of the property if no value is found. +func (p *TypedConnectionProperty[T]) GetValueOrDefault(state *ConnectionState) T { + value, _ := state.value(p /*returnErrForUnknownProperty=*/, false) + if value == nil { + return p.defaultValue + } + if typedValue, ok := value.(*connectionPropertyValue[T]); ok { + return typedValue.value + } + return p.defaultValue +} + +// GetValueOrError returns the current value of the property in the given ConnectionState. +// It returns an error if no value is found. +func (p *TypedConnectionProperty[T]) GetValueOrError(state *ConnectionState) (T, error) { + value, err := state.value(p /*returnErrForUnknownProperty=*/, true) + if err != nil { + return p.zeroAndErr(err) + } + if value == nil { + return p.zeroAndErr(status.Errorf(codes.InvalidArgument, "no value found for property: %q", p)) + } + if typedValue, ok := value.(*connectionPropertyValue[T]); ok { + return typedValue.value, nil + } + return p.zeroAndErr(status.Errorf(codes.InvalidArgument, "value has wrong type: %s", value)) +} + +// ResetValue resets the value of the property in the given ConnectionState to its original value. +// +// The given Context should indicate the current context where the application tries to reset the value, e.g. it should +// be ContextUser if the reset happens during the lifetime of a connection, and ContextStartup if the reset happens at +// the creation of a connection. +func (p *TypedConnectionProperty[T]) ResetValue(state *ConnectionState, context Context) error { + value, _ := state.value(p /*returnErrForUnknownProperty=*/, false) + if value == nil { + var t T + return p.SetValue(state, t, context) + } else { + resetValue := value.GetResetValue() + typedResetValue, ok := resetValue.(T) + if !ok { + return status.Errorf(codes.InvalidArgument, "value has wrong type: %T", resetValue) + } + return p.SetValue(state, typedResetValue, context) + } +} + +// SetValue sets the value of the property in the given ConnectionState. +// +// The given Context should indicate the current context where the application tries to reset the value, e.g. it should +// be ContextUser if the reset happens during the lifetime of a connection, and ContextStartup if the reset happens at +// the creation of a connection. +func (p *TypedConnectionProperty[T]) SetValue(state *ConnectionState, value T, context Context) error { + if p.context < context { + return status.Errorf(codes.FailedPrecondition, "property has context %s and cannot be set in context %s", p.context, context) + } + if !state.inTransaction || state.connectionStateType == TypeNonTransactional || context < ContextUser { + // Set the value in non-transactional mode. + if err := p.setValue(state, state.properties, value, context); err != nil { + return err + } + // Remove the setting from the local settings if it's there, as the new setting is + // the one that should be used. + if state.localProperties != nil { + delete(state.localProperties, p.key) + } + return nil + } + // Set the value in a transaction. + if state.transactionProperties == nil { + state.transactionProperties = make(map[string]ConnectionPropertyValue) + } + if err := p.setValue(state, state.transactionProperties, value, context); err != nil { + return err + } + // Remove the setting from the local settings if it's there, as the new transaction setting is + // the one that should be used. + if state.localProperties != nil { + delete(state.localProperties, p.key) + } + return nil +} + +// SetLocalValue sets the local value of the property in the given ConnectionState. A local value is only visible +// for the remainder of the current transaction. The value is reset to the value it had before the transaction when the +// transaction ends, regardless whether the transaction committed or rolled back. +// +// Setting a local value outside a transaction is a no-op. +func (p *TypedConnectionProperty[T]) SetLocalValue(state *ConnectionState, value T) error { + if p.context < ContextUser { + return status.Error(codes.FailedPrecondition, "SetLocalValue is only supported for properties with context USER or higher") + } + if !state.inTransaction { + // SetLocalValue outside a transaction is a no-op. + return nil + } + if state.localProperties == nil { + state.localProperties = make(map[string]ConnectionPropertyValue) + } + return p.setValue(state, state.localProperties, value, ContextUser) +} + +func (p *TypedConnectionProperty[T]) setValue(state *ConnectionState, currentProperties map[string]ConnectionPropertyValue, value T, context Context) error { + if err := p.checkValidValue(value); err != nil { + return err + } + newValue, ok := currentProperties[p.key] + if !ok { + existingValue, ok := state.properties[p.key] + if !ok { + if p.extension == "" { + return unknownPropertyErr(p) + } + newValue = &connectionPropertyValue[T]{connectionProperty: p, removeAtReset: true} + } else { + newValue = existingValue.Copy() + } + } + if err := newValue.SetValue(value, context); err != nil { + return err + } + currentProperties[p.key] = newValue + return nil +} + +func (p *TypedConnectionProperty[T]) zeroAndErr(err error) (T, error) { + var t T + return t, err +} + +func (p *TypedConnectionProperty[T]) checkValidValue(value T) error { + if p.validValues == nil { + return nil + } + for _, validValue := range p.validValues { + if value == validValue { + return nil + } + } + return nil +} + +func unknownPropertyErr(p ConnectionProperty) error { + return status.Errorf(codes.InvalidArgument, "unrecognized configuration property %q", p.Key()) +} + +// ConnectionPropertyValue is the public interface for connection state property values. +type ConnectionPropertyValue interface { + // ConnectionProperty returns the property that this value is for. + ConnectionProperty() ConnectionProperty + // Copy creates a shallow copy of the ConnectionPropertyValue. + Copy() ConnectionPropertyValue + // SetValue sets the value of the property. The given value must be a valid value for the property. + SetValue(value any, context Context) error + // ResetValue resets the value of the property to the value it had at the creation of the connection. + ResetValue(context Context) error + // RemoveAtReset indicates whether the value should be removed from the ConnectionState when the ConnectionState is + // reset. This function should return true for property values that have been added to the set after the connection + // was created, for example because the user executed `set my_extension.my_property='some-value'`. + RemoveAtReset() bool + // GetResetValue returns the value that will be assigned to this property value if the value is reset. + GetResetValue() any +} + +// CreateInitialValue creates an initial value for a property. Both the current and the reset value are set to the given +// value. +func CreateInitialValue[T comparable](property *TypedConnectionProperty[T], value T) ConnectionPropertyValue { + return &connectionPropertyValue[T]{ + connectionProperty: property, + value: value, + resetValue: value, + removeAtReset: false, + } +} + +type connectionPropertyValue[T comparable] struct { + connectionProperty *TypedConnectionProperty[T] + resetValue T + value T + removeAtReset bool +} + +func (v *connectionPropertyValue[T]) ConnectionProperty() ConnectionProperty { + return v.connectionProperty +} + +func (v *connectionPropertyValue[T]) Copy() ConnectionPropertyValue { + return &connectionPropertyValue[T]{ + connectionProperty: v.connectionProperty, + resetValue: v.resetValue, + value: v.value, + removeAtReset: v.removeAtReset, + } +} + +func (v *connectionPropertyValue[T]) SetValue(value any, context Context) error { + if v.connectionProperty.context < context { + return status.Errorf(codes.FailedPrecondition, "property has context %s and cannot be set in context %s", v.connectionProperty.context, context) + } + typedValue, ok := value.(T) + if !ok { + return status.Errorf(codes.InvalidArgument, "value has wrong type: %T", value) + } + v.value = typedValue + return nil +} + +func (v *connectionPropertyValue[T]) ResetValue(context Context) error { + if v.connectionProperty.context < context { + return status.Errorf(codes.FailedPrecondition, "property has context %s and cannot be set in context %s", v.connectionProperty.context, context) + } + v.value = v.resetValue + return nil +} + +func (v *connectionPropertyValue[T]) RemoveAtReset() bool { + return v.removeAtReset +} + +func (v *connectionPropertyValue[T]) GetResetValue() any { + return v.resetValue +} diff --git a/connectionstate/connection_state.go b/connectionstate/connection_state.go new file mode 100644 index 00000000..a09959ab --- /dev/null +++ b/connectionstate/connection_state.go @@ -0,0 +1,163 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connectionstate + +import ( + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// Type represents the type of ConnectionState that is used. +// ConnectionState can be either transactional or non-transactional. +// Transactional ConnectionState means that any changes to the state +// during a transaction will only be persisted and visible after the +// transaction if the transaction is committed. +// Non-transactional ConnectionState means that changes that are made +// to the state during a transaction are persisted directly and will +// be visible after the transaction regardless whether the transaction +// was committed or rolled back. +type Type int + +const ( + // TypeDefault indicates that the default ConnectionState type of the database dialect should be used. + // GoogleSQL uses non-transactional ConnectionState by default. + // PostgreSQL uses transactional ConnectionState by default. + TypeDefault Type = iota + // TypeTransactional ConnectionState means that changes to the state during a transaction will only be + // persisted and visible after the transaction if the transaction is committed. + TypeTransactional + // TypeNonTransactional ConnectionState means that changes to the state during a transaction are persisted + // directly and remain visible after the transaction, regardless whether the transaction committed or not. + TypeNonTransactional +) + +// ConnectionState contains connection the state of a connection in a map. +type ConnectionState struct { + connectionStateType Type + inTransaction bool + properties map[string]ConnectionPropertyValue + transactionProperties map[string]ConnectionPropertyValue + localProperties map[string]ConnectionPropertyValue +} + +// NewConnectionState creates a new ConnectionState instance with the given initial values. +// The Type must be either TypeTransactional or TypeNonTransactional. +func NewConnectionState(connectionStateType Type, properties map[string]ConnectionProperty, initialValues map[string]ConnectionPropertyValue) (*ConnectionState, error) { + if connectionStateType == TypeDefault { + return nil, status.Error(codes.InvalidArgument, "connection state type cannot be TypeDefault") + } + state := &ConnectionState{ + connectionStateType: connectionStateType, + properties: make(map[string]ConnectionPropertyValue), + transactionProperties: nil, + localProperties: nil, + } + for key, value := range initialValues { + state.properties[key] = value.Copy() + } + for key, value := range properties { + if _, ok := state.properties[key]; !ok { + state.properties[key] = value.CreateInitialValue() + } + } + return state, nil +} + +// Begin starts a new transaction for this ConnectionState. +func (cs *ConnectionState) Begin() error { + if cs.inTransaction { + return status.Error(codes.FailedPrecondition, "connection state is already in transaction") + } + cs.inTransaction = true + return nil +} + +// Commit the current ConnectionState. +// This resets all local property values to the value they had before the transaction. +// If the Type is TypeTransactional, then all pending state changes are committed. +// If the Type is TypeNonTransactional, all state changes during the transaction have already been persisted during the +// transaction. +func (cs *ConnectionState) Commit() error { + if !cs.inTransaction { + return status.Error(codes.FailedPrecondition, "connection state is not in a transaction") + } + cs.inTransaction = false + if cs.transactionProperties != nil { + for key, value := range cs.transactionProperties { + cs.properties[key] = value + } + } + cs.transactionProperties = nil + cs.localProperties = nil + return nil +} + +// Rollback the current transactional state. +// This resets all local property values to the value they had before the transaction. +// If the Type is TypeTransactional, then all pending state changes are rolled back. +// If the Type is TypeNonTransactional, all state changes during the transaction have already been persisted during the +// transaction. +func (cs *ConnectionState) Rollback() error { + if !cs.inTransaction { + return status.Error(codes.FailedPrecondition, "connection state is not in a transaction") + } + cs.inTransaction = false + cs.transactionProperties = nil + cs.localProperties = nil + return nil +} + +// Reset the state to the initial values. Only the properties with a Context equal to or higher +// than the given Context will be reset. E.g. if the given Context is ContextUser, then properties +// with ContextStartup will not be reset. +func (cs *ConnectionState) Reset(context Context) error { + cs.transactionProperties = nil + cs.localProperties = nil + var remove map[string]bool + for _, value := range cs.properties { + if value.ConnectionProperty().Context() >= context { + if value.RemoveAtReset() { + if remove == nil { + remove = make(map[string]bool) + } + remove[value.ConnectionProperty().Key()] = true + } else { + if err := value.ResetValue(context); err != nil { + return err + } + } + } + } + for key := range remove { + delete(cs.properties, key) + } + return nil +} + +func (cs *ConnectionState) value(property ConnectionProperty, returnErrForUnknownProperty bool) (ConnectionPropertyValue, error) { + if val, ok := cs.localProperties[property.Key()]; ok { + return val, nil + } + if val, ok := cs.transactionProperties[property.Key()]; ok { + return val, nil + } + if val, ok := cs.properties[property.Key()]; ok { + return val, nil + } + if returnErrForUnknownProperty { + return nil, status.Errorf(codes.InvalidArgument, "unrecognized configuration property %q", property.Key()) + } + return nil, nil +} diff --git a/connectionstate/connection_state_test.go b/connectionstate/connection_state_test.go new file mode 100644 index 00000000..6a04369c --- /dev/null +++ b/connectionstate/connection_state_test.go @@ -0,0 +1,491 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connectionstate + +import ( + "testing" + + "cloud.google.com/go/spanner" + "google.golang.org/grpc/codes" +) + +func TestSetValueOutsideTransaction(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + for _, setToValue := range []string{"new-value", ""} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("initial value mismatch\n Got: %v\nWant: %v", g, w) + } + if err := prop.SetValue(state, setToValue, ContextUser); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), setToValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } + } +} + +func TestSetValueInTransactionAndCommit(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + for _, setToValue := range []string{"new-value", ""} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + if err := state.Begin(); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("initial value mismatch\n Got: %v\nWant: %v", g, w) + } + if err := prop.SetValue(state, setToValue, ContextUser); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), setToValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the change is persisted if the transaction is committed. + if err := state.Commit(); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), setToValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } + } +} + +func TestSetValueInTransactionAndRollback(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + for _, setToValue := range []string{"new-value", ""} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + if err := state.Begin(); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("initial value mismatch\n Got: %v\nWant: %v", g, w) + } + if err := prop.SetValue(state, setToValue, ContextUser); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), setToValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the change is rolled back if the transaction is rolled back and the connection + // state is transactional. + if err := state.Rollback(); err != nil { + t.Fatal(err) + } + var expected string + if tp == TypeTransactional { + expected = prop.defaultValue + } else { + expected = setToValue + } + if g, w := prop.GetValueOrDefault(state), expected; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } + } +} + +func TestResetValueOutsideTransaction(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + _ = prop.SetValue(state, "new-value", ContextUser) + if err := prop.ResetValue(state, ContextUser); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestResetValueInTransactionAndCommit(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + if err := state.Begin(); err != nil { + t.Fatal(err) + } + + // Change the value to something else than the default and commit. + _ = prop.SetValue(state, "new-value", ContextUser) + _ = state.Commit() + if g, w := prop.GetValueOrDefault(state), "new-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Now reset the value to the default in a new transaction. + if err := state.Begin(); err != nil { + t.Fatal(err) + } + if err := prop.ResetValue(state, ContextUser); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the change is persisted if the transaction is committed. + if err := state.Commit(); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestResetValueInTransactionAndRollback(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + if err := state.Begin(); err != nil { + t.Fatal(err) + } + + // Change the value to something else than the default and commit. + _ = prop.SetValue(state, "new-value", ContextUser) + if err := state.Commit(); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), "new-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Now reset the value to the default in a new transaction. + if err := state.Begin(); err != nil { + t.Fatal(err) + } + if err := prop.ResetValue(state, ContextUser); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the change is rolled back if the transaction is rolled back and the connection + // state is transactional. + if err := state.Rollback(); err != nil { + t.Fatal(err) + } + // The value should be rolled back to "new-value" if the state is transactional. + // The value should be "initial-value" if the state is non-transactional, as the Reset is persisted regardless + // whether the transaction committed or not. + var expected string + if tp == TypeTransactional { + expected = "new-value" + } else { + expected = prop.defaultValue + } + if g, w := prop.GetValueOrDefault(state), expected; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestSetLocalValue(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + if err := state.Begin(); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("initial value mismatch\n Got: %v\nWant: %v", g, w) + } + if err := prop.SetLocalValue(state, "new-value"); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), "new-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the change is no longer visible once the transaction has ended, even if the + // transaction was committed. + if err := state.Commit(); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestSetLocalValueOutsideTransaction(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + // Setting a local value outside a transaction is a no-op. + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("initial value mismatch\n Got: %v\nWant: %v", g, w) + } + if err := prop.SetLocalValue(state, "new-value"); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestSetLocalValueForStartupProperty(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextStartup) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + err := prop.SetLocalValue(state, "new-value") + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestSetInTransactionForStartupProperty(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextStartup) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + _ = state.Begin() + err := prop.SetValue(state, "new-value", ContextUser) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestSetStartupProperty(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextStartup) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + err := prop.SetValue(state, "new-value", ContextUser) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestSetNormalAndLocalValue(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, commit := range []bool{true, false} { + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + _ = state.Begin() + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("initial value mismatch\n Got: %v\nWant: %v", g, w) + } + // First set a local value. + if err := prop.SetLocalValue(state, "local-value"); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), "local-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + // Then set a 'standard' value in a transaction. + // This should override the local value. + if err := prop.SetValue(state, "new-value", ContextUser); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), "new-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + // Then set yet another local value. This should take precedence within the current transaction. + if err := prop.SetLocalValue(state, "second-local-value"); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), "second-local-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the last local value is lost when the transaction ends. + // The value that was set in the transaction should be persisted if + // the transaction is committed, or if we are using non-transactional state. + if commit { + _ = state.Commit() + if g, w := prop.GetValueOrDefault(state), "new-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + _ = state.Rollback() + if tp == TypeNonTransactional { + // The transaction was rolled back, but this should have no impact on the + // SET statement in the transaction. So the new-value should be persisted. + if g, w := prop.GetValueOrDefault(state), "new-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + // The transaction was rolled back. The value should have been reset to + // the value it had before the transaction. + if g, w := prop.GetValueOrDefault(state), "initial-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } + } + } + } +} + +func TestSetUnknownProperty(t *testing.T) { + properties := map[string]ConnectionProperty{} + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + // Try to add a new property without an extension to the connection state. + // This should fail. + prop := CreateConnectionProperty("prop", "Test property", "initial-value", nil, ContextUser) + err := prop.SetValue(state, "new-value", ContextUser) + if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + // Adding a new property with an extension to the connection state should work. + propWithExtension := CreateConnectionPropertyWithExtension("spanner", "prop3", "Test property 3", "initial-value-3", nil, ContextUser) + if err := propWithExtension.SetValue(state, "new-value", ContextUser); err != nil { + t.Fatal(err) + } + } +} + +func TestReset(t *testing.T) { + prop1 := CreateConnectionProperty("prop1", "Test property 1", "initial-value-1", nil, ContextUser) + prop2 := CreateConnectionProperty("prop2", "Test property 2", "initial-value-2", nil, ContextUser) + properties := map[string]ConnectionProperty{ + prop1.key: prop1, + prop2.key: prop2, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + _ = prop1.SetValue(state, "new-value-1", ContextUser) + _ = prop2.SetValue(state, "new-value-2", ContextUser) + // Add a new property to the connection state. This will be removed when the state is reset. + prop3 := CreateConnectionPropertyWithExtension("spanner", "prop3", "Test property 3", "initial-value-3", nil, ContextUser) + if err := prop3.SetValue(state, "new-value-3", ContextUser); err != nil { + t.Fatal(err) + } + + if g, w := prop1.GetValueOrDefault(state), "new-value-1"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := prop2.GetValueOrDefault(state), "new-value-2"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := prop3.GetValueOrDefault(state), "new-value-3"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := state.Reset(ContextUser); err != nil { + t.Fatal(err) + } + } +} + +func TestResetWithInitialValues(t *testing.T) { + prop1 := CreateConnectionProperty("prop1", "Test property 1", "default-value-1", nil, ContextUser) + prop2 := CreateConnectionProperty("prop2", "Test property 2", "default-value-2", nil, ContextUser) + properties := map[string]ConnectionProperty{ + prop1.key: prop1, + prop2.key: prop2, + } + initialValues := map[string]ConnectionPropertyValue{ + prop2.key: CreateInitialValue(prop2, "initial-value-2"), + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, initialValues) + // Verify that prop1 has the default value of the property, and that prop2 has the initial value that was + // passed in when the ConnectionState was created. + if g, w := prop1.GetValueOrDefault(state), "default-value-1"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := prop2.GetValueOrDefault(state), "initial-value-2"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that updating the values works. + _ = prop1.SetValue(state, "new-value-1", ContextUser) + _ = prop2.SetValue(state, "new-value-2", ContextUser) + + if g, w := prop1.GetValueOrDefault(state), "new-value-1"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := prop2.GetValueOrDefault(state), "new-value-2"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Resetting the values should bring them back to the original values. + if err := state.Reset(ContextUser); err != nil { + t.Fatal(err) + } + if g, w := prop1.GetValueOrDefault(state), "default-value-1"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := prop2.GetValueOrDefault(state), "initial-value-2"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } +} diff --git a/driver.go b/driver.go index 4824f0a1..8fed04c4 100644 --- a/driver.go +++ b/driver.go @@ -39,6 +39,7 @@ import ( "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/google/uuid" "github.com/googleapis/gax-go/v2" + "github.com/googleapis/go-sql-spanner/connectionstate" "google.golang.org/api/iterator" "google.golang.org/api/option" "google.golang.org/api/option/internaloption" @@ -293,6 +294,17 @@ type ConnectorConfig struct { // any of those do not yet exist. AutoConfigEmulator bool + // ConnectionStateType determines the behavior of changes to connection state + // during a transaction. + // connectionstate.TypeTransactional means that changes during a transaction + // are only persisted if the transaction is committed. If the transaction is + // rolled back, any changes to the connection state during the transaction + // will be lost. + // connectionstate.TypeNonTransactional means that changes to the connection + // state during a transaction are persisted directly, and are always visible + // after the transaction, regardless whether the transaction was committed or + // rolled back. + ConnectionStateType connectionstate.Type // Params contains key/value pairs for commonly used configuration parameters // for connections. The valid values are the same as the parameters that can // be added to a connection string. @@ -678,6 +690,11 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) { connId := uuid.New().String() logger := c.logger.With("connId", connId) + connectionStateType := c.connectorConfig.ConnectionStateType + if connectionStateType == connectionstate.TypeDefault { + // TODO: Determine the default type of connection state based on the dialect + connectionStateType = connectionstate.TypeNonTransactional + } connection := &conn{ parser: c.parser, connector: c, @@ -688,6 +705,8 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) { database: databaseName, retryAborts: c.retryAbortsInternally, + // TODO: Pass in initial values for the connection state + state: createInitialConnectionState(connectionStateType, map[string]connectionstate.ConnectionPropertyValue{}), autoBatchDml: c.connectorConfig.AutoBatchDml, autoBatchDmlUpdateCount: c.connectorConfig.AutoBatchDmlUpdateCount, autoBatchDmlUpdateCountVerification: !c.connectorConfig.DisableAutoBatchDmlUpdateCountVerification, diff --git a/driver_test.go b/driver_test.go index 31c0fcc9..e3c84dfa 100644 --- a/driver_test.go +++ b/driver_test.go @@ -31,6 +31,7 @@ import ( "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/googleapis/go-sql-spanner/connectionstate" "google.golang.org/api/option" "google.golang.org/grpc/codes" ) @@ -486,12 +487,13 @@ func TestConnection_Reset(t *testing.T) { connector: &connector{ connectorConfig: ConnectorConfig{}, }, + state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{}), readOnlyStaleness: spanner.ExactStaleness(time.Second), batch: &batch{tp: dml}, commitResponse: &spanner.CommitResponse{}, tx: &readOnlyTransaction{ logger: noopLogger, - close: func() { + close: func(_ txResult) { txClosed = true }, }, @@ -517,6 +519,7 @@ func TestConnection_Reset(t *testing.T) { func TestConnection_NoNestedTransactions(t *testing.T) { c := conn{ logger: noopLogger, + state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{}), tx: &readOnlyTransaction{}, } _, err := c.BeginTx(context.Background(), driver.TxOptions{}) @@ -618,10 +621,10 @@ func TestConn_NonDdlStatementsInDdlBatch(t *testing.T) { t.Fatal(err) } c := &conn{ - parser: parser, - logger: noopLogger, - autocommitDMLMode: Transactional, - batch: &batch{tp: ddl}, + parser: parser, + logger: noopLogger, + state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}), + batch: &batch{tp: ddl}, execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options ExecOptions) *spanner.RowIterator { return &spanner.RowIterator{} }, @@ -749,9 +752,9 @@ func TestConn_GetCommitResponseAfterAutocommitDml(t *testing.T) { } want := &spanner.CommitResponse{CommitTs: time.Now()} c := &conn{ - parser: parser, - logger: noopLogger, - autocommitDMLMode: Transactional, + parser: parser, + logger: noopLogger, + state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}), execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options ExecOptions) *spanner.RowIterator { return &spanner.RowIterator{} }, diff --git a/transaction.go b/transaction.go index 5fc2b3c3..b83061b8 100644 --- a/transaction.go +++ b/transaction.go @@ -91,13 +91,20 @@ func (ri *readOnlyRowIterator) ResultSetStats() *sppb.ResultSetStats { } } +type txResult int + +const ( + txResultCommit txResult = iota + txResultRollback +) + var _ contextTransaction = &readOnlyTransaction{} type readOnlyTransaction struct { roTx *spanner.ReadOnlyTransaction boTx *spanner.BatchReadOnlyTransaction logger *slog.Logger - close func() + close func(result txResult) } func (tx *readOnlyTransaction) Commit() error { @@ -109,7 +116,7 @@ func (tx *readOnlyTransaction) Commit() error { } else if tx.roTx != nil { tx.roTx.Close() } - tx.close() + tx.close(txResultCommit) return nil } @@ -120,7 +127,7 @@ func (tx *readOnlyTransaction) Rollback() error { if tx.roTx != nil { tx.roTx.Close() } - tx.close() + tx.close(txResultRollback) return nil } @@ -233,7 +240,7 @@ type readWriteTransaction struct { active bool // batch is any DML batch that is active for this transaction. batch *batch - close func(commitResponse *spanner.CommitResponse, commitErr error) + close func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) // retryAborts indicates whether this transaction will automatically retry // the transaction if it is aborted by Spanner. The default is true. retryAborts bool @@ -409,7 +416,7 @@ func (tx *readWriteTransaction) Commit() (err error) { if tx.rwTx != nil { if !tx.retryAborts { ts, err := tx.rwTx.CommitWithReturnResp(tx.ctx) - tx.close(&ts, err) + tx.close(txResultCommit, &ts, err) return err } @@ -421,7 +428,7 @@ func (tx *readWriteTransaction) Commit() (err error) { tx.rwTx.Rollback(context.Background()) } } - tx.close(&commitResponse, err) + tx.close(txResultCommit, &commitResponse, err) return err } @@ -439,7 +446,7 @@ func (tx *readWriteTransaction) rollback(ctx context.Context) error { if tx.rwTx != nil { tx.rwTx.Rollback(ctx) } - tx.close(nil, nil) + tx.close(txResultRollback, nil, nil) return nil } From 26fc3676a7bdc4b3a57bfddab18e3b945d3490e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Tue, 2 Sep 2025 11:17:21 +0200 Subject: [PATCH 2/2] fix: return error for invalid value --- connectionstate/connection_property.go | 2 +- connectionstate/connection_property_test.go | 46 +++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 connectionstate/connection_property_test.go diff --git a/connectionstate/connection_property.go b/connectionstate/connection_property.go index 66db44b1..af9d46f0 100644 --- a/connectionstate/connection_property.go +++ b/connectionstate/connection_property.go @@ -259,7 +259,7 @@ func (p *TypedConnectionProperty[T]) checkValidValue(value T) error { return nil } } - return nil + return status.Errorf(codes.InvalidArgument, "value %v is not a valid value for %s", value, p) } func unknownPropertyErr(p ConnectionProperty) error { diff --git a/connectionstate/connection_property_test.go b/connectionstate/connection_property_test.go new file mode 100644 index 00000000..232c4b46 --- /dev/null +++ b/connectionstate/connection_property_test.go @@ -0,0 +1,46 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connectionstate + +import ( + "testing" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestCheckValidValue(t *testing.T) { + type testValue int + const ( + testValueUnspecified testValue = iota + testValueTrue + testValueFalse + ) + + p := &TypedConnectionProperty[testValue]{validValues: []testValue{testValueTrue, testValueFalse}} + if err := p.checkValidValue(testValueTrue); err != nil { + t.Fatal(err) + } + if err := p.checkValidValue(testValueFalse); err != nil { + t.Fatal(err) + } + if err := p.checkValidValue(testValueUnspecified); err == nil { + t.Fatalf("expected error for %v", testValueUnspecified) + } else { + if g, w := status.Code(err), codes.InvalidArgument; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + } +}